llama_cpp 0.14.5 → 0.14.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/vendor/tmp/llama.cpp/Makefile +18 -6
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +135 -46
- data/vendor/tmp/llama.cpp/ggml-impl.h +1 -1
- data/vendor/tmp/llama.cpp/ggml-metal.m +130 -83
- data/vendor/tmp/llama.cpp/ggml-metal.metal +505 -1467
- data/vendor/tmp/llama.cpp/ggml-quants.c +1 -1
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +65 -52
- data/vendor/tmp/llama.cpp/ggml.c +153 -87
- data/vendor/tmp/llama.cpp/ggml.h +5 -4
- data/vendor/tmp/llama.cpp/llama.cpp +885 -144
- data/vendor/tmp/llama.cpp/sgemm.cpp +1148 -0
- data/vendor/tmp/llama.cpp/sgemm.h +12 -0
- metadata +4 -2
@@ -213,6 +213,15 @@ kernel void kernel_scale_4(
|
|
213
213
|
dst[tpig] = src0[tpig] * scale;
|
214
214
|
}
|
215
215
|
|
216
|
+
kernel void kernel_clamp(
|
217
|
+
device const float * src0,
|
218
|
+
device float * dst,
|
219
|
+
constant float & min,
|
220
|
+
constant float & max,
|
221
|
+
uint tpig[[thread_position_in_grid]]) {
|
222
|
+
dst[tpig] = src0[tpig] < min ? min : (src0[tpig] > max ? max : src0[tpig]);
|
223
|
+
}
|
224
|
+
|
216
225
|
kernel void kernel_relu(
|
217
226
|
device const float * src0,
|
218
227
|
device float * dst,
|
@@ -233,6 +242,15 @@ constant float GELU_QUICK_COEF = -1.702f;
|
|
233
242
|
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
234
243
|
|
235
244
|
kernel void kernel_gelu(
|
245
|
+
device const float * src0,
|
246
|
+
device float * dst,
|
247
|
+
uint tpig[[thread_position_in_grid]]) {
|
248
|
+
device const float & x = src0[tpig];
|
249
|
+
|
250
|
+
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
251
|
+
}
|
252
|
+
|
253
|
+
kernel void kernel_gelu_4(
|
236
254
|
device const float4 * src0,
|
237
255
|
device float4 * dst,
|
238
256
|
uint tpig[[thread_position_in_grid]]) {
|
@@ -246,6 +264,15 @@ kernel void kernel_gelu(
|
|
246
264
|
}
|
247
265
|
|
248
266
|
kernel void kernel_gelu_quick(
|
267
|
+
device const float * src0,
|
268
|
+
device float * dst,
|
269
|
+
uint tpig[[thread_position_in_grid]]) {
|
270
|
+
device const float & x = src0[tpig];
|
271
|
+
|
272
|
+
dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
|
273
|
+
}
|
274
|
+
|
275
|
+
kernel void kernel_gelu_quick_4(
|
249
276
|
device const float4 * src0,
|
250
277
|
device float4 * dst,
|
251
278
|
uint tpig[[thread_position_in_grid]]) {
|
@@ -255,6 +282,14 @@ kernel void kernel_gelu_quick(
|
|
255
282
|
}
|
256
283
|
|
257
284
|
kernel void kernel_silu(
|
285
|
+
device const float * src0,
|
286
|
+
device float * dst,
|
287
|
+
uint tpig[[thread_position_in_grid]]) {
|
288
|
+
device const float & x = src0[tpig];
|
289
|
+
dst[tpig] = x / (1.0f + exp(-x));
|
290
|
+
}
|
291
|
+
|
292
|
+
kernel void kernel_silu_4(
|
258
293
|
device const float4 * src0,
|
259
294
|
device float4 * dst,
|
260
295
|
uint tpig[[thread_position_in_grid]]) {
|
@@ -866,6 +901,7 @@ void mul_vec_q_n_f32_impl(
|
|
866
901
|
int64_t ne1,
|
867
902
|
uint r2,
|
868
903
|
uint r3,
|
904
|
+
threadgroup int8_t * shared_values,
|
869
905
|
uint3 tgpig, uint tiisg, uint sgitg) {
|
870
906
|
const int nb = ne00/QK4_0;
|
871
907
|
|
@@ -942,7 +978,7 @@ kernel void kernel_mul_mv_q4_0_f32(
|
|
942
978
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
943
979
|
uint tiisg[[thread_index_in_simdgroup]],
|
944
980
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
945
|
-
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
981
|
+
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
|
946
982
|
}
|
947
983
|
|
948
984
|
kernel void kernel_mul_mv_q4_1_f32(
|
@@ -968,7 +1004,7 @@ kernel void kernel_mul_mv_q4_1_f32(
|
|
968
1004
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
969
1005
|
uint tiisg[[thread_index_in_simdgroup]],
|
970
1006
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
971
|
-
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
1007
|
+
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
|
972
1008
|
}
|
973
1009
|
|
974
1010
|
kernel void kernel_mul_mv_q5_0_f32(
|
@@ -994,7 +1030,7 @@ kernel void kernel_mul_mv_q5_0_f32(
|
|
994
1030
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
995
1031
|
uint tiisg[[thread_index_in_simdgroup]],
|
996
1032
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
997
|
-
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
1033
|
+
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
|
998
1034
|
}
|
999
1035
|
|
1000
1036
|
kernel void kernel_mul_mv_q5_1_f32(
|
@@ -1020,7 +1056,7 @@ kernel void kernel_mul_mv_q5_1_f32(
|
|
1020
1056
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1021
1057
|
uint tiisg[[thread_index_in_simdgroup]],
|
1022
1058
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1023
|
-
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
1059
|
+
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
|
1024
1060
|
}
|
1025
1061
|
|
1026
1062
|
|
@@ -1030,18 +1066,19 @@ void kernel_mul_mv_q8_0_f32_impl(
|
|
1030
1066
|
device const void * src0,
|
1031
1067
|
device const float * src1,
|
1032
1068
|
device float * dst,
|
1033
|
-
|
1034
|
-
|
1035
|
-
|
1036
|
-
|
1037
|
-
|
1038
|
-
|
1039
|
-
|
1040
|
-
|
1041
|
-
|
1042
|
-
|
1043
|
-
|
1044
|
-
|
1069
|
+
int64_t ne00,
|
1070
|
+
int64_t ne01,
|
1071
|
+
int64_t ne02,
|
1072
|
+
int64_t ne10,
|
1073
|
+
int64_t ne12,
|
1074
|
+
int64_t ne0,
|
1075
|
+
int64_t ne1,
|
1076
|
+
uint r2,
|
1077
|
+
uint r3,
|
1078
|
+
threadgroup int8_t * shared_values,
|
1079
|
+
uint3 tgpig,
|
1080
|
+
uint tiisg,
|
1081
|
+
uint sgitg) {
|
1045
1082
|
const int nr = N_DST;
|
1046
1083
|
const int nsg = N_SIMDGROUP;
|
1047
1084
|
const int nw = N_SIMDWIDTH;
|
@@ -1119,7 +1156,7 @@ kernel void kernel_mul_mv_q8_0_f32(
|
|
1119
1156
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1120
1157
|
uint tiisg[[thread_index_in_simdgroup]],
|
1121
1158
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1122
|
-
kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
1159
|
+
kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
|
1123
1160
|
}
|
1124
1161
|
|
1125
1162
|
#define N_F32_F32 4
|
@@ -1128,24 +1165,24 @@ void kernel_mul_mv_f32_f32_impl(
|
|
1128
1165
|
device const char * src0,
|
1129
1166
|
device const char * src1,
|
1130
1167
|
device float * dst,
|
1131
|
-
|
1132
|
-
|
1133
|
-
|
1134
|
-
|
1135
|
-
|
1136
|
-
|
1137
|
-
|
1138
|
-
|
1139
|
-
|
1140
|
-
|
1141
|
-
|
1142
|
-
|
1143
|
-
|
1144
|
-
|
1145
|
-
|
1146
|
-
|
1147
|
-
|
1148
|
-
|
1168
|
+
int64_t ne00,
|
1169
|
+
int64_t ne01,
|
1170
|
+
int64_t ne02,
|
1171
|
+
uint64_t nb00,
|
1172
|
+
uint64_t nb01,
|
1173
|
+
uint64_t nb02,
|
1174
|
+
int64_t ne10,
|
1175
|
+
int64_t ne11,
|
1176
|
+
int64_t ne12,
|
1177
|
+
uint64_t nb10,
|
1178
|
+
uint64_t nb11,
|
1179
|
+
uint64_t nb12,
|
1180
|
+
int64_t ne0,
|
1181
|
+
int64_t ne1,
|
1182
|
+
uint r2,
|
1183
|
+
uint r3,
|
1184
|
+
uint3 tgpig,
|
1185
|
+
uint tiisg) {
|
1149
1186
|
|
1150
1187
|
const int64_t r0 = tgpig.x;
|
1151
1188
|
const int64_t rb = tgpig.y*N_F32_F32;
|
@@ -1398,24 +1435,24 @@ void kernel_mul_mv_f16_f32_impl(
|
|
1398
1435
|
device const char * src0,
|
1399
1436
|
device const char * src1,
|
1400
1437
|
device float * dst,
|
1401
|
-
|
1402
|
-
|
1403
|
-
|
1404
|
-
|
1405
|
-
|
1406
|
-
|
1407
|
-
|
1408
|
-
|
1409
|
-
|
1410
|
-
|
1411
|
-
|
1412
|
-
|
1413
|
-
|
1414
|
-
|
1415
|
-
|
1416
|
-
|
1417
|
-
|
1418
|
-
|
1438
|
+
int64_t ne00,
|
1439
|
+
int64_t ne01,
|
1440
|
+
int64_t ne02,
|
1441
|
+
uint64_t nb00,
|
1442
|
+
uint64_t nb01,
|
1443
|
+
uint64_t nb02,
|
1444
|
+
int64_t ne10,
|
1445
|
+
int64_t ne11,
|
1446
|
+
int64_t ne12,
|
1447
|
+
uint64_t nb10,
|
1448
|
+
uint64_t nb11,
|
1449
|
+
uint64_t nb12,
|
1450
|
+
int64_t ne0,
|
1451
|
+
int64_t ne1,
|
1452
|
+
uint r2,
|
1453
|
+
uint r3,
|
1454
|
+
uint3 tgpig,
|
1455
|
+
uint tiisg) {
|
1419
1456
|
|
1420
1457
|
const int64_t r0 = tgpig.x;
|
1421
1458
|
const int64_t rb = tgpig.y*N_F16_F32;
|
@@ -2700,18 +2737,19 @@ void kernel_mul_mv_q2_K_f32_impl(
|
|
2700
2737
|
device const void * src0,
|
2701
2738
|
device const float * src1,
|
2702
2739
|
device float * dst,
|
2703
|
-
|
2704
|
-
|
2705
|
-
|
2706
|
-
|
2707
|
-
|
2708
|
-
|
2709
|
-
|
2710
|
-
|
2711
|
-
|
2712
|
-
|
2713
|
-
|
2714
|
-
|
2740
|
+
int64_t ne00,
|
2741
|
+
int64_t ne01,
|
2742
|
+
int64_t ne02,
|
2743
|
+
int64_t ne10,
|
2744
|
+
int64_t ne12,
|
2745
|
+
int64_t ne0,
|
2746
|
+
int64_t ne1,
|
2747
|
+
uint r2,
|
2748
|
+
uint r3,
|
2749
|
+
threadgroup int8_t * shared_values,
|
2750
|
+
uint3 tgpig,
|
2751
|
+
uint tiisg,
|
2752
|
+
uint sgitg) {
|
2715
2753
|
|
2716
2754
|
const int nb = ne00/QK_K;
|
2717
2755
|
const int r0 = tgpig.x;
|
@@ -2871,7 +2909,7 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
2871
2909
|
uint tiisg[[thread_index_in_simdgroup]],
|
2872
2910
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
2873
2911
|
|
2874
|
-
kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
2912
|
+
kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
2875
2913
|
}
|
2876
2914
|
|
2877
2915
|
#if QK_K == 256
|
@@ -2879,18 +2917,19 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|
2879
2917
|
device const void * src0,
|
2880
2918
|
device const float * src1,
|
2881
2919
|
device float * dst,
|
2882
|
-
|
2883
|
-
|
2884
|
-
|
2885
|
-
|
2886
|
-
|
2887
|
-
|
2888
|
-
|
2889
|
-
|
2890
|
-
|
2891
|
-
|
2892
|
-
|
2893
|
-
|
2920
|
+
int64_t ne00,
|
2921
|
+
int64_t ne01,
|
2922
|
+
int64_t ne02,
|
2923
|
+
int64_t ne10,
|
2924
|
+
int64_t ne12,
|
2925
|
+
int64_t ne0,
|
2926
|
+
int64_t ne1,
|
2927
|
+
uint r2,
|
2928
|
+
uint r3,
|
2929
|
+
threadgroup int8_t * shared_values,
|
2930
|
+
uint3 tgpig,
|
2931
|
+
uint tiisg,
|
2932
|
+
uint sgitg) {
|
2894
2933
|
|
2895
2934
|
const int nb = ne00/QK_K;
|
2896
2935
|
|
@@ -3046,6 +3085,7 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|
3046
3085
|
constant int64_t & ne1,
|
3047
3086
|
constant uint & r2,
|
3048
3087
|
constant uint & r3,
|
3088
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
3049
3089
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
3050
3090
|
uint tiisg[[thread_index_in_simdgroup]],
|
3051
3091
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -3135,7 +3175,7 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
3135
3175
|
uint tiisg[[thread_index_in_simdgroup]],
|
3136
3176
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3137
3177
|
|
3138
|
-
kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
3178
|
+
kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
3139
3179
|
}
|
3140
3180
|
|
3141
3181
|
#if QK_K == 256
|
@@ -3143,18 +3183,19 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
3143
3183
|
device const void * src0,
|
3144
3184
|
device const float * src1,
|
3145
3185
|
device float * dst,
|
3146
|
-
|
3147
|
-
|
3148
|
-
|
3149
|
-
|
3150
|
-
|
3151
|
-
|
3152
|
-
|
3153
|
-
|
3154
|
-
|
3155
|
-
|
3156
|
-
|
3157
|
-
|
3186
|
+
int64_t ne00,
|
3187
|
+
int64_t ne01,
|
3188
|
+
int64_t ne02,
|
3189
|
+
int64_t ne10,
|
3190
|
+
int64_t ne12,
|
3191
|
+
int64_t ne0,
|
3192
|
+
int64_t ne1,
|
3193
|
+
uint r2,
|
3194
|
+
uint r3,
|
3195
|
+
threadgroup int8_t * shared_values,
|
3196
|
+
uint3 tgpig,
|
3197
|
+
uint tiisg,
|
3198
|
+
uint sgitg) {
|
3158
3199
|
|
3159
3200
|
const uint16_t kmask1 = 0x3f3f;
|
3160
3201
|
const uint16_t kmask2 = 0x0f0f;
|
@@ -3265,6 +3306,7 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
3265
3306
|
constant int64_t & ne1,
|
3266
3307
|
constant uint & r2,
|
3267
3308
|
constant uint & r3,
|
3309
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
3268
3310
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
3269
3311
|
uint tiisg[[thread_index_in_simdgroup]],
|
3270
3312
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -3373,25 +3415,26 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
3373
3415
|
uint tiisg[[thread_index_in_simdgroup]],
|
3374
3416
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3375
3417
|
|
3376
|
-
kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
3418
|
+
kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
3377
3419
|
}
|
3378
3420
|
|
3379
3421
|
void kernel_mul_mv_q5_K_f32_impl(
|
3380
3422
|
device const void * src0,
|
3381
3423
|
device const float * src1,
|
3382
3424
|
device float * dst,
|
3383
|
-
|
3384
|
-
|
3385
|
-
|
3386
|
-
|
3387
|
-
|
3388
|
-
|
3389
|
-
|
3390
|
-
|
3391
|
-
|
3392
|
-
|
3393
|
-
|
3394
|
-
|
3425
|
+
int64_t ne00,
|
3426
|
+
int64_t ne01,
|
3427
|
+
int64_t ne02,
|
3428
|
+
int64_t ne10,
|
3429
|
+
int64_t ne12,
|
3430
|
+
int64_t ne0,
|
3431
|
+
int64_t ne1,
|
3432
|
+
uint r2,
|
3433
|
+
uint r3,
|
3434
|
+
threadgroup int8_t * shared_values,
|
3435
|
+
uint3 tgpig,
|
3436
|
+
uint tiisg,
|
3437
|
+
uint sgitg) {
|
3395
3438
|
|
3396
3439
|
const int nb = ne00/QK_K;
|
3397
3440
|
|
@@ -3579,25 +3622,26 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
3579
3622
|
uint tiisg[[thread_index_in_simdgroup]],
|
3580
3623
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3581
3624
|
|
3582
|
-
kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
3625
|
+
kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
3583
3626
|
}
|
3584
3627
|
|
3585
3628
|
void kernel_mul_mv_q6_K_f32_impl(
|
3586
3629
|
device const void * src0,
|
3587
3630
|
device const float * src1,
|
3588
3631
|
device float * dst,
|
3589
|
-
|
3590
|
-
|
3591
|
-
|
3592
|
-
|
3593
|
-
|
3594
|
-
|
3595
|
-
|
3596
|
-
|
3597
|
-
|
3598
|
-
|
3599
|
-
|
3600
|
-
|
3632
|
+
int64_t ne00,
|
3633
|
+
int64_t ne01,
|
3634
|
+
int64_t ne02,
|
3635
|
+
int64_t ne10,
|
3636
|
+
int64_t ne12,
|
3637
|
+
int64_t ne0,
|
3638
|
+
int64_t ne1,
|
3639
|
+
uint r2,
|
3640
|
+
uint r3,
|
3641
|
+
threadgroup int8_t * shared_values,
|
3642
|
+
uint3 tgpig,
|
3643
|
+
uint tiisg,
|
3644
|
+
uint sgitg) {
|
3601
3645
|
|
3602
3646
|
const uint8_t kmask1 = 0x03;
|
3603
3647
|
const uint8_t kmask2 = 0x0C;
|
@@ -3713,7 +3757,7 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|
3713
3757
|
uint tiisg[[thread_index_in_simdgroup]],
|
3714
3758
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3715
3759
|
|
3716
|
-
kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
3760
|
+
kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
3717
3761
|
}
|
3718
3762
|
|
3719
3763
|
// ======================= "True" 2-bit
|
@@ -3722,19 +3766,19 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
|
3722
3766
|
device const void * src0,
|
3723
3767
|
device const float * src1,
|
3724
3768
|
device float * dst,
|
3725
|
-
|
3726
|
-
|
3727
|
-
|
3728
|
-
|
3729
|
-
|
3730
|
-
|
3731
|
-
|
3732
|
-
|
3733
|
-
|
3734
|
-
threadgroup int8_t * shared_values
|
3735
|
-
|
3736
|
-
|
3737
|
-
|
3769
|
+
int64_t ne00,
|
3770
|
+
int64_t ne01,
|
3771
|
+
int64_t ne02,
|
3772
|
+
int64_t ne10,
|
3773
|
+
int64_t ne12,
|
3774
|
+
int64_t ne0,
|
3775
|
+
int64_t ne1,
|
3776
|
+
uint r2,
|
3777
|
+
uint r3,
|
3778
|
+
threadgroup int8_t * shared_values,
|
3779
|
+
uint3 tgpig,
|
3780
|
+
uint tiisg,
|
3781
|
+
uint sgitg) {
|
3738
3782
|
|
3739
3783
|
const int nb = ne00/QK_K;
|
3740
3784
|
const int r0 = tgpig.x;
|
@@ -3851,19 +3895,19 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
|
3851
3895
|
device const void * src0,
|
3852
3896
|
device const float * src1,
|
3853
3897
|
device float * dst,
|
3854
|
-
|
3855
|
-
|
3856
|
-
|
3857
|
-
|
3858
|
-
|
3859
|
-
|
3860
|
-
|
3861
|
-
|
3862
|
-
|
3863
|
-
threadgroup int8_t * shared_values
|
3864
|
-
|
3865
|
-
|
3866
|
-
|
3898
|
+
int64_t ne00,
|
3899
|
+
int64_t ne01,
|
3900
|
+
int64_t ne02,
|
3901
|
+
int64_t ne10,
|
3902
|
+
int64_t ne12,
|
3903
|
+
int64_t ne0,
|
3904
|
+
int64_t ne1,
|
3905
|
+
uint r2,
|
3906
|
+
uint r3,
|
3907
|
+
threadgroup int8_t * shared_values,
|
3908
|
+
uint3 tgpig,
|
3909
|
+
uint tiisg,
|
3910
|
+
uint sgitg) {
|
3867
3911
|
|
3868
3912
|
const int nb = ne00/QK_K;
|
3869
3913
|
const int r0 = tgpig.x;
|
@@ -3990,19 +4034,19 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
|
3990
4034
|
device const void * src0,
|
3991
4035
|
device const float * src1,
|
3992
4036
|
device float * dst,
|
3993
|
-
|
3994
|
-
|
3995
|
-
|
3996
|
-
|
3997
|
-
|
3998
|
-
|
3999
|
-
|
4000
|
-
|
4001
|
-
|
4002
|
-
threadgroup int8_t * shared_values
|
4003
|
-
|
4004
|
-
|
4005
|
-
|
4037
|
+
int64_t ne00,
|
4038
|
+
int64_t ne01,
|
4039
|
+
int64_t ne02,
|
4040
|
+
int64_t ne10,
|
4041
|
+
int64_t ne12,
|
4042
|
+
int64_t ne0,
|
4043
|
+
int64_t ne1,
|
4044
|
+
uint r2,
|
4045
|
+
uint r3,
|
4046
|
+
threadgroup int8_t * shared_values,
|
4047
|
+
uint3 tgpig,
|
4048
|
+
uint tiisg,
|
4049
|
+
uint sgitg) {
|
4006
4050
|
|
4007
4051
|
const int nb = ne00/QK_K;
|
4008
4052
|
const int r0 = tgpig.x;
|
@@ -4122,19 +4166,19 @@ void kernel_mul_mv_iq3_s_f32_impl(
|
|
4122
4166
|
device const void * src0,
|
4123
4167
|
device const float * src1,
|
4124
4168
|
device float * dst,
|
4125
|
-
|
4126
|
-
|
4127
|
-
|
4128
|
-
|
4129
|
-
|
4130
|
-
|
4131
|
-
|
4132
|
-
|
4133
|
-
|
4134
|
-
threadgroup int8_t * shared_values
|
4135
|
-
|
4136
|
-
|
4137
|
-
|
4169
|
+
int64_t ne00,
|
4170
|
+
int64_t ne01,
|
4171
|
+
int64_t ne02,
|
4172
|
+
int64_t ne10,
|
4173
|
+
int64_t ne12,
|
4174
|
+
int64_t ne0,
|
4175
|
+
int64_t ne1,
|
4176
|
+
uint r2,
|
4177
|
+
uint r3,
|
4178
|
+
threadgroup int8_t * shared_values,
|
4179
|
+
uint3 tgpig,
|
4180
|
+
uint tiisg,
|
4181
|
+
uint sgitg) {
|
4138
4182
|
|
4139
4183
|
const int nb = ne00/QK_K;
|
4140
4184
|
const int r0 = tgpig.x;
|
@@ -4254,19 +4298,19 @@ void kernel_mul_mv_iq2_s_f32_impl(
|
|
4254
4298
|
device const void * src0,
|
4255
4299
|
device const float * src1,
|
4256
4300
|
device float * dst,
|
4257
|
-
|
4258
|
-
|
4259
|
-
|
4260
|
-
|
4261
|
-
|
4262
|
-
|
4263
|
-
|
4264
|
-
|
4265
|
-
|
4266
|
-
threadgroup int8_t * shared_values
|
4267
|
-
|
4268
|
-
|
4269
|
-
|
4301
|
+
int64_t ne00,
|
4302
|
+
int64_t ne01,
|
4303
|
+
int64_t ne02,
|
4304
|
+
int64_t ne10,
|
4305
|
+
int64_t ne12,
|
4306
|
+
int64_t ne0,
|
4307
|
+
int64_t ne1,
|
4308
|
+
uint r2,
|
4309
|
+
uint r3,
|
4310
|
+
threadgroup int8_t * shared_values,
|
4311
|
+
uint3 tgpig,
|
4312
|
+
uint tiisg,
|
4313
|
+
uint sgitg) {
|
4270
4314
|
|
4271
4315
|
const int nb = ne00/QK_K;
|
4272
4316
|
const int r0 = tgpig.x;
|
@@ -4387,18 +4431,19 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|
4387
4431
|
device const void * src0,
|
4388
4432
|
device const float * src1,
|
4389
4433
|
device float * dst,
|
4390
|
-
|
4391
|
-
|
4392
|
-
|
4393
|
-
|
4394
|
-
|
4395
|
-
|
4396
|
-
|
4397
|
-
|
4398
|
-
|
4399
|
-
|
4400
|
-
|
4401
|
-
|
4434
|
+
int64_t ne00,
|
4435
|
+
int64_t ne01,
|
4436
|
+
int64_t ne02,
|
4437
|
+
int64_t ne10,
|
4438
|
+
int64_t ne12,
|
4439
|
+
int64_t ne0,
|
4440
|
+
int64_t ne1,
|
4441
|
+
uint r2,
|
4442
|
+
uint r3,
|
4443
|
+
threadgroup int8_t * shared_value,
|
4444
|
+
uint3 tgpig,
|
4445
|
+
uint tiisg,
|
4446
|
+
uint sgitg) {
|
4402
4447
|
|
4403
4448
|
const int nb = ne00/QK_K;
|
4404
4449
|
const int r0 = tgpig.x;
|
@@ -4476,18 +4521,19 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
|
4476
4521
|
device const void * src0,
|
4477
4522
|
device const float * src1,
|
4478
4523
|
device float * dst,
|
4479
|
-
|
4480
|
-
|
4481
|
-
|
4482
|
-
|
4483
|
-
|
4484
|
-
|
4485
|
-
|
4486
|
-
|
4487
|
-
|
4488
|
-
|
4489
|
-
|
4490
|
-
|
4524
|
+
int64_t ne00,
|
4525
|
+
int64_t ne01,
|
4526
|
+
int64_t ne02,
|
4527
|
+
int64_t ne10,
|
4528
|
+
int64_t ne12,
|
4529
|
+
int64_t ne0,
|
4530
|
+
int64_t ne1,
|
4531
|
+
uint r2,
|
4532
|
+
uint r3,
|
4533
|
+
threadgroup int8_t * shared_value,
|
4534
|
+
uint3 tgpig,
|
4535
|
+
uint tiisg,
|
4536
|
+
uint sgitg) {
|
4491
4537
|
|
4492
4538
|
const int nb = ne00/QK_K;
|
4493
4539
|
const int r0 = tgpig.x;
|
@@ -4584,20 +4630,21 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
|
4584
4630
|
device const void * src0,
|
4585
4631
|
device const float * src1,
|
4586
4632
|
device float * dst,
|
4587
|
-
|
4588
|
-
|
4589
|
-
|
4590
|
-
|
4591
|
-
|
4592
|
-
|
4593
|
-
|
4594
|
-
|
4595
|
-
|
4596
|
-
threadgroup
|
4597
|
-
|
4598
|
-
|
4599
|
-
|
4633
|
+
int64_t ne00,
|
4634
|
+
int64_t ne01,
|
4635
|
+
int64_t ne02,
|
4636
|
+
int64_t ne10,
|
4637
|
+
int64_t ne12,
|
4638
|
+
int64_t ne0,
|
4639
|
+
int64_t ne1,
|
4640
|
+
uint r2,
|
4641
|
+
uint r3,
|
4642
|
+
threadgroup int8_t * shared_values_i8,
|
4643
|
+
uint3 tgpig,
|
4644
|
+
uint tiisg,
|
4645
|
+
uint sgitg) {
|
4600
4646
|
|
4647
|
+
threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
|
4601
4648
|
const int nb = ne00/QK4_NL;
|
4602
4649
|
const int r0 = tgpig.x;
|
4603
4650
|
const int r1 = tgpig.y;
|
@@ -4678,20 +4725,21 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
|
4678
4725
|
device const void * src0,
|
4679
4726
|
device const float * src1,
|
4680
4727
|
device float * dst,
|
4681
|
-
|
4682
|
-
|
4683
|
-
|
4684
|
-
|
4685
|
-
|
4686
|
-
|
4687
|
-
|
4688
|
-
|
4689
|
-
|
4690
|
-
threadgroup
|
4691
|
-
|
4692
|
-
|
4693
|
-
|
4728
|
+
int64_t ne00,
|
4729
|
+
int64_t ne01,
|
4730
|
+
int64_t ne02,
|
4731
|
+
int64_t ne10,
|
4732
|
+
int64_t ne12,
|
4733
|
+
int64_t ne0,
|
4734
|
+
int64_t ne1,
|
4735
|
+
uint r2,
|
4736
|
+
uint r3,
|
4737
|
+
threadgroup int8_t * shared_values_i8,
|
4738
|
+
uint3 tgpig,
|
4739
|
+
uint tiisg,
|
4740
|
+
uint sgitg) {
|
4694
4741
|
|
4742
|
+
threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
|
4695
4743
|
const int nb = ne00/QK_K;
|
4696
4744
|
const int r0 = tgpig.x;
|
4697
4745
|
const int r1 = tgpig.y;
|
@@ -4794,7 +4842,7 @@ kernel void kernel_mul_mv_iq1_s_f32(
|
|
4794
4842
|
uint tiisg[[thread_index_in_simdgroup]],
|
4795
4843
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4796
4844
|
|
4797
|
-
kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
4845
|
+
kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
4798
4846
|
}
|
4799
4847
|
|
4800
4848
|
[[host_name("kernel_mul_mv_iq1_m_f32")]]
|
@@ -4822,7 +4870,7 @@ kernel void kernel_mul_mv_iq1_m_f32(
|
|
4822
4870
|
uint tiisg[[thread_index_in_simdgroup]],
|
4823
4871
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4824
4872
|
|
4825
|
-
kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
4873
|
+
kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
4826
4874
|
}
|
4827
4875
|
|
4828
4876
|
[[host_name("kernel_mul_mv_iq4_nl_f32")]]
|
@@ -4846,7 +4894,7 @@ kernel void kernel_mul_mv_iq4_nl_f32(
|
|
4846
4894
|
constant int64_t & ne1,
|
4847
4895
|
constant uint & r2,
|
4848
4896
|
constant uint & r3,
|
4849
|
-
threadgroup
|
4897
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
4850
4898
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
4851
4899
|
uint tiisg[[thread_index_in_simdgroup]],
|
4852
4900
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -4875,7 +4923,7 @@ kernel void kernel_mul_mv_iq4_xs_f32(
|
|
4875
4923
|
constant int64_t & ne1,
|
4876
4924
|
constant uint & r2,
|
4877
4925
|
constant uint & r3,
|
4878
|
-
threadgroup
|
4926
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
4879
4927
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
4880
4928
|
uint tiisg[[thread_index_in_simdgroup]],
|
4881
4929
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -5632,25 +5680,25 @@ void kernel_mul_mm_impl(device const uchar * src0,
|
|
5632
5680
|
}
|
5633
5681
|
}
|
5634
5682
|
|
5635
|
-
// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in
|
5683
|
+
// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids
|
5636
5684
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
5637
5685
|
void kernel_mul_mm_id_impl(
|
5638
5686
|
device const uchar * src0,
|
5639
5687
|
device const uchar * src1,
|
5640
|
-
threadgroup
|
5688
|
+
threadgroup ushort2 * rowids,
|
5641
5689
|
device float * dst,
|
5642
5690
|
constant int64_t & ne00,
|
5643
5691
|
constant int64_t & ne02,
|
5644
5692
|
constant uint64_t & nb01,
|
5645
5693
|
constant uint64_t & nb02,
|
5694
|
+
constant int64_t & ne11,
|
5646
5695
|
constant int64_t & ne12,
|
5647
5696
|
constant uint64_t & nb10,
|
5648
5697
|
constant uint64_t & nb11,
|
5649
5698
|
constant uint64_t & nb12,
|
5650
5699
|
constant int64_t & ne0,
|
5651
5700
|
int64_t ne1,
|
5652
|
-
|
5653
|
-
constant uint & r3,
|
5701
|
+
int64_t ne0ne1,
|
5654
5702
|
threadgroup uchar * shared_memory,
|
5655
5703
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
5656
5704
|
uint tiitg[[thread_index_in_threadgroup]],
|
@@ -5661,7 +5709,6 @@ void kernel_mul_mm_id_impl(
|
|
5661
5709
|
|
5662
5710
|
const uint r0 = tgpig.y;
|
5663
5711
|
const uint r1 = tgpig.x;
|
5664
|
-
const uint im = tgpig.z;
|
5665
5712
|
|
5666
5713
|
if (r1 * BLOCK_SIZE_N >= ne1) return;
|
5667
5714
|
|
@@ -5679,19 +5726,16 @@ void kernel_mul_mm_id_impl(
|
|
5679
5726
|
for (int i = 0; i < 8; i++){
|
5680
5727
|
c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
|
5681
5728
|
}
|
5682
|
-
|
5683
5729
|
short il = (tiitg % THREAD_PER_ROW);
|
5684
5730
|
|
5685
|
-
const uint i12 = im%ne12;
|
5686
|
-
const uint i13 = im/ne12;
|
5687
|
-
|
5688
|
-
uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
|
5689
5731
|
ushort offset1 = il/nl;
|
5690
5732
|
|
5691
|
-
|
5733
|
+
threadgroup const auto & id = rowids[r1 * BLOCK_SIZE_N + thread_col];
|
5734
|
+
|
5735
|
+
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01) + offset1;
|
5692
5736
|
device const float * y = (device const float *)(src1
|
5693
|
-
+ nb12 *
|
5694
|
-
+ nb11 *
|
5737
|
+
+ nb12 * id[1]
|
5738
|
+
+ nb11 * (id[0] % ne11)
|
5695
5739
|
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
5696
5740
|
|
5697
5741
|
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
@@ -5720,11 +5764,11 @@ void kernel_mul_mm_id_impl(
|
|
5720
5764
|
|
5721
5765
|
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
|
5722
5766
|
for (int i = 0; i < 4; i++) {
|
5723
|
-
simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
|
5767
|
+
simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
|
5724
5768
|
}
|
5725
5769
|
simdgroup_barrier(mem_flags::mem_none);
|
5726
5770
|
for (int i = 0; i < 2; i++) {
|
5727
|
-
simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
|
5771
|
+
simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
|
5728
5772
|
}
|
5729
5773
|
|
5730
5774
|
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
|
@@ -5746,11 +5790,13 @@ void kernel_mul_mm_id_impl(
|
|
5746
5790
|
|
5747
5791
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
5748
5792
|
|
5749
|
-
device float * C = dst + (BLOCK_SIZE_M * r0)
|
5793
|
+
device float * C = dst + (BLOCK_SIZE_M * r0);
|
5750
5794
|
if (sgitg == 0) {
|
5751
|
-
for (int
|
5752
|
-
|
5753
|
-
|
5795
|
+
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
|
5796
|
+
threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j];
|
5797
|
+
int joff = jid[0] * ne0 + jid[1] * ne0ne1;
|
5798
|
+
for (int i = 0; i < n_rows; i++) {
|
5799
|
+
*(C + i + joff) = *(temp_str + i + j * BLOCK_SIZE_M);
|
5754
5800
|
}
|
5755
5801
|
}
|
5756
5802
|
}
|
@@ -5805,11 +5851,14 @@ kernel void kernel_mul_mm_id(
|
|
5805
5851
|
device const uchar * src1,
|
5806
5852
|
device float * dst,
|
5807
5853
|
device const uchar * ids,
|
5854
|
+
constant int64_t & nei0,
|
5855
|
+
constant int64_t & nei1,
|
5808
5856
|
constant uint64_t & nbi1,
|
5809
5857
|
constant int64_t & ne00,
|
5810
5858
|
constant int64_t & ne02,
|
5811
5859
|
constant uint64_t & nb01,
|
5812
5860
|
constant uint64_t & nb02,
|
5861
|
+
constant int64_t & ne11,
|
5813
5862
|
constant int64_t & ne12,
|
5814
5863
|
constant int64_t & ne13,
|
5815
5864
|
constant uint64_t & nb10,
|
@@ -5818,47 +5867,52 @@ kernel void kernel_mul_mm_id(
|
|
5818
5867
|
constant int64_t & ne0,
|
5819
5868
|
constant int64_t & ne1,
|
5820
5869
|
constant uint64_t & nb1,
|
5821
|
-
constant uint & r2,
|
5822
|
-
constant uint & r3,
|
5823
|
-
constant int & idx,
|
5824
5870
|
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
5825
5871
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
5826
5872
|
uint tiitg[[thread_index_in_threadgroup]],
|
5827
5873
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
5828
5874
|
|
5829
|
-
|
5830
|
-
|
5831
|
-
device const uchar * src0 = src0s + id*nb02;
|
5875
|
+
const int32_t i02 = tgpig.z;
|
5876
|
+
tgpig.z = 0;
|
5832
5877
|
|
5833
|
-
|
5878
|
+
device const uchar * src0 = src0s + i02*nb02;
|
5834
5879
|
|
5835
|
-
// row indices
|
5836
|
-
threadgroup
|
5880
|
+
// row indices
|
5881
|
+
threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
|
5837
5882
|
|
5883
|
+
// TODO: parallelize this loop
|
5838
5884
|
int64_t _ne1 = 0;
|
5839
|
-
for (
|
5840
|
-
|
5841
|
-
|
5885
|
+
for (ushort ii1 = 0; ii1 < nei1; ii1++) {
|
5886
|
+
for (ushort ii0 = 0; ii0 < nei0; ii0++) {
|
5887
|
+
int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
|
5888
|
+
if (id == i02) {
|
5889
|
+
//if (tiitg == 0) {
|
5890
|
+
rowids[_ne1] = ushort2(ii0, ii1);
|
5891
|
+
//}
|
5892
|
+
_ne1++;
|
5893
|
+
}
|
5842
5894
|
}
|
5843
5895
|
}
|
5844
5896
|
|
5897
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
5898
|
+
|
5845
5899
|
kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
|
5846
5900
|
src0,
|
5847
5901
|
src1,
|
5848
|
-
|
5902
|
+
rowids,
|
5849
5903
|
dst,
|
5850
5904
|
ne00,
|
5851
5905
|
ne02,
|
5852
5906
|
nb01,
|
5853
5907
|
nb02,
|
5908
|
+
ne11,
|
5854
5909
|
ne12,
|
5855
5910
|
nb10,
|
5856
5911
|
nb11,
|
5857
5912
|
nb12,
|
5858
5913
|
ne0,
|
5859
5914
|
_ne1,
|
5860
|
-
|
5861
|
-
r3,
|
5915
|
+
ne0*ne1,
|
5862
5916
|
shared_memory,
|
5863
5917
|
tgpig,
|
5864
5918
|
tiitg,
|
@@ -5919,24 +5973,7 @@ template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_r
|
|
5919
5973
|
// matrix-matrix multiplication
|
5920
5974
|
//
|
5921
5975
|
|
5922
|
-
typedef
|
5923
|
-
device const uchar * src0,
|
5924
|
-
device const uchar * src1,
|
5925
|
-
device float * dst,
|
5926
|
-
constant int64_t & ne00,
|
5927
|
-
constant int64_t & ne02,
|
5928
|
-
constant uint64_t & nb01,
|
5929
|
-
constant uint64_t & nb02,
|
5930
|
-
constant int64_t & ne12,
|
5931
|
-
constant uint64_t & nb10,
|
5932
|
-
constant uint64_t & nb11,
|
5933
|
-
constant uint64_t & nb12,
|
5934
|
-
constant int64_t & ne0,
|
5935
|
-
constant int64_t & ne1,
|
5936
|
-
constant uint & r2,
|
5937
|
-
constant uint & r3,
|
5938
|
-
threadgroup uchar *,
|
5939
|
-
uint3, uint, uint);
|
5976
|
+
typedef decltype(kernel_mul_mm<float4x4, 1, dequantize_f32>) mat_mm_t;
|
5940
5977
|
|
5941
5978
|
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
|
5942
5979
|
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
@@ -5968,29 +6005,7 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_m
|
|
5968
6005
|
// indirect matrix-matrix multiplication
|
5969
6006
|
//
|
5970
6007
|
|
5971
|
-
typedef
|
5972
|
-
device const uchar * src0s,
|
5973
|
-
device const uchar * src1,
|
5974
|
-
device float * dst,
|
5975
|
-
device const uchar * ids,
|
5976
|
-
constant uint64_t & nbi1,
|
5977
|
-
constant int64_t & ne00,
|
5978
|
-
constant int64_t & ne02,
|
5979
|
-
constant uint64_t & nb01,
|
5980
|
-
constant uint64_t & nb02,
|
5981
|
-
constant int64_t & ne12,
|
5982
|
-
constant int64_t & ne13,
|
5983
|
-
constant uint64_t & nb10,
|
5984
|
-
constant uint64_t & nb11,
|
5985
|
-
constant uint64_t & nb12,
|
5986
|
-
constant int64_t & ne0,
|
5987
|
-
constant int64_t & ne1,
|
5988
|
-
constant uint64_t & nb1,
|
5989
|
-
constant uint & r2,
|
5990
|
-
constant uint & r3,
|
5991
|
-
constant int & idx,
|
5992
|
-
threadgroup uchar *,
|
5993
|
-
uint3, uint, uint);
|
6008
|
+
typedef decltype(kernel_mul_mm_id<float4x4, 1, dequantize_f32>) mat_mm_id_t;
|
5994
6009
|
|
5995
6010
|
template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
|
5996
6011
|
template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
|
@@ -6022,12 +6037,119 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel
|
|
6022
6037
|
// matrix-vector multiplication
|
6023
6038
|
//
|
6024
6039
|
|
6025
|
-
|
6026
|
-
|
6040
|
+
typedef void (kernel_mul_mv_impl_t)(
|
6041
|
+
device const char * src0,
|
6042
|
+
device const char * src1,
|
6043
|
+
device float * dst,
|
6044
|
+
int64_t ne00,
|
6045
|
+
int64_t ne01,
|
6046
|
+
int64_t ne02,
|
6047
|
+
uint64_t nb00,
|
6048
|
+
uint64_t nb01,
|
6049
|
+
uint64_t nb02,
|
6050
|
+
int64_t ne10,
|
6051
|
+
int64_t ne11,
|
6052
|
+
int64_t ne12,
|
6053
|
+
uint64_t nb10,
|
6054
|
+
uint64_t nb11,
|
6055
|
+
uint64_t nb12,
|
6056
|
+
int64_t ne0,
|
6057
|
+
int64_t ne1,
|
6058
|
+
uint r2,
|
6059
|
+
uint r3,
|
6060
|
+
uint3 tgpig,
|
6061
|
+
uint tiisg);
|
6062
|
+
|
6063
|
+
typedef void (kernel_mul_mv2_impl_t)(
|
6064
|
+
device const void * src0,
|
6065
|
+
device const float * src1,
|
6066
|
+
device float * dst,
|
6067
|
+
int64_t ne00,
|
6068
|
+
int64_t ne01,
|
6069
|
+
int64_t ne02,
|
6070
|
+
int64_t ne10,
|
6071
|
+
int64_t ne12,
|
6072
|
+
int64_t ne0,
|
6073
|
+
int64_t ne1,
|
6074
|
+
uint r2,
|
6075
|
+
uint r3,
|
6076
|
+
threadgroup int8_t * shared_values,
|
6077
|
+
uint3 tgpig,
|
6078
|
+
uint tiisg,
|
6079
|
+
uint sgitg);
|
6080
|
+
|
6081
|
+
template<kernel_mul_mv_impl_t impl_fn>
|
6082
|
+
void mmv_fn(
|
6083
|
+
device const char * src0,
|
6084
|
+
device const char * src1,
|
6085
|
+
device float * dst,
|
6086
|
+
int64_t ne00,
|
6087
|
+
int64_t ne01,
|
6088
|
+
int64_t ne02,
|
6089
|
+
uint64_t nb00,
|
6090
|
+
uint64_t nb01,
|
6091
|
+
uint64_t nb02,
|
6092
|
+
int64_t ne10,
|
6093
|
+
int64_t ne11,
|
6094
|
+
int64_t ne12,
|
6095
|
+
int64_t ne13,
|
6096
|
+
uint64_t nb10,
|
6097
|
+
uint64_t nb11,
|
6098
|
+
uint64_t nb12,
|
6099
|
+
int64_t ne0,
|
6100
|
+
int64_t ne1,
|
6101
|
+
uint64_t nb1,
|
6102
|
+
uint r2,
|
6103
|
+
uint r3,
|
6104
|
+
threadgroup int8_t * shared_values,
|
6105
|
+
uint3 tgpig,
|
6106
|
+
uint tiitg,
|
6107
|
+
uint tiisg,
|
6108
|
+
uint sgitg) {
|
6109
|
+
impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,ne10,ne11,ne12,nb10,nb11,nb12,ne0,ne1,r2,r3,tgpig,tiisg);
|
6110
|
+
}
|
6111
|
+
|
6112
|
+
template<kernel_mul_mv2_impl_t impl_fn>
|
6113
|
+
void mmv_fn(
|
6114
|
+
device const char * src0,
|
6115
|
+
device const char * src1,
|
6116
|
+
device float * dst,
|
6117
|
+
int64_t ne00,
|
6118
|
+
int64_t ne01,
|
6119
|
+
int64_t ne02,
|
6120
|
+
uint64_t nb00,
|
6121
|
+
uint64_t nb01,
|
6122
|
+
uint64_t nb02,
|
6123
|
+
int64_t ne10,
|
6124
|
+
int64_t ne11,
|
6125
|
+
int64_t ne12,
|
6126
|
+
int64_t ne13,
|
6127
|
+
uint64_t nb10,
|
6128
|
+
uint64_t nb11,
|
6129
|
+
uint64_t nb12,
|
6130
|
+
int64_t ne0,
|
6131
|
+
int64_t ne1,
|
6132
|
+
uint64_t nb1,
|
6133
|
+
uint r2,
|
6134
|
+
uint r3,
|
6135
|
+
threadgroup int8_t * shared_values,
|
6136
|
+
uint3 tgpig,
|
6137
|
+
uint tiitg,
|
6138
|
+
uint tiisg,
|
6139
|
+
uint sgitg) {
|
6140
|
+
impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
|
6141
|
+
}
|
6142
|
+
|
6143
|
+
typedef decltype(mmv_fn<kernel_mul_mv_f32_f32_impl>) mul_mv_impl_fn_t;
|
6144
|
+
|
6145
|
+
template<mul_mv_impl_fn_t impl_fn>
|
6146
|
+
kernel void kernel_mul_mv_id(
|
6027
6147
|
device const char * src0s,
|
6028
6148
|
device const char * src1,
|
6029
6149
|
device float * dst,
|
6030
6150
|
device const char * ids,
|
6151
|
+
constant int64_t & nei0,
|
6152
|
+
constant int64_t & nei1,
|
6031
6153
|
constant uint64_t & nbi1,
|
6032
6154
|
constant int64_t & ne00,
|
6033
6155
|
constant int64_t & ne01,
|
@@ -6045,1164 +6167,80 @@ kernel void kernel_mul_mv_id_f32_f32(
|
|
6045
6167
|
constant int64_t & ne0,
|
6046
6168
|
constant int64_t & ne1,
|
6047
6169
|
constant uint64_t & nb1,
|
6048
|
-
|
6049
|
-
constant uint & r3,
|
6050
|
-
constant int & idx,
|
6170
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
6051
6171
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6052
6172
|
uint tiitg[[thread_index_in_threadgroup]],
|
6053
6173
|
uint tiisg[[thread_index_in_simdgroup]],
|
6054
6174
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6055
|
-
const
|
6056
|
-
|
6057
|
-
|
6058
|
-
|
6059
|
-
|
6060
|
-
device const
|
6061
|
-
|
6062
|
-
|
6063
|
-
|
6064
|
-
|
6065
|
-
|
6066
|
-
|
6067
|
-
|
6068
|
-
|
6069
|
-
|
6070
|
-
|
6071
|
-
|
6072
|
-
|
6073
|
-
|
6074
|
-
|
6075
|
-
|
6076
|
-
|
6077
|
-
|
6078
|
-
|
6079
|
-
|
6080
|
-
|
6081
|
-
|
6175
|
+
const int iid1 = tgpig.z/nei0;
|
6176
|
+
const int idx = tgpig.z%nei0;
|
6177
|
+
|
6178
|
+
tgpig.z = 0;
|
6179
|
+
|
6180
|
+
const int32_t i02 = ((device const int32_t *) (ids + iid1*nbi1))[idx];
|
6181
|
+
|
6182
|
+
const int64_t i11 = idx % ne11;
|
6183
|
+
const int64_t i12 = iid1;
|
6184
|
+
|
6185
|
+
const int64_t i1 = idx;
|
6186
|
+
const int64_t i2 = i12;
|
6187
|
+
|
6188
|
+
device const char * src0_cur = src0s + i02*nb02;
|
6189
|
+
device const char * src1_cur = src1 + i11*nb11 + i12*nb12;
|
6190
|
+
device float * dst_cur = dst + i1*ne0 + i2*ne1*ne0;
|
6191
|
+
|
6192
|
+
impl_fn(
|
6193
|
+
/* src0 */ src0_cur,
|
6194
|
+
/* src1 */ src1_cur,
|
6195
|
+
/* dst */ dst_cur,
|
6196
|
+
/* ne00 */ ne00,
|
6197
|
+
/* ne01 */ ne01,
|
6198
|
+
/* ne02 */ 1,//ne02,
|
6199
|
+
/* nb00 */ nb00,
|
6200
|
+
/* nb01 */ nb01,
|
6201
|
+
/* nb02 */ nb02,
|
6202
|
+
/* ne10 */ ne10,
|
6203
|
+
/* ne11 */ 1,//ne11,
|
6204
|
+
/* ne12 */ 1,//ne12,
|
6205
|
+
/* ne13 */ 1,//ne13,
|
6206
|
+
/* nb10 */ nb10,
|
6207
|
+
/* nb11 */ nb11,
|
6208
|
+
/* nb12 */ nb12,
|
6209
|
+
/* ne0 */ ne0,
|
6210
|
+
/* ne1 */ 1,//ne1,
|
6211
|
+
/* nb1 */ nb1,
|
6212
|
+
/* r2 */ 1,
|
6213
|
+
/* r3 */ 1,
|
6214
|
+
shared_values,
|
6082
6215
|
tgpig,
|
6083
|
-
|
6216
|
+
tiitg,
|
6217
|
+
tiisg,
|
6218
|
+
sgitg);
|
6084
6219
|
}
|
6085
6220
|
|
6086
|
-
|
6087
|
-
|
6088
|
-
|
6089
|
-
|
6090
|
-
|
6091
|
-
|
6092
|
-
|
6093
|
-
|
6094
|
-
|
6095
|
-
|
6096
|
-
|
6097
|
-
|
6098
|
-
|
6099
|
-
|
6100
|
-
|
6101
|
-
|
6102
|
-
|
6103
|
-
|
6104
|
-
|
6105
|
-
|
6106
|
-
|
6107
|
-
|
6108
|
-
|
6109
|
-
|
6110
|
-
constant uint & r3,
|
6111
|
-
constant int & idx,
|
6112
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
6113
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
6114
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
6115
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6116
|
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
6117
|
-
|
6118
|
-
tgpig.z = tgpig.z%(ne12*ne13);
|
6119
|
-
|
6120
|
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6121
|
-
device const char * src0 = src0s + id*nb02;
|
6122
|
-
|
6123
|
-
kernel_mul_mv_f16_f32_impl(
|
6124
|
-
src0,
|
6125
|
-
src1 + bid*nb11,
|
6126
|
-
dst + bid*ne0,
|
6127
|
-
ne00,
|
6128
|
-
ne01,
|
6129
|
-
ne02,
|
6130
|
-
nb00,
|
6131
|
-
nb01,
|
6132
|
-
nb02,
|
6133
|
-
ne10,
|
6134
|
-
ne11,
|
6135
|
-
ne12,
|
6136
|
-
nb10,
|
6137
|
-
nb11,
|
6138
|
-
nb12,
|
6139
|
-
ne0,
|
6140
|
-
ne1,
|
6141
|
-
r2,
|
6142
|
-
r3,
|
6143
|
-
tgpig,
|
6144
|
-
tiisg);
|
6145
|
-
}
|
6146
|
-
|
6147
|
-
[[host_name("kernel_mul_mv_id_q8_0_f32")]]
|
6148
|
-
kernel void kernel_mul_mv_id_q8_0_f32(
|
6149
|
-
device const char * src0s,
|
6150
|
-
device const char * src1,
|
6151
|
-
device float * dst,
|
6152
|
-
device const char * ids,
|
6153
|
-
constant uint64_t & nbi1,
|
6154
|
-
constant int64_t & ne00,
|
6155
|
-
constant int64_t & ne01,
|
6156
|
-
constant int64_t & ne02,
|
6157
|
-
constant uint64_t & nb00,
|
6158
|
-
constant uint64_t & nb01,
|
6159
|
-
constant uint64_t & nb02,
|
6160
|
-
constant int64_t & ne10,
|
6161
|
-
constant int64_t & ne11,
|
6162
|
-
constant int64_t & ne12,
|
6163
|
-
constant int64_t & ne13,
|
6164
|
-
constant uint64_t & nb10,
|
6165
|
-
constant uint64_t & nb11,
|
6166
|
-
constant uint64_t & nb12,
|
6167
|
-
constant int64_t & ne0,
|
6168
|
-
constant int64_t & ne1,
|
6169
|
-
constant uint64_t & nb1,
|
6170
|
-
constant uint & r2,
|
6171
|
-
constant uint & r3,
|
6172
|
-
constant int & idx,
|
6173
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
6174
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
6175
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
6176
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6177
|
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
6178
|
-
|
6179
|
-
tgpig.z = tgpig.z%(ne12*ne13);
|
6180
|
-
|
6181
|
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6182
|
-
device const char * src0 = src0s + id*nb02;
|
6183
|
-
|
6184
|
-
kernel_mul_mv_q8_0_f32_impl(
|
6185
|
-
src0,
|
6186
|
-
(device const float *) (src1 + bid*nb11),
|
6187
|
-
dst + bid*ne0,
|
6188
|
-
ne00,
|
6189
|
-
ne01,
|
6190
|
-
ne02,
|
6191
|
-
ne10,
|
6192
|
-
ne12,
|
6193
|
-
ne0,
|
6194
|
-
ne1,
|
6195
|
-
r2,
|
6196
|
-
r3,
|
6197
|
-
tgpig,
|
6198
|
-
tiisg,
|
6199
|
-
sgitg);
|
6200
|
-
}
|
6201
|
-
|
6202
|
-
[[host_name("kernel_mul_mv_id_q4_0_f32")]]
|
6203
|
-
kernel void kernel_mul_mv_id_q4_0_f32(
|
6204
|
-
device const char * src0s,
|
6205
|
-
device const char * src1,
|
6206
|
-
device float * dst,
|
6207
|
-
device const char * ids,
|
6208
|
-
constant uint64_t & nbi1,
|
6209
|
-
constant int64_t & ne00,
|
6210
|
-
constant int64_t & ne01,
|
6211
|
-
constant int64_t & ne02,
|
6212
|
-
constant uint64_t & nb00,
|
6213
|
-
constant uint64_t & nb01,
|
6214
|
-
constant uint64_t & nb02,
|
6215
|
-
constant int64_t & ne10,
|
6216
|
-
constant int64_t & ne11,
|
6217
|
-
constant int64_t & ne12,
|
6218
|
-
constant int64_t & ne13,
|
6219
|
-
constant uint64_t & nb10,
|
6220
|
-
constant uint64_t & nb11,
|
6221
|
-
constant uint64_t & nb12,
|
6222
|
-
constant int64_t & ne0,
|
6223
|
-
constant int64_t & ne1,
|
6224
|
-
constant uint64_t & nb1,
|
6225
|
-
constant uint & r2,
|
6226
|
-
constant uint & r3,
|
6227
|
-
constant int & idx,
|
6228
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
6229
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
6230
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
6231
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6232
|
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
6233
|
-
|
6234
|
-
tgpig.z = tgpig.z%(ne12*ne13);
|
6235
|
-
|
6236
|
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6237
|
-
device const char * src0 = src0s + id*nb02;
|
6238
|
-
|
6239
|
-
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
6240
|
-
src0,
|
6241
|
-
(device const float *) (src1 + bid*nb11),
|
6242
|
-
dst + bid*ne0,
|
6243
|
-
ne00,
|
6244
|
-
ne01,
|
6245
|
-
ne02,
|
6246
|
-
ne10,
|
6247
|
-
ne12,
|
6248
|
-
ne0,
|
6249
|
-
ne1,
|
6250
|
-
r2,
|
6251
|
-
r3,
|
6252
|
-
tgpig,
|
6253
|
-
tiisg,
|
6254
|
-
sgitg);
|
6255
|
-
}
|
6256
|
-
|
6257
|
-
[[host_name("kernel_mul_mv_id_q4_1_f32")]]
|
6258
|
-
kernel void kernel_mul_mv_id_q4_1_f32(
|
6259
|
-
device const char * src0s,
|
6260
|
-
device const char * src1,
|
6261
|
-
device float * dst,
|
6262
|
-
device const char * ids,
|
6263
|
-
constant uint64_t & nbi1,
|
6264
|
-
constant int64_t & ne00,
|
6265
|
-
constant int64_t & ne01,
|
6266
|
-
constant int64_t & ne02,
|
6267
|
-
constant uint64_t & nb00,
|
6268
|
-
constant uint64_t & nb01,
|
6269
|
-
constant uint64_t & nb02,
|
6270
|
-
constant int64_t & ne10,
|
6271
|
-
constant int64_t & ne11,
|
6272
|
-
constant int64_t & ne12,
|
6273
|
-
constant int64_t & ne13,
|
6274
|
-
constant uint64_t & nb10,
|
6275
|
-
constant uint64_t & nb11,
|
6276
|
-
constant uint64_t & nb12,
|
6277
|
-
constant int64_t & ne0,
|
6278
|
-
constant int64_t & ne1,
|
6279
|
-
constant uint64_t & nb1,
|
6280
|
-
constant uint & r2,
|
6281
|
-
constant uint & r3,
|
6282
|
-
constant int & idx,
|
6283
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
6284
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
6285
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
6286
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6287
|
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
6288
|
-
|
6289
|
-
tgpig.z = tgpig.z%(ne12*ne13);
|
6290
|
-
|
6291
|
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6292
|
-
device const char * src0 = src0s + id*nb02;
|
6293
|
-
|
6294
|
-
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
6295
|
-
src0,
|
6296
|
-
(device const float *) (src1 + bid*nb11),
|
6297
|
-
dst + bid*ne0,
|
6298
|
-
ne00,
|
6299
|
-
ne01,
|
6300
|
-
ne02,
|
6301
|
-
ne10,
|
6302
|
-
ne12,
|
6303
|
-
ne0,
|
6304
|
-
ne1,
|
6305
|
-
r2,
|
6306
|
-
r3,
|
6307
|
-
tgpig,
|
6308
|
-
tiisg,
|
6309
|
-
sgitg);
|
6310
|
-
}
|
6311
|
-
|
6312
|
-
[[host_name("kernel_mul_mv_id_q5_0_f32")]]
|
6313
|
-
kernel void kernel_mul_mv_id_q5_0_f32(
|
6314
|
-
device const char * src0s,
|
6315
|
-
device const char * src1,
|
6316
|
-
device float * dst,
|
6317
|
-
device const char * ids,
|
6318
|
-
constant uint64_t & nbi1,
|
6319
|
-
constant int64_t & ne00,
|
6320
|
-
constant int64_t & ne01,
|
6321
|
-
constant int64_t & ne02,
|
6322
|
-
constant uint64_t & nb00,
|
6323
|
-
constant uint64_t & nb01,
|
6324
|
-
constant uint64_t & nb02,
|
6325
|
-
constant int64_t & ne10,
|
6326
|
-
constant int64_t & ne11,
|
6327
|
-
constant int64_t & ne12,
|
6328
|
-
constant int64_t & ne13,
|
6329
|
-
constant uint64_t & nb10,
|
6330
|
-
constant uint64_t & nb11,
|
6331
|
-
constant uint64_t & nb12,
|
6332
|
-
constant int64_t & ne0,
|
6333
|
-
constant int64_t & ne1,
|
6334
|
-
constant uint64_t & nb1,
|
6335
|
-
constant uint & r2,
|
6336
|
-
constant uint & r3,
|
6337
|
-
constant int & idx,
|
6338
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
6339
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
6340
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
6341
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6342
|
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
6343
|
-
|
6344
|
-
tgpig.z = tgpig.z%(ne12*ne13);
|
6345
|
-
|
6346
|
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6347
|
-
device const char * src0 = src0s + id*nb02;
|
6348
|
-
|
6349
|
-
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
6350
|
-
src0,
|
6351
|
-
(device const float *) (src1 + bid*nb11),
|
6352
|
-
dst + bid*ne0,
|
6353
|
-
ne00,
|
6354
|
-
ne01,
|
6355
|
-
ne02,
|
6356
|
-
ne10,
|
6357
|
-
ne12,
|
6358
|
-
ne0,
|
6359
|
-
ne1,
|
6360
|
-
r2,
|
6361
|
-
r3,
|
6362
|
-
tgpig,
|
6363
|
-
tiisg,
|
6364
|
-
sgitg);
|
6365
|
-
}
|
6366
|
-
|
6367
|
-
[[host_name("kernel_mul_mv_id_q5_1_f32")]]
|
6368
|
-
kernel void kernel_mul_mv_id_q5_1_f32(
|
6369
|
-
device const char * src0s,
|
6370
|
-
device const char * src1,
|
6371
|
-
device float * dst,
|
6372
|
-
device const char * ids,
|
6373
|
-
constant uint64_t & nbi1,
|
6374
|
-
constant int64_t & ne00,
|
6375
|
-
constant int64_t & ne01,
|
6376
|
-
constant int64_t & ne02,
|
6377
|
-
constant uint64_t & nb00,
|
6378
|
-
constant uint64_t & nb01,
|
6379
|
-
constant uint64_t & nb02,
|
6380
|
-
constant int64_t & ne10,
|
6381
|
-
constant int64_t & ne11,
|
6382
|
-
constant int64_t & ne12,
|
6383
|
-
constant int64_t & ne13,
|
6384
|
-
constant uint64_t & nb10,
|
6385
|
-
constant uint64_t & nb11,
|
6386
|
-
constant uint64_t & nb12,
|
6387
|
-
constant int64_t & ne0,
|
6388
|
-
constant int64_t & ne1,
|
6389
|
-
constant uint64_t & nb1,
|
6390
|
-
constant uint & r2,
|
6391
|
-
constant uint & r3,
|
6392
|
-
constant int & idx,
|
6393
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
6394
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
6395
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
6396
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6397
|
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
6398
|
-
|
6399
|
-
tgpig.z = tgpig.z%(ne12*ne13);
|
6400
|
-
|
6401
|
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6402
|
-
device const char * src0 = src0s + id*nb02;
|
6403
|
-
|
6404
|
-
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
6405
|
-
src0,
|
6406
|
-
(device const float *) (src1 + bid*nb11),
|
6407
|
-
dst + bid*ne0,
|
6408
|
-
ne00,
|
6409
|
-
ne01,
|
6410
|
-
ne02,
|
6411
|
-
ne10,
|
6412
|
-
ne12,
|
6413
|
-
ne0,
|
6414
|
-
ne1,
|
6415
|
-
r2,
|
6416
|
-
r3,
|
6417
|
-
tgpig,
|
6418
|
-
tiisg,
|
6419
|
-
sgitg);
|
6420
|
-
}
|
6421
|
-
|
6422
|
-
[[host_name("kernel_mul_mv_id_q2_K_f32")]]
|
6423
|
-
kernel void kernel_mul_mv_id_q2_K_f32(
|
6424
|
-
device const char * src0s,
|
6425
|
-
device const char * src1,
|
6426
|
-
device float * dst,
|
6427
|
-
device const char * ids,
|
6428
|
-
constant uint64_t & nbi1,
|
6429
|
-
constant int64_t & ne00,
|
6430
|
-
constant int64_t & ne01,
|
6431
|
-
constant int64_t & ne02,
|
6432
|
-
constant uint64_t & nb00,
|
6433
|
-
constant uint64_t & nb01,
|
6434
|
-
constant uint64_t & nb02,
|
6435
|
-
constant int64_t & ne10,
|
6436
|
-
constant int64_t & ne11,
|
6437
|
-
constant int64_t & ne12,
|
6438
|
-
constant int64_t & ne13,
|
6439
|
-
constant uint64_t & nb10,
|
6440
|
-
constant uint64_t & nb11,
|
6441
|
-
constant uint64_t & nb12,
|
6442
|
-
constant int64_t & ne0,
|
6443
|
-
constant int64_t & ne1,
|
6444
|
-
constant uint64_t & nb1,
|
6445
|
-
constant uint & r2,
|
6446
|
-
constant uint & r3,
|
6447
|
-
constant int & idx,
|
6448
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
6449
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
6450
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
6451
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6452
|
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
6453
|
-
|
6454
|
-
tgpig.z = tgpig.z%(ne12*ne13);
|
6455
|
-
|
6456
|
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6457
|
-
device const char * src0 = src0s + id*nb02;
|
6458
|
-
|
6459
|
-
kernel_mul_mv_q2_K_f32_impl(
|
6460
|
-
src0,
|
6461
|
-
(device const float *) (src1 + bid*nb11),
|
6462
|
-
dst + bid*ne0,
|
6463
|
-
ne00,
|
6464
|
-
ne01,
|
6465
|
-
ne02,
|
6466
|
-
ne10,
|
6467
|
-
ne12,
|
6468
|
-
ne0,
|
6469
|
-
ne1,
|
6470
|
-
r2,
|
6471
|
-
r3,
|
6472
|
-
tgpig,
|
6473
|
-
tiisg,
|
6474
|
-
sgitg);
|
6475
|
-
}
|
6476
|
-
|
6477
|
-
[[host_name("kernel_mul_mv_id_q3_K_f32")]]
|
6478
|
-
kernel void kernel_mul_mv_id_q3_K_f32(
|
6479
|
-
device const char * src0s,
|
6480
|
-
device const char * src1,
|
6481
|
-
device float * dst,
|
6482
|
-
device const char * ids,
|
6483
|
-
constant uint64_t & nbi1,
|
6484
|
-
constant int64_t & ne00,
|
6485
|
-
constant int64_t & ne01,
|
6486
|
-
constant int64_t & ne02,
|
6487
|
-
constant uint64_t & nb00,
|
6488
|
-
constant uint64_t & nb01,
|
6489
|
-
constant uint64_t & nb02,
|
6490
|
-
constant int64_t & ne10,
|
6491
|
-
constant int64_t & ne11,
|
6492
|
-
constant int64_t & ne12,
|
6493
|
-
constant int64_t & ne13,
|
6494
|
-
constant uint64_t & nb10,
|
6495
|
-
constant uint64_t & nb11,
|
6496
|
-
constant uint64_t & nb12,
|
6497
|
-
constant int64_t & ne0,
|
6498
|
-
constant int64_t & ne1,
|
6499
|
-
constant uint64_t & nb1,
|
6500
|
-
constant uint & r2,
|
6501
|
-
constant uint & r3,
|
6502
|
-
constant int & idx,
|
6503
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
6504
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
6505
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
6506
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6507
|
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
6508
|
-
|
6509
|
-
tgpig.z = tgpig.z%(ne12*ne13);
|
6510
|
-
|
6511
|
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6512
|
-
device const char * src0 = src0s + id*nb02;
|
6513
|
-
|
6514
|
-
kernel_mul_mv_q3_K_f32_impl(
|
6515
|
-
src0,
|
6516
|
-
(device const float *) (src1 + bid*nb11),
|
6517
|
-
dst + bid*ne0,
|
6518
|
-
ne00,
|
6519
|
-
ne01,
|
6520
|
-
ne02,
|
6521
|
-
ne10,
|
6522
|
-
ne12,
|
6523
|
-
ne0,
|
6524
|
-
ne1,
|
6525
|
-
r2,
|
6526
|
-
r3,
|
6527
|
-
tgpig,
|
6528
|
-
tiisg,
|
6529
|
-
sgitg);
|
6530
|
-
}
|
6531
|
-
|
6532
|
-
[[host_name("kernel_mul_mv_id_q4_K_f32")]]
|
6533
|
-
kernel void kernel_mul_mv_id_q4_K_f32(
|
6534
|
-
device const char * src0s,
|
6535
|
-
device const char * src1,
|
6536
|
-
device float * dst,
|
6537
|
-
device const char * ids,
|
6538
|
-
constant uint64_t & nbi1,
|
6539
|
-
constant int64_t & ne00,
|
6540
|
-
constant int64_t & ne01,
|
6541
|
-
constant int64_t & ne02,
|
6542
|
-
constant uint64_t & nb00,
|
6543
|
-
constant uint64_t & nb01,
|
6544
|
-
constant uint64_t & nb02,
|
6545
|
-
constant int64_t & ne10,
|
6546
|
-
constant int64_t & ne11,
|
6547
|
-
constant int64_t & ne12,
|
6548
|
-
constant int64_t & ne13,
|
6549
|
-
constant uint64_t & nb10,
|
6550
|
-
constant uint64_t & nb11,
|
6551
|
-
constant uint64_t & nb12,
|
6552
|
-
constant int64_t & ne0,
|
6553
|
-
constant int64_t & ne1,
|
6554
|
-
constant uint64_t & nb1,
|
6555
|
-
constant uint & r2,
|
6556
|
-
constant uint & r3,
|
6557
|
-
constant int & idx,
|
6558
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
6559
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
6560
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
6561
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6562
|
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
6563
|
-
|
6564
|
-
tgpig.z = tgpig.z%(ne12*ne13);
|
6565
|
-
|
6566
|
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6567
|
-
device const char * src0 = src0s + id*nb02;
|
6568
|
-
|
6569
|
-
kernel_mul_mv_q4_K_f32_impl(
|
6570
|
-
src0,
|
6571
|
-
(device const float *) (src1 + bid*nb11),
|
6572
|
-
dst + bid*ne0,
|
6573
|
-
ne00,
|
6574
|
-
ne01,
|
6575
|
-
ne02,
|
6576
|
-
ne10,
|
6577
|
-
ne12,
|
6578
|
-
ne0,
|
6579
|
-
ne1,
|
6580
|
-
r2,
|
6581
|
-
r3,
|
6582
|
-
tgpig,
|
6583
|
-
tiisg,
|
6584
|
-
sgitg);
|
6585
|
-
}
|
6586
|
-
|
6587
|
-
[[host_name("kernel_mul_mv_id_q5_K_f32")]]
|
6588
|
-
kernel void kernel_mul_mv_id_q5_K_f32(
|
6589
|
-
device const char * src0s,
|
6590
|
-
device const char * src1,
|
6591
|
-
device float * dst,
|
6592
|
-
device const char * ids,
|
6593
|
-
constant uint64_t & nbi1,
|
6594
|
-
constant int64_t & ne00,
|
6595
|
-
constant int64_t & ne01,
|
6596
|
-
constant int64_t & ne02,
|
6597
|
-
constant uint64_t & nb00,
|
6598
|
-
constant uint64_t & nb01,
|
6599
|
-
constant uint64_t & nb02,
|
6600
|
-
constant int64_t & ne10,
|
6601
|
-
constant int64_t & ne11,
|
6602
|
-
constant int64_t & ne12,
|
6603
|
-
constant int64_t & ne13,
|
6604
|
-
constant uint64_t & nb10,
|
6605
|
-
constant uint64_t & nb11,
|
6606
|
-
constant uint64_t & nb12,
|
6607
|
-
constant int64_t & ne0,
|
6608
|
-
constant int64_t & ne1,
|
6609
|
-
constant uint64_t & nb1,
|
6610
|
-
constant uint & r2,
|
6611
|
-
constant uint & r3,
|
6612
|
-
constant int & idx,
|
6613
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
6614
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
6615
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
6616
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6617
|
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
6618
|
-
|
6619
|
-
tgpig.z = tgpig.z%(ne12*ne13);
|
6620
|
-
|
6621
|
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6622
|
-
device const char * src0 = src0s + id*nb02;
|
6623
|
-
|
6624
|
-
kernel_mul_mv_q5_K_f32_impl(
|
6625
|
-
src0,
|
6626
|
-
(device const float *) (src1 + bid*nb11),
|
6627
|
-
dst + bid*ne0,
|
6628
|
-
ne00,
|
6629
|
-
ne01,
|
6630
|
-
ne02,
|
6631
|
-
ne10,
|
6632
|
-
ne12,
|
6633
|
-
ne0,
|
6634
|
-
ne1,
|
6635
|
-
r2,
|
6636
|
-
r3,
|
6637
|
-
tgpig,
|
6638
|
-
tiisg,
|
6639
|
-
sgitg);
|
6640
|
-
}
|
6641
|
-
|
6642
|
-
[[host_name("kernel_mul_mv_id_q6_K_f32")]]
|
6643
|
-
kernel void kernel_mul_mv_id_q6_K_f32(
|
6644
|
-
device const char * src0s,
|
6645
|
-
device const char * src1,
|
6646
|
-
device float * dst,
|
6647
|
-
device const char * ids,
|
6648
|
-
constant uint64_t & nbi1,
|
6649
|
-
constant int64_t & ne00,
|
6650
|
-
constant int64_t & ne01,
|
6651
|
-
constant int64_t & ne02,
|
6652
|
-
constant uint64_t & nb00,
|
6653
|
-
constant uint64_t & nb01,
|
6654
|
-
constant uint64_t & nb02,
|
6655
|
-
constant int64_t & ne10,
|
6656
|
-
constant int64_t & ne11,
|
6657
|
-
constant int64_t & ne12,
|
6658
|
-
constant int64_t & ne13,
|
6659
|
-
constant uint64_t & nb10,
|
6660
|
-
constant uint64_t & nb11,
|
6661
|
-
constant uint64_t & nb12,
|
6662
|
-
constant int64_t & ne0,
|
6663
|
-
constant int64_t & ne1,
|
6664
|
-
constant uint64_t & nb1,
|
6665
|
-
constant uint & r2,
|
6666
|
-
constant uint & r3,
|
6667
|
-
constant int & idx,
|
6668
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
6669
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
6670
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
6671
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6672
|
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
6673
|
-
|
6674
|
-
tgpig.z = tgpig.z%(ne12*ne13);
|
6675
|
-
|
6676
|
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6677
|
-
device const char * src0 = src0s + id*nb02;
|
6678
|
-
|
6679
|
-
kernel_mul_mv_q6_K_f32_impl(
|
6680
|
-
src0,
|
6681
|
-
(device const float *) (src1 + bid*nb11),
|
6682
|
-
dst + bid*ne0,
|
6683
|
-
ne00,
|
6684
|
-
ne01,
|
6685
|
-
ne02,
|
6686
|
-
ne10,
|
6687
|
-
ne12,
|
6688
|
-
ne0,
|
6689
|
-
ne1,
|
6690
|
-
r2,
|
6691
|
-
r3,
|
6692
|
-
tgpig,
|
6693
|
-
tiisg,
|
6694
|
-
sgitg);
|
6695
|
-
}
|
6696
|
-
|
6697
|
-
[[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
|
6698
|
-
kernel void kernel_mul_mv_id_iq2_xxs_f32(
|
6699
|
-
device const char * src0s,
|
6700
|
-
device const char * src1,
|
6701
|
-
device float * dst,
|
6702
|
-
device const char * ids,
|
6703
|
-
constant uint64_t & nbi1,
|
6704
|
-
constant int64_t & ne00,
|
6705
|
-
constant int64_t & ne01,
|
6706
|
-
constant int64_t & ne02,
|
6707
|
-
constant uint64_t & nb00,
|
6708
|
-
constant uint64_t & nb01,
|
6709
|
-
constant uint64_t & nb02,
|
6710
|
-
constant int64_t & ne10,
|
6711
|
-
constant int64_t & ne11,
|
6712
|
-
constant int64_t & ne12,
|
6713
|
-
constant int64_t & ne13,
|
6714
|
-
constant uint64_t & nb10,
|
6715
|
-
constant uint64_t & nb11,
|
6716
|
-
constant uint64_t & nb12,
|
6717
|
-
constant int64_t & ne0,
|
6718
|
-
constant int64_t & ne1,
|
6719
|
-
constant uint64_t & nb1,
|
6720
|
-
constant uint & r2,
|
6721
|
-
constant uint & r3,
|
6722
|
-
constant int & idx,
|
6723
|
-
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
6724
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
6725
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
6726
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
6727
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6728
|
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
6729
|
-
|
6730
|
-
tgpig.z = tgpig.z%(ne12*ne13);
|
6731
|
-
|
6732
|
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6733
|
-
device const char * src0 = src0s + id*nb02;
|
6734
|
-
|
6735
|
-
kernel_mul_mv_iq2_xxs_f32_impl(
|
6736
|
-
src0,
|
6737
|
-
(device const float *) (src1 + bid*nb11),
|
6738
|
-
dst + bid*ne0,
|
6739
|
-
ne00,
|
6740
|
-
ne01,
|
6741
|
-
ne02,
|
6742
|
-
ne10,
|
6743
|
-
ne12,
|
6744
|
-
ne0,
|
6745
|
-
ne1,
|
6746
|
-
r2,
|
6747
|
-
r3,
|
6748
|
-
shared_values,
|
6749
|
-
tgpig,
|
6750
|
-
tiisg,
|
6751
|
-
sgitg);
|
6752
|
-
}
|
6753
|
-
|
6754
|
-
[[host_name("kernel_mul_mv_id_iq2_xs_f32")]]
|
6755
|
-
kernel void kernel_mul_mv_id_iq2_xs_f32(
|
6756
|
-
device const char * src0s,
|
6757
|
-
device const char * src1,
|
6758
|
-
device float * dst,
|
6759
|
-
device const char * ids,
|
6760
|
-
constant uint64_t & nbi1,
|
6761
|
-
constant int64_t & ne00,
|
6762
|
-
constant int64_t & ne01,
|
6763
|
-
constant int64_t & ne02,
|
6764
|
-
constant uint64_t & nb00,
|
6765
|
-
constant uint64_t & nb01,
|
6766
|
-
constant uint64_t & nb02,
|
6767
|
-
constant int64_t & ne10,
|
6768
|
-
constant int64_t & ne11,
|
6769
|
-
constant int64_t & ne12,
|
6770
|
-
constant int64_t & ne13,
|
6771
|
-
constant uint64_t & nb10,
|
6772
|
-
constant uint64_t & nb11,
|
6773
|
-
constant uint64_t & nb12,
|
6774
|
-
constant int64_t & ne0,
|
6775
|
-
constant int64_t & ne1,
|
6776
|
-
constant uint64_t & nb1,
|
6777
|
-
constant uint & r2,
|
6778
|
-
constant uint & r3,
|
6779
|
-
constant int & idx,
|
6780
|
-
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
6781
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
6782
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
6783
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
6784
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6785
|
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
6786
|
-
|
6787
|
-
tgpig.z = tgpig.z%(ne12*ne13);
|
6788
|
-
|
6789
|
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6790
|
-
device const char * src0 = src0s + id*nb02;
|
6791
|
-
|
6792
|
-
kernel_mul_mv_iq2_xs_f32_impl(
|
6793
|
-
src0,
|
6794
|
-
(device const float *) (src1 + bid*nb11),
|
6795
|
-
dst + bid*ne0,
|
6796
|
-
ne00,
|
6797
|
-
ne01,
|
6798
|
-
ne02,
|
6799
|
-
ne10,
|
6800
|
-
ne12,
|
6801
|
-
ne0,
|
6802
|
-
ne1,
|
6803
|
-
r2,
|
6804
|
-
r3,
|
6805
|
-
shared_values,
|
6806
|
-
tgpig,
|
6807
|
-
tiisg,
|
6808
|
-
sgitg);
|
6809
|
-
}
|
6810
|
-
|
6811
|
-
[[host_name("kernel_mul_mv_id_iq3_xxs_f32")]]
|
6812
|
-
kernel void kernel_mul_mv_id_iq3_xxs_f32(
|
6813
|
-
device const char * src0s,
|
6814
|
-
device const char * src1,
|
6815
|
-
device float * dst,
|
6816
|
-
device const char * ids,
|
6817
|
-
constant uint64_t & nbi1,
|
6818
|
-
constant int64_t & ne00,
|
6819
|
-
constant int64_t & ne01,
|
6820
|
-
constant int64_t & ne02,
|
6821
|
-
constant uint64_t & nb00,
|
6822
|
-
constant uint64_t & nb01,
|
6823
|
-
constant uint64_t & nb02,
|
6824
|
-
constant int64_t & ne10,
|
6825
|
-
constant int64_t & ne11,
|
6826
|
-
constant int64_t & ne12,
|
6827
|
-
constant int64_t & ne13,
|
6828
|
-
constant uint64_t & nb10,
|
6829
|
-
constant uint64_t & nb11,
|
6830
|
-
constant uint64_t & nb12,
|
6831
|
-
constant int64_t & ne0,
|
6832
|
-
constant int64_t & ne1,
|
6833
|
-
constant uint64_t & nb1,
|
6834
|
-
constant uint & r2,
|
6835
|
-
constant uint & r3,
|
6836
|
-
constant int & idx,
|
6837
|
-
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
6838
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
6839
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
6840
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
6841
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6842
|
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
6843
|
-
|
6844
|
-
tgpig.z = tgpig.z%(ne12*ne13);
|
6845
|
-
|
6846
|
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6847
|
-
device const char * src0 = src0s + id*nb02;
|
6848
|
-
|
6849
|
-
kernel_mul_mv_iq3_xxs_f32_impl(
|
6850
|
-
src0,
|
6851
|
-
(device const float *) (src1 + bid*nb11),
|
6852
|
-
dst + bid*ne0,
|
6853
|
-
ne00,
|
6854
|
-
ne01,
|
6855
|
-
ne02,
|
6856
|
-
ne10,
|
6857
|
-
ne12,
|
6858
|
-
ne0,
|
6859
|
-
ne1,
|
6860
|
-
r2,
|
6861
|
-
r3,
|
6862
|
-
shared_values,
|
6863
|
-
tgpig,
|
6864
|
-
tiisg,
|
6865
|
-
sgitg);
|
6866
|
-
}
|
6867
|
-
|
6868
|
-
[[host_name("kernel_mul_mv_id_iq3_s_f32")]]
|
6869
|
-
kernel void kernel_mul_mv_id_iq3_s_f32(
|
6870
|
-
device const char * src0s,
|
6871
|
-
device const char * src1,
|
6872
|
-
device float * dst,
|
6873
|
-
device const char * ids,
|
6874
|
-
constant uint64_t & nbi1,
|
6875
|
-
constant int64_t & ne00,
|
6876
|
-
constant int64_t & ne01,
|
6877
|
-
constant int64_t & ne02,
|
6878
|
-
constant uint64_t & nb00,
|
6879
|
-
constant uint64_t & nb01,
|
6880
|
-
constant uint64_t & nb02,
|
6881
|
-
constant int64_t & ne10,
|
6882
|
-
constant int64_t & ne11,
|
6883
|
-
constant int64_t & ne12,
|
6884
|
-
constant int64_t & ne13,
|
6885
|
-
constant uint64_t & nb10,
|
6886
|
-
constant uint64_t & nb11,
|
6887
|
-
constant uint64_t & nb12,
|
6888
|
-
constant int64_t & ne0,
|
6889
|
-
constant int64_t & ne1,
|
6890
|
-
constant uint64_t & nb1,
|
6891
|
-
constant uint & r2,
|
6892
|
-
constant uint & r3,
|
6893
|
-
constant int & idx,
|
6894
|
-
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
6895
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
6896
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
6897
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
6898
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6899
|
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
6900
|
-
|
6901
|
-
tgpig.z = tgpig.z%(ne12*ne13);
|
6902
|
-
|
6903
|
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6904
|
-
device const char * src0 = src0s + id*nb02;
|
6905
|
-
|
6906
|
-
kernel_mul_mv_iq3_s_f32_impl(
|
6907
|
-
src0,
|
6908
|
-
(device const float *) (src1 + bid*nb11),
|
6909
|
-
dst + bid*ne0,
|
6910
|
-
ne00,
|
6911
|
-
ne01,
|
6912
|
-
ne02,
|
6913
|
-
ne10,
|
6914
|
-
ne12,
|
6915
|
-
ne0,
|
6916
|
-
ne1,
|
6917
|
-
r2,
|
6918
|
-
r3,
|
6919
|
-
shared_values,
|
6920
|
-
tgpig,
|
6921
|
-
tiisg,
|
6922
|
-
sgitg);
|
6923
|
-
}
|
6924
|
-
|
6925
|
-
[[host_name("kernel_mul_mv_id_iq2_s_f32")]]
|
6926
|
-
kernel void kernel_mul_mv_id_iq2_s_f32(
|
6927
|
-
device const char * src0s,
|
6928
|
-
device const char * src1,
|
6929
|
-
device float * dst,
|
6930
|
-
device const char * ids,
|
6931
|
-
constant uint64_t & nbi1,
|
6932
|
-
constant int64_t & ne00,
|
6933
|
-
constant int64_t & ne01,
|
6934
|
-
constant int64_t & ne02,
|
6935
|
-
constant uint64_t & nb00,
|
6936
|
-
constant uint64_t & nb01,
|
6937
|
-
constant uint64_t & nb02,
|
6938
|
-
constant int64_t & ne10,
|
6939
|
-
constant int64_t & ne11,
|
6940
|
-
constant int64_t & ne12,
|
6941
|
-
constant int64_t & ne13,
|
6942
|
-
constant uint64_t & nb10,
|
6943
|
-
constant uint64_t & nb11,
|
6944
|
-
constant uint64_t & nb12,
|
6945
|
-
constant int64_t & ne0,
|
6946
|
-
constant int64_t & ne1,
|
6947
|
-
constant uint64_t & nb1,
|
6948
|
-
constant uint & r2,
|
6949
|
-
constant uint & r3,
|
6950
|
-
constant int & idx,
|
6951
|
-
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
6952
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
6953
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
6954
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
6955
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6956
|
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
6957
|
-
|
6958
|
-
tgpig.z = tgpig.z%(ne12*ne13);
|
6959
|
-
|
6960
|
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6961
|
-
device const char * src0 = src0s + id*nb02;
|
6962
|
-
|
6963
|
-
kernel_mul_mv_iq2_s_f32_impl(
|
6964
|
-
src0,
|
6965
|
-
(device const float *) (src1 + bid*nb11),
|
6966
|
-
dst + bid*ne0,
|
6967
|
-
ne00,
|
6968
|
-
ne01,
|
6969
|
-
ne02,
|
6970
|
-
ne10,
|
6971
|
-
ne12,
|
6972
|
-
ne0,
|
6973
|
-
ne1,
|
6974
|
-
r2,
|
6975
|
-
r3,
|
6976
|
-
shared_values,
|
6977
|
-
tgpig,
|
6978
|
-
tiisg,
|
6979
|
-
sgitg);
|
6980
|
-
}
|
6981
|
-
|
6982
|
-
[[host_name("kernel_mul_mv_id_iq1_s_f32")]]
|
6983
|
-
kernel void kernel_mul_mv_id_iq1_s_f32(
|
6984
|
-
device const char * src0s,
|
6985
|
-
device const char * src1,
|
6986
|
-
device float * dst,
|
6987
|
-
device const char * ids,
|
6988
|
-
constant uint64_t & nbi1,
|
6989
|
-
constant int64_t & ne00,
|
6990
|
-
constant int64_t & ne01,
|
6991
|
-
constant int64_t & ne02,
|
6992
|
-
constant uint64_t & nb00,
|
6993
|
-
constant uint64_t & nb01,
|
6994
|
-
constant uint64_t & nb02,
|
6995
|
-
constant int64_t & ne10,
|
6996
|
-
constant int64_t & ne11,
|
6997
|
-
constant int64_t & ne12,
|
6998
|
-
constant int64_t & ne13,
|
6999
|
-
constant uint64_t & nb10,
|
7000
|
-
constant uint64_t & nb11,
|
7001
|
-
constant uint64_t & nb12,
|
7002
|
-
constant int64_t & ne0,
|
7003
|
-
constant int64_t & ne1,
|
7004
|
-
constant uint64_t & nb1,
|
7005
|
-
constant uint & r2,
|
7006
|
-
constant uint & r3,
|
7007
|
-
constant int & idx,
|
7008
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
7009
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
7010
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
7011
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
7012
|
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
7013
|
-
|
7014
|
-
tgpig.z = tgpig.z%(ne12*ne13);
|
7015
|
-
|
7016
|
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
7017
|
-
device const char * src0 = src0s + id*nb02;
|
7018
|
-
|
7019
|
-
kernel_mul_mv_iq1_s_f32_impl(
|
7020
|
-
src0,
|
7021
|
-
(device const float *) (src1 + bid*nb11),
|
7022
|
-
dst + bid*ne0,
|
7023
|
-
ne00,
|
7024
|
-
ne01,
|
7025
|
-
ne02,
|
7026
|
-
ne10,
|
7027
|
-
ne12,
|
7028
|
-
ne0,
|
7029
|
-
ne1,
|
7030
|
-
r2,
|
7031
|
-
r3,
|
7032
|
-
tgpig,
|
7033
|
-
tiisg,
|
7034
|
-
sgitg);
|
7035
|
-
}
|
7036
|
-
|
7037
|
-
[[host_name("kernel_mul_mv_id_iq1_m_f32")]]
|
7038
|
-
kernel void kernel_mul_mv_id_iq1_m_f32(
|
7039
|
-
device const char * src0s,
|
7040
|
-
device const char * src1,
|
7041
|
-
device float * dst,
|
7042
|
-
device const char * ids,
|
7043
|
-
constant uint64_t & nbi1,
|
7044
|
-
constant int64_t & ne00,
|
7045
|
-
constant int64_t & ne01,
|
7046
|
-
constant int64_t & ne02,
|
7047
|
-
constant uint64_t & nb00,
|
7048
|
-
constant uint64_t & nb01,
|
7049
|
-
constant uint64_t & nb02,
|
7050
|
-
constant int64_t & ne10,
|
7051
|
-
constant int64_t & ne11,
|
7052
|
-
constant int64_t & ne12,
|
7053
|
-
constant int64_t & ne13,
|
7054
|
-
constant uint64_t & nb10,
|
7055
|
-
constant uint64_t & nb11,
|
7056
|
-
constant uint64_t & nb12,
|
7057
|
-
constant int64_t & ne0,
|
7058
|
-
constant int64_t & ne1,
|
7059
|
-
constant uint64_t & nb1,
|
7060
|
-
constant uint & r2,
|
7061
|
-
constant uint & r3,
|
7062
|
-
constant int & idx,
|
7063
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
7064
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
7065
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
7066
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
7067
|
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
7068
|
-
|
7069
|
-
tgpig.z = tgpig.z%(ne12*ne13);
|
7070
|
-
|
7071
|
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
7072
|
-
device const char * src0 = src0s + id*nb02;
|
7073
|
-
|
7074
|
-
kernel_mul_mv_iq1_m_f32_impl(
|
7075
|
-
src0,
|
7076
|
-
(device const float *) (src1 + bid*nb11),
|
7077
|
-
dst + bid*ne0,
|
7078
|
-
ne00,
|
7079
|
-
ne01,
|
7080
|
-
ne02,
|
7081
|
-
ne10,
|
7082
|
-
ne12,
|
7083
|
-
ne0,
|
7084
|
-
ne1,
|
7085
|
-
r2,
|
7086
|
-
r3,
|
7087
|
-
tgpig,
|
7088
|
-
tiisg,
|
7089
|
-
sgitg);
|
7090
|
-
}
|
7091
|
-
|
7092
|
-
[[host_name("kernel_mul_mv_id_iq4_nl_f32")]]
|
7093
|
-
kernel void kernel_mul_mv_id_iq4_nl_f32(
|
7094
|
-
device const char * src0s,
|
7095
|
-
device const char * src1,
|
7096
|
-
device float * dst,
|
7097
|
-
device const char * ids,
|
7098
|
-
constant uint64_t & nbi1,
|
7099
|
-
constant int64_t & ne00,
|
7100
|
-
constant int64_t & ne01,
|
7101
|
-
constant int64_t & ne02,
|
7102
|
-
constant uint64_t & nb00,
|
7103
|
-
constant uint64_t & nb01,
|
7104
|
-
constant uint64_t & nb02,
|
7105
|
-
constant int64_t & ne10,
|
7106
|
-
constant int64_t & ne11,
|
7107
|
-
constant int64_t & ne12,
|
7108
|
-
constant int64_t & ne13,
|
7109
|
-
constant uint64_t & nb10,
|
7110
|
-
constant uint64_t & nb11,
|
7111
|
-
constant uint64_t & nb12,
|
7112
|
-
constant int64_t & ne0,
|
7113
|
-
constant int64_t & ne1,
|
7114
|
-
constant uint64_t & nb1,
|
7115
|
-
constant uint & r2,
|
7116
|
-
constant uint & r3,
|
7117
|
-
constant int & idx,
|
7118
|
-
threadgroup float * shared_values [[threadgroup(0)]],
|
7119
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
7120
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
7121
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
7122
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
7123
|
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
7124
|
-
|
7125
|
-
tgpig.z = tgpig.z%(ne12*ne13);
|
7126
|
-
|
7127
|
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
7128
|
-
device const char * src0 = src0s + id*nb02;
|
7129
|
-
|
7130
|
-
kernel_mul_mv_iq4_nl_f32_impl(
|
7131
|
-
src0,
|
7132
|
-
(device const float *) (src1 + bid*nb11),
|
7133
|
-
dst + bid*ne0,
|
7134
|
-
ne00,
|
7135
|
-
ne01,
|
7136
|
-
ne02,
|
7137
|
-
ne10,
|
7138
|
-
ne12,
|
7139
|
-
ne0,
|
7140
|
-
ne1,
|
7141
|
-
r2,
|
7142
|
-
r3,
|
7143
|
-
shared_values,
|
7144
|
-
tgpig,
|
7145
|
-
tiisg,
|
7146
|
-
sgitg);
|
7147
|
-
}
|
7148
|
-
|
7149
|
-
[[host_name("kernel_mul_mv_id_iq4_xs_f32")]]
|
7150
|
-
kernel void kernel_mul_mv_id_iq4_xs_f32(
|
7151
|
-
device const char * src0s,
|
7152
|
-
device const char * src1,
|
7153
|
-
device float * dst,
|
7154
|
-
device const char * ids,
|
7155
|
-
constant uint64_t & nbi1,
|
7156
|
-
constant int64_t & ne00,
|
7157
|
-
constant int64_t & ne01,
|
7158
|
-
constant int64_t & ne02,
|
7159
|
-
constant uint64_t & nb00,
|
7160
|
-
constant uint64_t & nb01,
|
7161
|
-
constant uint64_t & nb02,
|
7162
|
-
constant int64_t & ne10,
|
7163
|
-
constant int64_t & ne11,
|
7164
|
-
constant int64_t & ne12,
|
7165
|
-
constant int64_t & ne13,
|
7166
|
-
constant uint64_t & nb10,
|
7167
|
-
constant uint64_t & nb11,
|
7168
|
-
constant uint64_t & nb12,
|
7169
|
-
constant int64_t & ne0,
|
7170
|
-
constant int64_t & ne1,
|
7171
|
-
constant uint64_t & nb1,
|
7172
|
-
constant uint & r2,
|
7173
|
-
constant uint & r3,
|
7174
|
-
constant int & idx,
|
7175
|
-
threadgroup float * shared_values [[threadgroup(0)]],
|
7176
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
7177
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
7178
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
7179
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
7180
|
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
7181
|
-
|
7182
|
-
tgpig.z = tgpig.z%(ne12*ne13);
|
7183
|
-
|
7184
|
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
7185
|
-
device const char * src0 = src0s + id*nb02;
|
7186
|
-
|
7187
|
-
#if QK_K == 64
|
7188
|
-
kernel_mul_mv_iq4_nl_f32_impl(
|
7189
|
-
#else
|
7190
|
-
kernel_mul_mv_iq4_xs_f32_impl(
|
6221
|
+
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>) kernel_mul_mv_id_t;
|
6222
|
+
|
6223
|
+
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>;
|
6224
|
+
template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f16_f32_impl>>;
|
6225
|
+
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
|
6226
|
+
template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
6227
|
+
template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
6228
|
+
template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
6229
|
+
template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
6230
|
+
template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl>>;
|
6231
|
+
template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl>>;
|
6232
|
+
template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl>>;
|
6233
|
+
template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl>>;
|
6234
|
+
template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl>>;
|
6235
|
+
template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl>>;
|
6236
|
+
template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl>>;
|
6237
|
+
template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl>>;
|
6238
|
+
template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl>>;
|
6239
|
+
template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl>>;
|
6240
|
+
template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl>>;
|
6241
|
+
template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
|
6242
|
+
template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
|
6243
|
+
#if QK_K != 64
|
6244
|
+
template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
|
7191
6245
|
#endif
|
7192
|
-
|
7193
|
-
(device const float *) (src1 + bid*nb11),
|
7194
|
-
dst + bid*ne0,
|
7195
|
-
ne00,
|
7196
|
-
ne01,
|
7197
|
-
ne02,
|
7198
|
-
ne10,
|
7199
|
-
ne12,
|
7200
|
-
ne0,
|
7201
|
-
ne1,
|
7202
|
-
r2,
|
7203
|
-
r3,
|
7204
|
-
shared_values,
|
7205
|
-
tgpig,
|
7206
|
-
tiisg,
|
7207
|
-
sgitg);
|
7208
|
-
}
|
6246
|
+
|