llama_cpp 0.14.5 → 0.14.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -213,6 +213,15 @@ kernel void kernel_scale_4(
213
213
  dst[tpig] = src0[tpig] * scale;
214
214
  }
215
215
 
216
+ kernel void kernel_clamp(
217
+ device const float * src0,
218
+ device float * dst,
219
+ constant float & min,
220
+ constant float & max,
221
+ uint tpig[[thread_position_in_grid]]) {
222
+ dst[tpig] = src0[tpig] < min ? min : (src0[tpig] > max ? max : src0[tpig]);
223
+ }
224
+
216
225
  kernel void kernel_relu(
217
226
  device const float * src0,
218
227
  device float * dst,
@@ -233,6 +242,15 @@ constant float GELU_QUICK_COEF = -1.702f;
233
242
  constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
234
243
 
235
244
  kernel void kernel_gelu(
245
+ device const float * src0,
246
+ device float * dst,
247
+ uint tpig[[thread_position_in_grid]]) {
248
+ device const float & x = src0[tpig];
249
+
250
+ dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
251
+ }
252
+
253
+ kernel void kernel_gelu_4(
236
254
  device const float4 * src0,
237
255
  device float4 * dst,
238
256
  uint tpig[[thread_position_in_grid]]) {
@@ -246,6 +264,15 @@ kernel void kernel_gelu(
246
264
  }
247
265
 
248
266
  kernel void kernel_gelu_quick(
267
+ device const float * src0,
268
+ device float * dst,
269
+ uint tpig[[thread_position_in_grid]]) {
270
+ device const float & x = src0[tpig];
271
+
272
+ dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
273
+ }
274
+
275
+ kernel void kernel_gelu_quick_4(
249
276
  device const float4 * src0,
250
277
  device float4 * dst,
251
278
  uint tpig[[thread_position_in_grid]]) {
@@ -255,6 +282,14 @@ kernel void kernel_gelu_quick(
255
282
  }
256
283
 
257
284
  kernel void kernel_silu(
285
+ device const float * src0,
286
+ device float * dst,
287
+ uint tpig[[thread_position_in_grid]]) {
288
+ device const float & x = src0[tpig];
289
+ dst[tpig] = x / (1.0f + exp(-x));
290
+ }
291
+
292
+ kernel void kernel_silu_4(
258
293
  device const float4 * src0,
259
294
  device float4 * dst,
260
295
  uint tpig[[thread_position_in_grid]]) {
@@ -866,6 +901,7 @@ void mul_vec_q_n_f32_impl(
866
901
  int64_t ne1,
867
902
  uint r2,
868
903
  uint r3,
904
+ threadgroup int8_t * shared_values,
869
905
  uint3 tgpig, uint tiisg, uint sgitg) {
870
906
  const int nb = ne00/QK4_0;
871
907
 
@@ -942,7 +978,7 @@ kernel void kernel_mul_mv_q4_0_f32(
942
978
  uint3 tgpig[[threadgroup_position_in_grid]],
943
979
  uint tiisg[[thread_index_in_simdgroup]],
944
980
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
945
- mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
981
+ mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
946
982
  }
947
983
 
948
984
  kernel void kernel_mul_mv_q4_1_f32(
@@ -968,7 +1004,7 @@ kernel void kernel_mul_mv_q4_1_f32(
968
1004
  uint3 tgpig[[threadgroup_position_in_grid]],
969
1005
  uint tiisg[[thread_index_in_simdgroup]],
970
1006
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
971
- mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
1007
+ mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
972
1008
  }
973
1009
 
974
1010
  kernel void kernel_mul_mv_q5_0_f32(
@@ -994,7 +1030,7 @@ kernel void kernel_mul_mv_q5_0_f32(
994
1030
  uint3 tgpig[[threadgroup_position_in_grid]],
995
1031
  uint tiisg[[thread_index_in_simdgroup]],
996
1032
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
997
- mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
1033
+ mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
998
1034
  }
999
1035
 
1000
1036
  kernel void kernel_mul_mv_q5_1_f32(
@@ -1020,7 +1056,7 @@ kernel void kernel_mul_mv_q5_1_f32(
1020
1056
  uint3 tgpig[[threadgroup_position_in_grid]],
1021
1057
  uint tiisg[[thread_index_in_simdgroup]],
1022
1058
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
1023
- mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
1059
+ mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
1024
1060
  }
1025
1061
 
1026
1062
 
@@ -1030,18 +1066,19 @@ void kernel_mul_mv_q8_0_f32_impl(
1030
1066
  device const void * src0,
1031
1067
  device const float * src1,
1032
1068
  device float * dst,
1033
- constant int64_t & ne00,
1034
- constant int64_t & ne01,
1035
- constant int64_t & ne02,
1036
- constant int64_t & ne10,
1037
- constant int64_t & ne12,
1038
- constant int64_t & ne0,
1039
- constant int64_t & ne1,
1040
- constant uint & r2,
1041
- constant uint & r3,
1042
- uint3 tgpig[[threadgroup_position_in_grid]],
1043
- uint tiisg[[thread_index_in_simdgroup]],
1044
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
1069
+ int64_t ne00,
1070
+ int64_t ne01,
1071
+ int64_t ne02,
1072
+ int64_t ne10,
1073
+ int64_t ne12,
1074
+ int64_t ne0,
1075
+ int64_t ne1,
1076
+ uint r2,
1077
+ uint r3,
1078
+ threadgroup int8_t * shared_values,
1079
+ uint3 tgpig,
1080
+ uint tiisg,
1081
+ uint sgitg) {
1045
1082
  const int nr = N_DST;
1046
1083
  const int nsg = N_SIMDGROUP;
1047
1084
  const int nw = N_SIMDWIDTH;
@@ -1119,7 +1156,7 @@ kernel void kernel_mul_mv_q8_0_f32(
1119
1156
  uint3 tgpig[[threadgroup_position_in_grid]],
1120
1157
  uint tiisg[[thread_index_in_simdgroup]],
1121
1158
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
1122
- kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
1159
+ kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
1123
1160
  }
1124
1161
 
1125
1162
  #define N_F32_F32 4
@@ -1128,24 +1165,24 @@ void kernel_mul_mv_f32_f32_impl(
1128
1165
  device const char * src0,
1129
1166
  device const char * src1,
1130
1167
  device float * dst,
1131
- constant int64_t & ne00,
1132
- constant int64_t & ne01,
1133
- constant int64_t & ne02,
1134
- constant uint64_t & nb00,
1135
- constant uint64_t & nb01,
1136
- constant uint64_t & nb02,
1137
- constant int64_t & ne10,
1138
- constant int64_t & ne11,
1139
- constant int64_t & ne12,
1140
- constant uint64_t & nb10,
1141
- constant uint64_t & nb11,
1142
- constant uint64_t & nb12,
1143
- constant int64_t & ne0,
1144
- constant int64_t & ne1,
1145
- constant uint & r2,
1146
- constant uint & r3,
1147
- uint3 tgpig[[threadgroup_position_in_grid]],
1148
- uint tiisg[[thread_index_in_simdgroup]]) {
1168
+ int64_t ne00,
1169
+ int64_t ne01,
1170
+ int64_t ne02,
1171
+ uint64_t nb00,
1172
+ uint64_t nb01,
1173
+ uint64_t nb02,
1174
+ int64_t ne10,
1175
+ int64_t ne11,
1176
+ int64_t ne12,
1177
+ uint64_t nb10,
1178
+ uint64_t nb11,
1179
+ uint64_t nb12,
1180
+ int64_t ne0,
1181
+ int64_t ne1,
1182
+ uint r2,
1183
+ uint r3,
1184
+ uint3 tgpig,
1185
+ uint tiisg) {
1149
1186
 
1150
1187
  const int64_t r0 = tgpig.x;
1151
1188
  const int64_t rb = tgpig.y*N_F32_F32;
@@ -1398,24 +1435,24 @@ void kernel_mul_mv_f16_f32_impl(
1398
1435
  device const char * src0,
1399
1436
  device const char * src1,
1400
1437
  device float * dst,
1401
- constant int64_t & ne00,
1402
- constant int64_t & ne01,
1403
- constant int64_t & ne02,
1404
- constant uint64_t & nb00,
1405
- constant uint64_t & nb01,
1406
- constant uint64_t & nb02,
1407
- constant int64_t & ne10,
1408
- constant int64_t & ne11,
1409
- constant int64_t & ne12,
1410
- constant uint64_t & nb10,
1411
- constant uint64_t & nb11,
1412
- constant uint64_t & nb12,
1413
- constant int64_t & ne0,
1414
- constant int64_t & ne1,
1415
- constant uint & r2,
1416
- constant uint & r3,
1417
- uint3 tgpig[[threadgroup_position_in_grid]],
1418
- uint tiisg[[thread_index_in_simdgroup]]) {
1438
+ int64_t ne00,
1439
+ int64_t ne01,
1440
+ int64_t ne02,
1441
+ uint64_t nb00,
1442
+ uint64_t nb01,
1443
+ uint64_t nb02,
1444
+ int64_t ne10,
1445
+ int64_t ne11,
1446
+ int64_t ne12,
1447
+ uint64_t nb10,
1448
+ uint64_t nb11,
1449
+ uint64_t nb12,
1450
+ int64_t ne0,
1451
+ int64_t ne1,
1452
+ uint r2,
1453
+ uint r3,
1454
+ uint3 tgpig,
1455
+ uint tiisg) {
1419
1456
 
1420
1457
  const int64_t r0 = tgpig.x;
1421
1458
  const int64_t rb = tgpig.y*N_F16_F32;
@@ -2700,18 +2737,19 @@ void kernel_mul_mv_q2_K_f32_impl(
2700
2737
  device const void * src0,
2701
2738
  device const float * src1,
2702
2739
  device float * dst,
2703
- constant int64_t & ne00,
2704
- constant int64_t & ne01,
2705
- constant int64_t & ne02,
2706
- constant int64_t & ne10,
2707
- constant int64_t & ne12,
2708
- constant int64_t & ne0,
2709
- constant int64_t & ne1,
2710
- constant uint & r2,
2711
- constant uint & r3,
2712
- uint3 tgpig[[threadgroup_position_in_grid]],
2713
- uint tiisg[[thread_index_in_simdgroup]],
2714
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2740
+ int64_t ne00,
2741
+ int64_t ne01,
2742
+ int64_t ne02,
2743
+ int64_t ne10,
2744
+ int64_t ne12,
2745
+ int64_t ne0,
2746
+ int64_t ne1,
2747
+ uint r2,
2748
+ uint r3,
2749
+ threadgroup int8_t * shared_values,
2750
+ uint3 tgpig,
2751
+ uint tiisg,
2752
+ uint sgitg) {
2715
2753
 
2716
2754
  const int nb = ne00/QK_K;
2717
2755
  const int r0 = tgpig.x;
@@ -2871,7 +2909,7 @@ kernel void kernel_mul_mv_q2_K_f32(
2871
2909
  uint tiisg[[thread_index_in_simdgroup]],
2872
2910
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
2873
2911
 
2874
- kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
2912
+ kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
2875
2913
  }
2876
2914
 
2877
2915
  #if QK_K == 256
@@ -2879,18 +2917,19 @@ void kernel_mul_mv_q3_K_f32_impl(
2879
2917
  device const void * src0,
2880
2918
  device const float * src1,
2881
2919
  device float * dst,
2882
- constant int64_t & ne00,
2883
- constant int64_t & ne01,
2884
- constant int64_t & ne02,
2885
- constant int64_t & ne10,
2886
- constant int64_t & ne12,
2887
- constant int64_t & ne0,
2888
- constant int64_t & ne1,
2889
- constant uint & r2,
2890
- constant uint & r3,
2891
- uint3 tgpig[[threadgroup_position_in_grid]],
2892
- uint tiisg[[thread_index_in_simdgroup]],
2893
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2920
+ int64_t ne00,
2921
+ int64_t ne01,
2922
+ int64_t ne02,
2923
+ int64_t ne10,
2924
+ int64_t ne12,
2925
+ int64_t ne0,
2926
+ int64_t ne1,
2927
+ uint r2,
2928
+ uint r3,
2929
+ threadgroup int8_t * shared_values,
2930
+ uint3 tgpig,
2931
+ uint tiisg,
2932
+ uint sgitg) {
2894
2933
 
2895
2934
  const int nb = ne00/QK_K;
2896
2935
 
@@ -3046,6 +3085,7 @@ void kernel_mul_mv_q3_K_f32_impl(
3046
3085
  constant int64_t & ne1,
3047
3086
  constant uint & r2,
3048
3087
  constant uint & r3,
3088
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
3049
3089
  uint3 tgpig[[threadgroup_position_in_grid]],
3050
3090
  uint tiisg[[thread_index_in_simdgroup]],
3051
3091
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3135,7 +3175,7 @@ kernel void kernel_mul_mv_q3_K_f32(
3135
3175
  uint tiisg[[thread_index_in_simdgroup]],
3136
3176
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
3137
3177
 
3138
- kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
3178
+ kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
3139
3179
  }
3140
3180
 
3141
3181
  #if QK_K == 256
@@ -3143,18 +3183,19 @@ void kernel_mul_mv_q4_K_f32_impl(
3143
3183
  device const void * src0,
3144
3184
  device const float * src1,
3145
3185
  device float * dst,
3146
- constant int64_t & ne00,
3147
- constant int64_t & ne01,
3148
- constant int64_t & ne02,
3149
- constant int64_t & ne10,
3150
- constant int64_t & ne12,
3151
- constant int64_t & ne0,
3152
- constant int64_t & ne1,
3153
- constant uint & r2,
3154
- constant uint & r3,
3155
- uint3 tgpig[[threadgroup_position_in_grid]],
3156
- uint tiisg[[thread_index_in_simdgroup]],
3157
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3186
+ int64_t ne00,
3187
+ int64_t ne01,
3188
+ int64_t ne02,
3189
+ int64_t ne10,
3190
+ int64_t ne12,
3191
+ int64_t ne0,
3192
+ int64_t ne1,
3193
+ uint r2,
3194
+ uint r3,
3195
+ threadgroup int8_t * shared_values,
3196
+ uint3 tgpig,
3197
+ uint tiisg,
3198
+ uint sgitg) {
3158
3199
 
3159
3200
  const uint16_t kmask1 = 0x3f3f;
3160
3201
  const uint16_t kmask2 = 0x0f0f;
@@ -3265,6 +3306,7 @@ void kernel_mul_mv_q4_K_f32_impl(
3265
3306
  constant int64_t & ne1,
3266
3307
  constant uint & r2,
3267
3308
  constant uint & r3,
3309
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
3268
3310
  uint3 tgpig[[threadgroup_position_in_grid]],
3269
3311
  uint tiisg[[thread_index_in_simdgroup]],
3270
3312
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3373,25 +3415,26 @@ kernel void kernel_mul_mv_q4_K_f32(
3373
3415
  uint tiisg[[thread_index_in_simdgroup]],
3374
3416
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
3375
3417
 
3376
- kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
3418
+ kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
3377
3419
  }
3378
3420
 
3379
3421
  void kernel_mul_mv_q5_K_f32_impl(
3380
3422
  device const void * src0,
3381
3423
  device const float * src1,
3382
3424
  device float * dst,
3383
- constant int64_t & ne00,
3384
- constant int64_t & ne01,
3385
- constant int64_t & ne02,
3386
- constant int64_t & ne10,
3387
- constant int64_t & ne12,
3388
- constant int64_t & ne0,
3389
- constant int64_t & ne1,
3390
- constant uint & r2,
3391
- constant uint & r3,
3392
- uint3 tgpig[[threadgroup_position_in_grid]],
3393
- uint tiisg[[thread_index_in_simdgroup]],
3394
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3425
+ int64_t ne00,
3426
+ int64_t ne01,
3427
+ int64_t ne02,
3428
+ int64_t ne10,
3429
+ int64_t ne12,
3430
+ int64_t ne0,
3431
+ int64_t ne1,
3432
+ uint r2,
3433
+ uint r3,
3434
+ threadgroup int8_t * shared_values,
3435
+ uint3 tgpig,
3436
+ uint tiisg,
3437
+ uint sgitg) {
3395
3438
 
3396
3439
  const int nb = ne00/QK_K;
3397
3440
 
@@ -3579,25 +3622,26 @@ kernel void kernel_mul_mv_q5_K_f32(
3579
3622
  uint tiisg[[thread_index_in_simdgroup]],
3580
3623
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
3581
3624
 
3582
- kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
3625
+ kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
3583
3626
  }
3584
3627
 
3585
3628
  void kernel_mul_mv_q6_K_f32_impl(
3586
3629
  device const void * src0,
3587
3630
  device const float * src1,
3588
3631
  device float * dst,
3589
- constant int64_t & ne00,
3590
- constant int64_t & ne01,
3591
- constant int64_t & ne02,
3592
- constant int64_t & ne10,
3593
- constant int64_t & ne12,
3594
- constant int64_t & ne0,
3595
- constant int64_t & ne1,
3596
- constant uint & r2,
3597
- constant uint & r3,
3598
- uint3 tgpig[[threadgroup_position_in_grid]],
3599
- uint tiisg[[thread_index_in_simdgroup]],
3600
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3632
+ int64_t ne00,
3633
+ int64_t ne01,
3634
+ int64_t ne02,
3635
+ int64_t ne10,
3636
+ int64_t ne12,
3637
+ int64_t ne0,
3638
+ int64_t ne1,
3639
+ uint r2,
3640
+ uint r3,
3641
+ threadgroup int8_t * shared_values,
3642
+ uint3 tgpig,
3643
+ uint tiisg,
3644
+ uint sgitg) {
3601
3645
 
3602
3646
  const uint8_t kmask1 = 0x03;
3603
3647
  const uint8_t kmask2 = 0x0C;
@@ -3713,7 +3757,7 @@ kernel void kernel_mul_mv_q6_K_f32(
3713
3757
  uint tiisg[[thread_index_in_simdgroup]],
3714
3758
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
3715
3759
 
3716
- kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
3760
+ kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
3717
3761
  }
3718
3762
 
3719
3763
  // ======================= "True" 2-bit
@@ -3722,19 +3766,19 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
3722
3766
  device const void * src0,
3723
3767
  device const float * src1,
3724
3768
  device float * dst,
3725
- constant int64_t & ne00,
3726
- constant int64_t & ne01,
3727
- constant int64_t & ne02,
3728
- constant int64_t & ne10,
3729
- constant int64_t & ne12,
3730
- constant int64_t & ne0,
3731
- constant int64_t & ne1,
3732
- constant uint & r2,
3733
- constant uint & r3,
3734
- threadgroup int8_t * shared_values [[threadgroup(0)]],
3735
- uint3 tgpig[[threadgroup_position_in_grid]],
3736
- uint tiisg[[thread_index_in_simdgroup]],
3737
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3769
+ int64_t ne00,
3770
+ int64_t ne01,
3771
+ int64_t ne02,
3772
+ int64_t ne10,
3773
+ int64_t ne12,
3774
+ int64_t ne0,
3775
+ int64_t ne1,
3776
+ uint r2,
3777
+ uint r3,
3778
+ threadgroup int8_t * shared_values,
3779
+ uint3 tgpig,
3780
+ uint tiisg,
3781
+ uint sgitg) {
3738
3782
 
3739
3783
  const int nb = ne00/QK_K;
3740
3784
  const int r0 = tgpig.x;
@@ -3851,19 +3895,19 @@ void kernel_mul_mv_iq2_xs_f32_impl(
3851
3895
  device const void * src0,
3852
3896
  device const float * src1,
3853
3897
  device float * dst,
3854
- constant int64_t & ne00,
3855
- constant int64_t & ne01,
3856
- constant int64_t & ne02,
3857
- constant int64_t & ne10,
3858
- constant int64_t & ne12,
3859
- constant int64_t & ne0,
3860
- constant int64_t & ne1,
3861
- constant uint & r2,
3862
- constant uint & r3,
3863
- threadgroup int8_t * shared_values [[threadgroup(0)]],
3864
- uint3 tgpig[[threadgroup_position_in_grid]],
3865
- uint tiisg[[thread_index_in_simdgroup]],
3866
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3898
+ int64_t ne00,
3899
+ int64_t ne01,
3900
+ int64_t ne02,
3901
+ int64_t ne10,
3902
+ int64_t ne12,
3903
+ int64_t ne0,
3904
+ int64_t ne1,
3905
+ uint r2,
3906
+ uint r3,
3907
+ threadgroup int8_t * shared_values,
3908
+ uint3 tgpig,
3909
+ uint tiisg,
3910
+ uint sgitg) {
3867
3911
 
3868
3912
  const int nb = ne00/QK_K;
3869
3913
  const int r0 = tgpig.x;
@@ -3990,19 +4034,19 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
3990
4034
  device const void * src0,
3991
4035
  device const float * src1,
3992
4036
  device float * dst,
3993
- constant int64_t & ne00,
3994
- constant int64_t & ne01,
3995
- constant int64_t & ne02,
3996
- constant int64_t & ne10,
3997
- constant int64_t & ne12,
3998
- constant int64_t & ne0,
3999
- constant int64_t & ne1,
4000
- constant uint & r2,
4001
- constant uint & r3,
4002
- threadgroup int8_t * shared_values [[threadgroup(0)]],
4003
- uint3 tgpig[[threadgroup_position_in_grid]],
4004
- uint tiisg[[thread_index_in_simdgroup]],
4005
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
4037
+ int64_t ne00,
4038
+ int64_t ne01,
4039
+ int64_t ne02,
4040
+ int64_t ne10,
4041
+ int64_t ne12,
4042
+ int64_t ne0,
4043
+ int64_t ne1,
4044
+ uint r2,
4045
+ uint r3,
4046
+ threadgroup int8_t * shared_values,
4047
+ uint3 tgpig,
4048
+ uint tiisg,
4049
+ uint sgitg) {
4006
4050
 
4007
4051
  const int nb = ne00/QK_K;
4008
4052
  const int r0 = tgpig.x;
@@ -4122,19 +4166,19 @@ void kernel_mul_mv_iq3_s_f32_impl(
4122
4166
  device const void * src0,
4123
4167
  device const float * src1,
4124
4168
  device float * dst,
4125
- constant int64_t & ne00,
4126
- constant int64_t & ne01,
4127
- constant int64_t & ne02,
4128
- constant int64_t & ne10,
4129
- constant int64_t & ne12,
4130
- constant int64_t & ne0,
4131
- constant int64_t & ne1,
4132
- constant uint & r2,
4133
- constant uint & r3,
4134
- threadgroup int8_t * shared_values [[threadgroup(0)]],
4135
- uint3 tgpig[[threadgroup_position_in_grid]],
4136
- uint tiisg[[thread_index_in_simdgroup]],
4137
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
4169
+ int64_t ne00,
4170
+ int64_t ne01,
4171
+ int64_t ne02,
4172
+ int64_t ne10,
4173
+ int64_t ne12,
4174
+ int64_t ne0,
4175
+ int64_t ne1,
4176
+ uint r2,
4177
+ uint r3,
4178
+ threadgroup int8_t * shared_values,
4179
+ uint3 tgpig,
4180
+ uint tiisg,
4181
+ uint sgitg) {
4138
4182
 
4139
4183
  const int nb = ne00/QK_K;
4140
4184
  const int r0 = tgpig.x;
@@ -4254,19 +4298,19 @@ void kernel_mul_mv_iq2_s_f32_impl(
4254
4298
  device const void * src0,
4255
4299
  device const float * src1,
4256
4300
  device float * dst,
4257
- constant int64_t & ne00,
4258
- constant int64_t & ne01,
4259
- constant int64_t & ne02,
4260
- constant int64_t & ne10,
4261
- constant int64_t & ne12,
4262
- constant int64_t & ne0,
4263
- constant int64_t & ne1,
4264
- constant uint & r2,
4265
- constant uint & r3,
4266
- threadgroup int8_t * shared_values [[threadgroup(0)]],
4267
- uint3 tgpig[[threadgroup_position_in_grid]],
4268
- uint tiisg[[thread_index_in_simdgroup]],
4269
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
4301
+ int64_t ne00,
4302
+ int64_t ne01,
4303
+ int64_t ne02,
4304
+ int64_t ne10,
4305
+ int64_t ne12,
4306
+ int64_t ne0,
4307
+ int64_t ne1,
4308
+ uint r2,
4309
+ uint r3,
4310
+ threadgroup int8_t * shared_values,
4311
+ uint3 tgpig,
4312
+ uint tiisg,
4313
+ uint sgitg) {
4270
4314
 
4271
4315
  const int nb = ne00/QK_K;
4272
4316
  const int r0 = tgpig.x;
@@ -4387,18 +4431,19 @@ void kernel_mul_mv_iq1_s_f32_impl(
4387
4431
  device const void * src0,
4388
4432
  device const float * src1,
4389
4433
  device float * dst,
4390
- constant int64_t & ne00,
4391
- constant int64_t & ne01,
4392
- constant int64_t & ne02,
4393
- constant int64_t & ne10,
4394
- constant int64_t & ne12,
4395
- constant int64_t & ne0,
4396
- constant int64_t & ne1,
4397
- constant uint & r2,
4398
- constant uint & r3,
4399
- uint3 tgpig[[threadgroup_position_in_grid]],
4400
- uint tiisg[[thread_index_in_simdgroup]],
4401
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
4434
+ int64_t ne00,
4435
+ int64_t ne01,
4436
+ int64_t ne02,
4437
+ int64_t ne10,
4438
+ int64_t ne12,
4439
+ int64_t ne0,
4440
+ int64_t ne1,
4441
+ uint r2,
4442
+ uint r3,
4443
+ threadgroup int8_t * shared_value,
4444
+ uint3 tgpig,
4445
+ uint tiisg,
4446
+ uint sgitg) {
4402
4447
 
4403
4448
  const int nb = ne00/QK_K;
4404
4449
  const int r0 = tgpig.x;
@@ -4476,18 +4521,19 @@ void kernel_mul_mv_iq1_m_f32_impl(
4476
4521
  device const void * src0,
4477
4522
  device const float * src1,
4478
4523
  device float * dst,
4479
- constant int64_t & ne00,
4480
- constant int64_t & ne01,
4481
- constant int64_t & ne02,
4482
- constant int64_t & ne10,
4483
- constant int64_t & ne12,
4484
- constant int64_t & ne0,
4485
- constant int64_t & ne1,
4486
- constant uint & r2,
4487
- constant uint & r3,
4488
- uint3 tgpig[[threadgroup_position_in_grid]],
4489
- uint tiisg[[thread_index_in_simdgroup]],
4490
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
4524
+ int64_t ne00,
4525
+ int64_t ne01,
4526
+ int64_t ne02,
4527
+ int64_t ne10,
4528
+ int64_t ne12,
4529
+ int64_t ne0,
4530
+ int64_t ne1,
4531
+ uint r2,
4532
+ uint r3,
4533
+ threadgroup int8_t * shared_value,
4534
+ uint3 tgpig,
4535
+ uint tiisg,
4536
+ uint sgitg) {
4491
4537
 
4492
4538
  const int nb = ne00/QK_K;
4493
4539
  const int r0 = tgpig.x;
@@ -4584,20 +4630,21 @@ void kernel_mul_mv_iq4_nl_f32_impl(
4584
4630
  device const void * src0,
4585
4631
  device const float * src1,
4586
4632
  device float * dst,
4587
- constant int64_t & ne00,
4588
- constant int64_t & ne01,
4589
- constant int64_t & ne02,
4590
- constant int64_t & ne10,
4591
- constant int64_t & ne12,
4592
- constant int64_t & ne0,
4593
- constant int64_t & ne1,
4594
- constant uint & r2,
4595
- constant uint & r3,
4596
- threadgroup float * shared_values [[threadgroup(0)]],
4597
- uint3 tgpig[[threadgroup_position_in_grid]],
4598
- uint tiisg[[thread_index_in_simdgroup]],
4599
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
4633
+ int64_t ne00,
4634
+ int64_t ne01,
4635
+ int64_t ne02,
4636
+ int64_t ne10,
4637
+ int64_t ne12,
4638
+ int64_t ne0,
4639
+ int64_t ne1,
4640
+ uint r2,
4641
+ uint r3,
4642
+ threadgroup int8_t * shared_values_i8,
4643
+ uint3 tgpig,
4644
+ uint tiisg,
4645
+ uint sgitg) {
4600
4646
 
4647
+ threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
4601
4648
  const int nb = ne00/QK4_NL;
4602
4649
  const int r0 = tgpig.x;
4603
4650
  const int r1 = tgpig.y;
@@ -4678,20 +4725,21 @@ void kernel_mul_mv_iq4_xs_f32_impl(
4678
4725
  device const void * src0,
4679
4726
  device const float * src1,
4680
4727
  device float * dst,
4681
- constant int64_t & ne00,
4682
- constant int64_t & ne01,
4683
- constant int64_t & ne02,
4684
- constant int64_t & ne10,
4685
- constant int64_t & ne12,
4686
- constant int64_t & ne0,
4687
- constant int64_t & ne1,
4688
- constant uint & r2,
4689
- constant uint & r3,
4690
- threadgroup float * shared_values [[threadgroup(0)]],
4691
- uint3 tgpig[[threadgroup_position_in_grid]],
4692
- uint tiisg[[thread_index_in_simdgroup]],
4693
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
4728
+ int64_t ne00,
4729
+ int64_t ne01,
4730
+ int64_t ne02,
4731
+ int64_t ne10,
4732
+ int64_t ne12,
4733
+ int64_t ne0,
4734
+ int64_t ne1,
4735
+ uint r2,
4736
+ uint r3,
4737
+ threadgroup int8_t * shared_values_i8,
4738
+ uint3 tgpig,
4739
+ uint tiisg,
4740
+ uint sgitg) {
4694
4741
 
4742
+ threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
4695
4743
  const int nb = ne00/QK_K;
4696
4744
  const int r0 = tgpig.x;
4697
4745
  const int r1 = tgpig.y;
@@ -4794,7 +4842,7 @@ kernel void kernel_mul_mv_iq1_s_f32(
4794
4842
  uint tiisg[[thread_index_in_simdgroup]],
4795
4843
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
4796
4844
 
4797
- kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
4845
+ kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
4798
4846
  }
4799
4847
 
4800
4848
  [[host_name("kernel_mul_mv_iq1_m_f32")]]
@@ -4822,7 +4870,7 @@ kernel void kernel_mul_mv_iq1_m_f32(
4822
4870
  uint tiisg[[thread_index_in_simdgroup]],
4823
4871
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
4824
4872
 
4825
- kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
4873
+ kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
4826
4874
  }
4827
4875
 
4828
4876
  [[host_name("kernel_mul_mv_iq4_nl_f32")]]
@@ -4846,7 +4894,7 @@ kernel void kernel_mul_mv_iq4_nl_f32(
4846
4894
  constant int64_t & ne1,
4847
4895
  constant uint & r2,
4848
4896
  constant uint & r3,
4849
- threadgroup float * shared_values [[threadgroup(0)]],
4897
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
4850
4898
  uint3 tgpig[[threadgroup_position_in_grid]],
4851
4899
  uint tiisg[[thread_index_in_simdgroup]],
4852
4900
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -4875,7 +4923,7 @@ kernel void kernel_mul_mv_iq4_xs_f32(
4875
4923
  constant int64_t & ne1,
4876
4924
  constant uint & r2,
4877
4925
  constant uint & r3,
4878
- threadgroup float * shared_values [[threadgroup(0)]],
4926
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
4879
4927
  uint3 tgpig[[threadgroup_position_in_grid]],
4880
4928
  uint tiisg[[thread_index_in_simdgroup]],
4881
4929
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -5632,25 +5680,25 @@ void kernel_mul_mm_impl(device const uchar * src0,
5632
5680
  }
5633
5681
  }
5634
5682
 
5635
- // same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in src1ids
5683
+ // same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids
5636
5684
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
5637
5685
  void kernel_mul_mm_id_impl(
5638
5686
  device const uchar * src0,
5639
5687
  device const uchar * src1,
5640
- threadgroup short * src1ids,
5688
+ threadgroup ushort2 * rowids,
5641
5689
  device float * dst,
5642
5690
  constant int64_t & ne00,
5643
5691
  constant int64_t & ne02,
5644
5692
  constant uint64_t & nb01,
5645
5693
  constant uint64_t & nb02,
5694
+ constant int64_t & ne11,
5646
5695
  constant int64_t & ne12,
5647
5696
  constant uint64_t & nb10,
5648
5697
  constant uint64_t & nb11,
5649
5698
  constant uint64_t & nb12,
5650
5699
  constant int64_t & ne0,
5651
5700
  int64_t ne1,
5652
- constant uint & r2,
5653
- constant uint & r3,
5701
+ int64_t ne0ne1,
5654
5702
  threadgroup uchar * shared_memory,
5655
5703
  uint3 tgpig[[threadgroup_position_in_grid]],
5656
5704
  uint tiitg[[thread_index_in_threadgroup]],
@@ -5661,7 +5709,6 @@ void kernel_mul_mm_id_impl(
5661
5709
 
5662
5710
  const uint r0 = tgpig.y;
5663
5711
  const uint r1 = tgpig.x;
5664
- const uint im = tgpig.z;
5665
5712
 
5666
5713
  if (r1 * BLOCK_SIZE_N >= ne1) return;
5667
5714
 
@@ -5679,19 +5726,16 @@ void kernel_mul_mm_id_impl(
5679
5726
  for (int i = 0; i < 8; i++){
5680
5727
  c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
5681
5728
  }
5682
-
5683
5729
  short il = (tiitg % THREAD_PER_ROW);
5684
5730
 
5685
- const uint i12 = im%ne12;
5686
- const uint i13 = im/ne12;
5687
-
5688
- uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
5689
5731
  ushort offset1 = il/nl;
5690
5732
 
5691
- device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
5733
+ threadgroup const auto & id = rowids[r1 * BLOCK_SIZE_N + thread_col];
5734
+
5735
+ device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01) + offset1;
5692
5736
  device const float * y = (device const float *)(src1
5693
- + nb12 * im
5694
- + nb11 * src1ids[r1 * BLOCK_SIZE_N + thread_col]
5737
+ + nb12 * id[1]
5738
+ + nb11 * (id[0] % ne11)
5695
5739
  + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
5696
5740
 
5697
5741
  for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
@@ -5720,11 +5764,11 @@ void kernel_mul_mm_id_impl(
5720
5764
 
5721
5765
  for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
5722
5766
  for (int i = 0; i < 4; i++) {
5723
- simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
5767
+ simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
5724
5768
  }
5725
5769
  simdgroup_barrier(mem_flags::mem_none);
5726
5770
  for (int i = 0; i < 2; i++) {
5727
- simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
5771
+ simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
5728
5772
  }
5729
5773
 
5730
5774
  lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
@@ -5746,11 +5790,13 @@ void kernel_mul_mm_id_impl(
5746
5790
 
5747
5791
  threadgroup_barrier(mem_flags::mem_threadgroup);
5748
5792
 
5749
- device float * C = dst + (BLOCK_SIZE_M * r0) + im*ne1*ne0;
5793
+ device float * C = dst + (BLOCK_SIZE_M * r0);
5750
5794
  if (sgitg == 0) {
5751
- for (int i = 0; i < n_rows; i++) {
5752
- for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
5753
- *(C + i + src1ids[j + r1*BLOCK_SIZE_N] * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
5795
+ for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
5796
+ threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j];
5797
+ int joff = jid[0] * ne0 + jid[1] * ne0ne1;
5798
+ for (int i = 0; i < n_rows; i++) {
5799
+ *(C + i + joff) = *(temp_str + i + j * BLOCK_SIZE_M);
5754
5800
  }
5755
5801
  }
5756
5802
  }
@@ -5805,11 +5851,14 @@ kernel void kernel_mul_mm_id(
5805
5851
  device const uchar * src1,
5806
5852
  device float * dst,
5807
5853
  device const uchar * ids,
5854
+ constant int64_t & nei0,
5855
+ constant int64_t & nei1,
5808
5856
  constant uint64_t & nbi1,
5809
5857
  constant int64_t & ne00,
5810
5858
  constant int64_t & ne02,
5811
5859
  constant uint64_t & nb01,
5812
5860
  constant uint64_t & nb02,
5861
+ constant int64_t & ne11,
5813
5862
  constant int64_t & ne12,
5814
5863
  constant int64_t & ne13,
5815
5864
  constant uint64_t & nb10,
@@ -5818,47 +5867,52 @@ kernel void kernel_mul_mm_id(
5818
5867
  constant int64_t & ne0,
5819
5868
  constant int64_t & ne1,
5820
5869
  constant uint64_t & nb1,
5821
- constant uint & r2,
5822
- constant uint & r3,
5823
- constant int & idx,
5824
5870
  threadgroup uchar * shared_memory [[threadgroup(0)]],
5825
5871
  uint3 tgpig[[threadgroup_position_in_grid]],
5826
5872
  uint tiitg[[thread_index_in_threadgroup]],
5827
5873
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
5828
5874
 
5829
- // expert id
5830
- const int32_t id = tgpig.z/(ne12*ne13);
5831
- device const uchar * src0 = src0s + id*nb02;
5875
+ const int32_t i02 = tgpig.z;
5876
+ tgpig.z = 0;
5832
5877
 
5833
- tgpig.z = tgpig.z%(ne12*ne13);
5878
+ device const uchar * src0 = src0s + i02*nb02;
5834
5879
 
5835
- // row indices of src1 for expert id
5836
- threadgroup short * src1ids = (threadgroup short *)(shared_memory + 8192);
5880
+ // row indices
5881
+ threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
5837
5882
 
5883
+ // TODO: parallelize this loop
5838
5884
  int64_t _ne1 = 0;
5839
- for (int64_t i1 = 0; i1 < ne1; i1++) {
5840
- if (((device int32_t *) (ids + i1*nbi1))[idx] == id) {
5841
- src1ids[_ne1++] = i1;
5885
+ for (ushort ii1 = 0; ii1 < nei1; ii1++) {
5886
+ for (ushort ii0 = 0; ii0 < nei0; ii0++) {
5887
+ int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
5888
+ if (id == i02) {
5889
+ //if (tiitg == 0) {
5890
+ rowids[_ne1] = ushort2(ii0, ii1);
5891
+ //}
5892
+ _ne1++;
5893
+ }
5842
5894
  }
5843
5895
  }
5844
5896
 
5897
+ threadgroup_barrier(mem_flags::mem_threadgroup);
5898
+
5845
5899
  kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
5846
5900
  src0,
5847
5901
  src1,
5848
- src1ids,
5902
+ rowids,
5849
5903
  dst,
5850
5904
  ne00,
5851
5905
  ne02,
5852
5906
  nb01,
5853
5907
  nb02,
5908
+ ne11,
5854
5909
  ne12,
5855
5910
  nb10,
5856
5911
  nb11,
5857
5912
  nb12,
5858
5913
  ne0,
5859
5914
  _ne1,
5860
- r2,
5861
- r3,
5915
+ ne0*ne1,
5862
5916
  shared_memory,
5863
5917
  tgpig,
5864
5918
  tiitg,
@@ -5919,24 +5973,7 @@ template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_r
5919
5973
  // matrix-matrix multiplication
5920
5974
  //
5921
5975
 
5922
- typedef void (mat_mm_t)(
5923
- device const uchar * src0,
5924
- device const uchar * src1,
5925
- device float * dst,
5926
- constant int64_t & ne00,
5927
- constant int64_t & ne02,
5928
- constant uint64_t & nb01,
5929
- constant uint64_t & nb02,
5930
- constant int64_t & ne12,
5931
- constant uint64_t & nb10,
5932
- constant uint64_t & nb11,
5933
- constant uint64_t & nb12,
5934
- constant int64_t & ne0,
5935
- constant int64_t & ne1,
5936
- constant uint & r2,
5937
- constant uint & r3,
5938
- threadgroup uchar *,
5939
- uint3, uint, uint);
5976
+ typedef decltype(kernel_mul_mm<float4x4, 1, dequantize_f32>) mat_mm_t;
5940
5977
 
5941
5978
  template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
5942
5979
  template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
@@ -5968,29 +6005,7 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_m
5968
6005
  // indirect matrix-matrix multiplication
5969
6006
  //
5970
6007
 
5971
- typedef void (mat_mm_id_t)(
5972
- device const uchar * src0s,
5973
- device const uchar * src1,
5974
- device float * dst,
5975
- device const uchar * ids,
5976
- constant uint64_t & nbi1,
5977
- constant int64_t & ne00,
5978
- constant int64_t & ne02,
5979
- constant uint64_t & nb01,
5980
- constant uint64_t & nb02,
5981
- constant int64_t & ne12,
5982
- constant int64_t & ne13,
5983
- constant uint64_t & nb10,
5984
- constant uint64_t & nb11,
5985
- constant uint64_t & nb12,
5986
- constant int64_t & ne0,
5987
- constant int64_t & ne1,
5988
- constant uint64_t & nb1,
5989
- constant uint & r2,
5990
- constant uint & r3,
5991
- constant int & idx,
5992
- threadgroup uchar *,
5993
- uint3, uint, uint);
6008
+ typedef decltype(kernel_mul_mm_id<float4x4, 1, dequantize_f32>) mat_mm_id_t;
5994
6009
 
5995
6010
  template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
5996
6011
  template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
@@ -6022,12 +6037,119 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel
6022
6037
  // matrix-vector multiplication
6023
6038
  //
6024
6039
 
6025
- [[host_name("kernel_mul_mv_id_f32_f32")]]
6026
- kernel void kernel_mul_mv_id_f32_f32(
6040
+ typedef void (kernel_mul_mv_impl_t)(
6041
+ device const char * src0,
6042
+ device const char * src1,
6043
+ device float * dst,
6044
+ int64_t ne00,
6045
+ int64_t ne01,
6046
+ int64_t ne02,
6047
+ uint64_t nb00,
6048
+ uint64_t nb01,
6049
+ uint64_t nb02,
6050
+ int64_t ne10,
6051
+ int64_t ne11,
6052
+ int64_t ne12,
6053
+ uint64_t nb10,
6054
+ uint64_t nb11,
6055
+ uint64_t nb12,
6056
+ int64_t ne0,
6057
+ int64_t ne1,
6058
+ uint r2,
6059
+ uint r3,
6060
+ uint3 tgpig,
6061
+ uint tiisg);
6062
+
6063
+ typedef void (kernel_mul_mv2_impl_t)(
6064
+ device const void * src0,
6065
+ device const float * src1,
6066
+ device float * dst,
6067
+ int64_t ne00,
6068
+ int64_t ne01,
6069
+ int64_t ne02,
6070
+ int64_t ne10,
6071
+ int64_t ne12,
6072
+ int64_t ne0,
6073
+ int64_t ne1,
6074
+ uint r2,
6075
+ uint r3,
6076
+ threadgroup int8_t * shared_values,
6077
+ uint3 tgpig,
6078
+ uint tiisg,
6079
+ uint sgitg);
6080
+
6081
+ template<kernel_mul_mv_impl_t impl_fn>
6082
+ void mmv_fn(
6083
+ device const char * src0,
6084
+ device const char * src1,
6085
+ device float * dst,
6086
+ int64_t ne00,
6087
+ int64_t ne01,
6088
+ int64_t ne02,
6089
+ uint64_t nb00,
6090
+ uint64_t nb01,
6091
+ uint64_t nb02,
6092
+ int64_t ne10,
6093
+ int64_t ne11,
6094
+ int64_t ne12,
6095
+ int64_t ne13,
6096
+ uint64_t nb10,
6097
+ uint64_t nb11,
6098
+ uint64_t nb12,
6099
+ int64_t ne0,
6100
+ int64_t ne1,
6101
+ uint64_t nb1,
6102
+ uint r2,
6103
+ uint r3,
6104
+ threadgroup int8_t * shared_values,
6105
+ uint3 tgpig,
6106
+ uint tiitg,
6107
+ uint tiisg,
6108
+ uint sgitg) {
6109
+ impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,ne10,ne11,ne12,nb10,nb11,nb12,ne0,ne1,r2,r3,tgpig,tiisg);
6110
+ }
6111
+
6112
+ template<kernel_mul_mv2_impl_t impl_fn>
6113
+ void mmv_fn(
6114
+ device const char * src0,
6115
+ device const char * src1,
6116
+ device float * dst,
6117
+ int64_t ne00,
6118
+ int64_t ne01,
6119
+ int64_t ne02,
6120
+ uint64_t nb00,
6121
+ uint64_t nb01,
6122
+ uint64_t nb02,
6123
+ int64_t ne10,
6124
+ int64_t ne11,
6125
+ int64_t ne12,
6126
+ int64_t ne13,
6127
+ uint64_t nb10,
6128
+ uint64_t nb11,
6129
+ uint64_t nb12,
6130
+ int64_t ne0,
6131
+ int64_t ne1,
6132
+ uint64_t nb1,
6133
+ uint r2,
6134
+ uint r3,
6135
+ threadgroup int8_t * shared_values,
6136
+ uint3 tgpig,
6137
+ uint tiitg,
6138
+ uint tiisg,
6139
+ uint sgitg) {
6140
+ impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
6141
+ }
6142
+
6143
+ typedef decltype(mmv_fn<kernel_mul_mv_f32_f32_impl>) mul_mv_impl_fn_t;
6144
+
6145
+ template<mul_mv_impl_fn_t impl_fn>
6146
+ kernel void kernel_mul_mv_id(
6027
6147
  device const char * src0s,
6028
6148
  device const char * src1,
6029
6149
  device float * dst,
6030
6150
  device const char * ids,
6151
+ constant int64_t & nei0,
6152
+ constant int64_t & nei1,
6031
6153
  constant uint64_t & nbi1,
6032
6154
  constant int64_t & ne00,
6033
6155
  constant int64_t & ne01,
@@ -6045,1164 +6167,80 @@ kernel void kernel_mul_mv_id_f32_f32(
6045
6167
  constant int64_t & ne0,
6046
6168
  constant int64_t & ne1,
6047
6169
  constant uint64_t & nb1,
6048
- constant uint & r2,
6049
- constant uint & r3,
6050
- constant int & idx,
6170
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
6051
6171
  uint3 tgpig[[threadgroup_position_in_grid]],
6052
6172
  uint tiitg[[thread_index_in_threadgroup]],
6053
6173
  uint tiisg[[thread_index_in_simdgroup]],
6054
6174
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6055
- const int64_t bid = tgpig.z/(ne12*ne13);
6056
-
6057
- tgpig.z = tgpig.z%(ne12*ne13);
6058
-
6059
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6060
- device const char * src0 = src0s + id*nb02;
6061
-
6062
- kernel_mul_mv_f32_f32_impl(
6063
- src0,
6064
- src1 + bid*nb11,
6065
- dst + bid*ne0,
6066
- ne00,
6067
- ne01,
6068
- ne02,
6069
- nb00,
6070
- nb01,
6071
- nb02,
6072
- ne10,
6073
- ne11,
6074
- ne12,
6075
- nb10,
6076
- nb11,
6077
- nb12,
6078
- ne0,
6079
- ne1,
6080
- r2,
6081
- r3,
6175
+ const int iid1 = tgpig.z/nei0;
6176
+ const int idx = tgpig.z%nei0;
6177
+
6178
+ tgpig.z = 0;
6179
+
6180
+ const int32_t i02 = ((device const int32_t *) (ids + iid1*nbi1))[idx];
6181
+
6182
+ const int64_t i11 = idx % ne11;
6183
+ const int64_t i12 = iid1;
6184
+
6185
+ const int64_t i1 = idx;
6186
+ const int64_t i2 = i12;
6187
+
6188
+ device const char * src0_cur = src0s + i02*nb02;
6189
+ device const char * src1_cur = src1 + i11*nb11 + i12*nb12;
6190
+ device float * dst_cur = dst + i1*ne0 + i2*ne1*ne0;
6191
+
6192
+ impl_fn(
6193
+ /* src0 */ src0_cur,
6194
+ /* src1 */ src1_cur,
6195
+ /* dst */ dst_cur,
6196
+ /* ne00 */ ne00,
6197
+ /* ne01 */ ne01,
6198
+ /* ne02 */ 1,//ne02,
6199
+ /* nb00 */ nb00,
6200
+ /* nb01 */ nb01,
6201
+ /* nb02 */ nb02,
6202
+ /* ne10 */ ne10,
6203
+ /* ne11 */ 1,//ne11,
6204
+ /* ne12 */ 1,//ne12,
6205
+ /* ne13 */ 1,//ne13,
6206
+ /* nb10 */ nb10,
6207
+ /* nb11 */ nb11,
6208
+ /* nb12 */ nb12,
6209
+ /* ne0 */ ne0,
6210
+ /* ne1 */ 1,//ne1,
6211
+ /* nb1 */ nb1,
6212
+ /* r2 */ 1,
6213
+ /* r3 */ 1,
6214
+ shared_values,
6082
6215
  tgpig,
6083
- tiisg);
6216
+ tiitg,
6217
+ tiisg,
6218
+ sgitg);
6084
6219
  }
6085
6220
 
6086
- [[host_name("kernel_mul_mv_id_f16_f32")]]
6087
- kernel void kernel_mul_mv_id_f16_f32(
6088
- device const char * src0s,
6089
- device const char * src1,
6090
- device float * dst,
6091
- device const char * ids,
6092
- constant uint64_t & nbi1,
6093
- constant int64_t & ne00,
6094
- constant int64_t & ne01,
6095
- constant int64_t & ne02,
6096
- constant uint64_t & nb00,
6097
- constant uint64_t & nb01,
6098
- constant uint64_t & nb02,
6099
- constant int64_t & ne10,
6100
- constant int64_t & ne11,
6101
- constant int64_t & ne12,
6102
- constant int64_t & ne13,
6103
- constant uint64_t & nb10,
6104
- constant uint64_t & nb11,
6105
- constant uint64_t & nb12,
6106
- constant int64_t & ne0,
6107
- constant int64_t & ne1,
6108
- constant uint64_t & nb1,
6109
- constant uint & r2,
6110
- constant uint & r3,
6111
- constant int & idx,
6112
- uint3 tgpig[[threadgroup_position_in_grid]],
6113
- uint tiitg[[thread_index_in_threadgroup]],
6114
- uint tiisg[[thread_index_in_simdgroup]],
6115
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6116
- const int64_t bid = tgpig.z/(ne12*ne13);
6117
-
6118
- tgpig.z = tgpig.z%(ne12*ne13);
6119
-
6120
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6121
- device const char * src0 = src0s + id*nb02;
6122
-
6123
- kernel_mul_mv_f16_f32_impl(
6124
- src0,
6125
- src1 + bid*nb11,
6126
- dst + bid*ne0,
6127
- ne00,
6128
- ne01,
6129
- ne02,
6130
- nb00,
6131
- nb01,
6132
- nb02,
6133
- ne10,
6134
- ne11,
6135
- ne12,
6136
- nb10,
6137
- nb11,
6138
- nb12,
6139
- ne0,
6140
- ne1,
6141
- r2,
6142
- r3,
6143
- tgpig,
6144
- tiisg);
6145
- }
6146
-
6147
- [[host_name("kernel_mul_mv_id_q8_0_f32")]]
6148
- kernel void kernel_mul_mv_id_q8_0_f32(
6149
- device const char * src0s,
6150
- device const char * src1,
6151
- device float * dst,
6152
- device const char * ids,
6153
- constant uint64_t & nbi1,
6154
- constant int64_t & ne00,
6155
- constant int64_t & ne01,
6156
- constant int64_t & ne02,
6157
- constant uint64_t & nb00,
6158
- constant uint64_t & nb01,
6159
- constant uint64_t & nb02,
6160
- constant int64_t & ne10,
6161
- constant int64_t & ne11,
6162
- constant int64_t & ne12,
6163
- constant int64_t & ne13,
6164
- constant uint64_t & nb10,
6165
- constant uint64_t & nb11,
6166
- constant uint64_t & nb12,
6167
- constant int64_t & ne0,
6168
- constant int64_t & ne1,
6169
- constant uint64_t & nb1,
6170
- constant uint & r2,
6171
- constant uint & r3,
6172
- constant int & idx,
6173
- uint3 tgpig[[threadgroup_position_in_grid]],
6174
- uint tiitg[[thread_index_in_threadgroup]],
6175
- uint tiisg[[thread_index_in_simdgroup]],
6176
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6177
- const int64_t bid = tgpig.z/(ne12*ne13);
6178
-
6179
- tgpig.z = tgpig.z%(ne12*ne13);
6180
-
6181
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6182
- device const char * src0 = src0s + id*nb02;
6183
-
6184
- kernel_mul_mv_q8_0_f32_impl(
6185
- src0,
6186
- (device const float *) (src1 + bid*nb11),
6187
- dst + bid*ne0,
6188
- ne00,
6189
- ne01,
6190
- ne02,
6191
- ne10,
6192
- ne12,
6193
- ne0,
6194
- ne1,
6195
- r2,
6196
- r3,
6197
- tgpig,
6198
- tiisg,
6199
- sgitg);
6200
- }
6201
-
6202
- [[host_name("kernel_mul_mv_id_q4_0_f32")]]
6203
- kernel void kernel_mul_mv_id_q4_0_f32(
6204
- device const char * src0s,
6205
- device const char * src1,
6206
- device float * dst,
6207
- device const char * ids,
6208
- constant uint64_t & nbi1,
6209
- constant int64_t & ne00,
6210
- constant int64_t & ne01,
6211
- constant int64_t & ne02,
6212
- constant uint64_t & nb00,
6213
- constant uint64_t & nb01,
6214
- constant uint64_t & nb02,
6215
- constant int64_t & ne10,
6216
- constant int64_t & ne11,
6217
- constant int64_t & ne12,
6218
- constant int64_t & ne13,
6219
- constant uint64_t & nb10,
6220
- constant uint64_t & nb11,
6221
- constant uint64_t & nb12,
6222
- constant int64_t & ne0,
6223
- constant int64_t & ne1,
6224
- constant uint64_t & nb1,
6225
- constant uint & r2,
6226
- constant uint & r3,
6227
- constant int & idx,
6228
- uint3 tgpig[[threadgroup_position_in_grid]],
6229
- uint tiitg[[thread_index_in_threadgroup]],
6230
- uint tiisg[[thread_index_in_simdgroup]],
6231
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6232
- const int64_t bid = tgpig.z/(ne12*ne13);
6233
-
6234
- tgpig.z = tgpig.z%(ne12*ne13);
6235
-
6236
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6237
- device const char * src0 = src0s + id*nb02;
6238
-
6239
- mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
6240
- src0,
6241
- (device const float *) (src1 + bid*nb11),
6242
- dst + bid*ne0,
6243
- ne00,
6244
- ne01,
6245
- ne02,
6246
- ne10,
6247
- ne12,
6248
- ne0,
6249
- ne1,
6250
- r2,
6251
- r3,
6252
- tgpig,
6253
- tiisg,
6254
- sgitg);
6255
- }
6256
-
6257
- [[host_name("kernel_mul_mv_id_q4_1_f32")]]
6258
- kernel void kernel_mul_mv_id_q4_1_f32(
6259
- device const char * src0s,
6260
- device const char * src1,
6261
- device float * dst,
6262
- device const char * ids,
6263
- constant uint64_t & nbi1,
6264
- constant int64_t & ne00,
6265
- constant int64_t & ne01,
6266
- constant int64_t & ne02,
6267
- constant uint64_t & nb00,
6268
- constant uint64_t & nb01,
6269
- constant uint64_t & nb02,
6270
- constant int64_t & ne10,
6271
- constant int64_t & ne11,
6272
- constant int64_t & ne12,
6273
- constant int64_t & ne13,
6274
- constant uint64_t & nb10,
6275
- constant uint64_t & nb11,
6276
- constant uint64_t & nb12,
6277
- constant int64_t & ne0,
6278
- constant int64_t & ne1,
6279
- constant uint64_t & nb1,
6280
- constant uint & r2,
6281
- constant uint & r3,
6282
- constant int & idx,
6283
- uint3 tgpig[[threadgroup_position_in_grid]],
6284
- uint tiitg[[thread_index_in_threadgroup]],
6285
- uint tiisg[[thread_index_in_simdgroup]],
6286
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6287
- const int64_t bid = tgpig.z/(ne12*ne13);
6288
-
6289
- tgpig.z = tgpig.z%(ne12*ne13);
6290
-
6291
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6292
- device const char * src0 = src0s + id*nb02;
6293
-
6294
- mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
6295
- src0,
6296
- (device const float *) (src1 + bid*nb11),
6297
- dst + bid*ne0,
6298
- ne00,
6299
- ne01,
6300
- ne02,
6301
- ne10,
6302
- ne12,
6303
- ne0,
6304
- ne1,
6305
- r2,
6306
- r3,
6307
- tgpig,
6308
- tiisg,
6309
- sgitg);
6310
- }
6311
-
6312
- [[host_name("kernel_mul_mv_id_q5_0_f32")]]
6313
- kernel void kernel_mul_mv_id_q5_0_f32(
6314
- device const char * src0s,
6315
- device const char * src1,
6316
- device float * dst,
6317
- device const char * ids,
6318
- constant uint64_t & nbi1,
6319
- constant int64_t & ne00,
6320
- constant int64_t & ne01,
6321
- constant int64_t & ne02,
6322
- constant uint64_t & nb00,
6323
- constant uint64_t & nb01,
6324
- constant uint64_t & nb02,
6325
- constant int64_t & ne10,
6326
- constant int64_t & ne11,
6327
- constant int64_t & ne12,
6328
- constant int64_t & ne13,
6329
- constant uint64_t & nb10,
6330
- constant uint64_t & nb11,
6331
- constant uint64_t & nb12,
6332
- constant int64_t & ne0,
6333
- constant int64_t & ne1,
6334
- constant uint64_t & nb1,
6335
- constant uint & r2,
6336
- constant uint & r3,
6337
- constant int & idx,
6338
- uint3 tgpig[[threadgroup_position_in_grid]],
6339
- uint tiitg[[thread_index_in_threadgroup]],
6340
- uint tiisg[[thread_index_in_simdgroup]],
6341
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6342
- const int64_t bid = tgpig.z/(ne12*ne13);
6343
-
6344
- tgpig.z = tgpig.z%(ne12*ne13);
6345
-
6346
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6347
- device const char * src0 = src0s + id*nb02;
6348
-
6349
- mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
6350
- src0,
6351
- (device const float *) (src1 + bid*nb11),
6352
- dst + bid*ne0,
6353
- ne00,
6354
- ne01,
6355
- ne02,
6356
- ne10,
6357
- ne12,
6358
- ne0,
6359
- ne1,
6360
- r2,
6361
- r3,
6362
- tgpig,
6363
- tiisg,
6364
- sgitg);
6365
- }
6366
-
6367
- [[host_name("kernel_mul_mv_id_q5_1_f32")]]
6368
- kernel void kernel_mul_mv_id_q5_1_f32(
6369
- device const char * src0s,
6370
- device const char * src1,
6371
- device float * dst,
6372
- device const char * ids,
6373
- constant uint64_t & nbi1,
6374
- constant int64_t & ne00,
6375
- constant int64_t & ne01,
6376
- constant int64_t & ne02,
6377
- constant uint64_t & nb00,
6378
- constant uint64_t & nb01,
6379
- constant uint64_t & nb02,
6380
- constant int64_t & ne10,
6381
- constant int64_t & ne11,
6382
- constant int64_t & ne12,
6383
- constant int64_t & ne13,
6384
- constant uint64_t & nb10,
6385
- constant uint64_t & nb11,
6386
- constant uint64_t & nb12,
6387
- constant int64_t & ne0,
6388
- constant int64_t & ne1,
6389
- constant uint64_t & nb1,
6390
- constant uint & r2,
6391
- constant uint & r3,
6392
- constant int & idx,
6393
- uint3 tgpig[[threadgroup_position_in_grid]],
6394
- uint tiitg[[thread_index_in_threadgroup]],
6395
- uint tiisg[[thread_index_in_simdgroup]],
6396
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6397
- const int64_t bid = tgpig.z/(ne12*ne13);
6398
-
6399
- tgpig.z = tgpig.z%(ne12*ne13);
6400
-
6401
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6402
- device const char * src0 = src0s + id*nb02;
6403
-
6404
- mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
6405
- src0,
6406
- (device const float *) (src1 + bid*nb11),
6407
- dst + bid*ne0,
6408
- ne00,
6409
- ne01,
6410
- ne02,
6411
- ne10,
6412
- ne12,
6413
- ne0,
6414
- ne1,
6415
- r2,
6416
- r3,
6417
- tgpig,
6418
- tiisg,
6419
- sgitg);
6420
- }
6421
-
6422
- [[host_name("kernel_mul_mv_id_q2_K_f32")]]
6423
- kernel void kernel_mul_mv_id_q2_K_f32(
6424
- device const char * src0s,
6425
- device const char * src1,
6426
- device float * dst,
6427
- device const char * ids,
6428
- constant uint64_t & nbi1,
6429
- constant int64_t & ne00,
6430
- constant int64_t & ne01,
6431
- constant int64_t & ne02,
6432
- constant uint64_t & nb00,
6433
- constant uint64_t & nb01,
6434
- constant uint64_t & nb02,
6435
- constant int64_t & ne10,
6436
- constant int64_t & ne11,
6437
- constant int64_t & ne12,
6438
- constant int64_t & ne13,
6439
- constant uint64_t & nb10,
6440
- constant uint64_t & nb11,
6441
- constant uint64_t & nb12,
6442
- constant int64_t & ne0,
6443
- constant int64_t & ne1,
6444
- constant uint64_t & nb1,
6445
- constant uint & r2,
6446
- constant uint & r3,
6447
- constant int & idx,
6448
- uint3 tgpig[[threadgroup_position_in_grid]],
6449
- uint tiitg[[thread_index_in_threadgroup]],
6450
- uint tiisg[[thread_index_in_simdgroup]],
6451
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6452
- const int64_t bid = tgpig.z/(ne12*ne13);
6453
-
6454
- tgpig.z = tgpig.z%(ne12*ne13);
6455
-
6456
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6457
- device const char * src0 = src0s + id*nb02;
6458
-
6459
- kernel_mul_mv_q2_K_f32_impl(
6460
- src0,
6461
- (device const float *) (src1 + bid*nb11),
6462
- dst + bid*ne0,
6463
- ne00,
6464
- ne01,
6465
- ne02,
6466
- ne10,
6467
- ne12,
6468
- ne0,
6469
- ne1,
6470
- r2,
6471
- r3,
6472
- tgpig,
6473
- tiisg,
6474
- sgitg);
6475
- }
6476
-
6477
- [[host_name("kernel_mul_mv_id_q3_K_f32")]]
6478
- kernel void kernel_mul_mv_id_q3_K_f32(
6479
- device const char * src0s,
6480
- device const char * src1,
6481
- device float * dst,
6482
- device const char * ids,
6483
- constant uint64_t & nbi1,
6484
- constant int64_t & ne00,
6485
- constant int64_t & ne01,
6486
- constant int64_t & ne02,
6487
- constant uint64_t & nb00,
6488
- constant uint64_t & nb01,
6489
- constant uint64_t & nb02,
6490
- constant int64_t & ne10,
6491
- constant int64_t & ne11,
6492
- constant int64_t & ne12,
6493
- constant int64_t & ne13,
6494
- constant uint64_t & nb10,
6495
- constant uint64_t & nb11,
6496
- constant uint64_t & nb12,
6497
- constant int64_t & ne0,
6498
- constant int64_t & ne1,
6499
- constant uint64_t & nb1,
6500
- constant uint & r2,
6501
- constant uint & r3,
6502
- constant int & idx,
6503
- uint3 tgpig[[threadgroup_position_in_grid]],
6504
- uint tiitg[[thread_index_in_threadgroup]],
6505
- uint tiisg[[thread_index_in_simdgroup]],
6506
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6507
- const int64_t bid = tgpig.z/(ne12*ne13);
6508
-
6509
- tgpig.z = tgpig.z%(ne12*ne13);
6510
-
6511
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6512
- device const char * src0 = src0s + id*nb02;
6513
-
6514
- kernel_mul_mv_q3_K_f32_impl(
6515
- src0,
6516
- (device const float *) (src1 + bid*nb11),
6517
- dst + bid*ne0,
6518
- ne00,
6519
- ne01,
6520
- ne02,
6521
- ne10,
6522
- ne12,
6523
- ne0,
6524
- ne1,
6525
- r2,
6526
- r3,
6527
- tgpig,
6528
- tiisg,
6529
- sgitg);
6530
- }
6531
-
6532
- [[host_name("kernel_mul_mv_id_q4_K_f32")]]
6533
- kernel void kernel_mul_mv_id_q4_K_f32(
6534
- device const char * src0s,
6535
- device const char * src1,
6536
- device float * dst,
6537
- device const char * ids,
6538
- constant uint64_t & nbi1,
6539
- constant int64_t & ne00,
6540
- constant int64_t & ne01,
6541
- constant int64_t & ne02,
6542
- constant uint64_t & nb00,
6543
- constant uint64_t & nb01,
6544
- constant uint64_t & nb02,
6545
- constant int64_t & ne10,
6546
- constant int64_t & ne11,
6547
- constant int64_t & ne12,
6548
- constant int64_t & ne13,
6549
- constant uint64_t & nb10,
6550
- constant uint64_t & nb11,
6551
- constant uint64_t & nb12,
6552
- constant int64_t & ne0,
6553
- constant int64_t & ne1,
6554
- constant uint64_t & nb1,
6555
- constant uint & r2,
6556
- constant uint & r3,
6557
- constant int & idx,
6558
- uint3 tgpig[[threadgroup_position_in_grid]],
6559
- uint tiitg[[thread_index_in_threadgroup]],
6560
- uint tiisg[[thread_index_in_simdgroup]],
6561
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6562
- const int64_t bid = tgpig.z/(ne12*ne13);
6563
-
6564
- tgpig.z = tgpig.z%(ne12*ne13);
6565
-
6566
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6567
- device const char * src0 = src0s + id*nb02;
6568
-
6569
- kernel_mul_mv_q4_K_f32_impl(
6570
- src0,
6571
- (device const float *) (src1 + bid*nb11),
6572
- dst + bid*ne0,
6573
- ne00,
6574
- ne01,
6575
- ne02,
6576
- ne10,
6577
- ne12,
6578
- ne0,
6579
- ne1,
6580
- r2,
6581
- r3,
6582
- tgpig,
6583
- tiisg,
6584
- sgitg);
6585
- }
6586
-
6587
- [[host_name("kernel_mul_mv_id_q5_K_f32")]]
6588
- kernel void kernel_mul_mv_id_q5_K_f32(
6589
- device const char * src0s,
6590
- device const char * src1,
6591
- device float * dst,
6592
- device const char * ids,
6593
- constant uint64_t & nbi1,
6594
- constant int64_t & ne00,
6595
- constant int64_t & ne01,
6596
- constant int64_t & ne02,
6597
- constant uint64_t & nb00,
6598
- constant uint64_t & nb01,
6599
- constant uint64_t & nb02,
6600
- constant int64_t & ne10,
6601
- constant int64_t & ne11,
6602
- constant int64_t & ne12,
6603
- constant int64_t & ne13,
6604
- constant uint64_t & nb10,
6605
- constant uint64_t & nb11,
6606
- constant uint64_t & nb12,
6607
- constant int64_t & ne0,
6608
- constant int64_t & ne1,
6609
- constant uint64_t & nb1,
6610
- constant uint & r2,
6611
- constant uint & r3,
6612
- constant int & idx,
6613
- uint3 tgpig[[threadgroup_position_in_grid]],
6614
- uint tiitg[[thread_index_in_threadgroup]],
6615
- uint tiisg[[thread_index_in_simdgroup]],
6616
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6617
- const int64_t bid = tgpig.z/(ne12*ne13);
6618
-
6619
- tgpig.z = tgpig.z%(ne12*ne13);
6620
-
6621
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6622
- device const char * src0 = src0s + id*nb02;
6623
-
6624
- kernel_mul_mv_q5_K_f32_impl(
6625
- src0,
6626
- (device const float *) (src1 + bid*nb11),
6627
- dst + bid*ne0,
6628
- ne00,
6629
- ne01,
6630
- ne02,
6631
- ne10,
6632
- ne12,
6633
- ne0,
6634
- ne1,
6635
- r2,
6636
- r3,
6637
- tgpig,
6638
- tiisg,
6639
- sgitg);
6640
- }
6641
-
6642
- [[host_name("kernel_mul_mv_id_q6_K_f32")]]
6643
- kernel void kernel_mul_mv_id_q6_K_f32(
6644
- device const char * src0s,
6645
- device const char * src1,
6646
- device float * dst,
6647
- device const char * ids,
6648
- constant uint64_t & nbi1,
6649
- constant int64_t & ne00,
6650
- constant int64_t & ne01,
6651
- constant int64_t & ne02,
6652
- constant uint64_t & nb00,
6653
- constant uint64_t & nb01,
6654
- constant uint64_t & nb02,
6655
- constant int64_t & ne10,
6656
- constant int64_t & ne11,
6657
- constant int64_t & ne12,
6658
- constant int64_t & ne13,
6659
- constant uint64_t & nb10,
6660
- constant uint64_t & nb11,
6661
- constant uint64_t & nb12,
6662
- constant int64_t & ne0,
6663
- constant int64_t & ne1,
6664
- constant uint64_t & nb1,
6665
- constant uint & r2,
6666
- constant uint & r3,
6667
- constant int & idx,
6668
- uint3 tgpig[[threadgroup_position_in_grid]],
6669
- uint tiitg[[thread_index_in_threadgroup]],
6670
- uint tiisg[[thread_index_in_simdgroup]],
6671
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6672
- const int64_t bid = tgpig.z/(ne12*ne13);
6673
-
6674
- tgpig.z = tgpig.z%(ne12*ne13);
6675
-
6676
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6677
- device const char * src0 = src0s + id*nb02;
6678
-
6679
- kernel_mul_mv_q6_K_f32_impl(
6680
- src0,
6681
- (device const float *) (src1 + bid*nb11),
6682
- dst + bid*ne0,
6683
- ne00,
6684
- ne01,
6685
- ne02,
6686
- ne10,
6687
- ne12,
6688
- ne0,
6689
- ne1,
6690
- r2,
6691
- r3,
6692
- tgpig,
6693
- tiisg,
6694
- sgitg);
6695
- }
6696
-
6697
- [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
6698
- kernel void kernel_mul_mv_id_iq2_xxs_f32(
6699
- device const char * src0s,
6700
- device const char * src1,
6701
- device float * dst,
6702
- device const char * ids,
6703
- constant uint64_t & nbi1,
6704
- constant int64_t & ne00,
6705
- constant int64_t & ne01,
6706
- constant int64_t & ne02,
6707
- constant uint64_t & nb00,
6708
- constant uint64_t & nb01,
6709
- constant uint64_t & nb02,
6710
- constant int64_t & ne10,
6711
- constant int64_t & ne11,
6712
- constant int64_t & ne12,
6713
- constant int64_t & ne13,
6714
- constant uint64_t & nb10,
6715
- constant uint64_t & nb11,
6716
- constant uint64_t & nb12,
6717
- constant int64_t & ne0,
6718
- constant int64_t & ne1,
6719
- constant uint64_t & nb1,
6720
- constant uint & r2,
6721
- constant uint & r3,
6722
- constant int & idx,
6723
- threadgroup int8_t * shared_values [[threadgroup(0)]],
6724
- uint3 tgpig[[threadgroup_position_in_grid]],
6725
- uint tiitg[[thread_index_in_threadgroup]],
6726
- uint tiisg[[thread_index_in_simdgroup]],
6727
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6728
- const int64_t bid = tgpig.z/(ne12*ne13);
6729
-
6730
- tgpig.z = tgpig.z%(ne12*ne13);
6731
-
6732
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6733
- device const char * src0 = src0s + id*nb02;
6734
-
6735
- kernel_mul_mv_iq2_xxs_f32_impl(
6736
- src0,
6737
- (device const float *) (src1 + bid*nb11),
6738
- dst + bid*ne0,
6739
- ne00,
6740
- ne01,
6741
- ne02,
6742
- ne10,
6743
- ne12,
6744
- ne0,
6745
- ne1,
6746
- r2,
6747
- r3,
6748
- shared_values,
6749
- tgpig,
6750
- tiisg,
6751
- sgitg);
6752
- }
6753
-
6754
- [[host_name("kernel_mul_mv_id_iq2_xs_f32")]]
6755
- kernel void kernel_mul_mv_id_iq2_xs_f32(
6756
- device const char * src0s,
6757
- device const char * src1,
6758
- device float * dst,
6759
- device const char * ids,
6760
- constant uint64_t & nbi1,
6761
- constant int64_t & ne00,
6762
- constant int64_t & ne01,
6763
- constant int64_t & ne02,
6764
- constant uint64_t & nb00,
6765
- constant uint64_t & nb01,
6766
- constant uint64_t & nb02,
6767
- constant int64_t & ne10,
6768
- constant int64_t & ne11,
6769
- constant int64_t & ne12,
6770
- constant int64_t & ne13,
6771
- constant uint64_t & nb10,
6772
- constant uint64_t & nb11,
6773
- constant uint64_t & nb12,
6774
- constant int64_t & ne0,
6775
- constant int64_t & ne1,
6776
- constant uint64_t & nb1,
6777
- constant uint & r2,
6778
- constant uint & r3,
6779
- constant int & idx,
6780
- threadgroup int8_t * shared_values [[threadgroup(0)]],
6781
- uint3 tgpig[[threadgroup_position_in_grid]],
6782
- uint tiitg[[thread_index_in_threadgroup]],
6783
- uint tiisg[[thread_index_in_simdgroup]],
6784
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6785
- const int64_t bid = tgpig.z/(ne12*ne13);
6786
-
6787
- tgpig.z = tgpig.z%(ne12*ne13);
6788
-
6789
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6790
- device const char * src0 = src0s + id*nb02;
6791
-
6792
- kernel_mul_mv_iq2_xs_f32_impl(
6793
- src0,
6794
- (device const float *) (src1 + bid*nb11),
6795
- dst + bid*ne0,
6796
- ne00,
6797
- ne01,
6798
- ne02,
6799
- ne10,
6800
- ne12,
6801
- ne0,
6802
- ne1,
6803
- r2,
6804
- r3,
6805
- shared_values,
6806
- tgpig,
6807
- tiisg,
6808
- sgitg);
6809
- }
6810
-
6811
- [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]]
6812
- kernel void kernel_mul_mv_id_iq3_xxs_f32(
6813
- device const char * src0s,
6814
- device const char * src1,
6815
- device float * dst,
6816
- device const char * ids,
6817
- constant uint64_t & nbi1,
6818
- constant int64_t & ne00,
6819
- constant int64_t & ne01,
6820
- constant int64_t & ne02,
6821
- constant uint64_t & nb00,
6822
- constant uint64_t & nb01,
6823
- constant uint64_t & nb02,
6824
- constant int64_t & ne10,
6825
- constant int64_t & ne11,
6826
- constant int64_t & ne12,
6827
- constant int64_t & ne13,
6828
- constant uint64_t & nb10,
6829
- constant uint64_t & nb11,
6830
- constant uint64_t & nb12,
6831
- constant int64_t & ne0,
6832
- constant int64_t & ne1,
6833
- constant uint64_t & nb1,
6834
- constant uint & r2,
6835
- constant uint & r3,
6836
- constant int & idx,
6837
- threadgroup int8_t * shared_values [[threadgroup(0)]],
6838
- uint3 tgpig[[threadgroup_position_in_grid]],
6839
- uint tiitg[[thread_index_in_threadgroup]],
6840
- uint tiisg[[thread_index_in_simdgroup]],
6841
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6842
- const int64_t bid = tgpig.z/(ne12*ne13);
6843
-
6844
- tgpig.z = tgpig.z%(ne12*ne13);
6845
-
6846
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6847
- device const char * src0 = src0s + id*nb02;
6848
-
6849
- kernel_mul_mv_iq3_xxs_f32_impl(
6850
- src0,
6851
- (device const float *) (src1 + bid*nb11),
6852
- dst + bid*ne0,
6853
- ne00,
6854
- ne01,
6855
- ne02,
6856
- ne10,
6857
- ne12,
6858
- ne0,
6859
- ne1,
6860
- r2,
6861
- r3,
6862
- shared_values,
6863
- tgpig,
6864
- tiisg,
6865
- sgitg);
6866
- }
6867
-
6868
- [[host_name("kernel_mul_mv_id_iq3_s_f32")]]
6869
- kernel void kernel_mul_mv_id_iq3_s_f32(
6870
- device const char * src0s,
6871
- device const char * src1,
6872
- device float * dst,
6873
- device const char * ids,
6874
- constant uint64_t & nbi1,
6875
- constant int64_t & ne00,
6876
- constant int64_t & ne01,
6877
- constant int64_t & ne02,
6878
- constant uint64_t & nb00,
6879
- constant uint64_t & nb01,
6880
- constant uint64_t & nb02,
6881
- constant int64_t & ne10,
6882
- constant int64_t & ne11,
6883
- constant int64_t & ne12,
6884
- constant int64_t & ne13,
6885
- constant uint64_t & nb10,
6886
- constant uint64_t & nb11,
6887
- constant uint64_t & nb12,
6888
- constant int64_t & ne0,
6889
- constant int64_t & ne1,
6890
- constant uint64_t & nb1,
6891
- constant uint & r2,
6892
- constant uint & r3,
6893
- constant int & idx,
6894
- threadgroup int8_t * shared_values [[threadgroup(0)]],
6895
- uint3 tgpig[[threadgroup_position_in_grid]],
6896
- uint tiitg[[thread_index_in_threadgroup]],
6897
- uint tiisg[[thread_index_in_simdgroup]],
6898
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6899
- const int64_t bid = tgpig.z/(ne12*ne13);
6900
-
6901
- tgpig.z = tgpig.z%(ne12*ne13);
6902
-
6903
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6904
- device const char * src0 = src0s + id*nb02;
6905
-
6906
- kernel_mul_mv_iq3_s_f32_impl(
6907
- src0,
6908
- (device const float *) (src1 + bid*nb11),
6909
- dst + bid*ne0,
6910
- ne00,
6911
- ne01,
6912
- ne02,
6913
- ne10,
6914
- ne12,
6915
- ne0,
6916
- ne1,
6917
- r2,
6918
- r3,
6919
- shared_values,
6920
- tgpig,
6921
- tiisg,
6922
- sgitg);
6923
- }
6924
-
6925
- [[host_name("kernel_mul_mv_id_iq2_s_f32")]]
6926
- kernel void kernel_mul_mv_id_iq2_s_f32(
6927
- device const char * src0s,
6928
- device const char * src1,
6929
- device float * dst,
6930
- device const char * ids,
6931
- constant uint64_t & nbi1,
6932
- constant int64_t & ne00,
6933
- constant int64_t & ne01,
6934
- constant int64_t & ne02,
6935
- constant uint64_t & nb00,
6936
- constant uint64_t & nb01,
6937
- constant uint64_t & nb02,
6938
- constant int64_t & ne10,
6939
- constant int64_t & ne11,
6940
- constant int64_t & ne12,
6941
- constant int64_t & ne13,
6942
- constant uint64_t & nb10,
6943
- constant uint64_t & nb11,
6944
- constant uint64_t & nb12,
6945
- constant int64_t & ne0,
6946
- constant int64_t & ne1,
6947
- constant uint64_t & nb1,
6948
- constant uint & r2,
6949
- constant uint & r3,
6950
- constant int & idx,
6951
- threadgroup int8_t * shared_values [[threadgroup(0)]],
6952
- uint3 tgpig[[threadgroup_position_in_grid]],
6953
- uint tiitg[[thread_index_in_threadgroup]],
6954
- uint tiisg[[thread_index_in_simdgroup]],
6955
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6956
- const int64_t bid = tgpig.z/(ne12*ne13);
6957
-
6958
- tgpig.z = tgpig.z%(ne12*ne13);
6959
-
6960
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6961
- device const char * src0 = src0s + id*nb02;
6962
-
6963
- kernel_mul_mv_iq2_s_f32_impl(
6964
- src0,
6965
- (device const float *) (src1 + bid*nb11),
6966
- dst + bid*ne0,
6967
- ne00,
6968
- ne01,
6969
- ne02,
6970
- ne10,
6971
- ne12,
6972
- ne0,
6973
- ne1,
6974
- r2,
6975
- r3,
6976
- shared_values,
6977
- tgpig,
6978
- tiisg,
6979
- sgitg);
6980
- }
6981
-
6982
- [[host_name("kernel_mul_mv_id_iq1_s_f32")]]
6983
- kernel void kernel_mul_mv_id_iq1_s_f32(
6984
- device const char * src0s,
6985
- device const char * src1,
6986
- device float * dst,
6987
- device const char * ids,
6988
- constant uint64_t & nbi1,
6989
- constant int64_t & ne00,
6990
- constant int64_t & ne01,
6991
- constant int64_t & ne02,
6992
- constant uint64_t & nb00,
6993
- constant uint64_t & nb01,
6994
- constant uint64_t & nb02,
6995
- constant int64_t & ne10,
6996
- constant int64_t & ne11,
6997
- constant int64_t & ne12,
6998
- constant int64_t & ne13,
6999
- constant uint64_t & nb10,
7000
- constant uint64_t & nb11,
7001
- constant uint64_t & nb12,
7002
- constant int64_t & ne0,
7003
- constant int64_t & ne1,
7004
- constant uint64_t & nb1,
7005
- constant uint & r2,
7006
- constant uint & r3,
7007
- constant int & idx,
7008
- uint3 tgpig[[threadgroup_position_in_grid]],
7009
- uint tiitg[[thread_index_in_threadgroup]],
7010
- uint tiisg[[thread_index_in_simdgroup]],
7011
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
7012
- const int64_t bid = tgpig.z/(ne12*ne13);
7013
-
7014
- tgpig.z = tgpig.z%(ne12*ne13);
7015
-
7016
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
7017
- device const char * src0 = src0s + id*nb02;
7018
-
7019
- kernel_mul_mv_iq1_s_f32_impl(
7020
- src0,
7021
- (device const float *) (src1 + bid*nb11),
7022
- dst + bid*ne0,
7023
- ne00,
7024
- ne01,
7025
- ne02,
7026
- ne10,
7027
- ne12,
7028
- ne0,
7029
- ne1,
7030
- r2,
7031
- r3,
7032
- tgpig,
7033
- tiisg,
7034
- sgitg);
7035
- }
7036
-
7037
- [[host_name("kernel_mul_mv_id_iq1_m_f32")]]
7038
- kernel void kernel_mul_mv_id_iq1_m_f32(
7039
- device const char * src0s,
7040
- device const char * src1,
7041
- device float * dst,
7042
- device const char * ids,
7043
- constant uint64_t & nbi1,
7044
- constant int64_t & ne00,
7045
- constant int64_t & ne01,
7046
- constant int64_t & ne02,
7047
- constant uint64_t & nb00,
7048
- constant uint64_t & nb01,
7049
- constant uint64_t & nb02,
7050
- constant int64_t & ne10,
7051
- constant int64_t & ne11,
7052
- constant int64_t & ne12,
7053
- constant int64_t & ne13,
7054
- constant uint64_t & nb10,
7055
- constant uint64_t & nb11,
7056
- constant uint64_t & nb12,
7057
- constant int64_t & ne0,
7058
- constant int64_t & ne1,
7059
- constant uint64_t & nb1,
7060
- constant uint & r2,
7061
- constant uint & r3,
7062
- constant int & idx,
7063
- uint3 tgpig[[threadgroup_position_in_grid]],
7064
- uint tiitg[[thread_index_in_threadgroup]],
7065
- uint tiisg[[thread_index_in_simdgroup]],
7066
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
7067
- const int64_t bid = tgpig.z/(ne12*ne13);
7068
-
7069
- tgpig.z = tgpig.z%(ne12*ne13);
7070
-
7071
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
7072
- device const char * src0 = src0s + id*nb02;
7073
-
7074
- kernel_mul_mv_iq1_m_f32_impl(
7075
- src0,
7076
- (device const float *) (src1 + bid*nb11),
7077
- dst + bid*ne0,
7078
- ne00,
7079
- ne01,
7080
- ne02,
7081
- ne10,
7082
- ne12,
7083
- ne0,
7084
- ne1,
7085
- r2,
7086
- r3,
7087
- tgpig,
7088
- tiisg,
7089
- sgitg);
7090
- }
7091
-
7092
- [[host_name("kernel_mul_mv_id_iq4_nl_f32")]]
7093
- kernel void kernel_mul_mv_id_iq4_nl_f32(
7094
- device const char * src0s,
7095
- device const char * src1,
7096
- device float * dst,
7097
- device const char * ids,
7098
- constant uint64_t & nbi1,
7099
- constant int64_t & ne00,
7100
- constant int64_t & ne01,
7101
- constant int64_t & ne02,
7102
- constant uint64_t & nb00,
7103
- constant uint64_t & nb01,
7104
- constant uint64_t & nb02,
7105
- constant int64_t & ne10,
7106
- constant int64_t & ne11,
7107
- constant int64_t & ne12,
7108
- constant int64_t & ne13,
7109
- constant uint64_t & nb10,
7110
- constant uint64_t & nb11,
7111
- constant uint64_t & nb12,
7112
- constant int64_t & ne0,
7113
- constant int64_t & ne1,
7114
- constant uint64_t & nb1,
7115
- constant uint & r2,
7116
- constant uint & r3,
7117
- constant int & idx,
7118
- threadgroup float * shared_values [[threadgroup(0)]],
7119
- uint3 tgpig[[threadgroup_position_in_grid]],
7120
- uint tiitg[[thread_index_in_threadgroup]],
7121
- uint tiisg[[thread_index_in_simdgroup]],
7122
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
7123
- const int64_t bid = tgpig.z/(ne12*ne13);
7124
-
7125
- tgpig.z = tgpig.z%(ne12*ne13);
7126
-
7127
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
7128
- device const char * src0 = src0s + id*nb02;
7129
-
7130
- kernel_mul_mv_iq4_nl_f32_impl(
7131
- src0,
7132
- (device const float *) (src1 + bid*nb11),
7133
- dst + bid*ne0,
7134
- ne00,
7135
- ne01,
7136
- ne02,
7137
- ne10,
7138
- ne12,
7139
- ne0,
7140
- ne1,
7141
- r2,
7142
- r3,
7143
- shared_values,
7144
- tgpig,
7145
- tiisg,
7146
- sgitg);
7147
- }
7148
-
7149
- [[host_name("kernel_mul_mv_id_iq4_xs_f32")]]
7150
- kernel void kernel_mul_mv_id_iq4_xs_f32(
7151
- device const char * src0s,
7152
- device const char * src1,
7153
- device float * dst,
7154
- device const char * ids,
7155
- constant uint64_t & nbi1,
7156
- constant int64_t & ne00,
7157
- constant int64_t & ne01,
7158
- constant int64_t & ne02,
7159
- constant uint64_t & nb00,
7160
- constant uint64_t & nb01,
7161
- constant uint64_t & nb02,
7162
- constant int64_t & ne10,
7163
- constant int64_t & ne11,
7164
- constant int64_t & ne12,
7165
- constant int64_t & ne13,
7166
- constant uint64_t & nb10,
7167
- constant uint64_t & nb11,
7168
- constant uint64_t & nb12,
7169
- constant int64_t & ne0,
7170
- constant int64_t & ne1,
7171
- constant uint64_t & nb1,
7172
- constant uint & r2,
7173
- constant uint & r3,
7174
- constant int & idx,
7175
- threadgroup float * shared_values [[threadgroup(0)]],
7176
- uint3 tgpig[[threadgroup_position_in_grid]],
7177
- uint tiitg[[thread_index_in_threadgroup]],
7178
- uint tiisg[[thread_index_in_simdgroup]],
7179
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
7180
- const int64_t bid = tgpig.z/(ne12*ne13);
7181
-
7182
- tgpig.z = tgpig.z%(ne12*ne13);
7183
-
7184
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
7185
- device const char * src0 = src0s + id*nb02;
7186
-
7187
- #if QK_K == 64
7188
- kernel_mul_mv_iq4_nl_f32_impl(
7189
- #else
7190
- kernel_mul_mv_iq4_xs_f32_impl(
6221
+ typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>) kernel_mul_mv_id_t;
6222
+
6223
+ template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>;
6224
+ template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f16_f32_impl>>;
6225
+ template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
6226
+ template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
6227
+ template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
6228
+ template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
6229
+ template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
6230
+ template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl>>;
6231
+ template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl>>;
6232
+ template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl>>;
6233
+ template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl>>;
6234
+ template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl>>;
6235
+ template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl>>;
6236
+ template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl>>;
6237
+ template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl>>;
6238
+ template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl>>;
6239
+ template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl>>;
6240
+ template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl>>;
6241
+ template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
6242
+ template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
6243
+ #if QK_K != 64
6244
+ template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
7191
6245
  #endif
7192
- src0,
7193
- (device const float *) (src1 + bid*nb11),
7194
- dst + bid*ne0,
7195
- ne00,
7196
- ne01,
7197
- ne02,
7198
- ne10,
7199
- ne12,
7200
- ne0,
7201
- ne1,
7202
- r2,
7203
- r3,
7204
- shared_values,
7205
- tgpig,
7206
- tiisg,
7207
- sgitg);
7208
- }
6246
+