llama_cpp 0.10.0 → 0.10.2

Sign up to get free protection for your applications and to get access to all the features.
@@ -79,6 +79,7 @@ kernel void kernel_add(
79
79
  constant int64_t & nb1,
80
80
  constant int64_t & nb2,
81
81
  constant int64_t & nb3,
82
+ constant int64_t & offs,
82
83
  uint3 tgpig[[threadgroup_position_in_grid]],
83
84
  uint3 tpitg[[thread_position_in_threadgroup]],
84
85
  uint3 ntg[[threads_per_threadgroup]]) {
@@ -90,9 +91,9 @@ kernel void kernel_add(
90
91
  const int64_t i12 = i02 % ne12;
91
92
  const int64_t i11 = i01 % ne11;
92
93
 
93
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
94
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
94
95
  device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
95
- device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
96
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs;
96
97
 
97
98
  for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
98
99
  const int i10 = i0 % ne10;
@@ -204,7 +205,7 @@ kernel void kernel_add_row(
204
205
  device const float4 * src0,
205
206
  device const float4 * src1,
206
207
  device float4 * dst,
207
- constant int64_t & nb [[buffer(27)]],
208
+ constant int64_t & nb [[buffer(28)]],
208
209
  uint tpig[[thread_position_in_grid]]) {
209
210
  dst[tpig] = src0[tpig] + src1[tpig % nb];
210
211
  }
@@ -213,7 +214,7 @@ kernel void kernel_mul_row(
213
214
  device const float4 * src0,
214
215
  device const float4 * src1,
215
216
  device float4 * dst,
216
- constant int64_t & nb [[buffer(27)]],
217
+ constant int64_t & nb [[buffer(28)]],
217
218
  uint tpig[[thread_position_in_grid]]) {
218
219
  dst[tpig] = src0[tpig] * src1[tpig % nb];
219
220
  }
@@ -222,7 +223,7 @@ kernel void kernel_div_row(
222
223
  device const float4 * src0,
223
224
  device const float4 * src1,
224
225
  device float4 * dst,
225
- constant int64_t & nb [[buffer(27)]],
226
+ constant int64_t & nb [[buffer(28)]],
226
227
  uint tpig[[thread_position_in_grid]]) {
227
228
  dst[tpig] = src0[tpig] / src1[tpig % nb];
228
229
  }
@@ -243,19 +244,53 @@ kernel void kernel_scale_4(
243
244
  dst[tpig] = src0[tpig] * scale;
244
245
  }
245
246
 
246
- kernel void kernel_silu(
247
- device const float4 * src0,
248
- device float4 * dst,
247
+ kernel void kernel_relu(
248
+ device const float * src0,
249
+ device float * dst,
249
250
  uint tpig[[thread_position_in_grid]]) {
250
- device const float4 & x = src0[tpig];
251
- dst[tpig] = x / (1.0f + exp(-x));
251
+ dst[tpig] = max(0.0f, src0[tpig]);
252
252
  }
253
253
 
254
- kernel void kernel_relu(
254
+ kernel void kernel_tanh(
255
255
  device const float * src0,
256
256
  device float * dst,
257
257
  uint tpig[[thread_position_in_grid]]) {
258
- dst[tpig] = max(0.0f, src0[tpig]);
258
+ device const float & x = src0[tpig];
259
+ dst[tpig] = precise::tanh(x);
260
+ }
261
+
262
+ constant float GELU_COEF_A = 0.044715f;
263
+ constant float GELU_QUICK_COEF = -1.702f;
264
+ constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
265
+
266
+ kernel void kernel_gelu(
267
+ device const float4 * src0,
268
+ device float4 * dst,
269
+ uint tpig[[thread_position_in_grid]]) {
270
+ device const float4 & x = src0[tpig];
271
+
272
+ // BEWARE !!!
273
+ // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
274
+ // This was observed with Falcon 7B and 40B models
275
+ //
276
+ dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
277
+ }
278
+
279
+ kernel void kernel_gelu_quick(
280
+ device const float4 * src0,
281
+ device float4 * dst,
282
+ uint tpig[[thread_position_in_grid]]) {
283
+ device const float4 & x = src0[tpig];
284
+
285
+ dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
286
+ }
287
+
288
+ kernel void kernel_silu(
289
+ device const float4 * src0,
290
+ device float4 * dst,
291
+ uint tpig[[thread_position_in_grid]]) {
292
+ device const float4 & x = src0[tpig];
293
+ dst[tpig] = x / (1.0f + exp(-x));
259
294
  }
260
295
 
261
296
  kernel void kernel_sqr(
@@ -313,22 +348,6 @@ kernel void kernel_sum_rows(
313
348
  dst_row[0] = row_sum;
314
349
  }
315
350
 
316
- constant float GELU_COEF_A = 0.044715f;
317
- constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
318
-
319
- kernel void kernel_gelu(
320
- device const float4 * src0,
321
- device float4 * dst,
322
- uint tpig[[thread_position_in_grid]]) {
323
- device const float4 & x = src0[tpig];
324
-
325
- // BEWARE !!!
326
- // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
327
- // This was observed with Falcon 7B and 40B models
328
- //
329
- dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
330
- }
331
-
332
351
  kernel void kernel_soft_max(
333
352
  device const float * src0,
334
353
  device const float * src1,
@@ -347,9 +366,9 @@ kernel void kernel_soft_max(
347
366
  const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
348
367
  const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
349
368
 
350
- device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
351
- device const float * pmask = src1 ? src1 + i01*ne00 : nullptr;
352
- device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
369
+ device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
370
+ device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
371
+ device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
353
372
 
354
373
  // parallel max
355
374
  float lmax = -INFINITY;
@@ -385,7 +404,12 @@ kernel void kernel_soft_max(
385
404
  pdst[i00] = exp_psrc0;
386
405
  }
387
406
 
407
+ // This barrier fixes a failing test
408
+ // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
409
+ threadgroup_barrier(mem_flags::mem_none);
410
+
388
411
  float sum = simd_sum(lsum);
412
+
389
413
  if (ntg > N_SIMDWIDTH) {
390
414
  if (sgitg == 0) {
391
415
  buf[tiisg] = 0.0f;
@@ -428,9 +452,9 @@ kernel void kernel_soft_max_4(
428
452
  const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
429
453
  const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
430
454
 
431
- device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
432
- device const float4 * pmask = src1 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
433
- device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
455
+ device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
456
+ device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
457
+ device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
434
458
 
435
459
  // parallel max
436
460
  float4 lmax4 = -INFINITY;
@@ -468,7 +492,13 @@ kernel void kernel_soft_max_4(
468
492
  }
469
493
 
470
494
  const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
495
+
496
+ // This barrier fixes a failing test
497
+ // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
498
+ threadgroup_barrier(mem_flags::mem_none);
499
+
471
500
  float sum = simd_sum(lsum);
501
+
472
502
  if (ntg > N_SIMDWIDTH) {
473
503
  if (sgitg == 0) {
474
504
  buf[tiisg] = 0.0f;
@@ -639,6 +669,94 @@ kernel void kernel_rms_norm(
639
669
  }
640
670
  }
641
671
 
672
+ kernel void kernel_group_norm(
673
+ device const float * src0,
674
+ device float * dst,
675
+ constant int64_t & ne00,
676
+ constant int64_t & ne01,
677
+ constant int64_t & ne02,
678
+ constant uint64_t & nb00,
679
+ constant uint64_t & nb01,
680
+ constant uint64_t & nb02,
681
+ constant int32_t & n_groups,
682
+ constant float & eps,
683
+ threadgroup float * buf [[threadgroup(0)]],
684
+ uint tgpig[[threadgroup_position_in_grid]],
685
+ uint tpitg[[thread_position_in_threadgroup]],
686
+ uint sgitg[[simdgroup_index_in_threadgroup]],
687
+ uint tiisg[[thread_index_in_simdgroup]],
688
+ uint ntg[[threads_per_threadgroup]]) {
689
+ const int64_t ne = ne00*ne01*ne02;
690
+ const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups);
691
+
692
+ int start = tgpig * gs;
693
+ int end = start + gs;
694
+
695
+ start += tpitg;
696
+
697
+ if (end >= ne) {
698
+ end = ne;
699
+ }
700
+
701
+ float tmp = 0.0f; // partial sum for thread in warp
702
+
703
+ for (int j = start; j < end; j += ntg) {
704
+ tmp += src0[j];
705
+ }
706
+
707
+ threadgroup_barrier(mem_flags::mem_threadgroup);
708
+ tmp = simd_sum(tmp);
709
+ if (ntg > N_SIMDWIDTH) {
710
+ if (sgitg == 0) {
711
+ buf[tiisg] = 0.0f;
712
+ }
713
+
714
+ threadgroup_barrier(mem_flags::mem_threadgroup);
715
+
716
+ if (tiisg == 0) {
717
+ buf[sgitg] = tmp;
718
+ }
719
+
720
+ threadgroup_barrier(mem_flags::mem_threadgroup);
721
+
722
+ tmp = buf[tiisg];
723
+ tmp = simd_sum(tmp);
724
+ }
725
+
726
+ const float mean = tmp / gs;
727
+ tmp = 0.0f;
728
+
729
+ for (int j = start; j < end; j += ntg) {
730
+ float xi = src0[j] - mean;
731
+ dst[j] = xi;
732
+ tmp += xi * xi;
733
+ }
734
+
735
+ tmp = simd_sum(tmp);
736
+ if (ntg > N_SIMDWIDTH) {
737
+ if (sgitg == 0) {
738
+ buf[tiisg] = 0.0f;
739
+ }
740
+
741
+ threadgroup_barrier(mem_flags::mem_threadgroup);
742
+
743
+ if (tiisg == 0) {
744
+ buf[sgitg] = tmp;
745
+ }
746
+
747
+ threadgroup_barrier(mem_flags::mem_threadgroup);
748
+
749
+ tmp = buf[tiisg];
750
+ tmp = simd_sum(tmp);
751
+ }
752
+
753
+ const float variance = tmp / gs;
754
+ const float scale = 1.0f/sqrt(variance + eps);
755
+ for (int j = start; j < end; j += ntg) {
756
+ dst[j] *= scale;
757
+ }
758
+ }
759
+
642
760
  // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
643
761
  // il indicates where the q4 quants begin (0 or QK4_0/4)
644
762
  // we assume that the yl's have been multiplied with the appropriate scale factor
@@ -731,7 +849,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
731
849
  // giard against the number of rows not being divisible by
732
850
  // N_DST, so this is another explicit assumption of the implementation.
733
851
  template<typename block_q_type, int nr, int nsg, int nw>
734
- void mul_vec_q_n_f32(
852
+ void mul_vec_q_n_f32_impl(
735
853
  device const void * src0,
736
854
  device const float * src1,
737
855
  device float * dst,
@@ -813,7 +931,7 @@ kernel void kernel_mul_mv_q4_0_f32(
813
931
  uint3 tgpig[[threadgroup_position_in_grid]],
814
932
  uint tiisg[[thread_index_in_simdgroup]],
815
933
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
816
- mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
934
+ mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
817
935
  }
818
936
 
819
937
  kernel void kernel_mul_mv_q4_1_f32(
@@ -832,7 +950,7 @@ kernel void kernel_mul_mv_q4_1_f32(
832
950
  uint3 tgpig[[threadgroup_position_in_grid]],
833
951
  uint tiisg[[thread_index_in_simdgroup]],
834
952
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
835
- mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
953
+ mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
836
954
  }
837
955
 
838
956
  kernel void kernel_mul_mv_q5_0_f32(
@@ -851,7 +969,7 @@ kernel void kernel_mul_mv_q5_0_f32(
851
969
  uint3 tgpig[[threadgroup_position_in_grid]],
852
970
  uint tiisg[[thread_index_in_simdgroup]],
853
971
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
854
- mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
972
+ mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
855
973
  }
856
974
 
857
975
  kernel void kernel_mul_mv_q5_1_f32(
@@ -870,28 +988,28 @@ kernel void kernel_mul_mv_q5_1_f32(
870
988
  uint3 tgpig[[threadgroup_position_in_grid]],
871
989
  uint tiisg[[thread_index_in_simdgroup]],
872
990
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
873
- mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
991
+ mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
874
992
  }
875
993
 
876
994
 
877
995
  #define NB_Q8_0 8
878
996
 
879
- kernel void kernel_mul_mv_q8_0_f32(
997
+ void kernel_mul_mv_q8_0_f32_impl(
880
998
  device const void * src0,
881
999
  device const float * src1,
882
1000
  device float * dst,
883
1001
  constant int64_t & ne00,
884
- constant int64_t & ne01[[buffer(4)]],
885
- constant int64_t & ne02[[buffer(5)]],
886
- constant int64_t & ne10[[buffer(9)]],
887
- constant int64_t & ne12[[buffer(11)]],
888
- constant int64_t & ne0 [[buffer(15)]],
889
- constant int64_t & ne1 [[buffer(16)]],
890
- constant uint & r2 [[buffer(17)]],
891
- constant uint & r3 [[buffer(18)]],
1002
+ constant int64_t & ne01,
1003
+ constant int64_t & ne02,
1004
+ constant int64_t & ne10,
1005
+ constant int64_t & ne12,
1006
+ constant int64_t & ne0,
1007
+ constant int64_t & ne1,
1008
+ constant uint & r2,
1009
+ constant uint & r3,
892
1010
  uint3 tgpig[[threadgroup_position_in_grid]],
893
- uint tiisg[[thread_index_in_simdgroup]],
894
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
1011
+ uint tiisg[[thread_index_in_simdgroup]],
1012
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
895
1013
  const int nr = N_DST;
896
1014
  const int nsg = N_SIMDGROUP;
897
1015
  const int nw = N_SIMDWIDTH;
@@ -945,9 +1063,29 @@ kernel void kernel_mul_mv_q8_0_f32(
945
1063
  }
946
1064
  }
947
1065
 
1066
+ [[host_name("kernel_mul_mv_q8_0_f32")]]
1067
+ kernel void kernel_mul_mv_q8_0_f32(
1068
+ device const void * src0,
1069
+ device const float * src1,
1070
+ device float * dst,
1071
+ constant int64_t & ne00,
1072
+ constant int64_t & ne01,
1073
+ constant int64_t & ne02,
1074
+ constant int64_t & ne10,
1075
+ constant int64_t & ne12,
1076
+ constant int64_t & ne0,
1077
+ constant int64_t & ne1,
1078
+ constant uint & r2 [[buffer(17)]],
1079
+ constant uint & r3 [[buffer(18)]],
1080
+ uint3 tgpig[[threadgroup_position_in_grid]],
1081
+ uint tiisg[[thread_index_in_simdgroup]],
1082
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1083
+ kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
1084
+ }
1085
+
948
1086
  #define N_F32_F32 4
949
1087
 
950
- kernel void kernel_mul_mv_f32_f32(
1088
+ void kernel_mul_mv_f32_f32_impl(
951
1089
  device const char * src0,
952
1090
  device const char * src1,
953
1091
  device float * dst,
@@ -965,8 +1103,8 @@ kernel void kernel_mul_mv_f32_f32(
965
1103
  constant uint64_t & nb12,
966
1104
  constant int64_t & ne0,
967
1105
  constant int64_t & ne1,
968
- constant uint & r2 [[buffer(17)]],
969
- constant uint & r3 [[buffer(18)]],
1106
+ constant uint & r2,
1107
+ constant uint & r3,
970
1108
  uint3 tgpig[[threadgroup_position_in_grid]],
971
1109
  uint tiisg[[thread_index_in_simdgroup]]) {
972
1110
 
@@ -1025,6 +1163,32 @@ kernel void kernel_mul_mv_f32_f32(
1025
1163
  }
1026
1164
  }
1027
1165
 
1166
+ [[host_name("kernel_mul_mv_f32_f32")]]
1167
+ kernel void kernel_mul_mv_f32_f32(
1168
+ device const char * src0,
1169
+ device const char * src1,
1170
+ device float * dst,
1171
+ constant int64_t & ne00,
1172
+ constant int64_t & ne01,
1173
+ constant int64_t & ne02,
1174
+ constant uint64_t & nb00,
1175
+ constant uint64_t & nb01,
1176
+ constant uint64_t & nb02,
1177
+ constant int64_t & ne10,
1178
+ constant int64_t & ne11,
1179
+ constant int64_t & ne12,
1180
+ constant uint64_t & nb10,
1181
+ constant uint64_t & nb11,
1182
+ constant uint64_t & nb12,
1183
+ constant int64_t & ne0,
1184
+ constant int64_t & ne1,
1185
+ constant uint & r2 [[buffer(17)]],
1186
+ constant uint & r3 [[buffer(18)]],
1187
+ uint3 tgpig[[threadgroup_position_in_grid]],
1188
+ uint tiisg[[thread_index_in_simdgroup]]) {
1189
+ kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
1190
+ }
1191
+
1028
1192
  #define N_F16_F16 4
1029
1193
 
1030
1194
  kernel void kernel_mul_mv_f16_f16(
@@ -1105,7 +1269,7 @@ kernel void kernel_mul_mv_f16_f16(
1105
1269
  }
1106
1270
  }
1107
1271
 
1108
- kernel void kernel_mul_mv_f16_f32_1row(
1272
+ void kernel_mul_mv_f16_f32_1row_impl(
1109
1273
  device const char * src0,
1110
1274
  device const char * src1,
1111
1275
  device float * dst,
@@ -1123,8 +1287,8 @@ kernel void kernel_mul_mv_f16_f32_1row(
1123
1287
  constant uint64_t & nb12,
1124
1288
  constant int64_t & ne0,
1125
1289
  constant int64_t & ne1,
1126
- constant uint & r2 [[buffer(17)]],
1127
- constant uint & r3 [[buffer(18)]],
1290
+ constant uint & r2,
1291
+ constant uint & r3,
1128
1292
  uint3 tgpig[[threadgroup_position_in_grid]],
1129
1293
  uint tiisg[[thread_index_in_simdgroup]]) {
1130
1294
 
@@ -1161,12 +1325,37 @@ kernel void kernel_mul_mv_f16_f32_1row(
1161
1325
  dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1162
1326
  }
1163
1327
  }
1328
+ }
1164
1329
 
1330
+ [[host_name("kernel_mul_mv_f16_f32_1row")]]
1331
+ kernel void kernel_mul_mv_f16_f32_1row(
1332
+ device const char * src0,
1333
+ device const char * src1,
1334
+ device float * dst,
1335
+ constant int64_t & ne00,
1336
+ constant int64_t & ne01,
1337
+ constant int64_t & ne02,
1338
+ constant uint64_t & nb00,
1339
+ constant uint64_t & nb01,
1340
+ constant uint64_t & nb02,
1341
+ constant int64_t & ne10,
1342
+ constant int64_t & ne11,
1343
+ constant int64_t & ne12,
1344
+ constant uint64_t & nb10,
1345
+ constant uint64_t & nb11,
1346
+ constant uint64_t & nb12,
1347
+ constant int64_t & ne0,
1348
+ constant int64_t & ne1,
1349
+ constant uint & r2 [[buffer(17)]],
1350
+ constant uint & r3 [[buffer(18)]],
1351
+ uint3 tgpig[[threadgroup_position_in_grid]],
1352
+ uint tiisg[[thread_index_in_simdgroup]]) {
1353
+ kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
1165
1354
  }
1166
1355
 
1167
1356
  #define N_F16_F32 4
1168
1357
 
1169
- kernel void kernel_mul_mv_f16_f32(
1358
+ void kernel_mul_mv_f16_f32_impl(
1170
1359
  device const char * src0,
1171
1360
  device const char * src1,
1172
1361
  device float * dst,
@@ -1184,8 +1373,8 @@ kernel void kernel_mul_mv_f16_f32(
1184
1373
  constant uint64_t & nb12,
1185
1374
  constant int64_t & ne0,
1186
1375
  constant int64_t & ne1,
1187
- constant uint & r2 [[buffer(17)]],
1188
- constant uint & r3 [[buffer(18)]],
1376
+ constant uint & r2,
1377
+ constant uint & r3,
1189
1378
  uint3 tgpig[[threadgroup_position_in_grid]],
1190
1379
  uint tiisg[[thread_index_in_simdgroup]]) {
1191
1380
 
@@ -1244,6 +1433,32 @@ kernel void kernel_mul_mv_f16_f32(
1244
1433
  }
1245
1434
  }
1246
1435
 
1436
+ [[host_name("kernel_mul_mv_f16_f32")]]
1437
+ kernel void kernel_mul_mv_f16_f32(
1438
+ device const char * src0,
1439
+ device const char * src1,
1440
+ device float * dst,
1441
+ constant int64_t & ne00,
1442
+ constant int64_t & ne01,
1443
+ constant int64_t & ne02,
1444
+ constant uint64_t & nb00,
1445
+ constant uint64_t & nb01,
1446
+ constant uint64_t & nb02,
1447
+ constant int64_t & ne10,
1448
+ constant int64_t & ne11,
1449
+ constant int64_t & ne12,
1450
+ constant uint64_t & nb10,
1451
+ constant uint64_t & nb11,
1452
+ constant uint64_t & nb12,
1453
+ constant int64_t & ne0,
1454
+ constant int64_t & ne1,
1455
+ constant uint & r2 [[buffer(17)]],
1456
+ constant uint & r3 [[buffer(18)]],
1457
+ uint3 tgpig[[threadgroup_position_in_grid]],
1458
+ uint tiisg[[thread_index_in_simdgroup]]) {
1459
+ kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
1460
+ }
1461
+
1247
1462
  // Assumes row size (ne00) is a multiple of 4
1248
1463
  kernel void kernel_mul_mv_f16_f32_l4(
1249
1464
  device const char * src0,
@@ -1487,8 +1702,9 @@ kernel void kernel_rope(
1487
1702
  dst_data[1] = x0*sin_theta + x1*cos_theta;
1488
1703
  }
1489
1704
  } else {
1490
- for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
1491
- for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
1705
+ for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
1706
+ if (ic < n_dims) {
1707
+ const int64_t ib = 0;
1492
1708
 
1493
1709
  // simplified from `(ib * n_dims + ic) * inv_ndims`
1494
1710
  const float cur_rot = inv_ndims*ic - ib;
@@ -1507,6 +1723,14 @@ kernel void kernel_rope(
1507
1723
 
1508
1724
  dst_data[0] = x0*cos_theta - x1*sin_theta;
1509
1725
  dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
1726
+ } else {
1727
+ const int64_t i0 = ic;
1728
+
1729
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1730
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1731
+
1732
+ dst_data[0] = src[0];
1733
+ dst_data[1] = src[1];
1510
1734
  }
1511
1735
  }
1512
1736
  }
@@ -1548,21 +1772,112 @@ kernel void kernel_im2col_f16(
1548
1772
  }
1549
1773
  }
1550
1774
 
1551
- // bitonic sort implementation following the CUDA kernels as reference
1552
- typedef void (argsort_t)(
1553
- device const float * x,
1554
- device int32_t * dst,
1555
- constant int64_t & ncols,
1556
- uint3 tgpig[[threadgroup_position_in_grid]],
1557
- uint3 tpitg[[thread_position_in_threadgroup]]);
1558
-
1559
- template<ggml_sort_order order>
1560
- kernel void kernel_argsort_f32_i32(
1561
- device const float * x,
1562
- device int32_t * dst,
1563
- constant int64_t & ncols,
1564
- uint3 tgpig[[threadgroup_position_in_grid]],
1565
- uint3 tpitg[[thread_position_in_threadgroup]]) {
1775
+ kernel void kernel_upscale_f32(
1776
+ device const char * src0,
1777
+ device char * dst,
1778
+ constant int64_t & ne00,
1779
+ constant int64_t & ne01,
1780
+ constant int64_t & ne02,
1781
+ constant int64_t & ne03,
1782
+ constant uint64_t & nb00,
1783
+ constant uint64_t & nb01,
1784
+ constant uint64_t & nb02,
1785
+ constant uint64_t & nb03,
1786
+ constant int64_t & ne0,
1787
+ constant int64_t & ne1,
1788
+ constant int64_t & ne2,
1789
+ constant int64_t & ne3,
1790
+ constant uint64_t & nb0,
1791
+ constant uint64_t & nb1,
1792
+ constant uint64_t & nb2,
1793
+ constant uint64_t & nb3,
1794
+ constant int32_t & sf,
1795
+ uint3 tgpig[[threadgroup_position_in_grid]],
1796
+ uint3 tpitg[[thread_position_in_threadgroup]],
1797
+ uint3 ntg[[threads_per_threadgroup]]) {
1798
+
1799
+ const int64_t i3 = tgpig.z;
1800
+ const int64_t i2 = tgpig.y;
1801
+ const int64_t i1 = tgpig.x;
1802
+
1803
+ const int64_t i03 = i3;
1804
+ const int64_t i02 = i2;
1805
+ const int64_t i01 = i1/sf;
1806
+
1807
+ device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
1808
+ device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
1809
+
1810
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
1811
+ dst_ptr[i0] = src0_ptr[i0/sf];
1812
+ }
1813
+ }
1814
+
1815
+ kernel void kernel_pad_f32(
1816
+ device const char * src0,
1817
+ device char * dst,
1818
+ constant int64_t & ne00,
1819
+ constant int64_t & ne01,
1820
+ constant int64_t & ne02,
1821
+ constant int64_t & ne03,
1822
+ constant uint64_t & nb00,
1823
+ constant uint64_t & nb01,
1824
+ constant uint64_t & nb02,
1825
+ constant uint64_t & nb03,
1826
+ constant int64_t & ne0,
1827
+ constant int64_t & ne1,
1828
+ constant int64_t & ne2,
1829
+ constant int64_t & ne3,
1830
+ constant uint64_t & nb0,
1831
+ constant uint64_t & nb1,
1832
+ constant uint64_t & nb2,
1833
+ constant uint64_t & nb3,
1834
+ uint3 tgpig[[threadgroup_position_in_grid]],
1835
+ uint3 tpitg[[thread_position_in_threadgroup]],
1836
+ uint3 ntg[[threads_per_threadgroup]]) {
1837
+
1838
+ const int64_t i3 = tgpig.z;
1839
+ const int64_t i2 = tgpig.y;
1840
+ const int64_t i1 = tgpig.x;
1841
+
1842
+ const int64_t i03 = i3;
1843
+ const int64_t i02 = i2;
1844
+ const int64_t i01 = i1;
1845
+
1846
+ device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
1847
+ device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
1848
+
1849
+ if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
1850
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
1851
+ if (i0 < ne00) {
1852
+ dst_ptr[i0] = src0_ptr[i0];
1853
+ } else {
1854
+ dst_ptr[i0] = 0.0f;
1855
+ }
1856
+ }
1857
+
1858
+ return;
1859
+ }
1860
+
1861
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
1862
+ dst_ptr[i0] = 0.0f;
1863
+ }
1864
+ }
1865
+
1866
+ // bitonic sort implementation following the CUDA kernels as reference
1867
+ typedef void (argsort_t)(
1868
+ device const float * x,
1869
+ device int32_t * dst,
1870
+ constant int64_t & ncols,
1871
+ uint3 tgpig[[threadgroup_position_in_grid]],
1872
+ uint3 tpitg[[thread_position_in_threadgroup]]);
1873
+
1874
+ template<ggml_sort_order order>
1875
+ kernel void kernel_argsort_f32_i32(
1876
+ device const float * x,
1877
+ device int32_t * dst,
1878
+ constant int64_t & ncols,
1879
+ uint3 tgpig[[threadgroup_position_in_grid]],
1880
+ uint3 tpitg[[thread_position_in_threadgroup]]) {
1566
1881
  // bitonic sort
1567
1882
  int col = tpitg[0];
1568
1883
  int row = tgpig[1];
@@ -1600,9 +1915,17 @@ kernel void kernel_argsort_f32_i32(
1600
1915
  template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ASC>;
1601
1916
  template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
1602
1917
 
1918
+ kernel void kernel_leaky_relu_f32(
1919
+ device const float * src0,
1920
+ device float * dst,
1921
+ constant float & slope,
1922
+ uint tpig[[thread_position_in_grid]]) {
1923
+ dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
1924
+ }
1925
+
1603
1926
  kernel void kernel_cpy_f16_f16(
1604
- device const half * src0,
1605
- device half * dst,
1927
+ device const half * src0,
1928
+ device half * dst,
1606
1929
  constant int64_t & ne00,
1607
1930
  constant int64_t & ne01,
1608
1931
  constant int64_t & ne02,
@@ -1641,6 +1964,47 @@ kernel void kernel_cpy_f16_f16(
1641
1964
  }
1642
1965
  }
1643
1966
 
1967
+ kernel void kernel_cpy_f16_f32(
1968
+ device const half * src0,
1969
+ device float * dst,
1970
+ constant int64_t & ne00,
1971
+ constant int64_t & ne01,
1972
+ constant int64_t & ne02,
1973
+ constant int64_t & ne03,
1974
+ constant uint64_t & nb00,
1975
+ constant uint64_t & nb01,
1976
+ constant uint64_t & nb02,
1977
+ constant uint64_t & nb03,
1978
+ constant int64_t & ne0,
1979
+ constant int64_t & ne1,
1980
+ constant int64_t & ne2,
1981
+ constant int64_t & ne3,
1982
+ constant uint64_t & nb0,
1983
+ constant uint64_t & nb1,
1984
+ constant uint64_t & nb2,
1985
+ constant uint64_t & nb3,
1986
+ uint3 tgpig[[threadgroup_position_in_grid]],
1987
+ uint3 tpitg[[thread_position_in_threadgroup]],
1988
+ uint3 ntg[[threads_per_threadgroup]]) {
1989
+ const int64_t i03 = tgpig[2];
1990
+ const int64_t i02 = tgpig[1];
1991
+ const int64_t i01 = tgpig[0];
1992
+
1993
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
1994
+
1995
+ const int64_t i3 = n / (ne2*ne1*ne0);
1996
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1997
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1998
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1999
+
2000
+ device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2001
+
2002
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
2003
+ device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2004
+ dst_data[i00] = src[0];
2005
+ }
2006
+ }
2007
+
1644
2008
  kernel void kernel_cpy_f32_f16(
1645
2009
  device const float * src0,
1646
2010
  device half * dst,
@@ -1917,9 +2281,9 @@ kernel void kernel_cpy_f32_q4_1(
1917
2281
  }
1918
2282
 
1919
2283
  kernel void kernel_concat(
1920
- device const char * src0,
1921
- device const char * src1,
1922
- device char * dst,
2284
+ device const char * src0,
2285
+ device const char * src1,
2286
+ device char * dst,
1923
2287
  constant int64_t & ne00,
1924
2288
  constant int64_t & ne01,
1925
2289
  constant int64_t & ne02,
@@ -1956,7 +2320,7 @@ kernel void kernel_concat(
1956
2320
  const int64_t i12 = i02 % ne12;
1957
2321
  const int64_t i11 = i01 % ne11;
1958
2322
 
1959
- device const char * src0_ptr = src0 + i03 * nb03 + i02 * nb02 + i01 * nb01 + tpitg.x*nb00;
2323
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
1960
2324
  device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
1961
2325
  device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
1962
2326
 
@@ -2064,19 +2428,19 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
2064
2428
 
2065
2429
  //====================================== dot products =========================
2066
2430
 
2067
- kernel void kernel_mul_mv_q2_K_f32(
2431
+ void kernel_mul_mv_q2_K_f32_impl(
2068
2432
  device const void * src0,
2069
2433
  device const float * src1,
2070
2434
  device float * dst,
2071
2435
  constant int64_t & ne00,
2072
- constant int64_t & ne01[[buffer(4)]],
2073
- constant int64_t & ne02[[buffer(5)]],
2074
- constant int64_t & ne10[[buffer(9)]],
2075
- constant int64_t & ne12[[buffer(11)]],
2076
- constant int64_t & ne0 [[buffer(15)]],
2077
- constant int64_t & ne1 [[buffer(16)]],
2078
- constant uint & r2 [[buffer(17)]],
2079
- constant uint & r3 [[buffer(18)]],
2436
+ constant int64_t & ne01,
2437
+ constant int64_t & ne02,
2438
+ constant int64_t & ne10,
2439
+ constant int64_t & ne12,
2440
+ constant int64_t & ne0,
2441
+ constant int64_t & ne1,
2442
+ constant uint & r2,
2443
+ constant uint & r3,
2080
2444
  uint3 tgpig[[threadgroup_position_in_grid]],
2081
2445
  uint tiisg[[thread_index_in_simdgroup]],
2082
2446
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2214,8 +2578,8 @@ kernel void kernel_mul_mv_q2_K_f32(
2214
2578
  }
2215
2579
  }
2216
2580
 
2217
- #if QK_K == 256
2218
- kernel void kernel_mul_mv_q3_K_f32(
2581
+ [[host_name("kernel_mul_mv_q2_K_f32")]]
2582
+ kernel void kernel_mul_mv_q2_K_f32(
2219
2583
  device const void * src0,
2220
2584
  device const float * src1,
2221
2585
  device float * dst,
@@ -2229,8 +2593,29 @@ kernel void kernel_mul_mv_q3_K_f32(
2229
2593
  constant uint & r2 [[buffer(17)]],
2230
2594
  constant uint & r3 [[buffer(18)]],
2231
2595
  uint3 tgpig[[threadgroup_position_in_grid]],
2232
- uint tiisg[[thread_index_in_simdgroup]],
2233
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2596
+ uint tiisg[[thread_index_in_simdgroup]],
2597
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2598
+
2599
+ kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
2600
+ }
2601
+
2602
+ #if QK_K == 256
2603
+ void kernel_mul_mv_q3_K_f32_impl(
2604
+ device const void * src0,
2605
+ device const float * src1,
2606
+ device float * dst,
2607
+ constant int64_t & ne00,
2608
+ constant int64_t & ne01,
2609
+ constant int64_t & ne02,
2610
+ constant int64_t & ne10,
2611
+ constant int64_t & ne12,
2612
+ constant int64_t & ne0,
2613
+ constant int64_t & ne1,
2614
+ constant uint & r2,
2615
+ constant uint & r3,
2616
+ uint3 tgpig[[threadgroup_position_in_grid]],
2617
+ uint tiisg[[thread_index_in_simdgroup]],
2618
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2234
2619
 
2235
2620
  const int nb = ne00/QK_K;
2236
2621
 
@@ -2373,19 +2758,19 @@ kernel void kernel_mul_mv_q3_K_f32(
2373
2758
  }
2374
2759
  }
2375
2760
  #else
2376
- kernel void kernel_mul_mv_q3_K_f32(
2761
+ void kernel_mul_mv_q3_K_f32_impl(
2377
2762
  device const void * src0,
2378
2763
  device const float * src1,
2379
2764
  device float * dst,
2380
2765
  constant int64_t & ne00,
2381
- constant int64_t & ne01[[buffer(4)]],
2382
- constant int64_t & ne02[[buffer(5)]],
2383
- constant int64_t & ne10[[buffer(9)]],
2384
- constant int64_t & ne12[[buffer(11)]],
2385
- constant int64_t & ne0 [[buffer(15)]],
2386
- constant int64_t & ne1 [[buffer(16)]],
2387
- constant uint & r2 [[buffer(17)]],
2388
- constant uint & r3 [[buffer(18)]],
2766
+ constant int64_t & ne01,
2767
+ constant int64_t & ne02,
2768
+ constant int64_t & ne10,
2769
+ constant int64_t & ne12,
2770
+ constant int64_t & ne0,
2771
+ constant int64_t & ne1,
2772
+ constant uint & r2,
2773
+ constant uint & r3,
2389
2774
  uint3 tgpig[[threadgroup_position_in_grid]],
2390
2775
  uint tiisg[[thread_index_in_simdgroup]],
2391
2776
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2450,20 +2835,41 @@ kernel void kernel_mul_mv_q3_K_f32(
2450
2835
  }
2451
2836
  #endif
2452
2837
 
2838
+ [[host_name("kernel_mul_mv_q3_K_f32")]]
2839
+ kernel void kernel_mul_mv_q3_K_f32(
2840
+ device const void * src0,
2841
+ device const float * src1,
2842
+ device float * dst,
2843
+ constant int64_t & ne00,
2844
+ constant int64_t & ne01[[buffer(4)]],
2845
+ constant int64_t & ne02[[buffer(5)]],
2846
+ constant int64_t & ne10[[buffer(9)]],
2847
+ constant int64_t & ne12[[buffer(11)]],
2848
+ constant int64_t & ne0 [[buffer(15)]],
2849
+ constant int64_t & ne1 [[buffer(16)]],
2850
+ constant uint & r2 [[buffer(17)]],
2851
+ constant uint & r3 [[buffer(18)]],
2852
+ uint3 tgpig[[threadgroup_position_in_grid]],
2853
+ uint tiisg[[thread_index_in_simdgroup]],
2854
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2855
+
2856
+ kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
2857
+ }
2858
+
2453
2859
  #if QK_K == 256
2454
- kernel void kernel_mul_mv_q4_K_f32(
2860
+ void kernel_mul_mv_q4_K_f32_impl(
2455
2861
  device const void * src0,
2456
2862
  device const float * src1,
2457
2863
  device float * dst,
2458
2864
  constant int64_t & ne00,
2459
- constant int64_t & ne01 [[buffer(4)]],
2460
- constant int64_t & ne02 [[buffer(5)]],
2461
- constant int64_t & ne10 [[buffer(9)]],
2462
- constant int64_t & ne12 [[buffer(11)]],
2463
- constant int64_t & ne0 [[buffer(15)]],
2464
- constant int64_t & ne1 [[buffer(16)]],
2465
- constant uint & r2 [[buffer(17)]],
2466
- constant uint & r3 [[buffer(18)]],
2865
+ constant int64_t & ne01,
2866
+ constant int64_t & ne02,
2867
+ constant int64_t & ne10,
2868
+ constant int64_t & ne12,
2869
+ constant int64_t & ne0,
2870
+ constant int64_t & ne1,
2871
+ constant uint & r2,
2872
+ constant uint & r3,
2467
2873
  uint3 tgpig[[threadgroup_position_in_grid]],
2468
2874
  uint tiisg[[thread_index_in_simdgroup]],
2469
2875
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2564,19 +2970,19 @@ kernel void kernel_mul_mv_q4_K_f32(
2564
2970
  }
2565
2971
  }
2566
2972
  #else
2567
- kernel void kernel_mul_mv_q4_K_f32(
2973
+ void kernel_mul_mv_q4_K_f32_impl(
2568
2974
  device const void * src0,
2569
2975
  device const float * src1,
2570
2976
  device float * dst,
2571
2977
  constant int64_t & ne00,
2572
- constant int64_t & ne01[[buffer(4)]],
2573
- constant int64_t & ne02[[buffer(5)]],
2574
- constant int64_t & ne10[[buffer(9)]],
2575
- constant int64_t & ne12[[buffer(11)]],
2576
- constant int64_t & ne0 [[buffer(15)]],
2577
- constant int64_t & ne1 [[buffer(16)]],
2578
- constant uint & r2 [[buffer(17)]],
2579
- constant uint & r3 [[buffer(18)]],
2978
+ constant int64_t & ne01,
2979
+ constant int64_t & ne02,
2980
+ constant int64_t & ne10,
2981
+ constant int64_t & ne12,
2982
+ constant int64_t & ne0,
2983
+ constant int64_t & ne1,
2984
+ constant uint & r2,
2985
+ constant uint & r3,
2580
2986
  uint3 tgpig[[threadgroup_position_in_grid]],
2581
2987
  uint tiisg[[thread_index_in_simdgroup]],
2582
2988
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2660,7 +3066,8 @@ kernel void kernel_mul_mv_q4_K_f32(
2660
3066
  }
2661
3067
  #endif
2662
3068
 
2663
- kernel void kernel_mul_mv_q5_K_f32(
3069
+ [[host_name("kernel_mul_mv_q4_K_f32")]]
3070
+ kernel void kernel_mul_mv_q4_K_f32(
2664
3071
  device const void * src0,
2665
3072
  device const float * src1,
2666
3073
  device float * dst,
@@ -2677,6 +3084,26 @@ kernel void kernel_mul_mv_q5_K_f32(
2677
3084
  uint tiisg[[thread_index_in_simdgroup]],
2678
3085
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
2679
3086
 
3087
+ kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
3088
+ }
3089
+
3090
+ void kernel_mul_mv_q5_K_f32_impl(
3091
+ device const void * src0,
3092
+ device const float * src1,
3093
+ device float * dst,
3094
+ constant int64_t & ne00,
3095
+ constant int64_t & ne01,
3096
+ constant int64_t & ne02,
3097
+ constant int64_t & ne10,
3098
+ constant int64_t & ne12,
3099
+ constant int64_t & ne0,
3100
+ constant int64_t & ne1,
3101
+ constant uint & r2,
3102
+ constant uint & r3,
3103
+ uint3 tgpig[[threadgroup_position_in_grid]],
3104
+ uint tiisg[[thread_index_in_simdgroup]],
3105
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3106
+
2680
3107
  const int nb = ne00/QK_K;
2681
3108
 
2682
3109
  const int64_t r0 = tgpig.x;
@@ -2836,10 +3263,10 @@ kernel void kernel_mul_mv_q5_K_f32(
2836
3263
  dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
2837
3264
  }
2838
3265
  }
2839
-
2840
3266
  }
2841
3267
 
2842
- kernel void kernel_mul_mv_q6_K_f32(
3268
+ [[host_name("kernel_mul_mv_q5_K_f32")]]
3269
+ kernel void kernel_mul_mv_q5_K_f32(
2843
3270
  device const void * src0,
2844
3271
  device const float * src1,
2845
3272
  device float * dst,
@@ -2853,21 +3280,41 @@ kernel void kernel_mul_mv_q6_K_f32(
2853
3280
  constant uint & r2 [[buffer(17)]],
2854
3281
  constant uint & r3 [[buffer(18)]],
2855
3282
  uint3 tgpig[[threadgroup_position_in_grid]],
2856
- uint tiisg[[thread_index_in_simdgroup]],
2857
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2858
-
2859
- const uint8_t kmask1 = 0x03;
2860
- const uint8_t kmask2 = 0x0C;
2861
- const uint8_t kmask3 = 0x30;
2862
- const uint8_t kmask4 = 0xC0;
2863
-
2864
- const int nb = ne00/QK_K;
2865
-
2866
- const int64_t r0 = tgpig.x;
2867
- const int64_t r1 = tgpig.y;
2868
- const int im = tgpig.z;
3283
+ uint tiisg[[thread_index_in_simdgroup]],
3284
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2869
3285
 
2870
- const int row = 2 * r0 + sgitg;
3286
+ kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
3287
+ }
3288
+
3289
+ void kernel_mul_mv_q6_K_f32_impl(
3290
+ device const void * src0,
3291
+ device const float * src1,
3292
+ device float * dst,
3293
+ constant int64_t & ne00,
3294
+ constant int64_t & ne01,
3295
+ constant int64_t & ne02,
3296
+ constant int64_t & ne10,
3297
+ constant int64_t & ne12,
3298
+ constant int64_t & ne0,
3299
+ constant int64_t & ne1,
3300
+ constant uint & r2,
3301
+ constant uint & r3,
3302
+ uint3 tgpig[[threadgroup_position_in_grid]],
3303
+ uint tiisg[[thread_index_in_simdgroup]],
3304
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3305
+
3306
+ const uint8_t kmask1 = 0x03;
3307
+ const uint8_t kmask2 = 0x0C;
3308
+ const uint8_t kmask3 = 0x30;
3309
+ const uint8_t kmask4 = 0xC0;
3310
+
3311
+ const int nb = ne00/QK_K;
3312
+
3313
+ const int64_t r0 = tgpig.x;
3314
+ const int64_t r1 = tgpig.y;
3315
+ const int im = tgpig.z;
3316
+
3317
+ const int row = 2 * r0 + sgitg;
2871
3318
 
2872
3319
  const uint i12 = im%ne12;
2873
3320
  const uint i13 = im/ne12;
@@ -2945,6 +3392,27 @@ kernel void kernel_mul_mv_q6_K_f32(
2945
3392
  }
2946
3393
  }
2947
3394
 
3395
+ [[host_name("kernel_mul_mv_q6_K_f32")]]
3396
+ kernel void kernel_mul_mv_q6_K_f32(
3397
+ device const void * src0,
3398
+ device const float * src1,
3399
+ device float * dst,
3400
+ constant int64_t & ne00,
3401
+ constant int64_t & ne01[[buffer(4)]],
3402
+ constant int64_t & ne02[[buffer(5)]],
3403
+ constant int64_t & ne10[[buffer(9)]],
3404
+ constant int64_t & ne12[[buffer(11)]],
3405
+ constant int64_t & ne0 [[buffer(15)]],
3406
+ constant int64_t & ne1 [[buffer(16)]],
3407
+ constant uint & r2 [[buffer(17)]],
3408
+ constant uint & r3 [[buffer(18)]],
3409
+ uint3 tgpig[[threadgroup_position_in_grid]],
3410
+ uint tiisg[[thread_index_in_simdgroup]],
3411
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3412
+
3413
+ kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
3414
+ }
3415
+
2948
3416
  //============================= templates and their specializations =============================
2949
3417
 
2950
3418
  // NOTE: this is not dequantizing - we are simply fitting the template
@@ -3062,10 +3530,10 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
3062
3530
 
3063
3531
  template <typename type4x4>
3064
3532
  void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
3065
- const half d = xb->d;
3066
- const half min = xb->dmin;
3533
+ const float d = xb->d;
3534
+ const float min = xb->dmin;
3067
3535
  device const uint8_t * q = (device const uint8_t *)xb->qs;
3068
- half dl, ml;
3536
+ float dl, ml;
3069
3537
  uint8_t sc = xb->scales[il];
3070
3538
 
3071
3539
  #if QK_K == 256
@@ -3135,10 +3603,10 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
3135
3603
  q = q + (il/4) * 32 + 16 * (il&1);
3136
3604
  il = il & 3;
3137
3605
  const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
3138
- const half d = il < 2 ? xb->d : xb->d / 16.h;
3139
- const half min = xb->dmin;
3140
- const half dl = d * sc[0];
3141
- const half ml = min * sc[1];
3606
+ const float d = il < 2 ? xb->d : xb->d / 16.h;
3607
+ const float min = xb->dmin;
3608
+ const float dl = d * sc[0];
3609
+ const float ml = min * sc[1];
3142
3610
  #else
3143
3611
  q = q + 16 * (il&1);
3144
3612
  device const uint8_t * s = xb->scales;
@@ -3165,13 +3633,13 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
3165
3633
  uint8_t ul = 1 << (il/2);
3166
3634
  il = il & 3;
3167
3635
  const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
3168
- const half d = il < 2 ? xb->d : xb->d / 16.h;
3169
- const half min = xb->dmin;
3170
- const half dl = d * sc[0];
3171
- const half ml = min * sc[1];
3636
+ const float d = il < 2 ? xb->d : xb->d / 16.h;
3637
+ const float min = xb->dmin;
3638
+ const float dl = d * sc[0];
3639
+ const float ml = min * sc[1];
3172
3640
 
3173
- const ushort mask = il<2 ? 0x0F : 0xF0;
3174
- const half qh_val = il<2 ? 16.h : 256.h;
3641
+ const ushort mask = il<2 ? 0x0F : 0xF0;
3642
+ const float qh_val = il<2 ? 16.f : 256.f;
3175
3643
  for (int i = 0; i < 16; ++i) {
3176
3644
  reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
3177
3645
  }
@@ -3219,22 +3687,90 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
3219
3687
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
3220
3688
  kernel void kernel_get_rows(
3221
3689
  device const void * src0,
3222
- device const int * src1,
3690
+ device const char * src1,
3223
3691
  device float * dst,
3224
3692
  constant int64_t & ne00,
3225
3693
  constant uint64_t & nb01,
3694
+ constant uint64_t & nb02,
3695
+ constant int64_t & ne10,
3696
+ constant uint64_t & nb10,
3697
+ constant uint64_t & nb11,
3226
3698
  constant uint64_t & nb1,
3227
- uint tgpig[[threadgroup_position_in_grid]],
3699
+ constant uint64_t & nb2,
3700
+ uint3 tgpig[[threadgroup_position_in_grid]],
3228
3701
  uint tiitg[[thread_index_in_threadgroup]],
3229
- uint tptg[[threads_per_threadgroup]]) {
3230
- const int i = tgpig;
3231
- const int r = ((device int32_t *) src1)[i];
3702
+ uint3 tptg [[threads_per_threadgroup]]) {
3703
+ //const int64_t i = tgpig;
3704
+ //const int64_t r = ((device int32_t *) src1)[i];
3705
+
3706
+ const int64_t i10 = tgpig.x;
3707
+ const int64_t i11 = tgpig.y;
3708
+
3709
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
3232
3710
 
3233
- for (int ind = tiitg; ind < ne00/16; ind += tptg) {
3711
+ const int64_t i02 = i11;
3712
+
3713
+ for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
3234
3714
  float4x4 temp;
3235
3715
  dequantize_func(
3236
- ((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
3237
- *(((device float4x4 *) ((device char *) dst + i*nb1)) + ind) = temp;
3716
+ ((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
3717
+ *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
3718
+ }
3719
+ }
3720
+
3721
+ kernel void kernel_get_rows_f32(
3722
+ device const void * src0,
3723
+ device const char * src1,
3724
+ device float * dst,
3725
+ constant int64_t & ne00,
3726
+ constant uint64_t & nb01,
3727
+ constant uint64_t & nb02,
3728
+ constant int64_t & ne10,
3729
+ constant uint64_t & nb10,
3730
+ constant uint64_t & nb11,
3731
+ constant uint64_t & nb1,
3732
+ constant uint64_t & nb2,
3733
+ uint3 tgpig[[threadgroup_position_in_grid]],
3734
+ uint tiitg[[thread_index_in_threadgroup]],
3735
+ uint3 tptg [[threads_per_threadgroup]]) {
3736
+ const int64_t i10 = tgpig.x;
3737
+ const int64_t i11 = tgpig.y;
3738
+
3739
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
3740
+
3741
+ const int64_t i02 = i11;
3742
+
3743
+ for (int ind = tiitg; ind < ne00; ind += tptg.x) {
3744
+ ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
3745
+ ((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
3746
+ }
3747
+ }
3748
+
3749
+ kernel void kernel_get_rows_f16(
3750
+ device const void * src0,
3751
+ device const char * src1,
3752
+ device float * dst,
3753
+ constant int64_t & ne00,
3754
+ constant uint64_t & nb01,
3755
+ constant uint64_t & nb02,
3756
+ constant int64_t & ne10,
3757
+ constant uint64_t & nb10,
3758
+ constant uint64_t & nb11,
3759
+ constant uint64_t & nb1,
3760
+ constant uint64_t & nb2,
3761
+ uint3 tgpig[[threadgroup_position_in_grid]],
3762
+ uint tiitg[[thread_index_in_threadgroup]],
3763
+ uint3 tptg [[threads_per_threadgroup]]) {
3764
+ const int64_t i10 = tgpig.x;
3765
+ const int64_t i11 = tgpig.y;
3766
+
3767
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
3768
+
3769
+ const int64_t i02 = i11;
3770
+
3771
+ for (int ind = tiitg; ind < ne00; ind += tptg.x) {
3772
+ ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
3773
+ ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
3238
3774
  }
3239
3775
  }
3240
3776
 
@@ -3426,19 +3962,22 @@ kernel void kernel_mul_mm(device const uchar * src0,
3426
3962
 
3427
3963
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
3428
3964
  kernel void kernel_mul_mm_id(
3429
- device const int32_t * ids,
3965
+ device const uchar * ids,
3430
3966
  device const uchar * src1,
3431
- device float * dst,
3967
+ device uchar * dst,
3968
+ constant int64_t & nbi1,
3432
3969
  constant int64_t & ne00,
3433
3970
  constant int64_t & ne02,
3434
3971
  constant int64_t & nb01,
3435
3972
  constant int64_t & nb02,
3436
3973
  constant int64_t & ne12,
3974
+ constant int64_t & ne13,
3437
3975
  constant int64_t & nb10,
3438
3976
  constant int64_t & nb11,
3439
3977
  constant int64_t & nb12,
3440
3978
  constant int64_t & ne0,
3441
3979
  constant int64_t & ne1,
3980
+ constant int64_t & nb1,
3442
3981
  constant uint & r2,
3443
3982
  constant uint & r3,
3444
3983
  constant int & idx,
@@ -3456,10 +3995,16 @@ kernel void kernel_mul_mm_id(
3456
3995
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
3457
3996
  device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
3458
3997
 
3998
+ const int64_t bid = tgpig.z/(ne12*ne13);
3999
+
4000
+ tgpig.z = tgpig.z%(ne12*ne13);
4001
+
4002
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4003
+
3459
4004
  kernel_mul_mm_impl<block_q, nl, dequantize_func>(
3460
- src0[ids[idx]],
3461
- src1,
3462
- dst,
4005
+ src0[id],
4006
+ src1 + bid*nb11,
4007
+ (device float *) (dst + bid*nb1),
3463
4008
  ne00,
3464
4009
  ne02,
3465
4010
  nb01,
@@ -3484,17 +4029,26 @@ kernel void kernel_mul_mm_id(
3484
4029
  #define QK_NL 4
3485
4030
  #endif
3486
4031
 
4032
+ //
4033
+ // get rows
4034
+ //
4035
+
3487
4036
  typedef void (get_rows_t)(
3488
4037
  device const void * src0,
3489
- device const int * src1,
4038
+ device const char * src1,
3490
4039
  device float * dst,
3491
4040
  constant int64_t & ne00,
3492
4041
  constant uint64_t & nb01,
4042
+ constant uint64_t & nb02,
4043
+ constant int64_t & ne10,
4044
+ constant uint64_t & nb10,
4045
+ constant uint64_t & nb11,
3493
4046
  constant uint64_t & nb1,
3494
- uint, uint, uint);
4047
+ constant uint64_t & nb2,
4048
+ uint3, uint, uint3);
3495
4049
 
3496
- template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
3497
- template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
4050
+ //template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
4051
+ //template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
3498
4052
  template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
3499
4053
  template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
3500
4054
  template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
@@ -3506,6 +4060,10 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows
3506
4060
  template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
3507
4061
  template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
3508
4062
 
4063
+ //
4064
+ // matrix-matrix multiplication
4065
+ //
4066
+
3509
4067
  typedef void (mat_mm_t)(
3510
4068
  device const uchar * src0,
3511
4069
  device const uchar * src1,
@@ -3538,20 +4096,27 @@ template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
3538
4096
  template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
3539
4097
  template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
3540
4098
 
4099
+ //
4100
+ // indirect matrix-matrix multiplication
4101
+ //
4102
+
3541
4103
  typedef void (mat_mm_id_t)(
3542
- device const int32_t * ids,
4104
+ device const uchar * ids,
3543
4105
  device const uchar * src1,
3544
- device float * dst,
4106
+ device uchar * dst,
4107
+ constant int64_t & nbi1,
3545
4108
  constant int64_t & ne00,
3546
4109
  constant int64_t & ne02,
3547
4110
  constant int64_t & nb01,
3548
4111
  constant int64_t & nb02,
3549
4112
  constant int64_t & ne12,
4113
+ constant int64_t & ne13,
3550
4114
  constant int64_t & nb10,
3551
4115
  constant int64_t & nb11,
3552
4116
  constant int64_t & nb12,
3553
4117
  constant int64_t & ne0,
3554
4118
  constant int64_t & ne1,
4119
+ constant int64_t & nb1,
3555
4120
  constant uint & r2,
3556
4121
  constant uint & r3,
3557
4122
  constant int & idx,
@@ -3578,3 +4143,775 @@ template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mu
3578
4143
  template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
3579
4144
  template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
3580
4145
  template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
4146
+
4147
+ //
4148
+ // matrix-vector multiplication
4149
+ //
4150
+
4151
+ [[host_name("kernel_mul_mv_id_f32_f32")]]
4152
+ kernel void kernel_mul_mv_id_f32_f32(
4153
+ device const char * ids,
4154
+ device const char * src1,
4155
+ device uchar * dst,
4156
+ constant int64_t & nbi1,
4157
+ constant int64_t & ne00,
4158
+ constant int64_t & ne01,
4159
+ constant int64_t & ne02,
4160
+ constant uint64_t & nb00,
4161
+ constant uint64_t & nb01,
4162
+ constant uint64_t & nb02,
4163
+ constant int64_t & ne10,
4164
+ constant int64_t & ne11,
4165
+ constant int64_t & ne12,
4166
+ constant int64_t & ne13,
4167
+ constant uint64_t & nb10,
4168
+ constant uint64_t & nb11,
4169
+ constant uint64_t & nb12,
4170
+ constant int64_t & ne0,
4171
+ constant int64_t & ne1,
4172
+ constant int64_t & nb1,
4173
+ constant uint & r2,
4174
+ constant uint & r3,
4175
+ constant int & idx,
4176
+ device const char * src00,
4177
+ device const char * src01,
4178
+ device const char * src02,
4179
+ device const char * src03,
4180
+ device const char * src04,
4181
+ device const char * src05,
4182
+ device const char * src06,
4183
+ device const char * src07,
4184
+ uint3 tgpig[[threadgroup_position_in_grid]],
4185
+ uint tiitg[[thread_index_in_threadgroup]],
4186
+ uint tiisg[[thread_index_in_simdgroup]],
4187
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4188
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4189
+
4190
+ const int64_t bid = tgpig.z/(ne12*ne13);
4191
+
4192
+ tgpig.z = tgpig.z%(ne12*ne13);
4193
+
4194
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4195
+
4196
+ kernel_mul_mv_f32_f32_impl(
4197
+ src0[id],
4198
+ src1 + bid*nb11,
4199
+ (device float *) (dst + bid*nb1),
4200
+ ne00,
4201
+ ne01,
4202
+ ne02,
4203
+ nb00,
4204
+ nb01,
4205
+ nb02,
4206
+ ne10,
4207
+ ne11,
4208
+ ne12,
4209
+ nb10,
4210
+ nb11,
4211
+ nb12,
4212
+ ne0,
4213
+ ne1,
4214
+ r2,
4215
+ r3,
4216
+ tgpig,
4217
+ tiisg);
4218
+ }
4219
+
4220
+ [[host_name("kernel_mul_mv_id_f16_f32")]]
4221
+ kernel void kernel_mul_mv_id_f16_f32(
4222
+ device const char * ids,
4223
+ device const char * src1,
4224
+ device uchar * dst,
4225
+ constant int64_t & nbi1,
4226
+ constant int64_t & ne00,
4227
+ constant int64_t & ne01,
4228
+ constant int64_t & ne02,
4229
+ constant uint64_t & nb00,
4230
+ constant uint64_t & nb01,
4231
+ constant uint64_t & nb02,
4232
+ constant int64_t & ne10,
4233
+ constant int64_t & ne11,
4234
+ constant int64_t & ne12,
4235
+ constant int64_t & ne13,
4236
+ constant uint64_t & nb10,
4237
+ constant uint64_t & nb11,
4238
+ constant uint64_t & nb12,
4239
+ constant int64_t & ne0,
4240
+ constant int64_t & ne1,
4241
+ constant int64_t & nb1,
4242
+ constant uint & r2,
4243
+ constant uint & r3,
4244
+ constant int & idx,
4245
+ device const char * src00,
4246
+ device const char * src01,
4247
+ device const char * src02,
4248
+ device const char * src03,
4249
+ device const char * src04,
4250
+ device const char * src05,
4251
+ device const char * src06,
4252
+ device const char * src07,
4253
+ uint3 tgpig[[threadgroup_position_in_grid]],
4254
+ uint tiitg[[thread_index_in_threadgroup]],
4255
+ uint tiisg[[thread_index_in_simdgroup]],
4256
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4257
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4258
+
4259
+ const int64_t bid = tgpig.z/(ne12*ne13);
4260
+
4261
+ tgpig.z = tgpig.z%(ne12*ne13);
4262
+
4263
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4264
+
4265
+ kernel_mul_mv_f16_f32_impl(
4266
+ src0[id],
4267
+ src1 + bid*nb11,
4268
+ (device float *) (dst + bid*nb1),
4269
+ ne00,
4270
+ ne01,
4271
+ ne02,
4272
+ nb00,
4273
+ nb01,
4274
+ nb02,
4275
+ ne10,
4276
+ ne11,
4277
+ ne12,
4278
+ nb10,
4279
+ nb11,
4280
+ nb12,
4281
+ ne0,
4282
+ ne1,
4283
+ r2,
4284
+ r3,
4285
+ tgpig,
4286
+ tiisg);
4287
+ }
4288
+
4289
+ [[host_name("kernel_mul_mv_id_q8_0_f32")]]
4290
+ kernel void kernel_mul_mv_id_q8_0_f32(
4291
+ device const char * ids,
4292
+ device const char * src1,
4293
+ device uchar * dst,
4294
+ constant int64_t & nbi1,
4295
+ constant int64_t & ne00,
4296
+ constant int64_t & ne01,
4297
+ constant int64_t & ne02,
4298
+ constant uint64_t & nb00,
4299
+ constant uint64_t & nb01,
4300
+ constant uint64_t & nb02,
4301
+ constant int64_t & ne10,
4302
+ constant int64_t & ne11,
4303
+ constant int64_t & ne12,
4304
+ constant int64_t & ne13,
4305
+ constant uint64_t & nb10,
4306
+ constant uint64_t & nb11,
4307
+ constant uint64_t & nb12,
4308
+ constant int64_t & ne0,
4309
+ constant int64_t & ne1,
4310
+ constant int64_t & nb1,
4311
+ constant uint & r2,
4312
+ constant uint & r3,
4313
+ constant int & idx,
4314
+ device const char * src00,
4315
+ device const char * src01,
4316
+ device const char * src02,
4317
+ device const char * src03,
4318
+ device const char * src04,
4319
+ device const char * src05,
4320
+ device const char * src06,
4321
+ device const char * src07,
4322
+ uint3 tgpig[[threadgroup_position_in_grid]],
4323
+ uint tiitg[[thread_index_in_threadgroup]],
4324
+ uint tiisg[[thread_index_in_simdgroup]],
4325
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4326
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4327
+
4328
+ const int64_t bid = tgpig.z/(ne12*ne13);
4329
+
4330
+ tgpig.z = tgpig.z%(ne12*ne13);
4331
+
4332
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4333
+
4334
+ kernel_mul_mv_q8_0_f32_impl(
4335
+ src0[id],
4336
+ (device const float *) (src1 + bid*nb11),
4337
+ (device float *) ( dst + bid*nb1),
4338
+ ne00,
4339
+ ne01,
4340
+ ne02,
4341
+ ne10,
4342
+ ne12,
4343
+ ne0,
4344
+ ne1,
4345
+ r2,
4346
+ r3,
4347
+ tgpig,
4348
+ tiisg,
4349
+ sgitg);
4350
+ }
4351
+
4352
+ [[host_name("kernel_mul_mv_id_q4_0_f32")]]
4353
+ kernel void kernel_mul_mv_id_q4_0_f32(
4354
+ device const char * ids,
4355
+ device const char * src1,
4356
+ device uchar * dst,
4357
+ constant int64_t & nbi1,
4358
+ constant int64_t & ne00,
4359
+ constant int64_t & ne01,
4360
+ constant int64_t & ne02,
4361
+ constant uint64_t & nb00,
4362
+ constant uint64_t & nb01,
4363
+ constant uint64_t & nb02,
4364
+ constant int64_t & ne10,
4365
+ constant int64_t & ne11,
4366
+ constant int64_t & ne12,
4367
+ constant int64_t & ne13,
4368
+ constant uint64_t & nb10,
4369
+ constant uint64_t & nb11,
4370
+ constant uint64_t & nb12,
4371
+ constant int64_t & ne0,
4372
+ constant int64_t & ne1,
4373
+ constant int64_t & nb1,
4374
+ constant uint & r2,
4375
+ constant uint & r3,
4376
+ constant int & idx,
4377
+ device const char * src00,
4378
+ device const char * src01,
4379
+ device const char * src02,
4380
+ device const char * src03,
4381
+ device const char * src04,
4382
+ device const char * src05,
4383
+ device const char * src06,
4384
+ device const char * src07,
4385
+ uint3 tgpig[[threadgroup_position_in_grid]],
4386
+ uint tiitg[[thread_index_in_threadgroup]],
4387
+ uint tiisg[[thread_index_in_simdgroup]],
4388
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4389
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4390
+
4391
+ const int64_t bid = tgpig.z/(ne12*ne13);
4392
+
4393
+ tgpig.z = tgpig.z%(ne12*ne13);
4394
+
4395
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4396
+
4397
+ mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4398
+ src0[id],
4399
+ (device const float *) (src1 + bid*nb11),
4400
+ (device float *) ( dst + bid*nb1),
4401
+ ne00,
4402
+ ne01,
4403
+ ne02,
4404
+ ne10,
4405
+ ne12,
4406
+ ne0,
4407
+ ne1,
4408
+ r2,
4409
+ r3,
4410
+ tgpig,
4411
+ tiisg,
4412
+ sgitg);
4413
+ }
4414
+
4415
+ [[host_name("kernel_mul_mv_id_q4_1_f32")]]
4416
+ kernel void kernel_mul_mv_id_q4_1_f32(
4417
+ device const char * ids,
4418
+ device const char * src1,
4419
+ device uchar * dst,
4420
+ constant int64_t & nbi1,
4421
+ constant int64_t & ne00,
4422
+ constant int64_t & ne01,
4423
+ constant int64_t & ne02,
4424
+ constant uint64_t & nb00,
4425
+ constant uint64_t & nb01,
4426
+ constant uint64_t & nb02,
4427
+ constant int64_t & ne10,
4428
+ constant int64_t & ne11,
4429
+ constant int64_t & ne12,
4430
+ constant int64_t & ne13,
4431
+ constant uint64_t & nb10,
4432
+ constant uint64_t & nb11,
4433
+ constant uint64_t & nb12,
4434
+ constant int64_t & ne0,
4435
+ constant int64_t & ne1,
4436
+ constant int64_t & nb1,
4437
+ constant uint & r2,
4438
+ constant uint & r3,
4439
+ constant int & idx,
4440
+ device const char * src00,
4441
+ device const char * src01,
4442
+ device const char * src02,
4443
+ device const char * src03,
4444
+ device const char * src04,
4445
+ device const char * src05,
4446
+ device const char * src06,
4447
+ device const char * src07,
4448
+ uint3 tgpig[[threadgroup_position_in_grid]],
4449
+ uint tiitg[[thread_index_in_threadgroup]],
4450
+ uint tiisg[[thread_index_in_simdgroup]],
4451
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4452
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4453
+
4454
+ const int64_t bid = tgpig.z/(ne12*ne13);
4455
+
4456
+ tgpig.z = tgpig.z%(ne12*ne13);
4457
+
4458
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4459
+
4460
+ mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4461
+ src0[id],
4462
+ (device const float *) (src1 + bid*nb11),
4463
+ (device float *) ( dst + bid*nb1),
4464
+ ne00,
4465
+ ne01,
4466
+ ne02,
4467
+ ne10,
4468
+ ne12,
4469
+ ne0,
4470
+ ne1,
4471
+ r2,
4472
+ r3,
4473
+ tgpig,
4474
+ tiisg,
4475
+ sgitg);
4476
+ }
4477
+
4478
+ [[host_name("kernel_mul_mv_id_q5_0_f32")]]
4479
+ kernel void kernel_mul_mv_id_q5_0_f32(
4480
+ device const char * ids,
4481
+ device const char * src1,
4482
+ device uchar * dst,
4483
+ constant int64_t & nbi1,
4484
+ constant int64_t & ne00,
4485
+ constant int64_t & ne01,
4486
+ constant int64_t & ne02,
4487
+ constant uint64_t & nb00,
4488
+ constant uint64_t & nb01,
4489
+ constant uint64_t & nb02,
4490
+ constant int64_t & ne10,
4491
+ constant int64_t & ne11,
4492
+ constant int64_t & ne12,
4493
+ constant int64_t & ne13,
4494
+ constant uint64_t & nb10,
4495
+ constant uint64_t & nb11,
4496
+ constant uint64_t & nb12,
4497
+ constant int64_t & ne0,
4498
+ constant int64_t & ne1,
4499
+ constant int64_t & nb1,
4500
+ constant uint & r2,
4501
+ constant uint & r3,
4502
+ constant int & idx,
4503
+ device const char * src00,
4504
+ device const char * src01,
4505
+ device const char * src02,
4506
+ device const char * src03,
4507
+ device const char * src04,
4508
+ device const char * src05,
4509
+ device const char * src06,
4510
+ device const char * src07,
4511
+ uint3 tgpig[[threadgroup_position_in_grid]],
4512
+ uint tiitg[[thread_index_in_threadgroup]],
4513
+ uint tiisg[[thread_index_in_simdgroup]],
4514
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4515
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4516
+
4517
+ const int64_t bid = tgpig.z/(ne12*ne13);
4518
+
4519
+ tgpig.z = tgpig.z%(ne12*ne13);
4520
+
4521
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4522
+
4523
+ mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4524
+ src0[id],
4525
+ (device const float *) (src1 + bid*nb11),
4526
+ (device float *) ( dst + bid*nb1),
4527
+ ne00,
4528
+ ne01,
4529
+ ne02,
4530
+ ne10,
4531
+ ne12,
4532
+ ne0,
4533
+ ne1,
4534
+ r2,
4535
+ r3,
4536
+ tgpig,
4537
+ tiisg,
4538
+ sgitg);
4539
+ }
4540
+
4541
+ [[host_name("kernel_mul_mv_id_q5_1_f32")]]
4542
+ kernel void kernel_mul_mv_id_q5_1_f32(
4543
+ device const char * ids,
4544
+ device const char * src1,
4545
+ device uchar * dst,
4546
+ constant int64_t & nbi1,
4547
+ constant int64_t & ne00,
4548
+ constant int64_t & ne01,
4549
+ constant int64_t & ne02,
4550
+ constant uint64_t & nb00,
4551
+ constant uint64_t & nb01,
4552
+ constant uint64_t & nb02,
4553
+ constant int64_t & ne10,
4554
+ constant int64_t & ne11,
4555
+ constant int64_t & ne12,
4556
+ constant int64_t & ne13,
4557
+ constant uint64_t & nb10,
4558
+ constant uint64_t & nb11,
4559
+ constant uint64_t & nb12,
4560
+ constant int64_t & ne0,
4561
+ constant int64_t & ne1,
4562
+ constant int64_t & nb1,
4563
+ constant uint & r2,
4564
+ constant uint & r3,
4565
+ constant int & idx,
4566
+ device const char * src00,
4567
+ device const char * src01,
4568
+ device const char * src02,
4569
+ device const char * src03,
4570
+ device const char * src04,
4571
+ device const char * src05,
4572
+ device const char * src06,
4573
+ device const char * src07,
4574
+ uint3 tgpig[[threadgroup_position_in_grid]],
4575
+ uint tiitg[[thread_index_in_threadgroup]],
4576
+ uint tiisg[[thread_index_in_simdgroup]],
4577
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4578
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4579
+
4580
+ const int64_t bid = tgpig.z/(ne12*ne13);
4581
+
4582
+ tgpig.z = tgpig.z%(ne12*ne13);
4583
+
4584
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4585
+
4586
+ mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4587
+ src0[id],
4588
+ (device const float *) (src1 + bid*nb11),
4589
+ (device float *) ( dst + bid*nb1),
4590
+ ne00,
4591
+ ne01,
4592
+ ne02,
4593
+ ne10,
4594
+ ne12,
4595
+ ne0,
4596
+ ne1,
4597
+ r2,
4598
+ r3,
4599
+ tgpig,
4600
+ tiisg,
4601
+ sgitg);
4602
+ }
4603
+
4604
+ [[host_name("kernel_mul_mv_id_q2_K_f32")]]
4605
+ kernel void kernel_mul_mv_id_q2_K_f32(
4606
+ device const char * ids,
4607
+ device const char * src1,
4608
+ device uchar * dst,
4609
+ constant int64_t & nbi1,
4610
+ constant int64_t & ne00,
4611
+ constant int64_t & ne01,
4612
+ constant int64_t & ne02,
4613
+ constant uint64_t & nb00,
4614
+ constant uint64_t & nb01,
4615
+ constant uint64_t & nb02,
4616
+ constant int64_t & ne10,
4617
+ constant int64_t & ne11,
4618
+ constant int64_t & ne12,
4619
+ constant int64_t & ne13,
4620
+ constant uint64_t & nb10,
4621
+ constant uint64_t & nb11,
4622
+ constant uint64_t & nb12,
4623
+ constant int64_t & ne0,
4624
+ constant int64_t & ne1,
4625
+ constant int64_t & nb1,
4626
+ constant uint & r2,
4627
+ constant uint & r3,
4628
+ constant int & idx,
4629
+ device const char * src00,
4630
+ device const char * src01,
4631
+ device const char * src02,
4632
+ device const char * src03,
4633
+ device const char * src04,
4634
+ device const char * src05,
4635
+ device const char * src06,
4636
+ device const char * src07,
4637
+ uint3 tgpig[[threadgroup_position_in_grid]],
4638
+ uint tiitg[[thread_index_in_threadgroup]],
4639
+ uint tiisg[[thread_index_in_simdgroup]],
4640
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4641
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4642
+
4643
+ const int64_t bid = tgpig.z/(ne12*ne13);
4644
+
4645
+ tgpig.z = tgpig.z%(ne12*ne13);
4646
+
4647
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4648
+
4649
+ kernel_mul_mv_q2_K_f32_impl(
4650
+ src0[id],
4651
+ (device const float *) (src1 + bid*nb11),
4652
+ (device float *) ( dst + bid*nb1),
4653
+ ne00,
4654
+ ne01,
4655
+ ne02,
4656
+ ne10,
4657
+ ne12,
4658
+ ne0,
4659
+ ne1,
4660
+ r2,
4661
+ r3,
4662
+ tgpig,
4663
+ tiisg,
4664
+ sgitg);
4665
+ }
4666
+
4667
+ [[host_name("kernel_mul_mv_id_q3_K_f32")]]
4668
+ kernel void kernel_mul_mv_id_q3_K_f32(
4669
+ device const char * ids,
4670
+ device const char * src1,
4671
+ device uchar * dst,
4672
+ constant int64_t & nbi1,
4673
+ constant int64_t & ne00,
4674
+ constant int64_t & ne01,
4675
+ constant int64_t & ne02,
4676
+ constant uint64_t & nb00,
4677
+ constant uint64_t & nb01,
4678
+ constant uint64_t & nb02,
4679
+ constant int64_t & ne10,
4680
+ constant int64_t & ne11,
4681
+ constant int64_t & ne12,
4682
+ constant int64_t & ne13,
4683
+ constant uint64_t & nb10,
4684
+ constant uint64_t & nb11,
4685
+ constant uint64_t & nb12,
4686
+ constant int64_t & ne0,
4687
+ constant int64_t & ne1,
4688
+ constant int64_t & nb1,
4689
+ constant uint & r2,
4690
+ constant uint & r3,
4691
+ constant int & idx,
4692
+ device const char * src00,
4693
+ device const char * src01,
4694
+ device const char * src02,
4695
+ device const char * src03,
4696
+ device const char * src04,
4697
+ device const char * src05,
4698
+ device const char * src06,
4699
+ device const char * src07,
4700
+ uint3 tgpig[[threadgroup_position_in_grid]],
4701
+ uint tiitg[[thread_index_in_threadgroup]],
4702
+ uint tiisg[[thread_index_in_simdgroup]],
4703
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4704
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4705
+
4706
+ const int64_t bid = tgpig.z/(ne12*ne13);
4707
+
4708
+ tgpig.z = tgpig.z%(ne12*ne13);
4709
+
4710
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4711
+
4712
+ kernel_mul_mv_q3_K_f32_impl(
4713
+ src0[id],
4714
+ (device const float *) (src1 + bid*nb11),
4715
+ (device float *) ( dst + bid*nb1),
4716
+ ne00,
4717
+ ne01,
4718
+ ne02,
4719
+ ne10,
4720
+ ne12,
4721
+ ne0,
4722
+ ne1,
4723
+ r2,
4724
+ r3,
4725
+ tgpig,
4726
+ tiisg,
4727
+ sgitg);
4728
+ }
4729
+
4730
+ [[host_name("kernel_mul_mv_id_q4_K_f32")]]
4731
+ kernel void kernel_mul_mv_id_q4_K_f32(
4732
+ device const char * ids,
4733
+ device const char * src1,
4734
+ device uchar * dst,
4735
+ constant int64_t & nbi1,
4736
+ constant int64_t & ne00,
4737
+ constant int64_t & ne01,
4738
+ constant int64_t & ne02,
4739
+ constant uint64_t & nb00,
4740
+ constant uint64_t & nb01,
4741
+ constant uint64_t & nb02,
4742
+ constant int64_t & ne10,
4743
+ constant int64_t & ne11,
4744
+ constant int64_t & ne12,
4745
+ constant int64_t & ne13,
4746
+ constant uint64_t & nb10,
4747
+ constant uint64_t & nb11,
4748
+ constant uint64_t & nb12,
4749
+ constant int64_t & ne0,
4750
+ constant int64_t & ne1,
4751
+ constant int64_t & nb1,
4752
+ constant uint & r2,
4753
+ constant uint & r3,
4754
+ constant int & idx,
4755
+ device const char * src00,
4756
+ device const char * src01,
4757
+ device const char * src02,
4758
+ device const char * src03,
4759
+ device const char * src04,
4760
+ device const char * src05,
4761
+ device const char * src06,
4762
+ device const char * src07,
4763
+ uint3 tgpig[[threadgroup_position_in_grid]],
4764
+ uint tiitg[[thread_index_in_threadgroup]],
4765
+ uint tiisg[[thread_index_in_simdgroup]],
4766
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4767
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4768
+
4769
+ const int64_t bid = tgpig.z/(ne12*ne13);
4770
+
4771
+ tgpig.z = tgpig.z%(ne12*ne13);
4772
+
4773
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4774
+
4775
+ kernel_mul_mv_q4_K_f32_impl(
4776
+ src0[id],
4777
+ (device const float *) (src1 + bid*nb11),
4778
+ (device float *) ( dst + bid*nb1),
4779
+ ne00,
4780
+ ne01,
4781
+ ne02,
4782
+ ne10,
4783
+ ne12,
4784
+ ne0,
4785
+ ne1,
4786
+ r2,
4787
+ r3,
4788
+ tgpig,
4789
+ tiisg,
4790
+ sgitg);
4791
+ }
4792
+
4793
+ [[host_name("kernel_mul_mv_id_q5_K_f32")]]
4794
+ kernel void kernel_mul_mv_id_q5_K_f32(
4795
+ device const char * ids,
4796
+ device const char * src1,
4797
+ device uchar * dst,
4798
+ constant int64_t & nbi1,
4799
+ constant int64_t & ne00,
4800
+ constant int64_t & ne01,
4801
+ constant int64_t & ne02,
4802
+ constant uint64_t & nb00,
4803
+ constant uint64_t & nb01,
4804
+ constant uint64_t & nb02,
4805
+ constant int64_t & ne10,
4806
+ constant int64_t & ne11,
4807
+ constant int64_t & ne12,
4808
+ constant int64_t & ne13,
4809
+ constant uint64_t & nb10,
4810
+ constant uint64_t & nb11,
4811
+ constant uint64_t & nb12,
4812
+ constant int64_t & ne0,
4813
+ constant int64_t & ne1,
4814
+ constant int64_t & nb1,
4815
+ constant uint & r2,
4816
+ constant uint & r3,
4817
+ constant int & idx,
4818
+ device const char * src00,
4819
+ device const char * src01,
4820
+ device const char * src02,
4821
+ device const char * src03,
4822
+ device const char * src04,
4823
+ device const char * src05,
4824
+ device const char * src06,
4825
+ device const char * src07,
4826
+ uint3 tgpig[[threadgroup_position_in_grid]],
4827
+ uint tiitg[[thread_index_in_threadgroup]],
4828
+ uint tiisg[[thread_index_in_simdgroup]],
4829
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4830
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4831
+
4832
+ const int64_t bid = tgpig.z/(ne12*ne13);
4833
+
4834
+ tgpig.z = tgpig.z%(ne12*ne13);
4835
+
4836
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4837
+
4838
+ kernel_mul_mv_q5_K_f32_impl(
4839
+ src0[id],
4840
+ (device const float *) (src1 + bid*nb11),
4841
+ (device float *) ( dst + bid*nb1),
4842
+ ne00,
4843
+ ne01,
4844
+ ne02,
4845
+ ne10,
4846
+ ne12,
4847
+ ne0,
4848
+ ne1,
4849
+ r2,
4850
+ r3,
4851
+ tgpig,
4852
+ tiisg,
4853
+ sgitg);
4854
+ }
4855
+
4856
+ [[host_name("kernel_mul_mv_id_q6_K_f32")]]
4857
+ kernel void kernel_mul_mv_id_q6_K_f32(
4858
+ device const char * ids,
4859
+ device const char * src1,
4860
+ device uchar * dst,
4861
+ constant int64_t & nbi1,
4862
+ constant int64_t & ne00,
4863
+ constant int64_t & ne01,
4864
+ constant int64_t & ne02,
4865
+ constant uint64_t & nb00,
4866
+ constant uint64_t & nb01,
4867
+ constant uint64_t & nb02,
4868
+ constant int64_t & ne10,
4869
+ constant int64_t & ne11,
4870
+ constant int64_t & ne12,
4871
+ constant int64_t & ne13,
4872
+ constant uint64_t & nb10,
4873
+ constant uint64_t & nb11,
4874
+ constant uint64_t & nb12,
4875
+ constant int64_t & ne0,
4876
+ constant int64_t & ne1,
4877
+ constant int64_t & nb1,
4878
+ constant uint & r2,
4879
+ constant uint & r3,
4880
+ constant int & idx,
4881
+ device const char * src00,
4882
+ device const char * src01,
4883
+ device const char * src02,
4884
+ device const char * src03,
4885
+ device const char * src04,
4886
+ device const char * src05,
4887
+ device const char * src06,
4888
+ device const char * src07,
4889
+ uint3 tgpig[[threadgroup_position_in_grid]],
4890
+ uint tiitg[[thread_index_in_threadgroup]],
4891
+ uint tiisg[[thread_index_in_simdgroup]],
4892
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4893
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4894
+
4895
+ const int64_t bid = tgpig.z/(ne12*ne13);
4896
+
4897
+ tgpig.z = tgpig.z%(ne12*ne13);
4898
+
4899
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4900
+
4901
+ kernel_mul_mv_q6_K_f32_impl(
4902
+ src0[id],
4903
+ (device const float *) (src1 + bid*nb11),
4904
+ (device float *) ( dst + bid*nb1),
4905
+ ne00,
4906
+ ne01,
4907
+ ne02,
4908
+ ne10,
4909
+ ne12,
4910
+ ne0,
4911
+ ne1,
4912
+ r2,
4913
+ r3,
4914
+ tgpig,
4915
+ tiisg,
4916
+ sgitg);
4917
+ }