llama_cpp 0.14.5 → 0.14.6
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/vendor/tmp/llama.cpp/Makefile +18 -6
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +135 -46
- data/vendor/tmp/llama.cpp/ggml-impl.h +1 -1
- data/vendor/tmp/llama.cpp/ggml-metal.m +130 -83
- data/vendor/tmp/llama.cpp/ggml-metal.metal +505 -1467
- data/vendor/tmp/llama.cpp/ggml-quants.c +1 -1
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +65 -52
- data/vendor/tmp/llama.cpp/ggml.c +153 -87
- data/vendor/tmp/llama.cpp/ggml.h +5 -4
- data/vendor/tmp/llama.cpp/llama.cpp +885 -144
- data/vendor/tmp/llama.cpp/sgemm.cpp +1148 -0
- data/vendor/tmp/llama.cpp/sgemm.h +12 -0
- metadata +4 -2
@@ -213,6 +213,15 @@ kernel void kernel_scale_4(
|
|
213
213
|
dst[tpig] = src0[tpig] * scale;
|
214
214
|
}
|
215
215
|
|
216
|
+
kernel void kernel_clamp(
|
217
|
+
device const float * src0,
|
218
|
+
device float * dst,
|
219
|
+
constant float & min,
|
220
|
+
constant float & max,
|
221
|
+
uint tpig[[thread_position_in_grid]]) {
|
222
|
+
dst[tpig] = src0[tpig] < min ? min : (src0[tpig] > max ? max : src0[tpig]);
|
223
|
+
}
|
224
|
+
|
216
225
|
kernel void kernel_relu(
|
217
226
|
device const float * src0,
|
218
227
|
device float * dst,
|
@@ -233,6 +242,15 @@ constant float GELU_QUICK_COEF = -1.702f;
|
|
233
242
|
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
234
243
|
|
235
244
|
kernel void kernel_gelu(
|
245
|
+
device const float * src0,
|
246
|
+
device float * dst,
|
247
|
+
uint tpig[[thread_position_in_grid]]) {
|
248
|
+
device const float & x = src0[tpig];
|
249
|
+
|
250
|
+
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
251
|
+
}
|
252
|
+
|
253
|
+
kernel void kernel_gelu_4(
|
236
254
|
device const float4 * src0,
|
237
255
|
device float4 * dst,
|
238
256
|
uint tpig[[thread_position_in_grid]]) {
|
@@ -246,6 +264,15 @@ kernel void kernel_gelu(
|
|
246
264
|
}
|
247
265
|
|
248
266
|
kernel void kernel_gelu_quick(
|
267
|
+
device const float * src0,
|
268
|
+
device float * dst,
|
269
|
+
uint tpig[[thread_position_in_grid]]) {
|
270
|
+
device const float & x = src0[tpig];
|
271
|
+
|
272
|
+
dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
|
273
|
+
}
|
274
|
+
|
275
|
+
kernel void kernel_gelu_quick_4(
|
249
276
|
device const float4 * src0,
|
250
277
|
device float4 * dst,
|
251
278
|
uint tpig[[thread_position_in_grid]]) {
|
@@ -255,6 +282,14 @@ kernel void kernel_gelu_quick(
|
|
255
282
|
}
|
256
283
|
|
257
284
|
kernel void kernel_silu(
|
285
|
+
device const float * src0,
|
286
|
+
device float * dst,
|
287
|
+
uint tpig[[thread_position_in_grid]]) {
|
288
|
+
device const float & x = src0[tpig];
|
289
|
+
dst[tpig] = x / (1.0f + exp(-x));
|
290
|
+
}
|
291
|
+
|
292
|
+
kernel void kernel_silu_4(
|
258
293
|
device const float4 * src0,
|
259
294
|
device float4 * dst,
|
260
295
|
uint tpig[[thread_position_in_grid]]) {
|
@@ -866,6 +901,7 @@ void mul_vec_q_n_f32_impl(
|
|
866
901
|
int64_t ne1,
|
867
902
|
uint r2,
|
868
903
|
uint r3,
|
904
|
+
threadgroup int8_t * shared_values,
|
869
905
|
uint3 tgpig, uint tiisg, uint sgitg) {
|
870
906
|
const int nb = ne00/QK4_0;
|
871
907
|
|
@@ -942,7 +978,7 @@ kernel void kernel_mul_mv_q4_0_f32(
|
|
942
978
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
943
979
|
uint tiisg[[thread_index_in_simdgroup]],
|
944
980
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
945
|
-
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
981
|
+
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
|
946
982
|
}
|
947
983
|
|
948
984
|
kernel void kernel_mul_mv_q4_1_f32(
|
@@ -968,7 +1004,7 @@ kernel void kernel_mul_mv_q4_1_f32(
|
|
968
1004
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
969
1005
|
uint tiisg[[thread_index_in_simdgroup]],
|
970
1006
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
971
|
-
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
1007
|
+
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
|
972
1008
|
}
|
973
1009
|
|
974
1010
|
kernel void kernel_mul_mv_q5_0_f32(
|
@@ -994,7 +1030,7 @@ kernel void kernel_mul_mv_q5_0_f32(
|
|
994
1030
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
995
1031
|
uint tiisg[[thread_index_in_simdgroup]],
|
996
1032
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
997
|
-
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
1033
|
+
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
|
998
1034
|
}
|
999
1035
|
|
1000
1036
|
kernel void kernel_mul_mv_q5_1_f32(
|
@@ -1020,7 +1056,7 @@ kernel void kernel_mul_mv_q5_1_f32(
|
|
1020
1056
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1021
1057
|
uint tiisg[[thread_index_in_simdgroup]],
|
1022
1058
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1023
|
-
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
1059
|
+
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
|
1024
1060
|
}
|
1025
1061
|
|
1026
1062
|
|
@@ -1030,18 +1066,19 @@ void kernel_mul_mv_q8_0_f32_impl(
|
|
1030
1066
|
device const void * src0,
|
1031
1067
|
device const float * src1,
|
1032
1068
|
device float * dst,
|
1033
|
-
|
1034
|
-
|
1035
|
-
|
1036
|
-
|
1037
|
-
|
1038
|
-
|
1039
|
-
|
1040
|
-
|
1041
|
-
|
1042
|
-
|
1043
|
-
|
1044
|
-
|
1069
|
+
int64_t ne00,
|
1070
|
+
int64_t ne01,
|
1071
|
+
int64_t ne02,
|
1072
|
+
int64_t ne10,
|
1073
|
+
int64_t ne12,
|
1074
|
+
int64_t ne0,
|
1075
|
+
int64_t ne1,
|
1076
|
+
uint r2,
|
1077
|
+
uint r3,
|
1078
|
+
threadgroup int8_t * shared_values,
|
1079
|
+
uint3 tgpig,
|
1080
|
+
uint tiisg,
|
1081
|
+
uint sgitg) {
|
1045
1082
|
const int nr = N_DST;
|
1046
1083
|
const int nsg = N_SIMDGROUP;
|
1047
1084
|
const int nw = N_SIMDWIDTH;
|
@@ -1119,7 +1156,7 @@ kernel void kernel_mul_mv_q8_0_f32(
|
|
1119
1156
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1120
1157
|
uint tiisg[[thread_index_in_simdgroup]],
|
1121
1158
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1122
|
-
kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
1159
|
+
kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
|
1123
1160
|
}
|
1124
1161
|
|
1125
1162
|
#define N_F32_F32 4
|
@@ -1128,24 +1165,24 @@ void kernel_mul_mv_f32_f32_impl(
|
|
1128
1165
|
device const char * src0,
|
1129
1166
|
device const char * src1,
|
1130
1167
|
device float * dst,
|
1131
|
-
|
1132
|
-
|
1133
|
-
|
1134
|
-
|
1135
|
-
|
1136
|
-
|
1137
|
-
|
1138
|
-
|
1139
|
-
|
1140
|
-
|
1141
|
-
|
1142
|
-
|
1143
|
-
|
1144
|
-
|
1145
|
-
|
1146
|
-
|
1147
|
-
|
1148
|
-
|
1168
|
+
int64_t ne00,
|
1169
|
+
int64_t ne01,
|
1170
|
+
int64_t ne02,
|
1171
|
+
uint64_t nb00,
|
1172
|
+
uint64_t nb01,
|
1173
|
+
uint64_t nb02,
|
1174
|
+
int64_t ne10,
|
1175
|
+
int64_t ne11,
|
1176
|
+
int64_t ne12,
|
1177
|
+
uint64_t nb10,
|
1178
|
+
uint64_t nb11,
|
1179
|
+
uint64_t nb12,
|
1180
|
+
int64_t ne0,
|
1181
|
+
int64_t ne1,
|
1182
|
+
uint r2,
|
1183
|
+
uint r3,
|
1184
|
+
uint3 tgpig,
|
1185
|
+
uint tiisg) {
|
1149
1186
|
|
1150
1187
|
const int64_t r0 = tgpig.x;
|
1151
1188
|
const int64_t rb = tgpig.y*N_F32_F32;
|
@@ -1398,24 +1435,24 @@ void kernel_mul_mv_f16_f32_impl(
|
|
1398
1435
|
device const char * src0,
|
1399
1436
|
device const char * src1,
|
1400
1437
|
device float * dst,
|
1401
|
-
|
1402
|
-
|
1403
|
-
|
1404
|
-
|
1405
|
-
|
1406
|
-
|
1407
|
-
|
1408
|
-
|
1409
|
-
|
1410
|
-
|
1411
|
-
|
1412
|
-
|
1413
|
-
|
1414
|
-
|
1415
|
-
|
1416
|
-
|
1417
|
-
|
1418
|
-
|
1438
|
+
int64_t ne00,
|
1439
|
+
int64_t ne01,
|
1440
|
+
int64_t ne02,
|
1441
|
+
uint64_t nb00,
|
1442
|
+
uint64_t nb01,
|
1443
|
+
uint64_t nb02,
|
1444
|
+
int64_t ne10,
|
1445
|
+
int64_t ne11,
|
1446
|
+
int64_t ne12,
|
1447
|
+
uint64_t nb10,
|
1448
|
+
uint64_t nb11,
|
1449
|
+
uint64_t nb12,
|
1450
|
+
int64_t ne0,
|
1451
|
+
int64_t ne1,
|
1452
|
+
uint r2,
|
1453
|
+
uint r3,
|
1454
|
+
uint3 tgpig,
|
1455
|
+
uint tiisg) {
|
1419
1456
|
|
1420
1457
|
const int64_t r0 = tgpig.x;
|
1421
1458
|
const int64_t rb = tgpig.y*N_F16_F32;
|
@@ -2700,18 +2737,19 @@ void kernel_mul_mv_q2_K_f32_impl(
|
|
2700
2737
|
device const void * src0,
|
2701
2738
|
device const float * src1,
|
2702
2739
|
device float * dst,
|
2703
|
-
|
2704
|
-
|
2705
|
-
|
2706
|
-
|
2707
|
-
|
2708
|
-
|
2709
|
-
|
2710
|
-
|
2711
|
-
|
2712
|
-
|
2713
|
-
|
2714
|
-
|
2740
|
+
int64_t ne00,
|
2741
|
+
int64_t ne01,
|
2742
|
+
int64_t ne02,
|
2743
|
+
int64_t ne10,
|
2744
|
+
int64_t ne12,
|
2745
|
+
int64_t ne0,
|
2746
|
+
int64_t ne1,
|
2747
|
+
uint r2,
|
2748
|
+
uint r3,
|
2749
|
+
threadgroup int8_t * shared_values,
|
2750
|
+
uint3 tgpig,
|
2751
|
+
uint tiisg,
|
2752
|
+
uint sgitg) {
|
2715
2753
|
|
2716
2754
|
const int nb = ne00/QK_K;
|
2717
2755
|
const int r0 = tgpig.x;
|
@@ -2871,7 +2909,7 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
2871
2909
|
uint tiisg[[thread_index_in_simdgroup]],
|
2872
2910
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
2873
2911
|
|
2874
|
-
kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
2912
|
+
kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
2875
2913
|
}
|
2876
2914
|
|
2877
2915
|
#if QK_K == 256
|
@@ -2879,18 +2917,19 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|
2879
2917
|
device const void * src0,
|
2880
2918
|
device const float * src1,
|
2881
2919
|
device float * dst,
|
2882
|
-
|
2883
|
-
|
2884
|
-
|
2885
|
-
|
2886
|
-
|
2887
|
-
|
2888
|
-
|
2889
|
-
|
2890
|
-
|
2891
|
-
|
2892
|
-
|
2893
|
-
|
2920
|
+
int64_t ne00,
|
2921
|
+
int64_t ne01,
|
2922
|
+
int64_t ne02,
|
2923
|
+
int64_t ne10,
|
2924
|
+
int64_t ne12,
|
2925
|
+
int64_t ne0,
|
2926
|
+
int64_t ne1,
|
2927
|
+
uint r2,
|
2928
|
+
uint r3,
|
2929
|
+
threadgroup int8_t * shared_values,
|
2930
|
+
uint3 tgpig,
|
2931
|
+
uint tiisg,
|
2932
|
+
uint sgitg) {
|
2894
2933
|
|
2895
2934
|
const int nb = ne00/QK_K;
|
2896
2935
|
|
@@ -3046,6 +3085,7 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|
3046
3085
|
constant int64_t & ne1,
|
3047
3086
|
constant uint & r2,
|
3048
3087
|
constant uint & r3,
|
3088
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
3049
3089
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
3050
3090
|
uint tiisg[[thread_index_in_simdgroup]],
|
3051
3091
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -3135,7 +3175,7 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
3135
3175
|
uint tiisg[[thread_index_in_simdgroup]],
|
3136
3176
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3137
3177
|
|
3138
|
-
kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
3178
|
+
kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
3139
3179
|
}
|
3140
3180
|
|
3141
3181
|
#if QK_K == 256
|
@@ -3143,18 +3183,19 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
3143
3183
|
device const void * src0,
|
3144
3184
|
device const float * src1,
|
3145
3185
|
device float * dst,
|
3146
|
-
|
3147
|
-
|
3148
|
-
|
3149
|
-
|
3150
|
-
|
3151
|
-
|
3152
|
-
|
3153
|
-
|
3154
|
-
|
3155
|
-
|
3156
|
-
|
3157
|
-
|
3186
|
+
int64_t ne00,
|
3187
|
+
int64_t ne01,
|
3188
|
+
int64_t ne02,
|
3189
|
+
int64_t ne10,
|
3190
|
+
int64_t ne12,
|
3191
|
+
int64_t ne0,
|
3192
|
+
int64_t ne1,
|
3193
|
+
uint r2,
|
3194
|
+
uint r3,
|
3195
|
+
threadgroup int8_t * shared_values,
|
3196
|
+
uint3 tgpig,
|
3197
|
+
uint tiisg,
|
3198
|
+
uint sgitg) {
|
3158
3199
|
|
3159
3200
|
const uint16_t kmask1 = 0x3f3f;
|
3160
3201
|
const uint16_t kmask2 = 0x0f0f;
|
@@ -3265,6 +3306,7 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
3265
3306
|
constant int64_t & ne1,
|
3266
3307
|
constant uint & r2,
|
3267
3308
|
constant uint & r3,
|
3309
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
3268
3310
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
3269
3311
|
uint tiisg[[thread_index_in_simdgroup]],
|
3270
3312
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -3373,25 +3415,26 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
3373
3415
|
uint tiisg[[thread_index_in_simdgroup]],
|
3374
3416
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3375
3417
|
|
3376
|
-
kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
3418
|
+
kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
3377
3419
|
}
|
3378
3420
|
|
3379
3421
|
void kernel_mul_mv_q5_K_f32_impl(
|
3380
3422
|
device const void * src0,
|
3381
3423
|
device const float * src1,
|
3382
3424
|
device float * dst,
|
3383
|
-
|
3384
|
-
|
3385
|
-
|
3386
|
-
|
3387
|
-
|
3388
|
-
|
3389
|
-
|
3390
|
-
|
3391
|
-
|
3392
|
-
|
3393
|
-
|
3394
|
-
|
3425
|
+
int64_t ne00,
|
3426
|
+
int64_t ne01,
|
3427
|
+
int64_t ne02,
|
3428
|
+
int64_t ne10,
|
3429
|
+
int64_t ne12,
|
3430
|
+
int64_t ne0,
|
3431
|
+
int64_t ne1,
|
3432
|
+
uint r2,
|
3433
|
+
uint r3,
|
3434
|
+
threadgroup int8_t * shared_values,
|
3435
|
+
uint3 tgpig,
|
3436
|
+
uint tiisg,
|
3437
|
+
uint sgitg) {
|
3395
3438
|
|
3396
3439
|
const int nb = ne00/QK_K;
|
3397
3440
|
|
@@ -3579,25 +3622,26 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
3579
3622
|
uint tiisg[[thread_index_in_simdgroup]],
|
3580
3623
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3581
3624
|
|
3582
|
-
kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
3625
|
+
kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
3583
3626
|
}
|
3584
3627
|
|
3585
3628
|
void kernel_mul_mv_q6_K_f32_impl(
|
3586
3629
|
device const void * src0,
|
3587
3630
|
device const float * src1,
|
3588
3631
|
device float * dst,
|
3589
|
-
|
3590
|
-
|
3591
|
-
|
3592
|
-
|
3593
|
-
|
3594
|
-
|
3595
|
-
|
3596
|
-
|
3597
|
-
|
3598
|
-
|
3599
|
-
|
3600
|
-
|
3632
|
+
int64_t ne00,
|
3633
|
+
int64_t ne01,
|
3634
|
+
int64_t ne02,
|
3635
|
+
int64_t ne10,
|
3636
|
+
int64_t ne12,
|
3637
|
+
int64_t ne0,
|
3638
|
+
int64_t ne1,
|
3639
|
+
uint r2,
|
3640
|
+
uint r3,
|
3641
|
+
threadgroup int8_t * shared_values,
|
3642
|
+
uint3 tgpig,
|
3643
|
+
uint tiisg,
|
3644
|
+
uint sgitg) {
|
3601
3645
|
|
3602
3646
|
const uint8_t kmask1 = 0x03;
|
3603
3647
|
const uint8_t kmask2 = 0x0C;
|
@@ -3713,7 +3757,7 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|
3713
3757
|
uint tiisg[[thread_index_in_simdgroup]],
|
3714
3758
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3715
3759
|
|
3716
|
-
kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
3760
|
+
kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
3717
3761
|
}
|
3718
3762
|
|
3719
3763
|
// ======================= "True" 2-bit
|
@@ -3722,19 +3766,19 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
|
3722
3766
|
device const void * src0,
|
3723
3767
|
device const float * src1,
|
3724
3768
|
device float * dst,
|
3725
|
-
|
3726
|
-
|
3727
|
-
|
3728
|
-
|
3729
|
-
|
3730
|
-
|
3731
|
-
|
3732
|
-
|
3733
|
-
|
3734
|
-
threadgroup int8_t * shared_values
|
3735
|
-
|
3736
|
-
|
3737
|
-
|
3769
|
+
int64_t ne00,
|
3770
|
+
int64_t ne01,
|
3771
|
+
int64_t ne02,
|
3772
|
+
int64_t ne10,
|
3773
|
+
int64_t ne12,
|
3774
|
+
int64_t ne0,
|
3775
|
+
int64_t ne1,
|
3776
|
+
uint r2,
|
3777
|
+
uint r3,
|
3778
|
+
threadgroup int8_t * shared_values,
|
3779
|
+
uint3 tgpig,
|
3780
|
+
uint tiisg,
|
3781
|
+
uint sgitg) {
|
3738
3782
|
|
3739
3783
|
const int nb = ne00/QK_K;
|
3740
3784
|
const int r0 = tgpig.x;
|
@@ -3851,19 +3895,19 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
|
3851
3895
|
device const void * src0,
|
3852
3896
|
device const float * src1,
|
3853
3897
|
device float * dst,
|
3854
|
-
|
3855
|
-
|
3856
|
-
|
3857
|
-
|
3858
|
-
|
3859
|
-
|
3860
|
-
|
3861
|
-
|
3862
|
-
|
3863
|
-
threadgroup int8_t * shared_values
|
3864
|
-
|
3865
|
-
|
3866
|
-
|
3898
|
+
int64_t ne00,
|
3899
|
+
int64_t ne01,
|
3900
|
+
int64_t ne02,
|
3901
|
+
int64_t ne10,
|
3902
|
+
int64_t ne12,
|
3903
|
+
int64_t ne0,
|
3904
|
+
int64_t ne1,
|
3905
|
+
uint r2,
|
3906
|
+
uint r3,
|
3907
|
+
threadgroup int8_t * shared_values,
|
3908
|
+
uint3 tgpig,
|
3909
|
+
uint tiisg,
|
3910
|
+
uint sgitg) {
|
3867
3911
|
|
3868
3912
|
const int nb = ne00/QK_K;
|
3869
3913
|
const int r0 = tgpig.x;
|
@@ -3990,19 +4034,19 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
|
3990
4034
|
device const void * src0,
|
3991
4035
|
device const float * src1,
|
3992
4036
|
device float * dst,
|
3993
|
-
|
3994
|
-
|
3995
|
-
|
3996
|
-
|
3997
|
-
|
3998
|
-
|
3999
|
-
|
4000
|
-
|
4001
|
-
|
4002
|
-
threadgroup int8_t * shared_values
|
4003
|
-
|
4004
|
-
|
4005
|
-
|
4037
|
+
int64_t ne00,
|
4038
|
+
int64_t ne01,
|
4039
|
+
int64_t ne02,
|
4040
|
+
int64_t ne10,
|
4041
|
+
int64_t ne12,
|
4042
|
+
int64_t ne0,
|
4043
|
+
int64_t ne1,
|
4044
|
+
uint r2,
|
4045
|
+
uint r3,
|
4046
|
+
threadgroup int8_t * shared_values,
|
4047
|
+
uint3 tgpig,
|
4048
|
+
uint tiisg,
|
4049
|
+
uint sgitg) {
|
4006
4050
|
|
4007
4051
|
const int nb = ne00/QK_K;
|
4008
4052
|
const int r0 = tgpig.x;
|
@@ -4122,19 +4166,19 @@ void kernel_mul_mv_iq3_s_f32_impl(
|
|
4122
4166
|
device const void * src0,
|
4123
4167
|
device const float * src1,
|
4124
4168
|
device float * dst,
|
4125
|
-
|
4126
|
-
|
4127
|
-
|
4128
|
-
|
4129
|
-
|
4130
|
-
|
4131
|
-
|
4132
|
-
|
4133
|
-
|
4134
|
-
threadgroup int8_t * shared_values
|
4135
|
-
|
4136
|
-
|
4137
|
-
|
4169
|
+
int64_t ne00,
|
4170
|
+
int64_t ne01,
|
4171
|
+
int64_t ne02,
|
4172
|
+
int64_t ne10,
|
4173
|
+
int64_t ne12,
|
4174
|
+
int64_t ne0,
|
4175
|
+
int64_t ne1,
|
4176
|
+
uint r2,
|
4177
|
+
uint r3,
|
4178
|
+
threadgroup int8_t * shared_values,
|
4179
|
+
uint3 tgpig,
|
4180
|
+
uint tiisg,
|
4181
|
+
uint sgitg) {
|
4138
4182
|
|
4139
4183
|
const int nb = ne00/QK_K;
|
4140
4184
|
const int r0 = tgpig.x;
|
@@ -4254,19 +4298,19 @@ void kernel_mul_mv_iq2_s_f32_impl(
|
|
4254
4298
|
device const void * src0,
|
4255
4299
|
device const float * src1,
|
4256
4300
|
device float * dst,
|
4257
|
-
|
4258
|
-
|
4259
|
-
|
4260
|
-
|
4261
|
-
|
4262
|
-
|
4263
|
-
|
4264
|
-
|
4265
|
-
|
4266
|
-
threadgroup int8_t * shared_values
|
4267
|
-
|
4268
|
-
|
4269
|
-
|
4301
|
+
int64_t ne00,
|
4302
|
+
int64_t ne01,
|
4303
|
+
int64_t ne02,
|
4304
|
+
int64_t ne10,
|
4305
|
+
int64_t ne12,
|
4306
|
+
int64_t ne0,
|
4307
|
+
int64_t ne1,
|
4308
|
+
uint r2,
|
4309
|
+
uint r3,
|
4310
|
+
threadgroup int8_t * shared_values,
|
4311
|
+
uint3 tgpig,
|
4312
|
+
uint tiisg,
|
4313
|
+
uint sgitg) {
|
4270
4314
|
|
4271
4315
|
const int nb = ne00/QK_K;
|
4272
4316
|
const int r0 = tgpig.x;
|
@@ -4387,18 +4431,19 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|
4387
4431
|
device const void * src0,
|
4388
4432
|
device const float * src1,
|
4389
4433
|
device float * dst,
|
4390
|
-
|
4391
|
-
|
4392
|
-
|
4393
|
-
|
4394
|
-
|
4395
|
-
|
4396
|
-
|
4397
|
-
|
4398
|
-
|
4399
|
-
|
4400
|
-
|
4401
|
-
|
4434
|
+
int64_t ne00,
|
4435
|
+
int64_t ne01,
|
4436
|
+
int64_t ne02,
|
4437
|
+
int64_t ne10,
|
4438
|
+
int64_t ne12,
|
4439
|
+
int64_t ne0,
|
4440
|
+
int64_t ne1,
|
4441
|
+
uint r2,
|
4442
|
+
uint r3,
|
4443
|
+
threadgroup int8_t * shared_value,
|
4444
|
+
uint3 tgpig,
|
4445
|
+
uint tiisg,
|
4446
|
+
uint sgitg) {
|
4402
4447
|
|
4403
4448
|
const int nb = ne00/QK_K;
|
4404
4449
|
const int r0 = tgpig.x;
|
@@ -4476,18 +4521,19 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
|
4476
4521
|
device const void * src0,
|
4477
4522
|
device const float * src1,
|
4478
4523
|
device float * dst,
|
4479
|
-
|
4480
|
-
|
4481
|
-
|
4482
|
-
|
4483
|
-
|
4484
|
-
|
4485
|
-
|
4486
|
-
|
4487
|
-
|
4488
|
-
|
4489
|
-
|
4490
|
-
|
4524
|
+
int64_t ne00,
|
4525
|
+
int64_t ne01,
|
4526
|
+
int64_t ne02,
|
4527
|
+
int64_t ne10,
|
4528
|
+
int64_t ne12,
|
4529
|
+
int64_t ne0,
|
4530
|
+
int64_t ne1,
|
4531
|
+
uint r2,
|
4532
|
+
uint r3,
|
4533
|
+
threadgroup int8_t * shared_value,
|
4534
|
+
uint3 tgpig,
|
4535
|
+
uint tiisg,
|
4536
|
+
uint sgitg) {
|
4491
4537
|
|
4492
4538
|
const int nb = ne00/QK_K;
|
4493
4539
|
const int r0 = tgpig.x;
|
@@ -4584,20 +4630,21 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
|
4584
4630
|
device const void * src0,
|
4585
4631
|
device const float * src1,
|
4586
4632
|
device float * dst,
|
4587
|
-
|
4588
|
-
|
4589
|
-
|
4590
|
-
|
4591
|
-
|
4592
|
-
|
4593
|
-
|
4594
|
-
|
4595
|
-
|
4596
|
-
threadgroup
|
4597
|
-
|
4598
|
-
|
4599
|
-
|
4633
|
+
int64_t ne00,
|
4634
|
+
int64_t ne01,
|
4635
|
+
int64_t ne02,
|
4636
|
+
int64_t ne10,
|
4637
|
+
int64_t ne12,
|
4638
|
+
int64_t ne0,
|
4639
|
+
int64_t ne1,
|
4640
|
+
uint r2,
|
4641
|
+
uint r3,
|
4642
|
+
threadgroup int8_t * shared_values_i8,
|
4643
|
+
uint3 tgpig,
|
4644
|
+
uint tiisg,
|
4645
|
+
uint sgitg) {
|
4600
4646
|
|
4647
|
+
threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
|
4601
4648
|
const int nb = ne00/QK4_NL;
|
4602
4649
|
const int r0 = tgpig.x;
|
4603
4650
|
const int r1 = tgpig.y;
|
@@ -4678,20 +4725,21 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
|
4678
4725
|
device const void * src0,
|
4679
4726
|
device const float * src1,
|
4680
4727
|
device float * dst,
|
4681
|
-
|
4682
|
-
|
4683
|
-
|
4684
|
-
|
4685
|
-
|
4686
|
-
|
4687
|
-
|
4688
|
-
|
4689
|
-
|
4690
|
-
threadgroup
|
4691
|
-
|
4692
|
-
|
4693
|
-
|
4728
|
+
int64_t ne00,
|
4729
|
+
int64_t ne01,
|
4730
|
+
int64_t ne02,
|
4731
|
+
int64_t ne10,
|
4732
|
+
int64_t ne12,
|
4733
|
+
int64_t ne0,
|
4734
|
+
int64_t ne1,
|
4735
|
+
uint r2,
|
4736
|
+
uint r3,
|
4737
|
+
threadgroup int8_t * shared_values_i8,
|
4738
|
+
uint3 tgpig,
|
4739
|
+
uint tiisg,
|
4740
|
+
uint sgitg) {
|
4694
4741
|
|
4742
|
+
threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
|
4695
4743
|
const int nb = ne00/QK_K;
|
4696
4744
|
const int r0 = tgpig.x;
|
4697
4745
|
const int r1 = tgpig.y;
|
@@ -4794,7 +4842,7 @@ kernel void kernel_mul_mv_iq1_s_f32(
|
|
4794
4842
|
uint tiisg[[thread_index_in_simdgroup]],
|
4795
4843
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4796
4844
|
|
4797
|
-
kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
4845
|
+
kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
4798
4846
|
}
|
4799
4847
|
|
4800
4848
|
[[host_name("kernel_mul_mv_iq1_m_f32")]]
|
@@ -4822,7 +4870,7 @@ kernel void kernel_mul_mv_iq1_m_f32(
|
|
4822
4870
|
uint tiisg[[thread_index_in_simdgroup]],
|
4823
4871
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4824
4872
|
|
4825
|
-
kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
4873
|
+
kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
4826
4874
|
}
|
4827
4875
|
|
4828
4876
|
[[host_name("kernel_mul_mv_iq4_nl_f32")]]
|
@@ -4846,7 +4894,7 @@ kernel void kernel_mul_mv_iq4_nl_f32(
|
|
4846
4894
|
constant int64_t & ne1,
|
4847
4895
|
constant uint & r2,
|
4848
4896
|
constant uint & r3,
|
4849
|
-
threadgroup
|
4897
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
4850
4898
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
4851
4899
|
uint tiisg[[thread_index_in_simdgroup]],
|
4852
4900
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -4875,7 +4923,7 @@ kernel void kernel_mul_mv_iq4_xs_f32(
|
|
4875
4923
|
constant int64_t & ne1,
|
4876
4924
|
constant uint & r2,
|
4877
4925
|
constant uint & r3,
|
4878
|
-
threadgroup
|
4926
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
4879
4927
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
4880
4928
|
uint tiisg[[thread_index_in_simdgroup]],
|
4881
4929
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -5632,25 +5680,25 @@ void kernel_mul_mm_impl(device const uchar * src0,
|
|
5632
5680
|
}
|
5633
5681
|
}
|
5634
5682
|
|
5635
|
-
// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in
|
5683
|
+
// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids
|
5636
5684
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
5637
5685
|
void kernel_mul_mm_id_impl(
|
5638
5686
|
device const uchar * src0,
|
5639
5687
|
device const uchar * src1,
|
5640
|
-
threadgroup
|
5688
|
+
threadgroup ushort2 * rowids,
|
5641
5689
|
device float * dst,
|
5642
5690
|
constant int64_t & ne00,
|
5643
5691
|
constant int64_t & ne02,
|
5644
5692
|
constant uint64_t & nb01,
|
5645
5693
|
constant uint64_t & nb02,
|
5694
|
+
constant int64_t & ne11,
|
5646
5695
|
constant int64_t & ne12,
|
5647
5696
|
constant uint64_t & nb10,
|
5648
5697
|
constant uint64_t & nb11,
|
5649
5698
|
constant uint64_t & nb12,
|
5650
5699
|
constant int64_t & ne0,
|
5651
5700
|
int64_t ne1,
|
5652
|
-
|
5653
|
-
constant uint & r3,
|
5701
|
+
int64_t ne0ne1,
|
5654
5702
|
threadgroup uchar * shared_memory,
|
5655
5703
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
5656
5704
|
uint tiitg[[thread_index_in_threadgroup]],
|
@@ -5661,7 +5709,6 @@ void kernel_mul_mm_id_impl(
|
|
5661
5709
|
|
5662
5710
|
const uint r0 = tgpig.y;
|
5663
5711
|
const uint r1 = tgpig.x;
|
5664
|
-
const uint im = tgpig.z;
|
5665
5712
|
|
5666
5713
|
if (r1 * BLOCK_SIZE_N >= ne1) return;
|
5667
5714
|
|
@@ -5679,19 +5726,16 @@ void kernel_mul_mm_id_impl(
|
|
5679
5726
|
for (int i = 0; i < 8; i++){
|
5680
5727
|
c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
|
5681
5728
|
}
|
5682
|
-
|
5683
5729
|
short il = (tiitg % THREAD_PER_ROW);
|
5684
5730
|
|
5685
|
-
const uint i12 = im%ne12;
|
5686
|
-
const uint i13 = im/ne12;
|
5687
|
-
|
5688
|
-
uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
|
5689
5731
|
ushort offset1 = il/nl;
|
5690
5732
|
|
5691
|
-
|
5733
|
+
threadgroup const auto & id = rowids[r1 * BLOCK_SIZE_N + thread_col];
|
5734
|
+
|
5735
|
+
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01) + offset1;
|
5692
5736
|
device const float * y = (device const float *)(src1
|
5693
|
-
+ nb12 *
|
5694
|
-
+ nb11 *
|
5737
|
+
+ nb12 * id[1]
|
5738
|
+
+ nb11 * (id[0] % ne11)
|
5695
5739
|
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
5696
5740
|
|
5697
5741
|
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
@@ -5720,11 +5764,11 @@ void kernel_mul_mm_id_impl(
|
|
5720
5764
|
|
5721
5765
|
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
|
5722
5766
|
for (int i = 0; i < 4; i++) {
|
5723
|
-
simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
|
5767
|
+
simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
|
5724
5768
|
}
|
5725
5769
|
simdgroup_barrier(mem_flags::mem_none);
|
5726
5770
|
for (int i = 0; i < 2; i++) {
|
5727
|
-
simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
|
5771
|
+
simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
|
5728
5772
|
}
|
5729
5773
|
|
5730
5774
|
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
|
@@ -5746,11 +5790,13 @@ void kernel_mul_mm_id_impl(
|
|
5746
5790
|
|
5747
5791
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
5748
5792
|
|
5749
|
-
device float * C = dst + (BLOCK_SIZE_M * r0)
|
5793
|
+
device float * C = dst + (BLOCK_SIZE_M * r0);
|
5750
5794
|
if (sgitg == 0) {
|
5751
|
-
for (int
|
5752
|
-
|
5753
|
-
|
5795
|
+
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
|
5796
|
+
threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j];
|
5797
|
+
int joff = jid[0] * ne0 + jid[1] * ne0ne1;
|
5798
|
+
for (int i = 0; i < n_rows; i++) {
|
5799
|
+
*(C + i + joff) = *(temp_str + i + j * BLOCK_SIZE_M);
|
5754
5800
|
}
|
5755
5801
|
}
|
5756
5802
|
}
|
@@ -5805,11 +5851,14 @@ kernel void kernel_mul_mm_id(
|
|
5805
5851
|
device const uchar * src1,
|
5806
5852
|
device float * dst,
|
5807
5853
|
device const uchar * ids,
|
5854
|
+
constant int64_t & nei0,
|
5855
|
+
constant int64_t & nei1,
|
5808
5856
|
constant uint64_t & nbi1,
|
5809
5857
|
constant int64_t & ne00,
|
5810
5858
|
constant int64_t & ne02,
|
5811
5859
|
constant uint64_t & nb01,
|
5812
5860
|
constant uint64_t & nb02,
|
5861
|
+
constant int64_t & ne11,
|
5813
5862
|
constant int64_t & ne12,
|
5814
5863
|
constant int64_t & ne13,
|
5815
5864
|
constant uint64_t & nb10,
|
@@ -5818,47 +5867,52 @@ kernel void kernel_mul_mm_id(
|
|
5818
5867
|
constant int64_t & ne0,
|
5819
5868
|
constant int64_t & ne1,
|
5820
5869
|
constant uint64_t & nb1,
|
5821
|
-
constant uint & r2,
|
5822
|
-
constant uint & r3,
|
5823
|
-
constant int & idx,
|
5824
5870
|
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
5825
5871
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
5826
5872
|
uint tiitg[[thread_index_in_threadgroup]],
|
5827
5873
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
5828
5874
|
|
5829
|
-
|
5830
|
-
|
5831
|
-
device const uchar * src0 = src0s + id*nb02;
|
5875
|
+
const int32_t i02 = tgpig.z;
|
5876
|
+
tgpig.z = 0;
|
5832
5877
|
|
5833
|
-
|
5878
|
+
device const uchar * src0 = src0s + i02*nb02;
|
5834
5879
|
|
5835
|
-
// row indices
|
5836
|
-
threadgroup
|
5880
|
+
// row indices
|
5881
|
+
threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
|
5837
5882
|
|
5883
|
+
// TODO: parallelize this loop
|
5838
5884
|
int64_t _ne1 = 0;
|
5839
|
-
for (
|
5840
|
-
|
5841
|
-
|
5885
|
+
for (ushort ii1 = 0; ii1 < nei1; ii1++) {
|
5886
|
+
for (ushort ii0 = 0; ii0 < nei0; ii0++) {
|
5887
|
+
int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
|
5888
|
+
if (id == i02) {
|
5889
|
+
//if (tiitg == 0) {
|
5890
|
+
rowids[_ne1] = ushort2(ii0, ii1);
|
5891
|
+
//}
|
5892
|
+
_ne1++;
|
5893
|
+
}
|
5842
5894
|
}
|
5843
5895
|
}
|
5844
5896
|
|
5897
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
5898
|
+
|
5845
5899
|
kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
|
5846
5900
|
src0,
|
5847
5901
|
src1,
|
5848
|
-
|
5902
|
+
rowids,
|
5849
5903
|
dst,
|
5850
5904
|
ne00,
|
5851
5905
|
ne02,
|
5852
5906
|
nb01,
|
5853
5907
|
nb02,
|
5908
|
+
ne11,
|
5854
5909
|
ne12,
|
5855
5910
|
nb10,
|
5856
5911
|
nb11,
|
5857
5912
|
nb12,
|
5858
5913
|
ne0,
|
5859
5914
|
_ne1,
|
5860
|
-
|
5861
|
-
r3,
|
5915
|
+
ne0*ne1,
|
5862
5916
|
shared_memory,
|
5863
5917
|
tgpig,
|
5864
5918
|
tiitg,
|
@@ -5919,24 +5973,7 @@ template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_r
|
|
5919
5973
|
// matrix-matrix multiplication
|
5920
5974
|
//
|
5921
5975
|
|
5922
|
-
typedef
|
5923
|
-
device const uchar * src0,
|
5924
|
-
device const uchar * src1,
|
5925
|
-
device float * dst,
|
5926
|
-
constant int64_t & ne00,
|
5927
|
-
constant int64_t & ne02,
|
5928
|
-
constant uint64_t & nb01,
|
5929
|
-
constant uint64_t & nb02,
|
5930
|
-
constant int64_t & ne12,
|
5931
|
-
constant uint64_t & nb10,
|
5932
|
-
constant uint64_t & nb11,
|
5933
|
-
constant uint64_t & nb12,
|
5934
|
-
constant int64_t & ne0,
|
5935
|
-
constant int64_t & ne1,
|
5936
|
-
constant uint & r2,
|
5937
|
-
constant uint & r3,
|
5938
|
-
threadgroup uchar *,
|
5939
|
-
uint3, uint, uint);
|
5976
|
+
typedef decltype(kernel_mul_mm<float4x4, 1, dequantize_f32>) mat_mm_t;
|
5940
5977
|
|
5941
5978
|
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
|
5942
5979
|
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
@@ -5968,29 +6005,7 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_m
|
|
5968
6005
|
// indirect matrix-matrix multiplication
|
5969
6006
|
//
|
5970
6007
|
|
5971
|
-
typedef
|
5972
|
-
device const uchar * src0s,
|
5973
|
-
device const uchar * src1,
|
5974
|
-
device float * dst,
|
5975
|
-
device const uchar * ids,
|
5976
|
-
constant uint64_t & nbi1,
|
5977
|
-
constant int64_t & ne00,
|
5978
|
-
constant int64_t & ne02,
|
5979
|
-
constant uint64_t & nb01,
|
5980
|
-
constant uint64_t & nb02,
|
5981
|
-
constant int64_t & ne12,
|
5982
|
-
constant int64_t & ne13,
|
5983
|
-
constant uint64_t & nb10,
|
5984
|
-
constant uint64_t & nb11,
|
5985
|
-
constant uint64_t & nb12,
|
5986
|
-
constant int64_t & ne0,
|
5987
|
-
constant int64_t & ne1,
|
5988
|
-
constant uint64_t & nb1,
|
5989
|
-
constant uint & r2,
|
5990
|
-
constant uint & r3,
|
5991
|
-
constant int & idx,
|
5992
|
-
threadgroup uchar *,
|
5993
|
-
uint3, uint, uint);
|
6008
|
+
typedef decltype(kernel_mul_mm_id<float4x4, 1, dequantize_f32>) mat_mm_id_t;
|
5994
6009
|
|
5995
6010
|
template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
|
5996
6011
|
template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
|
@@ -6022,12 +6037,119 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel
|
|
6022
6037
|
// matrix-vector multiplication
|
6023
6038
|
//
|
6024
6039
|
|
6025
|
-
|
6026
|
-
|
6040
|
+
typedef void (kernel_mul_mv_impl_t)(
|
6041
|
+
device const char * src0,
|
6042
|
+
device const char * src1,
|
6043
|
+
device float * dst,
|
6044
|
+
int64_t ne00,
|
6045
|
+
int64_t ne01,
|
6046
|
+
int64_t ne02,
|
6047
|
+
uint64_t nb00,
|
6048
|
+
uint64_t nb01,
|
6049
|
+
uint64_t nb02,
|
6050
|
+
int64_t ne10,
|
6051
|
+
int64_t ne11,
|
6052
|
+
int64_t ne12,
|
6053
|
+
uint64_t nb10,
|
6054
|
+
uint64_t nb11,
|
6055
|
+
uint64_t nb12,
|
6056
|
+
int64_t ne0,
|
6057
|
+
int64_t ne1,
|
6058
|
+
uint r2,
|
6059
|
+
uint r3,
|
6060
|
+
uint3 tgpig,
|
6061
|
+
uint tiisg);
|
6062
|
+
|
6063
|
+
typedef void (kernel_mul_mv2_impl_t)(
|
6064
|
+
device const void * src0,
|
6065
|
+
device const float * src1,
|
6066
|
+
device float * dst,
|
6067
|
+
int64_t ne00,
|
6068
|
+
int64_t ne01,
|
6069
|
+
int64_t ne02,
|
6070
|
+
int64_t ne10,
|
6071
|
+
int64_t ne12,
|
6072
|
+
int64_t ne0,
|
6073
|
+
int64_t ne1,
|
6074
|
+
uint r2,
|
6075
|
+
uint r3,
|
6076
|
+
threadgroup int8_t * shared_values,
|
6077
|
+
uint3 tgpig,
|
6078
|
+
uint tiisg,
|
6079
|
+
uint sgitg);
|
6080
|
+
|
6081
|
+
template<kernel_mul_mv_impl_t impl_fn>
|
6082
|
+
void mmv_fn(
|
6083
|
+
device const char * src0,
|
6084
|
+
device const char * src1,
|
6085
|
+
device float * dst,
|
6086
|
+
int64_t ne00,
|
6087
|
+
int64_t ne01,
|
6088
|
+
int64_t ne02,
|
6089
|
+
uint64_t nb00,
|
6090
|
+
uint64_t nb01,
|
6091
|
+
uint64_t nb02,
|
6092
|
+
int64_t ne10,
|
6093
|
+
int64_t ne11,
|
6094
|
+
int64_t ne12,
|
6095
|
+
int64_t ne13,
|
6096
|
+
uint64_t nb10,
|
6097
|
+
uint64_t nb11,
|
6098
|
+
uint64_t nb12,
|
6099
|
+
int64_t ne0,
|
6100
|
+
int64_t ne1,
|
6101
|
+
uint64_t nb1,
|
6102
|
+
uint r2,
|
6103
|
+
uint r3,
|
6104
|
+
threadgroup int8_t * shared_values,
|
6105
|
+
uint3 tgpig,
|
6106
|
+
uint tiitg,
|
6107
|
+
uint tiisg,
|
6108
|
+
uint sgitg) {
|
6109
|
+
impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,ne10,ne11,ne12,nb10,nb11,nb12,ne0,ne1,r2,r3,tgpig,tiisg);
|
6110
|
+
}
|
6111
|
+
|
6112
|
+
template<kernel_mul_mv2_impl_t impl_fn>
|
6113
|
+
void mmv_fn(
|
6114
|
+
device const char * src0,
|
6115
|
+
device const char * src1,
|
6116
|
+
device float * dst,
|
6117
|
+
int64_t ne00,
|
6118
|
+
int64_t ne01,
|
6119
|
+
int64_t ne02,
|
6120
|
+
uint64_t nb00,
|
6121
|
+
uint64_t nb01,
|
6122
|
+
uint64_t nb02,
|
6123
|
+
int64_t ne10,
|
6124
|
+
int64_t ne11,
|
6125
|
+
int64_t ne12,
|
6126
|
+
int64_t ne13,
|
6127
|
+
uint64_t nb10,
|
6128
|
+
uint64_t nb11,
|
6129
|
+
uint64_t nb12,
|
6130
|
+
int64_t ne0,
|
6131
|
+
int64_t ne1,
|
6132
|
+
uint64_t nb1,
|
6133
|
+
uint r2,
|
6134
|
+
uint r3,
|
6135
|
+
threadgroup int8_t * shared_values,
|
6136
|
+
uint3 tgpig,
|
6137
|
+
uint tiitg,
|
6138
|
+
uint tiisg,
|
6139
|
+
uint sgitg) {
|
6140
|
+
impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
|
6141
|
+
}
|
6142
|
+
|
6143
|
+
typedef decltype(mmv_fn<kernel_mul_mv_f32_f32_impl>) mul_mv_impl_fn_t;
|
6144
|
+
|
6145
|
+
template<mul_mv_impl_fn_t impl_fn>
|
6146
|
+
kernel void kernel_mul_mv_id(
|
6027
6147
|
device const char * src0s,
|
6028
6148
|
device const char * src1,
|
6029
6149
|
device float * dst,
|
6030
6150
|
device const char * ids,
|
6151
|
+
constant int64_t & nei0,
|
6152
|
+
constant int64_t & nei1,
|
6031
6153
|
constant uint64_t & nbi1,
|
6032
6154
|
constant int64_t & ne00,
|
6033
6155
|
constant int64_t & ne01,
|
@@ -6045,1164 +6167,80 @@ kernel void kernel_mul_mv_id_f32_f32(
|
|
6045
6167
|
constant int64_t & ne0,
|
6046
6168
|
constant int64_t & ne1,
|
6047
6169
|
constant uint64_t & nb1,
|
6048
|
-
|
6049
|
-
constant uint & r3,
|
6050
|
-
constant int & idx,
|
6170
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
6051
6171
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6052
6172
|
uint tiitg[[thread_index_in_threadgroup]],
|
6053
6173
|
uint tiisg[[thread_index_in_simdgroup]],
|
6054
6174
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6055
|
-
const
|
6056
|
-
|
6057
|
-
|
6058
|
-
|
6059
|
-
|
6060
|
-
device const
|
6061
|
-
|
6062
|
-
|
6063
|
-
|
6064
|
-
|
6065
|
-
|
6066
|
-
|
6067
|
-
|
6068
|
-
|
6069
|
-
|
6070
|
-
|
6071
|
-
|
6072
|
-
|
6073
|
-
|
6074
|
-
|
6075
|
-
|
6076
|
-
|
6077
|
-
|
6078
|
-
|
6079
|
-
|
6080
|
-
|
6081
|
-
|
6175
|
+
const int iid1 = tgpig.z/nei0;
|
6176
|
+
const int idx = tgpig.z%nei0;
|
6177
|
+
|
6178
|
+
tgpig.z = 0;
|
6179
|
+
|
6180
|
+
const int32_t i02 = ((device const int32_t *) (ids + iid1*nbi1))[idx];
|
6181
|
+
|
6182
|
+
const int64_t i11 = idx % ne11;
|
6183
|
+
const int64_t i12 = iid1;
|
6184
|
+
|
6185
|
+
const int64_t i1 = idx;
|
6186
|
+
const int64_t i2 = i12;
|
6187
|
+
|
6188
|
+
device const char * src0_cur = src0s + i02*nb02;
|
6189
|
+
device const char * src1_cur = src1 + i11*nb11 + i12*nb12;
|
6190
|
+
device float * dst_cur = dst + i1*ne0 + i2*ne1*ne0;
|
6191
|
+
|
6192
|
+
impl_fn(
|
6193
|
+
/* src0 */ src0_cur,
|
6194
|
+
/* src1 */ src1_cur,
|
6195
|
+
/* dst */ dst_cur,
|
6196
|
+
/* ne00 */ ne00,
|
6197
|
+
/* ne01 */ ne01,
|
6198
|
+
/* ne02 */ 1,//ne02,
|
6199
|
+
/* nb00 */ nb00,
|
6200
|
+
/* nb01 */ nb01,
|
6201
|
+
/* nb02 */ nb02,
|
6202
|
+
/* ne10 */ ne10,
|
6203
|
+
/* ne11 */ 1,//ne11,
|
6204
|
+
/* ne12 */ 1,//ne12,
|
6205
|
+
/* ne13 */ 1,//ne13,
|
6206
|
+
/* nb10 */ nb10,
|
6207
|
+
/* nb11 */ nb11,
|
6208
|
+
/* nb12 */ nb12,
|
6209
|
+
/* ne0 */ ne0,
|
6210
|
+
/* ne1 */ 1,//ne1,
|
6211
|
+
/* nb1 */ nb1,
|
6212
|
+
/* r2 */ 1,
|
6213
|
+
/* r3 */ 1,
|
6214
|
+
shared_values,
|
6082
6215
|
tgpig,
|
6083
|
-
|
6216
|
+
tiitg,
|
6217
|
+
tiisg,
|
6218
|
+
sgitg);
|
6084
6219
|
}
|
6085
6220
|
|
6086
|
-
|
6087
|
-
|
6088
|
-
|
6089
|
-
|
6090
|
-
|
6091
|
-
|
6092
|
-
|
6093
|
-
|
6094
|
-
|
6095
|
-
|
6096
|
-
|
6097
|
-
|
6098
|
-
|
6099
|
-
|
6100
|
-
|
6101
|
-
|
6102
|
-
|
6103
|
-
|
6104
|
-
|
6105
|
-
|
6106
|
-
|
6107
|
-
|
6108
|
-
|
6109
|
-
|
6110
|
-
constant uint & r3,
|
6111
|
-
constant int & idx,
|
6112
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
6113
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
6114
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
6115
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6116
|
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
6117
|
-
|
6118
|
-
tgpig.z = tgpig.z%(ne12*ne13);
|
6119
|
-
|
6120
|
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6121
|
-
device const char * src0 = src0s + id*nb02;
|
6122
|
-
|
6123
|
-
kernel_mul_mv_f16_f32_impl(
|
6124
|
-
src0,
|
6125
|
-
src1 + bid*nb11,
|
6126
|
-
dst + bid*ne0,
|
6127
|
-
ne00,
|
6128
|
-
ne01,
|
6129
|
-
ne02,
|
6130
|
-
nb00,
|
6131
|
-
nb01,
|
6132
|
-
nb02,
|
6133
|
-
ne10,
|
6134
|
-
ne11,
|
6135
|
-
ne12,
|
6136
|
-
nb10,
|
6137
|
-
nb11,
|
6138
|
-
nb12,
|
6139
|
-
ne0,
|
6140
|
-
ne1,
|
6141
|
-
r2,
|
6142
|
-
r3,
|
6143
|
-
tgpig,
|
6144
|
-
tiisg);
|
6145
|
-
}
|
6146
|
-
|
6147
|
-
[[host_name("kernel_mul_mv_id_q8_0_f32")]]
|
6148
|
-
kernel void kernel_mul_mv_id_q8_0_f32(
|
6149
|
-
device const char * src0s,
|
6150
|
-
device const char * src1,
|
6151
|
-
device float * dst,
|
6152
|
-
device const char * ids,
|
6153
|
-
constant uint64_t & nbi1,
|
6154
|
-
constant int64_t & ne00,
|
6155
|
-
constant int64_t & ne01,
|
6156
|
-
constant int64_t & ne02,
|
6157
|
-
constant uint64_t & nb00,
|
6158
|
-
constant uint64_t & nb01,
|
6159
|
-
constant uint64_t & nb02,
|
6160
|
-
constant int64_t & ne10,
|
6161
|
-
constant int64_t & ne11,
|
6162
|
-
constant int64_t & ne12,
|
6163
|
-
constant int64_t & ne13,
|
6164
|
-
constant uint64_t & nb10,
|
6165
|
-
constant uint64_t & nb11,
|
6166
|
-
constant uint64_t & nb12,
|
6167
|
-
constant int64_t & ne0,
|
6168
|
-
constant int64_t & ne1,
|
6169
|
-
constant uint64_t & nb1,
|
6170
|
-
constant uint & r2,
|
6171
|
-
constant uint & r3,
|
6172
|
-
constant int & idx,
|
6173
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
6174
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
6175
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
6176
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6177
|
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
6178
|
-
|
6179
|
-
tgpig.z = tgpig.z%(ne12*ne13);
|
6180
|
-
|
6181
|
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6182
|
-
device const char * src0 = src0s + id*nb02;
|
6183
|
-
|
6184
|
-
kernel_mul_mv_q8_0_f32_impl(
|
6185
|
-
src0,
|
6186
|
-
(device const float *) (src1 + bid*nb11),
|
6187
|
-
dst + bid*ne0,
|
6188
|
-
ne00,
|
6189
|
-
ne01,
|
6190
|
-
ne02,
|
6191
|
-
ne10,
|
6192
|
-
ne12,
|
6193
|
-
ne0,
|
6194
|
-
ne1,
|
6195
|
-
r2,
|
6196
|
-
r3,
|
6197
|
-
tgpig,
|
6198
|
-
tiisg,
|
6199
|
-
sgitg);
|
6200
|
-
}
|
6201
|
-
|
6202
|
-
[[host_name("kernel_mul_mv_id_q4_0_f32")]]
|
6203
|
-
kernel void kernel_mul_mv_id_q4_0_f32(
|
6204
|
-
device const char * src0s,
|
6205
|
-
device const char * src1,
|
6206
|
-
device float * dst,
|
6207
|
-
device const char * ids,
|
6208
|
-
constant uint64_t & nbi1,
|
6209
|
-
constant int64_t & ne00,
|
6210
|
-
constant int64_t & ne01,
|
6211
|
-
constant int64_t & ne02,
|
6212
|
-
constant uint64_t & nb00,
|
6213
|
-
constant uint64_t & nb01,
|
6214
|
-
constant uint64_t & nb02,
|
6215
|
-
constant int64_t & ne10,
|
6216
|
-
constant int64_t & ne11,
|
6217
|
-
constant int64_t & ne12,
|
6218
|
-
constant int64_t & ne13,
|
6219
|
-
constant uint64_t & nb10,
|
6220
|
-
constant uint64_t & nb11,
|
6221
|
-
constant uint64_t & nb12,
|
6222
|
-
constant int64_t & ne0,
|
6223
|
-
constant int64_t & ne1,
|
6224
|
-
constant uint64_t & nb1,
|
6225
|
-
constant uint & r2,
|
6226
|
-
constant uint & r3,
|
6227
|
-
constant int & idx,
|
6228
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
6229
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
6230
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
6231
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6232
|
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
6233
|
-
|
6234
|
-
tgpig.z = tgpig.z%(ne12*ne13);
|
6235
|
-
|
6236
|
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6237
|
-
device const char * src0 = src0s + id*nb02;
|
6238
|
-
|
6239
|
-
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
6240
|
-
src0,
|
6241
|
-
(device const float *) (src1 + bid*nb11),
|
6242
|
-
dst + bid*ne0,
|
6243
|
-
ne00,
|
6244
|
-
ne01,
|
6245
|
-
ne02,
|
6246
|
-
ne10,
|
6247
|
-
ne12,
|
6248
|
-
ne0,
|
6249
|
-
ne1,
|
6250
|
-
r2,
|
6251
|
-
r3,
|
6252
|
-
tgpig,
|
6253
|
-
tiisg,
|
6254
|
-
sgitg);
|
6255
|
-
}
|
6256
|
-
|
6257
|
-
[[host_name("kernel_mul_mv_id_q4_1_f32")]]
|
6258
|
-
kernel void kernel_mul_mv_id_q4_1_f32(
|
6259
|
-
device const char * src0s,
|
6260
|
-
device const char * src1,
|
6261
|
-
device float * dst,
|
6262
|
-
device const char * ids,
|
6263
|
-
constant uint64_t & nbi1,
|
6264
|
-
constant int64_t & ne00,
|
6265
|
-
constant int64_t & ne01,
|
6266
|
-
constant int64_t & ne02,
|
6267
|
-
constant uint64_t & nb00,
|
6268
|
-
constant uint64_t & nb01,
|
6269
|
-
constant uint64_t & nb02,
|
6270
|
-
constant int64_t & ne10,
|
6271
|
-
constant int64_t & ne11,
|
6272
|
-
constant int64_t & ne12,
|
6273
|
-
constant int64_t & ne13,
|
6274
|
-
constant uint64_t & nb10,
|
6275
|
-
constant uint64_t & nb11,
|
6276
|
-
constant uint64_t & nb12,
|
6277
|
-
constant int64_t & ne0,
|
6278
|
-
constant int64_t & ne1,
|
6279
|
-
constant uint64_t & nb1,
|
6280
|
-
constant uint & r2,
|
6281
|
-
constant uint & r3,
|
6282
|
-
constant int & idx,
|
6283
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
6284
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
6285
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
6286
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6287
|
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
6288
|
-
|
6289
|
-
tgpig.z = tgpig.z%(ne12*ne13);
|
6290
|
-
|
6291
|
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6292
|
-
device const char * src0 = src0s + id*nb02;
|
6293
|
-
|
6294
|
-
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
6295
|
-
src0,
|
6296
|
-
(device const float *) (src1 + bid*nb11),
|
6297
|
-
dst + bid*ne0,
|
6298
|
-
ne00,
|
6299
|
-
ne01,
|
6300
|
-
ne02,
|
6301
|
-
ne10,
|
6302
|
-
ne12,
|
6303
|
-
ne0,
|
6304
|
-
ne1,
|
6305
|
-
r2,
|
6306
|
-
r3,
|
6307
|
-
tgpig,
|
6308
|
-
tiisg,
|
6309
|
-
sgitg);
|
6310
|
-
}
|
6311
|
-
|
6312
|
-
[[host_name("kernel_mul_mv_id_q5_0_f32")]]
|
6313
|
-
kernel void kernel_mul_mv_id_q5_0_f32(
|
6314
|
-
device const char * src0s,
|
6315
|
-
device const char * src1,
|
6316
|
-
device float * dst,
|
6317
|
-
device const char * ids,
|
6318
|
-
constant uint64_t & nbi1,
|
6319
|
-
constant int64_t & ne00,
|
6320
|
-
constant int64_t & ne01,
|
6321
|
-
constant int64_t & ne02,
|
6322
|
-
constant uint64_t & nb00,
|
6323
|
-
constant uint64_t & nb01,
|
6324
|
-
constant uint64_t & nb02,
|
6325
|
-
constant int64_t & ne10,
|
6326
|
-
constant int64_t & ne11,
|
6327
|
-
constant int64_t & ne12,
|
6328
|
-
constant int64_t & ne13,
|
6329
|
-
constant uint64_t & nb10,
|
6330
|
-
constant uint64_t & nb11,
|
6331
|
-
constant uint64_t & nb12,
|
6332
|
-
constant int64_t & ne0,
|
6333
|
-
constant int64_t & ne1,
|
6334
|
-
constant uint64_t & nb1,
|
6335
|
-
constant uint & r2,
|
6336
|
-
constant uint & r3,
|
6337
|
-
constant int & idx,
|
6338
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
6339
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
6340
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
6341
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6342
|
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
6343
|
-
|
6344
|
-
tgpig.z = tgpig.z%(ne12*ne13);
|
6345
|
-
|
6346
|
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6347
|
-
device const char * src0 = src0s + id*nb02;
|
6348
|
-
|
6349
|
-
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
6350
|
-
src0,
|
6351
|
-
(device const float *) (src1 + bid*nb11),
|
6352
|
-
dst + bid*ne0,
|
6353
|
-
ne00,
|
6354
|
-
ne01,
|
6355
|
-
ne02,
|
6356
|
-
ne10,
|
6357
|
-
ne12,
|
6358
|
-
ne0,
|
6359
|
-
ne1,
|
6360
|
-
r2,
|
6361
|
-
r3,
|
6362
|
-
tgpig,
|
6363
|
-
tiisg,
|
6364
|
-
sgitg);
|
6365
|
-
}
|
6366
|
-
|
6367
|
-
[[host_name("kernel_mul_mv_id_q5_1_f32")]]
|
6368
|
-
kernel void kernel_mul_mv_id_q5_1_f32(
|
6369
|
-
device const char * src0s,
|
6370
|
-
device const char * src1,
|
6371
|
-
device float * dst,
|
6372
|
-
device const char * ids,
|
6373
|
-
constant uint64_t & nbi1,
|
6374
|
-
constant int64_t & ne00,
|
6375
|
-
constant int64_t & ne01,
|
6376
|
-
constant int64_t & ne02,
|
6377
|
-
constant uint64_t & nb00,
|
6378
|
-
constant uint64_t & nb01,
|
6379
|
-
constant uint64_t & nb02,
|
6380
|
-
constant int64_t & ne10,
|
6381
|
-
constant int64_t & ne11,
|
6382
|
-
constant int64_t & ne12,
|
6383
|
-
constant int64_t & ne13,
|
6384
|
-
constant uint64_t & nb10,
|
6385
|
-
constant uint64_t & nb11,
|
6386
|
-
constant uint64_t & nb12,
|
6387
|
-
constant int64_t & ne0,
|
6388
|
-
constant int64_t & ne1,
|
6389
|
-
constant uint64_t & nb1,
|
6390
|
-
constant uint & r2,
|
6391
|
-
constant uint & r3,
|
6392
|
-
constant int & idx,
|
6393
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
6394
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
6395
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
6396
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6397
|
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
6398
|
-
|
6399
|
-
tgpig.z = tgpig.z%(ne12*ne13);
|
6400
|
-
|
6401
|
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6402
|
-
device const char * src0 = src0s + id*nb02;
|
6403
|
-
|
6404
|
-
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
6405
|
-
src0,
|
6406
|
-
(device const float *) (src1 + bid*nb11),
|
6407
|
-
dst + bid*ne0,
|
6408
|
-
ne00,
|
6409
|
-
ne01,
|
6410
|
-
ne02,
|
6411
|
-
ne10,
|
6412
|
-
ne12,
|
6413
|
-
ne0,
|
6414
|
-
ne1,
|
6415
|
-
r2,
|
6416
|
-
r3,
|
6417
|
-
tgpig,
|
6418
|
-
tiisg,
|
6419
|
-
sgitg);
|
6420
|
-
}
|
6421
|
-
|
6422
|
-
[[host_name("kernel_mul_mv_id_q2_K_f32")]]
|
6423
|
-
kernel void kernel_mul_mv_id_q2_K_f32(
|
6424
|
-
device const char * src0s,
|
6425
|
-
device const char * src1,
|
6426
|
-
device float * dst,
|
6427
|
-
device const char * ids,
|
6428
|
-
constant uint64_t & nbi1,
|
6429
|
-
constant int64_t & ne00,
|
6430
|
-
constant int64_t & ne01,
|
6431
|
-
constant int64_t & ne02,
|
6432
|
-
constant uint64_t & nb00,
|
6433
|
-
constant uint64_t & nb01,
|
6434
|
-
constant uint64_t & nb02,
|
6435
|
-
constant int64_t & ne10,
|
6436
|
-
constant int64_t & ne11,
|
6437
|
-
constant int64_t & ne12,
|
6438
|
-
constant int64_t & ne13,
|
6439
|
-
constant uint64_t & nb10,
|
6440
|
-
constant uint64_t & nb11,
|
6441
|
-
constant uint64_t & nb12,
|
6442
|
-
constant int64_t & ne0,
|
6443
|
-
constant int64_t & ne1,
|
6444
|
-
constant uint64_t & nb1,
|
6445
|
-
constant uint & r2,
|
6446
|
-
constant uint & r3,
|
6447
|
-
constant int & idx,
|
6448
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
6449
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
6450
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
6451
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6452
|
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
6453
|
-
|
6454
|
-
tgpig.z = tgpig.z%(ne12*ne13);
|
6455
|
-
|
6456
|
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6457
|
-
device const char * src0 = src0s + id*nb02;
|
6458
|
-
|
6459
|
-
kernel_mul_mv_q2_K_f32_impl(
|
6460
|
-
src0,
|
6461
|
-
(device const float *) (src1 + bid*nb11),
|
6462
|
-
dst + bid*ne0,
|
6463
|
-
ne00,
|
6464
|
-
ne01,
|
6465
|
-
ne02,
|
6466
|
-
ne10,
|
6467
|
-
ne12,
|
6468
|
-
ne0,
|
6469
|
-
ne1,
|
6470
|
-
r2,
|
6471
|
-
r3,
|
6472
|
-
tgpig,
|
6473
|
-
tiisg,
|
6474
|
-
sgitg);
|
6475
|
-
}
|
6476
|
-
|
6477
|
-
[[host_name("kernel_mul_mv_id_q3_K_f32")]]
|
6478
|
-
kernel void kernel_mul_mv_id_q3_K_f32(
|
6479
|
-
device const char * src0s,
|
6480
|
-
device const char * src1,
|
6481
|
-
device float * dst,
|
6482
|
-
device const char * ids,
|
6483
|
-
constant uint64_t & nbi1,
|
6484
|
-
constant int64_t & ne00,
|
6485
|
-
constant int64_t & ne01,
|
6486
|
-
constant int64_t & ne02,
|
6487
|
-
constant uint64_t & nb00,
|
6488
|
-
constant uint64_t & nb01,
|
6489
|
-
constant uint64_t & nb02,
|
6490
|
-
constant int64_t & ne10,
|
6491
|
-
constant int64_t & ne11,
|
6492
|
-
constant int64_t & ne12,
|
6493
|
-
constant int64_t & ne13,
|
6494
|
-
constant uint64_t & nb10,
|
6495
|
-
constant uint64_t & nb11,
|
6496
|
-
constant uint64_t & nb12,
|
6497
|
-
constant int64_t & ne0,
|
6498
|
-
constant int64_t & ne1,
|
6499
|
-
constant uint64_t & nb1,
|
6500
|
-
constant uint & r2,
|
6501
|
-
constant uint & r3,
|
6502
|
-
constant int & idx,
|
6503
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
6504
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
6505
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
6506
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6507
|
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
6508
|
-
|
6509
|
-
tgpig.z = tgpig.z%(ne12*ne13);
|
6510
|
-
|
6511
|
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6512
|
-
device const char * src0 = src0s + id*nb02;
|
6513
|
-
|
6514
|
-
kernel_mul_mv_q3_K_f32_impl(
|
6515
|
-
src0,
|
6516
|
-
(device const float *) (src1 + bid*nb11),
|
6517
|
-
dst + bid*ne0,
|
6518
|
-
ne00,
|
6519
|
-
ne01,
|
6520
|
-
ne02,
|
6521
|
-
ne10,
|
6522
|
-
ne12,
|
6523
|
-
ne0,
|
6524
|
-
ne1,
|
6525
|
-
r2,
|
6526
|
-
r3,
|
6527
|
-
tgpig,
|
6528
|
-
tiisg,
|
6529
|
-
sgitg);
|
6530
|
-
}
|
6531
|
-
|
6532
|
-
[[host_name("kernel_mul_mv_id_q4_K_f32")]]
|
6533
|
-
kernel void kernel_mul_mv_id_q4_K_f32(
|
6534
|
-
device const char * src0s,
|
6535
|
-
device const char * src1,
|
6536
|
-
device float * dst,
|
6537
|
-
device const char * ids,
|
6538
|
-
constant uint64_t & nbi1,
|
6539
|
-
constant int64_t & ne00,
|
6540
|
-
constant int64_t & ne01,
|
6541
|
-
constant int64_t & ne02,
|
6542
|
-
constant uint64_t & nb00,
|
6543
|
-
constant uint64_t & nb01,
|
6544
|
-
constant uint64_t & nb02,
|
6545
|
-
constant int64_t & ne10,
|
6546
|
-
constant int64_t & ne11,
|
6547
|
-
constant int64_t & ne12,
|
6548
|
-
constant int64_t & ne13,
|
6549
|
-
constant uint64_t & nb10,
|
6550
|
-
constant uint64_t & nb11,
|
6551
|
-
constant uint64_t & nb12,
|
6552
|
-
constant int64_t & ne0,
|
6553
|
-
constant int64_t & ne1,
|
6554
|
-
constant uint64_t & nb1,
|
6555
|
-
constant uint & r2,
|
6556
|
-
constant uint & r3,
|
6557
|
-
constant int & idx,
|
6558
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
6559
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
6560
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
6561
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6562
|
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
6563
|
-
|
6564
|
-
tgpig.z = tgpig.z%(ne12*ne13);
|
6565
|
-
|
6566
|
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6567
|
-
device const char * src0 = src0s + id*nb02;
|
6568
|
-
|
6569
|
-
kernel_mul_mv_q4_K_f32_impl(
|
6570
|
-
src0,
|
6571
|
-
(device const float *) (src1 + bid*nb11),
|
6572
|
-
dst + bid*ne0,
|
6573
|
-
ne00,
|
6574
|
-
ne01,
|
6575
|
-
ne02,
|
6576
|
-
ne10,
|
6577
|
-
ne12,
|
6578
|
-
ne0,
|
6579
|
-
ne1,
|
6580
|
-
r2,
|
6581
|
-
r3,
|
6582
|
-
tgpig,
|
6583
|
-
tiisg,
|
6584
|
-
sgitg);
|
6585
|
-
}
|
6586
|
-
|
6587
|
-
[[host_name("kernel_mul_mv_id_q5_K_f32")]]
|
6588
|
-
kernel void kernel_mul_mv_id_q5_K_f32(
|
6589
|
-
device const char * src0s,
|
6590
|
-
device const char * src1,
|
6591
|
-
device float * dst,
|
6592
|
-
device const char * ids,
|
6593
|
-
constant uint64_t & nbi1,
|
6594
|
-
constant int64_t & ne00,
|
6595
|
-
constant int64_t & ne01,
|
6596
|
-
constant int64_t & ne02,
|
6597
|
-
constant uint64_t & nb00,
|
6598
|
-
constant uint64_t & nb01,
|
6599
|
-
constant uint64_t & nb02,
|
6600
|
-
constant int64_t & ne10,
|
6601
|
-
constant int64_t & ne11,
|
6602
|
-
constant int64_t & ne12,
|
6603
|
-
constant int64_t & ne13,
|
6604
|
-
constant uint64_t & nb10,
|
6605
|
-
constant uint64_t & nb11,
|
6606
|
-
constant uint64_t & nb12,
|
6607
|
-
constant int64_t & ne0,
|
6608
|
-
constant int64_t & ne1,
|
6609
|
-
constant uint64_t & nb1,
|
6610
|
-
constant uint & r2,
|
6611
|
-
constant uint & r3,
|
6612
|
-
constant int & idx,
|
6613
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
6614
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
6615
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
6616
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6617
|
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
6618
|
-
|
6619
|
-
tgpig.z = tgpig.z%(ne12*ne13);
|
6620
|
-
|
6621
|
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6622
|
-
device const char * src0 = src0s + id*nb02;
|
6623
|
-
|
6624
|
-
kernel_mul_mv_q5_K_f32_impl(
|
6625
|
-
src0,
|
6626
|
-
(device const float *) (src1 + bid*nb11),
|
6627
|
-
dst + bid*ne0,
|
6628
|
-
ne00,
|
6629
|
-
ne01,
|
6630
|
-
ne02,
|
6631
|
-
ne10,
|
6632
|
-
ne12,
|
6633
|
-
ne0,
|
6634
|
-
ne1,
|
6635
|
-
r2,
|
6636
|
-
r3,
|
6637
|
-
tgpig,
|
6638
|
-
tiisg,
|
6639
|
-
sgitg);
|
6640
|
-
}
|
6641
|
-
|
6642
|
-
[[host_name("kernel_mul_mv_id_q6_K_f32")]]
|
6643
|
-
kernel void kernel_mul_mv_id_q6_K_f32(
|
6644
|
-
device const char * src0s,
|
6645
|
-
device const char * src1,
|
6646
|
-
device float * dst,
|
6647
|
-
device const char * ids,
|
6648
|
-
constant uint64_t & nbi1,
|
6649
|
-
constant int64_t & ne00,
|
6650
|
-
constant int64_t & ne01,
|
6651
|
-
constant int64_t & ne02,
|
6652
|
-
constant uint64_t & nb00,
|
6653
|
-
constant uint64_t & nb01,
|
6654
|
-
constant uint64_t & nb02,
|
6655
|
-
constant int64_t & ne10,
|
6656
|
-
constant int64_t & ne11,
|
6657
|
-
constant int64_t & ne12,
|
6658
|
-
constant int64_t & ne13,
|
6659
|
-
constant uint64_t & nb10,
|
6660
|
-
constant uint64_t & nb11,
|
6661
|
-
constant uint64_t & nb12,
|
6662
|
-
constant int64_t & ne0,
|
6663
|
-
constant int64_t & ne1,
|
6664
|
-
constant uint64_t & nb1,
|
6665
|
-
constant uint & r2,
|
6666
|
-
constant uint & r3,
|
6667
|
-
constant int & idx,
|
6668
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
6669
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
6670
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
6671
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6672
|
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
6673
|
-
|
6674
|
-
tgpig.z = tgpig.z%(ne12*ne13);
|
6675
|
-
|
6676
|
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6677
|
-
device const char * src0 = src0s + id*nb02;
|
6678
|
-
|
6679
|
-
kernel_mul_mv_q6_K_f32_impl(
|
6680
|
-
src0,
|
6681
|
-
(device const float *) (src1 + bid*nb11),
|
6682
|
-
dst + bid*ne0,
|
6683
|
-
ne00,
|
6684
|
-
ne01,
|
6685
|
-
ne02,
|
6686
|
-
ne10,
|
6687
|
-
ne12,
|
6688
|
-
ne0,
|
6689
|
-
ne1,
|
6690
|
-
r2,
|
6691
|
-
r3,
|
6692
|
-
tgpig,
|
6693
|
-
tiisg,
|
6694
|
-
sgitg);
|
6695
|
-
}
|
6696
|
-
|
6697
|
-
[[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
|
6698
|
-
kernel void kernel_mul_mv_id_iq2_xxs_f32(
|
6699
|
-
device const char * src0s,
|
6700
|
-
device const char * src1,
|
6701
|
-
device float * dst,
|
6702
|
-
device const char * ids,
|
6703
|
-
constant uint64_t & nbi1,
|
6704
|
-
constant int64_t & ne00,
|
6705
|
-
constant int64_t & ne01,
|
6706
|
-
constant int64_t & ne02,
|
6707
|
-
constant uint64_t & nb00,
|
6708
|
-
constant uint64_t & nb01,
|
6709
|
-
constant uint64_t & nb02,
|
6710
|
-
constant int64_t & ne10,
|
6711
|
-
constant int64_t & ne11,
|
6712
|
-
constant int64_t & ne12,
|
6713
|
-
constant int64_t & ne13,
|
6714
|
-
constant uint64_t & nb10,
|
6715
|
-
constant uint64_t & nb11,
|
6716
|
-
constant uint64_t & nb12,
|
6717
|
-
constant int64_t & ne0,
|
6718
|
-
constant int64_t & ne1,
|
6719
|
-
constant uint64_t & nb1,
|
6720
|
-
constant uint & r2,
|
6721
|
-
constant uint & r3,
|
6722
|
-
constant int & idx,
|
6723
|
-
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
6724
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
6725
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
6726
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
6727
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6728
|
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
6729
|
-
|
6730
|
-
tgpig.z = tgpig.z%(ne12*ne13);
|
6731
|
-
|
6732
|
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6733
|
-
device const char * src0 = src0s + id*nb02;
|
6734
|
-
|
6735
|
-
kernel_mul_mv_iq2_xxs_f32_impl(
|
6736
|
-
src0,
|
6737
|
-
(device const float *) (src1 + bid*nb11),
|
6738
|
-
dst + bid*ne0,
|
6739
|
-
ne00,
|
6740
|
-
ne01,
|
6741
|
-
ne02,
|
6742
|
-
ne10,
|
6743
|
-
ne12,
|
6744
|
-
ne0,
|
6745
|
-
ne1,
|
6746
|
-
r2,
|
6747
|
-
r3,
|
6748
|
-
shared_values,
|
6749
|
-
tgpig,
|
6750
|
-
tiisg,
|
6751
|
-
sgitg);
|
6752
|
-
}
|
6753
|
-
|
6754
|
-
[[host_name("kernel_mul_mv_id_iq2_xs_f32")]]
|
6755
|
-
kernel void kernel_mul_mv_id_iq2_xs_f32(
|
6756
|
-
device const char * src0s,
|
6757
|
-
device const char * src1,
|
6758
|
-
device float * dst,
|
6759
|
-
device const char * ids,
|
6760
|
-
constant uint64_t & nbi1,
|
6761
|
-
constant int64_t & ne00,
|
6762
|
-
constant int64_t & ne01,
|
6763
|
-
constant int64_t & ne02,
|
6764
|
-
constant uint64_t & nb00,
|
6765
|
-
constant uint64_t & nb01,
|
6766
|
-
constant uint64_t & nb02,
|
6767
|
-
constant int64_t & ne10,
|
6768
|
-
constant int64_t & ne11,
|
6769
|
-
constant int64_t & ne12,
|
6770
|
-
constant int64_t & ne13,
|
6771
|
-
constant uint64_t & nb10,
|
6772
|
-
constant uint64_t & nb11,
|
6773
|
-
constant uint64_t & nb12,
|
6774
|
-
constant int64_t & ne0,
|
6775
|
-
constant int64_t & ne1,
|
6776
|
-
constant uint64_t & nb1,
|
6777
|
-
constant uint & r2,
|
6778
|
-
constant uint & r3,
|
6779
|
-
constant int & idx,
|
6780
|
-
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
6781
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
6782
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
6783
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
6784
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6785
|
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
6786
|
-
|
6787
|
-
tgpig.z = tgpig.z%(ne12*ne13);
|
6788
|
-
|
6789
|
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6790
|
-
device const char * src0 = src0s + id*nb02;
|
6791
|
-
|
6792
|
-
kernel_mul_mv_iq2_xs_f32_impl(
|
6793
|
-
src0,
|
6794
|
-
(device const float *) (src1 + bid*nb11),
|
6795
|
-
dst + bid*ne0,
|
6796
|
-
ne00,
|
6797
|
-
ne01,
|
6798
|
-
ne02,
|
6799
|
-
ne10,
|
6800
|
-
ne12,
|
6801
|
-
ne0,
|
6802
|
-
ne1,
|
6803
|
-
r2,
|
6804
|
-
r3,
|
6805
|
-
shared_values,
|
6806
|
-
tgpig,
|
6807
|
-
tiisg,
|
6808
|
-
sgitg);
|
6809
|
-
}
|
6810
|
-
|
6811
|
-
[[host_name("kernel_mul_mv_id_iq3_xxs_f32")]]
|
6812
|
-
kernel void kernel_mul_mv_id_iq3_xxs_f32(
|
6813
|
-
device const char * src0s,
|
6814
|
-
device const char * src1,
|
6815
|
-
device float * dst,
|
6816
|
-
device const char * ids,
|
6817
|
-
constant uint64_t & nbi1,
|
6818
|
-
constant int64_t & ne00,
|
6819
|
-
constant int64_t & ne01,
|
6820
|
-
constant int64_t & ne02,
|
6821
|
-
constant uint64_t & nb00,
|
6822
|
-
constant uint64_t & nb01,
|
6823
|
-
constant uint64_t & nb02,
|
6824
|
-
constant int64_t & ne10,
|
6825
|
-
constant int64_t & ne11,
|
6826
|
-
constant int64_t & ne12,
|
6827
|
-
constant int64_t & ne13,
|
6828
|
-
constant uint64_t & nb10,
|
6829
|
-
constant uint64_t & nb11,
|
6830
|
-
constant uint64_t & nb12,
|
6831
|
-
constant int64_t & ne0,
|
6832
|
-
constant int64_t & ne1,
|
6833
|
-
constant uint64_t & nb1,
|
6834
|
-
constant uint & r2,
|
6835
|
-
constant uint & r3,
|
6836
|
-
constant int & idx,
|
6837
|
-
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
6838
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
6839
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
6840
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
6841
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6842
|
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
6843
|
-
|
6844
|
-
tgpig.z = tgpig.z%(ne12*ne13);
|
6845
|
-
|
6846
|
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6847
|
-
device const char * src0 = src0s + id*nb02;
|
6848
|
-
|
6849
|
-
kernel_mul_mv_iq3_xxs_f32_impl(
|
6850
|
-
src0,
|
6851
|
-
(device const float *) (src1 + bid*nb11),
|
6852
|
-
dst + bid*ne0,
|
6853
|
-
ne00,
|
6854
|
-
ne01,
|
6855
|
-
ne02,
|
6856
|
-
ne10,
|
6857
|
-
ne12,
|
6858
|
-
ne0,
|
6859
|
-
ne1,
|
6860
|
-
r2,
|
6861
|
-
r3,
|
6862
|
-
shared_values,
|
6863
|
-
tgpig,
|
6864
|
-
tiisg,
|
6865
|
-
sgitg);
|
6866
|
-
}
|
6867
|
-
|
6868
|
-
[[host_name("kernel_mul_mv_id_iq3_s_f32")]]
|
6869
|
-
kernel void kernel_mul_mv_id_iq3_s_f32(
|
6870
|
-
device const char * src0s,
|
6871
|
-
device const char * src1,
|
6872
|
-
device float * dst,
|
6873
|
-
device const char * ids,
|
6874
|
-
constant uint64_t & nbi1,
|
6875
|
-
constant int64_t & ne00,
|
6876
|
-
constant int64_t & ne01,
|
6877
|
-
constant int64_t & ne02,
|
6878
|
-
constant uint64_t & nb00,
|
6879
|
-
constant uint64_t & nb01,
|
6880
|
-
constant uint64_t & nb02,
|
6881
|
-
constant int64_t & ne10,
|
6882
|
-
constant int64_t & ne11,
|
6883
|
-
constant int64_t & ne12,
|
6884
|
-
constant int64_t & ne13,
|
6885
|
-
constant uint64_t & nb10,
|
6886
|
-
constant uint64_t & nb11,
|
6887
|
-
constant uint64_t & nb12,
|
6888
|
-
constant int64_t & ne0,
|
6889
|
-
constant int64_t & ne1,
|
6890
|
-
constant uint64_t & nb1,
|
6891
|
-
constant uint & r2,
|
6892
|
-
constant uint & r3,
|
6893
|
-
constant int & idx,
|
6894
|
-
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
6895
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
6896
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
6897
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
6898
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6899
|
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
6900
|
-
|
6901
|
-
tgpig.z = tgpig.z%(ne12*ne13);
|
6902
|
-
|
6903
|
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6904
|
-
device const char * src0 = src0s + id*nb02;
|
6905
|
-
|
6906
|
-
kernel_mul_mv_iq3_s_f32_impl(
|
6907
|
-
src0,
|
6908
|
-
(device const float *) (src1 + bid*nb11),
|
6909
|
-
dst + bid*ne0,
|
6910
|
-
ne00,
|
6911
|
-
ne01,
|
6912
|
-
ne02,
|
6913
|
-
ne10,
|
6914
|
-
ne12,
|
6915
|
-
ne0,
|
6916
|
-
ne1,
|
6917
|
-
r2,
|
6918
|
-
r3,
|
6919
|
-
shared_values,
|
6920
|
-
tgpig,
|
6921
|
-
tiisg,
|
6922
|
-
sgitg);
|
6923
|
-
}
|
6924
|
-
|
6925
|
-
[[host_name("kernel_mul_mv_id_iq2_s_f32")]]
|
6926
|
-
kernel void kernel_mul_mv_id_iq2_s_f32(
|
6927
|
-
device const char * src0s,
|
6928
|
-
device const char * src1,
|
6929
|
-
device float * dst,
|
6930
|
-
device const char * ids,
|
6931
|
-
constant uint64_t & nbi1,
|
6932
|
-
constant int64_t & ne00,
|
6933
|
-
constant int64_t & ne01,
|
6934
|
-
constant int64_t & ne02,
|
6935
|
-
constant uint64_t & nb00,
|
6936
|
-
constant uint64_t & nb01,
|
6937
|
-
constant uint64_t & nb02,
|
6938
|
-
constant int64_t & ne10,
|
6939
|
-
constant int64_t & ne11,
|
6940
|
-
constant int64_t & ne12,
|
6941
|
-
constant int64_t & ne13,
|
6942
|
-
constant uint64_t & nb10,
|
6943
|
-
constant uint64_t & nb11,
|
6944
|
-
constant uint64_t & nb12,
|
6945
|
-
constant int64_t & ne0,
|
6946
|
-
constant int64_t & ne1,
|
6947
|
-
constant uint64_t & nb1,
|
6948
|
-
constant uint & r2,
|
6949
|
-
constant uint & r3,
|
6950
|
-
constant int & idx,
|
6951
|
-
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
6952
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
6953
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
6954
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
6955
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6956
|
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
6957
|
-
|
6958
|
-
tgpig.z = tgpig.z%(ne12*ne13);
|
6959
|
-
|
6960
|
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6961
|
-
device const char * src0 = src0s + id*nb02;
|
6962
|
-
|
6963
|
-
kernel_mul_mv_iq2_s_f32_impl(
|
6964
|
-
src0,
|
6965
|
-
(device const float *) (src1 + bid*nb11),
|
6966
|
-
dst + bid*ne0,
|
6967
|
-
ne00,
|
6968
|
-
ne01,
|
6969
|
-
ne02,
|
6970
|
-
ne10,
|
6971
|
-
ne12,
|
6972
|
-
ne0,
|
6973
|
-
ne1,
|
6974
|
-
r2,
|
6975
|
-
r3,
|
6976
|
-
shared_values,
|
6977
|
-
tgpig,
|
6978
|
-
tiisg,
|
6979
|
-
sgitg);
|
6980
|
-
}
|
6981
|
-
|
6982
|
-
[[host_name("kernel_mul_mv_id_iq1_s_f32")]]
|
6983
|
-
kernel void kernel_mul_mv_id_iq1_s_f32(
|
6984
|
-
device const char * src0s,
|
6985
|
-
device const char * src1,
|
6986
|
-
device float * dst,
|
6987
|
-
device const char * ids,
|
6988
|
-
constant uint64_t & nbi1,
|
6989
|
-
constant int64_t & ne00,
|
6990
|
-
constant int64_t & ne01,
|
6991
|
-
constant int64_t & ne02,
|
6992
|
-
constant uint64_t & nb00,
|
6993
|
-
constant uint64_t & nb01,
|
6994
|
-
constant uint64_t & nb02,
|
6995
|
-
constant int64_t & ne10,
|
6996
|
-
constant int64_t & ne11,
|
6997
|
-
constant int64_t & ne12,
|
6998
|
-
constant int64_t & ne13,
|
6999
|
-
constant uint64_t & nb10,
|
7000
|
-
constant uint64_t & nb11,
|
7001
|
-
constant uint64_t & nb12,
|
7002
|
-
constant int64_t & ne0,
|
7003
|
-
constant int64_t & ne1,
|
7004
|
-
constant uint64_t & nb1,
|
7005
|
-
constant uint & r2,
|
7006
|
-
constant uint & r3,
|
7007
|
-
constant int & idx,
|
7008
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
7009
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
7010
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
7011
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
7012
|
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
7013
|
-
|
7014
|
-
tgpig.z = tgpig.z%(ne12*ne13);
|
7015
|
-
|
7016
|
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
7017
|
-
device const char * src0 = src0s + id*nb02;
|
7018
|
-
|
7019
|
-
kernel_mul_mv_iq1_s_f32_impl(
|
7020
|
-
src0,
|
7021
|
-
(device const float *) (src1 + bid*nb11),
|
7022
|
-
dst + bid*ne0,
|
7023
|
-
ne00,
|
7024
|
-
ne01,
|
7025
|
-
ne02,
|
7026
|
-
ne10,
|
7027
|
-
ne12,
|
7028
|
-
ne0,
|
7029
|
-
ne1,
|
7030
|
-
r2,
|
7031
|
-
r3,
|
7032
|
-
tgpig,
|
7033
|
-
tiisg,
|
7034
|
-
sgitg);
|
7035
|
-
}
|
7036
|
-
|
7037
|
-
[[host_name("kernel_mul_mv_id_iq1_m_f32")]]
|
7038
|
-
kernel void kernel_mul_mv_id_iq1_m_f32(
|
7039
|
-
device const char * src0s,
|
7040
|
-
device const char * src1,
|
7041
|
-
device float * dst,
|
7042
|
-
device const char * ids,
|
7043
|
-
constant uint64_t & nbi1,
|
7044
|
-
constant int64_t & ne00,
|
7045
|
-
constant int64_t & ne01,
|
7046
|
-
constant int64_t & ne02,
|
7047
|
-
constant uint64_t & nb00,
|
7048
|
-
constant uint64_t & nb01,
|
7049
|
-
constant uint64_t & nb02,
|
7050
|
-
constant int64_t & ne10,
|
7051
|
-
constant int64_t & ne11,
|
7052
|
-
constant int64_t & ne12,
|
7053
|
-
constant int64_t & ne13,
|
7054
|
-
constant uint64_t & nb10,
|
7055
|
-
constant uint64_t & nb11,
|
7056
|
-
constant uint64_t & nb12,
|
7057
|
-
constant int64_t & ne0,
|
7058
|
-
constant int64_t & ne1,
|
7059
|
-
constant uint64_t & nb1,
|
7060
|
-
constant uint & r2,
|
7061
|
-
constant uint & r3,
|
7062
|
-
constant int & idx,
|
7063
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
7064
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
7065
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
7066
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
7067
|
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
7068
|
-
|
7069
|
-
tgpig.z = tgpig.z%(ne12*ne13);
|
7070
|
-
|
7071
|
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
7072
|
-
device const char * src0 = src0s + id*nb02;
|
7073
|
-
|
7074
|
-
kernel_mul_mv_iq1_m_f32_impl(
|
7075
|
-
src0,
|
7076
|
-
(device const float *) (src1 + bid*nb11),
|
7077
|
-
dst + bid*ne0,
|
7078
|
-
ne00,
|
7079
|
-
ne01,
|
7080
|
-
ne02,
|
7081
|
-
ne10,
|
7082
|
-
ne12,
|
7083
|
-
ne0,
|
7084
|
-
ne1,
|
7085
|
-
r2,
|
7086
|
-
r3,
|
7087
|
-
tgpig,
|
7088
|
-
tiisg,
|
7089
|
-
sgitg);
|
7090
|
-
}
|
7091
|
-
|
7092
|
-
[[host_name("kernel_mul_mv_id_iq4_nl_f32")]]
|
7093
|
-
kernel void kernel_mul_mv_id_iq4_nl_f32(
|
7094
|
-
device const char * src0s,
|
7095
|
-
device const char * src1,
|
7096
|
-
device float * dst,
|
7097
|
-
device const char * ids,
|
7098
|
-
constant uint64_t & nbi1,
|
7099
|
-
constant int64_t & ne00,
|
7100
|
-
constant int64_t & ne01,
|
7101
|
-
constant int64_t & ne02,
|
7102
|
-
constant uint64_t & nb00,
|
7103
|
-
constant uint64_t & nb01,
|
7104
|
-
constant uint64_t & nb02,
|
7105
|
-
constant int64_t & ne10,
|
7106
|
-
constant int64_t & ne11,
|
7107
|
-
constant int64_t & ne12,
|
7108
|
-
constant int64_t & ne13,
|
7109
|
-
constant uint64_t & nb10,
|
7110
|
-
constant uint64_t & nb11,
|
7111
|
-
constant uint64_t & nb12,
|
7112
|
-
constant int64_t & ne0,
|
7113
|
-
constant int64_t & ne1,
|
7114
|
-
constant uint64_t & nb1,
|
7115
|
-
constant uint & r2,
|
7116
|
-
constant uint & r3,
|
7117
|
-
constant int & idx,
|
7118
|
-
threadgroup float * shared_values [[threadgroup(0)]],
|
7119
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
7120
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
7121
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
7122
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
7123
|
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
7124
|
-
|
7125
|
-
tgpig.z = tgpig.z%(ne12*ne13);
|
7126
|
-
|
7127
|
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
7128
|
-
device const char * src0 = src0s + id*nb02;
|
7129
|
-
|
7130
|
-
kernel_mul_mv_iq4_nl_f32_impl(
|
7131
|
-
src0,
|
7132
|
-
(device const float *) (src1 + bid*nb11),
|
7133
|
-
dst + bid*ne0,
|
7134
|
-
ne00,
|
7135
|
-
ne01,
|
7136
|
-
ne02,
|
7137
|
-
ne10,
|
7138
|
-
ne12,
|
7139
|
-
ne0,
|
7140
|
-
ne1,
|
7141
|
-
r2,
|
7142
|
-
r3,
|
7143
|
-
shared_values,
|
7144
|
-
tgpig,
|
7145
|
-
tiisg,
|
7146
|
-
sgitg);
|
7147
|
-
}
|
7148
|
-
|
7149
|
-
[[host_name("kernel_mul_mv_id_iq4_xs_f32")]]
|
7150
|
-
kernel void kernel_mul_mv_id_iq4_xs_f32(
|
7151
|
-
device const char * src0s,
|
7152
|
-
device const char * src1,
|
7153
|
-
device float * dst,
|
7154
|
-
device const char * ids,
|
7155
|
-
constant uint64_t & nbi1,
|
7156
|
-
constant int64_t & ne00,
|
7157
|
-
constant int64_t & ne01,
|
7158
|
-
constant int64_t & ne02,
|
7159
|
-
constant uint64_t & nb00,
|
7160
|
-
constant uint64_t & nb01,
|
7161
|
-
constant uint64_t & nb02,
|
7162
|
-
constant int64_t & ne10,
|
7163
|
-
constant int64_t & ne11,
|
7164
|
-
constant int64_t & ne12,
|
7165
|
-
constant int64_t & ne13,
|
7166
|
-
constant uint64_t & nb10,
|
7167
|
-
constant uint64_t & nb11,
|
7168
|
-
constant uint64_t & nb12,
|
7169
|
-
constant int64_t & ne0,
|
7170
|
-
constant int64_t & ne1,
|
7171
|
-
constant uint64_t & nb1,
|
7172
|
-
constant uint & r2,
|
7173
|
-
constant uint & r3,
|
7174
|
-
constant int & idx,
|
7175
|
-
threadgroup float * shared_values [[threadgroup(0)]],
|
7176
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
7177
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
7178
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
7179
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
7180
|
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
7181
|
-
|
7182
|
-
tgpig.z = tgpig.z%(ne12*ne13);
|
7183
|
-
|
7184
|
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
7185
|
-
device const char * src0 = src0s + id*nb02;
|
7186
|
-
|
7187
|
-
#if QK_K == 64
|
7188
|
-
kernel_mul_mv_iq4_nl_f32_impl(
|
7189
|
-
#else
|
7190
|
-
kernel_mul_mv_iq4_xs_f32_impl(
|
6221
|
+
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>) kernel_mul_mv_id_t;
|
6222
|
+
|
6223
|
+
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>;
|
6224
|
+
template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f16_f32_impl>>;
|
6225
|
+
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
|
6226
|
+
template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
6227
|
+
template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
6228
|
+
template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
6229
|
+
template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
6230
|
+
template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl>>;
|
6231
|
+
template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl>>;
|
6232
|
+
template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl>>;
|
6233
|
+
template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl>>;
|
6234
|
+
template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl>>;
|
6235
|
+
template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl>>;
|
6236
|
+
template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl>>;
|
6237
|
+
template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl>>;
|
6238
|
+
template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl>>;
|
6239
|
+
template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl>>;
|
6240
|
+
template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl>>;
|
6241
|
+
template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
|
6242
|
+
template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
|
6243
|
+
#if QK_K != 64
|
6244
|
+
template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
|
7191
6245
|
#endif
|
7192
|
-
|
7193
|
-
(device const float *) (src1 + bid*nb11),
|
7194
|
-
dst + bid*ne0,
|
7195
|
-
ne00,
|
7196
|
-
ne01,
|
7197
|
-
ne02,
|
7198
|
-
ne10,
|
7199
|
-
ne12,
|
7200
|
-
ne0,
|
7201
|
-
ne1,
|
7202
|
-
r2,
|
7203
|
-
r3,
|
7204
|
-
shared_values,
|
7205
|
-
tgpig,
|
7206
|
-
tiisg,
|
7207
|
-
sgitg);
|
7208
|
-
}
|
6246
|
+
|