llama_cpp 0.10.0 → 0.10.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -79,6 +79,7 @@ kernel void kernel_add(
79
79
  constant int64_t & nb1,
80
80
  constant int64_t & nb2,
81
81
  constant int64_t & nb3,
82
+ constant int64_t & offs,
82
83
  uint3 tgpig[[threadgroup_position_in_grid]],
83
84
  uint3 tpitg[[thread_position_in_threadgroup]],
84
85
  uint3 ntg[[threads_per_threadgroup]]) {
@@ -90,9 +91,9 @@ kernel void kernel_add(
90
91
  const int64_t i12 = i02 % ne12;
91
92
  const int64_t i11 = i01 % ne11;
92
93
 
93
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
94
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
94
95
  device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
95
- device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
96
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs;
96
97
 
97
98
  for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
98
99
  const int i10 = i0 % ne10;
@@ -204,7 +205,7 @@ kernel void kernel_add_row(
204
205
  device const float4 * src0,
205
206
  device const float4 * src1,
206
207
  device float4 * dst,
207
- constant int64_t & nb [[buffer(27)]],
208
+ constant int64_t & nb [[buffer(28)]],
208
209
  uint tpig[[thread_position_in_grid]]) {
209
210
  dst[tpig] = src0[tpig] + src1[tpig % nb];
210
211
  }
@@ -213,7 +214,7 @@ kernel void kernel_mul_row(
213
214
  device const float4 * src0,
214
215
  device const float4 * src1,
215
216
  device float4 * dst,
216
- constant int64_t & nb [[buffer(27)]],
217
+ constant int64_t & nb [[buffer(28)]],
217
218
  uint tpig[[thread_position_in_grid]]) {
218
219
  dst[tpig] = src0[tpig] * src1[tpig % nb];
219
220
  }
@@ -222,7 +223,7 @@ kernel void kernel_div_row(
222
223
  device const float4 * src0,
223
224
  device const float4 * src1,
224
225
  device float4 * dst,
225
- constant int64_t & nb [[buffer(27)]],
226
+ constant int64_t & nb [[buffer(28)]],
226
227
  uint tpig[[thread_position_in_grid]]) {
227
228
  dst[tpig] = src0[tpig] / src1[tpig % nb];
228
229
  }
@@ -243,19 +244,53 @@ kernel void kernel_scale_4(
243
244
  dst[tpig] = src0[tpig] * scale;
244
245
  }
245
246
 
246
- kernel void kernel_silu(
247
- device const float4 * src0,
248
- device float4 * dst,
247
+ kernel void kernel_relu(
248
+ device const float * src0,
249
+ device float * dst,
249
250
  uint tpig[[thread_position_in_grid]]) {
250
- device const float4 & x = src0[tpig];
251
- dst[tpig] = x / (1.0f + exp(-x));
251
+ dst[tpig] = max(0.0f, src0[tpig]);
252
252
  }
253
253
 
254
- kernel void kernel_relu(
254
+ kernel void kernel_tanh(
255
255
  device const float * src0,
256
256
  device float * dst,
257
257
  uint tpig[[thread_position_in_grid]]) {
258
- dst[tpig] = max(0.0f, src0[tpig]);
258
+ device const float & x = src0[tpig];
259
+ dst[tpig] = precise::tanh(x);
260
+ }
261
+
262
+ constant float GELU_COEF_A = 0.044715f;
263
+ constant float GELU_QUICK_COEF = -1.702f;
264
+ constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
265
+
266
+ kernel void kernel_gelu(
267
+ device const float4 * src0,
268
+ device float4 * dst,
269
+ uint tpig[[thread_position_in_grid]]) {
270
+ device const float4 & x = src0[tpig];
271
+
272
+ // BEWARE !!!
273
+ // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
274
+ // This was observed with Falcon 7B and 40B models
275
+ //
276
+ dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
277
+ }
278
+
279
+ kernel void kernel_gelu_quick(
280
+ device const float4 * src0,
281
+ device float4 * dst,
282
+ uint tpig[[thread_position_in_grid]]) {
283
+ device const float4 & x = src0[tpig];
284
+
285
+ dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
286
+ }
287
+
288
+ kernel void kernel_silu(
289
+ device const float4 * src0,
290
+ device float4 * dst,
291
+ uint tpig[[thread_position_in_grid]]) {
292
+ device const float4 & x = src0[tpig];
293
+ dst[tpig] = x / (1.0f + exp(-x));
259
294
  }
260
295
 
261
296
  kernel void kernel_sqr(
@@ -313,22 +348,6 @@ kernel void kernel_sum_rows(
313
348
  dst_row[0] = row_sum;
314
349
  }
315
350
 
316
- constant float GELU_COEF_A = 0.044715f;
317
- constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
318
-
319
- kernel void kernel_gelu(
320
- device const float4 * src0,
321
- device float4 * dst,
322
- uint tpig[[thread_position_in_grid]]) {
323
- device const float4 & x = src0[tpig];
324
-
325
- // BEWARE !!!
326
- // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
327
- // This was observed with Falcon 7B and 40B models
328
- //
329
- dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
330
- }
331
-
332
351
  kernel void kernel_soft_max(
333
352
  device const float * src0,
334
353
  device const float * src1,
@@ -347,9 +366,9 @@ kernel void kernel_soft_max(
347
366
  const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
348
367
  const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
349
368
 
350
- device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
351
- device const float * pmask = src1 ? src1 + i01*ne00 : nullptr;
352
- device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
369
+ device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
370
+ device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
371
+ device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
353
372
 
354
373
  // parallel max
355
374
  float lmax = -INFINITY;
@@ -385,7 +404,12 @@ kernel void kernel_soft_max(
385
404
  pdst[i00] = exp_psrc0;
386
405
  }
387
406
 
407
+ // This barrier fixes a failing test
408
+ // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
409
+ threadgroup_barrier(mem_flags::mem_none);
410
+
388
411
  float sum = simd_sum(lsum);
412
+
389
413
  if (ntg > N_SIMDWIDTH) {
390
414
  if (sgitg == 0) {
391
415
  buf[tiisg] = 0.0f;
@@ -428,9 +452,9 @@ kernel void kernel_soft_max_4(
428
452
  const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
429
453
  const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
430
454
 
431
- device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
432
- device const float4 * pmask = src1 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
433
- device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
455
+ device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
456
+ device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
457
+ device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
434
458
 
435
459
  // parallel max
436
460
  float4 lmax4 = -INFINITY;
@@ -468,7 +492,13 @@ kernel void kernel_soft_max_4(
468
492
  }
469
493
 
470
494
  const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
495
+
496
+ // This barrier fixes a failing test
497
+ // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
498
+ threadgroup_barrier(mem_flags::mem_none);
499
+
471
500
  float sum = simd_sum(lsum);
501
+
472
502
  if (ntg > N_SIMDWIDTH) {
473
503
  if (sgitg == 0) {
474
504
  buf[tiisg] = 0.0f;
@@ -639,6 +669,94 @@ kernel void kernel_rms_norm(
639
669
  }
640
670
  }
641
671
 
672
+ kernel void kernel_group_norm(
673
+ device const float * src0,
674
+ device float * dst,
675
+ constant int64_t & ne00,
676
+ constant int64_t & ne01,
677
+ constant int64_t & ne02,
678
+ constant uint64_t & nb00,
679
+ constant uint64_t & nb01,
680
+ constant uint64_t & nb02,
681
+ constant int32_t & n_groups,
682
+ constant float & eps,
683
+ threadgroup float * buf [[threadgroup(0)]],
684
+ uint tgpig[[threadgroup_position_in_grid]],
685
+ uint tpitg[[thread_position_in_threadgroup]],
686
+ uint sgitg[[simdgroup_index_in_threadgroup]],
687
+ uint tiisg[[thread_index_in_simdgroup]],
688
+ uint ntg[[threads_per_threadgroup]]) {
689
+ const int64_t ne = ne00*ne01*ne02;
690
+ const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups);
691
+
692
+ int start = tgpig * gs;
693
+ int end = start + gs;
694
+
695
+ start += tpitg;
696
+
697
+ if (end >= ne) {
698
+ end = ne;
699
+ }
700
+
701
+ float tmp = 0.0f; // partial sum for thread in warp
702
+
703
+ for (int j = start; j < end; j += ntg) {
704
+ tmp += src0[j];
705
+ }
706
+
707
+ threadgroup_barrier(mem_flags::mem_threadgroup);
708
+ tmp = simd_sum(tmp);
709
+ if (ntg > N_SIMDWIDTH) {
710
+ if (sgitg == 0) {
711
+ buf[tiisg] = 0.0f;
712
+ }
713
+
714
+ threadgroup_barrier(mem_flags::mem_threadgroup);
715
+
716
+ if (tiisg == 0) {
717
+ buf[sgitg] = tmp;
718
+ }
719
+
720
+ threadgroup_barrier(mem_flags::mem_threadgroup);
721
+
722
+ tmp = buf[tiisg];
723
+ tmp = simd_sum(tmp);
724
+ }
725
+
726
+ const float mean = tmp / gs;
727
+ tmp = 0.0f;
728
+
729
+ for (int j = start; j < end; j += ntg) {
730
+ float xi = src0[j] - mean;
731
+ dst[j] = xi;
732
+ tmp += xi * xi;
733
+ }
734
+
735
+ tmp = simd_sum(tmp);
736
+ if (ntg > N_SIMDWIDTH) {
737
+ if (sgitg == 0) {
738
+ buf[tiisg] = 0.0f;
739
+ }
740
+
741
+ threadgroup_barrier(mem_flags::mem_threadgroup);
742
+
743
+ if (tiisg == 0) {
744
+ buf[sgitg] = tmp;
745
+ }
746
+
747
+ threadgroup_barrier(mem_flags::mem_threadgroup);
748
+
749
+ tmp = buf[tiisg];
750
+ tmp = simd_sum(tmp);
751
+ }
752
+
753
+ const float variance = tmp / gs;
754
+ const float scale = 1.0f/sqrt(variance + eps);
755
+ for (int j = start; j < end; j += ntg) {
756
+ dst[j] *= scale;
757
+ }
758
+ }
759
+
642
760
  // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
643
761
  // il indicates where the q4 quants begin (0 or QK4_0/4)
644
762
  // we assume that the yl's have been multiplied with the appropriate scale factor
@@ -731,7 +849,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
731
849
  // giard against the number of rows not being divisible by
732
850
  // N_DST, so this is another explicit assumption of the implementation.
733
851
  template<typename block_q_type, int nr, int nsg, int nw>
734
- void mul_vec_q_n_f32(
852
+ void mul_vec_q_n_f32_impl(
735
853
  device const void * src0,
736
854
  device const float * src1,
737
855
  device float * dst,
@@ -813,7 +931,7 @@ kernel void kernel_mul_mv_q4_0_f32(
813
931
  uint3 tgpig[[threadgroup_position_in_grid]],
814
932
  uint tiisg[[thread_index_in_simdgroup]],
815
933
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
816
- mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
934
+ mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
817
935
  }
818
936
 
819
937
  kernel void kernel_mul_mv_q4_1_f32(
@@ -832,7 +950,7 @@ kernel void kernel_mul_mv_q4_1_f32(
832
950
  uint3 tgpig[[threadgroup_position_in_grid]],
833
951
  uint tiisg[[thread_index_in_simdgroup]],
834
952
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
835
- mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
953
+ mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
836
954
  }
837
955
 
838
956
  kernel void kernel_mul_mv_q5_0_f32(
@@ -851,7 +969,7 @@ kernel void kernel_mul_mv_q5_0_f32(
851
969
  uint3 tgpig[[threadgroup_position_in_grid]],
852
970
  uint tiisg[[thread_index_in_simdgroup]],
853
971
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
854
- mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
972
+ mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
855
973
  }
856
974
 
857
975
  kernel void kernel_mul_mv_q5_1_f32(
@@ -870,28 +988,28 @@ kernel void kernel_mul_mv_q5_1_f32(
870
988
  uint3 tgpig[[threadgroup_position_in_grid]],
871
989
  uint tiisg[[thread_index_in_simdgroup]],
872
990
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
873
- mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
991
+ mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
874
992
  }
875
993
 
876
994
 
877
995
  #define NB_Q8_0 8
878
996
 
879
- kernel void kernel_mul_mv_q8_0_f32(
997
+ void kernel_mul_mv_q8_0_f32_impl(
880
998
  device const void * src0,
881
999
  device const float * src1,
882
1000
  device float * dst,
883
1001
  constant int64_t & ne00,
884
- constant int64_t & ne01[[buffer(4)]],
885
- constant int64_t & ne02[[buffer(5)]],
886
- constant int64_t & ne10[[buffer(9)]],
887
- constant int64_t & ne12[[buffer(11)]],
888
- constant int64_t & ne0 [[buffer(15)]],
889
- constant int64_t & ne1 [[buffer(16)]],
890
- constant uint & r2 [[buffer(17)]],
891
- constant uint & r3 [[buffer(18)]],
1002
+ constant int64_t & ne01,
1003
+ constant int64_t & ne02,
1004
+ constant int64_t & ne10,
1005
+ constant int64_t & ne12,
1006
+ constant int64_t & ne0,
1007
+ constant int64_t & ne1,
1008
+ constant uint & r2,
1009
+ constant uint & r3,
892
1010
  uint3 tgpig[[threadgroup_position_in_grid]],
893
- uint tiisg[[thread_index_in_simdgroup]],
894
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
1011
+ uint tiisg[[thread_index_in_simdgroup]],
1012
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
895
1013
  const int nr = N_DST;
896
1014
  const int nsg = N_SIMDGROUP;
897
1015
  const int nw = N_SIMDWIDTH;
@@ -945,9 +1063,29 @@ kernel void kernel_mul_mv_q8_0_f32(
945
1063
  }
946
1064
  }
947
1065
 
1066
+ [[host_name("kernel_mul_mv_q8_0_f32")]]
1067
+ kernel void kernel_mul_mv_q8_0_f32(
1068
+ device const void * src0,
1069
+ device const float * src1,
1070
+ device float * dst,
1071
+ constant int64_t & ne00,
1072
+ constant int64_t & ne01,
1073
+ constant int64_t & ne02,
1074
+ constant int64_t & ne10,
1075
+ constant int64_t & ne12,
1076
+ constant int64_t & ne0,
1077
+ constant int64_t & ne1,
1078
+ constant uint & r2 [[buffer(17)]],
1079
+ constant uint & r3 [[buffer(18)]],
1080
+ uint3 tgpig[[threadgroup_position_in_grid]],
1081
+ uint tiisg[[thread_index_in_simdgroup]],
1082
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1083
+ kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
1084
+ }
1085
+
948
1086
  #define N_F32_F32 4
949
1087
 
950
- kernel void kernel_mul_mv_f32_f32(
1088
+ void kernel_mul_mv_f32_f32_impl(
951
1089
  device const char * src0,
952
1090
  device const char * src1,
953
1091
  device float * dst,
@@ -965,8 +1103,8 @@ kernel void kernel_mul_mv_f32_f32(
965
1103
  constant uint64_t & nb12,
966
1104
  constant int64_t & ne0,
967
1105
  constant int64_t & ne1,
968
- constant uint & r2 [[buffer(17)]],
969
- constant uint & r3 [[buffer(18)]],
1106
+ constant uint & r2,
1107
+ constant uint & r3,
970
1108
  uint3 tgpig[[threadgroup_position_in_grid]],
971
1109
  uint tiisg[[thread_index_in_simdgroup]]) {
972
1110
 
@@ -1025,6 +1163,32 @@ kernel void kernel_mul_mv_f32_f32(
1025
1163
  }
1026
1164
  }
1027
1165
 
1166
+ [[host_name("kernel_mul_mv_f32_f32")]]
1167
+ kernel void kernel_mul_mv_f32_f32(
1168
+ device const char * src0,
1169
+ device const char * src1,
1170
+ device float * dst,
1171
+ constant int64_t & ne00,
1172
+ constant int64_t & ne01,
1173
+ constant int64_t & ne02,
1174
+ constant uint64_t & nb00,
1175
+ constant uint64_t & nb01,
1176
+ constant uint64_t & nb02,
1177
+ constant int64_t & ne10,
1178
+ constant int64_t & ne11,
1179
+ constant int64_t & ne12,
1180
+ constant uint64_t & nb10,
1181
+ constant uint64_t & nb11,
1182
+ constant uint64_t & nb12,
1183
+ constant int64_t & ne0,
1184
+ constant int64_t & ne1,
1185
+ constant uint & r2 [[buffer(17)]],
1186
+ constant uint & r3 [[buffer(18)]],
1187
+ uint3 tgpig[[threadgroup_position_in_grid]],
1188
+ uint tiisg[[thread_index_in_simdgroup]]) {
1189
+ kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
1190
+ }
1191
+
1028
1192
  #define N_F16_F16 4
1029
1193
 
1030
1194
  kernel void kernel_mul_mv_f16_f16(
@@ -1105,7 +1269,7 @@ kernel void kernel_mul_mv_f16_f16(
1105
1269
  }
1106
1270
  }
1107
1271
 
1108
- kernel void kernel_mul_mv_f16_f32_1row(
1272
+ void kernel_mul_mv_f16_f32_1row_impl(
1109
1273
  device const char * src0,
1110
1274
  device const char * src1,
1111
1275
  device float * dst,
@@ -1123,8 +1287,8 @@ kernel void kernel_mul_mv_f16_f32_1row(
1123
1287
  constant uint64_t & nb12,
1124
1288
  constant int64_t & ne0,
1125
1289
  constant int64_t & ne1,
1126
- constant uint & r2 [[buffer(17)]],
1127
- constant uint & r3 [[buffer(18)]],
1290
+ constant uint & r2,
1291
+ constant uint & r3,
1128
1292
  uint3 tgpig[[threadgroup_position_in_grid]],
1129
1293
  uint tiisg[[thread_index_in_simdgroup]]) {
1130
1294
 
@@ -1161,12 +1325,37 @@ kernel void kernel_mul_mv_f16_f32_1row(
1161
1325
  dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1162
1326
  }
1163
1327
  }
1328
+ }
1164
1329
 
1330
+ [[host_name("kernel_mul_mv_f16_f32_1row")]]
1331
+ kernel void kernel_mul_mv_f16_f32_1row(
1332
+ device const char * src0,
1333
+ device const char * src1,
1334
+ device float * dst,
1335
+ constant int64_t & ne00,
1336
+ constant int64_t & ne01,
1337
+ constant int64_t & ne02,
1338
+ constant uint64_t & nb00,
1339
+ constant uint64_t & nb01,
1340
+ constant uint64_t & nb02,
1341
+ constant int64_t & ne10,
1342
+ constant int64_t & ne11,
1343
+ constant int64_t & ne12,
1344
+ constant uint64_t & nb10,
1345
+ constant uint64_t & nb11,
1346
+ constant uint64_t & nb12,
1347
+ constant int64_t & ne0,
1348
+ constant int64_t & ne1,
1349
+ constant uint & r2 [[buffer(17)]],
1350
+ constant uint & r3 [[buffer(18)]],
1351
+ uint3 tgpig[[threadgroup_position_in_grid]],
1352
+ uint tiisg[[thread_index_in_simdgroup]]) {
1353
+ kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
1165
1354
  }
1166
1355
 
1167
1356
  #define N_F16_F32 4
1168
1357
 
1169
- kernel void kernel_mul_mv_f16_f32(
1358
+ void kernel_mul_mv_f16_f32_impl(
1170
1359
  device const char * src0,
1171
1360
  device const char * src1,
1172
1361
  device float * dst,
@@ -1184,8 +1373,8 @@ kernel void kernel_mul_mv_f16_f32(
1184
1373
  constant uint64_t & nb12,
1185
1374
  constant int64_t & ne0,
1186
1375
  constant int64_t & ne1,
1187
- constant uint & r2 [[buffer(17)]],
1188
- constant uint & r3 [[buffer(18)]],
1376
+ constant uint & r2,
1377
+ constant uint & r3,
1189
1378
  uint3 tgpig[[threadgroup_position_in_grid]],
1190
1379
  uint tiisg[[thread_index_in_simdgroup]]) {
1191
1380
 
@@ -1244,6 +1433,32 @@ kernel void kernel_mul_mv_f16_f32(
1244
1433
  }
1245
1434
  }
1246
1435
 
1436
+ [[host_name("kernel_mul_mv_f16_f32")]]
1437
+ kernel void kernel_mul_mv_f16_f32(
1438
+ device const char * src0,
1439
+ device const char * src1,
1440
+ device float * dst,
1441
+ constant int64_t & ne00,
1442
+ constant int64_t & ne01,
1443
+ constant int64_t & ne02,
1444
+ constant uint64_t & nb00,
1445
+ constant uint64_t & nb01,
1446
+ constant uint64_t & nb02,
1447
+ constant int64_t & ne10,
1448
+ constant int64_t & ne11,
1449
+ constant int64_t & ne12,
1450
+ constant uint64_t & nb10,
1451
+ constant uint64_t & nb11,
1452
+ constant uint64_t & nb12,
1453
+ constant int64_t & ne0,
1454
+ constant int64_t & ne1,
1455
+ constant uint & r2 [[buffer(17)]],
1456
+ constant uint & r3 [[buffer(18)]],
1457
+ uint3 tgpig[[threadgroup_position_in_grid]],
1458
+ uint tiisg[[thread_index_in_simdgroup]]) {
1459
+ kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
1460
+ }
1461
+
1247
1462
  // Assumes row size (ne00) is a multiple of 4
1248
1463
  kernel void kernel_mul_mv_f16_f32_l4(
1249
1464
  device const char * src0,
@@ -1548,25 +1763,116 @@ kernel void kernel_im2col_f16(
1548
1763
  }
1549
1764
  }
1550
1765
 
1551
- // bitonic sort implementation following the CUDA kernels as reference
1552
- typedef void (argsort_t)(
1553
- device const float * x,
1554
- device int32_t * dst,
1555
- constant int64_t & ncols,
1556
- uint3 tgpig[[threadgroup_position_in_grid]],
1557
- uint3 tpitg[[thread_position_in_threadgroup]]);
1558
-
1559
- template<ggml_sort_order order>
1560
- kernel void kernel_argsort_f32_i32(
1561
- device const float * x,
1562
- device int32_t * dst,
1563
- constant int64_t & ncols,
1564
- uint3 tgpig[[threadgroup_position_in_grid]],
1565
- uint3 tpitg[[thread_position_in_threadgroup]]) {
1566
- // bitonic sort
1567
- int col = tpitg[0];
1568
- int row = tgpig[1];
1569
-
1766
+ kernel void kernel_upscale_f32(
1767
+ device const char * src0,
1768
+ device char * dst,
1769
+ constant int64_t & ne00,
1770
+ constant int64_t & ne01,
1771
+ constant int64_t & ne02,
1772
+ constant int64_t & ne03,
1773
+ constant uint64_t & nb00,
1774
+ constant uint64_t & nb01,
1775
+ constant uint64_t & nb02,
1776
+ constant uint64_t & nb03,
1777
+ constant int64_t & ne0,
1778
+ constant int64_t & ne1,
1779
+ constant int64_t & ne2,
1780
+ constant int64_t & ne3,
1781
+ constant uint64_t & nb0,
1782
+ constant uint64_t & nb1,
1783
+ constant uint64_t & nb2,
1784
+ constant uint64_t & nb3,
1785
+ constant int32_t & sf,
1786
+ uint3 tgpig[[threadgroup_position_in_grid]],
1787
+ uint3 tpitg[[thread_position_in_threadgroup]],
1788
+ uint3 ntg[[threads_per_threadgroup]]) {
1789
+
1790
+ const int64_t i3 = tgpig.z;
1791
+ const int64_t i2 = tgpig.y;
1792
+ const int64_t i1 = tgpig.x;
1793
+
1794
+ const int64_t i03 = i3;
1795
+ const int64_t i02 = i2;
1796
+ const int64_t i01 = i1/sf;
1797
+
1798
+ device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
1799
+ device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
1800
+
1801
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
1802
+ dst_ptr[i0] = src0_ptr[i0/sf];
1803
+ }
1804
+ }
1805
+
1806
+ kernel void kernel_pad_f32(
1807
+ device const char * src0,
1808
+ device char * dst,
1809
+ constant int64_t & ne00,
1810
+ constant int64_t & ne01,
1811
+ constant int64_t & ne02,
1812
+ constant int64_t & ne03,
1813
+ constant uint64_t & nb00,
1814
+ constant uint64_t & nb01,
1815
+ constant uint64_t & nb02,
1816
+ constant uint64_t & nb03,
1817
+ constant int64_t & ne0,
1818
+ constant int64_t & ne1,
1819
+ constant int64_t & ne2,
1820
+ constant int64_t & ne3,
1821
+ constant uint64_t & nb0,
1822
+ constant uint64_t & nb1,
1823
+ constant uint64_t & nb2,
1824
+ constant uint64_t & nb3,
1825
+ uint3 tgpig[[threadgroup_position_in_grid]],
1826
+ uint3 tpitg[[thread_position_in_threadgroup]],
1827
+ uint3 ntg[[threads_per_threadgroup]]) {
1828
+
1829
+ const int64_t i3 = tgpig.z;
1830
+ const int64_t i2 = tgpig.y;
1831
+ const int64_t i1 = tgpig.x;
1832
+
1833
+ const int64_t i03 = i3;
1834
+ const int64_t i02 = i2;
1835
+ const int64_t i01 = i1;
1836
+
1837
+ device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
1838
+ device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
1839
+
1840
+ if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
1841
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
1842
+ if (i0 < ne00) {
1843
+ dst_ptr[i0] = src0_ptr[i0];
1844
+ } else {
1845
+ dst_ptr[i0] = 0.0f;
1846
+ }
1847
+ }
1848
+
1849
+ return;
1850
+ }
1851
+
1852
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
1853
+ dst_ptr[i0] = 0.0f;
1854
+ }
1855
+ }
1856
+
1857
+ // bitonic sort implementation following the CUDA kernels as reference
1858
+ typedef void (argsort_t)(
1859
+ device const float * x,
1860
+ device int32_t * dst,
1861
+ constant int64_t & ncols,
1862
+ uint3 tgpig[[threadgroup_position_in_grid]],
1863
+ uint3 tpitg[[thread_position_in_threadgroup]]);
1864
+
1865
+ template<ggml_sort_order order>
1866
+ kernel void kernel_argsort_f32_i32(
1867
+ device const float * x,
1868
+ device int32_t * dst,
1869
+ constant int64_t & ncols,
1870
+ uint3 tgpig[[threadgroup_position_in_grid]],
1871
+ uint3 tpitg[[thread_position_in_threadgroup]]) {
1872
+ // bitonic sort
1873
+ int col = tpitg[0];
1874
+ int row = tgpig[1];
1875
+
1570
1876
  if (col >= ncols) return;
1571
1877
 
1572
1878
  device const float * x_row = x + row * ncols;
@@ -1600,9 +1906,17 @@ kernel void kernel_argsort_f32_i32(
1600
1906
  template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ASC>;
1601
1907
  template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
1602
1908
 
1909
+ kernel void kernel_leaky_relu_f32(
1910
+ device const float * src0,
1911
+ device float * dst,
1912
+ constant float & slope,
1913
+ uint tpig[[thread_position_in_grid]]) {
1914
+ dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
1915
+ }
1916
+
1603
1917
  kernel void kernel_cpy_f16_f16(
1604
- device const half * src0,
1605
- device half * dst,
1918
+ device const half * src0,
1919
+ device half * dst,
1606
1920
  constant int64_t & ne00,
1607
1921
  constant int64_t & ne01,
1608
1922
  constant int64_t & ne02,
@@ -1641,6 +1955,47 @@ kernel void kernel_cpy_f16_f16(
1641
1955
  }
1642
1956
  }
1643
1957
 
1958
+ kernel void kernel_cpy_f16_f32(
1959
+ device const half * src0,
1960
+ device float * dst,
1961
+ constant int64_t & ne00,
1962
+ constant int64_t & ne01,
1963
+ constant int64_t & ne02,
1964
+ constant int64_t & ne03,
1965
+ constant uint64_t & nb00,
1966
+ constant uint64_t & nb01,
1967
+ constant uint64_t & nb02,
1968
+ constant uint64_t & nb03,
1969
+ constant int64_t & ne0,
1970
+ constant int64_t & ne1,
1971
+ constant int64_t & ne2,
1972
+ constant int64_t & ne3,
1973
+ constant uint64_t & nb0,
1974
+ constant uint64_t & nb1,
1975
+ constant uint64_t & nb2,
1976
+ constant uint64_t & nb3,
1977
+ uint3 tgpig[[threadgroup_position_in_grid]],
1978
+ uint3 tpitg[[thread_position_in_threadgroup]],
1979
+ uint3 ntg[[threads_per_threadgroup]]) {
1980
+ const int64_t i03 = tgpig[2];
1981
+ const int64_t i02 = tgpig[1];
1982
+ const int64_t i01 = tgpig[0];
1983
+
1984
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
1985
+
1986
+ const int64_t i3 = n / (ne2*ne1*ne0);
1987
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1988
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1989
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1990
+
1991
+ device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1992
+
1993
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
1994
+ device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1995
+ dst_data[i00] = src[0];
1996
+ }
1997
+ }
1998
+
1644
1999
  kernel void kernel_cpy_f32_f16(
1645
2000
  device const float * src0,
1646
2001
  device half * dst,
@@ -1917,9 +2272,9 @@ kernel void kernel_cpy_f32_q4_1(
1917
2272
  }
1918
2273
 
1919
2274
  kernel void kernel_concat(
1920
- device const char * src0,
1921
- device const char * src1,
1922
- device char * dst,
2275
+ device const char * src0,
2276
+ device const char * src1,
2277
+ device char * dst,
1923
2278
  constant int64_t & ne00,
1924
2279
  constant int64_t & ne01,
1925
2280
  constant int64_t & ne02,
@@ -1956,7 +2311,7 @@ kernel void kernel_concat(
1956
2311
  const int64_t i12 = i02 % ne12;
1957
2312
  const int64_t i11 = i01 % ne11;
1958
2313
 
1959
- device const char * src0_ptr = src0 + i03 * nb03 + i02 * nb02 + i01 * nb01 + tpitg.x*nb00;
2314
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
1960
2315
  device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
1961
2316
  device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
1962
2317
 
@@ -2064,19 +2419,19 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
2064
2419
 
2065
2420
  //====================================== dot products =========================
2066
2421
 
2067
- kernel void kernel_mul_mv_q2_K_f32(
2422
+ void kernel_mul_mv_q2_K_f32_impl(
2068
2423
  device const void * src0,
2069
2424
  device const float * src1,
2070
2425
  device float * dst,
2071
2426
  constant int64_t & ne00,
2072
- constant int64_t & ne01[[buffer(4)]],
2073
- constant int64_t & ne02[[buffer(5)]],
2074
- constant int64_t & ne10[[buffer(9)]],
2075
- constant int64_t & ne12[[buffer(11)]],
2076
- constant int64_t & ne0 [[buffer(15)]],
2077
- constant int64_t & ne1 [[buffer(16)]],
2078
- constant uint & r2 [[buffer(17)]],
2079
- constant uint & r3 [[buffer(18)]],
2427
+ constant int64_t & ne01,
2428
+ constant int64_t & ne02,
2429
+ constant int64_t & ne10,
2430
+ constant int64_t & ne12,
2431
+ constant int64_t & ne0,
2432
+ constant int64_t & ne1,
2433
+ constant uint & r2,
2434
+ constant uint & r3,
2080
2435
  uint3 tgpig[[threadgroup_position_in_grid]],
2081
2436
  uint tiisg[[thread_index_in_simdgroup]],
2082
2437
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2214,8 +2569,8 @@ kernel void kernel_mul_mv_q2_K_f32(
2214
2569
  }
2215
2570
  }
2216
2571
 
2217
- #if QK_K == 256
2218
- kernel void kernel_mul_mv_q3_K_f32(
2572
+ [[host_name("kernel_mul_mv_q2_K_f32")]]
2573
+ kernel void kernel_mul_mv_q2_K_f32(
2219
2574
  device const void * src0,
2220
2575
  device const float * src1,
2221
2576
  device float * dst,
@@ -2229,8 +2584,29 @@ kernel void kernel_mul_mv_q3_K_f32(
2229
2584
  constant uint & r2 [[buffer(17)]],
2230
2585
  constant uint & r3 [[buffer(18)]],
2231
2586
  uint3 tgpig[[threadgroup_position_in_grid]],
2232
- uint tiisg[[thread_index_in_simdgroup]],
2233
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2587
+ uint tiisg[[thread_index_in_simdgroup]],
2588
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2589
+
2590
+ kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
2591
+ }
2592
+
2593
+ #if QK_K == 256
2594
+ void kernel_mul_mv_q3_K_f32_impl(
2595
+ device const void * src0,
2596
+ device const float * src1,
2597
+ device float * dst,
2598
+ constant int64_t & ne00,
2599
+ constant int64_t & ne01,
2600
+ constant int64_t & ne02,
2601
+ constant int64_t & ne10,
2602
+ constant int64_t & ne12,
2603
+ constant int64_t & ne0,
2604
+ constant int64_t & ne1,
2605
+ constant uint & r2,
2606
+ constant uint & r3,
2607
+ uint3 tgpig[[threadgroup_position_in_grid]],
2608
+ uint tiisg[[thread_index_in_simdgroup]],
2609
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2234
2610
 
2235
2611
  const int nb = ne00/QK_K;
2236
2612
 
@@ -2373,19 +2749,19 @@ kernel void kernel_mul_mv_q3_K_f32(
2373
2749
  }
2374
2750
  }
2375
2751
  #else
2376
- kernel void kernel_mul_mv_q3_K_f32(
2752
+ void kernel_mul_mv_q3_K_f32_impl(
2377
2753
  device const void * src0,
2378
2754
  device const float * src1,
2379
2755
  device float * dst,
2380
2756
  constant int64_t & ne00,
2381
- constant int64_t & ne01[[buffer(4)]],
2382
- constant int64_t & ne02[[buffer(5)]],
2383
- constant int64_t & ne10[[buffer(9)]],
2384
- constant int64_t & ne12[[buffer(11)]],
2385
- constant int64_t & ne0 [[buffer(15)]],
2386
- constant int64_t & ne1 [[buffer(16)]],
2387
- constant uint & r2 [[buffer(17)]],
2388
- constant uint & r3 [[buffer(18)]],
2757
+ constant int64_t & ne01,
2758
+ constant int64_t & ne02,
2759
+ constant int64_t & ne10,
2760
+ constant int64_t & ne12,
2761
+ constant int64_t & ne0,
2762
+ constant int64_t & ne1,
2763
+ constant uint & r2,
2764
+ constant uint & r3,
2389
2765
  uint3 tgpig[[threadgroup_position_in_grid]],
2390
2766
  uint tiisg[[thread_index_in_simdgroup]],
2391
2767
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2450,20 +2826,41 @@ kernel void kernel_mul_mv_q3_K_f32(
2450
2826
  }
2451
2827
  #endif
2452
2828
 
2829
+ [[host_name("kernel_mul_mv_q3_K_f32")]]
2830
+ kernel void kernel_mul_mv_q3_K_f32(
2831
+ device const void * src0,
2832
+ device const float * src1,
2833
+ device float * dst,
2834
+ constant int64_t & ne00,
2835
+ constant int64_t & ne01[[buffer(4)]],
2836
+ constant int64_t & ne02[[buffer(5)]],
2837
+ constant int64_t & ne10[[buffer(9)]],
2838
+ constant int64_t & ne12[[buffer(11)]],
2839
+ constant int64_t & ne0 [[buffer(15)]],
2840
+ constant int64_t & ne1 [[buffer(16)]],
2841
+ constant uint & r2 [[buffer(17)]],
2842
+ constant uint & r3 [[buffer(18)]],
2843
+ uint3 tgpig[[threadgroup_position_in_grid]],
2844
+ uint tiisg[[thread_index_in_simdgroup]],
2845
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2846
+
2847
+ kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
2848
+ }
2849
+
2453
2850
  #if QK_K == 256
2454
- kernel void kernel_mul_mv_q4_K_f32(
2851
+ void kernel_mul_mv_q4_K_f32_impl(
2455
2852
  device const void * src0,
2456
2853
  device const float * src1,
2457
2854
  device float * dst,
2458
2855
  constant int64_t & ne00,
2459
- constant int64_t & ne01 [[buffer(4)]],
2460
- constant int64_t & ne02 [[buffer(5)]],
2461
- constant int64_t & ne10 [[buffer(9)]],
2462
- constant int64_t & ne12 [[buffer(11)]],
2463
- constant int64_t & ne0 [[buffer(15)]],
2464
- constant int64_t & ne1 [[buffer(16)]],
2465
- constant uint & r2 [[buffer(17)]],
2466
- constant uint & r3 [[buffer(18)]],
2856
+ constant int64_t & ne01,
2857
+ constant int64_t & ne02,
2858
+ constant int64_t & ne10,
2859
+ constant int64_t & ne12,
2860
+ constant int64_t & ne0,
2861
+ constant int64_t & ne1,
2862
+ constant uint & r2,
2863
+ constant uint & r3,
2467
2864
  uint3 tgpig[[threadgroup_position_in_grid]],
2468
2865
  uint tiisg[[thread_index_in_simdgroup]],
2469
2866
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2564,19 +2961,19 @@ kernel void kernel_mul_mv_q4_K_f32(
2564
2961
  }
2565
2962
  }
2566
2963
  #else
2567
- kernel void kernel_mul_mv_q4_K_f32(
2964
+ void kernel_mul_mv_q4_K_f32_impl(
2568
2965
  device const void * src0,
2569
2966
  device const float * src1,
2570
2967
  device float * dst,
2571
2968
  constant int64_t & ne00,
2572
- constant int64_t & ne01[[buffer(4)]],
2573
- constant int64_t & ne02[[buffer(5)]],
2574
- constant int64_t & ne10[[buffer(9)]],
2575
- constant int64_t & ne12[[buffer(11)]],
2576
- constant int64_t & ne0 [[buffer(15)]],
2577
- constant int64_t & ne1 [[buffer(16)]],
2578
- constant uint & r2 [[buffer(17)]],
2579
- constant uint & r3 [[buffer(18)]],
2969
+ constant int64_t & ne01,
2970
+ constant int64_t & ne02,
2971
+ constant int64_t & ne10,
2972
+ constant int64_t & ne12,
2973
+ constant int64_t & ne0,
2974
+ constant int64_t & ne1,
2975
+ constant uint & r2,
2976
+ constant uint & r3,
2580
2977
  uint3 tgpig[[threadgroup_position_in_grid]],
2581
2978
  uint tiisg[[thread_index_in_simdgroup]],
2582
2979
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2660,7 +3057,8 @@ kernel void kernel_mul_mv_q4_K_f32(
2660
3057
  }
2661
3058
  #endif
2662
3059
 
2663
- kernel void kernel_mul_mv_q5_K_f32(
3060
+ [[host_name("kernel_mul_mv_q4_K_f32")]]
3061
+ kernel void kernel_mul_mv_q4_K_f32(
2664
3062
  device const void * src0,
2665
3063
  device const float * src1,
2666
3064
  device float * dst,
@@ -2677,6 +3075,26 @@ kernel void kernel_mul_mv_q5_K_f32(
2677
3075
  uint tiisg[[thread_index_in_simdgroup]],
2678
3076
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
2679
3077
 
3078
+ kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
3079
+ }
3080
+
3081
+ void kernel_mul_mv_q5_K_f32_impl(
3082
+ device const void * src0,
3083
+ device const float * src1,
3084
+ device float * dst,
3085
+ constant int64_t & ne00,
3086
+ constant int64_t & ne01,
3087
+ constant int64_t & ne02,
3088
+ constant int64_t & ne10,
3089
+ constant int64_t & ne12,
3090
+ constant int64_t & ne0,
3091
+ constant int64_t & ne1,
3092
+ constant uint & r2,
3093
+ constant uint & r3,
3094
+ uint3 tgpig[[threadgroup_position_in_grid]],
3095
+ uint tiisg[[thread_index_in_simdgroup]],
3096
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3097
+
2680
3098
  const int nb = ne00/QK_K;
2681
3099
 
2682
3100
  const int64_t r0 = tgpig.x;
@@ -2836,10 +3254,10 @@ kernel void kernel_mul_mv_q5_K_f32(
2836
3254
  dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
2837
3255
  }
2838
3256
  }
2839
-
2840
3257
  }
2841
3258
 
2842
- kernel void kernel_mul_mv_q6_K_f32(
3259
+ [[host_name("kernel_mul_mv_q5_K_f32")]]
3260
+ kernel void kernel_mul_mv_q5_K_f32(
2843
3261
  device const void * src0,
2844
3262
  device const float * src1,
2845
3263
  device float * dst,
@@ -2853,18 +3271,38 @@ kernel void kernel_mul_mv_q6_K_f32(
2853
3271
  constant uint & r2 [[buffer(17)]],
2854
3272
  constant uint & r3 [[buffer(18)]],
2855
3273
  uint3 tgpig[[threadgroup_position_in_grid]],
2856
- uint tiisg[[thread_index_in_simdgroup]],
2857
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2858
-
2859
- const uint8_t kmask1 = 0x03;
2860
- const uint8_t kmask2 = 0x0C;
2861
- const uint8_t kmask3 = 0x30;
2862
- const uint8_t kmask4 = 0xC0;
3274
+ uint tiisg[[thread_index_in_simdgroup]],
3275
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2863
3276
 
2864
- const int nb = ne00/QK_K;
3277
+ kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
3278
+ }
2865
3279
 
2866
- const int64_t r0 = tgpig.x;
2867
- const int64_t r1 = tgpig.y;
3280
+ void kernel_mul_mv_q6_K_f32_impl(
3281
+ device const void * src0,
3282
+ device const float * src1,
3283
+ device float * dst,
3284
+ constant int64_t & ne00,
3285
+ constant int64_t & ne01,
3286
+ constant int64_t & ne02,
3287
+ constant int64_t & ne10,
3288
+ constant int64_t & ne12,
3289
+ constant int64_t & ne0,
3290
+ constant int64_t & ne1,
3291
+ constant uint & r2,
3292
+ constant uint & r3,
3293
+ uint3 tgpig[[threadgroup_position_in_grid]],
3294
+ uint tiisg[[thread_index_in_simdgroup]],
3295
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3296
+
3297
+ const uint8_t kmask1 = 0x03;
3298
+ const uint8_t kmask2 = 0x0C;
3299
+ const uint8_t kmask3 = 0x30;
3300
+ const uint8_t kmask4 = 0xC0;
3301
+
3302
+ const int nb = ne00/QK_K;
3303
+
3304
+ const int64_t r0 = tgpig.x;
3305
+ const int64_t r1 = tgpig.y;
2868
3306
  const int im = tgpig.z;
2869
3307
 
2870
3308
  const int row = 2 * r0 + sgitg;
@@ -2945,6 +3383,27 @@ kernel void kernel_mul_mv_q6_K_f32(
2945
3383
  }
2946
3384
  }
2947
3385
 
3386
+ [[host_name("kernel_mul_mv_q6_K_f32")]]
3387
+ kernel void kernel_mul_mv_q6_K_f32(
3388
+ device const void * src0,
3389
+ device const float * src1,
3390
+ device float * dst,
3391
+ constant int64_t & ne00,
3392
+ constant int64_t & ne01[[buffer(4)]],
3393
+ constant int64_t & ne02[[buffer(5)]],
3394
+ constant int64_t & ne10[[buffer(9)]],
3395
+ constant int64_t & ne12[[buffer(11)]],
3396
+ constant int64_t & ne0 [[buffer(15)]],
3397
+ constant int64_t & ne1 [[buffer(16)]],
3398
+ constant uint & r2 [[buffer(17)]],
3399
+ constant uint & r3 [[buffer(18)]],
3400
+ uint3 tgpig[[threadgroup_position_in_grid]],
3401
+ uint tiisg[[thread_index_in_simdgroup]],
3402
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3403
+
3404
+ kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
3405
+ }
3406
+
2948
3407
  //============================= templates and their specializations =============================
2949
3408
 
2950
3409
  // NOTE: this is not dequantizing - we are simply fitting the template
@@ -3062,10 +3521,10 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
3062
3521
 
3063
3522
  template <typename type4x4>
3064
3523
  void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
3065
- const half d = xb->d;
3066
- const half min = xb->dmin;
3524
+ const float d = xb->d;
3525
+ const float min = xb->dmin;
3067
3526
  device const uint8_t * q = (device const uint8_t *)xb->qs;
3068
- half dl, ml;
3527
+ float dl, ml;
3069
3528
  uint8_t sc = xb->scales[il];
3070
3529
 
3071
3530
  #if QK_K == 256
@@ -3135,10 +3594,10 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
3135
3594
  q = q + (il/4) * 32 + 16 * (il&1);
3136
3595
  il = il & 3;
3137
3596
  const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
3138
- const half d = il < 2 ? xb->d : xb->d / 16.h;
3139
- const half min = xb->dmin;
3140
- const half dl = d * sc[0];
3141
- const half ml = min * sc[1];
3597
+ const float d = il < 2 ? xb->d : xb->d / 16.h;
3598
+ const float min = xb->dmin;
3599
+ const float dl = d * sc[0];
3600
+ const float ml = min * sc[1];
3142
3601
  #else
3143
3602
  q = q + 16 * (il&1);
3144
3603
  device const uint8_t * s = xb->scales;
@@ -3165,13 +3624,13 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
3165
3624
  uint8_t ul = 1 << (il/2);
3166
3625
  il = il & 3;
3167
3626
  const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
3168
- const half d = il < 2 ? xb->d : xb->d / 16.h;
3169
- const half min = xb->dmin;
3170
- const half dl = d * sc[0];
3171
- const half ml = min * sc[1];
3627
+ const float d = il < 2 ? xb->d : xb->d / 16.h;
3628
+ const float min = xb->dmin;
3629
+ const float dl = d * sc[0];
3630
+ const float ml = min * sc[1];
3172
3631
 
3173
- const ushort mask = il<2 ? 0x0F : 0xF0;
3174
- const half qh_val = il<2 ? 16.h : 256.h;
3632
+ const ushort mask = il<2 ? 0x0F : 0xF0;
3633
+ const float qh_val = il<2 ? 16.f : 256.f;
3175
3634
  for (int i = 0; i < 16; ++i) {
3176
3635
  reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
3177
3636
  }
@@ -3219,22 +3678,90 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
3219
3678
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
3220
3679
  kernel void kernel_get_rows(
3221
3680
  device const void * src0,
3222
- device const int * src1,
3681
+ device const char * src1,
3223
3682
  device float * dst,
3224
3683
  constant int64_t & ne00,
3225
3684
  constant uint64_t & nb01,
3685
+ constant uint64_t & nb02,
3686
+ constant int64_t & ne10,
3687
+ constant uint64_t & nb10,
3688
+ constant uint64_t & nb11,
3226
3689
  constant uint64_t & nb1,
3227
- uint tgpig[[threadgroup_position_in_grid]],
3690
+ constant uint64_t & nb2,
3691
+ uint3 tgpig[[threadgroup_position_in_grid]],
3228
3692
  uint tiitg[[thread_index_in_threadgroup]],
3229
- uint tptg[[threads_per_threadgroup]]) {
3230
- const int i = tgpig;
3231
- const int r = ((device int32_t *) src1)[i];
3693
+ uint3 tptg [[threads_per_threadgroup]]) {
3694
+ //const int64_t i = tgpig;
3695
+ //const int64_t r = ((device int32_t *) src1)[i];
3696
+
3697
+ const int64_t i10 = tgpig.x;
3698
+ const int64_t i11 = tgpig.y;
3232
3699
 
3233
- for (int ind = tiitg; ind < ne00/16; ind += tptg) {
3700
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
3701
+
3702
+ const int64_t i02 = i11;
3703
+
3704
+ for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
3234
3705
  float4x4 temp;
3235
3706
  dequantize_func(
3236
- ((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
3237
- *(((device float4x4 *) ((device char *) dst + i*nb1)) + ind) = temp;
3707
+ ((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
3708
+ *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
3709
+ }
3710
+ }
3711
+
3712
+ kernel void kernel_get_rows_f32(
3713
+ device const void * src0,
3714
+ device const char * src1,
3715
+ device float * dst,
3716
+ constant int64_t & ne00,
3717
+ constant uint64_t & nb01,
3718
+ constant uint64_t & nb02,
3719
+ constant int64_t & ne10,
3720
+ constant uint64_t & nb10,
3721
+ constant uint64_t & nb11,
3722
+ constant uint64_t & nb1,
3723
+ constant uint64_t & nb2,
3724
+ uint3 tgpig[[threadgroup_position_in_grid]],
3725
+ uint tiitg[[thread_index_in_threadgroup]],
3726
+ uint3 tptg [[threads_per_threadgroup]]) {
3727
+ const int64_t i10 = tgpig.x;
3728
+ const int64_t i11 = tgpig.y;
3729
+
3730
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
3731
+
3732
+ const int64_t i02 = i11;
3733
+
3734
+ for (int ind = tiitg; ind < ne00; ind += tptg.x) {
3735
+ ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
3736
+ ((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
3737
+ }
3738
+ }
3739
+
3740
+ kernel void kernel_get_rows_f16(
3741
+ device const void * src0,
3742
+ device const char * src1,
3743
+ device float * dst,
3744
+ constant int64_t & ne00,
3745
+ constant uint64_t & nb01,
3746
+ constant uint64_t & nb02,
3747
+ constant int64_t & ne10,
3748
+ constant uint64_t & nb10,
3749
+ constant uint64_t & nb11,
3750
+ constant uint64_t & nb1,
3751
+ constant uint64_t & nb2,
3752
+ uint3 tgpig[[threadgroup_position_in_grid]],
3753
+ uint tiitg[[thread_index_in_threadgroup]],
3754
+ uint3 tptg [[threads_per_threadgroup]]) {
3755
+ const int64_t i10 = tgpig.x;
3756
+ const int64_t i11 = tgpig.y;
3757
+
3758
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
3759
+
3760
+ const int64_t i02 = i11;
3761
+
3762
+ for (int ind = tiitg; ind < ne00; ind += tptg.x) {
3763
+ ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
3764
+ ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
3238
3765
  }
3239
3766
  }
3240
3767
 
@@ -3426,19 +3953,22 @@ kernel void kernel_mul_mm(device const uchar * src0,
3426
3953
 
3427
3954
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
3428
3955
  kernel void kernel_mul_mm_id(
3429
- device const int32_t * ids,
3956
+ device const uchar * ids,
3430
3957
  device const uchar * src1,
3431
- device float * dst,
3958
+ device uchar * dst,
3959
+ constant int64_t & nbi1,
3432
3960
  constant int64_t & ne00,
3433
3961
  constant int64_t & ne02,
3434
3962
  constant int64_t & nb01,
3435
3963
  constant int64_t & nb02,
3436
3964
  constant int64_t & ne12,
3965
+ constant int64_t & ne13,
3437
3966
  constant int64_t & nb10,
3438
3967
  constant int64_t & nb11,
3439
3968
  constant int64_t & nb12,
3440
3969
  constant int64_t & ne0,
3441
3970
  constant int64_t & ne1,
3971
+ constant int64_t & nb1,
3442
3972
  constant uint & r2,
3443
3973
  constant uint & r3,
3444
3974
  constant int & idx,
@@ -3456,10 +3986,16 @@ kernel void kernel_mul_mm_id(
3456
3986
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
3457
3987
  device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
3458
3988
 
3989
+ const int64_t bid = tgpig.z/(ne12*ne13);
3990
+
3991
+ tgpig.z = tgpig.z%(ne12*ne13);
3992
+
3993
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
3994
+
3459
3995
  kernel_mul_mm_impl<block_q, nl, dequantize_func>(
3460
- src0[ids[idx]],
3461
- src1,
3462
- dst,
3996
+ src0[id],
3997
+ src1 + bid*nb11,
3998
+ (device float *) (dst + bid*nb1),
3463
3999
  ne00,
3464
4000
  ne02,
3465
4001
  nb01,
@@ -3484,17 +4020,26 @@ kernel void kernel_mul_mm_id(
3484
4020
  #define QK_NL 4
3485
4021
  #endif
3486
4022
 
4023
+ //
4024
+ // get rows
4025
+ //
4026
+
3487
4027
  typedef void (get_rows_t)(
3488
4028
  device const void * src0,
3489
- device const int * src1,
4029
+ device const char * src1,
3490
4030
  device float * dst,
3491
4031
  constant int64_t & ne00,
3492
4032
  constant uint64_t & nb01,
4033
+ constant uint64_t & nb02,
4034
+ constant int64_t & ne10,
4035
+ constant uint64_t & nb10,
4036
+ constant uint64_t & nb11,
3493
4037
  constant uint64_t & nb1,
3494
- uint, uint, uint);
4038
+ constant uint64_t & nb2,
4039
+ uint3, uint, uint3);
3495
4040
 
3496
- template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
3497
- template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
4041
+ //template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
4042
+ //template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
3498
4043
  template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
3499
4044
  template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
3500
4045
  template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
@@ -3506,6 +4051,10 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows
3506
4051
  template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
3507
4052
  template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
3508
4053
 
4054
+ //
4055
+ // matrix-matrix multiplication
4056
+ //
4057
+
3509
4058
  typedef void (mat_mm_t)(
3510
4059
  device const uchar * src0,
3511
4060
  device const uchar * src1,
@@ -3538,20 +4087,27 @@ template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
3538
4087
  template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
3539
4088
  template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
3540
4089
 
4090
+ //
4091
+ // indirect matrix-matrix multiplication
4092
+ //
4093
+
3541
4094
  typedef void (mat_mm_id_t)(
3542
- device const int32_t * ids,
4095
+ device const uchar * ids,
3543
4096
  device const uchar * src1,
3544
- device float * dst,
4097
+ device uchar * dst,
4098
+ constant int64_t & nbi1,
3545
4099
  constant int64_t & ne00,
3546
4100
  constant int64_t & ne02,
3547
4101
  constant int64_t & nb01,
3548
4102
  constant int64_t & nb02,
3549
4103
  constant int64_t & ne12,
4104
+ constant int64_t & ne13,
3550
4105
  constant int64_t & nb10,
3551
4106
  constant int64_t & nb11,
3552
4107
  constant int64_t & nb12,
3553
4108
  constant int64_t & ne0,
3554
4109
  constant int64_t & ne1,
4110
+ constant int64_t & nb1,
3555
4111
  constant uint & r2,
3556
4112
  constant uint & r3,
3557
4113
  constant int & idx,
@@ -3578,3 +4134,775 @@ template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mu
3578
4134
  template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
3579
4135
  template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
3580
4136
  template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
4137
+
4138
+ //
4139
+ // matrix-vector multiplication
4140
+ //
4141
+
4142
+ [[host_name("kernel_mul_mv_id_f32_f32")]]
4143
+ kernel void kernel_mul_mv_id_f32_f32(
4144
+ device const char * ids,
4145
+ device const char * src1,
4146
+ device uchar * dst,
4147
+ constant int64_t & nbi1,
4148
+ constant int64_t & ne00,
4149
+ constant int64_t & ne01,
4150
+ constant int64_t & ne02,
4151
+ constant uint64_t & nb00,
4152
+ constant uint64_t & nb01,
4153
+ constant uint64_t & nb02,
4154
+ constant int64_t & ne10,
4155
+ constant int64_t & ne11,
4156
+ constant int64_t & ne12,
4157
+ constant int64_t & ne13,
4158
+ constant uint64_t & nb10,
4159
+ constant uint64_t & nb11,
4160
+ constant uint64_t & nb12,
4161
+ constant int64_t & ne0,
4162
+ constant int64_t & ne1,
4163
+ constant int64_t & nb1,
4164
+ constant uint & r2,
4165
+ constant uint & r3,
4166
+ constant int & idx,
4167
+ device const char * src00,
4168
+ device const char * src01,
4169
+ device const char * src02,
4170
+ device const char * src03,
4171
+ device const char * src04,
4172
+ device const char * src05,
4173
+ device const char * src06,
4174
+ device const char * src07,
4175
+ uint3 tgpig[[threadgroup_position_in_grid]],
4176
+ uint tiitg[[thread_index_in_threadgroup]],
4177
+ uint tiisg[[thread_index_in_simdgroup]],
4178
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4179
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4180
+
4181
+ const int64_t bid = tgpig.z/(ne12*ne13);
4182
+
4183
+ tgpig.z = tgpig.z%(ne12*ne13);
4184
+
4185
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4186
+
4187
+ kernel_mul_mv_f32_f32_impl(
4188
+ src0[id],
4189
+ src1 + bid*nb11,
4190
+ (device float *) (dst + bid*nb1),
4191
+ ne00,
4192
+ ne01,
4193
+ ne02,
4194
+ nb00,
4195
+ nb01,
4196
+ nb02,
4197
+ ne10,
4198
+ ne11,
4199
+ ne12,
4200
+ nb10,
4201
+ nb11,
4202
+ nb12,
4203
+ ne0,
4204
+ ne1,
4205
+ r2,
4206
+ r3,
4207
+ tgpig,
4208
+ tiisg);
4209
+ }
4210
+
4211
+ [[host_name("kernel_mul_mv_id_f16_f32")]]
4212
+ kernel void kernel_mul_mv_id_f16_f32(
4213
+ device const char * ids,
4214
+ device const char * src1,
4215
+ device uchar * dst,
4216
+ constant int64_t & nbi1,
4217
+ constant int64_t & ne00,
4218
+ constant int64_t & ne01,
4219
+ constant int64_t & ne02,
4220
+ constant uint64_t & nb00,
4221
+ constant uint64_t & nb01,
4222
+ constant uint64_t & nb02,
4223
+ constant int64_t & ne10,
4224
+ constant int64_t & ne11,
4225
+ constant int64_t & ne12,
4226
+ constant int64_t & ne13,
4227
+ constant uint64_t & nb10,
4228
+ constant uint64_t & nb11,
4229
+ constant uint64_t & nb12,
4230
+ constant int64_t & ne0,
4231
+ constant int64_t & ne1,
4232
+ constant int64_t & nb1,
4233
+ constant uint & r2,
4234
+ constant uint & r3,
4235
+ constant int & idx,
4236
+ device const char * src00,
4237
+ device const char * src01,
4238
+ device const char * src02,
4239
+ device const char * src03,
4240
+ device const char * src04,
4241
+ device const char * src05,
4242
+ device const char * src06,
4243
+ device const char * src07,
4244
+ uint3 tgpig[[threadgroup_position_in_grid]],
4245
+ uint tiitg[[thread_index_in_threadgroup]],
4246
+ uint tiisg[[thread_index_in_simdgroup]],
4247
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4248
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4249
+
4250
+ const int64_t bid = tgpig.z/(ne12*ne13);
4251
+
4252
+ tgpig.z = tgpig.z%(ne12*ne13);
4253
+
4254
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4255
+
4256
+ kernel_mul_mv_f16_f32_impl(
4257
+ src0[id],
4258
+ src1 + bid*nb11,
4259
+ (device float *) (dst + bid*nb1),
4260
+ ne00,
4261
+ ne01,
4262
+ ne02,
4263
+ nb00,
4264
+ nb01,
4265
+ nb02,
4266
+ ne10,
4267
+ ne11,
4268
+ ne12,
4269
+ nb10,
4270
+ nb11,
4271
+ nb12,
4272
+ ne0,
4273
+ ne1,
4274
+ r2,
4275
+ r3,
4276
+ tgpig,
4277
+ tiisg);
4278
+ }
4279
+
4280
+ [[host_name("kernel_mul_mv_id_q8_0_f32")]]
4281
+ kernel void kernel_mul_mv_id_q8_0_f32(
4282
+ device const char * ids,
4283
+ device const char * src1,
4284
+ device uchar * dst,
4285
+ constant int64_t & nbi1,
4286
+ constant int64_t & ne00,
4287
+ constant int64_t & ne01,
4288
+ constant int64_t & ne02,
4289
+ constant uint64_t & nb00,
4290
+ constant uint64_t & nb01,
4291
+ constant uint64_t & nb02,
4292
+ constant int64_t & ne10,
4293
+ constant int64_t & ne11,
4294
+ constant int64_t & ne12,
4295
+ constant int64_t & ne13,
4296
+ constant uint64_t & nb10,
4297
+ constant uint64_t & nb11,
4298
+ constant uint64_t & nb12,
4299
+ constant int64_t & ne0,
4300
+ constant int64_t & ne1,
4301
+ constant int64_t & nb1,
4302
+ constant uint & r2,
4303
+ constant uint & r3,
4304
+ constant int & idx,
4305
+ device const char * src00,
4306
+ device const char * src01,
4307
+ device const char * src02,
4308
+ device const char * src03,
4309
+ device const char * src04,
4310
+ device const char * src05,
4311
+ device const char * src06,
4312
+ device const char * src07,
4313
+ uint3 tgpig[[threadgroup_position_in_grid]],
4314
+ uint tiitg[[thread_index_in_threadgroup]],
4315
+ uint tiisg[[thread_index_in_simdgroup]],
4316
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4317
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4318
+
4319
+ const int64_t bid = tgpig.z/(ne12*ne13);
4320
+
4321
+ tgpig.z = tgpig.z%(ne12*ne13);
4322
+
4323
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4324
+
4325
+ kernel_mul_mv_q8_0_f32_impl(
4326
+ src0[id],
4327
+ (device const float *) (src1 + bid*nb11),
4328
+ (device float *) ( dst + bid*nb1),
4329
+ ne00,
4330
+ ne01,
4331
+ ne02,
4332
+ ne10,
4333
+ ne12,
4334
+ ne0,
4335
+ ne1,
4336
+ r2,
4337
+ r3,
4338
+ tgpig,
4339
+ tiisg,
4340
+ sgitg);
4341
+ }
4342
+
4343
+ [[host_name("kernel_mul_mv_id_q4_0_f32")]]
4344
+ kernel void kernel_mul_mv_id_q4_0_f32(
4345
+ device const char * ids,
4346
+ device const char * src1,
4347
+ device uchar * dst,
4348
+ constant int64_t & nbi1,
4349
+ constant int64_t & ne00,
4350
+ constant int64_t & ne01,
4351
+ constant int64_t & ne02,
4352
+ constant uint64_t & nb00,
4353
+ constant uint64_t & nb01,
4354
+ constant uint64_t & nb02,
4355
+ constant int64_t & ne10,
4356
+ constant int64_t & ne11,
4357
+ constant int64_t & ne12,
4358
+ constant int64_t & ne13,
4359
+ constant uint64_t & nb10,
4360
+ constant uint64_t & nb11,
4361
+ constant uint64_t & nb12,
4362
+ constant int64_t & ne0,
4363
+ constant int64_t & ne1,
4364
+ constant int64_t & nb1,
4365
+ constant uint & r2,
4366
+ constant uint & r3,
4367
+ constant int & idx,
4368
+ device const char * src00,
4369
+ device const char * src01,
4370
+ device const char * src02,
4371
+ device const char * src03,
4372
+ device const char * src04,
4373
+ device const char * src05,
4374
+ device const char * src06,
4375
+ device const char * src07,
4376
+ uint3 tgpig[[threadgroup_position_in_grid]],
4377
+ uint tiitg[[thread_index_in_threadgroup]],
4378
+ uint tiisg[[thread_index_in_simdgroup]],
4379
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4380
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4381
+
4382
+ const int64_t bid = tgpig.z/(ne12*ne13);
4383
+
4384
+ tgpig.z = tgpig.z%(ne12*ne13);
4385
+
4386
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4387
+
4388
+ mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4389
+ src0[id],
4390
+ (device const float *) (src1 + bid*nb11),
4391
+ (device float *) ( dst + bid*nb1),
4392
+ ne00,
4393
+ ne01,
4394
+ ne02,
4395
+ ne10,
4396
+ ne12,
4397
+ ne0,
4398
+ ne1,
4399
+ r2,
4400
+ r3,
4401
+ tgpig,
4402
+ tiisg,
4403
+ sgitg);
4404
+ }
4405
+
4406
+ [[host_name("kernel_mul_mv_id_q4_1_f32")]]
4407
+ kernel void kernel_mul_mv_id_q4_1_f32(
4408
+ device const char * ids,
4409
+ device const char * src1,
4410
+ device uchar * dst,
4411
+ constant int64_t & nbi1,
4412
+ constant int64_t & ne00,
4413
+ constant int64_t & ne01,
4414
+ constant int64_t & ne02,
4415
+ constant uint64_t & nb00,
4416
+ constant uint64_t & nb01,
4417
+ constant uint64_t & nb02,
4418
+ constant int64_t & ne10,
4419
+ constant int64_t & ne11,
4420
+ constant int64_t & ne12,
4421
+ constant int64_t & ne13,
4422
+ constant uint64_t & nb10,
4423
+ constant uint64_t & nb11,
4424
+ constant uint64_t & nb12,
4425
+ constant int64_t & ne0,
4426
+ constant int64_t & ne1,
4427
+ constant int64_t & nb1,
4428
+ constant uint & r2,
4429
+ constant uint & r3,
4430
+ constant int & idx,
4431
+ device const char * src00,
4432
+ device const char * src01,
4433
+ device const char * src02,
4434
+ device const char * src03,
4435
+ device const char * src04,
4436
+ device const char * src05,
4437
+ device const char * src06,
4438
+ device const char * src07,
4439
+ uint3 tgpig[[threadgroup_position_in_grid]],
4440
+ uint tiitg[[thread_index_in_threadgroup]],
4441
+ uint tiisg[[thread_index_in_simdgroup]],
4442
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4443
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4444
+
4445
+ const int64_t bid = tgpig.z/(ne12*ne13);
4446
+
4447
+ tgpig.z = tgpig.z%(ne12*ne13);
4448
+
4449
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4450
+
4451
+ mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4452
+ src0[id],
4453
+ (device const float *) (src1 + bid*nb11),
4454
+ (device float *) ( dst + bid*nb1),
4455
+ ne00,
4456
+ ne01,
4457
+ ne02,
4458
+ ne10,
4459
+ ne12,
4460
+ ne0,
4461
+ ne1,
4462
+ r2,
4463
+ r3,
4464
+ tgpig,
4465
+ tiisg,
4466
+ sgitg);
4467
+ }
4468
+
4469
+ [[host_name("kernel_mul_mv_id_q5_0_f32")]]
4470
+ kernel void kernel_mul_mv_id_q5_0_f32(
4471
+ device const char * ids,
4472
+ device const char * src1,
4473
+ device uchar * dst,
4474
+ constant int64_t & nbi1,
4475
+ constant int64_t & ne00,
4476
+ constant int64_t & ne01,
4477
+ constant int64_t & ne02,
4478
+ constant uint64_t & nb00,
4479
+ constant uint64_t & nb01,
4480
+ constant uint64_t & nb02,
4481
+ constant int64_t & ne10,
4482
+ constant int64_t & ne11,
4483
+ constant int64_t & ne12,
4484
+ constant int64_t & ne13,
4485
+ constant uint64_t & nb10,
4486
+ constant uint64_t & nb11,
4487
+ constant uint64_t & nb12,
4488
+ constant int64_t & ne0,
4489
+ constant int64_t & ne1,
4490
+ constant int64_t & nb1,
4491
+ constant uint & r2,
4492
+ constant uint & r3,
4493
+ constant int & idx,
4494
+ device const char * src00,
4495
+ device const char * src01,
4496
+ device const char * src02,
4497
+ device const char * src03,
4498
+ device const char * src04,
4499
+ device const char * src05,
4500
+ device const char * src06,
4501
+ device const char * src07,
4502
+ uint3 tgpig[[threadgroup_position_in_grid]],
4503
+ uint tiitg[[thread_index_in_threadgroup]],
4504
+ uint tiisg[[thread_index_in_simdgroup]],
4505
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4506
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4507
+
4508
+ const int64_t bid = tgpig.z/(ne12*ne13);
4509
+
4510
+ tgpig.z = tgpig.z%(ne12*ne13);
4511
+
4512
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4513
+
4514
+ mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4515
+ src0[id],
4516
+ (device const float *) (src1 + bid*nb11),
4517
+ (device float *) ( dst + bid*nb1),
4518
+ ne00,
4519
+ ne01,
4520
+ ne02,
4521
+ ne10,
4522
+ ne12,
4523
+ ne0,
4524
+ ne1,
4525
+ r2,
4526
+ r3,
4527
+ tgpig,
4528
+ tiisg,
4529
+ sgitg);
4530
+ }
4531
+
4532
+ [[host_name("kernel_mul_mv_id_q5_1_f32")]]
4533
+ kernel void kernel_mul_mv_id_q5_1_f32(
4534
+ device const char * ids,
4535
+ device const char * src1,
4536
+ device uchar * dst,
4537
+ constant int64_t & nbi1,
4538
+ constant int64_t & ne00,
4539
+ constant int64_t & ne01,
4540
+ constant int64_t & ne02,
4541
+ constant uint64_t & nb00,
4542
+ constant uint64_t & nb01,
4543
+ constant uint64_t & nb02,
4544
+ constant int64_t & ne10,
4545
+ constant int64_t & ne11,
4546
+ constant int64_t & ne12,
4547
+ constant int64_t & ne13,
4548
+ constant uint64_t & nb10,
4549
+ constant uint64_t & nb11,
4550
+ constant uint64_t & nb12,
4551
+ constant int64_t & ne0,
4552
+ constant int64_t & ne1,
4553
+ constant int64_t & nb1,
4554
+ constant uint & r2,
4555
+ constant uint & r3,
4556
+ constant int & idx,
4557
+ device const char * src00,
4558
+ device const char * src01,
4559
+ device const char * src02,
4560
+ device const char * src03,
4561
+ device const char * src04,
4562
+ device const char * src05,
4563
+ device const char * src06,
4564
+ device const char * src07,
4565
+ uint3 tgpig[[threadgroup_position_in_grid]],
4566
+ uint tiitg[[thread_index_in_threadgroup]],
4567
+ uint tiisg[[thread_index_in_simdgroup]],
4568
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4569
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4570
+
4571
+ const int64_t bid = tgpig.z/(ne12*ne13);
4572
+
4573
+ tgpig.z = tgpig.z%(ne12*ne13);
4574
+
4575
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4576
+
4577
+ mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4578
+ src0[id],
4579
+ (device const float *) (src1 + bid*nb11),
4580
+ (device float *) ( dst + bid*nb1),
4581
+ ne00,
4582
+ ne01,
4583
+ ne02,
4584
+ ne10,
4585
+ ne12,
4586
+ ne0,
4587
+ ne1,
4588
+ r2,
4589
+ r3,
4590
+ tgpig,
4591
+ tiisg,
4592
+ sgitg);
4593
+ }
4594
+
4595
+ [[host_name("kernel_mul_mv_id_q2_K_f32")]]
4596
+ kernel void kernel_mul_mv_id_q2_K_f32(
4597
+ device const char * ids,
4598
+ device const char * src1,
4599
+ device uchar * dst,
4600
+ constant int64_t & nbi1,
4601
+ constant int64_t & ne00,
4602
+ constant int64_t & ne01,
4603
+ constant int64_t & ne02,
4604
+ constant uint64_t & nb00,
4605
+ constant uint64_t & nb01,
4606
+ constant uint64_t & nb02,
4607
+ constant int64_t & ne10,
4608
+ constant int64_t & ne11,
4609
+ constant int64_t & ne12,
4610
+ constant int64_t & ne13,
4611
+ constant uint64_t & nb10,
4612
+ constant uint64_t & nb11,
4613
+ constant uint64_t & nb12,
4614
+ constant int64_t & ne0,
4615
+ constant int64_t & ne1,
4616
+ constant int64_t & nb1,
4617
+ constant uint & r2,
4618
+ constant uint & r3,
4619
+ constant int & idx,
4620
+ device const char * src00,
4621
+ device const char * src01,
4622
+ device const char * src02,
4623
+ device const char * src03,
4624
+ device const char * src04,
4625
+ device const char * src05,
4626
+ device const char * src06,
4627
+ device const char * src07,
4628
+ uint3 tgpig[[threadgroup_position_in_grid]],
4629
+ uint tiitg[[thread_index_in_threadgroup]],
4630
+ uint tiisg[[thread_index_in_simdgroup]],
4631
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4632
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4633
+
4634
+ const int64_t bid = tgpig.z/(ne12*ne13);
4635
+
4636
+ tgpig.z = tgpig.z%(ne12*ne13);
4637
+
4638
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4639
+
4640
+ kernel_mul_mv_q2_K_f32_impl(
4641
+ src0[id],
4642
+ (device const float *) (src1 + bid*nb11),
4643
+ (device float *) ( dst + bid*nb1),
4644
+ ne00,
4645
+ ne01,
4646
+ ne02,
4647
+ ne10,
4648
+ ne12,
4649
+ ne0,
4650
+ ne1,
4651
+ r2,
4652
+ r3,
4653
+ tgpig,
4654
+ tiisg,
4655
+ sgitg);
4656
+ }
4657
+
4658
+ [[host_name("kernel_mul_mv_id_q3_K_f32")]]
4659
+ kernel void kernel_mul_mv_id_q3_K_f32(
4660
+ device const char * ids,
4661
+ device const char * src1,
4662
+ device uchar * dst,
4663
+ constant int64_t & nbi1,
4664
+ constant int64_t & ne00,
4665
+ constant int64_t & ne01,
4666
+ constant int64_t & ne02,
4667
+ constant uint64_t & nb00,
4668
+ constant uint64_t & nb01,
4669
+ constant uint64_t & nb02,
4670
+ constant int64_t & ne10,
4671
+ constant int64_t & ne11,
4672
+ constant int64_t & ne12,
4673
+ constant int64_t & ne13,
4674
+ constant uint64_t & nb10,
4675
+ constant uint64_t & nb11,
4676
+ constant uint64_t & nb12,
4677
+ constant int64_t & ne0,
4678
+ constant int64_t & ne1,
4679
+ constant int64_t & nb1,
4680
+ constant uint & r2,
4681
+ constant uint & r3,
4682
+ constant int & idx,
4683
+ device const char * src00,
4684
+ device const char * src01,
4685
+ device const char * src02,
4686
+ device const char * src03,
4687
+ device const char * src04,
4688
+ device const char * src05,
4689
+ device const char * src06,
4690
+ device const char * src07,
4691
+ uint3 tgpig[[threadgroup_position_in_grid]],
4692
+ uint tiitg[[thread_index_in_threadgroup]],
4693
+ uint tiisg[[thread_index_in_simdgroup]],
4694
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4695
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4696
+
4697
+ const int64_t bid = tgpig.z/(ne12*ne13);
4698
+
4699
+ tgpig.z = tgpig.z%(ne12*ne13);
4700
+
4701
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4702
+
4703
+ kernel_mul_mv_q3_K_f32_impl(
4704
+ src0[id],
4705
+ (device const float *) (src1 + bid*nb11),
4706
+ (device float *) ( dst + bid*nb1),
4707
+ ne00,
4708
+ ne01,
4709
+ ne02,
4710
+ ne10,
4711
+ ne12,
4712
+ ne0,
4713
+ ne1,
4714
+ r2,
4715
+ r3,
4716
+ tgpig,
4717
+ tiisg,
4718
+ sgitg);
4719
+ }
4720
+
4721
+ [[host_name("kernel_mul_mv_id_q4_K_f32")]]
4722
+ kernel void kernel_mul_mv_id_q4_K_f32(
4723
+ device const char * ids,
4724
+ device const char * src1,
4725
+ device uchar * dst,
4726
+ constant int64_t & nbi1,
4727
+ constant int64_t & ne00,
4728
+ constant int64_t & ne01,
4729
+ constant int64_t & ne02,
4730
+ constant uint64_t & nb00,
4731
+ constant uint64_t & nb01,
4732
+ constant uint64_t & nb02,
4733
+ constant int64_t & ne10,
4734
+ constant int64_t & ne11,
4735
+ constant int64_t & ne12,
4736
+ constant int64_t & ne13,
4737
+ constant uint64_t & nb10,
4738
+ constant uint64_t & nb11,
4739
+ constant uint64_t & nb12,
4740
+ constant int64_t & ne0,
4741
+ constant int64_t & ne1,
4742
+ constant int64_t & nb1,
4743
+ constant uint & r2,
4744
+ constant uint & r3,
4745
+ constant int & idx,
4746
+ device const char * src00,
4747
+ device const char * src01,
4748
+ device const char * src02,
4749
+ device const char * src03,
4750
+ device const char * src04,
4751
+ device const char * src05,
4752
+ device const char * src06,
4753
+ device const char * src07,
4754
+ uint3 tgpig[[threadgroup_position_in_grid]],
4755
+ uint tiitg[[thread_index_in_threadgroup]],
4756
+ uint tiisg[[thread_index_in_simdgroup]],
4757
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4758
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4759
+
4760
+ const int64_t bid = tgpig.z/(ne12*ne13);
4761
+
4762
+ tgpig.z = tgpig.z%(ne12*ne13);
4763
+
4764
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4765
+
4766
+ kernel_mul_mv_q4_K_f32_impl(
4767
+ src0[id],
4768
+ (device const float *) (src1 + bid*nb11),
4769
+ (device float *) ( dst + bid*nb1),
4770
+ ne00,
4771
+ ne01,
4772
+ ne02,
4773
+ ne10,
4774
+ ne12,
4775
+ ne0,
4776
+ ne1,
4777
+ r2,
4778
+ r3,
4779
+ tgpig,
4780
+ tiisg,
4781
+ sgitg);
4782
+ }
4783
+
4784
+ [[host_name("kernel_mul_mv_id_q5_K_f32")]]
4785
+ kernel void kernel_mul_mv_id_q5_K_f32(
4786
+ device const char * ids,
4787
+ device const char * src1,
4788
+ device uchar * dst,
4789
+ constant int64_t & nbi1,
4790
+ constant int64_t & ne00,
4791
+ constant int64_t & ne01,
4792
+ constant int64_t & ne02,
4793
+ constant uint64_t & nb00,
4794
+ constant uint64_t & nb01,
4795
+ constant uint64_t & nb02,
4796
+ constant int64_t & ne10,
4797
+ constant int64_t & ne11,
4798
+ constant int64_t & ne12,
4799
+ constant int64_t & ne13,
4800
+ constant uint64_t & nb10,
4801
+ constant uint64_t & nb11,
4802
+ constant uint64_t & nb12,
4803
+ constant int64_t & ne0,
4804
+ constant int64_t & ne1,
4805
+ constant int64_t & nb1,
4806
+ constant uint & r2,
4807
+ constant uint & r3,
4808
+ constant int & idx,
4809
+ device const char * src00,
4810
+ device const char * src01,
4811
+ device const char * src02,
4812
+ device const char * src03,
4813
+ device const char * src04,
4814
+ device const char * src05,
4815
+ device const char * src06,
4816
+ device const char * src07,
4817
+ uint3 tgpig[[threadgroup_position_in_grid]],
4818
+ uint tiitg[[thread_index_in_threadgroup]],
4819
+ uint tiisg[[thread_index_in_simdgroup]],
4820
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4821
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4822
+
4823
+ const int64_t bid = tgpig.z/(ne12*ne13);
4824
+
4825
+ tgpig.z = tgpig.z%(ne12*ne13);
4826
+
4827
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4828
+
4829
+ kernel_mul_mv_q5_K_f32_impl(
4830
+ src0[id],
4831
+ (device const float *) (src1 + bid*nb11),
4832
+ (device float *) ( dst + bid*nb1),
4833
+ ne00,
4834
+ ne01,
4835
+ ne02,
4836
+ ne10,
4837
+ ne12,
4838
+ ne0,
4839
+ ne1,
4840
+ r2,
4841
+ r3,
4842
+ tgpig,
4843
+ tiisg,
4844
+ sgitg);
4845
+ }
4846
+
4847
+ [[host_name("kernel_mul_mv_id_q6_K_f32")]]
4848
+ kernel void kernel_mul_mv_id_q6_K_f32(
4849
+ device const char * ids,
4850
+ device const char * src1,
4851
+ device uchar * dst,
4852
+ constant int64_t & nbi1,
4853
+ constant int64_t & ne00,
4854
+ constant int64_t & ne01,
4855
+ constant int64_t & ne02,
4856
+ constant uint64_t & nb00,
4857
+ constant uint64_t & nb01,
4858
+ constant uint64_t & nb02,
4859
+ constant int64_t & ne10,
4860
+ constant int64_t & ne11,
4861
+ constant int64_t & ne12,
4862
+ constant int64_t & ne13,
4863
+ constant uint64_t & nb10,
4864
+ constant uint64_t & nb11,
4865
+ constant uint64_t & nb12,
4866
+ constant int64_t & ne0,
4867
+ constant int64_t & ne1,
4868
+ constant int64_t & nb1,
4869
+ constant uint & r2,
4870
+ constant uint & r3,
4871
+ constant int & idx,
4872
+ device const char * src00,
4873
+ device const char * src01,
4874
+ device const char * src02,
4875
+ device const char * src03,
4876
+ device const char * src04,
4877
+ device const char * src05,
4878
+ device const char * src06,
4879
+ device const char * src07,
4880
+ uint3 tgpig[[threadgroup_position_in_grid]],
4881
+ uint tiitg[[thread_index_in_threadgroup]],
4882
+ uint tiisg[[thread_index_in_simdgroup]],
4883
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4884
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4885
+
4886
+ const int64_t bid = tgpig.z/(ne12*ne13);
4887
+
4888
+ tgpig.z = tgpig.z%(ne12*ne13);
4889
+
4890
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4891
+
4892
+ kernel_mul_mv_q6_K_f32_impl(
4893
+ src0[id],
4894
+ (device const float *) (src1 + bid*nb11),
4895
+ (device float *) ( dst + bid*nb1),
4896
+ ne00,
4897
+ ne01,
4898
+ ne02,
4899
+ ne10,
4900
+ ne12,
4901
+ ne0,
4902
+ ne1,
4903
+ r2,
4904
+ r3,
4905
+ tgpig,
4906
+ tiisg,
4907
+ sgitg);
4908
+ }