whisper.rn 0.4.0-rc.7 → 0.4.0-rc.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/cpp/coreml/whisper-encoder.mm +1 -1
- package/cpp/ggml-alloc.c +41 -11
- package/cpp/ggml-alloc.h +3 -1
- package/cpp/ggml-backend-impl.h +38 -34
- package/cpp/ggml-backend.c +630 -269
- package/cpp/ggml-backend.h +58 -30
- package/cpp/ggml-impl.h +3 -0
- package/cpp/ggml-metal-whisper.metal +1253 -341
- package/cpp/ggml-metal.h +6 -54
- package/cpp/ggml-metal.m +2004 -1987
- package/cpp/ggml-quants.c +2230 -421
- package/cpp/ggml-quants.h +39 -1
- package/cpp/ggml.c +735 -265
- package/cpp/ggml.h +94 -43
- package/cpp/whisper.cpp +118 -86
- package/ios/RNWhisperContext.mm +2 -1
- package/lib/commonjs/version.json +1 -1
- package/lib/module/version.json +1 -1
- package/package.json +1 -1
- package/src/version.json +1 -1
|
@@ -59,26 +59,26 @@ kernel void kernel_add(
|
|
|
59
59
|
constant int64_t & ne01,
|
|
60
60
|
constant int64_t & ne02,
|
|
61
61
|
constant int64_t & ne03,
|
|
62
|
-
constant
|
|
63
|
-
constant
|
|
64
|
-
constant
|
|
65
|
-
constant
|
|
62
|
+
constant uint64_t & nb00,
|
|
63
|
+
constant uint64_t & nb01,
|
|
64
|
+
constant uint64_t & nb02,
|
|
65
|
+
constant uint64_t & nb03,
|
|
66
66
|
constant int64_t & ne10,
|
|
67
67
|
constant int64_t & ne11,
|
|
68
68
|
constant int64_t & ne12,
|
|
69
69
|
constant int64_t & ne13,
|
|
70
|
-
constant
|
|
71
|
-
constant
|
|
72
|
-
constant
|
|
73
|
-
constant
|
|
70
|
+
constant uint64_t & nb10,
|
|
71
|
+
constant uint64_t & nb11,
|
|
72
|
+
constant uint64_t & nb12,
|
|
73
|
+
constant uint64_t & nb13,
|
|
74
74
|
constant int64_t & ne0,
|
|
75
75
|
constant int64_t & ne1,
|
|
76
76
|
constant int64_t & ne2,
|
|
77
77
|
constant int64_t & ne3,
|
|
78
|
-
constant
|
|
79
|
-
constant
|
|
80
|
-
constant
|
|
81
|
-
constant
|
|
78
|
+
constant uint64_t & nb0,
|
|
79
|
+
constant uint64_t & nb1,
|
|
80
|
+
constant uint64_t & nb2,
|
|
81
|
+
constant uint64_t & nb3,
|
|
82
82
|
constant int64_t & offs,
|
|
83
83
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
84
84
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
@@ -109,26 +109,26 @@ kernel void kernel_mul(
|
|
|
109
109
|
constant int64_t & ne01,
|
|
110
110
|
constant int64_t & ne02,
|
|
111
111
|
constant int64_t & ne03,
|
|
112
|
-
constant
|
|
113
|
-
constant
|
|
114
|
-
constant
|
|
115
|
-
constant
|
|
112
|
+
constant uint64_t & nb00,
|
|
113
|
+
constant uint64_t & nb01,
|
|
114
|
+
constant uint64_t & nb02,
|
|
115
|
+
constant uint64_t & nb03,
|
|
116
116
|
constant int64_t & ne10,
|
|
117
117
|
constant int64_t & ne11,
|
|
118
118
|
constant int64_t & ne12,
|
|
119
119
|
constant int64_t & ne13,
|
|
120
|
-
constant
|
|
121
|
-
constant
|
|
122
|
-
constant
|
|
123
|
-
constant
|
|
120
|
+
constant uint64_t & nb10,
|
|
121
|
+
constant uint64_t & nb11,
|
|
122
|
+
constant uint64_t & nb12,
|
|
123
|
+
constant uint64_t & nb13,
|
|
124
124
|
constant int64_t & ne0,
|
|
125
125
|
constant int64_t & ne1,
|
|
126
126
|
constant int64_t & ne2,
|
|
127
127
|
constant int64_t & ne3,
|
|
128
|
-
constant
|
|
129
|
-
constant
|
|
130
|
-
constant
|
|
131
|
-
constant
|
|
128
|
+
constant uint64_t & nb0,
|
|
129
|
+
constant uint64_t & nb1,
|
|
130
|
+
constant uint64_t & nb2,
|
|
131
|
+
constant uint64_t & nb3,
|
|
132
132
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
133
133
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
134
134
|
uint3 ntg[[threads_per_threadgroup]]) {
|
|
@@ -158,26 +158,26 @@ kernel void kernel_div(
|
|
|
158
158
|
constant int64_t & ne01,
|
|
159
159
|
constant int64_t & ne02,
|
|
160
160
|
constant int64_t & ne03,
|
|
161
|
-
constant
|
|
162
|
-
constant
|
|
163
|
-
constant
|
|
164
|
-
constant
|
|
161
|
+
constant uint64_t & nb00,
|
|
162
|
+
constant uint64_t & nb01,
|
|
163
|
+
constant uint64_t & nb02,
|
|
164
|
+
constant uint64_t & nb03,
|
|
165
165
|
constant int64_t & ne10,
|
|
166
166
|
constant int64_t & ne11,
|
|
167
167
|
constant int64_t & ne12,
|
|
168
168
|
constant int64_t & ne13,
|
|
169
|
-
constant
|
|
170
|
-
constant
|
|
171
|
-
constant
|
|
172
|
-
constant
|
|
169
|
+
constant uint64_t & nb10,
|
|
170
|
+
constant uint64_t & nb11,
|
|
171
|
+
constant uint64_t & nb12,
|
|
172
|
+
constant uint64_t & nb13,
|
|
173
173
|
constant int64_t & ne0,
|
|
174
174
|
constant int64_t & ne1,
|
|
175
175
|
constant int64_t & ne2,
|
|
176
176
|
constant int64_t & ne3,
|
|
177
|
-
constant
|
|
178
|
-
constant
|
|
179
|
-
constant
|
|
180
|
-
constant
|
|
177
|
+
constant uint64_t & nb0,
|
|
178
|
+
constant uint64_t & nb1,
|
|
179
|
+
constant uint64_t & nb2,
|
|
180
|
+
constant uint64_t & nb3,
|
|
181
181
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
182
182
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
183
183
|
uint3 ntg[[threads_per_threadgroup]]) {
|
|
@@ -205,7 +205,7 @@ kernel void kernel_add_row(
|
|
|
205
205
|
device const float4 * src0,
|
|
206
206
|
device const float4 * src1,
|
|
207
207
|
device float4 * dst,
|
|
208
|
-
constant
|
|
208
|
+
constant uint64_t & nb [[buffer(28)]],
|
|
209
209
|
uint tpig[[thread_position_in_grid]]) {
|
|
210
210
|
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
|
211
211
|
}
|
|
@@ -214,7 +214,7 @@ kernel void kernel_mul_row(
|
|
|
214
214
|
device const float4 * src0,
|
|
215
215
|
device const float4 * src1,
|
|
216
216
|
device float4 * dst,
|
|
217
|
-
constant
|
|
217
|
+
constant uint64_t & nb [[buffer(28)]],
|
|
218
218
|
uint tpig[[thread_position_in_grid]]) {
|
|
219
219
|
dst[tpig] = src0[tpig] * src1[tpig % nb];
|
|
220
220
|
}
|
|
@@ -223,7 +223,7 @@ kernel void kernel_div_row(
|
|
|
223
223
|
device const float4 * src0,
|
|
224
224
|
device const float4 * src1,
|
|
225
225
|
device float4 * dst,
|
|
226
|
-
constant
|
|
226
|
+
constant uint64_t & nb [[buffer(28)]],
|
|
227
227
|
uint tpig[[thread_position_in_grid]]) {
|
|
228
228
|
dst[tpig] = src0[tpig] / src1[tpig % nb];
|
|
229
229
|
}
|
|
@@ -307,26 +307,26 @@ kernel void kernel_sum_rows(
|
|
|
307
307
|
constant int64_t & ne01,
|
|
308
308
|
constant int64_t & ne02,
|
|
309
309
|
constant int64_t & ne03,
|
|
310
|
-
constant
|
|
311
|
-
constant
|
|
312
|
-
constant
|
|
313
|
-
constant
|
|
310
|
+
constant uint64_t & nb00,
|
|
311
|
+
constant uint64_t & nb01,
|
|
312
|
+
constant uint64_t & nb02,
|
|
313
|
+
constant uint64_t & nb03,
|
|
314
314
|
constant int64_t & ne10,
|
|
315
315
|
constant int64_t & ne11,
|
|
316
316
|
constant int64_t & ne12,
|
|
317
317
|
constant int64_t & ne13,
|
|
318
|
-
constant
|
|
319
|
-
constant
|
|
320
|
-
constant
|
|
321
|
-
constant
|
|
318
|
+
constant uint64_t & nb10,
|
|
319
|
+
constant uint64_t & nb11,
|
|
320
|
+
constant uint64_t & nb12,
|
|
321
|
+
constant uint64_t & nb13,
|
|
322
322
|
constant int64_t & ne0,
|
|
323
323
|
constant int64_t & ne1,
|
|
324
324
|
constant int64_t & ne2,
|
|
325
325
|
constant int64_t & ne3,
|
|
326
|
-
constant
|
|
327
|
-
constant
|
|
328
|
-
constant
|
|
329
|
-
constant
|
|
326
|
+
constant uint64_t & nb0,
|
|
327
|
+
constant uint64_t & nb1,
|
|
328
|
+
constant uint64_t & nb2,
|
|
329
|
+
constant uint64_t & nb3,
|
|
330
330
|
uint3 tpig[[thread_position_in_grid]]) {
|
|
331
331
|
int64_t i3 = tpig.z;
|
|
332
332
|
int64_t i2 = tpig.y;
|
|
@@ -846,7 +846,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
|
|
|
846
846
|
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
|
847
847
|
//Note: This is a template, but strictly speaking it only applies to
|
|
848
848
|
// quantizations where the block size is 32. It also does not
|
|
849
|
-
//
|
|
849
|
+
// guard against the number of rows not being divisible by
|
|
850
850
|
// N_DST, so this is another explicit assumption of the implementation.
|
|
851
851
|
template<typename block_q_type, int nr, int nsg, int nw>
|
|
852
852
|
void mul_vec_q_n_f32_impl(
|
|
@@ -920,14 +920,21 @@ kernel void kernel_mul_mv_q4_0_f32(
|
|
|
920
920
|
device const float * src1,
|
|
921
921
|
device float * dst,
|
|
922
922
|
constant int64_t & ne00,
|
|
923
|
-
constant int64_t & ne01
|
|
924
|
-
constant int64_t & ne02
|
|
925
|
-
constant
|
|
926
|
-
constant
|
|
927
|
-
constant
|
|
928
|
-
constant int64_t &
|
|
929
|
-
constant
|
|
930
|
-
constant
|
|
923
|
+
constant int64_t & ne01,
|
|
924
|
+
constant int64_t & ne02,
|
|
925
|
+
constant uint64_t & nb00,
|
|
926
|
+
constant uint64_t & nb01,
|
|
927
|
+
constant uint64_t & nb02,
|
|
928
|
+
constant int64_t & ne10,
|
|
929
|
+
constant int64_t & ne11,
|
|
930
|
+
constant int64_t & ne12,
|
|
931
|
+
constant uint64_t & nb10,
|
|
932
|
+
constant uint64_t & nb11,
|
|
933
|
+
constant uint64_t & nb12,
|
|
934
|
+
constant int64_t & ne0,
|
|
935
|
+
constant int64_t & ne1,
|
|
936
|
+
constant uint & r2,
|
|
937
|
+
constant uint & r3,
|
|
931
938
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
932
939
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
933
940
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -939,14 +946,21 @@ kernel void kernel_mul_mv_q4_1_f32(
|
|
|
939
946
|
device const float * src1,
|
|
940
947
|
device float * dst,
|
|
941
948
|
constant int64_t & ne00,
|
|
942
|
-
constant int64_t & ne01
|
|
943
|
-
constant int64_t & ne02
|
|
944
|
-
constant
|
|
945
|
-
constant
|
|
946
|
-
constant
|
|
947
|
-
constant int64_t &
|
|
948
|
-
constant
|
|
949
|
-
constant
|
|
949
|
+
constant int64_t & ne01,
|
|
950
|
+
constant int64_t & ne02,
|
|
951
|
+
constant uint64_t & nb00,
|
|
952
|
+
constant uint64_t & nb01,
|
|
953
|
+
constant uint64_t & nb02,
|
|
954
|
+
constant int64_t & ne10,
|
|
955
|
+
constant int64_t & ne11,
|
|
956
|
+
constant int64_t & ne12,
|
|
957
|
+
constant uint64_t & nb10,
|
|
958
|
+
constant uint64_t & nb11,
|
|
959
|
+
constant uint64_t & nb12,
|
|
960
|
+
constant int64_t & ne0,
|
|
961
|
+
constant int64_t & ne1,
|
|
962
|
+
constant uint & r2,
|
|
963
|
+
constant uint & r3,
|
|
950
964
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
951
965
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
952
966
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -958,14 +972,21 @@ kernel void kernel_mul_mv_q5_0_f32(
|
|
|
958
972
|
device const float * src1,
|
|
959
973
|
device float * dst,
|
|
960
974
|
constant int64_t & ne00,
|
|
961
|
-
constant int64_t & ne01
|
|
962
|
-
constant int64_t & ne02
|
|
963
|
-
constant
|
|
964
|
-
constant
|
|
965
|
-
constant
|
|
966
|
-
constant int64_t &
|
|
967
|
-
constant
|
|
968
|
-
constant
|
|
975
|
+
constant int64_t & ne01,
|
|
976
|
+
constant int64_t & ne02,
|
|
977
|
+
constant uint64_t & nb00,
|
|
978
|
+
constant uint64_t & nb01,
|
|
979
|
+
constant uint64_t & nb02,
|
|
980
|
+
constant int64_t & ne10,
|
|
981
|
+
constant int64_t & ne11,
|
|
982
|
+
constant int64_t & ne12,
|
|
983
|
+
constant uint64_t & nb10,
|
|
984
|
+
constant uint64_t & nb11,
|
|
985
|
+
constant uint64_t & nb12,
|
|
986
|
+
constant int64_t & ne0,
|
|
987
|
+
constant int64_t & ne1,
|
|
988
|
+
constant uint & r2,
|
|
989
|
+
constant uint & r3,
|
|
969
990
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
970
991
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
971
992
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -977,14 +998,21 @@ kernel void kernel_mul_mv_q5_1_f32(
|
|
|
977
998
|
device const float * src1,
|
|
978
999
|
device float * dst,
|
|
979
1000
|
constant int64_t & ne00,
|
|
980
|
-
constant int64_t & ne01
|
|
981
|
-
constant int64_t & ne02
|
|
982
|
-
constant
|
|
983
|
-
constant
|
|
984
|
-
constant
|
|
985
|
-
constant int64_t &
|
|
986
|
-
constant
|
|
987
|
-
constant
|
|
1001
|
+
constant int64_t & ne01,
|
|
1002
|
+
constant int64_t & ne02,
|
|
1003
|
+
constant uint64_t & nb00,
|
|
1004
|
+
constant uint64_t & nb01,
|
|
1005
|
+
constant uint64_t & nb02,
|
|
1006
|
+
constant int64_t & ne10,
|
|
1007
|
+
constant int64_t & ne11,
|
|
1008
|
+
constant int64_t & ne12,
|
|
1009
|
+
constant uint64_t & nb10,
|
|
1010
|
+
constant uint64_t & nb11,
|
|
1011
|
+
constant uint64_t & nb12,
|
|
1012
|
+
constant int64_t & ne0,
|
|
1013
|
+
constant int64_t & ne1,
|
|
1014
|
+
constant uint & r2,
|
|
1015
|
+
constant uint & r3,
|
|
988
1016
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
989
1017
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
990
1018
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -1071,12 +1099,19 @@ kernel void kernel_mul_mv_q8_0_f32(
|
|
|
1071
1099
|
constant int64_t & ne00,
|
|
1072
1100
|
constant int64_t & ne01,
|
|
1073
1101
|
constant int64_t & ne02,
|
|
1102
|
+
constant uint64_t & nb00,
|
|
1103
|
+
constant uint64_t & nb01,
|
|
1104
|
+
constant uint64_t & nb02,
|
|
1074
1105
|
constant int64_t & ne10,
|
|
1106
|
+
constant int64_t & ne11,
|
|
1075
1107
|
constant int64_t & ne12,
|
|
1108
|
+
constant uint64_t & nb10,
|
|
1109
|
+
constant uint64_t & nb11,
|
|
1110
|
+
constant uint64_t & nb12,
|
|
1076
1111
|
constant int64_t & ne0,
|
|
1077
1112
|
constant int64_t & ne1,
|
|
1078
|
-
constant uint & r2
|
|
1079
|
-
constant uint & r3
|
|
1113
|
+
constant uint & r2,
|
|
1114
|
+
constant uint & r3,
|
|
1080
1115
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1081
1116
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
1082
1117
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -1182,8 +1217,8 @@ kernel void kernel_mul_mv_f32_f32(
|
|
|
1182
1217
|
constant uint64_t & nb12,
|
|
1183
1218
|
constant int64_t & ne0,
|
|
1184
1219
|
constant int64_t & ne1,
|
|
1185
|
-
constant uint & r2
|
|
1186
|
-
constant uint & r3
|
|
1220
|
+
constant uint & r2,
|
|
1221
|
+
constant uint & r3,
|
|
1187
1222
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1188
1223
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
1189
1224
|
kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
|
@@ -1209,8 +1244,8 @@ kernel void kernel_mul_mv_f16_f16(
|
|
|
1209
1244
|
constant uint64_t & nb12,
|
|
1210
1245
|
constant int64_t & ne0,
|
|
1211
1246
|
constant int64_t & ne1,
|
|
1212
|
-
constant uint & r2
|
|
1213
|
-
constant uint & r3
|
|
1247
|
+
constant uint & r2,
|
|
1248
|
+
constant uint & r3,
|
|
1214
1249
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1215
1250
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
1216
1251
|
|
|
@@ -1346,8 +1381,8 @@ kernel void kernel_mul_mv_f16_f32_1row(
|
|
|
1346
1381
|
constant uint64_t & nb12,
|
|
1347
1382
|
constant int64_t & ne0,
|
|
1348
1383
|
constant int64_t & ne1,
|
|
1349
|
-
constant uint & r2
|
|
1350
|
-
constant uint & r3
|
|
1384
|
+
constant uint & r2,
|
|
1385
|
+
constant uint & r3,
|
|
1351
1386
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1352
1387
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
1353
1388
|
kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
|
@@ -1452,8 +1487,8 @@ kernel void kernel_mul_mv_f16_f32(
|
|
|
1452
1487
|
constant uint64_t & nb12,
|
|
1453
1488
|
constant int64_t & ne0,
|
|
1454
1489
|
constant int64_t & ne1,
|
|
1455
|
-
constant uint & r2
|
|
1456
|
-
constant uint & r3
|
|
1490
|
+
constant uint & r2,
|
|
1491
|
+
constant uint & r3,
|
|
1457
1492
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1458
1493
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
1459
1494
|
kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
|
@@ -1478,8 +1513,8 @@ kernel void kernel_mul_mv_f16_f32_l4(
|
|
|
1478
1513
|
constant uint64_t & nb12,
|
|
1479
1514
|
constant int64_t & ne0,
|
|
1480
1515
|
constant int64_t & ne1,
|
|
1481
|
-
constant uint & r2
|
|
1482
|
-
constant uint & r3
|
|
1516
|
+
constant uint & r2,
|
|
1517
|
+
constant uint & r3,
|
|
1483
1518
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1484
1519
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
1485
1520
|
|
|
@@ -1543,7 +1578,8 @@ kernel void kernel_alibi_f32(
|
|
|
1543
1578
|
const int64_t i3 = n / (ne2*ne1*ne0);
|
|
1544
1579
|
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
|
1545
1580
|
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
|
1546
|
-
|
|
1581
|
+
//const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
|
1582
|
+
|
|
1547
1583
|
const int64_t k = i3*ne3 + i2;
|
|
1548
1584
|
|
|
1549
1585
|
float m_k;
|
|
@@ -1702,8 +1738,9 @@ kernel void kernel_rope(
|
|
|
1702
1738
|
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
|
1703
1739
|
}
|
|
1704
1740
|
} else {
|
|
1705
|
-
for (int64_t
|
|
1706
|
-
|
|
1741
|
+
for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
|
|
1742
|
+
if (ic < n_dims) {
|
|
1743
|
+
const int64_t ib = 0;
|
|
1707
1744
|
|
|
1708
1745
|
// simplified from `(ib * n_dims + ic) * inv_ndims`
|
|
1709
1746
|
const float cur_rot = inv_ndims*ic - ib;
|
|
@@ -1722,6 +1759,14 @@ kernel void kernel_rope(
|
|
|
1722
1759
|
|
|
1723
1760
|
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
|
1724
1761
|
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
|
1762
|
+
} else {
|
|
1763
|
+
const int64_t i0 = ic;
|
|
1764
|
+
|
|
1765
|
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
1766
|
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
1767
|
+
|
|
1768
|
+
dst_data[0] = src[0];
|
|
1769
|
+
dst_data[1] = src[1];
|
|
1725
1770
|
}
|
|
1726
1771
|
}
|
|
1727
1772
|
}
|
|
@@ -2401,21 +2446,18 @@ typedef struct {
|
|
|
2401
2446
|
} block_q6_K;
|
|
2402
2447
|
// 210 bytes / block
|
|
2403
2448
|
|
|
2404
|
-
|
|
2405
|
-
|
|
2406
|
-
|
|
2407
|
-
|
|
2408
|
-
|
|
2409
|
-
|
|
2410
|
-
|
|
2411
|
-
|
|
2412
|
-
|
|
2413
|
-
|
|
2414
|
-
|
|
2415
|
-
|
|
2416
|
-
}
|
|
2417
|
-
return r;
|
|
2418
|
-
}
|
|
2449
|
+
typedef struct {
|
|
2450
|
+
half d;
|
|
2451
|
+
uint16_t qs[QK_K/8];
|
|
2452
|
+
} block_iq2_xxs;
|
|
2453
|
+
// 66 bytes / block for QK_K = 256, so 2.0625 bpw
|
|
2454
|
+
|
|
2455
|
+
typedef struct {
|
|
2456
|
+
half d;
|
|
2457
|
+
uint16_t qs[QK_K/8];
|
|
2458
|
+
uint8_t scales[QK_K/32];
|
|
2459
|
+
} block_iq2_xs;
|
|
2460
|
+
// 74 bytes / block for QK_K = 256, so 2.3125 bpw
|
|
2419
2461
|
|
|
2420
2462
|
//====================================== dot products =========================
|
|
2421
2463
|
|
|
@@ -2575,14 +2617,21 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
|
2575
2617
|
device const float * src1,
|
|
2576
2618
|
device float * dst,
|
|
2577
2619
|
constant int64_t & ne00,
|
|
2578
|
-
constant int64_t & ne01
|
|
2579
|
-
constant int64_t & ne02
|
|
2580
|
-
constant
|
|
2581
|
-
constant
|
|
2582
|
-
constant
|
|
2583
|
-
constant int64_t &
|
|
2584
|
-
constant
|
|
2585
|
-
constant
|
|
2620
|
+
constant int64_t & ne01,
|
|
2621
|
+
constant int64_t & ne02,
|
|
2622
|
+
constant uint64_t & nb00,
|
|
2623
|
+
constant uint64_t & nb01,
|
|
2624
|
+
constant uint64_t & nb02,
|
|
2625
|
+
constant int64_t & ne10,
|
|
2626
|
+
constant int64_t & ne11,
|
|
2627
|
+
constant int64_t & ne12,
|
|
2628
|
+
constant uint64_t & nb10,
|
|
2629
|
+
constant uint64_t & nb11,
|
|
2630
|
+
constant uint64_t & nb12,
|
|
2631
|
+
constant int64_t & ne0,
|
|
2632
|
+
constant int64_t & ne1,
|
|
2633
|
+
constant uint & r2,
|
|
2634
|
+
constant uint & r3,
|
|
2586
2635
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2587
2636
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
2588
2637
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -2832,14 +2881,21 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
|
2832
2881
|
device const float * src1,
|
|
2833
2882
|
device float * dst,
|
|
2834
2883
|
constant int64_t & ne00,
|
|
2835
|
-
constant int64_t & ne01
|
|
2836
|
-
constant int64_t & ne02
|
|
2837
|
-
constant
|
|
2838
|
-
constant
|
|
2839
|
-
constant
|
|
2840
|
-
constant int64_t &
|
|
2841
|
-
constant
|
|
2842
|
-
constant
|
|
2884
|
+
constant int64_t & ne01,
|
|
2885
|
+
constant int64_t & ne02,
|
|
2886
|
+
constant uint64_t & nb00,
|
|
2887
|
+
constant uint64_t & nb01,
|
|
2888
|
+
constant uint64_t & nb02,
|
|
2889
|
+
constant int64_t & ne10,
|
|
2890
|
+
constant int64_t & ne11,
|
|
2891
|
+
constant int64_t & ne12,
|
|
2892
|
+
constant uint64_t & nb10,
|
|
2893
|
+
constant uint64_t & nb11,
|
|
2894
|
+
constant uint64_t & nb12,
|
|
2895
|
+
constant int64_t & ne0,
|
|
2896
|
+
constant int64_t & ne1,
|
|
2897
|
+
constant uint & r2,
|
|
2898
|
+
constant uint & r3,
|
|
2843
2899
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2844
2900
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
2845
2901
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -2975,8 +3031,8 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
|
2975
3031
|
constant uint & r2,
|
|
2976
3032
|
constant uint & r3,
|
|
2977
3033
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2978
|
-
uint
|
|
2979
|
-
uint
|
|
3034
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
3035
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
2980
3036
|
|
|
2981
3037
|
const int ix = tiisg/4; // 0...7
|
|
2982
3038
|
const int it = tiisg%4; // 0...3
|
|
@@ -2985,7 +3041,7 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
|
2985
3041
|
const int r0 = tgpig.x;
|
|
2986
3042
|
const int r1 = tgpig.y;
|
|
2987
3043
|
const int im = tgpig.z;
|
|
2988
|
-
const int first_row =
|
|
3044
|
+
const int first_row = r0 * N_DST;
|
|
2989
3045
|
const int ib_row = first_row * nb;
|
|
2990
3046
|
|
|
2991
3047
|
const uint i12 = im%ne12;
|
|
@@ -3051,7 +3107,7 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
|
3051
3107
|
for (int row = 0; row < N_DST; ++row) {
|
|
3052
3108
|
all_sum = simd_sum(sumf[row]);
|
|
3053
3109
|
if (tiisg == 0) {
|
|
3054
|
-
dst[r1*ne0+ im*ne0*ne1 + first_row + row] = all_sum;
|
|
3110
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
|
3055
3111
|
}
|
|
3056
3112
|
}
|
|
3057
3113
|
}
|
|
@@ -3063,14 +3119,21 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
|
3063
3119
|
device const float * src1,
|
|
3064
3120
|
device float * dst,
|
|
3065
3121
|
constant int64_t & ne00,
|
|
3066
|
-
constant int64_t & ne01
|
|
3067
|
-
constant int64_t & ne02
|
|
3068
|
-
constant
|
|
3069
|
-
constant
|
|
3070
|
-
constant
|
|
3071
|
-
constant int64_t &
|
|
3072
|
-
constant
|
|
3073
|
-
constant
|
|
3122
|
+
constant int64_t & ne01,
|
|
3123
|
+
constant int64_t & ne02,
|
|
3124
|
+
constant uint64_t & nb00,
|
|
3125
|
+
constant uint64_t & nb01,
|
|
3126
|
+
constant uint64_t & nb02,
|
|
3127
|
+
constant int64_t & ne10,
|
|
3128
|
+
constant int64_t & ne11,
|
|
3129
|
+
constant int64_t & ne12,
|
|
3130
|
+
constant uint64_t & nb10,
|
|
3131
|
+
constant uint64_t & nb11,
|
|
3132
|
+
constant uint64_t & nb12,
|
|
3133
|
+
constant int64_t & ne0,
|
|
3134
|
+
constant int64_t & ne1,
|
|
3135
|
+
constant uint & r2,
|
|
3136
|
+
constant uint & r3,
|
|
3074
3137
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
3075
3138
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
3076
3139
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -3262,14 +3325,21 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
|
3262
3325
|
device const float * src1,
|
|
3263
3326
|
device float * dst,
|
|
3264
3327
|
constant int64_t & ne00,
|
|
3265
|
-
constant int64_t & ne01
|
|
3266
|
-
constant int64_t & ne02
|
|
3267
|
-
constant
|
|
3268
|
-
constant
|
|
3269
|
-
constant
|
|
3270
|
-
constant int64_t &
|
|
3271
|
-
constant
|
|
3272
|
-
constant
|
|
3328
|
+
constant int64_t & ne01,
|
|
3329
|
+
constant int64_t & ne02,
|
|
3330
|
+
constant uint64_t & nb00,
|
|
3331
|
+
constant uint64_t & nb01,
|
|
3332
|
+
constant uint64_t & nb02,
|
|
3333
|
+
constant int64_t & ne10,
|
|
3334
|
+
constant int64_t & ne11,
|
|
3335
|
+
constant int64_t & ne12,
|
|
3336
|
+
constant uint64_t & nb10,
|
|
3337
|
+
constant uint64_t & nb11,
|
|
3338
|
+
constant uint64_t & nb12,
|
|
3339
|
+
constant int64_t & ne0,
|
|
3340
|
+
constant int64_t & ne1,
|
|
3341
|
+
constant uint & r2,
|
|
3342
|
+
constant uint & r3,
|
|
3273
3343
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
3274
3344
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
3275
3345
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -3389,14 +3459,21 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|
|
3389
3459
|
device const float * src1,
|
|
3390
3460
|
device float * dst,
|
|
3391
3461
|
constant int64_t & ne00,
|
|
3392
|
-
constant int64_t & ne01
|
|
3393
|
-
constant int64_t & ne02
|
|
3394
|
-
constant
|
|
3395
|
-
constant
|
|
3396
|
-
constant
|
|
3397
|
-
constant int64_t &
|
|
3398
|
-
constant
|
|
3399
|
-
constant
|
|
3462
|
+
constant int64_t & ne01,
|
|
3463
|
+
constant int64_t & ne02,
|
|
3464
|
+
constant uint64_t & nb00,
|
|
3465
|
+
constant uint64_t & nb01,
|
|
3466
|
+
constant uint64_t & nb02,
|
|
3467
|
+
constant int64_t & ne10,
|
|
3468
|
+
constant int64_t & ne11,
|
|
3469
|
+
constant int64_t & ne12,
|
|
3470
|
+
constant uint64_t & nb10,
|
|
3471
|
+
constant uint64_t & nb11,
|
|
3472
|
+
constant uint64_t & nb12,
|
|
3473
|
+
constant int64_t & ne0,
|
|
3474
|
+
constant int64_t & ne1,
|
|
3475
|
+
constant uint & r2,
|
|
3476
|
+
constant uint & r3,
|
|
3400
3477
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
3401
3478
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
3402
3479
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -3404,51 +3481,540 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|
|
3404
3481
|
kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
|
3405
3482
|
}
|
|
3406
3483
|
|
|
3407
|
-
|
|
3408
|
-
|
|
3409
|
-
|
|
3410
|
-
|
|
3411
|
-
|
|
3412
|
-
|
|
3413
|
-
|
|
3414
|
-
|
|
3415
|
-
|
|
3416
|
-
|
|
3484
|
+
// ======================= "True" 2-bit
|
|
3485
|
+
|
|
3486
|
+
constexpr constant static uint64_t iq2xxs_grid[256] = {
|
|
3487
|
+
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
|
|
3488
|
+
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
|
|
3489
|
+
0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
|
|
3490
|
+
0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819,
|
|
3491
|
+
0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b,
|
|
3492
|
+
0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808,
|
|
3493
|
+
0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08,
|
|
3494
|
+
0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b,
|
|
3495
|
+
0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819,
|
|
3496
|
+
0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08,
|
|
3497
|
+
0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808,
|
|
3498
|
+
0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08,
|
|
3499
|
+
0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808,
|
|
3500
|
+
0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808,
|
|
3501
|
+
0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919,
|
|
3502
|
+
0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819,
|
|
3503
|
+
0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08,
|
|
3504
|
+
0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908,
|
|
3505
|
+
0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819,
|
|
3506
|
+
0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808,
|
|
3507
|
+
0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808,
|
|
3508
|
+
0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908,
|
|
3509
|
+
0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808,
|
|
3510
|
+
0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08,
|
|
3511
|
+
0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819,
|
|
3512
|
+
0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819,
|
|
3513
|
+
0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819,
|
|
3514
|
+
0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908,
|
|
3515
|
+
0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19,
|
|
3516
|
+
0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819,
|
|
3517
|
+
0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b,
|
|
3518
|
+
0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808,
|
|
3519
|
+
0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908,
|
|
3520
|
+
0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08,
|
|
3521
|
+
0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08,
|
|
3522
|
+
0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908,
|
|
3523
|
+
0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819,
|
|
3524
|
+
0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808,
|
|
3525
|
+
0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808,
|
|
3526
|
+
0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19,
|
|
3527
|
+
0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819,
|
|
3528
|
+
0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919,
|
|
3529
|
+
0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b,
|
|
3530
|
+
0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08,
|
|
3531
|
+
0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808,
|
|
3532
|
+
0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908,
|
|
3533
|
+
0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b,
|
|
3534
|
+
0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819,
|
|
3535
|
+
0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08,
|
|
3536
|
+
0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08,
|
|
3537
|
+
0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808,
|
|
3538
|
+
0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b,
|
|
3539
|
+
0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b,
|
|
3540
|
+
0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908,
|
|
3541
|
+
0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819,
|
|
3542
|
+
0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808,
|
|
3543
|
+
0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908,
|
|
3544
|
+
0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b,
|
|
3545
|
+
0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808,
|
|
3546
|
+
0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b,
|
|
3547
|
+
0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b,
|
|
3548
|
+
0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808,
|
|
3549
|
+
0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19,
|
|
3550
|
+
0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
|
|
3551
|
+
};
|
|
3417
3552
|
|
|
3418
|
-
|
|
3419
|
-
|
|
3420
|
-
|
|
3421
|
-
|
|
3422
|
-
|
|
3423
|
-
|
|
3424
|
-
|
|
3553
|
+
constexpr constant static uint64_t iq2xs_grid[512] = {
|
|
3554
|
+
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
|
|
3555
|
+
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
|
|
3556
|
+
0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
|
|
3557
|
+
0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
|
|
3558
|
+
0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
|
|
3559
|
+
0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808,
|
|
3560
|
+
0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, 0x080808082b190819,
|
|
3561
|
+
0x080808082b191908, 0x080808082b192b19, 0x080808082b2b0808, 0x0808081908080819,
|
|
3562
|
+
0x0808081908081908, 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808,
|
|
3563
|
+
0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, 0x0808081908192b2b,
|
|
3564
|
+
0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808, 0x080808191908082b,
|
|
3565
|
+
0x0808081919081919, 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908,
|
|
3566
|
+
0x08080819192b0808, 0x08080819192b2b08, 0x080808192b080819, 0x080808192b081908,
|
|
3567
|
+
0x080808192b190808, 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b08081919,
|
|
3568
|
+
0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808,
|
|
3569
|
+
0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, 0x0808082b19191919,
|
|
3570
|
+
0x0808082b2b080808, 0x0808082b2b082b2b, 0x0808190808080819, 0x0808190808081908,
|
|
3571
|
+
0x080819080808192b, 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b,
|
|
3572
|
+
0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908,
|
|
3573
|
+
0x0808190819080808, 0x080819081908082b, 0x0808190819081919, 0x0808190819082b08,
|
|
3574
|
+
0x0808190819190819, 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808,
|
|
3575
|
+
0x080819082b080819, 0x080819082b081908, 0x080819082b190808, 0x0808191908080808,
|
|
3576
|
+
0x080819190808082b, 0x0808191908081919, 0x0808191908082b08, 0x0808191908190819,
|
|
3577
|
+
0x0808191908191908, 0x08081919082b0808, 0x0808191919080819, 0x0808191919081908,
|
|
3578
|
+
0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, 0x0808192b08080819,
|
|
3579
|
+
0x0808192b08081908, 0x0808192b08190808, 0x0808192b082b192b, 0x0808192b19080808,
|
|
3580
|
+
0x0808192b1908082b, 0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b,
|
|
3581
|
+
0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808082b2b, 0x08082b0808190819,
|
|
3582
|
+
0x08082b0808191908, 0x08082b08082b0808, 0x08082b08082b1919, 0x08082b0819080819,
|
|
3583
|
+
0x08082b0819081908, 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808,
|
|
3584
|
+
0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908,
|
|
3585
|
+
0x08082b1908190808, 0x08082b1919080808, 0x08082b192b080819, 0x08082b192b082b19,
|
|
3586
|
+
0x08082b2b08080808, 0x08082b2b082b0808, 0x08082b2b082b2b08, 0x08082b2b2b19192b,
|
|
3587
|
+
0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908, 0x081908080808192b,
|
|
3588
|
+
0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, 0x0819080808191919,
|
|
3589
|
+
0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808,
|
|
3590
|
+
0x081908081908082b, 0x0819080819081919, 0x0819080819082b08, 0x0819080819190819,
|
|
3591
|
+
0x0819080819191908, 0x08190808192b0808, 0x08190808192b2b2b, 0x081908082b080819,
|
|
3592
|
+
0x081908082b081908, 0x081908082b190808, 0x0819081908080808, 0x081908190808082b,
|
|
3593
|
+
0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, 0x0819081908191908,
|
|
3594
|
+
0x08190819082b0808, 0x0819081919080819, 0x0819081919081908, 0x0819081919190808,
|
|
3595
|
+
0x081908192b080808, 0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819,
|
|
3596
|
+
0x0819082b08081908, 0x0819082b0808192b, 0x0819082b08190808, 0x0819082b19080808,
|
|
3597
|
+
0x0819082b192b0808, 0x0819190808080808, 0x081919080808082b, 0x0819190808081919,
|
|
3598
|
+
0x0819190808082b08, 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808,
|
|
3599
|
+
0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, 0x0819190819190808,
|
|
3600
|
+
0x08191908192b1908, 0x081919082b080808, 0x0819191908080819, 0x0819191908081908,
|
|
3601
|
+
0x0819191908190808, 0x0819191919080808, 0x0819192b08080808, 0x0819192b08191908,
|
|
3602
|
+
0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908, 0x08192b0808190808,
|
|
3603
|
+
0x08192b080819082b, 0x08192b0819080808, 0x08192b0819191908, 0x08192b082b08192b,
|
|
3604
|
+
0x08192b1908080808, 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819,
|
|
3605
|
+
0x08192b2b2b2b2b19, 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919,
|
|
3606
|
+
0x082b080808082b08, 0x082b080808082b2b, 0x082b080808190819, 0x082b080808191908,
|
|
3607
|
+
0x082b0808082b0808, 0x082b080819080819, 0x082b080819081908, 0x082b080819190808,
|
|
3608
|
+
0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, 0x082b081908081908,
|
|
3609
|
+
0x082b081908190808, 0x082b081919080808, 0x082b081919082b08, 0x082b0819192b1919,
|
|
3610
|
+
0x082b082b08080808, 0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08,
|
|
3611
|
+
0x082b190808080819, 0x082b190808081908, 0x082b190808190808, 0x082b1908082b2b19,
|
|
3612
|
+
0x082b190819080808, 0x082b191908080808, 0x082b191919080819, 0x082b19191919082b,
|
|
3613
|
+
0x082b19192b192b19, 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b,
|
|
3614
|
+
0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, 0x082b2b08082b0808,
|
|
3615
|
+
0x082b2b0819191919, 0x082b2b082b082b08, 0x082b2b082b2b082b, 0x082b2b19192b2b08,
|
|
3616
|
+
0x082b2b192b190808, 0x082b2b2b08082b08, 0x082b2b2b082b0808, 0x082b2b2b2b08082b,
|
|
3617
|
+
0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, 0x1908080808081908,
|
|
3618
|
+
0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, 0x190808080819082b,
|
|
3619
|
+
0x1908080808191919, 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908,
|
|
3620
|
+
0x1908080819080808, 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08,
|
|
3621
|
+
0x1908080819082b2b, 0x1908080819190819, 0x1908080819191908, 0x19080808192b0808,
|
|
3622
|
+
0x19080808192b1919, 0x190808082b080819, 0x190808082b081908, 0x190808082b190808,
|
|
3623
|
+
0x1908081908080808, 0x190808190808082b, 0x1908081908081919, 0x1908081908082b08,
|
|
3624
|
+
0x1908081908190819, 0x1908081908191908, 0x19080819082b0808, 0x1908081919080819,
|
|
3625
|
+
0x1908081919081908, 0x1908081919190808, 0x190808192b080808, 0x190808192b081919,
|
|
3626
|
+
0x190808192b2b082b, 0x1908082b08080819, 0x1908082b08081908, 0x1908082b08190808,
|
|
3627
|
+
0x1908082b0819082b, 0x1908082b082b2b19, 0x1908082b19080808, 0x1908190808080808,
|
|
3628
|
+
0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819,
|
|
3629
|
+
0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, 0x1908190819080819,
|
|
3630
|
+
0x1908190819081908, 0x1908190819190808, 0x190819082b080808, 0x190819082b191908,
|
|
3631
|
+
0x1908191908080819, 0x1908191908081908, 0x1908191908190808, 0x19081919082b1908,
|
|
3632
|
+
0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808, 0x1908192b08082b2b,
|
|
3633
|
+
0x1908192b19081908, 0x1908192b19190808, 0x19082b0808080819, 0x19082b0808081908,
|
|
3634
|
+
0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908,
|
|
3635
|
+
0x19082b08192b082b, 0x19082b1908080808, 0x19082b1908190819, 0x19082b1919081908,
|
|
3636
|
+
0x19082b1919190808, 0x19082b19192b2b19, 0x19082b2b08081908, 0x1919080808080808,
|
|
3637
|
+
0x191908080808082b, 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819,
|
|
3638
|
+
0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, 0x1919080819080819,
|
|
3639
|
+
0x1919080819081908, 0x1919080819190808, 0x191908082b080808, 0x1919081908080819,
|
|
3640
|
+
0x1919081908081908, 0x1919081908190808, 0x1919081908191919, 0x1919081919080808,
|
|
3641
|
+
0x191908191908082b, 0x1919082b08080808, 0x1919082b19081908, 0x1919082b2b2b2b2b,
|
|
3642
|
+
0x1919190808080819, 0x1919190808081908, 0x1919190808190808, 0x19191908082b0819,
|
|
3643
|
+
0x1919190819080808, 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819,
|
|
3644
|
+
0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, 0x191919192b082b08,
|
|
3645
|
+
0x1919192b082b0819, 0x1919192b192b2b08, 0x1919192b2b2b0819, 0x19192b0808080808,
|
|
3646
|
+
0x19192b0808191908, 0x19192b0819080819, 0x19192b0819190808, 0x19192b082b192b19,
|
|
3647
|
+
0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b, 0x19192b2b2b081919,
|
|
3648
|
+
0x192b080808080819, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808,
|
|
3649
|
+
0x192b080819191908, 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19,
|
|
3650
|
+
0x192b081908080808, 0x192b082b082b1908, 0x192b082b19082b2b, 0x192b082b2b19082b,
|
|
3651
|
+
0x192b190808080808, 0x192b19080819192b, 0x192b191908190808, 0x192b191919080808,
|
|
3652
|
+
0x192b191919081919, 0x192b19192b2b1908, 0x192b2b0808080819, 0x192b2b08192b2b2b,
|
|
3653
|
+
0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, 0x192b2b2b192b082b,
|
|
3654
|
+
0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08,
|
|
3655
|
+
0x2b08080808190819, 0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b,
|
|
3656
|
+
0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808082b080808,
|
|
3657
|
+
0x2b0808082b08082b, 0x2b0808082b2b2b08, 0x2b0808082b2b2b2b, 0x2b08081908080819,
|
|
3658
|
+
0x2b08081908081908, 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808,
|
|
3659
|
+
0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, 0x2b08082b082b0808,
|
|
3660
|
+
0x2b08082b2b080808, 0x2b08082b2b08082b, 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08,
|
|
3661
|
+
0x2b08190808080819, 0x2b08190808081908, 0x2b08190808190808, 0x2b0819080819082b,
|
|
3662
|
+
0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, 0x2b0819082b082b19,
|
|
3663
|
+
0x2b08191908080808, 0x2b08191919081908, 0x2b0819192b2b1919, 0x2b08192b08192b08,
|
|
3664
|
+
0x2b08192b192b2b2b, 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919,
|
|
3665
|
+
0x2b082b0819192b2b, 0x2b082b082b080808, 0x2b082b082b08082b, 0x2b082b082b2b2b08,
|
|
3666
|
+
0x2b082b190808192b, 0x2b082b2b082b082b, 0x2b082b2b2b080808, 0x2b082b2b2b082b08,
|
|
3667
|
+
0x2b082b2b2b19192b, 0x2b082b2b2b2b2b08, 0x2b19080808080819, 0x2b19080808081908,
|
|
3668
|
+
0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, 0x2b1908082b081908,
|
|
3669
|
+
0x2b19081908080808, 0x2b190819082b082b, 0x2b190819192b1908, 0x2b19082b1919192b,
|
|
3670
|
+
0x2b19082b2b082b19, 0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908,
|
|
3671
|
+
0x2b19190819190808, 0x2b19190819192b08, 0x2b191919082b2b19, 0x2b1919192b190808,
|
|
3672
|
+
0x2b1919192b19082b, 0x2b19192b19080819, 0x2b192b0819190819, 0x2b192b082b2b192b,
|
|
3673
|
+
0x2b192b1919082b19, 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808,
|
|
3674
|
+
0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, 0x2b2b0808082b0808,
|
|
3675
|
+
0x2b2b0808082b2b2b, 0x2b2b08082b2b0808, 0x2b2b081919190819, 0x2b2b081919192b19,
|
|
3676
|
+
0x2b2b08192b2b192b, 0x2b2b082b08080808, 0x2b2b082b0808082b, 0x2b2b082b08082b08,
|
|
3677
|
+
0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808, 0x2b2b190819080808,
|
|
3678
|
+
0x2b2b19082b191919, 0x2b2b192b192b1919, 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b,
|
|
3679
|
+
0x2b2b2b08082b0808, 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808,
|
|
3680
|
+
0x2b2b2b082b2b2b08, 0x2b2b2b1908081908, 0x2b2b2b192b081908, 0x2b2b2b192b08192b,
|
|
3681
|
+
0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
|
|
3682
|
+
};
|
|
3425
3683
|
|
|
3426
|
-
|
|
3427
|
-
|
|
3428
|
-
|
|
3429
|
-
|
|
3430
|
-
|
|
3431
|
-
|
|
3432
|
-
|
|
3433
|
-
|
|
3684
|
+
constexpr constant static uint8_t ksigns_iq2xs[128] = {
|
|
3685
|
+
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
|
|
3686
|
+
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
|
|
3687
|
+
160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175,
|
|
3688
|
+
48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63,
|
|
3689
|
+
192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207,
|
|
3690
|
+
80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95,
|
|
3691
|
+
96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
|
|
3692
|
+
240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
|
|
3693
|
+
};
|
|
3434
3694
|
|
|
3435
|
-
|
|
3436
|
-
reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
|
|
3437
|
-
reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
|
|
3438
|
-
}
|
|
3439
|
-
}
|
|
3695
|
+
constexpr constant static uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
|
|
3440
3696
|
|
|
3441
|
-
|
|
3442
|
-
|
|
3443
|
-
|
|
3444
|
-
|
|
3445
|
-
|
|
3446
|
-
|
|
3447
|
-
|
|
3448
|
-
|
|
3697
|
+
void kernel_mul_mv_iq2_xxs_f32_impl(
|
|
3698
|
+
device const void * src0,
|
|
3699
|
+
device const float * src1,
|
|
3700
|
+
device float * dst,
|
|
3701
|
+
constant int64_t & ne00,
|
|
3702
|
+
constant int64_t & ne01,
|
|
3703
|
+
constant int64_t & ne02,
|
|
3704
|
+
constant int64_t & ne10,
|
|
3705
|
+
constant int64_t & ne12,
|
|
3706
|
+
constant int64_t & ne0,
|
|
3707
|
+
constant int64_t & ne1,
|
|
3708
|
+
constant uint & r2,
|
|
3709
|
+
constant uint & r3,
|
|
3710
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
3711
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
3712
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
3713
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
3449
3714
|
|
|
3450
|
-
|
|
3451
|
-
|
|
3715
|
+
const int nb = ne00/QK_K;
|
|
3716
|
+
const int r0 = tgpig.x;
|
|
3717
|
+
const int r1 = tgpig.y;
|
|
3718
|
+
const int im = tgpig.z;
|
|
3719
|
+
|
|
3720
|
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
|
3721
|
+
const int ib_row = first_row * nb;
|
|
3722
|
+
|
|
3723
|
+
const uint i12 = im%ne12;
|
|
3724
|
+
const uint i13 = im/ne12;
|
|
3725
|
+
|
|
3726
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
3727
|
+
|
|
3728
|
+
device const block_iq2_xxs * x = (device const block_iq2_xxs *) src0 + ib_row + offset0;
|
|
3729
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
3730
|
+
|
|
3731
|
+
float yl[32];
|
|
3732
|
+
float sumf[N_DST]={0.f}, all_sum;
|
|
3733
|
+
|
|
3734
|
+
const int nb32 = nb * (QK_K / 32);
|
|
3735
|
+
|
|
3736
|
+
threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
|
|
3737
|
+
threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256);
|
|
3738
|
+
{
|
|
3739
|
+
int nval = 4;
|
|
3740
|
+
int pos = (32*sgitg + tiisg)*nval;
|
|
3741
|
+
for (int i = 0; i < nval; ++i) values[pos + i] = iq2xxs_grid[pos + i];
|
|
3742
|
+
nval = 2;
|
|
3743
|
+
pos = (32*sgitg + tiisg)*nval;
|
|
3744
|
+
for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
|
|
3745
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
3746
|
+
}
|
|
3747
|
+
|
|
3748
|
+
#if QK_K == 256
|
|
3749
|
+
const int ix = tiisg;
|
|
3750
|
+
|
|
3751
|
+
device const float * y4 = y + 32 * ix;
|
|
3752
|
+
|
|
3753
|
+
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
|
3754
|
+
|
|
3755
|
+
for (int i = 0; i < 32; ++i) {
|
|
3756
|
+
yl[i] = y4[i];
|
|
3757
|
+
}
|
|
3758
|
+
|
|
3759
|
+
const int ibl = ib32 / (QK_K / 32);
|
|
3760
|
+
const int ib = ib32 % (QK_K / 32);
|
|
3761
|
+
|
|
3762
|
+
device const block_iq2_xxs * xr = x + ibl;
|
|
3763
|
+
device const uint16_t * q2 = xr->qs + 4 * ib;
|
|
3764
|
+
device const half * dh = &xr->d;
|
|
3765
|
+
|
|
3766
|
+
for (int row = 0; row < N_DST; row++) {
|
|
3767
|
+
|
|
3768
|
+
const float db = dh[0];
|
|
3769
|
+
device const uint8_t * aux8 = (device const uint8_t *)q2;
|
|
3770
|
+
const uint32_t aux32 = q2[2] | (q2[3] << 16);
|
|
3771
|
+
const float d = db * (0.5f + (aux32 >> 28));
|
|
3772
|
+
|
|
3773
|
+
float sum = 0;
|
|
3774
|
+
for (int l = 0; l < 4; ++l) {
|
|
3775
|
+
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + aux8[l]);
|
|
3776
|
+
const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127];
|
|
3777
|
+
for (int j = 0; j < 8; ++j) {
|
|
3778
|
+
sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
|
3779
|
+
}
|
|
3780
|
+
}
|
|
3781
|
+
sumf[row] += d * sum;
|
|
3782
|
+
|
|
3783
|
+
dh += nb*sizeof(block_iq2_xxs)/2;
|
|
3784
|
+
q2 += nb*sizeof(block_iq2_xxs)/2;
|
|
3785
|
+
}
|
|
3786
|
+
|
|
3787
|
+
y4 += 32 * 32;
|
|
3788
|
+
}
|
|
3789
|
+
#else
|
|
3790
|
+
// TODO
|
|
3791
|
+
#endif
|
|
3792
|
+
|
|
3793
|
+
for (int row = 0; row < N_DST; ++row) {
|
|
3794
|
+
all_sum = simd_sum(sumf[row]);
|
|
3795
|
+
if (tiisg == 0) {
|
|
3796
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
|
|
3797
|
+
}
|
|
3798
|
+
}
|
|
3799
|
+
}
|
|
3800
|
+
|
|
3801
|
+
[[host_name("kernel_mul_mv_iq2_xxs_f32")]]
|
|
3802
|
+
kernel void kernel_mul_mv_iq2_xxs_f32(
|
|
3803
|
+
device const void * src0,
|
|
3804
|
+
device const float * src1,
|
|
3805
|
+
device float * dst,
|
|
3806
|
+
constant int64_t & ne00,
|
|
3807
|
+
constant int64_t & ne01,
|
|
3808
|
+
constant int64_t & ne02,
|
|
3809
|
+
constant uint64_t & nb00,
|
|
3810
|
+
constant uint64_t & nb01,
|
|
3811
|
+
constant uint64_t & nb02,
|
|
3812
|
+
constant int64_t & ne10,
|
|
3813
|
+
constant int64_t & ne11,
|
|
3814
|
+
constant int64_t & ne12,
|
|
3815
|
+
constant uint64_t & nb10,
|
|
3816
|
+
constant uint64_t & nb11,
|
|
3817
|
+
constant uint64_t & nb12,
|
|
3818
|
+
constant int64_t & ne0,
|
|
3819
|
+
constant int64_t & ne1,
|
|
3820
|
+
constant uint & r2,
|
|
3821
|
+
constant uint & r3,
|
|
3822
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
3823
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
3824
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
3825
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
3826
|
+
|
|
3827
|
+
kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
|
3828
|
+
}
|
|
3829
|
+
|
|
3830
|
+
void kernel_mul_mv_iq2_xs_f32_impl(
|
|
3831
|
+
device const void * src0,
|
|
3832
|
+
device const float * src1,
|
|
3833
|
+
device float * dst,
|
|
3834
|
+
constant int64_t & ne00,
|
|
3835
|
+
constant int64_t & ne01,
|
|
3836
|
+
constant int64_t & ne02,
|
|
3837
|
+
constant int64_t & ne10,
|
|
3838
|
+
constant int64_t & ne12,
|
|
3839
|
+
constant int64_t & ne0,
|
|
3840
|
+
constant int64_t & ne1,
|
|
3841
|
+
constant uint & r2,
|
|
3842
|
+
constant uint & r3,
|
|
3843
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
3844
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
3845
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
3846
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
3847
|
+
|
|
3848
|
+
const int nb = ne00/QK_K;
|
|
3849
|
+
const int r0 = tgpig.x;
|
|
3850
|
+
const int r1 = tgpig.y;
|
|
3851
|
+
const int im = tgpig.z;
|
|
3852
|
+
|
|
3853
|
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
|
3854
|
+
const int ib_row = first_row * nb;
|
|
3855
|
+
|
|
3856
|
+
const uint i12 = im%ne12;
|
|
3857
|
+
const uint i13 = im/ne12;
|
|
3858
|
+
|
|
3859
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
3860
|
+
|
|
3861
|
+
device const block_iq2_xs * x = (device const block_iq2_xs *) src0 + ib_row + offset0;
|
|
3862
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
3863
|
+
|
|
3864
|
+
float yl[32];
|
|
3865
|
+
float sumf[N_DST]={0.f}, all_sum;
|
|
3866
|
+
|
|
3867
|
+
const int nb32 = nb * (QK_K / 32);
|
|
3868
|
+
|
|
3869
|
+
threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
|
|
3870
|
+
threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 512);
|
|
3871
|
+
{
|
|
3872
|
+
int nval = 8;
|
|
3873
|
+
int pos = (32*sgitg + tiisg)*nval;
|
|
3874
|
+
for (int i = 0; i < nval; ++i) values[pos + i] = iq2xs_grid[pos + i];
|
|
3875
|
+
nval = 2;
|
|
3876
|
+
pos = (32*sgitg + tiisg)*nval;
|
|
3877
|
+
for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
|
|
3878
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
3879
|
+
}
|
|
3880
|
+
|
|
3881
|
+
#if QK_K == 256
|
|
3882
|
+
const int ix = tiisg;
|
|
3883
|
+
|
|
3884
|
+
device const float * y4 = y + 32 * ix;
|
|
3885
|
+
|
|
3886
|
+
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
|
3887
|
+
|
|
3888
|
+
for (int i = 0; i < 32; ++i) {
|
|
3889
|
+
yl[i] = y4[i];
|
|
3890
|
+
}
|
|
3891
|
+
|
|
3892
|
+
const int ibl = ib32 / (QK_K / 32);
|
|
3893
|
+
const int ib = ib32 % (QK_K / 32);
|
|
3894
|
+
|
|
3895
|
+
device const block_iq2_xs * xr = x + ibl;
|
|
3896
|
+
device const uint16_t * q2 = xr->qs + 4 * ib;
|
|
3897
|
+
device const uint8_t * sc = xr->scales + ib;
|
|
3898
|
+
device const half * dh = &xr->d;
|
|
3899
|
+
|
|
3900
|
+
for (int row = 0; row < N_DST; row++) {
|
|
3901
|
+
|
|
3902
|
+
const float db = dh[0];
|
|
3903
|
+
const uint8_t ls1 = sc[0] & 0xf;
|
|
3904
|
+
const uint8_t ls2 = sc[0] >> 4;
|
|
3905
|
+
const float d1 = db * (0.5f + ls1);
|
|
3906
|
+
const float d2 = db * (0.5f + ls2);
|
|
3907
|
+
|
|
3908
|
+
float sum1 = 0, sum2 = 0;
|
|
3909
|
+
for (int l = 0; l < 2; ++l) {
|
|
3910
|
+
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511));
|
|
3911
|
+
const uint8_t signs = shared_signs[(q2[l] >> 9)];
|
|
3912
|
+
for (int j = 0; j < 8; ++j) {
|
|
3913
|
+
sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
|
3914
|
+
}
|
|
3915
|
+
}
|
|
3916
|
+
for (int l = 2; l < 4; ++l) {
|
|
3917
|
+
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511));
|
|
3918
|
+
const uint8_t signs = shared_signs[(q2[l] >> 9)];
|
|
3919
|
+
for (int j = 0; j < 8; ++j) {
|
|
3920
|
+
sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
|
3921
|
+
}
|
|
3922
|
+
}
|
|
3923
|
+
sumf[row] += d1 * sum1 + d2 * sum2;
|
|
3924
|
+
|
|
3925
|
+
dh += nb*sizeof(block_iq2_xs)/2;
|
|
3926
|
+
q2 += nb*sizeof(block_iq2_xs)/2;
|
|
3927
|
+
sc += nb*sizeof(block_iq2_xs);
|
|
3928
|
+
}
|
|
3929
|
+
|
|
3930
|
+
y4 += 32 * 32;
|
|
3931
|
+
}
|
|
3932
|
+
#else
|
|
3933
|
+
// TODO
|
|
3934
|
+
#endif
|
|
3935
|
+
|
|
3936
|
+
for (int row = 0; row < N_DST; ++row) {
|
|
3937
|
+
all_sum = simd_sum(sumf[row]);
|
|
3938
|
+
if (tiisg == 0) {
|
|
3939
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
|
|
3940
|
+
}
|
|
3941
|
+
}
|
|
3942
|
+
}
|
|
3943
|
+
|
|
3944
|
+
[[host_name("kernel_mul_mv_iq2_xs_f32")]]
|
|
3945
|
+
kernel void kernel_mul_mv_iq2_xs_f32(
|
|
3946
|
+
device const void * src0,
|
|
3947
|
+
device const float * src1,
|
|
3948
|
+
device float * dst,
|
|
3949
|
+
constant int64_t & ne00,
|
|
3950
|
+
constant int64_t & ne01,
|
|
3951
|
+
constant int64_t & ne02,
|
|
3952
|
+
constant uint64_t & nb00,
|
|
3953
|
+
constant uint64_t & nb01,
|
|
3954
|
+
constant uint64_t & nb02,
|
|
3955
|
+
constant int64_t & ne10,
|
|
3956
|
+
constant int64_t & ne11,
|
|
3957
|
+
constant int64_t & ne12,
|
|
3958
|
+
constant uint64_t & nb10,
|
|
3959
|
+
constant uint64_t & nb11,
|
|
3960
|
+
constant uint64_t & nb12,
|
|
3961
|
+
constant int64_t & ne0,
|
|
3962
|
+
constant int64_t & ne1,
|
|
3963
|
+
constant uint & r2,
|
|
3964
|
+
constant uint & r3,
|
|
3965
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
3966
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
3967
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
3968
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
3969
|
+
|
|
3970
|
+
kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
|
3971
|
+
}
|
|
3972
|
+
|
|
3973
|
+
//============================= templates and their specializations =============================
|
|
3974
|
+
|
|
3975
|
+
// NOTE: this is not dequantizing - we are simply fitting the template
|
|
3976
|
+
template <typename type4x4>
|
|
3977
|
+
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
|
|
3978
|
+
float4x4 temp = *(((device float4x4 *)src));
|
|
3979
|
+
for (int i = 0; i < 16; i++){
|
|
3980
|
+
reg[i/4][i%4] = temp[i/4][i%4];
|
|
3981
|
+
}
|
|
3982
|
+
}
|
|
3983
|
+
|
|
3984
|
+
template <typename type4x4>
|
|
3985
|
+
void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
|
|
3986
|
+
half4x4 temp = *(((device half4x4 *)src));
|
|
3987
|
+
for (int i = 0; i < 16; i++){
|
|
3988
|
+
reg[i/4][i%4] = temp[i/4][i%4];
|
|
3989
|
+
}
|
|
3990
|
+
}
|
|
3991
|
+
|
|
3992
|
+
template <typename type4x4>
|
|
3993
|
+
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
|
|
3994
|
+
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
|
3995
|
+
const float d1 = il ? (xb->d / 16.h) : xb->d;
|
|
3996
|
+
const float d2 = d1 / 256.f;
|
|
3997
|
+
const float md = -8.h * xb->d;
|
|
3998
|
+
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
|
3999
|
+
const ushort mask1 = mask0 << 8;
|
|
4000
|
+
|
|
4001
|
+
for (int i=0;i<8;i++) {
|
|
4002
|
+
reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
|
|
4003
|
+
reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
|
|
4004
|
+
}
|
|
4005
|
+
}
|
|
4006
|
+
|
|
4007
|
+
template <typename type4x4>
|
|
4008
|
+
void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
|
|
4009
|
+
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
|
|
4010
|
+
const float d1 = il ? (xb->d / 16.h) : xb->d;
|
|
4011
|
+
const float d2 = d1 / 256.f;
|
|
4012
|
+
const float m = xb->m;
|
|
4013
|
+
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
|
4014
|
+
const ushort mask1 = mask0 << 8;
|
|
4015
|
+
|
|
4016
|
+
for (int i=0;i<8;i++) {
|
|
4017
|
+
reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
|
|
3452
4018
|
reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
|
|
3453
4019
|
}
|
|
3454
4020
|
}
|
|
@@ -3514,7 +4080,7 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
|
|
|
3514
4080
|
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
|
3515
4081
|
const half d = xb->d;
|
|
3516
4082
|
|
|
3517
|
-
for (int i=0;i<16;i++) {
|
|
4083
|
+
for (int i = 0; i < 16; i++) {
|
|
3518
4084
|
reg[i/4][i%4] = (qs[i + 16*il] * d);
|
|
3519
4085
|
}
|
|
3520
4086
|
}
|
|
@@ -3556,8 +4122,8 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
|
|
|
3556
4122
|
uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
|
|
3557
4123
|
int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
|
|
3558
4124
|
: (scale_2&kmask2) | ((scale_1&kmask1) << 4);
|
|
3559
|
-
|
|
3560
|
-
const
|
|
4125
|
+
float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
|
|
4126
|
+
const float ml = 4.f * dl;
|
|
3561
4127
|
|
|
3562
4128
|
il = (il/2) & 3;
|
|
3563
4129
|
const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
|
@@ -3624,7 +4190,7 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
|
|
|
3624
4190
|
uint8_t ul = 1 << (il/2);
|
|
3625
4191
|
il = il & 3;
|
|
3626
4192
|
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
|
3627
|
-
const float d = il < 2 ? xb->d : xb->d / 16.
|
|
4193
|
+
const float d = il < 2 ? xb->d : xb->d / 16.f;
|
|
3628
4194
|
const float min = xb->dmin;
|
|
3629
4195
|
const float dl = d * sc[0];
|
|
3630
4196
|
const float ml = min * sc[1];
|
|
@@ -3657,17 +4223,17 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
|
|
|
3657
4223
|
#if QK_K == 256
|
|
3658
4224
|
ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
|
|
3659
4225
|
qh = qh + 32*(il/8) + 16*(il&1);
|
|
3660
|
-
|
|
4226
|
+
float sc = scales[(il%2) + 2 * ((il/2))];
|
|
3661
4227
|
il = (il/2) & 3;
|
|
3662
4228
|
#else
|
|
3663
4229
|
ql = ql + 16 * (il&1);
|
|
3664
|
-
|
|
4230
|
+
float sc = scales[il];
|
|
3665
4231
|
#endif
|
|
3666
4232
|
const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
|
3667
4233
|
const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
|
|
3668
|
-
const
|
|
3669
|
-
const
|
|
3670
|
-
const
|
|
4234
|
+
const float coef = il>1 ? 1.f/16.f : 1.f;
|
|
4235
|
+
const float ml = d_all * sc * 32.f;
|
|
4236
|
+
const float dl = d_all * sc * coef;
|
|
3671
4237
|
for (int i = 0; i < 16; ++i) {
|
|
3672
4238
|
const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
|
|
3673
4239
|
: ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
|
|
@@ -3675,6 +4241,52 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
|
|
|
3675
4241
|
}
|
|
3676
4242
|
}
|
|
3677
4243
|
|
|
4244
|
+
template <typename type4x4>
|
|
4245
|
+
void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {
|
|
4246
|
+
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
|
4247
|
+
const float d = xb->d;
|
|
4248
|
+
const int ib32 = il/2;
|
|
4249
|
+
il = il%2;
|
|
4250
|
+
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
|
4251
|
+
// each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's.
|
|
4252
|
+
device const uint16_t * q2 = xb->qs + 4*ib32;
|
|
4253
|
+
const uint32_t aux32_g = q2[0] | (q2[1] << 16);
|
|
4254
|
+
const uint32_t aux32_s = q2[2] | (q2[3] << 16);
|
|
4255
|
+
thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
|
|
4256
|
+
const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
|
|
4257
|
+
constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
|
|
4258
|
+
uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
|
|
4259
|
+
for (int i = 0; i < 8; ++i) {
|
|
4260
|
+
reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
|
|
4261
|
+
}
|
|
4262
|
+
grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
|
|
4263
|
+
signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
|
|
4264
|
+
for (int i = 0; i < 8; ++i) {
|
|
4265
|
+
reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
|
|
4266
|
+
}
|
|
4267
|
+
}
|
|
4268
|
+
|
|
4269
|
+
template <typename type4x4>
|
|
4270
|
+
void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) {
|
|
4271
|
+
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
|
4272
|
+
const float d = xb->d;
|
|
4273
|
+
const int ib32 = il/2;
|
|
4274
|
+
il = il%2;
|
|
4275
|
+
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
|
4276
|
+
device const uint16_t * q2 = xb->qs + 4*ib32;
|
|
4277
|
+
const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
|
|
4278
|
+
constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511));
|
|
4279
|
+
uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9];
|
|
4280
|
+
for (int i = 0; i < 8; ++i) {
|
|
4281
|
+
reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
|
|
4282
|
+
}
|
|
4283
|
+
grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511));
|
|
4284
|
+
signs = ksigns_iq2xs[q2[2*il+1] >> 9];
|
|
4285
|
+
for (int i = 0; i < 8; ++i) {
|
|
4286
|
+
reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
|
|
4287
|
+
}
|
|
4288
|
+
}
|
|
4289
|
+
|
|
3678
4290
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
|
3679
4291
|
kernel void kernel_get_rows(
|
|
3680
4292
|
device const void * src0,
|
|
@@ -3755,48 +4367,212 @@ kernel void kernel_get_rows_f16(
|
|
|
3755
4367
|
const int64_t i10 = tgpig.x;
|
|
3756
4368
|
const int64_t i11 = tgpig.y;
|
|
3757
4369
|
|
|
3758
|
-
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
|
4370
|
+
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
|
4371
|
+
|
|
4372
|
+
const int64_t i02 = i11;
|
|
4373
|
+
|
|
4374
|
+
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
|
4375
|
+
((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
|
4376
|
+
((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
|
4377
|
+
}
|
|
4378
|
+
}
|
|
4379
|
+
|
|
4380
|
+
kernel void kernel_get_rows_i32(
|
|
4381
|
+
device const void * src0,
|
|
4382
|
+
device const char * src1,
|
|
4383
|
+
device int32_t * dst,
|
|
4384
|
+
constant int64_t & ne00,
|
|
4385
|
+
constant uint64_t & nb01,
|
|
4386
|
+
constant uint64_t & nb02,
|
|
4387
|
+
constant int64_t & ne10,
|
|
4388
|
+
constant uint64_t & nb10,
|
|
4389
|
+
constant uint64_t & nb11,
|
|
4390
|
+
constant uint64_t & nb1,
|
|
4391
|
+
constant uint64_t & nb2,
|
|
4392
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4393
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
4394
|
+
uint3 tptg [[threads_per_threadgroup]]) {
|
|
4395
|
+
const int64_t i10 = tgpig.x;
|
|
4396
|
+
const int64_t i11 = tgpig.y;
|
|
4397
|
+
|
|
4398
|
+
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
|
4399
|
+
|
|
4400
|
+
const int64_t i02 = i11;
|
|
4401
|
+
|
|
4402
|
+
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
|
4403
|
+
((device int32_t *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
|
4404
|
+
((device int32_t *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
|
4405
|
+
}
|
|
4406
|
+
}
|
|
4407
|
+
|
|
4408
|
+
|
|
4409
|
+
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
|
|
4410
|
+
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
|
|
4411
|
+
#define BLOCK_SIZE_K 32
|
|
4412
|
+
#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
|
|
4413
|
+
#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
|
|
4414
|
+
#define THREAD_PER_BLOCK 128
|
|
4415
|
+
#define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers
|
|
4416
|
+
#define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers
|
|
4417
|
+
#define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8
|
|
4418
|
+
#define SG_MAT_ROW 8
|
|
4419
|
+
|
|
4420
|
+
// each block_q contains 16*nl weights
|
|
4421
|
+
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
|
4422
|
+
void kernel_mul_mm_impl(device const uchar * src0,
|
|
4423
|
+
device const uchar * src1,
|
|
4424
|
+
device float * dst,
|
|
4425
|
+
constant int64_t & ne00,
|
|
4426
|
+
constant int64_t & ne02,
|
|
4427
|
+
constant uint64_t & nb01,
|
|
4428
|
+
constant uint64_t & nb02,
|
|
4429
|
+
constant int64_t & ne12,
|
|
4430
|
+
constant uint64_t & nb10,
|
|
4431
|
+
constant uint64_t & nb11,
|
|
4432
|
+
constant uint64_t & nb12,
|
|
4433
|
+
constant int64_t & ne0,
|
|
4434
|
+
constant int64_t & ne1,
|
|
4435
|
+
constant uint & r2,
|
|
4436
|
+
constant uint & r3,
|
|
4437
|
+
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
|
4438
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4439
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
4440
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
4441
|
+
|
|
4442
|
+
threadgroup half * sa = (threadgroup half *)(shared_memory);
|
|
4443
|
+
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
|
|
4444
|
+
|
|
4445
|
+
const uint r0 = tgpig.y;
|
|
4446
|
+
const uint r1 = tgpig.x;
|
|
4447
|
+
const uint im = tgpig.z;
|
|
4448
|
+
|
|
4449
|
+
// if this block is of 64x32 shape or smaller
|
|
4450
|
+
short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
|
|
4451
|
+
short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
|
|
4452
|
+
|
|
4453
|
+
// a thread shouldn't load data outside of the matrix
|
|
4454
|
+
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
|
4455
|
+
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
|
4456
|
+
|
|
4457
|
+
simdgroup_half8x8 ma[4];
|
|
4458
|
+
simdgroup_float8x8 mb[2];
|
|
4459
|
+
simdgroup_float8x8 c_res[8];
|
|
4460
|
+
for (int i = 0; i < 8; i++){
|
|
4461
|
+
c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
|
|
4462
|
+
}
|
|
4463
|
+
|
|
4464
|
+
short il = (tiitg % THREAD_PER_ROW);
|
|
4465
|
+
|
|
4466
|
+
const uint i12 = im%ne12;
|
|
4467
|
+
const uint i13 = im/ne12;
|
|
4468
|
+
|
|
4469
|
+
uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
|
|
4470
|
+
ushort offset1 = il/nl;
|
|
4471
|
+
|
|
4472
|
+
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
|
|
4473
|
+
device const float * y = (device const float *)(src1
|
|
4474
|
+
+ nb12 * im
|
|
4475
|
+
+ nb11 * (r1 * BLOCK_SIZE_N + thread_col)
|
|
4476
|
+
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
|
4477
|
+
|
|
4478
|
+
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
|
4479
|
+
// load data and store to threadgroup memory
|
|
4480
|
+
half4x4 temp_a;
|
|
4481
|
+
dequantize_func(x, il, temp_a);
|
|
4482
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
4483
|
+
|
|
4484
|
+
#pragma unroll(16)
|
|
4485
|
+
for (int i = 0; i < 16; i++) {
|
|
4486
|
+
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
|
|
4487
|
+
+ (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
|
|
4488
|
+
+ (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
|
|
4489
|
+
}
|
|
4490
|
+
|
|
4491
|
+
*(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
|
|
4492
|
+
|
|
4493
|
+
il = (il + 2 < nl) ? il + 2 : il % 2;
|
|
4494
|
+
x = (il < 2) ? x + (2+nl-1)/nl : x;
|
|
4495
|
+
y += BLOCK_SIZE_K;
|
|
4496
|
+
|
|
4497
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
4498
|
+
|
|
4499
|
+
// load matrices from threadgroup memory and conduct outer products
|
|
4500
|
+
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
|
|
4501
|
+
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
|
|
4502
|
+
|
|
4503
|
+
#pragma unroll(4)
|
|
4504
|
+
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
|
|
4505
|
+
#pragma unroll(4)
|
|
4506
|
+
for (int i = 0; i < 4; i++) {
|
|
4507
|
+
simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
|
|
4508
|
+
}
|
|
4509
|
+
simdgroup_barrier(mem_flags::mem_none);
|
|
4510
|
+
#pragma unroll(2)
|
|
4511
|
+
for (int i = 0; i < 2; i++) {
|
|
4512
|
+
simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
|
|
4513
|
+
}
|
|
4514
|
+
|
|
4515
|
+
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
|
|
4516
|
+
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
|
|
3759
4517
|
|
|
3760
|
-
|
|
4518
|
+
#pragma unroll(8)
|
|
4519
|
+
for (int i = 0; i < 8; i++){
|
|
4520
|
+
simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
|
|
4521
|
+
}
|
|
4522
|
+
}
|
|
4523
|
+
}
|
|
3761
4524
|
|
|
3762
|
-
|
|
3763
|
-
|
|
3764
|
-
|
|
4525
|
+
if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
|
|
4526
|
+
device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
|
|
4527
|
+
+ (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
|
|
4528
|
+
for (int i = 0; i < 8; i++) {
|
|
4529
|
+
simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
|
|
4530
|
+
}
|
|
4531
|
+
} else {
|
|
4532
|
+
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
|
4533
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
4534
|
+
threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
|
|
4535
|
+
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
|
|
4536
|
+
for (int i = 0; i < 8; i++) {
|
|
4537
|
+
simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
|
|
4538
|
+
}
|
|
4539
|
+
|
|
4540
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
4541
|
+
|
|
4542
|
+
device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
|
|
4543
|
+
if (sgitg == 0) {
|
|
4544
|
+
for (int i = 0; i < n_rows; i++) {
|
|
4545
|
+
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
|
|
4546
|
+
*(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
|
|
4547
|
+
}
|
|
4548
|
+
}
|
|
4549
|
+
}
|
|
3765
4550
|
}
|
|
3766
4551
|
}
|
|
3767
4552
|
|
|
3768
|
-
|
|
3769
|
-
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
|
|
3770
|
-
#define BLOCK_SIZE_K 32
|
|
3771
|
-
#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
|
|
3772
|
-
#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
|
|
3773
|
-
#define THREAD_PER_BLOCK 128
|
|
3774
|
-
#define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers
|
|
3775
|
-
#define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers
|
|
3776
|
-
#define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8
|
|
3777
|
-
#define SG_MAT_ROW 8
|
|
3778
|
-
|
|
3779
|
-
// each block_q contains 16*nl weights
|
|
4553
|
+
// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in src1ids
|
|
3780
4554
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
|
3781
|
-
void
|
|
3782
|
-
|
|
3783
|
-
|
|
3784
|
-
|
|
3785
|
-
|
|
3786
|
-
|
|
3787
|
-
|
|
3788
|
-
|
|
3789
|
-
|
|
3790
|
-
|
|
3791
|
-
|
|
3792
|
-
|
|
3793
|
-
|
|
3794
|
-
|
|
3795
|
-
|
|
3796
|
-
|
|
3797
|
-
|
|
3798
|
-
|
|
3799
|
-
|
|
4555
|
+
void kernel_mul_mm_id_impl(
|
|
4556
|
+
device const uchar * src0,
|
|
4557
|
+
device const uchar * src1,
|
|
4558
|
+
thread short * src1ids,
|
|
4559
|
+
device float * dst,
|
|
4560
|
+
constant int64_t & ne00,
|
|
4561
|
+
constant int64_t & ne02,
|
|
4562
|
+
constant uint64_t & nb01,
|
|
4563
|
+
constant uint64_t & nb02,
|
|
4564
|
+
constant int64_t & ne12,
|
|
4565
|
+
constant uint64_t & nb10,
|
|
4566
|
+
constant uint64_t & nb11,
|
|
4567
|
+
constant uint64_t & nb12,
|
|
4568
|
+
constant int64_t & ne0,
|
|
4569
|
+
int64_t ne1,
|
|
4570
|
+
constant uint & r2,
|
|
4571
|
+
constant uint & r3,
|
|
4572
|
+
threadgroup uchar * shared_memory,
|
|
4573
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4574
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
4575
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
3800
4576
|
|
|
3801
4577
|
threadgroup half * sa = (threadgroup half *)(shared_memory);
|
|
3802
4578
|
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
|
|
@@ -3805,6 +4581,8 @@ void kernel_mul_mm_impl(device const uchar * src0,
|
|
|
3805
4581
|
const uint r1 = tgpig.x;
|
|
3806
4582
|
const uint im = tgpig.z;
|
|
3807
4583
|
|
|
4584
|
+
if (r1 * BLOCK_SIZE_N >= ne1) return;
|
|
4585
|
+
|
|
3808
4586
|
// if this block is of 64x32 shape or smaller
|
|
3809
4587
|
short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
|
|
3810
4588
|
short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
|
|
@@ -3831,7 +4609,7 @@ void kernel_mul_mm_impl(device const uchar * src0,
|
|
|
3831
4609
|
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
|
|
3832
4610
|
device const float * y = (device const float *)(src1
|
|
3833
4611
|
+ nb12 * im
|
|
3834
|
-
+ nb11 *
|
|
4612
|
+
+ nb11 * src1ids[r1 * BLOCK_SIZE_N + thread_col]
|
|
3835
4613
|
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
|
3836
4614
|
|
|
3837
4615
|
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
|
@@ -3840,7 +4618,6 @@ void kernel_mul_mm_impl(device const uchar * src0,
|
|
|
3840
4618
|
dequantize_func(x, il, temp_a);
|
|
3841
4619
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
3842
4620
|
|
|
3843
|
-
#pragma unroll(16)
|
|
3844
4621
|
for (int i = 0; i < 16; i++) {
|
|
3845
4622
|
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
|
|
3846
4623
|
+ (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
|
|
@@ -3859,14 +4636,11 @@ void kernel_mul_mm_impl(device const uchar * src0,
|
|
|
3859
4636
|
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
|
|
3860
4637
|
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
|
|
3861
4638
|
|
|
3862
|
-
#pragma unroll(4)
|
|
3863
4639
|
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
|
|
3864
|
-
#pragma unroll(4)
|
|
3865
4640
|
for (int i = 0; i < 4; i++) {
|
|
3866
4641
|
simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
|
|
3867
4642
|
}
|
|
3868
4643
|
simdgroup_barrier(mem_flags::mem_none);
|
|
3869
|
-
#pragma unroll(2)
|
|
3870
4644
|
for (int i = 0; i < 2; i++) {
|
|
3871
4645
|
simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
|
|
3872
4646
|
}
|
|
@@ -3874,21 +4648,13 @@ void kernel_mul_mm_impl(device const uchar * src0,
|
|
|
3874
4648
|
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
|
|
3875
4649
|
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
|
|
3876
4650
|
|
|
3877
|
-
#pragma unroll(8)
|
|
3878
4651
|
for (int i = 0; i < 8; i++){
|
|
3879
4652
|
simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
|
|
3880
4653
|
}
|
|
3881
4654
|
}
|
|
3882
4655
|
}
|
|
3883
4656
|
|
|
3884
|
-
|
|
3885
|
-
device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
|
|
3886
|
-
+ (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
|
|
3887
|
-
for (int i = 0; i < 8; i++) {
|
|
3888
|
-
simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
|
|
3889
|
-
}
|
|
3890
|
-
} else {
|
|
3891
|
-
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
|
4657
|
+
{
|
|
3892
4658
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
3893
4659
|
threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
|
|
3894
4660
|
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
|
|
@@ -3898,11 +4664,11 @@ void kernel_mul_mm_impl(device const uchar * src0,
|
|
|
3898
4664
|
|
|
3899
4665
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
3900
4666
|
|
|
3901
|
-
device float * C = dst + (BLOCK_SIZE_M * r0) +
|
|
4667
|
+
device float * C = dst + (BLOCK_SIZE_M * r0) + im*ne1*ne0;
|
|
3902
4668
|
if (sgitg == 0) {
|
|
3903
4669
|
for (int i = 0; i < n_rows; i++) {
|
|
3904
4670
|
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
|
|
3905
|
-
*(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
|
|
4671
|
+
*(C + i + src1ids[j + r1*BLOCK_SIZE_N] * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
|
|
3906
4672
|
}
|
|
3907
4673
|
}
|
|
3908
4674
|
}
|
|
@@ -3915,12 +4681,12 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
|
3915
4681
|
device float * dst,
|
|
3916
4682
|
constant int64_t & ne00,
|
|
3917
4683
|
constant int64_t & ne02,
|
|
3918
|
-
constant
|
|
3919
|
-
constant
|
|
4684
|
+
constant uint64_t & nb01,
|
|
4685
|
+
constant uint64_t & nb02,
|
|
3920
4686
|
constant int64_t & ne12,
|
|
3921
|
-
constant
|
|
3922
|
-
constant
|
|
3923
|
-
constant
|
|
4687
|
+
constant uint64_t & nb10,
|
|
4688
|
+
constant uint64_t & nb11,
|
|
4689
|
+
constant uint64_t & nb12,
|
|
3924
4690
|
constant int64_t & ne0,
|
|
3925
4691
|
constant int64_t & ne1,
|
|
3926
4692
|
constant uint & r2,
|
|
@@ -3955,20 +4721,20 @@ template<typename block_q, short nl, void (*dequantize_func)(device const block_
|
|
|
3955
4721
|
kernel void kernel_mul_mm_id(
|
|
3956
4722
|
device const uchar * ids,
|
|
3957
4723
|
device const uchar * src1,
|
|
3958
|
-
device
|
|
3959
|
-
constant
|
|
4724
|
+
device float * dst,
|
|
4725
|
+
constant uint64_t & nbi1,
|
|
3960
4726
|
constant int64_t & ne00,
|
|
3961
4727
|
constant int64_t & ne02,
|
|
3962
|
-
constant
|
|
3963
|
-
constant
|
|
4728
|
+
constant uint64_t & nb01,
|
|
4729
|
+
constant uint64_t & nb02,
|
|
3964
4730
|
constant int64_t & ne12,
|
|
3965
4731
|
constant int64_t & ne13,
|
|
3966
|
-
constant
|
|
3967
|
-
constant
|
|
3968
|
-
constant
|
|
4732
|
+
constant uint64_t & nb10,
|
|
4733
|
+
constant uint64_t & nb11,
|
|
4734
|
+
constant uint64_t & nb12,
|
|
3969
4735
|
constant int64_t & ne0,
|
|
3970
4736
|
constant int64_t & ne1,
|
|
3971
|
-
constant
|
|
4737
|
+
constant uint64_t & nb1,
|
|
3972
4738
|
constant uint & r2,
|
|
3973
4739
|
constant uint & r3,
|
|
3974
4740
|
constant int & idx,
|
|
@@ -3984,18 +4750,28 @@ kernel void kernel_mul_mm_id(
|
|
|
3984
4750
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
3985
4751
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
3986
4752
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
3987
|
-
device const uchar *
|
|
4753
|
+
device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
3988
4754
|
|
|
3989
|
-
|
|
4755
|
+
// expert id
|
|
4756
|
+
const int32_t id = tgpig.z/(ne12*ne13);
|
|
3990
4757
|
|
|
3991
4758
|
tgpig.z = tgpig.z%(ne12*ne13);
|
|
3992
4759
|
|
|
3993
|
-
|
|
4760
|
+
// row indices of src1 for expert id
|
|
4761
|
+
int64_t _ne1 = 0;
|
|
4762
|
+
short src1ids[512];
|
|
3994
4763
|
|
|
3995
|
-
|
|
3996
|
-
|
|
3997
|
-
|
|
3998
|
-
|
|
4764
|
+
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
|
4765
|
+
if (((device int32_t *) (ids + i1*nbi1))[idx] == id) {
|
|
4766
|
+
src1ids[_ne1++] = i1;
|
|
4767
|
+
}
|
|
4768
|
+
}
|
|
4769
|
+
|
|
4770
|
+
kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
|
|
4771
|
+
src0s[id],
|
|
4772
|
+
src1,
|
|
4773
|
+
src1ids,
|
|
4774
|
+
dst,
|
|
3999
4775
|
ne00,
|
|
4000
4776
|
ne02,
|
|
4001
4777
|
nb01,
|
|
@@ -4005,7 +4781,7 @@ kernel void kernel_mul_mm_id(
|
|
|
4005
4781
|
nb11,
|
|
4006
4782
|
nb12,
|
|
4007
4783
|
ne0,
|
|
4008
|
-
|
|
4784
|
+
_ne1,
|
|
4009
4785
|
r2,
|
|
4010
4786
|
r3,
|
|
4011
4787
|
shared_memory,
|
|
@@ -4050,6 +4826,8 @@ template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows
|
|
|
4050
4826
|
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
|
|
4051
4827
|
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
|
|
4052
4828
|
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
|
|
4829
|
+
template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
|
4830
|
+
template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
|
4053
4831
|
|
|
4054
4832
|
//
|
|
4055
4833
|
// matrix-matrix multiplication
|
|
@@ -4061,12 +4839,12 @@ typedef void (mat_mm_t)(
|
|
|
4061
4839
|
device float * dst,
|
|
4062
4840
|
constant int64_t & ne00,
|
|
4063
4841
|
constant int64_t & ne02,
|
|
4064
|
-
constant
|
|
4065
|
-
constant
|
|
4842
|
+
constant uint64_t & nb01,
|
|
4843
|
+
constant uint64_t & nb02,
|
|
4066
4844
|
constant int64_t & ne12,
|
|
4067
|
-
constant
|
|
4068
|
-
constant
|
|
4069
|
-
constant
|
|
4845
|
+
constant uint64_t & nb10,
|
|
4846
|
+
constant uint64_t & nb11,
|
|
4847
|
+
constant uint64_t & nb12,
|
|
4070
4848
|
constant int64_t & ne0,
|
|
4071
4849
|
constant int64_t & ne1,
|
|
4072
4850
|
constant uint & r2,
|
|
@@ -4086,6 +4864,8 @@ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
|
|
|
4086
4864
|
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
|
|
4087
4865
|
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
|
|
4088
4866
|
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|
|
4867
|
+
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
|
4868
|
+
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
|
4089
4869
|
|
|
4090
4870
|
//
|
|
4091
4871
|
// indirect matrix-matrix multiplication
|
|
@@ -4094,20 +4874,20 @@ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
|
|
|
4094
4874
|
typedef void (mat_mm_id_t)(
|
|
4095
4875
|
device const uchar * ids,
|
|
4096
4876
|
device const uchar * src1,
|
|
4097
|
-
device
|
|
4098
|
-
constant
|
|
4877
|
+
device float * dst,
|
|
4878
|
+
constant uint64_t & nbi1,
|
|
4099
4879
|
constant int64_t & ne00,
|
|
4100
4880
|
constant int64_t & ne02,
|
|
4101
|
-
constant
|
|
4102
|
-
constant
|
|
4881
|
+
constant uint64_t & nb01,
|
|
4882
|
+
constant uint64_t & nb02,
|
|
4103
4883
|
constant int64_t & ne12,
|
|
4104
4884
|
constant int64_t & ne13,
|
|
4105
|
-
constant
|
|
4106
|
-
constant
|
|
4107
|
-
constant
|
|
4885
|
+
constant uint64_t & nb10,
|
|
4886
|
+
constant uint64_t & nb11,
|
|
4887
|
+
constant uint64_t & nb12,
|
|
4108
4888
|
constant int64_t & ne0,
|
|
4109
4889
|
constant int64_t & ne1,
|
|
4110
|
-
constant
|
|
4890
|
+
constant uint64_t & nb1,
|
|
4111
4891
|
constant uint & r2,
|
|
4112
4892
|
constant uint & r3,
|
|
4113
4893
|
constant int & idx,
|
|
@@ -4134,6 +4914,8 @@ template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mu
|
|
|
4134
4914
|
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
|
|
4135
4915
|
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
|
|
4136
4916
|
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
|
|
4917
|
+
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
|
4918
|
+
template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
|
4137
4919
|
|
|
4138
4920
|
//
|
|
4139
4921
|
// matrix-vector multiplication
|
|
@@ -4143,8 +4925,8 @@ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mu
|
|
|
4143
4925
|
kernel void kernel_mul_mv_id_f32_f32(
|
|
4144
4926
|
device const char * ids,
|
|
4145
4927
|
device const char * src1,
|
|
4146
|
-
device
|
|
4147
|
-
constant
|
|
4928
|
+
device float * dst,
|
|
4929
|
+
constant uint64_t & nbi1,
|
|
4148
4930
|
constant int64_t & ne00,
|
|
4149
4931
|
constant int64_t & ne01,
|
|
4150
4932
|
constant int64_t & ne02,
|
|
@@ -4160,7 +4942,7 @@ kernel void kernel_mul_mv_id_f32_f32(
|
|
|
4160
4942
|
constant uint64_t & nb12,
|
|
4161
4943
|
constant int64_t & ne0,
|
|
4162
4944
|
constant int64_t & ne1,
|
|
4163
|
-
constant
|
|
4945
|
+
constant uint64_t & nb1,
|
|
4164
4946
|
constant uint & r2,
|
|
4165
4947
|
constant uint & r3,
|
|
4166
4948
|
constant int & idx,
|
|
@@ -4187,7 +4969,7 @@ kernel void kernel_mul_mv_id_f32_f32(
|
|
|
4187
4969
|
kernel_mul_mv_f32_f32_impl(
|
|
4188
4970
|
src0[id],
|
|
4189
4971
|
src1 + bid*nb11,
|
|
4190
|
-
|
|
4972
|
+
dst + bid*ne0,
|
|
4191
4973
|
ne00,
|
|
4192
4974
|
ne01,
|
|
4193
4975
|
ne02,
|
|
@@ -4212,8 +4994,8 @@ kernel void kernel_mul_mv_id_f32_f32(
|
|
|
4212
4994
|
kernel void kernel_mul_mv_id_f16_f32(
|
|
4213
4995
|
device const char * ids,
|
|
4214
4996
|
device const char * src1,
|
|
4215
|
-
device
|
|
4216
|
-
constant
|
|
4997
|
+
device float * dst,
|
|
4998
|
+
constant uint64_t & nbi1,
|
|
4217
4999
|
constant int64_t & ne00,
|
|
4218
5000
|
constant int64_t & ne01,
|
|
4219
5001
|
constant int64_t & ne02,
|
|
@@ -4229,7 +5011,7 @@ kernel void kernel_mul_mv_id_f16_f32(
|
|
|
4229
5011
|
constant uint64_t & nb12,
|
|
4230
5012
|
constant int64_t & ne0,
|
|
4231
5013
|
constant int64_t & ne1,
|
|
4232
|
-
constant
|
|
5014
|
+
constant uint64_t & nb1,
|
|
4233
5015
|
constant uint & r2,
|
|
4234
5016
|
constant uint & r3,
|
|
4235
5017
|
constant int & idx,
|
|
@@ -4256,7 +5038,7 @@ kernel void kernel_mul_mv_id_f16_f32(
|
|
|
4256
5038
|
kernel_mul_mv_f16_f32_impl(
|
|
4257
5039
|
src0[id],
|
|
4258
5040
|
src1 + bid*nb11,
|
|
4259
|
-
|
|
5041
|
+
dst + bid*ne0,
|
|
4260
5042
|
ne00,
|
|
4261
5043
|
ne01,
|
|
4262
5044
|
ne02,
|
|
@@ -4281,8 +5063,8 @@ kernel void kernel_mul_mv_id_f16_f32(
|
|
|
4281
5063
|
kernel void kernel_mul_mv_id_q8_0_f32(
|
|
4282
5064
|
device const char * ids,
|
|
4283
5065
|
device const char * src1,
|
|
4284
|
-
device
|
|
4285
|
-
constant
|
|
5066
|
+
device float * dst,
|
|
5067
|
+
constant uint64_t & nbi1,
|
|
4286
5068
|
constant int64_t & ne00,
|
|
4287
5069
|
constant int64_t & ne01,
|
|
4288
5070
|
constant int64_t & ne02,
|
|
@@ -4298,7 +5080,7 @@ kernel void kernel_mul_mv_id_q8_0_f32(
|
|
|
4298
5080
|
constant uint64_t & nb12,
|
|
4299
5081
|
constant int64_t & ne0,
|
|
4300
5082
|
constant int64_t & ne1,
|
|
4301
|
-
constant
|
|
5083
|
+
constant uint64_t & nb1,
|
|
4302
5084
|
constant uint & r2,
|
|
4303
5085
|
constant uint & r3,
|
|
4304
5086
|
constant int & idx,
|
|
@@ -4325,7 +5107,7 @@ kernel void kernel_mul_mv_id_q8_0_f32(
|
|
|
4325
5107
|
kernel_mul_mv_q8_0_f32_impl(
|
|
4326
5108
|
src0[id],
|
|
4327
5109
|
(device const float *) (src1 + bid*nb11),
|
|
4328
|
-
|
|
5110
|
+
dst + bid*ne0,
|
|
4329
5111
|
ne00,
|
|
4330
5112
|
ne01,
|
|
4331
5113
|
ne02,
|
|
@@ -4344,8 +5126,8 @@ kernel void kernel_mul_mv_id_q8_0_f32(
|
|
|
4344
5126
|
kernel void kernel_mul_mv_id_q4_0_f32(
|
|
4345
5127
|
device const char * ids,
|
|
4346
5128
|
device const char * src1,
|
|
4347
|
-
device
|
|
4348
|
-
constant
|
|
5129
|
+
device float * dst,
|
|
5130
|
+
constant uint64_t & nbi1,
|
|
4349
5131
|
constant int64_t & ne00,
|
|
4350
5132
|
constant int64_t & ne01,
|
|
4351
5133
|
constant int64_t & ne02,
|
|
@@ -4361,7 +5143,7 @@ kernel void kernel_mul_mv_id_q4_0_f32(
|
|
|
4361
5143
|
constant uint64_t & nb12,
|
|
4362
5144
|
constant int64_t & ne0,
|
|
4363
5145
|
constant int64_t & ne1,
|
|
4364
|
-
constant
|
|
5146
|
+
constant uint64_t & nb1,
|
|
4365
5147
|
constant uint & r2,
|
|
4366
5148
|
constant uint & r3,
|
|
4367
5149
|
constant int & idx,
|
|
@@ -4388,7 +5170,7 @@ kernel void kernel_mul_mv_id_q4_0_f32(
|
|
|
4388
5170
|
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
|
4389
5171
|
src0[id],
|
|
4390
5172
|
(device const float *) (src1 + bid*nb11),
|
|
4391
|
-
|
|
5173
|
+
dst + bid*ne0,
|
|
4392
5174
|
ne00,
|
|
4393
5175
|
ne01,
|
|
4394
5176
|
ne02,
|
|
@@ -4407,8 +5189,8 @@ kernel void kernel_mul_mv_id_q4_0_f32(
|
|
|
4407
5189
|
kernel void kernel_mul_mv_id_q4_1_f32(
|
|
4408
5190
|
device const char * ids,
|
|
4409
5191
|
device const char * src1,
|
|
4410
|
-
device
|
|
4411
|
-
constant
|
|
5192
|
+
device float * dst,
|
|
5193
|
+
constant uint64_t & nbi1,
|
|
4412
5194
|
constant int64_t & ne00,
|
|
4413
5195
|
constant int64_t & ne01,
|
|
4414
5196
|
constant int64_t & ne02,
|
|
@@ -4424,7 +5206,7 @@ kernel void kernel_mul_mv_id_q4_1_f32(
|
|
|
4424
5206
|
constant uint64_t & nb12,
|
|
4425
5207
|
constant int64_t & ne0,
|
|
4426
5208
|
constant int64_t & ne1,
|
|
4427
|
-
constant
|
|
5209
|
+
constant uint64_t & nb1,
|
|
4428
5210
|
constant uint & r2,
|
|
4429
5211
|
constant uint & r3,
|
|
4430
5212
|
constant int & idx,
|
|
@@ -4451,7 +5233,7 @@ kernel void kernel_mul_mv_id_q4_1_f32(
|
|
|
4451
5233
|
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
|
4452
5234
|
src0[id],
|
|
4453
5235
|
(device const float *) (src1 + bid*nb11),
|
|
4454
|
-
|
|
5236
|
+
dst + bid*ne0,
|
|
4455
5237
|
ne00,
|
|
4456
5238
|
ne01,
|
|
4457
5239
|
ne02,
|
|
@@ -4470,8 +5252,8 @@ kernel void kernel_mul_mv_id_q4_1_f32(
|
|
|
4470
5252
|
kernel void kernel_mul_mv_id_q5_0_f32(
|
|
4471
5253
|
device const char * ids,
|
|
4472
5254
|
device const char * src1,
|
|
4473
|
-
device
|
|
4474
|
-
constant
|
|
5255
|
+
device float * dst,
|
|
5256
|
+
constant uint64_t & nbi1,
|
|
4475
5257
|
constant int64_t & ne00,
|
|
4476
5258
|
constant int64_t & ne01,
|
|
4477
5259
|
constant int64_t & ne02,
|
|
@@ -4487,7 +5269,7 @@ kernel void kernel_mul_mv_id_q5_0_f32(
|
|
|
4487
5269
|
constant uint64_t & nb12,
|
|
4488
5270
|
constant int64_t & ne0,
|
|
4489
5271
|
constant int64_t & ne1,
|
|
4490
|
-
constant
|
|
5272
|
+
constant uint64_t & nb1,
|
|
4491
5273
|
constant uint & r2,
|
|
4492
5274
|
constant uint & r3,
|
|
4493
5275
|
constant int & idx,
|
|
@@ -4514,7 +5296,7 @@ kernel void kernel_mul_mv_id_q5_0_f32(
|
|
|
4514
5296
|
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
|
4515
5297
|
src0[id],
|
|
4516
5298
|
(device const float *) (src1 + bid*nb11),
|
|
4517
|
-
|
|
5299
|
+
dst + bid*ne0,
|
|
4518
5300
|
ne00,
|
|
4519
5301
|
ne01,
|
|
4520
5302
|
ne02,
|
|
@@ -4533,8 +5315,8 @@ kernel void kernel_mul_mv_id_q5_0_f32(
|
|
|
4533
5315
|
kernel void kernel_mul_mv_id_q5_1_f32(
|
|
4534
5316
|
device const char * ids,
|
|
4535
5317
|
device const char * src1,
|
|
4536
|
-
device
|
|
4537
|
-
constant
|
|
5318
|
+
device float * dst,
|
|
5319
|
+
constant uint64_t & nbi1,
|
|
4538
5320
|
constant int64_t & ne00,
|
|
4539
5321
|
constant int64_t & ne01,
|
|
4540
5322
|
constant int64_t & ne02,
|
|
@@ -4550,7 +5332,7 @@ kernel void kernel_mul_mv_id_q5_1_f32(
|
|
|
4550
5332
|
constant uint64_t & nb12,
|
|
4551
5333
|
constant int64_t & ne0,
|
|
4552
5334
|
constant int64_t & ne1,
|
|
4553
|
-
constant
|
|
5335
|
+
constant uint64_t & nb1,
|
|
4554
5336
|
constant uint & r2,
|
|
4555
5337
|
constant uint & r3,
|
|
4556
5338
|
constant int & idx,
|
|
@@ -4577,7 +5359,7 @@ kernel void kernel_mul_mv_id_q5_1_f32(
|
|
|
4577
5359
|
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
|
4578
5360
|
src0[id],
|
|
4579
5361
|
(device const float *) (src1 + bid*nb11),
|
|
4580
|
-
|
|
5362
|
+
dst + bid*ne0,
|
|
4581
5363
|
ne00,
|
|
4582
5364
|
ne01,
|
|
4583
5365
|
ne02,
|
|
@@ -4596,8 +5378,8 @@ kernel void kernel_mul_mv_id_q5_1_f32(
|
|
|
4596
5378
|
kernel void kernel_mul_mv_id_q2_K_f32(
|
|
4597
5379
|
device const char * ids,
|
|
4598
5380
|
device const char * src1,
|
|
4599
|
-
device
|
|
4600
|
-
constant
|
|
5381
|
+
device float * dst,
|
|
5382
|
+
constant uint64_t & nbi1,
|
|
4601
5383
|
constant int64_t & ne00,
|
|
4602
5384
|
constant int64_t & ne01,
|
|
4603
5385
|
constant int64_t & ne02,
|
|
@@ -4613,7 +5395,7 @@ kernel void kernel_mul_mv_id_q2_K_f32(
|
|
|
4613
5395
|
constant uint64_t & nb12,
|
|
4614
5396
|
constant int64_t & ne0,
|
|
4615
5397
|
constant int64_t & ne1,
|
|
4616
|
-
constant
|
|
5398
|
+
constant uint64_t & nb1,
|
|
4617
5399
|
constant uint & r2,
|
|
4618
5400
|
constant uint & r3,
|
|
4619
5401
|
constant int & idx,
|
|
@@ -4640,7 +5422,7 @@ kernel void kernel_mul_mv_id_q2_K_f32(
|
|
|
4640
5422
|
kernel_mul_mv_q2_K_f32_impl(
|
|
4641
5423
|
src0[id],
|
|
4642
5424
|
(device const float *) (src1 + bid*nb11),
|
|
4643
|
-
|
|
5425
|
+
dst + bid*ne0,
|
|
4644
5426
|
ne00,
|
|
4645
5427
|
ne01,
|
|
4646
5428
|
ne02,
|
|
@@ -4659,8 +5441,8 @@ kernel void kernel_mul_mv_id_q2_K_f32(
|
|
|
4659
5441
|
kernel void kernel_mul_mv_id_q3_K_f32(
|
|
4660
5442
|
device const char * ids,
|
|
4661
5443
|
device const char * src1,
|
|
4662
|
-
device
|
|
4663
|
-
constant
|
|
5444
|
+
device float * dst,
|
|
5445
|
+
constant uint64_t & nbi1,
|
|
4664
5446
|
constant int64_t & ne00,
|
|
4665
5447
|
constant int64_t & ne01,
|
|
4666
5448
|
constant int64_t & ne02,
|
|
@@ -4676,7 +5458,7 @@ kernel void kernel_mul_mv_id_q3_K_f32(
|
|
|
4676
5458
|
constant uint64_t & nb12,
|
|
4677
5459
|
constant int64_t & ne0,
|
|
4678
5460
|
constant int64_t & ne1,
|
|
4679
|
-
constant
|
|
5461
|
+
constant uint64_t & nb1,
|
|
4680
5462
|
constant uint & r2,
|
|
4681
5463
|
constant uint & r3,
|
|
4682
5464
|
constant int & idx,
|
|
@@ -4703,7 +5485,7 @@ kernel void kernel_mul_mv_id_q3_K_f32(
|
|
|
4703
5485
|
kernel_mul_mv_q3_K_f32_impl(
|
|
4704
5486
|
src0[id],
|
|
4705
5487
|
(device const float *) (src1 + bid*nb11),
|
|
4706
|
-
|
|
5488
|
+
dst + bid*ne0,
|
|
4707
5489
|
ne00,
|
|
4708
5490
|
ne01,
|
|
4709
5491
|
ne02,
|
|
@@ -4722,8 +5504,8 @@ kernel void kernel_mul_mv_id_q3_K_f32(
|
|
|
4722
5504
|
kernel void kernel_mul_mv_id_q4_K_f32(
|
|
4723
5505
|
device const char * ids,
|
|
4724
5506
|
device const char * src1,
|
|
4725
|
-
device
|
|
4726
|
-
constant
|
|
5507
|
+
device float * dst,
|
|
5508
|
+
constant uint64_t & nbi1,
|
|
4727
5509
|
constant int64_t & ne00,
|
|
4728
5510
|
constant int64_t & ne01,
|
|
4729
5511
|
constant int64_t & ne02,
|
|
@@ -4739,7 +5521,7 @@ kernel void kernel_mul_mv_id_q4_K_f32(
|
|
|
4739
5521
|
constant uint64_t & nb12,
|
|
4740
5522
|
constant int64_t & ne0,
|
|
4741
5523
|
constant int64_t & ne1,
|
|
4742
|
-
constant
|
|
5524
|
+
constant uint64_t & nb1,
|
|
4743
5525
|
constant uint & r2,
|
|
4744
5526
|
constant uint & r3,
|
|
4745
5527
|
constant int & idx,
|
|
@@ -4766,7 +5548,7 @@ kernel void kernel_mul_mv_id_q4_K_f32(
|
|
|
4766
5548
|
kernel_mul_mv_q4_K_f32_impl(
|
|
4767
5549
|
src0[id],
|
|
4768
5550
|
(device const float *) (src1 + bid*nb11),
|
|
4769
|
-
|
|
5551
|
+
dst + bid*ne0,
|
|
4770
5552
|
ne00,
|
|
4771
5553
|
ne01,
|
|
4772
5554
|
ne02,
|
|
@@ -4785,8 +5567,8 @@ kernel void kernel_mul_mv_id_q4_K_f32(
|
|
|
4785
5567
|
kernel void kernel_mul_mv_id_q5_K_f32(
|
|
4786
5568
|
device const char * ids,
|
|
4787
5569
|
device const char * src1,
|
|
4788
|
-
device
|
|
4789
|
-
constant
|
|
5570
|
+
device float * dst,
|
|
5571
|
+
constant uint64_t & nbi1,
|
|
4790
5572
|
constant int64_t & ne00,
|
|
4791
5573
|
constant int64_t & ne01,
|
|
4792
5574
|
constant int64_t & ne02,
|
|
@@ -4802,7 +5584,7 @@ kernel void kernel_mul_mv_id_q5_K_f32(
|
|
|
4802
5584
|
constant uint64_t & nb12,
|
|
4803
5585
|
constant int64_t & ne0,
|
|
4804
5586
|
constant int64_t & ne1,
|
|
4805
|
-
constant
|
|
5587
|
+
constant uint64_t & nb1,
|
|
4806
5588
|
constant uint & r2,
|
|
4807
5589
|
constant uint & r3,
|
|
4808
5590
|
constant int & idx,
|
|
@@ -4829,7 +5611,7 @@ kernel void kernel_mul_mv_id_q5_K_f32(
|
|
|
4829
5611
|
kernel_mul_mv_q5_K_f32_impl(
|
|
4830
5612
|
src0[id],
|
|
4831
5613
|
(device const float *) (src1 + bid*nb11),
|
|
4832
|
-
|
|
5614
|
+
dst + bid*ne0,
|
|
4833
5615
|
ne00,
|
|
4834
5616
|
ne01,
|
|
4835
5617
|
ne02,
|
|
@@ -4848,8 +5630,8 @@ kernel void kernel_mul_mv_id_q5_K_f32(
|
|
|
4848
5630
|
kernel void kernel_mul_mv_id_q6_K_f32(
|
|
4849
5631
|
device const char * ids,
|
|
4850
5632
|
device const char * src1,
|
|
4851
|
-
device
|
|
4852
|
-
constant
|
|
5633
|
+
device float * dst,
|
|
5634
|
+
constant uint64_t & nbi1,
|
|
4853
5635
|
constant int64_t & ne00,
|
|
4854
5636
|
constant int64_t & ne01,
|
|
4855
5637
|
constant int64_t & ne02,
|
|
@@ -4865,7 +5647,7 @@ kernel void kernel_mul_mv_id_q6_K_f32(
|
|
|
4865
5647
|
constant uint64_t & nb12,
|
|
4866
5648
|
constant int64_t & ne0,
|
|
4867
5649
|
constant int64_t & ne1,
|
|
4868
|
-
constant
|
|
5650
|
+
constant uint64_t & nb1,
|
|
4869
5651
|
constant uint & r2,
|
|
4870
5652
|
constant uint & r3,
|
|
4871
5653
|
constant int & idx,
|
|
@@ -4892,7 +5674,136 @@ kernel void kernel_mul_mv_id_q6_K_f32(
|
|
|
4892
5674
|
kernel_mul_mv_q6_K_f32_impl(
|
|
4893
5675
|
src0[id],
|
|
4894
5676
|
(device const float *) (src1 + bid*nb11),
|
|
4895
|
-
|
|
5677
|
+
dst + bid*ne0,
|
|
5678
|
+
ne00,
|
|
5679
|
+
ne01,
|
|
5680
|
+
ne02,
|
|
5681
|
+
ne10,
|
|
5682
|
+
ne12,
|
|
5683
|
+
ne0,
|
|
5684
|
+
ne1,
|
|
5685
|
+
r2,
|
|
5686
|
+
r3,
|
|
5687
|
+
tgpig,
|
|
5688
|
+
tiisg,
|
|
5689
|
+
sgitg);
|
|
5690
|
+
}
|
|
5691
|
+
|
|
5692
|
+
[[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
|
|
5693
|
+
kernel void kernel_mul_mv_id_iq2_xxs_f32(
|
|
5694
|
+
device const char * ids,
|
|
5695
|
+
device const char * src1,
|
|
5696
|
+
device float * dst,
|
|
5697
|
+
constant uint64_t & nbi1,
|
|
5698
|
+
constant int64_t & ne00,
|
|
5699
|
+
constant int64_t & ne01,
|
|
5700
|
+
constant int64_t & ne02,
|
|
5701
|
+
constant uint64_t & nb00,
|
|
5702
|
+
constant uint64_t & nb01,
|
|
5703
|
+
constant uint64_t & nb02,
|
|
5704
|
+
constant int64_t & ne10,
|
|
5705
|
+
constant int64_t & ne11,
|
|
5706
|
+
constant int64_t & ne12,
|
|
5707
|
+
constant int64_t & ne13,
|
|
5708
|
+
constant uint64_t & nb10,
|
|
5709
|
+
constant uint64_t & nb11,
|
|
5710
|
+
constant uint64_t & nb12,
|
|
5711
|
+
constant int64_t & ne0,
|
|
5712
|
+
constant int64_t & ne1,
|
|
5713
|
+
constant uint64_t & nb1,
|
|
5714
|
+
constant uint & r2,
|
|
5715
|
+
constant uint & r3,
|
|
5716
|
+
constant int & idx,
|
|
5717
|
+
device const char * src00,
|
|
5718
|
+
device const char * src01,
|
|
5719
|
+
device const char * src02,
|
|
5720
|
+
device const char * src03,
|
|
5721
|
+
device const char * src04,
|
|
5722
|
+
device const char * src05,
|
|
5723
|
+
device const char * src06,
|
|
5724
|
+
device const char * src07,
|
|
5725
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
5726
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5727
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
5728
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
5729
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
5730
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
5731
|
+
|
|
5732
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
|
5733
|
+
|
|
5734
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
|
5735
|
+
|
|
5736
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
5737
|
+
|
|
5738
|
+
kernel_mul_mv_iq2_xxs_f32_impl(
|
|
5739
|
+
src0[id],
|
|
5740
|
+
(device const float *) (src1 + bid*nb11),
|
|
5741
|
+
dst + bid*ne0,
|
|
5742
|
+
ne00,
|
|
5743
|
+
ne01,
|
|
5744
|
+
ne02,
|
|
5745
|
+
ne10,
|
|
5746
|
+
ne12,
|
|
5747
|
+
ne0,
|
|
5748
|
+
ne1,
|
|
5749
|
+
r2,
|
|
5750
|
+
r3,
|
|
5751
|
+
shared_values,
|
|
5752
|
+
tgpig,
|
|
5753
|
+
tiisg,
|
|
5754
|
+
sgitg);
|
|
5755
|
+
}
|
|
5756
|
+
|
|
5757
|
+
[[host_name("kernel_mul_mv_id_iq2_xs_f32")]]
|
|
5758
|
+
kernel void kernel_mul_mv_id_iq2_xs_f32(
|
|
5759
|
+
device const char * ids,
|
|
5760
|
+
device const char * src1,
|
|
5761
|
+
device float * dst,
|
|
5762
|
+
constant uint64_t & nbi1,
|
|
5763
|
+
constant int64_t & ne00,
|
|
5764
|
+
constant int64_t & ne01,
|
|
5765
|
+
constant int64_t & ne02,
|
|
5766
|
+
constant uint64_t & nb00,
|
|
5767
|
+
constant uint64_t & nb01,
|
|
5768
|
+
constant uint64_t & nb02,
|
|
5769
|
+
constant int64_t & ne10,
|
|
5770
|
+
constant int64_t & ne11,
|
|
5771
|
+
constant int64_t & ne12,
|
|
5772
|
+
constant int64_t & ne13,
|
|
5773
|
+
constant uint64_t & nb10,
|
|
5774
|
+
constant uint64_t & nb11,
|
|
5775
|
+
constant uint64_t & nb12,
|
|
5776
|
+
constant int64_t & ne0,
|
|
5777
|
+
constant int64_t & ne1,
|
|
5778
|
+
constant uint64_t & nb1,
|
|
5779
|
+
constant uint & r2,
|
|
5780
|
+
constant uint & r3,
|
|
5781
|
+
constant int & idx,
|
|
5782
|
+
device const char * src00,
|
|
5783
|
+
device const char * src01,
|
|
5784
|
+
device const char * src02,
|
|
5785
|
+
device const char * src03,
|
|
5786
|
+
device const char * src04,
|
|
5787
|
+
device const char * src05,
|
|
5788
|
+
device const char * src06,
|
|
5789
|
+
device const char * src07,
|
|
5790
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
5791
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5792
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
5793
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
5794
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
5795
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
5796
|
+
|
|
5797
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
|
5798
|
+
|
|
5799
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
|
5800
|
+
|
|
5801
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
5802
|
+
|
|
5803
|
+
kernel_mul_mv_iq2_xs_f32_impl(
|
|
5804
|
+
src0[id],
|
|
5805
|
+
(device const float *) (src1 + bid*nb11),
|
|
5806
|
+
dst + bid*ne0,
|
|
4896
5807
|
ne00,
|
|
4897
5808
|
ne01,
|
|
4898
5809
|
ne02,
|
|
@@ -4902,6 +5813,7 @@ kernel void kernel_mul_mv_id_q6_K_f32(
|
|
|
4902
5813
|
ne1,
|
|
4903
5814
|
r2,
|
|
4904
5815
|
r3,
|
|
5816
|
+
shared_values,
|
|
4905
5817
|
tgpig,
|
|
4906
5818
|
tiisg,
|
|
4907
5819
|
sgitg);
|