llama_cpp 0.14.5 → 0.14.7

Sign up to get free protection for your applications and to get access to all the features.
@@ -213,6 +213,15 @@ kernel void kernel_scale_4(
213
213
  dst[tpig] = src0[tpig] * scale;
214
214
  }
215
215
 
216
+ kernel void kernel_clamp(
217
+ device const float * src0,
218
+ device float * dst,
219
+ constant float & min,
220
+ constant float & max,
221
+ uint tpig[[thread_position_in_grid]]) {
222
+ dst[tpig] = src0[tpig] < min ? min : (src0[tpig] > max ? max : src0[tpig]);
223
+ }
224
+
216
225
  kernel void kernel_relu(
217
226
  device const float * src0,
218
227
  device float * dst,
@@ -233,6 +242,15 @@ constant float GELU_QUICK_COEF = -1.702f;
233
242
  constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
234
243
 
235
244
  kernel void kernel_gelu(
245
+ device const float * src0,
246
+ device float * dst,
247
+ uint tpig[[thread_position_in_grid]]) {
248
+ device const float & x = src0[tpig];
249
+
250
+ dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
251
+ }
252
+
253
+ kernel void kernel_gelu_4(
236
254
  device const float4 * src0,
237
255
  device float4 * dst,
238
256
  uint tpig[[thread_position_in_grid]]) {
@@ -246,6 +264,15 @@ kernel void kernel_gelu(
246
264
  }
247
265
 
248
266
  kernel void kernel_gelu_quick(
267
+ device const float * src0,
268
+ device float * dst,
269
+ uint tpig[[thread_position_in_grid]]) {
270
+ device const float & x = src0[tpig];
271
+
272
+ dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
273
+ }
274
+
275
+ kernel void kernel_gelu_quick_4(
249
276
  device const float4 * src0,
250
277
  device float4 * dst,
251
278
  uint tpig[[thread_position_in_grid]]) {
@@ -255,6 +282,14 @@ kernel void kernel_gelu_quick(
255
282
  }
256
283
 
257
284
  kernel void kernel_silu(
285
+ device const float * src0,
286
+ device float * dst,
287
+ uint tpig[[thread_position_in_grid]]) {
288
+ device const float & x = src0[tpig];
289
+ dst[tpig] = x / (1.0f + exp(-x));
290
+ }
291
+
292
+ kernel void kernel_silu_4(
258
293
  device const float4 * src0,
259
294
  device float4 * dst,
260
295
  uint tpig[[thread_position_in_grid]]) {
@@ -866,6 +901,7 @@ void mul_vec_q_n_f32_impl(
866
901
  int64_t ne1,
867
902
  uint r2,
868
903
  uint r3,
904
+ threadgroup int8_t * shared_values,
869
905
  uint3 tgpig, uint tiisg, uint sgitg) {
870
906
  const int nb = ne00/QK4_0;
871
907
 
@@ -942,7 +978,7 @@ kernel void kernel_mul_mv_q4_0_f32(
942
978
  uint3 tgpig[[threadgroup_position_in_grid]],
943
979
  uint tiisg[[thread_index_in_simdgroup]],
944
980
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
945
- mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
981
+ mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
946
982
  }
947
983
 
948
984
  kernel void kernel_mul_mv_q4_1_f32(
@@ -968,7 +1004,7 @@ kernel void kernel_mul_mv_q4_1_f32(
968
1004
  uint3 tgpig[[threadgroup_position_in_grid]],
969
1005
  uint tiisg[[thread_index_in_simdgroup]],
970
1006
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
971
- mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
1007
+ mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
972
1008
  }
973
1009
 
974
1010
  kernel void kernel_mul_mv_q5_0_f32(
@@ -994,7 +1030,7 @@ kernel void kernel_mul_mv_q5_0_f32(
994
1030
  uint3 tgpig[[threadgroup_position_in_grid]],
995
1031
  uint tiisg[[thread_index_in_simdgroup]],
996
1032
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
997
- mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
1033
+ mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
998
1034
  }
999
1035
 
1000
1036
  kernel void kernel_mul_mv_q5_1_f32(
@@ -1020,7 +1056,7 @@ kernel void kernel_mul_mv_q5_1_f32(
1020
1056
  uint3 tgpig[[threadgroup_position_in_grid]],
1021
1057
  uint tiisg[[thread_index_in_simdgroup]],
1022
1058
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
1023
- mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
1059
+ mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
1024
1060
  }
1025
1061
 
1026
1062
 
@@ -1030,18 +1066,19 @@ void kernel_mul_mv_q8_0_f32_impl(
1030
1066
  device const void * src0,
1031
1067
  device const float * src1,
1032
1068
  device float * dst,
1033
- constant int64_t & ne00,
1034
- constant int64_t & ne01,
1035
- constant int64_t & ne02,
1036
- constant int64_t & ne10,
1037
- constant int64_t & ne12,
1038
- constant int64_t & ne0,
1039
- constant int64_t & ne1,
1040
- constant uint & r2,
1041
- constant uint & r3,
1042
- uint3 tgpig[[threadgroup_position_in_grid]],
1043
- uint tiisg[[thread_index_in_simdgroup]],
1044
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
1069
+ int64_t ne00,
1070
+ int64_t ne01,
1071
+ int64_t ne02,
1072
+ int64_t ne10,
1073
+ int64_t ne12,
1074
+ int64_t ne0,
1075
+ int64_t ne1,
1076
+ uint r2,
1077
+ uint r3,
1078
+ threadgroup int8_t * shared_values,
1079
+ uint3 tgpig,
1080
+ uint tiisg,
1081
+ uint sgitg) {
1045
1082
  const int nr = N_DST;
1046
1083
  const int nsg = N_SIMDGROUP;
1047
1084
  const int nw = N_SIMDWIDTH;
@@ -1119,7 +1156,7 @@ kernel void kernel_mul_mv_q8_0_f32(
1119
1156
  uint3 tgpig[[threadgroup_position_in_grid]],
1120
1157
  uint tiisg[[thread_index_in_simdgroup]],
1121
1158
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
1122
- kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
1159
+ kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
1123
1160
  }
1124
1161
 
1125
1162
  #define N_F32_F32 4
@@ -1128,24 +1165,24 @@ void kernel_mul_mv_f32_f32_impl(
1128
1165
  device const char * src0,
1129
1166
  device const char * src1,
1130
1167
  device float * dst,
1131
- constant int64_t & ne00,
1132
- constant int64_t & ne01,
1133
- constant int64_t & ne02,
1134
- constant uint64_t & nb00,
1135
- constant uint64_t & nb01,
1136
- constant uint64_t & nb02,
1137
- constant int64_t & ne10,
1138
- constant int64_t & ne11,
1139
- constant int64_t & ne12,
1140
- constant uint64_t & nb10,
1141
- constant uint64_t & nb11,
1142
- constant uint64_t & nb12,
1143
- constant int64_t & ne0,
1144
- constant int64_t & ne1,
1145
- constant uint & r2,
1146
- constant uint & r3,
1147
- uint3 tgpig[[threadgroup_position_in_grid]],
1148
- uint tiisg[[thread_index_in_simdgroup]]) {
1168
+ int64_t ne00,
1169
+ int64_t ne01,
1170
+ int64_t ne02,
1171
+ uint64_t nb00,
1172
+ uint64_t nb01,
1173
+ uint64_t nb02,
1174
+ int64_t ne10,
1175
+ int64_t ne11,
1176
+ int64_t ne12,
1177
+ uint64_t nb10,
1178
+ uint64_t nb11,
1179
+ uint64_t nb12,
1180
+ int64_t ne0,
1181
+ int64_t ne1,
1182
+ uint r2,
1183
+ uint r3,
1184
+ uint3 tgpig,
1185
+ uint tiisg) {
1149
1186
 
1150
1187
  const int64_t r0 = tgpig.x;
1151
1188
  const int64_t rb = tgpig.y*N_F32_F32;
@@ -1398,24 +1435,24 @@ void kernel_mul_mv_f16_f32_impl(
1398
1435
  device const char * src0,
1399
1436
  device const char * src1,
1400
1437
  device float * dst,
1401
- constant int64_t & ne00,
1402
- constant int64_t & ne01,
1403
- constant int64_t & ne02,
1404
- constant uint64_t & nb00,
1405
- constant uint64_t & nb01,
1406
- constant uint64_t & nb02,
1407
- constant int64_t & ne10,
1408
- constant int64_t & ne11,
1409
- constant int64_t & ne12,
1410
- constant uint64_t & nb10,
1411
- constant uint64_t & nb11,
1412
- constant uint64_t & nb12,
1413
- constant int64_t & ne0,
1414
- constant int64_t & ne1,
1415
- constant uint & r2,
1416
- constant uint & r3,
1417
- uint3 tgpig[[threadgroup_position_in_grid]],
1418
- uint tiisg[[thread_index_in_simdgroup]]) {
1438
+ int64_t ne00,
1439
+ int64_t ne01,
1440
+ int64_t ne02,
1441
+ uint64_t nb00,
1442
+ uint64_t nb01,
1443
+ uint64_t nb02,
1444
+ int64_t ne10,
1445
+ int64_t ne11,
1446
+ int64_t ne12,
1447
+ uint64_t nb10,
1448
+ uint64_t nb11,
1449
+ uint64_t nb12,
1450
+ int64_t ne0,
1451
+ int64_t ne1,
1452
+ uint r2,
1453
+ uint r3,
1454
+ uint3 tgpig,
1455
+ uint tiisg) {
1419
1456
 
1420
1457
  const int64_t r0 = tgpig.x;
1421
1458
  const int64_t rb = tgpig.y*N_F16_F32;
@@ -2700,18 +2737,19 @@ void kernel_mul_mv_q2_K_f32_impl(
2700
2737
  device const void * src0,
2701
2738
  device const float * src1,
2702
2739
  device float * dst,
2703
- constant int64_t & ne00,
2704
- constant int64_t & ne01,
2705
- constant int64_t & ne02,
2706
- constant int64_t & ne10,
2707
- constant int64_t & ne12,
2708
- constant int64_t & ne0,
2709
- constant int64_t & ne1,
2710
- constant uint & r2,
2711
- constant uint & r3,
2712
- uint3 tgpig[[threadgroup_position_in_grid]],
2713
- uint tiisg[[thread_index_in_simdgroup]],
2714
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2740
+ int64_t ne00,
2741
+ int64_t ne01,
2742
+ int64_t ne02,
2743
+ int64_t ne10,
2744
+ int64_t ne12,
2745
+ int64_t ne0,
2746
+ int64_t ne1,
2747
+ uint r2,
2748
+ uint r3,
2749
+ threadgroup int8_t * shared_values,
2750
+ uint3 tgpig,
2751
+ uint tiisg,
2752
+ uint sgitg) {
2715
2753
 
2716
2754
  const int nb = ne00/QK_K;
2717
2755
  const int r0 = tgpig.x;
@@ -2871,7 +2909,7 @@ kernel void kernel_mul_mv_q2_K_f32(
2871
2909
  uint tiisg[[thread_index_in_simdgroup]],
2872
2910
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
2873
2911
 
2874
- kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
2912
+ kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
2875
2913
  }
2876
2914
 
2877
2915
  #if QK_K == 256
@@ -2879,18 +2917,19 @@ void kernel_mul_mv_q3_K_f32_impl(
2879
2917
  device const void * src0,
2880
2918
  device const float * src1,
2881
2919
  device float * dst,
2882
- constant int64_t & ne00,
2883
- constant int64_t & ne01,
2884
- constant int64_t & ne02,
2885
- constant int64_t & ne10,
2886
- constant int64_t & ne12,
2887
- constant int64_t & ne0,
2888
- constant int64_t & ne1,
2889
- constant uint & r2,
2890
- constant uint & r3,
2891
- uint3 tgpig[[threadgroup_position_in_grid]],
2892
- uint tiisg[[thread_index_in_simdgroup]],
2893
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2920
+ int64_t ne00,
2921
+ int64_t ne01,
2922
+ int64_t ne02,
2923
+ int64_t ne10,
2924
+ int64_t ne12,
2925
+ int64_t ne0,
2926
+ int64_t ne1,
2927
+ uint r2,
2928
+ uint r3,
2929
+ threadgroup int8_t * shared_values,
2930
+ uint3 tgpig,
2931
+ uint tiisg,
2932
+ uint sgitg) {
2894
2933
 
2895
2934
  const int nb = ne00/QK_K;
2896
2935
 
@@ -3046,6 +3085,7 @@ void kernel_mul_mv_q3_K_f32_impl(
3046
3085
  constant int64_t & ne1,
3047
3086
  constant uint & r2,
3048
3087
  constant uint & r3,
3088
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
3049
3089
  uint3 tgpig[[threadgroup_position_in_grid]],
3050
3090
  uint tiisg[[thread_index_in_simdgroup]],
3051
3091
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3135,7 +3175,7 @@ kernel void kernel_mul_mv_q3_K_f32(
3135
3175
  uint tiisg[[thread_index_in_simdgroup]],
3136
3176
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
3137
3177
 
3138
- kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
3178
+ kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
3139
3179
  }
3140
3180
 
3141
3181
  #if QK_K == 256
@@ -3143,18 +3183,19 @@ void kernel_mul_mv_q4_K_f32_impl(
3143
3183
  device const void * src0,
3144
3184
  device const float * src1,
3145
3185
  device float * dst,
3146
- constant int64_t & ne00,
3147
- constant int64_t & ne01,
3148
- constant int64_t & ne02,
3149
- constant int64_t & ne10,
3150
- constant int64_t & ne12,
3151
- constant int64_t & ne0,
3152
- constant int64_t & ne1,
3153
- constant uint & r2,
3154
- constant uint & r3,
3155
- uint3 tgpig[[threadgroup_position_in_grid]],
3156
- uint tiisg[[thread_index_in_simdgroup]],
3157
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3186
+ int64_t ne00,
3187
+ int64_t ne01,
3188
+ int64_t ne02,
3189
+ int64_t ne10,
3190
+ int64_t ne12,
3191
+ int64_t ne0,
3192
+ int64_t ne1,
3193
+ uint r2,
3194
+ uint r3,
3195
+ threadgroup int8_t * shared_values,
3196
+ uint3 tgpig,
3197
+ uint tiisg,
3198
+ uint sgitg) {
3158
3199
 
3159
3200
  const uint16_t kmask1 = 0x3f3f;
3160
3201
  const uint16_t kmask2 = 0x0f0f;
@@ -3265,6 +3306,7 @@ void kernel_mul_mv_q4_K_f32_impl(
3265
3306
  constant int64_t & ne1,
3266
3307
  constant uint & r2,
3267
3308
  constant uint & r3,
3309
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
3268
3310
  uint3 tgpig[[threadgroup_position_in_grid]],
3269
3311
  uint tiisg[[thread_index_in_simdgroup]],
3270
3312
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3373,25 +3415,26 @@ kernel void kernel_mul_mv_q4_K_f32(
3373
3415
  uint tiisg[[thread_index_in_simdgroup]],
3374
3416
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
3375
3417
 
3376
- kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
3418
+ kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
3377
3419
  }
3378
3420
 
3379
3421
  void kernel_mul_mv_q5_K_f32_impl(
3380
3422
  device const void * src0,
3381
3423
  device const float * src1,
3382
3424
  device float * dst,
3383
- constant int64_t & ne00,
3384
- constant int64_t & ne01,
3385
- constant int64_t & ne02,
3386
- constant int64_t & ne10,
3387
- constant int64_t & ne12,
3388
- constant int64_t & ne0,
3389
- constant int64_t & ne1,
3390
- constant uint & r2,
3391
- constant uint & r3,
3392
- uint3 tgpig[[threadgroup_position_in_grid]],
3393
- uint tiisg[[thread_index_in_simdgroup]],
3394
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3425
+ int64_t ne00,
3426
+ int64_t ne01,
3427
+ int64_t ne02,
3428
+ int64_t ne10,
3429
+ int64_t ne12,
3430
+ int64_t ne0,
3431
+ int64_t ne1,
3432
+ uint r2,
3433
+ uint r3,
3434
+ threadgroup int8_t * shared_values,
3435
+ uint3 tgpig,
3436
+ uint tiisg,
3437
+ uint sgitg) {
3395
3438
 
3396
3439
  const int nb = ne00/QK_K;
3397
3440
 
@@ -3579,25 +3622,26 @@ kernel void kernel_mul_mv_q5_K_f32(
3579
3622
  uint tiisg[[thread_index_in_simdgroup]],
3580
3623
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
3581
3624
 
3582
- kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
3625
+ kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
3583
3626
  }
3584
3627
 
3585
3628
  void kernel_mul_mv_q6_K_f32_impl(
3586
3629
  device const void * src0,
3587
3630
  device const float * src1,
3588
3631
  device float * dst,
3589
- constant int64_t & ne00,
3590
- constant int64_t & ne01,
3591
- constant int64_t & ne02,
3592
- constant int64_t & ne10,
3593
- constant int64_t & ne12,
3594
- constant int64_t & ne0,
3595
- constant int64_t & ne1,
3596
- constant uint & r2,
3597
- constant uint & r3,
3598
- uint3 tgpig[[threadgroup_position_in_grid]],
3599
- uint tiisg[[thread_index_in_simdgroup]],
3600
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3632
+ int64_t ne00,
3633
+ int64_t ne01,
3634
+ int64_t ne02,
3635
+ int64_t ne10,
3636
+ int64_t ne12,
3637
+ int64_t ne0,
3638
+ int64_t ne1,
3639
+ uint r2,
3640
+ uint r3,
3641
+ threadgroup int8_t * shared_values,
3642
+ uint3 tgpig,
3643
+ uint tiisg,
3644
+ uint sgitg) {
3601
3645
 
3602
3646
  const uint8_t kmask1 = 0x03;
3603
3647
  const uint8_t kmask2 = 0x0C;
@@ -3713,7 +3757,7 @@ kernel void kernel_mul_mv_q6_K_f32(
3713
3757
  uint tiisg[[thread_index_in_simdgroup]],
3714
3758
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
3715
3759
 
3716
- kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
3760
+ kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
3717
3761
  }
3718
3762
 
3719
3763
  // ======================= "True" 2-bit
@@ -3722,19 +3766,19 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
3722
3766
  device const void * src0,
3723
3767
  device const float * src1,
3724
3768
  device float * dst,
3725
- constant int64_t & ne00,
3726
- constant int64_t & ne01,
3727
- constant int64_t & ne02,
3728
- constant int64_t & ne10,
3729
- constant int64_t & ne12,
3730
- constant int64_t & ne0,
3731
- constant int64_t & ne1,
3732
- constant uint & r2,
3733
- constant uint & r3,
3734
- threadgroup int8_t * shared_values [[threadgroup(0)]],
3735
- uint3 tgpig[[threadgroup_position_in_grid]],
3736
- uint tiisg[[thread_index_in_simdgroup]],
3737
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3769
+ int64_t ne00,
3770
+ int64_t ne01,
3771
+ int64_t ne02,
3772
+ int64_t ne10,
3773
+ int64_t ne12,
3774
+ int64_t ne0,
3775
+ int64_t ne1,
3776
+ uint r2,
3777
+ uint r3,
3778
+ threadgroup int8_t * shared_values,
3779
+ uint3 tgpig,
3780
+ uint tiisg,
3781
+ uint sgitg) {
3738
3782
 
3739
3783
  const int nb = ne00/QK_K;
3740
3784
  const int r0 = tgpig.x;
@@ -3851,19 +3895,19 @@ void kernel_mul_mv_iq2_xs_f32_impl(
3851
3895
  device const void * src0,
3852
3896
  device const float * src1,
3853
3897
  device float * dst,
3854
- constant int64_t & ne00,
3855
- constant int64_t & ne01,
3856
- constant int64_t & ne02,
3857
- constant int64_t & ne10,
3858
- constant int64_t & ne12,
3859
- constant int64_t & ne0,
3860
- constant int64_t & ne1,
3861
- constant uint & r2,
3862
- constant uint & r3,
3863
- threadgroup int8_t * shared_values [[threadgroup(0)]],
3864
- uint3 tgpig[[threadgroup_position_in_grid]],
3865
- uint tiisg[[thread_index_in_simdgroup]],
3866
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3898
+ int64_t ne00,
3899
+ int64_t ne01,
3900
+ int64_t ne02,
3901
+ int64_t ne10,
3902
+ int64_t ne12,
3903
+ int64_t ne0,
3904
+ int64_t ne1,
3905
+ uint r2,
3906
+ uint r3,
3907
+ threadgroup int8_t * shared_values,
3908
+ uint3 tgpig,
3909
+ uint tiisg,
3910
+ uint sgitg) {
3867
3911
 
3868
3912
  const int nb = ne00/QK_K;
3869
3913
  const int r0 = tgpig.x;
@@ -3990,19 +4034,19 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
3990
4034
  device const void * src0,
3991
4035
  device const float * src1,
3992
4036
  device float * dst,
3993
- constant int64_t & ne00,
3994
- constant int64_t & ne01,
3995
- constant int64_t & ne02,
3996
- constant int64_t & ne10,
3997
- constant int64_t & ne12,
3998
- constant int64_t & ne0,
3999
- constant int64_t & ne1,
4000
- constant uint & r2,
4001
- constant uint & r3,
4002
- threadgroup int8_t * shared_values [[threadgroup(0)]],
4003
- uint3 tgpig[[threadgroup_position_in_grid]],
4004
- uint tiisg[[thread_index_in_simdgroup]],
4005
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
4037
+ int64_t ne00,
4038
+ int64_t ne01,
4039
+ int64_t ne02,
4040
+ int64_t ne10,
4041
+ int64_t ne12,
4042
+ int64_t ne0,
4043
+ int64_t ne1,
4044
+ uint r2,
4045
+ uint r3,
4046
+ threadgroup int8_t * shared_values,
4047
+ uint3 tgpig,
4048
+ uint tiisg,
4049
+ uint sgitg) {
4006
4050
 
4007
4051
  const int nb = ne00/QK_K;
4008
4052
  const int r0 = tgpig.x;
@@ -4122,19 +4166,19 @@ void kernel_mul_mv_iq3_s_f32_impl(
4122
4166
  device const void * src0,
4123
4167
  device const float * src1,
4124
4168
  device float * dst,
4125
- constant int64_t & ne00,
4126
- constant int64_t & ne01,
4127
- constant int64_t & ne02,
4128
- constant int64_t & ne10,
4129
- constant int64_t & ne12,
4130
- constant int64_t & ne0,
4131
- constant int64_t & ne1,
4132
- constant uint & r2,
4133
- constant uint & r3,
4134
- threadgroup int8_t * shared_values [[threadgroup(0)]],
4135
- uint3 tgpig[[threadgroup_position_in_grid]],
4136
- uint tiisg[[thread_index_in_simdgroup]],
4137
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
4169
+ int64_t ne00,
4170
+ int64_t ne01,
4171
+ int64_t ne02,
4172
+ int64_t ne10,
4173
+ int64_t ne12,
4174
+ int64_t ne0,
4175
+ int64_t ne1,
4176
+ uint r2,
4177
+ uint r3,
4178
+ threadgroup int8_t * shared_values,
4179
+ uint3 tgpig,
4180
+ uint tiisg,
4181
+ uint sgitg) {
4138
4182
 
4139
4183
  const int nb = ne00/QK_K;
4140
4184
  const int r0 = tgpig.x;
@@ -4254,19 +4298,19 @@ void kernel_mul_mv_iq2_s_f32_impl(
4254
4298
  device const void * src0,
4255
4299
  device const float * src1,
4256
4300
  device float * dst,
4257
- constant int64_t & ne00,
4258
- constant int64_t & ne01,
4259
- constant int64_t & ne02,
4260
- constant int64_t & ne10,
4261
- constant int64_t & ne12,
4262
- constant int64_t & ne0,
4263
- constant int64_t & ne1,
4264
- constant uint & r2,
4265
- constant uint & r3,
4266
- threadgroup int8_t * shared_values [[threadgroup(0)]],
4267
- uint3 tgpig[[threadgroup_position_in_grid]],
4268
- uint tiisg[[thread_index_in_simdgroup]],
4269
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
4301
+ int64_t ne00,
4302
+ int64_t ne01,
4303
+ int64_t ne02,
4304
+ int64_t ne10,
4305
+ int64_t ne12,
4306
+ int64_t ne0,
4307
+ int64_t ne1,
4308
+ uint r2,
4309
+ uint r3,
4310
+ threadgroup int8_t * shared_values,
4311
+ uint3 tgpig,
4312
+ uint tiisg,
4313
+ uint sgitg) {
4270
4314
 
4271
4315
  const int nb = ne00/QK_K;
4272
4316
  const int r0 = tgpig.x;
@@ -4387,18 +4431,19 @@ void kernel_mul_mv_iq1_s_f32_impl(
4387
4431
  device const void * src0,
4388
4432
  device const float * src1,
4389
4433
  device float * dst,
4390
- constant int64_t & ne00,
4391
- constant int64_t & ne01,
4392
- constant int64_t & ne02,
4393
- constant int64_t & ne10,
4394
- constant int64_t & ne12,
4395
- constant int64_t & ne0,
4396
- constant int64_t & ne1,
4397
- constant uint & r2,
4398
- constant uint & r3,
4399
- uint3 tgpig[[threadgroup_position_in_grid]],
4400
- uint tiisg[[thread_index_in_simdgroup]],
4401
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
4434
+ int64_t ne00,
4435
+ int64_t ne01,
4436
+ int64_t ne02,
4437
+ int64_t ne10,
4438
+ int64_t ne12,
4439
+ int64_t ne0,
4440
+ int64_t ne1,
4441
+ uint r2,
4442
+ uint r3,
4443
+ threadgroup int8_t * shared_value,
4444
+ uint3 tgpig,
4445
+ uint tiisg,
4446
+ uint sgitg) {
4402
4447
 
4403
4448
  const int nb = ne00/QK_K;
4404
4449
  const int r0 = tgpig.x;
@@ -4476,18 +4521,19 @@ void kernel_mul_mv_iq1_m_f32_impl(
4476
4521
  device const void * src0,
4477
4522
  device const float * src1,
4478
4523
  device float * dst,
4479
- constant int64_t & ne00,
4480
- constant int64_t & ne01,
4481
- constant int64_t & ne02,
4482
- constant int64_t & ne10,
4483
- constant int64_t & ne12,
4484
- constant int64_t & ne0,
4485
- constant int64_t & ne1,
4486
- constant uint & r2,
4487
- constant uint & r3,
4488
- uint3 tgpig[[threadgroup_position_in_grid]],
4489
- uint tiisg[[thread_index_in_simdgroup]],
4490
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
4524
+ int64_t ne00,
4525
+ int64_t ne01,
4526
+ int64_t ne02,
4527
+ int64_t ne10,
4528
+ int64_t ne12,
4529
+ int64_t ne0,
4530
+ int64_t ne1,
4531
+ uint r2,
4532
+ uint r3,
4533
+ threadgroup int8_t * shared_value,
4534
+ uint3 tgpig,
4535
+ uint tiisg,
4536
+ uint sgitg) {
4491
4537
 
4492
4538
  const int nb = ne00/QK_K;
4493
4539
  const int r0 = tgpig.x;
@@ -4584,20 +4630,21 @@ void kernel_mul_mv_iq4_nl_f32_impl(
4584
4630
  device const void * src0,
4585
4631
  device const float * src1,
4586
4632
  device float * dst,
4587
- constant int64_t & ne00,
4588
- constant int64_t & ne01,
4589
- constant int64_t & ne02,
4590
- constant int64_t & ne10,
4591
- constant int64_t & ne12,
4592
- constant int64_t & ne0,
4593
- constant int64_t & ne1,
4594
- constant uint & r2,
4595
- constant uint & r3,
4596
- threadgroup float * shared_values [[threadgroup(0)]],
4597
- uint3 tgpig[[threadgroup_position_in_grid]],
4598
- uint tiisg[[thread_index_in_simdgroup]],
4599
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
4633
+ int64_t ne00,
4634
+ int64_t ne01,
4635
+ int64_t ne02,
4636
+ int64_t ne10,
4637
+ int64_t ne12,
4638
+ int64_t ne0,
4639
+ int64_t ne1,
4640
+ uint r2,
4641
+ uint r3,
4642
+ threadgroup int8_t * shared_values_i8,
4643
+ uint3 tgpig,
4644
+ uint tiisg,
4645
+ uint sgitg) {
4600
4646
 
4647
+ threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
4601
4648
  const int nb = ne00/QK4_NL;
4602
4649
  const int r0 = tgpig.x;
4603
4650
  const int r1 = tgpig.y;
@@ -4678,20 +4725,21 @@ void kernel_mul_mv_iq4_xs_f32_impl(
4678
4725
  device const void * src0,
4679
4726
  device const float * src1,
4680
4727
  device float * dst,
4681
- constant int64_t & ne00,
4682
- constant int64_t & ne01,
4683
- constant int64_t & ne02,
4684
- constant int64_t & ne10,
4685
- constant int64_t & ne12,
4686
- constant int64_t & ne0,
4687
- constant int64_t & ne1,
4688
- constant uint & r2,
4689
- constant uint & r3,
4690
- threadgroup float * shared_values [[threadgroup(0)]],
4691
- uint3 tgpig[[threadgroup_position_in_grid]],
4692
- uint tiisg[[thread_index_in_simdgroup]],
4693
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
4728
+ int64_t ne00,
4729
+ int64_t ne01,
4730
+ int64_t ne02,
4731
+ int64_t ne10,
4732
+ int64_t ne12,
4733
+ int64_t ne0,
4734
+ int64_t ne1,
4735
+ uint r2,
4736
+ uint r3,
4737
+ threadgroup int8_t * shared_values_i8,
4738
+ uint3 tgpig,
4739
+ uint tiisg,
4740
+ uint sgitg) {
4694
4741
 
4742
+ threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
4695
4743
  const int nb = ne00/QK_K;
4696
4744
  const int r0 = tgpig.x;
4697
4745
  const int r1 = tgpig.y;
@@ -4794,7 +4842,7 @@ kernel void kernel_mul_mv_iq1_s_f32(
4794
4842
  uint tiisg[[thread_index_in_simdgroup]],
4795
4843
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
4796
4844
 
4797
- kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
4845
+ kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
4798
4846
  }
4799
4847
 
4800
4848
  [[host_name("kernel_mul_mv_iq1_m_f32")]]
@@ -4822,7 +4870,7 @@ kernel void kernel_mul_mv_iq1_m_f32(
4822
4870
  uint tiisg[[thread_index_in_simdgroup]],
4823
4871
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
4824
4872
 
4825
- kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
4873
+ kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
4826
4874
  }
4827
4875
 
4828
4876
  [[host_name("kernel_mul_mv_iq4_nl_f32")]]
@@ -4846,7 +4894,7 @@ kernel void kernel_mul_mv_iq4_nl_f32(
4846
4894
  constant int64_t & ne1,
4847
4895
  constant uint & r2,
4848
4896
  constant uint & r3,
4849
- threadgroup float * shared_values [[threadgroup(0)]],
4897
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
4850
4898
  uint3 tgpig[[threadgroup_position_in_grid]],
4851
4899
  uint tiisg[[thread_index_in_simdgroup]],
4852
4900
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -4875,7 +4923,7 @@ kernel void kernel_mul_mv_iq4_xs_f32(
4875
4923
  constant int64_t & ne1,
4876
4924
  constant uint & r2,
4877
4925
  constant uint & r3,
4878
- threadgroup float * shared_values [[threadgroup(0)]],
4926
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
4879
4927
  uint3 tgpig[[threadgroup_position_in_grid]],
4880
4928
  uint tiisg[[thread_index_in_simdgroup]],
4881
4929
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -5632,25 +5680,25 @@ void kernel_mul_mm_impl(device const uchar * src0,
5632
5680
  }
5633
5681
  }
5634
5682
 
5635
- // same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in src1ids
5683
+ // same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids
5636
5684
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
5637
5685
  void kernel_mul_mm_id_impl(
5638
5686
  device const uchar * src0,
5639
5687
  device const uchar * src1,
5640
- threadgroup short * src1ids,
5688
+ threadgroup ushort2 * rowids,
5641
5689
  device float * dst,
5642
5690
  constant int64_t & ne00,
5643
5691
  constant int64_t & ne02,
5644
5692
  constant uint64_t & nb01,
5645
5693
  constant uint64_t & nb02,
5694
+ constant int64_t & ne11,
5646
5695
  constant int64_t & ne12,
5647
5696
  constant uint64_t & nb10,
5648
5697
  constant uint64_t & nb11,
5649
5698
  constant uint64_t & nb12,
5650
5699
  constant int64_t & ne0,
5651
5700
  int64_t ne1,
5652
- constant uint & r2,
5653
- constant uint & r3,
5701
+ int64_t ne0ne1,
5654
5702
  threadgroup uchar * shared_memory,
5655
5703
  uint3 tgpig[[threadgroup_position_in_grid]],
5656
5704
  uint tiitg[[thread_index_in_threadgroup]],
@@ -5661,7 +5709,6 @@ void kernel_mul_mm_id_impl(
5661
5709
 
5662
5710
  const uint r0 = tgpig.y;
5663
5711
  const uint r1 = tgpig.x;
5664
- const uint im = tgpig.z;
5665
5712
 
5666
5713
  if (r1 * BLOCK_SIZE_N >= ne1) return;
5667
5714
 
@@ -5679,19 +5726,16 @@ void kernel_mul_mm_id_impl(
5679
5726
  for (int i = 0; i < 8; i++){
5680
5727
  c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
5681
5728
  }
5682
-
5683
5729
  short il = (tiitg % THREAD_PER_ROW);
5684
5730
 
5685
- const uint i12 = im%ne12;
5686
- const uint i13 = im/ne12;
5687
-
5688
- uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
5689
5731
  ushort offset1 = il/nl;
5690
5732
 
5691
- device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
5733
+ threadgroup const auto & id = rowids[r1 * BLOCK_SIZE_N + thread_col];
5734
+
5735
+ device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01) + offset1;
5692
5736
  device const float * y = (device const float *)(src1
5693
- + nb12 * im
5694
- + nb11 * src1ids[r1 * BLOCK_SIZE_N + thread_col]
5737
+ + nb12 * id[1]
5738
+ + nb11 * (id[0] % ne11)
5695
5739
  + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
5696
5740
 
5697
5741
  for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
@@ -5720,11 +5764,11 @@ void kernel_mul_mm_id_impl(
5720
5764
 
5721
5765
  for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
5722
5766
  for (int i = 0; i < 4; i++) {
5723
- simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
5767
+ simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
5724
5768
  }
5725
5769
  simdgroup_barrier(mem_flags::mem_none);
5726
5770
  for (int i = 0; i < 2; i++) {
5727
- simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
5771
+ simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
5728
5772
  }
5729
5773
 
5730
5774
  lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
@@ -5746,11 +5790,13 @@ void kernel_mul_mm_id_impl(
5746
5790
 
5747
5791
  threadgroup_barrier(mem_flags::mem_threadgroup);
5748
5792
 
5749
- device float * C = dst + (BLOCK_SIZE_M * r0) + im*ne1*ne0;
5793
+ device float * C = dst + (BLOCK_SIZE_M * r0);
5750
5794
  if (sgitg == 0) {
5751
- for (int i = 0; i < n_rows; i++) {
5752
- for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
5753
- *(C + i + src1ids[j + r1*BLOCK_SIZE_N] * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
5795
+ for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
5796
+ threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j];
5797
+ int joff = jid[0] * ne0 + jid[1] * ne0ne1;
5798
+ for (int i = 0; i < n_rows; i++) {
5799
+ *(C + i + joff) = *(temp_str + i + j * BLOCK_SIZE_M);
5754
5800
  }
5755
5801
  }
5756
5802
  }
@@ -5805,11 +5851,14 @@ kernel void kernel_mul_mm_id(
5805
5851
  device const uchar * src1,
5806
5852
  device float * dst,
5807
5853
  device const uchar * ids,
5854
+ constant int64_t & nei0,
5855
+ constant int64_t & nei1,
5808
5856
  constant uint64_t & nbi1,
5809
5857
  constant int64_t & ne00,
5810
5858
  constant int64_t & ne02,
5811
5859
  constant uint64_t & nb01,
5812
5860
  constant uint64_t & nb02,
5861
+ constant int64_t & ne11,
5813
5862
  constant int64_t & ne12,
5814
5863
  constant int64_t & ne13,
5815
5864
  constant uint64_t & nb10,
@@ -5818,47 +5867,52 @@ kernel void kernel_mul_mm_id(
5818
5867
  constant int64_t & ne0,
5819
5868
  constant int64_t & ne1,
5820
5869
  constant uint64_t & nb1,
5821
- constant uint & r2,
5822
- constant uint & r3,
5823
- constant int & idx,
5824
5870
  threadgroup uchar * shared_memory [[threadgroup(0)]],
5825
5871
  uint3 tgpig[[threadgroup_position_in_grid]],
5826
5872
  uint tiitg[[thread_index_in_threadgroup]],
5827
5873
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
5828
5874
 
5829
- // expert id
5830
- const int32_t id = tgpig.z/(ne12*ne13);
5831
- device const uchar * src0 = src0s + id*nb02;
5875
+ const int32_t i02 = tgpig.z;
5876
+ tgpig.z = 0;
5832
5877
 
5833
- tgpig.z = tgpig.z%(ne12*ne13);
5878
+ device const uchar * src0 = src0s + i02*nb02;
5834
5879
 
5835
- // row indices of src1 for expert id
5836
- threadgroup short * src1ids = (threadgroup short *)(shared_memory + 8192);
5880
+ // row indices
5881
+ threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
5837
5882
 
5883
+ // TODO: parallelize this loop
5838
5884
  int64_t _ne1 = 0;
5839
- for (int64_t i1 = 0; i1 < ne1; i1++) {
5840
- if (((device int32_t *) (ids + i1*nbi1))[idx] == id) {
5841
- src1ids[_ne1++] = i1;
5885
+ for (ushort ii1 = 0; ii1 < nei1; ii1++) {
5886
+ for (ushort ii0 = 0; ii0 < nei0; ii0++) {
5887
+ int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
5888
+ if (id == i02) {
5889
+ //if (tiitg == 0) {
5890
+ rowids[_ne1] = ushort2(ii0, ii1);
5891
+ //}
5892
+ _ne1++;
5893
+ }
5842
5894
  }
5843
5895
  }
5844
5896
 
5897
+ threadgroup_barrier(mem_flags::mem_threadgroup);
5898
+
5845
5899
  kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
5846
5900
  src0,
5847
5901
  src1,
5848
- src1ids,
5902
+ rowids,
5849
5903
  dst,
5850
5904
  ne00,
5851
5905
  ne02,
5852
5906
  nb01,
5853
5907
  nb02,
5908
+ ne11,
5854
5909
  ne12,
5855
5910
  nb10,
5856
5911
  nb11,
5857
5912
  nb12,
5858
5913
  ne0,
5859
5914
  _ne1,
5860
- r2,
5861
- r3,
5915
+ ne0*ne1,
5862
5916
  shared_memory,
5863
5917
  tgpig,
5864
5918
  tiitg,
@@ -5919,24 +5973,7 @@ template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_r
5919
5973
  // matrix-matrix multiplication
5920
5974
  //
5921
5975
 
5922
- typedef void (mat_mm_t)(
5923
- device const uchar * src0,
5924
- device const uchar * src1,
5925
- device float * dst,
5926
- constant int64_t & ne00,
5927
- constant int64_t & ne02,
5928
- constant uint64_t & nb01,
5929
- constant uint64_t & nb02,
5930
- constant int64_t & ne12,
5931
- constant uint64_t & nb10,
5932
- constant uint64_t & nb11,
5933
- constant uint64_t & nb12,
5934
- constant int64_t & ne0,
5935
- constant int64_t & ne1,
5936
- constant uint & r2,
5937
- constant uint & r3,
5938
- threadgroup uchar *,
5939
- uint3, uint, uint);
5976
+ typedef decltype(kernel_mul_mm<float4x4, 1, dequantize_f32>) mat_mm_t;
5940
5977
 
5941
5978
  template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
5942
5979
  template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
@@ -5968,29 +6005,7 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_m
5968
6005
  // indirect matrix-matrix multiplication
5969
6006
  //
5970
6007
 
5971
- typedef void (mat_mm_id_t)(
5972
- device const uchar * src0s,
5973
- device const uchar * src1,
5974
- device float * dst,
5975
- device const uchar * ids,
5976
- constant uint64_t & nbi1,
5977
- constant int64_t & ne00,
5978
- constant int64_t & ne02,
5979
- constant uint64_t & nb01,
5980
- constant uint64_t & nb02,
5981
- constant int64_t & ne12,
5982
- constant int64_t & ne13,
5983
- constant uint64_t & nb10,
5984
- constant uint64_t & nb11,
5985
- constant uint64_t & nb12,
5986
- constant int64_t & ne0,
5987
- constant int64_t & ne1,
5988
- constant uint64_t & nb1,
5989
- constant uint & r2,
5990
- constant uint & r3,
5991
- constant int & idx,
5992
- threadgroup uchar *,
5993
- uint3, uint, uint);
6008
+ typedef decltype(kernel_mul_mm_id<float4x4, 1, dequantize_f32>) mat_mm_id_t;
5994
6009
 
5995
6010
  template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
5996
6011
  template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
@@ -6022,12 +6037,119 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel
6022
6037
  // matrix-vector multiplication
6023
6038
  //
6024
6039
 
6025
- [[host_name("kernel_mul_mv_id_f32_f32")]]
6026
- kernel void kernel_mul_mv_id_f32_f32(
6040
+ typedef void (kernel_mul_mv_impl_t)(
6041
+ device const char * src0,
6042
+ device const char * src1,
6043
+ device float * dst,
6044
+ int64_t ne00,
6045
+ int64_t ne01,
6046
+ int64_t ne02,
6047
+ uint64_t nb00,
6048
+ uint64_t nb01,
6049
+ uint64_t nb02,
6050
+ int64_t ne10,
6051
+ int64_t ne11,
6052
+ int64_t ne12,
6053
+ uint64_t nb10,
6054
+ uint64_t nb11,
6055
+ uint64_t nb12,
6056
+ int64_t ne0,
6057
+ int64_t ne1,
6058
+ uint r2,
6059
+ uint r3,
6060
+ uint3 tgpig,
6061
+ uint tiisg);
6062
+
6063
+ typedef void (kernel_mul_mv2_impl_t)(
6064
+ device const void * src0,
6065
+ device const float * src1,
6066
+ device float * dst,
6067
+ int64_t ne00,
6068
+ int64_t ne01,
6069
+ int64_t ne02,
6070
+ int64_t ne10,
6071
+ int64_t ne12,
6072
+ int64_t ne0,
6073
+ int64_t ne1,
6074
+ uint r2,
6075
+ uint r3,
6076
+ threadgroup int8_t * shared_values,
6077
+ uint3 tgpig,
6078
+ uint tiisg,
6079
+ uint sgitg);
6080
+
6081
+ template<kernel_mul_mv_impl_t impl_fn>
6082
+ void mmv_fn(
6083
+ device const char * src0,
6084
+ device const char * src1,
6085
+ device float * dst,
6086
+ int64_t ne00,
6087
+ int64_t ne01,
6088
+ int64_t ne02,
6089
+ uint64_t nb00,
6090
+ uint64_t nb01,
6091
+ uint64_t nb02,
6092
+ int64_t ne10,
6093
+ int64_t ne11,
6094
+ int64_t ne12,
6095
+ int64_t ne13,
6096
+ uint64_t nb10,
6097
+ uint64_t nb11,
6098
+ uint64_t nb12,
6099
+ int64_t ne0,
6100
+ int64_t ne1,
6101
+ uint64_t nb1,
6102
+ uint r2,
6103
+ uint r3,
6104
+ threadgroup int8_t * shared_values,
6105
+ uint3 tgpig,
6106
+ uint tiitg,
6107
+ uint tiisg,
6108
+ uint sgitg) {
6109
+ impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,ne10,ne11,ne12,nb10,nb11,nb12,ne0,ne1,r2,r3,tgpig,tiisg);
6110
+ }
6111
+
6112
+ template<kernel_mul_mv2_impl_t impl_fn>
6113
+ void mmv_fn(
6114
+ device const char * src0,
6115
+ device const char * src1,
6116
+ device float * dst,
6117
+ int64_t ne00,
6118
+ int64_t ne01,
6119
+ int64_t ne02,
6120
+ uint64_t nb00,
6121
+ uint64_t nb01,
6122
+ uint64_t nb02,
6123
+ int64_t ne10,
6124
+ int64_t ne11,
6125
+ int64_t ne12,
6126
+ int64_t ne13,
6127
+ uint64_t nb10,
6128
+ uint64_t nb11,
6129
+ uint64_t nb12,
6130
+ int64_t ne0,
6131
+ int64_t ne1,
6132
+ uint64_t nb1,
6133
+ uint r2,
6134
+ uint r3,
6135
+ threadgroup int8_t * shared_values,
6136
+ uint3 tgpig,
6137
+ uint tiitg,
6138
+ uint tiisg,
6139
+ uint sgitg) {
6140
+ impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
6141
+ }
6142
+
6143
+ typedef decltype(mmv_fn<kernel_mul_mv_f32_f32_impl>) mul_mv_impl_fn_t;
6144
+
6145
+ template<mul_mv_impl_fn_t impl_fn>
6146
+ kernel void kernel_mul_mv_id(
6027
6147
  device const char * src0s,
6028
6148
  device const char * src1,
6029
6149
  device float * dst,
6030
6150
  device const char * ids,
6151
+ constant int64_t & nei0,
6152
+ constant int64_t & nei1,
6031
6153
  constant uint64_t & nbi1,
6032
6154
  constant int64_t & ne00,
6033
6155
  constant int64_t & ne01,
@@ -6045,1164 +6167,80 @@ kernel void kernel_mul_mv_id_f32_f32(
6045
6167
  constant int64_t & ne0,
6046
6168
  constant int64_t & ne1,
6047
6169
  constant uint64_t & nb1,
6048
- constant uint & r2,
6049
- constant uint & r3,
6050
- constant int & idx,
6170
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
6051
6171
  uint3 tgpig[[threadgroup_position_in_grid]],
6052
6172
  uint tiitg[[thread_index_in_threadgroup]],
6053
6173
  uint tiisg[[thread_index_in_simdgroup]],
6054
6174
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6055
- const int64_t bid = tgpig.z/(ne12*ne13);
6056
-
6057
- tgpig.z = tgpig.z%(ne12*ne13);
6058
-
6059
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6060
- device const char * src0 = src0s + id*nb02;
6061
-
6062
- kernel_mul_mv_f32_f32_impl(
6063
- src0,
6064
- src1 + bid*nb11,
6065
- dst + bid*ne0,
6066
- ne00,
6067
- ne01,
6068
- ne02,
6069
- nb00,
6070
- nb01,
6071
- nb02,
6072
- ne10,
6073
- ne11,
6074
- ne12,
6075
- nb10,
6076
- nb11,
6077
- nb12,
6078
- ne0,
6079
- ne1,
6080
- r2,
6081
- r3,
6175
+ const int iid1 = tgpig.z/nei0;
6176
+ const int idx = tgpig.z%nei0;
6177
+
6178
+ tgpig.z = 0;
6179
+
6180
+ const int32_t i02 = ((device const int32_t *) (ids + iid1*nbi1))[idx];
6181
+
6182
+ const int64_t i11 = idx % ne11;
6183
+ const int64_t i12 = iid1;
6184
+
6185
+ const int64_t i1 = idx;
6186
+ const int64_t i2 = i12;
6187
+
6188
+ device const char * src0_cur = src0s + i02*nb02;
6189
+ device const char * src1_cur = src1 + i11*nb11 + i12*nb12;
6190
+ device float * dst_cur = dst + i1*ne0 + i2*ne1*ne0;
6191
+
6192
+ impl_fn(
6193
+ /* src0 */ src0_cur,
6194
+ /* src1 */ src1_cur,
6195
+ /* dst */ dst_cur,
6196
+ /* ne00 */ ne00,
6197
+ /* ne01 */ ne01,
6198
+ /* ne02 */ 1,//ne02,
6199
+ /* nb00 */ nb00,
6200
+ /* nb01 */ nb01,
6201
+ /* nb02 */ nb02,
6202
+ /* ne10 */ ne10,
6203
+ /* ne11 */ 1,//ne11,
6204
+ /* ne12 */ 1,//ne12,
6205
+ /* ne13 */ 1,//ne13,
6206
+ /* nb10 */ nb10,
6207
+ /* nb11 */ nb11,
6208
+ /* nb12 */ nb12,
6209
+ /* ne0 */ ne0,
6210
+ /* ne1 */ 1,//ne1,
6211
+ /* nb1 */ nb1,
6212
+ /* r2 */ 1,
6213
+ /* r3 */ 1,
6214
+ shared_values,
6082
6215
  tgpig,
6083
- tiisg);
6216
+ tiitg,
6217
+ tiisg,
6218
+ sgitg);
6084
6219
  }
6085
6220
 
6086
- [[host_name("kernel_mul_mv_id_f16_f32")]]
6087
- kernel void kernel_mul_mv_id_f16_f32(
6088
- device const char * src0s,
6089
- device const char * src1,
6090
- device float * dst,
6091
- device const char * ids,
6092
- constant uint64_t & nbi1,
6093
- constant int64_t & ne00,
6094
- constant int64_t & ne01,
6095
- constant int64_t & ne02,
6096
- constant uint64_t & nb00,
6097
- constant uint64_t & nb01,
6098
- constant uint64_t & nb02,
6099
- constant int64_t & ne10,
6100
- constant int64_t & ne11,
6101
- constant int64_t & ne12,
6102
- constant int64_t & ne13,
6103
- constant uint64_t & nb10,
6104
- constant uint64_t & nb11,
6105
- constant uint64_t & nb12,
6106
- constant int64_t & ne0,
6107
- constant int64_t & ne1,
6108
- constant uint64_t & nb1,
6109
- constant uint & r2,
6110
- constant uint & r3,
6111
- constant int & idx,
6112
- uint3 tgpig[[threadgroup_position_in_grid]],
6113
- uint tiitg[[thread_index_in_threadgroup]],
6114
- uint tiisg[[thread_index_in_simdgroup]],
6115
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6116
- const int64_t bid = tgpig.z/(ne12*ne13);
6117
-
6118
- tgpig.z = tgpig.z%(ne12*ne13);
6119
-
6120
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6121
- device const char * src0 = src0s + id*nb02;
6122
-
6123
- kernel_mul_mv_f16_f32_impl(
6124
- src0,
6125
- src1 + bid*nb11,
6126
- dst + bid*ne0,
6127
- ne00,
6128
- ne01,
6129
- ne02,
6130
- nb00,
6131
- nb01,
6132
- nb02,
6133
- ne10,
6134
- ne11,
6135
- ne12,
6136
- nb10,
6137
- nb11,
6138
- nb12,
6139
- ne0,
6140
- ne1,
6141
- r2,
6142
- r3,
6143
- tgpig,
6144
- tiisg);
6145
- }
6146
-
6147
- [[host_name("kernel_mul_mv_id_q8_0_f32")]]
6148
- kernel void kernel_mul_mv_id_q8_0_f32(
6149
- device const char * src0s,
6150
- device const char * src1,
6151
- device float * dst,
6152
- device const char * ids,
6153
- constant uint64_t & nbi1,
6154
- constant int64_t & ne00,
6155
- constant int64_t & ne01,
6156
- constant int64_t & ne02,
6157
- constant uint64_t & nb00,
6158
- constant uint64_t & nb01,
6159
- constant uint64_t & nb02,
6160
- constant int64_t & ne10,
6161
- constant int64_t & ne11,
6162
- constant int64_t & ne12,
6163
- constant int64_t & ne13,
6164
- constant uint64_t & nb10,
6165
- constant uint64_t & nb11,
6166
- constant uint64_t & nb12,
6167
- constant int64_t & ne0,
6168
- constant int64_t & ne1,
6169
- constant uint64_t & nb1,
6170
- constant uint & r2,
6171
- constant uint & r3,
6172
- constant int & idx,
6173
- uint3 tgpig[[threadgroup_position_in_grid]],
6174
- uint tiitg[[thread_index_in_threadgroup]],
6175
- uint tiisg[[thread_index_in_simdgroup]],
6176
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6177
- const int64_t bid = tgpig.z/(ne12*ne13);
6178
-
6179
- tgpig.z = tgpig.z%(ne12*ne13);
6180
-
6181
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6182
- device const char * src0 = src0s + id*nb02;
6183
-
6184
- kernel_mul_mv_q8_0_f32_impl(
6185
- src0,
6186
- (device const float *) (src1 + bid*nb11),
6187
- dst + bid*ne0,
6188
- ne00,
6189
- ne01,
6190
- ne02,
6191
- ne10,
6192
- ne12,
6193
- ne0,
6194
- ne1,
6195
- r2,
6196
- r3,
6197
- tgpig,
6198
- tiisg,
6199
- sgitg);
6200
- }
6201
-
6202
- [[host_name("kernel_mul_mv_id_q4_0_f32")]]
6203
- kernel void kernel_mul_mv_id_q4_0_f32(
6204
- device const char * src0s,
6205
- device const char * src1,
6206
- device float * dst,
6207
- device const char * ids,
6208
- constant uint64_t & nbi1,
6209
- constant int64_t & ne00,
6210
- constant int64_t & ne01,
6211
- constant int64_t & ne02,
6212
- constant uint64_t & nb00,
6213
- constant uint64_t & nb01,
6214
- constant uint64_t & nb02,
6215
- constant int64_t & ne10,
6216
- constant int64_t & ne11,
6217
- constant int64_t & ne12,
6218
- constant int64_t & ne13,
6219
- constant uint64_t & nb10,
6220
- constant uint64_t & nb11,
6221
- constant uint64_t & nb12,
6222
- constant int64_t & ne0,
6223
- constant int64_t & ne1,
6224
- constant uint64_t & nb1,
6225
- constant uint & r2,
6226
- constant uint & r3,
6227
- constant int & idx,
6228
- uint3 tgpig[[threadgroup_position_in_grid]],
6229
- uint tiitg[[thread_index_in_threadgroup]],
6230
- uint tiisg[[thread_index_in_simdgroup]],
6231
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6232
- const int64_t bid = tgpig.z/(ne12*ne13);
6233
-
6234
- tgpig.z = tgpig.z%(ne12*ne13);
6235
-
6236
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6237
- device const char * src0 = src0s + id*nb02;
6238
-
6239
- mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
6240
- src0,
6241
- (device const float *) (src1 + bid*nb11),
6242
- dst + bid*ne0,
6243
- ne00,
6244
- ne01,
6245
- ne02,
6246
- ne10,
6247
- ne12,
6248
- ne0,
6249
- ne1,
6250
- r2,
6251
- r3,
6252
- tgpig,
6253
- tiisg,
6254
- sgitg);
6255
- }
6256
-
6257
- [[host_name("kernel_mul_mv_id_q4_1_f32")]]
6258
- kernel void kernel_mul_mv_id_q4_1_f32(
6259
- device const char * src0s,
6260
- device const char * src1,
6261
- device float * dst,
6262
- device const char * ids,
6263
- constant uint64_t & nbi1,
6264
- constant int64_t & ne00,
6265
- constant int64_t & ne01,
6266
- constant int64_t & ne02,
6267
- constant uint64_t & nb00,
6268
- constant uint64_t & nb01,
6269
- constant uint64_t & nb02,
6270
- constant int64_t & ne10,
6271
- constant int64_t & ne11,
6272
- constant int64_t & ne12,
6273
- constant int64_t & ne13,
6274
- constant uint64_t & nb10,
6275
- constant uint64_t & nb11,
6276
- constant uint64_t & nb12,
6277
- constant int64_t & ne0,
6278
- constant int64_t & ne1,
6279
- constant uint64_t & nb1,
6280
- constant uint & r2,
6281
- constant uint & r3,
6282
- constant int & idx,
6283
- uint3 tgpig[[threadgroup_position_in_grid]],
6284
- uint tiitg[[thread_index_in_threadgroup]],
6285
- uint tiisg[[thread_index_in_simdgroup]],
6286
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6287
- const int64_t bid = tgpig.z/(ne12*ne13);
6288
-
6289
- tgpig.z = tgpig.z%(ne12*ne13);
6290
-
6291
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6292
- device const char * src0 = src0s + id*nb02;
6293
-
6294
- mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
6295
- src0,
6296
- (device const float *) (src1 + bid*nb11),
6297
- dst + bid*ne0,
6298
- ne00,
6299
- ne01,
6300
- ne02,
6301
- ne10,
6302
- ne12,
6303
- ne0,
6304
- ne1,
6305
- r2,
6306
- r3,
6307
- tgpig,
6308
- tiisg,
6309
- sgitg);
6310
- }
6311
-
6312
- [[host_name("kernel_mul_mv_id_q5_0_f32")]]
6313
- kernel void kernel_mul_mv_id_q5_0_f32(
6314
- device const char * src0s,
6315
- device const char * src1,
6316
- device float * dst,
6317
- device const char * ids,
6318
- constant uint64_t & nbi1,
6319
- constant int64_t & ne00,
6320
- constant int64_t & ne01,
6321
- constant int64_t & ne02,
6322
- constant uint64_t & nb00,
6323
- constant uint64_t & nb01,
6324
- constant uint64_t & nb02,
6325
- constant int64_t & ne10,
6326
- constant int64_t & ne11,
6327
- constant int64_t & ne12,
6328
- constant int64_t & ne13,
6329
- constant uint64_t & nb10,
6330
- constant uint64_t & nb11,
6331
- constant uint64_t & nb12,
6332
- constant int64_t & ne0,
6333
- constant int64_t & ne1,
6334
- constant uint64_t & nb1,
6335
- constant uint & r2,
6336
- constant uint & r3,
6337
- constant int & idx,
6338
- uint3 tgpig[[threadgroup_position_in_grid]],
6339
- uint tiitg[[thread_index_in_threadgroup]],
6340
- uint tiisg[[thread_index_in_simdgroup]],
6341
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6342
- const int64_t bid = tgpig.z/(ne12*ne13);
6343
-
6344
- tgpig.z = tgpig.z%(ne12*ne13);
6345
-
6346
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6347
- device const char * src0 = src0s + id*nb02;
6348
-
6349
- mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
6350
- src0,
6351
- (device const float *) (src1 + bid*nb11),
6352
- dst + bid*ne0,
6353
- ne00,
6354
- ne01,
6355
- ne02,
6356
- ne10,
6357
- ne12,
6358
- ne0,
6359
- ne1,
6360
- r2,
6361
- r3,
6362
- tgpig,
6363
- tiisg,
6364
- sgitg);
6365
- }
6366
-
6367
- [[host_name("kernel_mul_mv_id_q5_1_f32")]]
6368
- kernel void kernel_mul_mv_id_q5_1_f32(
6369
- device const char * src0s,
6370
- device const char * src1,
6371
- device float * dst,
6372
- device const char * ids,
6373
- constant uint64_t & nbi1,
6374
- constant int64_t & ne00,
6375
- constant int64_t & ne01,
6376
- constant int64_t & ne02,
6377
- constant uint64_t & nb00,
6378
- constant uint64_t & nb01,
6379
- constant uint64_t & nb02,
6380
- constant int64_t & ne10,
6381
- constant int64_t & ne11,
6382
- constant int64_t & ne12,
6383
- constant int64_t & ne13,
6384
- constant uint64_t & nb10,
6385
- constant uint64_t & nb11,
6386
- constant uint64_t & nb12,
6387
- constant int64_t & ne0,
6388
- constant int64_t & ne1,
6389
- constant uint64_t & nb1,
6390
- constant uint & r2,
6391
- constant uint & r3,
6392
- constant int & idx,
6393
- uint3 tgpig[[threadgroup_position_in_grid]],
6394
- uint tiitg[[thread_index_in_threadgroup]],
6395
- uint tiisg[[thread_index_in_simdgroup]],
6396
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6397
- const int64_t bid = tgpig.z/(ne12*ne13);
6398
-
6399
- tgpig.z = tgpig.z%(ne12*ne13);
6400
-
6401
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6402
- device const char * src0 = src0s + id*nb02;
6403
-
6404
- mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
6405
- src0,
6406
- (device const float *) (src1 + bid*nb11),
6407
- dst + bid*ne0,
6408
- ne00,
6409
- ne01,
6410
- ne02,
6411
- ne10,
6412
- ne12,
6413
- ne0,
6414
- ne1,
6415
- r2,
6416
- r3,
6417
- tgpig,
6418
- tiisg,
6419
- sgitg);
6420
- }
6421
-
6422
- [[host_name("kernel_mul_mv_id_q2_K_f32")]]
6423
- kernel void kernel_mul_mv_id_q2_K_f32(
6424
- device const char * src0s,
6425
- device const char * src1,
6426
- device float * dst,
6427
- device const char * ids,
6428
- constant uint64_t & nbi1,
6429
- constant int64_t & ne00,
6430
- constant int64_t & ne01,
6431
- constant int64_t & ne02,
6432
- constant uint64_t & nb00,
6433
- constant uint64_t & nb01,
6434
- constant uint64_t & nb02,
6435
- constant int64_t & ne10,
6436
- constant int64_t & ne11,
6437
- constant int64_t & ne12,
6438
- constant int64_t & ne13,
6439
- constant uint64_t & nb10,
6440
- constant uint64_t & nb11,
6441
- constant uint64_t & nb12,
6442
- constant int64_t & ne0,
6443
- constant int64_t & ne1,
6444
- constant uint64_t & nb1,
6445
- constant uint & r2,
6446
- constant uint & r3,
6447
- constant int & idx,
6448
- uint3 tgpig[[threadgroup_position_in_grid]],
6449
- uint tiitg[[thread_index_in_threadgroup]],
6450
- uint tiisg[[thread_index_in_simdgroup]],
6451
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6452
- const int64_t bid = tgpig.z/(ne12*ne13);
6453
-
6454
- tgpig.z = tgpig.z%(ne12*ne13);
6455
-
6456
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6457
- device const char * src0 = src0s + id*nb02;
6458
-
6459
- kernel_mul_mv_q2_K_f32_impl(
6460
- src0,
6461
- (device const float *) (src1 + bid*nb11),
6462
- dst + bid*ne0,
6463
- ne00,
6464
- ne01,
6465
- ne02,
6466
- ne10,
6467
- ne12,
6468
- ne0,
6469
- ne1,
6470
- r2,
6471
- r3,
6472
- tgpig,
6473
- tiisg,
6474
- sgitg);
6475
- }
6476
-
6477
- [[host_name("kernel_mul_mv_id_q3_K_f32")]]
6478
- kernel void kernel_mul_mv_id_q3_K_f32(
6479
- device const char * src0s,
6480
- device const char * src1,
6481
- device float * dst,
6482
- device const char * ids,
6483
- constant uint64_t & nbi1,
6484
- constant int64_t & ne00,
6485
- constant int64_t & ne01,
6486
- constant int64_t & ne02,
6487
- constant uint64_t & nb00,
6488
- constant uint64_t & nb01,
6489
- constant uint64_t & nb02,
6490
- constant int64_t & ne10,
6491
- constant int64_t & ne11,
6492
- constant int64_t & ne12,
6493
- constant int64_t & ne13,
6494
- constant uint64_t & nb10,
6495
- constant uint64_t & nb11,
6496
- constant uint64_t & nb12,
6497
- constant int64_t & ne0,
6498
- constant int64_t & ne1,
6499
- constant uint64_t & nb1,
6500
- constant uint & r2,
6501
- constant uint & r3,
6502
- constant int & idx,
6503
- uint3 tgpig[[threadgroup_position_in_grid]],
6504
- uint tiitg[[thread_index_in_threadgroup]],
6505
- uint tiisg[[thread_index_in_simdgroup]],
6506
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6507
- const int64_t bid = tgpig.z/(ne12*ne13);
6508
-
6509
- tgpig.z = tgpig.z%(ne12*ne13);
6510
-
6511
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6512
- device const char * src0 = src0s + id*nb02;
6513
-
6514
- kernel_mul_mv_q3_K_f32_impl(
6515
- src0,
6516
- (device const float *) (src1 + bid*nb11),
6517
- dst + bid*ne0,
6518
- ne00,
6519
- ne01,
6520
- ne02,
6521
- ne10,
6522
- ne12,
6523
- ne0,
6524
- ne1,
6525
- r2,
6526
- r3,
6527
- tgpig,
6528
- tiisg,
6529
- sgitg);
6530
- }
6531
-
6532
- [[host_name("kernel_mul_mv_id_q4_K_f32")]]
6533
- kernel void kernel_mul_mv_id_q4_K_f32(
6534
- device const char * src0s,
6535
- device const char * src1,
6536
- device float * dst,
6537
- device const char * ids,
6538
- constant uint64_t & nbi1,
6539
- constant int64_t & ne00,
6540
- constant int64_t & ne01,
6541
- constant int64_t & ne02,
6542
- constant uint64_t & nb00,
6543
- constant uint64_t & nb01,
6544
- constant uint64_t & nb02,
6545
- constant int64_t & ne10,
6546
- constant int64_t & ne11,
6547
- constant int64_t & ne12,
6548
- constant int64_t & ne13,
6549
- constant uint64_t & nb10,
6550
- constant uint64_t & nb11,
6551
- constant uint64_t & nb12,
6552
- constant int64_t & ne0,
6553
- constant int64_t & ne1,
6554
- constant uint64_t & nb1,
6555
- constant uint & r2,
6556
- constant uint & r3,
6557
- constant int & idx,
6558
- uint3 tgpig[[threadgroup_position_in_grid]],
6559
- uint tiitg[[thread_index_in_threadgroup]],
6560
- uint tiisg[[thread_index_in_simdgroup]],
6561
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6562
- const int64_t bid = tgpig.z/(ne12*ne13);
6563
-
6564
- tgpig.z = tgpig.z%(ne12*ne13);
6565
-
6566
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6567
- device const char * src0 = src0s + id*nb02;
6568
-
6569
- kernel_mul_mv_q4_K_f32_impl(
6570
- src0,
6571
- (device const float *) (src1 + bid*nb11),
6572
- dst + bid*ne0,
6573
- ne00,
6574
- ne01,
6575
- ne02,
6576
- ne10,
6577
- ne12,
6578
- ne0,
6579
- ne1,
6580
- r2,
6581
- r3,
6582
- tgpig,
6583
- tiisg,
6584
- sgitg);
6585
- }
6586
-
6587
- [[host_name("kernel_mul_mv_id_q5_K_f32")]]
6588
- kernel void kernel_mul_mv_id_q5_K_f32(
6589
- device const char * src0s,
6590
- device const char * src1,
6591
- device float * dst,
6592
- device const char * ids,
6593
- constant uint64_t & nbi1,
6594
- constant int64_t & ne00,
6595
- constant int64_t & ne01,
6596
- constant int64_t & ne02,
6597
- constant uint64_t & nb00,
6598
- constant uint64_t & nb01,
6599
- constant uint64_t & nb02,
6600
- constant int64_t & ne10,
6601
- constant int64_t & ne11,
6602
- constant int64_t & ne12,
6603
- constant int64_t & ne13,
6604
- constant uint64_t & nb10,
6605
- constant uint64_t & nb11,
6606
- constant uint64_t & nb12,
6607
- constant int64_t & ne0,
6608
- constant int64_t & ne1,
6609
- constant uint64_t & nb1,
6610
- constant uint & r2,
6611
- constant uint & r3,
6612
- constant int & idx,
6613
- uint3 tgpig[[threadgroup_position_in_grid]],
6614
- uint tiitg[[thread_index_in_threadgroup]],
6615
- uint tiisg[[thread_index_in_simdgroup]],
6616
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6617
- const int64_t bid = tgpig.z/(ne12*ne13);
6618
-
6619
- tgpig.z = tgpig.z%(ne12*ne13);
6620
-
6621
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6622
- device const char * src0 = src0s + id*nb02;
6623
-
6624
- kernel_mul_mv_q5_K_f32_impl(
6625
- src0,
6626
- (device const float *) (src1 + bid*nb11),
6627
- dst + bid*ne0,
6628
- ne00,
6629
- ne01,
6630
- ne02,
6631
- ne10,
6632
- ne12,
6633
- ne0,
6634
- ne1,
6635
- r2,
6636
- r3,
6637
- tgpig,
6638
- tiisg,
6639
- sgitg);
6640
- }
6641
-
6642
- [[host_name("kernel_mul_mv_id_q6_K_f32")]]
6643
- kernel void kernel_mul_mv_id_q6_K_f32(
6644
- device const char * src0s,
6645
- device const char * src1,
6646
- device float * dst,
6647
- device const char * ids,
6648
- constant uint64_t & nbi1,
6649
- constant int64_t & ne00,
6650
- constant int64_t & ne01,
6651
- constant int64_t & ne02,
6652
- constant uint64_t & nb00,
6653
- constant uint64_t & nb01,
6654
- constant uint64_t & nb02,
6655
- constant int64_t & ne10,
6656
- constant int64_t & ne11,
6657
- constant int64_t & ne12,
6658
- constant int64_t & ne13,
6659
- constant uint64_t & nb10,
6660
- constant uint64_t & nb11,
6661
- constant uint64_t & nb12,
6662
- constant int64_t & ne0,
6663
- constant int64_t & ne1,
6664
- constant uint64_t & nb1,
6665
- constant uint & r2,
6666
- constant uint & r3,
6667
- constant int & idx,
6668
- uint3 tgpig[[threadgroup_position_in_grid]],
6669
- uint tiitg[[thread_index_in_threadgroup]],
6670
- uint tiisg[[thread_index_in_simdgroup]],
6671
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6672
- const int64_t bid = tgpig.z/(ne12*ne13);
6673
-
6674
- tgpig.z = tgpig.z%(ne12*ne13);
6675
-
6676
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6677
- device const char * src0 = src0s + id*nb02;
6678
-
6679
- kernel_mul_mv_q6_K_f32_impl(
6680
- src0,
6681
- (device const float *) (src1 + bid*nb11),
6682
- dst + bid*ne0,
6683
- ne00,
6684
- ne01,
6685
- ne02,
6686
- ne10,
6687
- ne12,
6688
- ne0,
6689
- ne1,
6690
- r2,
6691
- r3,
6692
- tgpig,
6693
- tiisg,
6694
- sgitg);
6695
- }
6696
-
6697
- [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
6698
- kernel void kernel_mul_mv_id_iq2_xxs_f32(
6699
- device const char * src0s,
6700
- device const char * src1,
6701
- device float * dst,
6702
- device const char * ids,
6703
- constant uint64_t & nbi1,
6704
- constant int64_t & ne00,
6705
- constant int64_t & ne01,
6706
- constant int64_t & ne02,
6707
- constant uint64_t & nb00,
6708
- constant uint64_t & nb01,
6709
- constant uint64_t & nb02,
6710
- constant int64_t & ne10,
6711
- constant int64_t & ne11,
6712
- constant int64_t & ne12,
6713
- constant int64_t & ne13,
6714
- constant uint64_t & nb10,
6715
- constant uint64_t & nb11,
6716
- constant uint64_t & nb12,
6717
- constant int64_t & ne0,
6718
- constant int64_t & ne1,
6719
- constant uint64_t & nb1,
6720
- constant uint & r2,
6721
- constant uint & r3,
6722
- constant int & idx,
6723
- threadgroup int8_t * shared_values [[threadgroup(0)]],
6724
- uint3 tgpig[[threadgroup_position_in_grid]],
6725
- uint tiitg[[thread_index_in_threadgroup]],
6726
- uint tiisg[[thread_index_in_simdgroup]],
6727
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6728
- const int64_t bid = tgpig.z/(ne12*ne13);
6729
-
6730
- tgpig.z = tgpig.z%(ne12*ne13);
6731
-
6732
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6733
- device const char * src0 = src0s + id*nb02;
6734
-
6735
- kernel_mul_mv_iq2_xxs_f32_impl(
6736
- src0,
6737
- (device const float *) (src1 + bid*nb11),
6738
- dst + bid*ne0,
6739
- ne00,
6740
- ne01,
6741
- ne02,
6742
- ne10,
6743
- ne12,
6744
- ne0,
6745
- ne1,
6746
- r2,
6747
- r3,
6748
- shared_values,
6749
- tgpig,
6750
- tiisg,
6751
- sgitg);
6752
- }
6753
-
6754
- [[host_name("kernel_mul_mv_id_iq2_xs_f32")]]
6755
- kernel void kernel_mul_mv_id_iq2_xs_f32(
6756
- device const char * src0s,
6757
- device const char * src1,
6758
- device float * dst,
6759
- device const char * ids,
6760
- constant uint64_t & nbi1,
6761
- constant int64_t & ne00,
6762
- constant int64_t & ne01,
6763
- constant int64_t & ne02,
6764
- constant uint64_t & nb00,
6765
- constant uint64_t & nb01,
6766
- constant uint64_t & nb02,
6767
- constant int64_t & ne10,
6768
- constant int64_t & ne11,
6769
- constant int64_t & ne12,
6770
- constant int64_t & ne13,
6771
- constant uint64_t & nb10,
6772
- constant uint64_t & nb11,
6773
- constant uint64_t & nb12,
6774
- constant int64_t & ne0,
6775
- constant int64_t & ne1,
6776
- constant uint64_t & nb1,
6777
- constant uint & r2,
6778
- constant uint & r3,
6779
- constant int & idx,
6780
- threadgroup int8_t * shared_values [[threadgroup(0)]],
6781
- uint3 tgpig[[threadgroup_position_in_grid]],
6782
- uint tiitg[[thread_index_in_threadgroup]],
6783
- uint tiisg[[thread_index_in_simdgroup]],
6784
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6785
- const int64_t bid = tgpig.z/(ne12*ne13);
6786
-
6787
- tgpig.z = tgpig.z%(ne12*ne13);
6788
-
6789
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6790
- device const char * src0 = src0s + id*nb02;
6791
-
6792
- kernel_mul_mv_iq2_xs_f32_impl(
6793
- src0,
6794
- (device const float *) (src1 + bid*nb11),
6795
- dst + bid*ne0,
6796
- ne00,
6797
- ne01,
6798
- ne02,
6799
- ne10,
6800
- ne12,
6801
- ne0,
6802
- ne1,
6803
- r2,
6804
- r3,
6805
- shared_values,
6806
- tgpig,
6807
- tiisg,
6808
- sgitg);
6809
- }
6810
-
6811
- [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]]
6812
- kernel void kernel_mul_mv_id_iq3_xxs_f32(
6813
- device const char * src0s,
6814
- device const char * src1,
6815
- device float * dst,
6816
- device const char * ids,
6817
- constant uint64_t & nbi1,
6818
- constant int64_t & ne00,
6819
- constant int64_t & ne01,
6820
- constant int64_t & ne02,
6821
- constant uint64_t & nb00,
6822
- constant uint64_t & nb01,
6823
- constant uint64_t & nb02,
6824
- constant int64_t & ne10,
6825
- constant int64_t & ne11,
6826
- constant int64_t & ne12,
6827
- constant int64_t & ne13,
6828
- constant uint64_t & nb10,
6829
- constant uint64_t & nb11,
6830
- constant uint64_t & nb12,
6831
- constant int64_t & ne0,
6832
- constant int64_t & ne1,
6833
- constant uint64_t & nb1,
6834
- constant uint & r2,
6835
- constant uint & r3,
6836
- constant int & idx,
6837
- threadgroup int8_t * shared_values [[threadgroup(0)]],
6838
- uint3 tgpig[[threadgroup_position_in_grid]],
6839
- uint tiitg[[thread_index_in_threadgroup]],
6840
- uint tiisg[[thread_index_in_simdgroup]],
6841
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6842
- const int64_t bid = tgpig.z/(ne12*ne13);
6843
-
6844
- tgpig.z = tgpig.z%(ne12*ne13);
6845
-
6846
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6847
- device const char * src0 = src0s + id*nb02;
6848
-
6849
- kernel_mul_mv_iq3_xxs_f32_impl(
6850
- src0,
6851
- (device const float *) (src1 + bid*nb11),
6852
- dst + bid*ne0,
6853
- ne00,
6854
- ne01,
6855
- ne02,
6856
- ne10,
6857
- ne12,
6858
- ne0,
6859
- ne1,
6860
- r2,
6861
- r3,
6862
- shared_values,
6863
- tgpig,
6864
- tiisg,
6865
- sgitg);
6866
- }
6867
-
6868
- [[host_name("kernel_mul_mv_id_iq3_s_f32")]]
6869
- kernel void kernel_mul_mv_id_iq3_s_f32(
6870
- device const char * src0s,
6871
- device const char * src1,
6872
- device float * dst,
6873
- device const char * ids,
6874
- constant uint64_t & nbi1,
6875
- constant int64_t & ne00,
6876
- constant int64_t & ne01,
6877
- constant int64_t & ne02,
6878
- constant uint64_t & nb00,
6879
- constant uint64_t & nb01,
6880
- constant uint64_t & nb02,
6881
- constant int64_t & ne10,
6882
- constant int64_t & ne11,
6883
- constant int64_t & ne12,
6884
- constant int64_t & ne13,
6885
- constant uint64_t & nb10,
6886
- constant uint64_t & nb11,
6887
- constant uint64_t & nb12,
6888
- constant int64_t & ne0,
6889
- constant int64_t & ne1,
6890
- constant uint64_t & nb1,
6891
- constant uint & r2,
6892
- constant uint & r3,
6893
- constant int & idx,
6894
- threadgroup int8_t * shared_values [[threadgroup(0)]],
6895
- uint3 tgpig[[threadgroup_position_in_grid]],
6896
- uint tiitg[[thread_index_in_threadgroup]],
6897
- uint tiisg[[thread_index_in_simdgroup]],
6898
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6899
- const int64_t bid = tgpig.z/(ne12*ne13);
6900
-
6901
- tgpig.z = tgpig.z%(ne12*ne13);
6902
-
6903
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6904
- device const char * src0 = src0s + id*nb02;
6905
-
6906
- kernel_mul_mv_iq3_s_f32_impl(
6907
- src0,
6908
- (device const float *) (src1 + bid*nb11),
6909
- dst + bid*ne0,
6910
- ne00,
6911
- ne01,
6912
- ne02,
6913
- ne10,
6914
- ne12,
6915
- ne0,
6916
- ne1,
6917
- r2,
6918
- r3,
6919
- shared_values,
6920
- tgpig,
6921
- tiisg,
6922
- sgitg);
6923
- }
6924
-
6925
- [[host_name("kernel_mul_mv_id_iq2_s_f32")]]
6926
- kernel void kernel_mul_mv_id_iq2_s_f32(
6927
- device const char * src0s,
6928
- device const char * src1,
6929
- device float * dst,
6930
- device const char * ids,
6931
- constant uint64_t & nbi1,
6932
- constant int64_t & ne00,
6933
- constant int64_t & ne01,
6934
- constant int64_t & ne02,
6935
- constant uint64_t & nb00,
6936
- constant uint64_t & nb01,
6937
- constant uint64_t & nb02,
6938
- constant int64_t & ne10,
6939
- constant int64_t & ne11,
6940
- constant int64_t & ne12,
6941
- constant int64_t & ne13,
6942
- constant uint64_t & nb10,
6943
- constant uint64_t & nb11,
6944
- constant uint64_t & nb12,
6945
- constant int64_t & ne0,
6946
- constant int64_t & ne1,
6947
- constant uint64_t & nb1,
6948
- constant uint & r2,
6949
- constant uint & r3,
6950
- constant int & idx,
6951
- threadgroup int8_t * shared_values [[threadgroup(0)]],
6952
- uint3 tgpig[[threadgroup_position_in_grid]],
6953
- uint tiitg[[thread_index_in_threadgroup]],
6954
- uint tiisg[[thread_index_in_simdgroup]],
6955
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6956
- const int64_t bid = tgpig.z/(ne12*ne13);
6957
-
6958
- tgpig.z = tgpig.z%(ne12*ne13);
6959
-
6960
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6961
- device const char * src0 = src0s + id*nb02;
6962
-
6963
- kernel_mul_mv_iq2_s_f32_impl(
6964
- src0,
6965
- (device const float *) (src1 + bid*nb11),
6966
- dst + bid*ne0,
6967
- ne00,
6968
- ne01,
6969
- ne02,
6970
- ne10,
6971
- ne12,
6972
- ne0,
6973
- ne1,
6974
- r2,
6975
- r3,
6976
- shared_values,
6977
- tgpig,
6978
- tiisg,
6979
- sgitg);
6980
- }
6981
-
6982
- [[host_name("kernel_mul_mv_id_iq1_s_f32")]]
6983
- kernel void kernel_mul_mv_id_iq1_s_f32(
6984
- device const char * src0s,
6985
- device const char * src1,
6986
- device float * dst,
6987
- device const char * ids,
6988
- constant uint64_t & nbi1,
6989
- constant int64_t & ne00,
6990
- constant int64_t & ne01,
6991
- constant int64_t & ne02,
6992
- constant uint64_t & nb00,
6993
- constant uint64_t & nb01,
6994
- constant uint64_t & nb02,
6995
- constant int64_t & ne10,
6996
- constant int64_t & ne11,
6997
- constant int64_t & ne12,
6998
- constant int64_t & ne13,
6999
- constant uint64_t & nb10,
7000
- constant uint64_t & nb11,
7001
- constant uint64_t & nb12,
7002
- constant int64_t & ne0,
7003
- constant int64_t & ne1,
7004
- constant uint64_t & nb1,
7005
- constant uint & r2,
7006
- constant uint & r3,
7007
- constant int & idx,
7008
- uint3 tgpig[[threadgroup_position_in_grid]],
7009
- uint tiitg[[thread_index_in_threadgroup]],
7010
- uint tiisg[[thread_index_in_simdgroup]],
7011
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
7012
- const int64_t bid = tgpig.z/(ne12*ne13);
7013
-
7014
- tgpig.z = tgpig.z%(ne12*ne13);
7015
-
7016
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
7017
- device const char * src0 = src0s + id*nb02;
7018
-
7019
- kernel_mul_mv_iq1_s_f32_impl(
7020
- src0,
7021
- (device const float *) (src1 + bid*nb11),
7022
- dst + bid*ne0,
7023
- ne00,
7024
- ne01,
7025
- ne02,
7026
- ne10,
7027
- ne12,
7028
- ne0,
7029
- ne1,
7030
- r2,
7031
- r3,
7032
- tgpig,
7033
- tiisg,
7034
- sgitg);
7035
- }
7036
-
7037
- [[host_name("kernel_mul_mv_id_iq1_m_f32")]]
7038
- kernel void kernel_mul_mv_id_iq1_m_f32(
7039
- device const char * src0s,
7040
- device const char * src1,
7041
- device float * dst,
7042
- device const char * ids,
7043
- constant uint64_t & nbi1,
7044
- constant int64_t & ne00,
7045
- constant int64_t & ne01,
7046
- constant int64_t & ne02,
7047
- constant uint64_t & nb00,
7048
- constant uint64_t & nb01,
7049
- constant uint64_t & nb02,
7050
- constant int64_t & ne10,
7051
- constant int64_t & ne11,
7052
- constant int64_t & ne12,
7053
- constant int64_t & ne13,
7054
- constant uint64_t & nb10,
7055
- constant uint64_t & nb11,
7056
- constant uint64_t & nb12,
7057
- constant int64_t & ne0,
7058
- constant int64_t & ne1,
7059
- constant uint64_t & nb1,
7060
- constant uint & r2,
7061
- constant uint & r3,
7062
- constant int & idx,
7063
- uint3 tgpig[[threadgroup_position_in_grid]],
7064
- uint tiitg[[thread_index_in_threadgroup]],
7065
- uint tiisg[[thread_index_in_simdgroup]],
7066
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
7067
- const int64_t bid = tgpig.z/(ne12*ne13);
7068
-
7069
- tgpig.z = tgpig.z%(ne12*ne13);
7070
-
7071
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
7072
- device const char * src0 = src0s + id*nb02;
7073
-
7074
- kernel_mul_mv_iq1_m_f32_impl(
7075
- src0,
7076
- (device const float *) (src1 + bid*nb11),
7077
- dst + bid*ne0,
7078
- ne00,
7079
- ne01,
7080
- ne02,
7081
- ne10,
7082
- ne12,
7083
- ne0,
7084
- ne1,
7085
- r2,
7086
- r3,
7087
- tgpig,
7088
- tiisg,
7089
- sgitg);
7090
- }
7091
-
7092
- [[host_name("kernel_mul_mv_id_iq4_nl_f32")]]
7093
- kernel void kernel_mul_mv_id_iq4_nl_f32(
7094
- device const char * src0s,
7095
- device const char * src1,
7096
- device float * dst,
7097
- device const char * ids,
7098
- constant uint64_t & nbi1,
7099
- constant int64_t & ne00,
7100
- constant int64_t & ne01,
7101
- constant int64_t & ne02,
7102
- constant uint64_t & nb00,
7103
- constant uint64_t & nb01,
7104
- constant uint64_t & nb02,
7105
- constant int64_t & ne10,
7106
- constant int64_t & ne11,
7107
- constant int64_t & ne12,
7108
- constant int64_t & ne13,
7109
- constant uint64_t & nb10,
7110
- constant uint64_t & nb11,
7111
- constant uint64_t & nb12,
7112
- constant int64_t & ne0,
7113
- constant int64_t & ne1,
7114
- constant uint64_t & nb1,
7115
- constant uint & r2,
7116
- constant uint & r3,
7117
- constant int & idx,
7118
- threadgroup float * shared_values [[threadgroup(0)]],
7119
- uint3 tgpig[[threadgroup_position_in_grid]],
7120
- uint tiitg[[thread_index_in_threadgroup]],
7121
- uint tiisg[[thread_index_in_simdgroup]],
7122
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
7123
- const int64_t bid = tgpig.z/(ne12*ne13);
7124
-
7125
- tgpig.z = tgpig.z%(ne12*ne13);
7126
-
7127
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
7128
- device const char * src0 = src0s + id*nb02;
7129
-
7130
- kernel_mul_mv_iq4_nl_f32_impl(
7131
- src0,
7132
- (device const float *) (src1 + bid*nb11),
7133
- dst + bid*ne0,
7134
- ne00,
7135
- ne01,
7136
- ne02,
7137
- ne10,
7138
- ne12,
7139
- ne0,
7140
- ne1,
7141
- r2,
7142
- r3,
7143
- shared_values,
7144
- tgpig,
7145
- tiisg,
7146
- sgitg);
7147
- }
7148
-
7149
- [[host_name("kernel_mul_mv_id_iq4_xs_f32")]]
7150
- kernel void kernel_mul_mv_id_iq4_xs_f32(
7151
- device const char * src0s,
7152
- device const char * src1,
7153
- device float * dst,
7154
- device const char * ids,
7155
- constant uint64_t & nbi1,
7156
- constant int64_t & ne00,
7157
- constant int64_t & ne01,
7158
- constant int64_t & ne02,
7159
- constant uint64_t & nb00,
7160
- constant uint64_t & nb01,
7161
- constant uint64_t & nb02,
7162
- constant int64_t & ne10,
7163
- constant int64_t & ne11,
7164
- constant int64_t & ne12,
7165
- constant int64_t & ne13,
7166
- constant uint64_t & nb10,
7167
- constant uint64_t & nb11,
7168
- constant uint64_t & nb12,
7169
- constant int64_t & ne0,
7170
- constant int64_t & ne1,
7171
- constant uint64_t & nb1,
7172
- constant uint & r2,
7173
- constant uint & r3,
7174
- constant int & idx,
7175
- threadgroup float * shared_values [[threadgroup(0)]],
7176
- uint3 tgpig[[threadgroup_position_in_grid]],
7177
- uint tiitg[[thread_index_in_threadgroup]],
7178
- uint tiisg[[thread_index_in_simdgroup]],
7179
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
7180
- const int64_t bid = tgpig.z/(ne12*ne13);
7181
-
7182
- tgpig.z = tgpig.z%(ne12*ne13);
7183
-
7184
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
7185
- device const char * src0 = src0s + id*nb02;
7186
-
7187
- #if QK_K == 64
7188
- kernel_mul_mv_iq4_nl_f32_impl(
7189
- #else
7190
- kernel_mul_mv_iq4_xs_f32_impl(
6221
+ typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>) kernel_mul_mv_id_t;
6222
+
6223
+ template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>;
6224
+ template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f16_f32_impl>>;
6225
+ template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
6226
+ template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
6227
+ template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
6228
+ template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
6229
+ template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
6230
+ template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl>>;
6231
+ template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl>>;
6232
+ template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl>>;
6233
+ template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl>>;
6234
+ template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl>>;
6235
+ template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl>>;
6236
+ template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl>>;
6237
+ template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl>>;
6238
+ template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl>>;
6239
+ template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl>>;
6240
+ template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl>>;
6241
+ template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
6242
+ template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
6243
+ #if QK_K != 64
6244
+ template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
7191
6245
  #endif
7192
- src0,
7193
- (device const float *) (src1 + bid*nb11),
7194
- dst + bid*ne0,
7195
- ne00,
7196
- ne01,
7197
- ne02,
7198
- ne10,
7199
- ne12,
7200
- ne0,
7201
- ne1,
7202
- r2,
7203
- r3,
7204
- shared_values,
7205
- tgpig,
7206
- tiisg,
7207
- sgitg);
7208
- }
6246
+