node-llama-cpp 2.8.2 → 2.8.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -59,26 +59,27 @@ kernel void kernel_add(
59
59
  constant int64_t & ne01,
60
60
  constant int64_t & ne02,
61
61
  constant int64_t & ne03,
62
- constant int64_t & nb00,
63
- constant int64_t & nb01,
64
- constant int64_t & nb02,
65
- constant int64_t & nb03,
62
+ constant uint64_t & nb00,
63
+ constant uint64_t & nb01,
64
+ constant uint64_t & nb02,
65
+ constant uint64_t & nb03,
66
66
  constant int64_t & ne10,
67
67
  constant int64_t & ne11,
68
68
  constant int64_t & ne12,
69
69
  constant int64_t & ne13,
70
- constant int64_t & nb10,
71
- constant int64_t & nb11,
72
- constant int64_t & nb12,
73
- constant int64_t & nb13,
70
+ constant uint64_t & nb10,
71
+ constant uint64_t & nb11,
72
+ constant uint64_t & nb12,
73
+ constant uint64_t & nb13,
74
74
  constant int64_t & ne0,
75
75
  constant int64_t & ne1,
76
76
  constant int64_t & ne2,
77
77
  constant int64_t & ne3,
78
- constant int64_t & nb0,
79
- constant int64_t & nb1,
80
- constant int64_t & nb2,
81
- constant int64_t & nb3,
78
+ constant uint64_t & nb0,
79
+ constant uint64_t & nb1,
80
+ constant uint64_t & nb2,
81
+ constant uint64_t & nb3,
82
+ constant int64_t & offs,
82
83
  uint3 tgpig[[threadgroup_position_in_grid]],
83
84
  uint3 tpitg[[thread_position_in_threadgroup]],
84
85
  uint3 ntg[[threads_per_threadgroup]]) {
@@ -90,9 +91,9 @@ kernel void kernel_add(
90
91
  const int64_t i12 = i02 % ne12;
91
92
  const int64_t i11 = i01 % ne11;
92
93
 
93
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
94
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
94
95
  device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
95
- device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
96
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs;
96
97
 
97
98
  for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
98
99
  const int i10 = i0 % ne10;
@@ -108,26 +109,26 @@ kernel void kernel_mul(
108
109
  constant int64_t & ne01,
109
110
  constant int64_t & ne02,
110
111
  constant int64_t & ne03,
111
- constant int64_t & nb00,
112
- constant int64_t & nb01,
113
- constant int64_t & nb02,
114
- constant int64_t & nb03,
112
+ constant uint64_t & nb00,
113
+ constant uint64_t & nb01,
114
+ constant uint64_t & nb02,
115
+ constant uint64_t & nb03,
115
116
  constant int64_t & ne10,
116
117
  constant int64_t & ne11,
117
118
  constant int64_t & ne12,
118
119
  constant int64_t & ne13,
119
- constant int64_t & nb10,
120
- constant int64_t & nb11,
121
- constant int64_t & nb12,
122
- constant int64_t & nb13,
120
+ constant uint64_t & nb10,
121
+ constant uint64_t & nb11,
122
+ constant uint64_t & nb12,
123
+ constant uint64_t & nb13,
123
124
  constant int64_t & ne0,
124
125
  constant int64_t & ne1,
125
126
  constant int64_t & ne2,
126
127
  constant int64_t & ne3,
127
- constant int64_t & nb0,
128
- constant int64_t & nb1,
129
- constant int64_t & nb2,
130
- constant int64_t & nb3,
128
+ constant uint64_t & nb0,
129
+ constant uint64_t & nb1,
130
+ constant uint64_t & nb2,
131
+ constant uint64_t & nb3,
131
132
  uint3 tgpig[[threadgroup_position_in_grid]],
132
133
  uint3 tpitg[[thread_position_in_threadgroup]],
133
134
  uint3 ntg[[threads_per_threadgroup]]) {
@@ -157,26 +158,26 @@ kernel void kernel_div(
157
158
  constant int64_t & ne01,
158
159
  constant int64_t & ne02,
159
160
  constant int64_t & ne03,
160
- constant int64_t & nb00,
161
- constant int64_t & nb01,
162
- constant int64_t & nb02,
163
- constant int64_t & nb03,
161
+ constant uint64_t & nb00,
162
+ constant uint64_t & nb01,
163
+ constant uint64_t & nb02,
164
+ constant uint64_t & nb03,
164
165
  constant int64_t & ne10,
165
166
  constant int64_t & ne11,
166
167
  constant int64_t & ne12,
167
168
  constant int64_t & ne13,
168
- constant int64_t & nb10,
169
- constant int64_t & nb11,
170
- constant int64_t & nb12,
171
- constant int64_t & nb13,
169
+ constant uint64_t & nb10,
170
+ constant uint64_t & nb11,
171
+ constant uint64_t & nb12,
172
+ constant uint64_t & nb13,
172
173
  constant int64_t & ne0,
173
174
  constant int64_t & ne1,
174
175
  constant int64_t & ne2,
175
176
  constant int64_t & ne3,
176
- constant int64_t & nb0,
177
- constant int64_t & nb1,
178
- constant int64_t & nb2,
179
- constant int64_t & nb3,
177
+ constant uint64_t & nb0,
178
+ constant uint64_t & nb1,
179
+ constant uint64_t & nb2,
180
+ constant uint64_t & nb3,
180
181
  uint3 tgpig[[threadgroup_position_in_grid]],
181
182
  uint3 tpitg[[thread_position_in_threadgroup]],
182
183
  uint3 ntg[[threads_per_threadgroup]]) {
@@ -204,7 +205,7 @@ kernel void kernel_add_row(
204
205
  device const float4 * src0,
205
206
  device const float4 * src1,
206
207
  device float4 * dst,
207
- constant int64_t & nb [[buffer(27)]],
208
+ constant uint64_t & nb [[buffer(28)]],
208
209
  uint tpig[[thread_position_in_grid]]) {
209
210
  dst[tpig] = src0[tpig] + src1[tpig % nb];
210
211
  }
@@ -213,7 +214,7 @@ kernel void kernel_mul_row(
213
214
  device const float4 * src0,
214
215
  device const float4 * src1,
215
216
  device float4 * dst,
216
- constant int64_t & nb [[buffer(27)]],
217
+ constant uint64_t & nb [[buffer(28)]],
217
218
  uint tpig[[thread_position_in_grid]]) {
218
219
  dst[tpig] = src0[tpig] * src1[tpig % nb];
219
220
  }
@@ -222,7 +223,7 @@ kernel void kernel_div_row(
222
223
  device const float4 * src0,
223
224
  device const float4 * src1,
224
225
  device float4 * dst,
225
- constant int64_t & nb [[buffer(27)]],
226
+ constant uint64_t & nb [[buffer(28)]],
226
227
  uint tpig[[thread_position_in_grid]]) {
227
228
  dst[tpig] = src0[tpig] / src1[tpig % nb];
228
229
  }
@@ -243,19 +244,53 @@ kernel void kernel_scale_4(
243
244
  dst[tpig] = src0[tpig] * scale;
244
245
  }
245
246
 
246
- kernel void kernel_silu(
247
- device const float4 * src0,
248
- device float4 * dst,
247
+ kernel void kernel_relu(
248
+ device const float * src0,
249
+ device float * dst,
249
250
  uint tpig[[thread_position_in_grid]]) {
250
- device const float4 & x = src0[tpig];
251
- dst[tpig] = x / (1.0f + exp(-x));
251
+ dst[tpig] = max(0.0f, src0[tpig]);
252
252
  }
253
253
 
254
- kernel void kernel_relu(
254
+ kernel void kernel_tanh(
255
255
  device const float * src0,
256
256
  device float * dst,
257
257
  uint tpig[[thread_position_in_grid]]) {
258
- dst[tpig] = max(0.0f, src0[tpig]);
258
+ device const float & x = src0[tpig];
259
+ dst[tpig] = precise::tanh(x);
260
+ }
261
+
262
+ constant float GELU_COEF_A = 0.044715f;
263
+ constant float GELU_QUICK_COEF = -1.702f;
264
+ constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
265
+
266
+ kernel void kernel_gelu(
267
+ device const float4 * src0,
268
+ device float4 * dst,
269
+ uint tpig[[thread_position_in_grid]]) {
270
+ device const float4 & x = src0[tpig];
271
+
272
+ // BEWARE !!!
273
+ // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
274
+ // This was observed with Falcon 7B and 40B models
275
+ //
276
+ dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
277
+ }
278
+
279
+ kernel void kernel_gelu_quick(
280
+ device const float4 * src0,
281
+ device float4 * dst,
282
+ uint tpig[[thread_position_in_grid]]) {
283
+ device const float4 & x = src0[tpig];
284
+
285
+ dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
286
+ }
287
+
288
+ kernel void kernel_silu(
289
+ device const float4 * src0,
290
+ device float4 * dst,
291
+ uint tpig[[thread_position_in_grid]]) {
292
+ device const float4 & x = src0[tpig];
293
+ dst[tpig] = x / (1.0f + exp(-x));
259
294
  }
260
295
 
261
296
  kernel void kernel_sqr(
@@ -272,26 +307,26 @@ kernel void kernel_sum_rows(
272
307
  constant int64_t & ne01,
273
308
  constant int64_t & ne02,
274
309
  constant int64_t & ne03,
275
- constant int64_t & nb00,
276
- constant int64_t & nb01,
277
- constant int64_t & nb02,
278
- constant int64_t & nb03,
310
+ constant uint64_t & nb00,
311
+ constant uint64_t & nb01,
312
+ constant uint64_t & nb02,
313
+ constant uint64_t & nb03,
279
314
  constant int64_t & ne10,
280
315
  constant int64_t & ne11,
281
316
  constant int64_t & ne12,
282
317
  constant int64_t & ne13,
283
- constant int64_t & nb10,
284
- constant int64_t & nb11,
285
- constant int64_t & nb12,
286
- constant int64_t & nb13,
318
+ constant uint64_t & nb10,
319
+ constant uint64_t & nb11,
320
+ constant uint64_t & nb12,
321
+ constant uint64_t & nb13,
287
322
  constant int64_t & ne0,
288
323
  constant int64_t & ne1,
289
324
  constant int64_t & ne2,
290
325
  constant int64_t & ne3,
291
- constant int64_t & nb0,
292
- constant int64_t & nb1,
293
- constant int64_t & nb2,
294
- constant int64_t & nb3,
326
+ constant uint64_t & nb0,
327
+ constant uint64_t & nb1,
328
+ constant uint64_t & nb2,
329
+ constant uint64_t & nb3,
295
330
  uint3 tpig[[thread_position_in_grid]]) {
296
331
  int64_t i3 = tpig.z;
297
332
  int64_t i2 = tpig.y;
@@ -313,22 +348,6 @@ kernel void kernel_sum_rows(
313
348
  dst_row[0] = row_sum;
314
349
  }
315
350
 
316
- constant float GELU_COEF_A = 0.044715f;
317
- constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
318
-
319
- kernel void kernel_gelu(
320
- device const float4 * src0,
321
- device float4 * dst,
322
- uint tpig[[thread_position_in_grid]]) {
323
- device const float4 & x = src0[tpig];
324
-
325
- // BEWARE !!!
326
- // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
327
- // This was observed with Falcon 7B and 40B models
328
- //
329
- dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
330
- }
331
-
332
351
  kernel void kernel_soft_max(
333
352
  device const float * src0,
334
353
  device const float * src1,
@@ -347,9 +366,9 @@ kernel void kernel_soft_max(
347
366
  const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
348
367
  const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
349
368
 
350
- device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
351
- device const float * pmask = src1 ? src1 + i01*ne00 : nullptr;
352
- device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
369
+ device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
370
+ device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
371
+ device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
353
372
 
354
373
  // parallel max
355
374
  float lmax = -INFINITY;
@@ -385,7 +404,12 @@ kernel void kernel_soft_max(
385
404
  pdst[i00] = exp_psrc0;
386
405
  }
387
406
 
407
+ // This barrier fixes a failing test
408
+ // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
409
+ threadgroup_barrier(mem_flags::mem_none);
410
+
388
411
  float sum = simd_sum(lsum);
412
+
389
413
  if (ntg > N_SIMDWIDTH) {
390
414
  if (sgitg == 0) {
391
415
  buf[tiisg] = 0.0f;
@@ -428,9 +452,9 @@ kernel void kernel_soft_max_4(
428
452
  const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
429
453
  const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
430
454
 
431
- device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
432
- device const float4 * pmask = src1 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
433
- device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
455
+ device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
456
+ device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
457
+ device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
434
458
 
435
459
  // parallel max
436
460
  float4 lmax4 = -INFINITY;
@@ -468,7 +492,13 @@ kernel void kernel_soft_max_4(
468
492
  }
469
493
 
470
494
  const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
495
+
496
+ // This barrier fixes a failing test
497
+ // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
498
+ threadgroup_barrier(mem_flags::mem_none);
499
+
471
500
  float sum = simd_sum(lsum);
501
+
472
502
  if (ntg > N_SIMDWIDTH) {
473
503
  if (sgitg == 0) {
474
504
  buf[tiisg] = 0.0f;
@@ -639,6 +669,94 @@ kernel void kernel_rms_norm(
639
669
  }
640
670
  }
641
671
 
672
+ kernel void kernel_group_norm(
673
+ device const float * src0,
674
+ device float * dst,
675
+ constant int64_t & ne00,
676
+ constant int64_t & ne01,
677
+ constant int64_t & ne02,
678
+ constant uint64_t & nb00,
679
+ constant uint64_t & nb01,
680
+ constant uint64_t & nb02,
681
+ constant int32_t & n_groups,
682
+ constant float & eps,
683
+ threadgroup float * buf [[threadgroup(0)]],
684
+ uint tgpig[[threadgroup_position_in_grid]],
685
+ uint tpitg[[thread_position_in_threadgroup]],
686
+ uint sgitg[[simdgroup_index_in_threadgroup]],
687
+ uint tiisg[[thread_index_in_simdgroup]],
688
+ uint ntg[[threads_per_threadgroup]]) {
689
+ const int64_t ne = ne00*ne01*ne02;
690
+ const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups);
691
+
692
+ int start = tgpig * gs;
693
+ int end = start + gs;
694
+
695
+ start += tpitg;
696
+
697
+ if (end >= ne) {
698
+ end = ne;
699
+ }
700
+
701
+ float tmp = 0.0f; // partial sum for thread in warp
702
+
703
+ for (int j = start; j < end; j += ntg) {
704
+ tmp += src0[j];
705
+ }
706
+
707
+ threadgroup_barrier(mem_flags::mem_threadgroup);
708
+ tmp = simd_sum(tmp);
709
+ if (ntg > N_SIMDWIDTH) {
710
+ if (sgitg == 0) {
711
+ buf[tiisg] = 0.0f;
712
+ }
713
+
714
+ threadgroup_barrier(mem_flags::mem_threadgroup);
715
+
716
+ if (tiisg == 0) {
717
+ buf[sgitg] = tmp;
718
+ }
719
+
720
+ threadgroup_barrier(mem_flags::mem_threadgroup);
721
+
722
+ tmp = buf[tiisg];
723
+ tmp = simd_sum(tmp);
724
+ }
725
+
726
+ const float mean = tmp / gs;
727
+ tmp = 0.0f;
728
+
729
+ for (int j = start; j < end; j += ntg) {
730
+ float xi = src0[j] - mean;
731
+ dst[j] = xi;
732
+ tmp += xi * xi;
733
+ }
734
+
735
+ tmp = simd_sum(tmp);
736
+ if (ntg > N_SIMDWIDTH) {
737
+ if (sgitg == 0) {
738
+ buf[tiisg] = 0.0f;
739
+ }
740
+
741
+ threadgroup_barrier(mem_flags::mem_threadgroup);
742
+
743
+ if (tiisg == 0) {
744
+ buf[sgitg] = tmp;
745
+ }
746
+
747
+ threadgroup_barrier(mem_flags::mem_threadgroup);
748
+
749
+ tmp = buf[tiisg];
750
+ tmp = simd_sum(tmp);
751
+ }
752
+
753
+ const float variance = tmp / gs;
754
+ const float scale = 1.0f/sqrt(variance + eps);
755
+ for (int j = start; j < end; j += ntg) {
756
+ dst[j] *= scale;
757
+ }
758
+ }
759
+
642
760
  // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
643
761
  // il indicates where the q4 quants begin (0 or QK4_0/4)
644
762
  // we assume that the yl's have been multiplied with the appropriate scale factor
@@ -728,10 +846,10 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
728
846
  #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
729
847
  //Note: This is a template, but strictly speaking it only applies to
730
848
  // quantizations where the block size is 32. It also does not
731
- // giard against the number of rows not being divisible by
849
+ // guard against the number of rows not being divisible by
732
850
  // N_DST, so this is another explicit assumption of the implementation.
733
851
  template<typename block_q_type, int nr, int nsg, int nw>
734
- void mul_vec_q_n_f32(
852
+ void mul_vec_q_n_f32_impl(
735
853
  device const void * src0,
736
854
  device const float * src1,
737
855
  device float * dst,
@@ -802,18 +920,25 @@ kernel void kernel_mul_mv_q4_0_f32(
802
920
  device const float * src1,
803
921
  device float * dst,
804
922
  constant int64_t & ne00,
805
- constant int64_t & ne01[[buffer(4)]],
806
- constant int64_t & ne02[[buffer(5)]],
807
- constant int64_t & ne10[[buffer(9)]],
808
- constant int64_t & ne12[[buffer(11)]],
809
- constant int64_t & ne0 [[buffer(15)]],
810
- constant int64_t & ne1 [[buffer(16)]],
811
- constant uint & r2 [[buffer(17)]],
812
- constant uint & r3 [[buffer(18)]],
923
+ constant int64_t & ne01,
924
+ constant int64_t & ne02,
925
+ constant uint64_t & nb00,
926
+ constant uint64_t & nb01,
927
+ constant uint64_t & nb02,
928
+ constant int64_t & ne10,
929
+ constant int64_t & ne11,
930
+ constant int64_t & ne12,
931
+ constant uint64_t & nb10,
932
+ constant uint64_t & nb11,
933
+ constant uint64_t & nb12,
934
+ constant int64_t & ne0,
935
+ constant int64_t & ne1,
936
+ constant uint & r2,
937
+ constant uint & r3,
813
938
  uint3 tgpig[[threadgroup_position_in_grid]],
814
939
  uint tiisg[[thread_index_in_simdgroup]],
815
940
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
816
- mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
941
+ mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
817
942
  }
818
943
 
819
944
  kernel void kernel_mul_mv_q4_1_f32(
@@ -821,18 +946,25 @@ kernel void kernel_mul_mv_q4_1_f32(
821
946
  device const float * src1,
822
947
  device float * dst,
823
948
  constant int64_t & ne00,
824
- constant int64_t & ne01[[buffer(4)]],
825
- constant int64_t & ne02[[buffer(5)]],
826
- constant int64_t & ne10[[buffer(9)]],
827
- constant int64_t & ne12[[buffer(11)]],
828
- constant int64_t & ne0 [[buffer(15)]],
829
- constant int64_t & ne1 [[buffer(16)]],
830
- constant uint & r2 [[buffer(17)]],
831
- constant uint & r3 [[buffer(18)]],
949
+ constant int64_t & ne01,
950
+ constant int64_t & ne02,
951
+ constant uint64_t & nb00,
952
+ constant uint64_t & nb01,
953
+ constant uint64_t & nb02,
954
+ constant int64_t & ne10,
955
+ constant int64_t & ne11,
956
+ constant int64_t & ne12,
957
+ constant uint64_t & nb10,
958
+ constant uint64_t & nb11,
959
+ constant uint64_t & nb12,
960
+ constant int64_t & ne0,
961
+ constant int64_t & ne1,
962
+ constant uint & r2,
963
+ constant uint & r3,
832
964
  uint3 tgpig[[threadgroup_position_in_grid]],
833
965
  uint tiisg[[thread_index_in_simdgroup]],
834
966
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
835
- mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
967
+ mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
836
968
  }
837
969
 
838
970
  kernel void kernel_mul_mv_q5_0_f32(
@@ -840,18 +972,25 @@ kernel void kernel_mul_mv_q5_0_f32(
840
972
  device const float * src1,
841
973
  device float * dst,
842
974
  constant int64_t & ne00,
843
- constant int64_t & ne01[[buffer(4)]],
844
- constant int64_t & ne02[[buffer(5)]],
845
- constant int64_t & ne10[[buffer(9)]],
846
- constant int64_t & ne12[[buffer(11)]],
847
- constant int64_t & ne0 [[buffer(15)]],
848
- constant int64_t & ne1 [[buffer(16)]],
849
- constant uint & r2 [[buffer(17)]],
850
- constant uint & r3 [[buffer(18)]],
975
+ constant int64_t & ne01,
976
+ constant int64_t & ne02,
977
+ constant uint64_t & nb00,
978
+ constant uint64_t & nb01,
979
+ constant uint64_t & nb02,
980
+ constant int64_t & ne10,
981
+ constant int64_t & ne11,
982
+ constant int64_t & ne12,
983
+ constant uint64_t & nb10,
984
+ constant uint64_t & nb11,
985
+ constant uint64_t & nb12,
986
+ constant int64_t & ne0,
987
+ constant int64_t & ne1,
988
+ constant uint & r2,
989
+ constant uint & r3,
851
990
  uint3 tgpig[[threadgroup_position_in_grid]],
852
991
  uint tiisg[[thread_index_in_simdgroup]],
853
992
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
854
- mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
993
+ mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
855
994
  }
856
995
 
857
996
  kernel void kernel_mul_mv_q5_1_f32(
@@ -859,39 +998,46 @@ kernel void kernel_mul_mv_q5_1_f32(
859
998
  device const float * src1,
860
999
  device float * dst,
861
1000
  constant int64_t & ne00,
862
- constant int64_t & ne01[[buffer(4)]],
863
- constant int64_t & ne02[[buffer(5)]],
864
- constant int64_t & ne10[[buffer(9)]],
865
- constant int64_t & ne12[[buffer(11)]],
866
- constant int64_t & ne0 [[buffer(15)]],
867
- constant int64_t & ne1 [[buffer(16)]],
868
- constant uint & r2 [[buffer(17)]],
869
- constant uint & r3 [[buffer(18)]],
1001
+ constant int64_t & ne01,
1002
+ constant int64_t & ne02,
1003
+ constant uint64_t & nb00,
1004
+ constant uint64_t & nb01,
1005
+ constant uint64_t & nb02,
1006
+ constant int64_t & ne10,
1007
+ constant int64_t & ne11,
1008
+ constant int64_t & ne12,
1009
+ constant uint64_t & nb10,
1010
+ constant uint64_t & nb11,
1011
+ constant uint64_t & nb12,
1012
+ constant int64_t & ne0,
1013
+ constant int64_t & ne1,
1014
+ constant uint & r2,
1015
+ constant uint & r3,
870
1016
  uint3 tgpig[[threadgroup_position_in_grid]],
871
1017
  uint tiisg[[thread_index_in_simdgroup]],
872
1018
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
873
- mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
1019
+ mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
874
1020
  }
875
1021
 
876
1022
 
877
1023
  #define NB_Q8_0 8
878
1024
 
879
- kernel void kernel_mul_mv_q8_0_f32(
1025
+ void kernel_mul_mv_q8_0_f32_impl(
880
1026
  device const void * src0,
881
1027
  device const float * src1,
882
1028
  device float * dst,
883
1029
  constant int64_t & ne00,
884
- constant int64_t & ne01[[buffer(4)]],
885
- constant int64_t & ne02[[buffer(5)]],
886
- constant int64_t & ne10[[buffer(9)]],
887
- constant int64_t & ne12[[buffer(11)]],
888
- constant int64_t & ne0 [[buffer(15)]],
889
- constant int64_t & ne1 [[buffer(16)]],
890
- constant uint & r2 [[buffer(17)]],
891
- constant uint & r3 [[buffer(18)]],
1030
+ constant int64_t & ne01,
1031
+ constant int64_t & ne02,
1032
+ constant int64_t & ne10,
1033
+ constant int64_t & ne12,
1034
+ constant int64_t & ne0,
1035
+ constant int64_t & ne1,
1036
+ constant uint & r2,
1037
+ constant uint & r3,
892
1038
  uint3 tgpig[[threadgroup_position_in_grid]],
893
- uint tiisg[[thread_index_in_simdgroup]],
894
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
1039
+ uint tiisg[[thread_index_in_simdgroup]],
1040
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
895
1041
  const int nr = N_DST;
896
1042
  const int nsg = N_SIMDGROUP;
897
1043
  const int nw = N_SIMDWIDTH;
@@ -945,9 +1091,36 @@ kernel void kernel_mul_mv_q8_0_f32(
945
1091
  }
946
1092
  }
947
1093
 
1094
+ [[host_name("kernel_mul_mv_q8_0_f32")]]
1095
+ kernel void kernel_mul_mv_q8_0_f32(
1096
+ device const void * src0,
1097
+ device const float * src1,
1098
+ device float * dst,
1099
+ constant int64_t & ne00,
1100
+ constant int64_t & ne01,
1101
+ constant int64_t & ne02,
1102
+ constant uint64_t & nb00,
1103
+ constant uint64_t & nb01,
1104
+ constant uint64_t & nb02,
1105
+ constant int64_t & ne10,
1106
+ constant int64_t & ne11,
1107
+ constant int64_t & ne12,
1108
+ constant uint64_t & nb10,
1109
+ constant uint64_t & nb11,
1110
+ constant uint64_t & nb12,
1111
+ constant int64_t & ne0,
1112
+ constant int64_t & ne1,
1113
+ constant uint & r2,
1114
+ constant uint & r3,
1115
+ uint3 tgpig[[threadgroup_position_in_grid]],
1116
+ uint tiisg[[thread_index_in_simdgroup]],
1117
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1118
+ kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
1119
+ }
1120
+
948
1121
  #define N_F32_F32 4
949
1122
 
950
- kernel void kernel_mul_mv_f32_f32(
1123
+ void kernel_mul_mv_f32_f32_impl(
951
1124
  device const char * src0,
952
1125
  device const char * src1,
953
1126
  device float * dst,
@@ -965,8 +1138,8 @@ kernel void kernel_mul_mv_f32_f32(
965
1138
  constant uint64_t & nb12,
966
1139
  constant int64_t & ne0,
967
1140
  constant int64_t & ne1,
968
- constant uint & r2 [[buffer(17)]],
969
- constant uint & r3 [[buffer(18)]],
1141
+ constant uint & r2,
1142
+ constant uint & r3,
970
1143
  uint3 tgpig[[threadgroup_position_in_grid]],
971
1144
  uint tiisg[[thread_index_in_simdgroup]]) {
972
1145
 
@@ -1025,6 +1198,32 @@ kernel void kernel_mul_mv_f32_f32(
1025
1198
  }
1026
1199
  }
1027
1200
 
1201
+ [[host_name("kernel_mul_mv_f32_f32")]]
1202
+ kernel void kernel_mul_mv_f32_f32(
1203
+ device const char * src0,
1204
+ device const char * src1,
1205
+ device float * dst,
1206
+ constant int64_t & ne00,
1207
+ constant int64_t & ne01,
1208
+ constant int64_t & ne02,
1209
+ constant uint64_t & nb00,
1210
+ constant uint64_t & nb01,
1211
+ constant uint64_t & nb02,
1212
+ constant int64_t & ne10,
1213
+ constant int64_t & ne11,
1214
+ constant int64_t & ne12,
1215
+ constant uint64_t & nb10,
1216
+ constant uint64_t & nb11,
1217
+ constant uint64_t & nb12,
1218
+ constant int64_t & ne0,
1219
+ constant int64_t & ne1,
1220
+ constant uint & r2,
1221
+ constant uint & r3,
1222
+ uint3 tgpig[[threadgroup_position_in_grid]],
1223
+ uint tiisg[[thread_index_in_simdgroup]]) {
1224
+ kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
1225
+ }
1226
+
1028
1227
  #define N_F16_F16 4
1029
1228
 
1030
1229
  kernel void kernel_mul_mv_f16_f16(
@@ -1045,8 +1244,8 @@ kernel void kernel_mul_mv_f16_f16(
1045
1244
  constant uint64_t & nb12,
1046
1245
  constant int64_t & ne0,
1047
1246
  constant int64_t & ne1,
1048
- constant uint & r2 [[buffer(17)]],
1049
- constant uint & r3 [[buffer(18)]],
1247
+ constant uint & r2,
1248
+ constant uint & r3,
1050
1249
  uint3 tgpig[[threadgroup_position_in_grid]],
1051
1250
  uint tiisg[[thread_index_in_simdgroup]]) {
1052
1251
 
@@ -1105,7 +1304,7 @@ kernel void kernel_mul_mv_f16_f16(
1105
1304
  }
1106
1305
  }
1107
1306
 
1108
- kernel void kernel_mul_mv_f16_f32_1row(
1307
+ void kernel_mul_mv_f16_f32_1row_impl(
1109
1308
  device const char * src0,
1110
1309
  device const char * src1,
1111
1310
  device float * dst,
@@ -1123,8 +1322,8 @@ kernel void kernel_mul_mv_f16_f32_1row(
1123
1322
  constant uint64_t & nb12,
1124
1323
  constant int64_t & ne0,
1125
1324
  constant int64_t & ne1,
1126
- constant uint & r2 [[buffer(17)]],
1127
- constant uint & r3 [[buffer(18)]],
1325
+ constant uint & r2,
1326
+ constant uint & r3,
1128
1327
  uint3 tgpig[[threadgroup_position_in_grid]],
1129
1328
  uint tiisg[[thread_index_in_simdgroup]]) {
1130
1329
 
@@ -1161,12 +1360,10 @@ kernel void kernel_mul_mv_f16_f32_1row(
1161
1360
  dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1162
1361
  }
1163
1362
  }
1164
-
1165
1363
  }
1166
1364
 
1167
- #define N_F16_F32 4
1168
-
1169
- kernel void kernel_mul_mv_f16_f32(
1365
+ [[host_name("kernel_mul_mv_f16_f32_1row")]]
1366
+ kernel void kernel_mul_mv_f16_f32_1row(
1170
1367
  device const char * src0,
1171
1368
  device const char * src1,
1172
1369
  device float * dst,
@@ -1184,21 +1381,48 @@ kernel void kernel_mul_mv_f16_f32(
1184
1381
  constant uint64_t & nb12,
1185
1382
  constant int64_t & ne0,
1186
1383
  constant int64_t & ne1,
1187
- constant uint & r2 [[buffer(17)]],
1188
- constant uint & r3 [[buffer(18)]],
1384
+ constant uint & r2,
1385
+ constant uint & r3,
1189
1386
  uint3 tgpig[[threadgroup_position_in_grid]],
1190
- uint tiisg[[thread_index_in_simdgroup]]) {
1191
-
1192
- const int64_t r0 = tgpig.x;
1193
- const int64_t rb = tgpig.y*N_F16_F32;
1194
- const int64_t im = tgpig.z;
1195
-
1196
- const uint i12 = im%ne12;
1197
- const uint i13 = im/ne12;
1387
+ uint tiisg[[thread_index_in_simdgroup]]) {
1388
+ kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
1389
+ }
1198
1390
 
1199
- const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1391
+ #define N_F16_F32 4
1200
1392
 
1201
- device const half * x = (device const half *) (src0 + offset0);
1393
+ void kernel_mul_mv_f16_f32_impl(
1394
+ device const char * src0,
1395
+ device const char * src1,
1396
+ device float * dst,
1397
+ constant int64_t & ne00,
1398
+ constant int64_t & ne01,
1399
+ constant int64_t & ne02,
1400
+ constant uint64_t & nb00,
1401
+ constant uint64_t & nb01,
1402
+ constant uint64_t & nb02,
1403
+ constant int64_t & ne10,
1404
+ constant int64_t & ne11,
1405
+ constant int64_t & ne12,
1406
+ constant uint64_t & nb10,
1407
+ constant uint64_t & nb11,
1408
+ constant uint64_t & nb12,
1409
+ constant int64_t & ne0,
1410
+ constant int64_t & ne1,
1411
+ constant uint & r2,
1412
+ constant uint & r3,
1413
+ uint3 tgpig[[threadgroup_position_in_grid]],
1414
+ uint tiisg[[thread_index_in_simdgroup]]) {
1415
+
1416
+ const int64_t r0 = tgpig.x;
1417
+ const int64_t rb = tgpig.y*N_F16_F32;
1418
+ const int64_t im = tgpig.z;
1419
+
1420
+ const uint i12 = im%ne12;
1421
+ const uint i13 = im/ne12;
1422
+
1423
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1424
+
1425
+ device const half * x = (device const half *) (src0 + offset0);
1202
1426
 
1203
1427
  if (ne00 < 128) {
1204
1428
  for (int row = 0; row < N_F16_F32; ++row) {
@@ -1244,6 +1468,32 @@ kernel void kernel_mul_mv_f16_f32(
1244
1468
  }
1245
1469
  }
1246
1470
 
1471
+ [[host_name("kernel_mul_mv_f16_f32")]]
1472
+ kernel void kernel_mul_mv_f16_f32(
1473
+ device const char * src0,
1474
+ device const char * src1,
1475
+ device float * dst,
1476
+ constant int64_t & ne00,
1477
+ constant int64_t & ne01,
1478
+ constant int64_t & ne02,
1479
+ constant uint64_t & nb00,
1480
+ constant uint64_t & nb01,
1481
+ constant uint64_t & nb02,
1482
+ constant int64_t & ne10,
1483
+ constant int64_t & ne11,
1484
+ constant int64_t & ne12,
1485
+ constant uint64_t & nb10,
1486
+ constant uint64_t & nb11,
1487
+ constant uint64_t & nb12,
1488
+ constant int64_t & ne0,
1489
+ constant int64_t & ne1,
1490
+ constant uint & r2,
1491
+ constant uint & r3,
1492
+ uint3 tgpig[[threadgroup_position_in_grid]],
1493
+ uint tiisg[[thread_index_in_simdgroup]]) {
1494
+ kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
1495
+ }
1496
+
1247
1497
  // Assumes row size (ne00) is a multiple of 4
1248
1498
  kernel void kernel_mul_mv_f16_f32_l4(
1249
1499
  device const char * src0,
@@ -1263,8 +1513,8 @@ kernel void kernel_mul_mv_f16_f32_l4(
1263
1513
  constant uint64_t & nb12,
1264
1514
  constant int64_t & ne0,
1265
1515
  constant int64_t & ne1,
1266
- constant uint & r2 [[buffer(17)]],
1267
- constant uint & r3 [[buffer(18)]],
1516
+ constant uint & r2,
1517
+ constant uint & r3,
1268
1518
  uint3 tgpig[[threadgroup_position_in_grid]],
1269
1519
  uint tiisg[[thread_index_in_simdgroup]]) {
1270
1520
 
@@ -1328,7 +1578,8 @@ kernel void kernel_alibi_f32(
1328
1578
  const int64_t i3 = n / (ne2*ne1*ne0);
1329
1579
  const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1330
1580
  const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1331
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1581
+ //const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1582
+
1332
1583
  const int64_t k = i3*ne3 + i2;
1333
1584
 
1334
1585
  float m_k;
@@ -1487,8 +1738,9 @@ kernel void kernel_rope(
1487
1738
  dst_data[1] = x0*sin_theta + x1*cos_theta;
1488
1739
  }
1489
1740
  } else {
1490
- for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
1491
- for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
1741
+ for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
1742
+ if (ic < n_dims) {
1743
+ const int64_t ib = 0;
1492
1744
 
1493
1745
  // simplified from `(ib * n_dims + ic) * inv_ndims`
1494
1746
  const float cur_rot = inv_ndims*ic - ib;
@@ -1507,6 +1759,14 @@ kernel void kernel_rope(
1507
1759
 
1508
1760
  dst_data[0] = x0*cos_theta - x1*sin_theta;
1509
1761
  dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
1762
+ } else {
1763
+ const int64_t i0 = ic;
1764
+
1765
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1766
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1767
+
1768
+ dst_data[0] = src[0];
1769
+ dst_data[1] = src[1];
1510
1770
  }
1511
1771
  }
1512
1772
  }
@@ -1548,6 +1808,97 @@ kernel void kernel_im2col_f16(
1548
1808
  }
1549
1809
  }
1550
1810
 
1811
+ kernel void kernel_upscale_f32(
1812
+ device const char * src0,
1813
+ device char * dst,
1814
+ constant int64_t & ne00,
1815
+ constant int64_t & ne01,
1816
+ constant int64_t & ne02,
1817
+ constant int64_t & ne03,
1818
+ constant uint64_t & nb00,
1819
+ constant uint64_t & nb01,
1820
+ constant uint64_t & nb02,
1821
+ constant uint64_t & nb03,
1822
+ constant int64_t & ne0,
1823
+ constant int64_t & ne1,
1824
+ constant int64_t & ne2,
1825
+ constant int64_t & ne3,
1826
+ constant uint64_t & nb0,
1827
+ constant uint64_t & nb1,
1828
+ constant uint64_t & nb2,
1829
+ constant uint64_t & nb3,
1830
+ constant int32_t & sf,
1831
+ uint3 tgpig[[threadgroup_position_in_grid]],
1832
+ uint3 tpitg[[thread_position_in_threadgroup]],
1833
+ uint3 ntg[[threads_per_threadgroup]]) {
1834
+
1835
+ const int64_t i3 = tgpig.z;
1836
+ const int64_t i2 = tgpig.y;
1837
+ const int64_t i1 = tgpig.x;
1838
+
1839
+ const int64_t i03 = i3;
1840
+ const int64_t i02 = i2;
1841
+ const int64_t i01 = i1/sf;
1842
+
1843
+ device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
1844
+ device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
1845
+
1846
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
1847
+ dst_ptr[i0] = src0_ptr[i0/sf];
1848
+ }
1849
+ }
1850
+
1851
+ kernel void kernel_pad_f32(
1852
+ device const char * src0,
1853
+ device char * dst,
1854
+ constant int64_t & ne00,
1855
+ constant int64_t & ne01,
1856
+ constant int64_t & ne02,
1857
+ constant int64_t & ne03,
1858
+ constant uint64_t & nb00,
1859
+ constant uint64_t & nb01,
1860
+ constant uint64_t & nb02,
1861
+ constant uint64_t & nb03,
1862
+ constant int64_t & ne0,
1863
+ constant int64_t & ne1,
1864
+ constant int64_t & ne2,
1865
+ constant int64_t & ne3,
1866
+ constant uint64_t & nb0,
1867
+ constant uint64_t & nb1,
1868
+ constant uint64_t & nb2,
1869
+ constant uint64_t & nb3,
1870
+ uint3 tgpig[[threadgroup_position_in_grid]],
1871
+ uint3 tpitg[[thread_position_in_threadgroup]],
1872
+ uint3 ntg[[threads_per_threadgroup]]) {
1873
+
1874
+ const int64_t i3 = tgpig.z;
1875
+ const int64_t i2 = tgpig.y;
1876
+ const int64_t i1 = tgpig.x;
1877
+
1878
+ const int64_t i03 = i3;
1879
+ const int64_t i02 = i2;
1880
+ const int64_t i01 = i1;
1881
+
1882
+ device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
1883
+ device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
1884
+
1885
+ if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
1886
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
1887
+ if (i0 < ne00) {
1888
+ dst_ptr[i0] = src0_ptr[i0];
1889
+ } else {
1890
+ dst_ptr[i0] = 0.0f;
1891
+ }
1892
+ }
1893
+
1894
+ return;
1895
+ }
1896
+
1897
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
1898
+ dst_ptr[i0] = 0.0f;
1899
+ }
1900
+ }
1901
+
1551
1902
  // bitonic sort implementation following the CUDA kernels as reference
1552
1903
  typedef void (argsort_t)(
1553
1904
  device const float * x,
@@ -1600,9 +1951,17 @@ kernel void kernel_argsort_f32_i32(
1600
1951
  template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ASC>;
1601
1952
  template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
1602
1953
 
1954
+ kernel void kernel_leaky_relu_f32(
1955
+ device const float * src0,
1956
+ device float * dst,
1957
+ constant float & slope,
1958
+ uint tpig[[thread_position_in_grid]]) {
1959
+ dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
1960
+ }
1961
+
1603
1962
  kernel void kernel_cpy_f16_f16(
1604
- device const half * src0,
1605
- device half * dst,
1963
+ device const half * src0,
1964
+ device half * dst,
1606
1965
  constant int64_t & ne00,
1607
1966
  constant int64_t & ne01,
1608
1967
  constant int64_t & ne02,
@@ -1641,6 +2000,47 @@ kernel void kernel_cpy_f16_f16(
1641
2000
  }
1642
2001
  }
1643
2002
 
2003
+ kernel void kernel_cpy_f16_f32(
2004
+ device const half * src0,
2005
+ device float * dst,
2006
+ constant int64_t & ne00,
2007
+ constant int64_t & ne01,
2008
+ constant int64_t & ne02,
2009
+ constant int64_t & ne03,
2010
+ constant uint64_t & nb00,
2011
+ constant uint64_t & nb01,
2012
+ constant uint64_t & nb02,
2013
+ constant uint64_t & nb03,
2014
+ constant int64_t & ne0,
2015
+ constant int64_t & ne1,
2016
+ constant int64_t & ne2,
2017
+ constant int64_t & ne3,
2018
+ constant uint64_t & nb0,
2019
+ constant uint64_t & nb1,
2020
+ constant uint64_t & nb2,
2021
+ constant uint64_t & nb3,
2022
+ uint3 tgpig[[threadgroup_position_in_grid]],
2023
+ uint3 tpitg[[thread_position_in_threadgroup]],
2024
+ uint3 ntg[[threads_per_threadgroup]]) {
2025
+ const int64_t i03 = tgpig[2];
2026
+ const int64_t i02 = tgpig[1];
2027
+ const int64_t i01 = tgpig[0];
2028
+
2029
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2030
+
2031
+ const int64_t i3 = n / (ne2*ne1*ne0);
2032
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2033
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2034
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
2035
+
2036
+ device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2037
+
2038
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
2039
+ device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2040
+ dst_data[i00] = src[0];
2041
+ }
2042
+ }
2043
+
1644
2044
  kernel void kernel_cpy_f32_f16(
1645
2045
  device const float * src0,
1646
2046
  device half * dst,
@@ -1917,9 +2317,9 @@ kernel void kernel_cpy_f32_q4_1(
1917
2317
  }
1918
2318
 
1919
2319
  kernel void kernel_concat(
1920
- device const char * src0,
1921
- device const char * src1,
1922
- device char * dst,
2320
+ device const char * src0,
2321
+ device const char * src1,
2322
+ device char * dst,
1923
2323
  constant int64_t & ne00,
1924
2324
  constant int64_t & ne01,
1925
2325
  constant int64_t & ne02,
@@ -1956,7 +2356,7 @@ kernel void kernel_concat(
1956
2356
  const int64_t i12 = i02 % ne12;
1957
2357
  const int64_t i11 = i01 % ne11;
1958
2358
 
1959
- device const char * src0_ptr = src0 + i03 * nb03 + i02 * nb02 + i01 * nb01 + tpitg.x*nb00;
2359
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
1960
2360
  device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
1961
2361
  device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
1962
2362
 
@@ -2046,37 +2446,34 @@ typedef struct {
2046
2446
  } block_q6_K;
2047
2447
  // 210 bytes / block
2048
2448
 
2049
- static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
2050
- uchar4 r;
2051
- if (j < 4) {
2052
- r[0] = q[j+0] & 63;
2053
- r[2] = q[j+1] & 63;
2054
- r[1] = q[j+4] & 63;
2055
- r[3] = q[j+5] & 63;
2056
- } else {
2057
- r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
2058
- r[2] = (q[j+5] & 0xF) | ((q[j-3] >> 6) << 4);
2059
- r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
2060
- r[3] = (q[j+5] >> 4) | ((q[j+1] >> 6) << 4);
2061
- }
2062
- return r;
2063
- }
2449
+ typedef struct {
2450
+ half d;
2451
+ uint16_t qs[QK_K/8];
2452
+ } block_iq2_xxs;
2453
+ // 66 bytes / block for QK_K = 256, so 2.0625 bpw
2454
+
2455
+ typedef struct {
2456
+ half d;
2457
+ uint16_t qs[QK_K/8];
2458
+ uint8_t scales[QK_K/32];
2459
+ } block_iq2_xs;
2460
+ // 74 bytes / block for QK_K = 256, so 2.3125 bpw
2064
2461
 
2065
2462
  //====================================== dot products =========================
2066
2463
 
2067
- kernel void kernel_mul_mv_q2_K_f32(
2464
+ void kernel_mul_mv_q2_K_f32_impl(
2068
2465
  device const void * src0,
2069
2466
  device const float * src1,
2070
2467
  device float * dst,
2071
2468
  constant int64_t & ne00,
2072
- constant int64_t & ne01[[buffer(4)]],
2073
- constant int64_t & ne02[[buffer(5)]],
2074
- constant int64_t & ne10[[buffer(9)]],
2075
- constant int64_t & ne12[[buffer(11)]],
2076
- constant int64_t & ne0 [[buffer(15)]],
2077
- constant int64_t & ne1 [[buffer(16)]],
2078
- constant uint & r2 [[buffer(17)]],
2079
- constant uint & r3 [[buffer(18)]],
2469
+ constant int64_t & ne01,
2470
+ constant int64_t & ne02,
2471
+ constant int64_t & ne10,
2472
+ constant int64_t & ne12,
2473
+ constant int64_t & ne0,
2474
+ constant int64_t & ne1,
2475
+ constant uint & r2,
2476
+ constant uint & r3,
2080
2477
  uint3 tgpig[[threadgroup_position_in_grid]],
2081
2478
  uint tiisg[[thread_index_in_simdgroup]],
2082
2479
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2214,23 +2611,51 @@ kernel void kernel_mul_mv_q2_K_f32(
2214
2611
  }
2215
2612
  }
2216
2613
 
2614
+ [[host_name("kernel_mul_mv_q2_K_f32")]]
2615
+ kernel void kernel_mul_mv_q2_K_f32(
2616
+ device const void * src0,
2617
+ device const float * src1,
2618
+ device float * dst,
2619
+ constant int64_t & ne00,
2620
+ constant int64_t & ne01,
2621
+ constant int64_t & ne02,
2622
+ constant uint64_t & nb00,
2623
+ constant uint64_t & nb01,
2624
+ constant uint64_t & nb02,
2625
+ constant int64_t & ne10,
2626
+ constant int64_t & ne11,
2627
+ constant int64_t & ne12,
2628
+ constant uint64_t & nb10,
2629
+ constant uint64_t & nb11,
2630
+ constant uint64_t & nb12,
2631
+ constant int64_t & ne0,
2632
+ constant int64_t & ne1,
2633
+ constant uint & r2,
2634
+ constant uint & r3,
2635
+ uint3 tgpig[[threadgroup_position_in_grid]],
2636
+ uint tiisg[[thread_index_in_simdgroup]],
2637
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2638
+
2639
+ kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
2640
+ }
2641
+
2217
2642
  #if QK_K == 256
2218
- kernel void kernel_mul_mv_q3_K_f32(
2643
+ void kernel_mul_mv_q3_K_f32_impl(
2219
2644
  device const void * src0,
2220
2645
  device const float * src1,
2221
2646
  device float * dst,
2222
2647
  constant int64_t & ne00,
2223
- constant int64_t & ne01[[buffer(4)]],
2224
- constant int64_t & ne02[[buffer(5)]],
2225
- constant int64_t & ne10[[buffer(9)]],
2226
- constant int64_t & ne12[[buffer(11)]],
2227
- constant int64_t & ne0 [[buffer(15)]],
2228
- constant int64_t & ne1 [[buffer(16)]],
2229
- constant uint & r2 [[buffer(17)]],
2230
- constant uint & r3 [[buffer(18)]],
2648
+ constant int64_t & ne01,
2649
+ constant int64_t & ne02,
2650
+ constant int64_t & ne10,
2651
+ constant int64_t & ne12,
2652
+ constant int64_t & ne0,
2653
+ constant int64_t & ne1,
2654
+ constant uint & r2,
2655
+ constant uint & r3,
2231
2656
  uint3 tgpig[[threadgroup_position_in_grid]],
2232
- uint tiisg[[thread_index_in_simdgroup]],
2233
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2657
+ uint tiisg[[thread_index_in_simdgroup]],
2658
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2234
2659
 
2235
2660
  const int nb = ne00/QK_K;
2236
2661
 
@@ -2373,19 +2798,19 @@ kernel void kernel_mul_mv_q3_K_f32(
2373
2798
  }
2374
2799
  }
2375
2800
  #else
2376
- kernel void kernel_mul_mv_q3_K_f32(
2801
+ void kernel_mul_mv_q3_K_f32_impl(
2377
2802
  device const void * src0,
2378
2803
  device const float * src1,
2379
2804
  device float * dst,
2380
2805
  constant int64_t & ne00,
2381
- constant int64_t & ne01[[buffer(4)]],
2382
- constant int64_t & ne02[[buffer(5)]],
2383
- constant int64_t & ne10[[buffer(9)]],
2384
- constant int64_t & ne12[[buffer(11)]],
2385
- constant int64_t & ne0 [[buffer(15)]],
2386
- constant int64_t & ne1 [[buffer(16)]],
2387
- constant uint & r2 [[buffer(17)]],
2388
- constant uint & r3 [[buffer(18)]],
2806
+ constant int64_t & ne01,
2807
+ constant int64_t & ne02,
2808
+ constant int64_t & ne10,
2809
+ constant int64_t & ne12,
2810
+ constant int64_t & ne0,
2811
+ constant int64_t & ne1,
2812
+ constant uint & r2,
2813
+ constant uint & r3,
2389
2814
  uint3 tgpig[[threadgroup_position_in_grid]],
2390
2815
  uint tiisg[[thread_index_in_simdgroup]],
2391
2816
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2450,20 +2875,48 @@ kernel void kernel_mul_mv_q3_K_f32(
2450
2875
  }
2451
2876
  #endif
2452
2877
 
2878
+ [[host_name("kernel_mul_mv_q3_K_f32")]]
2879
+ kernel void kernel_mul_mv_q3_K_f32(
2880
+ device const void * src0,
2881
+ device const float * src1,
2882
+ device float * dst,
2883
+ constant int64_t & ne00,
2884
+ constant int64_t & ne01,
2885
+ constant int64_t & ne02,
2886
+ constant uint64_t & nb00,
2887
+ constant uint64_t & nb01,
2888
+ constant uint64_t & nb02,
2889
+ constant int64_t & ne10,
2890
+ constant int64_t & ne11,
2891
+ constant int64_t & ne12,
2892
+ constant uint64_t & nb10,
2893
+ constant uint64_t & nb11,
2894
+ constant uint64_t & nb12,
2895
+ constant int64_t & ne0,
2896
+ constant int64_t & ne1,
2897
+ constant uint & r2,
2898
+ constant uint & r3,
2899
+ uint3 tgpig[[threadgroup_position_in_grid]],
2900
+ uint tiisg[[thread_index_in_simdgroup]],
2901
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2902
+
2903
+ kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
2904
+ }
2905
+
2453
2906
  #if QK_K == 256
2454
- kernel void kernel_mul_mv_q4_K_f32(
2907
+ void kernel_mul_mv_q4_K_f32_impl(
2455
2908
  device const void * src0,
2456
2909
  device const float * src1,
2457
2910
  device float * dst,
2458
2911
  constant int64_t & ne00,
2459
- constant int64_t & ne01 [[buffer(4)]],
2460
- constant int64_t & ne02 [[buffer(5)]],
2461
- constant int64_t & ne10 [[buffer(9)]],
2462
- constant int64_t & ne12 [[buffer(11)]],
2463
- constant int64_t & ne0 [[buffer(15)]],
2464
- constant int64_t & ne1 [[buffer(16)]],
2465
- constant uint & r2 [[buffer(17)]],
2466
- constant uint & r3 [[buffer(18)]],
2912
+ constant int64_t & ne01,
2913
+ constant int64_t & ne02,
2914
+ constant int64_t & ne10,
2915
+ constant int64_t & ne12,
2916
+ constant int64_t & ne0,
2917
+ constant int64_t & ne1,
2918
+ constant uint & r2,
2919
+ constant uint & r3,
2467
2920
  uint3 tgpig[[threadgroup_position_in_grid]],
2468
2921
  uint tiisg[[thread_index_in_simdgroup]],
2469
2922
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2564,31 +3017,31 @@ kernel void kernel_mul_mv_q4_K_f32(
2564
3017
  }
2565
3018
  }
2566
3019
  #else
2567
- kernel void kernel_mul_mv_q4_K_f32(
3020
+ void kernel_mul_mv_q4_K_f32_impl(
2568
3021
  device const void * src0,
2569
3022
  device const float * src1,
2570
3023
  device float * dst,
2571
3024
  constant int64_t & ne00,
2572
- constant int64_t & ne01[[buffer(4)]],
2573
- constant int64_t & ne02[[buffer(5)]],
2574
- constant int64_t & ne10[[buffer(9)]],
2575
- constant int64_t & ne12[[buffer(11)]],
2576
- constant int64_t & ne0 [[buffer(15)]],
2577
- constant int64_t & ne1 [[buffer(16)]],
2578
- constant uint & r2 [[buffer(17)]],
2579
- constant uint & r3 [[buffer(18)]],
2580
- uint3 tgpig[[threadgroup_position_in_grid]],
2581
- uint tiisg[[thread_index_in_simdgroup]],
2582
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2583
-
2584
- const int ix = tiisg/4; // 0...7
3025
+ constant int64_t & ne01,
3026
+ constant int64_t & ne02,
3027
+ constant int64_t & ne10,
3028
+ constant int64_t & ne12,
3029
+ constant int64_t & ne0,
3030
+ constant int64_t & ne1,
3031
+ constant uint & r2,
3032
+ constant uint & r3,
3033
+ uint3 tgpig[[threadgroup_position_in_grid]],
3034
+ uint tiisg[[thread_index_in_simdgroup]],
3035
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3036
+
3037
+ const int ix = tiisg/4; // 0...7
2585
3038
  const int it = tiisg%4; // 0...3
2586
3039
 
2587
3040
  const int nb = ne00/QK_K;
2588
3041
  const int r0 = tgpig.x;
2589
3042
  const int r1 = tgpig.y;
2590
3043
  const int im = tgpig.z;
2591
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
3044
+ const int first_row = r0 * N_DST;
2592
3045
  const int ib_row = first_row * nb;
2593
3046
 
2594
3047
  const uint i12 = im%ne12;
@@ -2654,29 +3107,57 @@ kernel void kernel_mul_mv_q4_K_f32(
2654
3107
  for (int row = 0; row < N_DST; ++row) {
2655
3108
  all_sum = simd_sum(sumf[row]);
2656
3109
  if (tiisg == 0) {
2657
- dst[r1*ne0+ im*ne0*ne1 + first_row + row] = all_sum;
3110
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
2658
3111
  }
2659
3112
  }
2660
3113
  }
2661
3114
  #endif
2662
3115
 
2663
- kernel void kernel_mul_mv_q5_K_f32(
3116
+ [[host_name("kernel_mul_mv_q4_K_f32")]]
3117
+ kernel void kernel_mul_mv_q4_K_f32(
2664
3118
  device const void * src0,
2665
3119
  device const float * src1,
2666
3120
  device float * dst,
2667
3121
  constant int64_t & ne00,
2668
- constant int64_t & ne01[[buffer(4)]],
2669
- constant int64_t & ne02[[buffer(5)]],
2670
- constant int64_t & ne10[[buffer(9)]],
2671
- constant int64_t & ne12[[buffer(11)]],
2672
- constant int64_t & ne0 [[buffer(15)]],
2673
- constant int64_t & ne1 [[buffer(16)]],
2674
- constant uint & r2 [[buffer(17)]],
2675
- constant uint & r3 [[buffer(18)]],
3122
+ constant int64_t & ne01,
3123
+ constant int64_t & ne02,
3124
+ constant uint64_t & nb00,
3125
+ constant uint64_t & nb01,
3126
+ constant uint64_t & nb02,
3127
+ constant int64_t & ne10,
3128
+ constant int64_t & ne11,
3129
+ constant int64_t & ne12,
3130
+ constant uint64_t & nb10,
3131
+ constant uint64_t & nb11,
3132
+ constant uint64_t & nb12,
3133
+ constant int64_t & ne0,
3134
+ constant int64_t & ne1,
3135
+ constant uint & r2,
3136
+ constant uint & r3,
2676
3137
  uint3 tgpig[[threadgroup_position_in_grid]],
2677
3138
  uint tiisg[[thread_index_in_simdgroup]],
2678
3139
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
2679
3140
 
3141
+ kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
3142
+ }
3143
+
3144
+ void kernel_mul_mv_q5_K_f32_impl(
3145
+ device const void * src0,
3146
+ device const float * src1,
3147
+ device float * dst,
3148
+ constant int64_t & ne00,
3149
+ constant int64_t & ne01,
3150
+ constant int64_t & ne02,
3151
+ constant int64_t & ne10,
3152
+ constant int64_t & ne12,
3153
+ constant int64_t & ne0,
3154
+ constant int64_t & ne1,
3155
+ constant uint & r2,
3156
+ constant uint & r3,
3157
+ uint3 tgpig[[threadgroup_position_in_grid]],
3158
+ uint tiisg[[thread_index_in_simdgroup]],
3159
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3160
+
2680
3161
  const int nb = ne00/QK_K;
2681
3162
 
2682
3163
  const int64_t r0 = tgpig.x;
@@ -2836,25 +3317,52 @@ kernel void kernel_mul_mv_q5_K_f32(
2836
3317
  dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
2837
3318
  }
2838
3319
  }
3320
+ }
3321
+
3322
+ [[host_name("kernel_mul_mv_q5_K_f32")]]
3323
+ kernel void kernel_mul_mv_q5_K_f32(
3324
+ device const void * src0,
3325
+ device const float * src1,
3326
+ device float * dst,
3327
+ constant int64_t & ne00,
3328
+ constant int64_t & ne01,
3329
+ constant int64_t & ne02,
3330
+ constant uint64_t & nb00,
3331
+ constant uint64_t & nb01,
3332
+ constant uint64_t & nb02,
3333
+ constant int64_t & ne10,
3334
+ constant int64_t & ne11,
3335
+ constant int64_t & ne12,
3336
+ constant uint64_t & nb10,
3337
+ constant uint64_t & nb11,
3338
+ constant uint64_t & nb12,
3339
+ constant int64_t & ne0,
3340
+ constant int64_t & ne1,
3341
+ constant uint & r2,
3342
+ constant uint & r3,
3343
+ uint3 tgpig[[threadgroup_position_in_grid]],
3344
+ uint tiisg[[thread_index_in_simdgroup]],
3345
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2839
3346
 
3347
+ kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
2840
3348
  }
2841
3349
 
2842
- kernel void kernel_mul_mv_q6_K_f32(
3350
+ void kernel_mul_mv_q6_K_f32_impl(
2843
3351
  device const void * src0,
2844
3352
  device const float * src1,
2845
3353
  device float * dst,
2846
3354
  constant int64_t & ne00,
2847
- constant int64_t & ne01[[buffer(4)]],
2848
- constant int64_t & ne02[[buffer(5)]],
2849
- constant int64_t & ne10[[buffer(9)]],
2850
- constant int64_t & ne12[[buffer(11)]],
2851
- constant int64_t & ne0 [[buffer(15)]],
2852
- constant int64_t & ne1 [[buffer(16)]],
2853
- constant uint & r2 [[buffer(17)]],
2854
- constant uint & r3 [[buffer(18)]],
3355
+ constant int64_t & ne01,
3356
+ constant int64_t & ne02,
3357
+ constant int64_t & ne10,
3358
+ constant int64_t & ne12,
3359
+ constant int64_t & ne0,
3360
+ constant int64_t & ne1,
3361
+ constant uint & r2,
3362
+ constant uint & r3,
2855
3363
  uint3 tgpig[[threadgroup_position_in_grid]],
2856
- uint tiisg[[thread_index_in_simdgroup]],
2857
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3364
+ uint tiisg[[thread_index_in_simdgroup]],
3365
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2858
3366
 
2859
3367
  const uint8_t kmask1 = 0x03;
2860
3368
  const uint8_t kmask2 = 0x0C;
@@ -2945,160 +3453,677 @@ kernel void kernel_mul_mv_q6_K_f32(
2945
3453
  }
2946
3454
  }
2947
3455
 
2948
- //============================= templates and their specializations =============================
3456
+ [[host_name("kernel_mul_mv_q6_K_f32")]]
3457
+ kernel void kernel_mul_mv_q6_K_f32(
3458
+ device const void * src0,
3459
+ device const float * src1,
3460
+ device float * dst,
3461
+ constant int64_t & ne00,
3462
+ constant int64_t & ne01,
3463
+ constant int64_t & ne02,
3464
+ constant uint64_t & nb00,
3465
+ constant uint64_t & nb01,
3466
+ constant uint64_t & nb02,
3467
+ constant int64_t & ne10,
3468
+ constant int64_t & ne11,
3469
+ constant int64_t & ne12,
3470
+ constant uint64_t & nb10,
3471
+ constant uint64_t & nb11,
3472
+ constant uint64_t & nb12,
3473
+ constant int64_t & ne0,
3474
+ constant int64_t & ne1,
3475
+ constant uint & r2,
3476
+ constant uint & r3,
3477
+ uint3 tgpig[[threadgroup_position_in_grid]],
3478
+ uint tiisg[[thread_index_in_simdgroup]],
3479
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2949
3480
 
2950
- // NOTE: this is not dequantizing - we are simply fitting the template
2951
- template <typename type4x4>
2952
- void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
2953
- float4x4 temp = *(((device float4x4 *)src));
2954
- for (int i = 0; i < 16; i++){
2955
- reg[i/4][i%4] = temp[i/4][i%4];
2956
- }
3481
+ kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
2957
3482
  }
2958
3483
 
2959
- template <typename type4x4>
2960
- void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
2961
- half4x4 temp = *(((device half4x4 *)src));
2962
- for (int i = 0; i < 16; i++){
2963
- reg[i/4][i%4] = temp[i/4][i%4];
2964
- }
2965
- }
3484
+ // ======================= "True" 2-bit
3485
+
3486
+ constexpr constant static uint64_t iq2xxs_grid[256] = {
3487
+ 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
3488
+ 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
3489
+ 0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
3490
+ 0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819,
3491
+ 0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b,
3492
+ 0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808,
3493
+ 0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08,
3494
+ 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b,
3495
+ 0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819,
3496
+ 0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08,
3497
+ 0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808,
3498
+ 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08,
3499
+ 0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808,
3500
+ 0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808,
3501
+ 0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919,
3502
+ 0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819,
3503
+ 0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08,
3504
+ 0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908,
3505
+ 0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819,
3506
+ 0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808,
3507
+ 0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808,
3508
+ 0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908,
3509
+ 0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808,
3510
+ 0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08,
3511
+ 0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819,
3512
+ 0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819,
3513
+ 0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819,
3514
+ 0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908,
3515
+ 0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19,
3516
+ 0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819,
3517
+ 0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b,
3518
+ 0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808,
3519
+ 0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908,
3520
+ 0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08,
3521
+ 0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08,
3522
+ 0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908,
3523
+ 0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819,
3524
+ 0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808,
3525
+ 0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808,
3526
+ 0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19,
3527
+ 0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819,
3528
+ 0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919,
3529
+ 0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b,
3530
+ 0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08,
3531
+ 0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808,
3532
+ 0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908,
3533
+ 0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b,
3534
+ 0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819,
3535
+ 0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08,
3536
+ 0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08,
3537
+ 0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808,
3538
+ 0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b,
3539
+ 0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b,
3540
+ 0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908,
3541
+ 0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819,
3542
+ 0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808,
3543
+ 0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908,
3544
+ 0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b,
3545
+ 0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808,
3546
+ 0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b,
3547
+ 0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b,
3548
+ 0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808,
3549
+ 0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19,
3550
+ 0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
3551
+ };
2966
3552
 
2967
- template <typename type4x4>
2968
- void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
2969
- device const uint16_t * qs = ((device const uint16_t *)xb + 1);
2970
- const float d1 = il ? (xb->d / 16.h) : xb->d;
2971
- const float d2 = d1 / 256.f;
2972
- const float md = -8.h * xb->d;
2973
- const ushort mask0 = il ? 0x00F0 : 0x000F;
2974
- const ushort mask1 = mask0 << 8;
3553
+ constexpr constant static uint64_t iq2xs_grid[512] = {
3554
+ 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
3555
+ 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
3556
+ 0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
3557
+ 0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
3558
+ 0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
3559
+ 0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808,
3560
+ 0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, 0x080808082b190819,
3561
+ 0x080808082b191908, 0x080808082b192b19, 0x080808082b2b0808, 0x0808081908080819,
3562
+ 0x0808081908081908, 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808,
3563
+ 0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, 0x0808081908192b2b,
3564
+ 0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808, 0x080808191908082b,
3565
+ 0x0808081919081919, 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908,
3566
+ 0x08080819192b0808, 0x08080819192b2b08, 0x080808192b080819, 0x080808192b081908,
3567
+ 0x080808192b190808, 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b08081919,
3568
+ 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808,
3569
+ 0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, 0x0808082b19191919,
3570
+ 0x0808082b2b080808, 0x0808082b2b082b2b, 0x0808190808080819, 0x0808190808081908,
3571
+ 0x080819080808192b, 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b,
3572
+ 0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908,
3573
+ 0x0808190819080808, 0x080819081908082b, 0x0808190819081919, 0x0808190819082b08,
3574
+ 0x0808190819190819, 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808,
3575
+ 0x080819082b080819, 0x080819082b081908, 0x080819082b190808, 0x0808191908080808,
3576
+ 0x080819190808082b, 0x0808191908081919, 0x0808191908082b08, 0x0808191908190819,
3577
+ 0x0808191908191908, 0x08081919082b0808, 0x0808191919080819, 0x0808191919081908,
3578
+ 0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, 0x0808192b08080819,
3579
+ 0x0808192b08081908, 0x0808192b08190808, 0x0808192b082b192b, 0x0808192b19080808,
3580
+ 0x0808192b1908082b, 0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b,
3581
+ 0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808082b2b, 0x08082b0808190819,
3582
+ 0x08082b0808191908, 0x08082b08082b0808, 0x08082b08082b1919, 0x08082b0819080819,
3583
+ 0x08082b0819081908, 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808,
3584
+ 0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908,
3585
+ 0x08082b1908190808, 0x08082b1919080808, 0x08082b192b080819, 0x08082b192b082b19,
3586
+ 0x08082b2b08080808, 0x08082b2b082b0808, 0x08082b2b082b2b08, 0x08082b2b2b19192b,
3587
+ 0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908, 0x081908080808192b,
3588
+ 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, 0x0819080808191919,
3589
+ 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808,
3590
+ 0x081908081908082b, 0x0819080819081919, 0x0819080819082b08, 0x0819080819190819,
3591
+ 0x0819080819191908, 0x08190808192b0808, 0x08190808192b2b2b, 0x081908082b080819,
3592
+ 0x081908082b081908, 0x081908082b190808, 0x0819081908080808, 0x081908190808082b,
3593
+ 0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, 0x0819081908191908,
3594
+ 0x08190819082b0808, 0x0819081919080819, 0x0819081919081908, 0x0819081919190808,
3595
+ 0x081908192b080808, 0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819,
3596
+ 0x0819082b08081908, 0x0819082b0808192b, 0x0819082b08190808, 0x0819082b19080808,
3597
+ 0x0819082b192b0808, 0x0819190808080808, 0x081919080808082b, 0x0819190808081919,
3598
+ 0x0819190808082b08, 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808,
3599
+ 0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, 0x0819190819190808,
3600
+ 0x08191908192b1908, 0x081919082b080808, 0x0819191908080819, 0x0819191908081908,
3601
+ 0x0819191908190808, 0x0819191919080808, 0x0819192b08080808, 0x0819192b08191908,
3602
+ 0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908, 0x08192b0808190808,
3603
+ 0x08192b080819082b, 0x08192b0819080808, 0x08192b0819191908, 0x08192b082b08192b,
3604
+ 0x08192b1908080808, 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819,
3605
+ 0x08192b2b2b2b2b19, 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919,
3606
+ 0x082b080808082b08, 0x082b080808082b2b, 0x082b080808190819, 0x082b080808191908,
3607
+ 0x082b0808082b0808, 0x082b080819080819, 0x082b080819081908, 0x082b080819190808,
3608
+ 0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, 0x082b081908081908,
3609
+ 0x082b081908190808, 0x082b081919080808, 0x082b081919082b08, 0x082b0819192b1919,
3610
+ 0x082b082b08080808, 0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08,
3611
+ 0x082b190808080819, 0x082b190808081908, 0x082b190808190808, 0x082b1908082b2b19,
3612
+ 0x082b190819080808, 0x082b191908080808, 0x082b191919080819, 0x082b19191919082b,
3613
+ 0x082b19192b192b19, 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b,
3614
+ 0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, 0x082b2b08082b0808,
3615
+ 0x082b2b0819191919, 0x082b2b082b082b08, 0x082b2b082b2b082b, 0x082b2b19192b2b08,
3616
+ 0x082b2b192b190808, 0x082b2b2b08082b08, 0x082b2b2b082b0808, 0x082b2b2b2b08082b,
3617
+ 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, 0x1908080808081908,
3618
+ 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, 0x190808080819082b,
3619
+ 0x1908080808191919, 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908,
3620
+ 0x1908080819080808, 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08,
3621
+ 0x1908080819082b2b, 0x1908080819190819, 0x1908080819191908, 0x19080808192b0808,
3622
+ 0x19080808192b1919, 0x190808082b080819, 0x190808082b081908, 0x190808082b190808,
3623
+ 0x1908081908080808, 0x190808190808082b, 0x1908081908081919, 0x1908081908082b08,
3624
+ 0x1908081908190819, 0x1908081908191908, 0x19080819082b0808, 0x1908081919080819,
3625
+ 0x1908081919081908, 0x1908081919190808, 0x190808192b080808, 0x190808192b081919,
3626
+ 0x190808192b2b082b, 0x1908082b08080819, 0x1908082b08081908, 0x1908082b08190808,
3627
+ 0x1908082b0819082b, 0x1908082b082b2b19, 0x1908082b19080808, 0x1908190808080808,
3628
+ 0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819,
3629
+ 0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, 0x1908190819080819,
3630
+ 0x1908190819081908, 0x1908190819190808, 0x190819082b080808, 0x190819082b191908,
3631
+ 0x1908191908080819, 0x1908191908081908, 0x1908191908190808, 0x19081919082b1908,
3632
+ 0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808, 0x1908192b08082b2b,
3633
+ 0x1908192b19081908, 0x1908192b19190808, 0x19082b0808080819, 0x19082b0808081908,
3634
+ 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908,
3635
+ 0x19082b08192b082b, 0x19082b1908080808, 0x19082b1908190819, 0x19082b1919081908,
3636
+ 0x19082b1919190808, 0x19082b19192b2b19, 0x19082b2b08081908, 0x1919080808080808,
3637
+ 0x191908080808082b, 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819,
3638
+ 0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, 0x1919080819080819,
3639
+ 0x1919080819081908, 0x1919080819190808, 0x191908082b080808, 0x1919081908080819,
3640
+ 0x1919081908081908, 0x1919081908190808, 0x1919081908191919, 0x1919081919080808,
3641
+ 0x191908191908082b, 0x1919082b08080808, 0x1919082b19081908, 0x1919082b2b2b2b2b,
3642
+ 0x1919190808080819, 0x1919190808081908, 0x1919190808190808, 0x19191908082b0819,
3643
+ 0x1919190819080808, 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819,
3644
+ 0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, 0x191919192b082b08,
3645
+ 0x1919192b082b0819, 0x1919192b192b2b08, 0x1919192b2b2b0819, 0x19192b0808080808,
3646
+ 0x19192b0808191908, 0x19192b0819080819, 0x19192b0819190808, 0x19192b082b192b19,
3647
+ 0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b, 0x19192b2b2b081919,
3648
+ 0x192b080808080819, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808,
3649
+ 0x192b080819191908, 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19,
3650
+ 0x192b081908080808, 0x192b082b082b1908, 0x192b082b19082b2b, 0x192b082b2b19082b,
3651
+ 0x192b190808080808, 0x192b19080819192b, 0x192b191908190808, 0x192b191919080808,
3652
+ 0x192b191919081919, 0x192b19192b2b1908, 0x192b2b0808080819, 0x192b2b08192b2b2b,
3653
+ 0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, 0x192b2b2b192b082b,
3654
+ 0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08,
3655
+ 0x2b08080808190819, 0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b,
3656
+ 0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808082b080808,
3657
+ 0x2b0808082b08082b, 0x2b0808082b2b2b08, 0x2b0808082b2b2b2b, 0x2b08081908080819,
3658
+ 0x2b08081908081908, 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808,
3659
+ 0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, 0x2b08082b082b0808,
3660
+ 0x2b08082b2b080808, 0x2b08082b2b08082b, 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08,
3661
+ 0x2b08190808080819, 0x2b08190808081908, 0x2b08190808190808, 0x2b0819080819082b,
3662
+ 0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, 0x2b0819082b082b19,
3663
+ 0x2b08191908080808, 0x2b08191919081908, 0x2b0819192b2b1919, 0x2b08192b08192b08,
3664
+ 0x2b08192b192b2b2b, 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919,
3665
+ 0x2b082b0819192b2b, 0x2b082b082b080808, 0x2b082b082b08082b, 0x2b082b082b2b2b08,
3666
+ 0x2b082b190808192b, 0x2b082b2b082b082b, 0x2b082b2b2b080808, 0x2b082b2b2b082b08,
3667
+ 0x2b082b2b2b19192b, 0x2b082b2b2b2b2b08, 0x2b19080808080819, 0x2b19080808081908,
3668
+ 0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, 0x2b1908082b081908,
3669
+ 0x2b19081908080808, 0x2b190819082b082b, 0x2b190819192b1908, 0x2b19082b1919192b,
3670
+ 0x2b19082b2b082b19, 0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908,
3671
+ 0x2b19190819190808, 0x2b19190819192b08, 0x2b191919082b2b19, 0x2b1919192b190808,
3672
+ 0x2b1919192b19082b, 0x2b19192b19080819, 0x2b192b0819190819, 0x2b192b082b2b192b,
3673
+ 0x2b192b1919082b19, 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808,
3674
+ 0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, 0x2b2b0808082b0808,
3675
+ 0x2b2b0808082b2b2b, 0x2b2b08082b2b0808, 0x2b2b081919190819, 0x2b2b081919192b19,
3676
+ 0x2b2b08192b2b192b, 0x2b2b082b08080808, 0x2b2b082b0808082b, 0x2b2b082b08082b08,
3677
+ 0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808, 0x2b2b190819080808,
3678
+ 0x2b2b19082b191919, 0x2b2b192b192b1919, 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b,
3679
+ 0x2b2b2b08082b0808, 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808,
3680
+ 0x2b2b2b082b2b2b08, 0x2b2b2b1908081908, 0x2b2b2b192b081908, 0x2b2b2b192b08192b,
3681
+ 0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
3682
+ };
2975
3683
 
2976
- for (int i=0;i<8;i++) {
2977
- reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
2978
- reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
2979
- }
2980
- }
3684
+ constexpr constant static uint8_t ksigns_iq2xs[128] = {
3685
+ 0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
3686
+ 144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
3687
+ 160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175,
3688
+ 48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63,
3689
+ 192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207,
3690
+ 80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95,
3691
+ 96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
3692
+ 240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
3693
+ };
2981
3694
 
2982
- template <typename type4x4>
2983
- void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
2984
- device const uint16_t * qs = ((device const uint16_t *)xb + 2);
2985
- const float d1 = il ? (xb->d / 16.h) : xb->d;
2986
- const float d2 = d1 / 256.f;
2987
- const float m = xb->m;
2988
- const ushort mask0 = il ? 0x00F0 : 0x000F;
2989
- const ushort mask1 = mask0 << 8;
3695
+ constexpr constant static uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
2990
3696
 
2991
- for (int i=0;i<8;i++) {
2992
- reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
2993
- reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
2994
- }
2995
- }
3697
+ void kernel_mul_mv_iq2_xxs_f32_impl(
3698
+ device const void * src0,
3699
+ device const float * src1,
3700
+ device float * dst,
3701
+ constant int64_t & ne00,
3702
+ constant int64_t & ne01,
3703
+ constant int64_t & ne02,
3704
+ constant int64_t & ne10,
3705
+ constant int64_t & ne12,
3706
+ constant int64_t & ne0,
3707
+ constant int64_t & ne1,
3708
+ constant uint & r2,
3709
+ constant uint & r3,
3710
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
3711
+ uint3 tgpig[[threadgroup_position_in_grid]],
3712
+ uint tiisg[[thread_index_in_simdgroup]],
3713
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2996
3714
 
2997
- template <typename type4x4>
2998
- void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
2999
- device const uint16_t * qs = ((device const uint16_t *)xb + 3);
3000
- const float d = xb->d;
3001
- const float md = -16.h * xb->d;
3002
- const ushort mask = il ? 0x00F0 : 0x000F;
3715
+ const int nb = ne00/QK_K;
3716
+ const int r0 = tgpig.x;
3717
+ const int r1 = tgpig.y;
3718
+ const int im = tgpig.z;
3003
3719
 
3004
- const uint32_t qh = *((device const uint32_t *)xb->qh);
3720
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
3721
+ const int ib_row = first_row * nb;
3005
3722
 
3006
- const int x_mv = il ? 4 : 0;
3723
+ const uint i12 = im%ne12;
3724
+ const uint i13 = im/ne12;
3007
3725
 
3008
- const int gh_mv = il ? 12 : 0;
3009
- const int gh_bk = il ? 0 : 4;
3726
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
3010
3727
 
3011
- for (int i = 0; i < 8; i++) {
3012
- // extract the 5-th bits for x0 and x1
3013
- const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
3014
- const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
3728
+ device const block_iq2_xxs * x = (device const block_iq2_xxs *) src0 + ib_row + offset0;
3729
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
3015
3730
 
3016
- // combine the 4-bits from qs with the 5th bit
3017
- const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
3018
- const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
3731
+ float yl[32];
3732
+ float sumf[N_DST]={0.f}, all_sum;
3019
3733
 
3020
- reg[i/2][2*(i%2)+0] = d * x0 + md;
3021
- reg[i/2][2*(i%2)+1] = d * x1 + md;
3734
+ const int nb32 = nb * (QK_K / 32);
3735
+
3736
+ threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
3737
+ threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256);
3738
+ {
3739
+ int nval = 4;
3740
+ int pos = (32*sgitg + tiisg)*nval;
3741
+ for (int i = 0; i < nval; ++i) values[pos + i] = iq2xxs_grid[pos + i];
3742
+ nval = 2;
3743
+ pos = (32*sgitg + tiisg)*nval;
3744
+ for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
3745
+ threadgroup_barrier(mem_flags::mem_threadgroup);
3022
3746
  }
3023
- }
3024
3747
 
3025
- template <typename type4x4>
3026
- void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
3027
- device const uint16_t * qs = ((device const uint16_t *)xb + 4);
3028
- const float d = xb->d;
3029
- const float m = xb->m;
3030
- const ushort mask = il ? 0x00F0 : 0x000F;
3748
+ #if QK_K == 256
3749
+ const int ix = tiisg;
3031
3750
 
3032
- const uint32_t qh = *((device const uint32_t *)xb->qh);
3751
+ device const float * y4 = y + 32 * ix;
3033
3752
 
3034
- const int x_mv = il ? 4 : 0;
3753
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
3035
3754
 
3036
- const int gh_mv = il ? 12 : 0;
3037
- const int gh_bk = il ? 0 : 4;
3755
+ for (int i = 0; i < 32; ++i) {
3756
+ yl[i] = y4[i];
3757
+ }
3038
3758
 
3039
- for (int i = 0; i < 8; i++) {
3040
- // extract the 5-th bits for x0 and x1
3041
- const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
3042
- const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
3759
+ const int ibl = ib32 / (QK_K / 32);
3760
+ const int ib = ib32 % (QK_K / 32);
3043
3761
 
3044
- // combine the 4-bits from qs with the 5th bit
3045
- const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
3046
- const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
3762
+ device const block_iq2_xxs * xr = x + ibl;
3763
+ device const uint16_t * q2 = xr->qs + 4 * ib;
3764
+ device const half * dh = &xr->d;
3047
3765
 
3048
- reg[i/2][2*(i%2)+0] = d * x0 + m;
3049
- reg[i/2][2*(i%2)+1] = d * x1 + m;
3050
- }
3051
- }
3766
+ for (int row = 0; row < N_DST; row++) {
3052
3767
 
3053
- template <typename type4x4>
3054
- void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
3055
- device const int8_t * qs = ((device const int8_t *)xb->qs);
3056
- const half d = xb->d;
3768
+ const float db = dh[0];
3769
+ device const uint8_t * aux8 = (device const uint8_t *)q2;
3770
+ const uint32_t aux32 = q2[2] | (q2[3] << 16);
3771
+ const float d = db * (0.5f + (aux32 >> 28));
3057
3772
 
3058
- for (int i=0;i<16;i++) {
3059
- reg[i/4][i%4] = (qs[i + 16*il] * d);
3060
- }
3061
- }
3773
+ float sum = 0;
3774
+ for (int l = 0; l < 4; ++l) {
3775
+ const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + aux8[l]);
3776
+ const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127];
3777
+ for (int j = 0; j < 8; ++j) {
3778
+ sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
3779
+ }
3780
+ }
3781
+ sumf[row] += d * sum;
3062
3782
 
3063
- template <typename type4x4>
3064
- void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
3065
- const half d = xb->d;
3066
- const half min = xb->dmin;
3067
- device const uint8_t * q = (device const uint8_t *)xb->qs;
3068
- half dl, ml;
3069
- uint8_t sc = xb->scales[il];
3783
+ dh += nb*sizeof(block_iq2_xxs)/2;
3784
+ q2 += nb*sizeof(block_iq2_xxs)/2;
3785
+ }
3070
3786
 
3071
- #if QK_K == 256
3072
- q = q + 32*(il/8) + 16*(il&1);
3073
- il = (il/2)%4;
3787
+ y4 += 32 * 32;
3788
+ }
3789
+ #else
3790
+ // TODO
3074
3791
  #endif
3075
- half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
3076
- uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
3077
- dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
3078
- for (int i = 0; i < 16; ++i) {
3079
- reg[i/4][i%4] = dl * (q[i] & mask) - ml;
3792
+
3793
+ for (int row = 0; row < N_DST; ++row) {
3794
+ all_sum = simd_sum(sumf[row]);
3795
+ if (tiisg == 0) {
3796
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
3797
+ }
3080
3798
  }
3081
3799
  }
3082
3800
 
3083
- template <typename type4x4>
3084
- void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
3085
- const half d_all = xb->d;
3086
- device const uint8_t * q = (device const uint8_t *)xb->qs;
3087
- device const uint8_t * h = (device const uint8_t *)xb->hmask;
3088
- device const int8_t * scales = (device const int8_t *)xb->scales;
3089
-
3090
- #if QK_K == 256
3091
- q = q + 32 * (il/8) + 16 * (il&1);
3092
- h = h + 16 * (il&1);
3093
- uint8_t m = 1 << (il/2);
3094
- uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \
3801
+ [[host_name("kernel_mul_mv_iq2_xxs_f32")]]
3802
+ kernel void kernel_mul_mv_iq2_xxs_f32(
3803
+ device const void * src0,
3804
+ device const float * src1,
3805
+ device float * dst,
3806
+ constant int64_t & ne00,
3807
+ constant int64_t & ne01,
3808
+ constant int64_t & ne02,
3809
+ constant uint64_t & nb00,
3810
+ constant uint64_t & nb01,
3811
+ constant uint64_t & nb02,
3812
+ constant int64_t & ne10,
3813
+ constant int64_t & ne11,
3814
+ constant int64_t & ne12,
3815
+ constant uint64_t & nb10,
3816
+ constant uint64_t & nb11,
3817
+ constant uint64_t & nb12,
3818
+ constant int64_t & ne0,
3819
+ constant int64_t & ne1,
3820
+ constant uint & r2,
3821
+ constant uint & r3,
3822
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
3823
+ uint3 tgpig[[threadgroup_position_in_grid]],
3824
+ uint tiisg[[thread_index_in_simdgroup]],
3825
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3826
+
3827
+ kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
3828
+ }
3829
+
3830
+ void kernel_mul_mv_iq2_xs_f32_impl(
3831
+ device const void * src0,
3832
+ device const float * src1,
3833
+ device float * dst,
3834
+ constant int64_t & ne00,
3835
+ constant int64_t & ne01,
3836
+ constant int64_t & ne02,
3837
+ constant int64_t & ne10,
3838
+ constant int64_t & ne12,
3839
+ constant int64_t & ne0,
3840
+ constant int64_t & ne1,
3841
+ constant uint & r2,
3842
+ constant uint & r3,
3843
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
3844
+ uint3 tgpig[[threadgroup_position_in_grid]],
3845
+ uint tiisg[[thread_index_in_simdgroup]],
3846
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3847
+
3848
+ const int nb = ne00/QK_K;
3849
+ const int r0 = tgpig.x;
3850
+ const int r1 = tgpig.y;
3851
+ const int im = tgpig.z;
3852
+
3853
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
3854
+ const int ib_row = first_row * nb;
3855
+
3856
+ const uint i12 = im%ne12;
3857
+ const uint i13 = im/ne12;
3858
+
3859
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
3860
+
3861
+ device const block_iq2_xs * x = (device const block_iq2_xs *) src0 + ib_row + offset0;
3862
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
3863
+
3864
+ float yl[32];
3865
+ float sumf[N_DST]={0.f}, all_sum;
3866
+
3867
+ const int nb32 = nb * (QK_K / 32);
3868
+
3869
+ threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
3870
+ threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 512);
3871
+ {
3872
+ int nval = 8;
3873
+ int pos = (32*sgitg + tiisg)*nval;
3874
+ for (int i = 0; i < nval; ++i) values[pos + i] = iq2xs_grid[pos + i];
3875
+ nval = 2;
3876
+ pos = (32*sgitg + tiisg)*nval;
3877
+ for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
3878
+ threadgroup_barrier(mem_flags::mem_threadgroup);
3879
+ }
3880
+
3881
+ #if QK_K == 256
3882
+ const int ix = tiisg;
3883
+
3884
+ device const float * y4 = y + 32 * ix;
3885
+
3886
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
3887
+
3888
+ for (int i = 0; i < 32; ++i) {
3889
+ yl[i] = y4[i];
3890
+ }
3891
+
3892
+ const int ibl = ib32 / (QK_K / 32);
3893
+ const int ib = ib32 % (QK_K / 32);
3894
+
3895
+ device const block_iq2_xs * xr = x + ibl;
3896
+ device const uint16_t * q2 = xr->qs + 4 * ib;
3897
+ device const uint8_t * sc = xr->scales + ib;
3898
+ device const half * dh = &xr->d;
3899
+
3900
+ for (int row = 0; row < N_DST; row++) {
3901
+
3902
+ const float db = dh[0];
3903
+ const uint8_t ls1 = sc[0] & 0xf;
3904
+ const uint8_t ls2 = sc[0] >> 4;
3905
+ const float d1 = db * (0.5f + ls1);
3906
+ const float d2 = db * (0.5f + ls2);
3907
+
3908
+ float sum1 = 0, sum2 = 0;
3909
+ for (int l = 0; l < 2; ++l) {
3910
+ const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511));
3911
+ const uint8_t signs = shared_signs[(q2[l] >> 9)];
3912
+ for (int j = 0; j < 8; ++j) {
3913
+ sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
3914
+ }
3915
+ }
3916
+ for (int l = 2; l < 4; ++l) {
3917
+ const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511));
3918
+ const uint8_t signs = shared_signs[(q2[l] >> 9)];
3919
+ for (int j = 0; j < 8; ++j) {
3920
+ sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
3921
+ }
3922
+ }
3923
+ sumf[row] += d1 * sum1 + d2 * sum2;
3924
+
3925
+ dh += nb*sizeof(block_iq2_xs)/2;
3926
+ q2 += nb*sizeof(block_iq2_xs)/2;
3927
+ sc += nb*sizeof(block_iq2_xs);
3928
+ }
3929
+
3930
+ y4 += 32 * 32;
3931
+ }
3932
+ #else
3933
+ // TODO
3934
+ #endif
3935
+
3936
+ for (int row = 0; row < N_DST; ++row) {
3937
+ all_sum = simd_sum(sumf[row]);
3938
+ if (tiisg == 0) {
3939
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
3940
+ }
3941
+ }
3942
+ }
3943
+
3944
+ [[host_name("kernel_mul_mv_iq2_xs_f32")]]
3945
+ kernel void kernel_mul_mv_iq2_xs_f32(
3946
+ device const void * src0,
3947
+ device const float * src1,
3948
+ device float * dst,
3949
+ constant int64_t & ne00,
3950
+ constant int64_t & ne01,
3951
+ constant int64_t & ne02,
3952
+ constant uint64_t & nb00,
3953
+ constant uint64_t & nb01,
3954
+ constant uint64_t & nb02,
3955
+ constant int64_t & ne10,
3956
+ constant int64_t & ne11,
3957
+ constant int64_t & ne12,
3958
+ constant uint64_t & nb10,
3959
+ constant uint64_t & nb11,
3960
+ constant uint64_t & nb12,
3961
+ constant int64_t & ne0,
3962
+ constant int64_t & ne1,
3963
+ constant uint & r2,
3964
+ constant uint & r3,
3965
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
3966
+ uint3 tgpig[[threadgroup_position_in_grid]],
3967
+ uint tiisg[[thread_index_in_simdgroup]],
3968
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3969
+
3970
+ kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
3971
+ }
3972
+
3973
+ //============================= templates and their specializations =============================
3974
+
3975
+ // NOTE: this is not dequantizing - we are simply fitting the template
3976
+ template <typename type4x4>
3977
+ void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
3978
+ float4x4 temp = *(((device float4x4 *)src));
3979
+ for (int i = 0; i < 16; i++){
3980
+ reg[i/4][i%4] = temp[i/4][i%4];
3981
+ }
3982
+ }
3983
+
3984
+ template <typename type4x4>
3985
+ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
3986
+ half4x4 temp = *(((device half4x4 *)src));
3987
+ for (int i = 0; i < 16; i++){
3988
+ reg[i/4][i%4] = temp[i/4][i%4];
3989
+ }
3990
+ }
3991
+
3992
+ template <typename type4x4>
3993
+ void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
3994
+ device const uint16_t * qs = ((device const uint16_t *)xb + 1);
3995
+ const float d1 = il ? (xb->d / 16.h) : xb->d;
3996
+ const float d2 = d1 / 256.f;
3997
+ const float md = -8.h * xb->d;
3998
+ const ushort mask0 = il ? 0x00F0 : 0x000F;
3999
+ const ushort mask1 = mask0 << 8;
4000
+
4001
+ for (int i=0;i<8;i++) {
4002
+ reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
4003
+ reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
4004
+ }
4005
+ }
4006
+
4007
+ template <typename type4x4>
4008
+ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
4009
+ device const uint16_t * qs = ((device const uint16_t *)xb + 2);
4010
+ const float d1 = il ? (xb->d / 16.h) : xb->d;
4011
+ const float d2 = d1 / 256.f;
4012
+ const float m = xb->m;
4013
+ const ushort mask0 = il ? 0x00F0 : 0x000F;
4014
+ const ushort mask1 = mask0 << 8;
4015
+
4016
+ for (int i=0;i<8;i++) {
4017
+ reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
4018
+ reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
4019
+ }
4020
+ }
4021
+
4022
+ template <typename type4x4>
4023
+ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
4024
+ device const uint16_t * qs = ((device const uint16_t *)xb + 3);
4025
+ const float d = xb->d;
4026
+ const float md = -16.h * xb->d;
4027
+ const ushort mask = il ? 0x00F0 : 0x000F;
4028
+
4029
+ const uint32_t qh = *((device const uint32_t *)xb->qh);
4030
+
4031
+ const int x_mv = il ? 4 : 0;
4032
+
4033
+ const int gh_mv = il ? 12 : 0;
4034
+ const int gh_bk = il ? 0 : 4;
4035
+
4036
+ for (int i = 0; i < 8; i++) {
4037
+ // extract the 5-th bits for x0 and x1
4038
+ const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
4039
+ const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
4040
+
4041
+ // combine the 4-bits from qs with the 5th bit
4042
+ const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
4043
+ const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
4044
+
4045
+ reg[i/2][2*(i%2)+0] = d * x0 + md;
4046
+ reg[i/2][2*(i%2)+1] = d * x1 + md;
4047
+ }
4048
+ }
4049
+
4050
+ template <typename type4x4>
4051
+ void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
4052
+ device const uint16_t * qs = ((device const uint16_t *)xb + 4);
4053
+ const float d = xb->d;
4054
+ const float m = xb->m;
4055
+ const ushort mask = il ? 0x00F0 : 0x000F;
4056
+
4057
+ const uint32_t qh = *((device const uint32_t *)xb->qh);
4058
+
4059
+ const int x_mv = il ? 4 : 0;
4060
+
4061
+ const int gh_mv = il ? 12 : 0;
4062
+ const int gh_bk = il ? 0 : 4;
4063
+
4064
+ for (int i = 0; i < 8; i++) {
4065
+ // extract the 5-th bits for x0 and x1
4066
+ const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
4067
+ const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
4068
+
4069
+ // combine the 4-bits from qs with the 5th bit
4070
+ const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
4071
+ const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
4072
+
4073
+ reg[i/2][2*(i%2)+0] = d * x0 + m;
4074
+ reg[i/2][2*(i%2)+1] = d * x1 + m;
4075
+ }
4076
+ }
4077
+
4078
+ template <typename type4x4>
4079
+ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
4080
+ device const int8_t * qs = ((device const int8_t *)xb->qs);
4081
+ const half d = xb->d;
4082
+
4083
+ for (int i = 0; i < 16; i++) {
4084
+ reg[i/4][i%4] = (qs[i + 16*il] * d);
4085
+ }
4086
+ }
4087
+
4088
+ template <typename type4x4>
4089
+ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
4090
+ const float d = xb->d;
4091
+ const float min = xb->dmin;
4092
+ device const uint8_t * q = (device const uint8_t *)xb->qs;
4093
+ float dl, ml;
4094
+ uint8_t sc = xb->scales[il];
4095
+
4096
+ #if QK_K == 256
4097
+ q = q + 32*(il/8) + 16*(il&1);
4098
+ il = (il/2)%4;
4099
+ #endif
4100
+ half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
4101
+ uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
4102
+ dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
4103
+ for (int i = 0; i < 16; ++i) {
4104
+ reg[i/4][i%4] = dl * (q[i] & mask) - ml;
4105
+ }
4106
+ }
4107
+
4108
+ template <typename type4x4>
4109
+ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
4110
+ const half d_all = xb->d;
4111
+ device const uint8_t * q = (device const uint8_t *)xb->qs;
4112
+ device const uint8_t * h = (device const uint8_t *)xb->hmask;
4113
+ device const int8_t * scales = (device const int8_t *)xb->scales;
4114
+
4115
+ #if QK_K == 256
4116
+ q = q + 32 * (il/8) + 16 * (il&1);
4117
+ h = h + 16 * (il&1);
4118
+ uint8_t m = 1 << (il/2);
4119
+ uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \
3095
4120
  ((il/4)>0 ? 12 : 3);
3096
4121
  uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
3097
4122
  uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
3098
4123
  int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
3099
4124
  : (scale_2&kmask2) | ((scale_1&kmask1) << 4);
3100
- half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h);
3101
- const half ml = 4.h * dl;
4125
+ float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
4126
+ const float ml = 4.f * dl;
3102
4127
 
3103
4128
  il = (il/2) & 3;
3104
4129
  const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
@@ -3135,10 +4160,10 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
3135
4160
  q = q + (il/4) * 32 + 16 * (il&1);
3136
4161
  il = il & 3;
3137
4162
  const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
3138
- const half d = il < 2 ? xb->d : xb->d / 16.h;
3139
- const half min = xb->dmin;
3140
- const half dl = d * sc[0];
3141
- const half ml = min * sc[1];
4163
+ const float d = il < 2 ? xb->d : xb->d / 16.h;
4164
+ const float min = xb->dmin;
4165
+ const float dl = d * sc[0];
4166
+ const float ml = min * sc[1];
3142
4167
  #else
3143
4168
  q = q + 16 * (il&1);
3144
4169
  device const uint8_t * s = xb->scales;
@@ -3165,13 +4190,13 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
3165
4190
  uint8_t ul = 1 << (il/2);
3166
4191
  il = il & 3;
3167
4192
  const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
3168
- const half d = il < 2 ? xb->d : xb->d / 16.h;
3169
- const half min = xb->dmin;
3170
- const half dl = d * sc[0];
3171
- const half ml = min * sc[1];
4193
+ const float d = il < 2 ? xb->d : xb->d / 16.f;
4194
+ const float min = xb->dmin;
4195
+ const float dl = d * sc[0];
4196
+ const float ml = min * sc[1];
3172
4197
 
3173
- const ushort mask = il<2 ? 0x0F : 0xF0;
3174
- const half qh_val = il<2 ? 16.h : 256.h;
4198
+ const ushort mask = il<2 ? 0x0F : 0xF0;
4199
+ const float qh_val = il<2 ? 16.f : 256.f;
3175
4200
  for (int i = 0; i < 16; ++i) {
3176
4201
  reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
3177
4202
  }
@@ -3198,17 +4223,17 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
3198
4223
  #if QK_K == 256
3199
4224
  ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
3200
4225
  qh = qh + 32*(il/8) + 16*(il&1);
3201
- half sc = scales[(il%2) + 2 * ((il/2))];
4226
+ float sc = scales[(il%2) + 2 * ((il/2))];
3202
4227
  il = (il/2) & 3;
3203
4228
  #else
3204
4229
  ql = ql + 16 * (il&1);
3205
- half sc = scales[il];
4230
+ float sc = scales[il];
3206
4231
  #endif
3207
4232
  const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
3208
4233
  const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
3209
- const half coef = il>1 ? 1.f/16.h : 1.h;
3210
- const half ml = d_all * sc * 32.h;
3211
- const half dl = d_all * sc * coef;
4234
+ const float coef = il>1 ? 1.f/16.f : 1.f;
4235
+ const float ml = d_all * sc * 32.f;
4236
+ const float dl = d_all * sc * coef;
3212
4237
  for (int i = 0; i < 16; ++i) {
3213
4238
  const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
3214
4239
  : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
@@ -3216,28 +4241,171 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
3216
4241
  }
3217
4242
  }
3218
4243
 
3219
- template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
3220
- kernel void kernel_get_rows(
4244
+ template <typename type4x4>
4245
+ void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {
4246
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
4247
+ const float d = xb->d;
4248
+ const int ib32 = il/2;
4249
+ il = il%2;
4250
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
4251
+ // each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's.
4252
+ device const uint16_t * q2 = xb->qs + 4*ib32;
4253
+ const uint32_t aux32_g = q2[0] | (q2[1] << 16);
4254
+ const uint32_t aux32_s = q2[2] | (q2[3] << 16);
4255
+ thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
4256
+ const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
4257
+ constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
4258
+ uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
4259
+ for (int i = 0; i < 8; ++i) {
4260
+ reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
4261
+ }
4262
+ grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
4263
+ signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
4264
+ for (int i = 0; i < 8; ++i) {
4265
+ reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
4266
+ }
4267
+ }
4268
+
4269
+ template <typename type4x4>
4270
+ void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) {
4271
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
4272
+ const float d = xb->d;
4273
+ const int ib32 = il/2;
4274
+ il = il%2;
4275
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
4276
+ device const uint16_t * q2 = xb->qs + 4*ib32;
4277
+ const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
4278
+ constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511));
4279
+ uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9];
4280
+ for (int i = 0; i < 8; ++i) {
4281
+ reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
4282
+ }
4283
+ grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511));
4284
+ signs = ksigns_iq2xs[q2[2*il+1] >> 9];
4285
+ for (int i = 0; i < 8; ++i) {
4286
+ reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
4287
+ }
4288
+ }
4289
+
4290
+ template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
4291
+ kernel void kernel_get_rows(
3221
4292
  device const void * src0,
3222
- device const int * src1,
4293
+ device const char * src1,
3223
4294
  device float * dst,
3224
4295
  constant int64_t & ne00,
3225
4296
  constant uint64_t & nb01,
4297
+ constant uint64_t & nb02,
4298
+ constant int64_t & ne10,
4299
+ constant uint64_t & nb10,
4300
+ constant uint64_t & nb11,
3226
4301
  constant uint64_t & nb1,
3227
- uint tgpig[[threadgroup_position_in_grid]],
4302
+ constant uint64_t & nb2,
4303
+ uint3 tgpig[[threadgroup_position_in_grid]],
3228
4304
  uint tiitg[[thread_index_in_threadgroup]],
3229
- uint tptg[[threads_per_threadgroup]]) {
3230
- const int i = tgpig;
3231
- const int r = ((device int32_t *) src1)[i];
4305
+ uint3 tptg [[threads_per_threadgroup]]) {
4306
+ //const int64_t i = tgpig;
4307
+ //const int64_t r = ((device int32_t *) src1)[i];
3232
4308
 
3233
- for (int ind = tiitg; ind < ne00/16; ind += tptg) {
4309
+ const int64_t i10 = tgpig.x;
4310
+ const int64_t i11 = tgpig.y;
4311
+
4312
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
4313
+
4314
+ const int64_t i02 = i11;
4315
+
4316
+ for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
3234
4317
  float4x4 temp;
3235
4318
  dequantize_func(
3236
- ((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
3237
- *(((device float4x4 *) ((device char *) dst + i*nb1)) + ind) = temp;
4319
+ ((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
4320
+ *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
4321
+ }
4322
+ }
4323
+
4324
+ kernel void kernel_get_rows_f32(
4325
+ device const void * src0,
4326
+ device const char * src1,
4327
+ device float * dst,
4328
+ constant int64_t & ne00,
4329
+ constant uint64_t & nb01,
4330
+ constant uint64_t & nb02,
4331
+ constant int64_t & ne10,
4332
+ constant uint64_t & nb10,
4333
+ constant uint64_t & nb11,
4334
+ constant uint64_t & nb1,
4335
+ constant uint64_t & nb2,
4336
+ uint3 tgpig[[threadgroup_position_in_grid]],
4337
+ uint tiitg[[thread_index_in_threadgroup]],
4338
+ uint3 tptg [[threads_per_threadgroup]]) {
4339
+ const int64_t i10 = tgpig.x;
4340
+ const int64_t i11 = tgpig.y;
4341
+
4342
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
4343
+
4344
+ const int64_t i02 = i11;
4345
+
4346
+ for (int ind = tiitg; ind < ne00; ind += tptg.x) {
4347
+ ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
4348
+ ((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
4349
+ }
4350
+ }
4351
+
4352
+ kernel void kernel_get_rows_f16(
4353
+ device const void * src0,
4354
+ device const char * src1,
4355
+ device float * dst,
4356
+ constant int64_t & ne00,
4357
+ constant uint64_t & nb01,
4358
+ constant uint64_t & nb02,
4359
+ constant int64_t & ne10,
4360
+ constant uint64_t & nb10,
4361
+ constant uint64_t & nb11,
4362
+ constant uint64_t & nb1,
4363
+ constant uint64_t & nb2,
4364
+ uint3 tgpig[[threadgroup_position_in_grid]],
4365
+ uint tiitg[[thread_index_in_threadgroup]],
4366
+ uint3 tptg [[threads_per_threadgroup]]) {
4367
+ const int64_t i10 = tgpig.x;
4368
+ const int64_t i11 = tgpig.y;
4369
+
4370
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
4371
+
4372
+ const int64_t i02 = i11;
4373
+
4374
+ for (int ind = tiitg; ind < ne00; ind += tptg.x) {
4375
+ ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
4376
+ ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
4377
+ }
4378
+ }
4379
+
4380
+ kernel void kernel_get_rows_i32(
4381
+ device const void * src0,
4382
+ device const char * src1,
4383
+ device int32_t * dst,
4384
+ constant int64_t & ne00,
4385
+ constant uint64_t & nb01,
4386
+ constant uint64_t & nb02,
4387
+ constant int64_t & ne10,
4388
+ constant uint64_t & nb10,
4389
+ constant uint64_t & nb11,
4390
+ constant uint64_t & nb1,
4391
+ constant uint64_t & nb2,
4392
+ uint3 tgpig[[threadgroup_position_in_grid]],
4393
+ uint tiitg[[thread_index_in_threadgroup]],
4394
+ uint3 tptg [[threads_per_threadgroup]]) {
4395
+ const int64_t i10 = tgpig.x;
4396
+ const int64_t i11 = tgpig.y;
4397
+
4398
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
4399
+
4400
+ const int64_t i02 = i11;
4401
+
4402
+ for (int ind = tiitg; ind < ne00; ind += tptg.x) {
4403
+ ((device int32_t *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
4404
+ ((device int32_t *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
3238
4405
  }
3239
4406
  }
3240
4407
 
4408
+
3241
4409
  #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
3242
4410
  #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
3243
4411
  #define BLOCK_SIZE_K 32
@@ -3256,12 +4424,12 @@ void kernel_mul_mm_impl(device const uchar * src0,
3256
4424
  device float * dst,
3257
4425
  constant int64_t & ne00,
3258
4426
  constant int64_t & ne02,
3259
- constant int64_t & nb01,
3260
- constant int64_t & nb02,
4427
+ constant uint64_t & nb01,
4428
+ constant uint64_t & nb02,
3261
4429
  constant int64_t & ne12,
3262
- constant int64_t & nb10,
3263
- constant int64_t & nb11,
3264
- constant int64_t & nb12,
4430
+ constant uint64_t & nb10,
4431
+ constant uint64_t & nb11,
4432
+ constant uint64_t & nb12,
3265
4433
  constant int64_t & ne0,
3266
4434
  constant int64_t & ne1,
3267
4435
  constant uint & r2,
@@ -3382,18 +4550,143 @@ void kernel_mul_mm_impl(device const uchar * src0,
3382
4550
  }
3383
4551
  }
3384
4552
 
4553
+ // same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in src1ids
4554
+ template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
4555
+ void kernel_mul_mm_id_impl(
4556
+ device const uchar * src0,
4557
+ device const uchar * src1,
4558
+ thread short * src1ids,
4559
+ device float * dst,
4560
+ constant int64_t & ne00,
4561
+ constant int64_t & ne02,
4562
+ constant uint64_t & nb01,
4563
+ constant uint64_t & nb02,
4564
+ constant int64_t & ne12,
4565
+ constant uint64_t & nb10,
4566
+ constant uint64_t & nb11,
4567
+ constant uint64_t & nb12,
4568
+ constant int64_t & ne0,
4569
+ int64_t ne1,
4570
+ constant uint & r2,
4571
+ constant uint & r3,
4572
+ threadgroup uchar * shared_memory,
4573
+ uint3 tgpig[[threadgroup_position_in_grid]],
4574
+ uint tiitg[[thread_index_in_threadgroup]],
4575
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4576
+
4577
+ threadgroup half * sa = (threadgroup half *)(shared_memory);
4578
+ threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
4579
+
4580
+ const uint r0 = tgpig.y;
4581
+ const uint r1 = tgpig.x;
4582
+ const uint im = tgpig.z;
4583
+
4584
+ if (r1 * BLOCK_SIZE_N >= ne1) return;
4585
+
4586
+ // if this block is of 64x32 shape or smaller
4587
+ short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
4588
+ short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
4589
+
4590
+ // a thread shouldn't load data outside of the matrix
4591
+ short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
4592
+ short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
4593
+
4594
+ simdgroup_half8x8 ma[4];
4595
+ simdgroup_float8x8 mb[2];
4596
+ simdgroup_float8x8 c_res[8];
4597
+ for (int i = 0; i < 8; i++){
4598
+ c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
4599
+ }
4600
+
4601
+ short il = (tiitg % THREAD_PER_ROW);
4602
+
4603
+ const uint i12 = im%ne12;
4604
+ const uint i13 = im/ne12;
4605
+
4606
+ uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
4607
+ ushort offset1 = il/nl;
4608
+
4609
+ device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
4610
+ device const float * y = (device const float *)(src1
4611
+ + nb12 * im
4612
+ + nb11 * src1ids[r1 * BLOCK_SIZE_N + thread_col]
4613
+ + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
4614
+
4615
+ for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
4616
+ // load data and store to threadgroup memory
4617
+ half4x4 temp_a;
4618
+ dequantize_func(x, il, temp_a);
4619
+ threadgroup_barrier(mem_flags::mem_threadgroup);
4620
+
4621
+ for (int i = 0; i < 16; i++) {
4622
+ *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
4623
+ + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
4624
+ + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
4625
+ }
4626
+
4627
+ *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
4628
+
4629
+ il = (il + 2 < nl) ? il + 2 : il % 2;
4630
+ x = (il < 2) ? x + (2+nl-1)/nl : x;
4631
+ y += BLOCK_SIZE_K;
4632
+
4633
+ threadgroup_barrier(mem_flags::mem_threadgroup);
4634
+
4635
+ // load matrices from threadgroup memory and conduct outer products
4636
+ threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
4637
+ threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
4638
+
4639
+ for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
4640
+ for (int i = 0; i < 4; i++) {
4641
+ simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
4642
+ }
4643
+ simdgroup_barrier(mem_flags::mem_none);
4644
+ for (int i = 0; i < 2; i++) {
4645
+ simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
4646
+ }
4647
+
4648
+ lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
4649
+ lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
4650
+
4651
+ for (int i = 0; i < 8; i++){
4652
+ simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
4653
+ }
4654
+ }
4655
+ }
4656
+
4657
+ {
4658
+ threadgroup_barrier(mem_flags::mem_threadgroup);
4659
+ threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
4660
+ + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
4661
+ for (int i = 0; i < 8; i++) {
4662
+ simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
4663
+ }
4664
+
4665
+ threadgroup_barrier(mem_flags::mem_threadgroup);
4666
+
4667
+ device float * C = dst + (BLOCK_SIZE_M * r0) + im*ne1*ne0;
4668
+ if (sgitg == 0) {
4669
+ for (int i = 0; i < n_rows; i++) {
4670
+ for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
4671
+ *(C + i + src1ids[j + r1*BLOCK_SIZE_N] * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
4672
+ }
4673
+ }
4674
+ }
4675
+ }
4676
+ }
4677
+
3385
4678
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
3386
4679
  kernel void kernel_mul_mm(device const uchar * src0,
3387
4680
  device const uchar * src1,
3388
4681
  device float * dst,
3389
4682
  constant int64_t & ne00,
3390
4683
  constant int64_t & ne02,
3391
- constant int64_t & nb01,
3392
- constant int64_t & nb02,
4684
+ constant uint64_t & nb01,
4685
+ constant uint64_t & nb02,
3393
4686
  constant int64_t & ne12,
3394
- constant int64_t & nb10,
3395
- constant int64_t & nb11,
3396
- constant int64_t & nb12,
4687
+ constant uint64_t & nb10,
4688
+ constant uint64_t & nb11,
4689
+ constant uint64_t & nb12,
3397
4690
  constant int64_t & ne0,
3398
4691
  constant int64_t & ne1,
3399
4692
  constant uint & r2,
@@ -3426,19 +4719,22 @@ kernel void kernel_mul_mm(device const uchar * src0,
3426
4719
 
3427
4720
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
3428
4721
  kernel void kernel_mul_mm_id(
3429
- device const int32_t * ids,
4722
+ device const uchar * ids,
3430
4723
  device const uchar * src1,
3431
4724
  device float * dst,
4725
+ constant uint64_t & nbi1,
3432
4726
  constant int64_t & ne00,
3433
4727
  constant int64_t & ne02,
3434
- constant int64_t & nb01,
3435
- constant int64_t & nb02,
4728
+ constant uint64_t & nb01,
4729
+ constant uint64_t & nb02,
3436
4730
  constant int64_t & ne12,
3437
- constant int64_t & nb10,
3438
- constant int64_t & nb11,
3439
- constant int64_t & nb12,
4731
+ constant int64_t & ne13,
4732
+ constant uint64_t & nb10,
4733
+ constant uint64_t & nb11,
4734
+ constant uint64_t & nb12,
3440
4735
  constant int64_t & ne0,
3441
4736
  constant int64_t & ne1,
4737
+ constant uint64_t & nb1,
3442
4738
  constant uint & r2,
3443
4739
  constant uint & r3,
3444
4740
  constant int & idx,
@@ -3454,11 +4750,27 @@ kernel void kernel_mul_mm_id(
3454
4750
  uint3 tgpig[[threadgroup_position_in_grid]],
3455
4751
  uint tiitg[[thread_index_in_threadgroup]],
3456
4752
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
3457
- device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4753
+ device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
3458
4754
 
3459
- kernel_mul_mm_impl<block_q, nl, dequantize_func>(
3460
- src0[ids[idx]],
4755
+ // expert id
4756
+ const int32_t id = tgpig.z/(ne12*ne13);
4757
+
4758
+ tgpig.z = tgpig.z%(ne12*ne13);
4759
+
4760
+ // row indices of src1 for expert id
4761
+ int64_t _ne1 = 0;
4762
+ short src1ids[512];
4763
+
4764
+ for (int64_t i1 = 0; i1 < ne1; i1++) {
4765
+ if (((device int32_t *) (ids + i1*nbi1))[idx] == id) {
4766
+ src1ids[_ne1++] = i1;
4767
+ }
4768
+ }
4769
+
4770
+ kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
4771
+ src0s[id],
3461
4772
  src1,
4773
+ src1ids,
3462
4774
  dst,
3463
4775
  ne00,
3464
4776
  ne02,
@@ -3469,7 +4781,7 @@ kernel void kernel_mul_mm_id(
3469
4781
  nb11,
3470
4782
  nb12,
3471
4783
  ne0,
3472
- ne1,
4784
+ _ne1,
3473
4785
  r2,
3474
4786
  r3,
3475
4787
  shared_memory,
@@ -3484,17 +4796,26 @@ kernel void kernel_mul_mm_id(
3484
4796
  #define QK_NL 4
3485
4797
  #endif
3486
4798
 
4799
+ //
4800
+ // get rows
4801
+ //
4802
+
3487
4803
  typedef void (get_rows_t)(
3488
4804
  device const void * src0,
3489
- device const int * src1,
4805
+ device const char * src1,
3490
4806
  device float * dst,
3491
4807
  constant int64_t & ne00,
3492
4808
  constant uint64_t & nb01,
4809
+ constant uint64_t & nb02,
4810
+ constant int64_t & ne10,
4811
+ constant uint64_t & nb10,
4812
+ constant uint64_t & nb11,
3493
4813
  constant uint64_t & nb1,
3494
- uint, uint, uint);
4814
+ constant uint64_t & nb2,
4815
+ uint3, uint, uint3);
3495
4816
 
3496
- template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
3497
- template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
4817
+ //template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
4818
+ //template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
3498
4819
  template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
3499
4820
  template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
3500
4821
  template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
@@ -3505,6 +4826,12 @@ template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows
3505
4826
  template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
3506
4827
  template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
3507
4828
  template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
4829
+ template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
4830
+ template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
4831
+
4832
+ //
4833
+ // matrix-matrix multiplication
4834
+ //
3508
4835
 
3509
4836
  typedef void (mat_mm_t)(
3510
4837
  device const uchar * src0,
@@ -3512,12 +4839,12 @@ typedef void (mat_mm_t)(
3512
4839
  device float * dst,
3513
4840
  constant int64_t & ne00,
3514
4841
  constant int64_t & ne02,
3515
- constant int64_t & nb01,
3516
- constant int64_t & nb02,
4842
+ constant uint64_t & nb01,
4843
+ constant uint64_t & nb02,
3517
4844
  constant int64_t & ne12,
3518
- constant int64_t & nb10,
3519
- constant int64_t & nb11,
3520
- constant int64_t & nb12,
4845
+ constant uint64_t & nb10,
4846
+ constant uint64_t & nb11,
4847
+ constant uint64_t & nb12,
3521
4848
  constant int64_t & ne0,
3522
4849
  constant int64_t & ne1,
3523
4850
  constant uint & r2,
@@ -3537,21 +4864,30 @@ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
3537
4864
  template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
3538
4865
  template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
3539
4866
  template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
4867
+ template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
4868
+ template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
4869
+
4870
+ //
4871
+ // indirect matrix-matrix multiplication
4872
+ //
3540
4873
 
3541
4874
  typedef void (mat_mm_id_t)(
3542
- device const int32_t * ids,
4875
+ device const uchar * ids,
3543
4876
  device const uchar * src1,
3544
4877
  device float * dst,
4878
+ constant uint64_t & nbi1,
3545
4879
  constant int64_t & ne00,
3546
4880
  constant int64_t & ne02,
3547
- constant int64_t & nb01,
3548
- constant int64_t & nb02,
4881
+ constant uint64_t & nb01,
4882
+ constant uint64_t & nb02,
3549
4883
  constant int64_t & ne12,
3550
- constant int64_t & nb10,
3551
- constant int64_t & nb11,
3552
- constant int64_t & nb12,
4884
+ constant int64_t & ne13,
4885
+ constant uint64_t & nb10,
4886
+ constant uint64_t & nb11,
4887
+ constant uint64_t & nb12,
3553
4888
  constant int64_t & ne0,
3554
4889
  constant int64_t & ne1,
4890
+ constant uint64_t & nb1,
3555
4891
  constant uint & r2,
3556
4892
  constant uint & r3,
3557
4893
  constant int & idx,
@@ -3578,3 +4914,907 @@ template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mu
3578
4914
  template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
3579
4915
  template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
3580
4916
  template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
4917
+ template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
4918
+ template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
4919
+
4920
+ //
4921
+ // matrix-vector multiplication
4922
+ //
4923
+
4924
+ [[host_name("kernel_mul_mv_id_f32_f32")]]
4925
+ kernel void kernel_mul_mv_id_f32_f32(
4926
+ device const char * ids,
4927
+ device const char * src1,
4928
+ device float * dst,
4929
+ constant uint64_t & nbi1,
4930
+ constant int64_t & ne00,
4931
+ constant int64_t & ne01,
4932
+ constant int64_t & ne02,
4933
+ constant uint64_t & nb00,
4934
+ constant uint64_t & nb01,
4935
+ constant uint64_t & nb02,
4936
+ constant int64_t & ne10,
4937
+ constant int64_t & ne11,
4938
+ constant int64_t & ne12,
4939
+ constant int64_t & ne13,
4940
+ constant uint64_t & nb10,
4941
+ constant uint64_t & nb11,
4942
+ constant uint64_t & nb12,
4943
+ constant int64_t & ne0,
4944
+ constant int64_t & ne1,
4945
+ constant uint64_t & nb1,
4946
+ constant uint & r2,
4947
+ constant uint & r3,
4948
+ constant int & idx,
4949
+ device const char * src00,
4950
+ device const char * src01,
4951
+ device const char * src02,
4952
+ device const char * src03,
4953
+ device const char * src04,
4954
+ device const char * src05,
4955
+ device const char * src06,
4956
+ device const char * src07,
4957
+ uint3 tgpig[[threadgroup_position_in_grid]],
4958
+ uint tiitg[[thread_index_in_threadgroup]],
4959
+ uint tiisg[[thread_index_in_simdgroup]],
4960
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4961
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4962
+
4963
+ const int64_t bid = tgpig.z/(ne12*ne13);
4964
+
4965
+ tgpig.z = tgpig.z%(ne12*ne13);
4966
+
4967
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4968
+
4969
+ kernel_mul_mv_f32_f32_impl(
4970
+ src0[id],
4971
+ src1 + bid*nb11,
4972
+ dst + bid*ne0,
4973
+ ne00,
4974
+ ne01,
4975
+ ne02,
4976
+ nb00,
4977
+ nb01,
4978
+ nb02,
4979
+ ne10,
4980
+ ne11,
4981
+ ne12,
4982
+ nb10,
4983
+ nb11,
4984
+ nb12,
4985
+ ne0,
4986
+ ne1,
4987
+ r2,
4988
+ r3,
4989
+ tgpig,
4990
+ tiisg);
4991
+ }
4992
+
4993
+ [[host_name("kernel_mul_mv_id_f16_f32")]]
4994
+ kernel void kernel_mul_mv_id_f16_f32(
4995
+ device const char * ids,
4996
+ device const char * src1,
4997
+ device float * dst,
4998
+ constant uint64_t & nbi1,
4999
+ constant int64_t & ne00,
5000
+ constant int64_t & ne01,
5001
+ constant int64_t & ne02,
5002
+ constant uint64_t & nb00,
5003
+ constant uint64_t & nb01,
5004
+ constant uint64_t & nb02,
5005
+ constant int64_t & ne10,
5006
+ constant int64_t & ne11,
5007
+ constant int64_t & ne12,
5008
+ constant int64_t & ne13,
5009
+ constant uint64_t & nb10,
5010
+ constant uint64_t & nb11,
5011
+ constant uint64_t & nb12,
5012
+ constant int64_t & ne0,
5013
+ constant int64_t & ne1,
5014
+ constant uint64_t & nb1,
5015
+ constant uint & r2,
5016
+ constant uint & r3,
5017
+ constant int & idx,
5018
+ device const char * src00,
5019
+ device const char * src01,
5020
+ device const char * src02,
5021
+ device const char * src03,
5022
+ device const char * src04,
5023
+ device const char * src05,
5024
+ device const char * src06,
5025
+ device const char * src07,
5026
+ uint3 tgpig[[threadgroup_position_in_grid]],
5027
+ uint tiitg[[thread_index_in_threadgroup]],
5028
+ uint tiisg[[thread_index_in_simdgroup]],
5029
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
5030
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5031
+
5032
+ const int64_t bid = tgpig.z/(ne12*ne13);
5033
+
5034
+ tgpig.z = tgpig.z%(ne12*ne13);
5035
+
5036
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
5037
+
5038
+ kernel_mul_mv_f16_f32_impl(
5039
+ src0[id],
5040
+ src1 + bid*nb11,
5041
+ dst + bid*ne0,
5042
+ ne00,
5043
+ ne01,
5044
+ ne02,
5045
+ nb00,
5046
+ nb01,
5047
+ nb02,
5048
+ ne10,
5049
+ ne11,
5050
+ ne12,
5051
+ nb10,
5052
+ nb11,
5053
+ nb12,
5054
+ ne0,
5055
+ ne1,
5056
+ r2,
5057
+ r3,
5058
+ tgpig,
5059
+ tiisg);
5060
+ }
5061
+
5062
+ [[host_name("kernel_mul_mv_id_q8_0_f32")]]
5063
+ kernel void kernel_mul_mv_id_q8_0_f32(
5064
+ device const char * ids,
5065
+ device const char * src1,
5066
+ device float * dst,
5067
+ constant uint64_t & nbi1,
5068
+ constant int64_t & ne00,
5069
+ constant int64_t & ne01,
5070
+ constant int64_t & ne02,
5071
+ constant uint64_t & nb00,
5072
+ constant uint64_t & nb01,
5073
+ constant uint64_t & nb02,
5074
+ constant int64_t & ne10,
5075
+ constant int64_t & ne11,
5076
+ constant int64_t & ne12,
5077
+ constant int64_t & ne13,
5078
+ constant uint64_t & nb10,
5079
+ constant uint64_t & nb11,
5080
+ constant uint64_t & nb12,
5081
+ constant int64_t & ne0,
5082
+ constant int64_t & ne1,
5083
+ constant uint64_t & nb1,
5084
+ constant uint & r2,
5085
+ constant uint & r3,
5086
+ constant int & idx,
5087
+ device const char * src00,
5088
+ device const char * src01,
5089
+ device const char * src02,
5090
+ device const char * src03,
5091
+ device const char * src04,
5092
+ device const char * src05,
5093
+ device const char * src06,
5094
+ device const char * src07,
5095
+ uint3 tgpig[[threadgroup_position_in_grid]],
5096
+ uint tiitg[[thread_index_in_threadgroup]],
5097
+ uint tiisg[[thread_index_in_simdgroup]],
5098
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
5099
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5100
+
5101
+ const int64_t bid = tgpig.z/(ne12*ne13);
5102
+
5103
+ tgpig.z = tgpig.z%(ne12*ne13);
5104
+
5105
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
5106
+
5107
+ kernel_mul_mv_q8_0_f32_impl(
5108
+ src0[id],
5109
+ (device const float *) (src1 + bid*nb11),
5110
+ dst + bid*ne0,
5111
+ ne00,
5112
+ ne01,
5113
+ ne02,
5114
+ ne10,
5115
+ ne12,
5116
+ ne0,
5117
+ ne1,
5118
+ r2,
5119
+ r3,
5120
+ tgpig,
5121
+ tiisg,
5122
+ sgitg);
5123
+ }
5124
+
5125
+ [[host_name("kernel_mul_mv_id_q4_0_f32")]]
5126
+ kernel void kernel_mul_mv_id_q4_0_f32(
5127
+ device const char * ids,
5128
+ device const char * src1,
5129
+ device float * dst,
5130
+ constant uint64_t & nbi1,
5131
+ constant int64_t & ne00,
5132
+ constant int64_t & ne01,
5133
+ constant int64_t & ne02,
5134
+ constant uint64_t & nb00,
5135
+ constant uint64_t & nb01,
5136
+ constant uint64_t & nb02,
5137
+ constant int64_t & ne10,
5138
+ constant int64_t & ne11,
5139
+ constant int64_t & ne12,
5140
+ constant int64_t & ne13,
5141
+ constant uint64_t & nb10,
5142
+ constant uint64_t & nb11,
5143
+ constant uint64_t & nb12,
5144
+ constant int64_t & ne0,
5145
+ constant int64_t & ne1,
5146
+ constant uint64_t & nb1,
5147
+ constant uint & r2,
5148
+ constant uint & r3,
5149
+ constant int & idx,
5150
+ device const char * src00,
5151
+ device const char * src01,
5152
+ device const char * src02,
5153
+ device const char * src03,
5154
+ device const char * src04,
5155
+ device const char * src05,
5156
+ device const char * src06,
5157
+ device const char * src07,
5158
+ uint3 tgpig[[threadgroup_position_in_grid]],
5159
+ uint tiitg[[thread_index_in_threadgroup]],
5160
+ uint tiisg[[thread_index_in_simdgroup]],
5161
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
5162
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5163
+
5164
+ const int64_t bid = tgpig.z/(ne12*ne13);
5165
+
5166
+ tgpig.z = tgpig.z%(ne12*ne13);
5167
+
5168
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
5169
+
5170
+ mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
5171
+ src0[id],
5172
+ (device const float *) (src1 + bid*nb11),
5173
+ dst + bid*ne0,
5174
+ ne00,
5175
+ ne01,
5176
+ ne02,
5177
+ ne10,
5178
+ ne12,
5179
+ ne0,
5180
+ ne1,
5181
+ r2,
5182
+ r3,
5183
+ tgpig,
5184
+ tiisg,
5185
+ sgitg);
5186
+ }
5187
+
5188
+ [[host_name("kernel_mul_mv_id_q4_1_f32")]]
5189
+ kernel void kernel_mul_mv_id_q4_1_f32(
5190
+ device const char * ids,
5191
+ device const char * src1,
5192
+ device float * dst,
5193
+ constant uint64_t & nbi1,
5194
+ constant int64_t & ne00,
5195
+ constant int64_t & ne01,
5196
+ constant int64_t & ne02,
5197
+ constant uint64_t & nb00,
5198
+ constant uint64_t & nb01,
5199
+ constant uint64_t & nb02,
5200
+ constant int64_t & ne10,
5201
+ constant int64_t & ne11,
5202
+ constant int64_t & ne12,
5203
+ constant int64_t & ne13,
5204
+ constant uint64_t & nb10,
5205
+ constant uint64_t & nb11,
5206
+ constant uint64_t & nb12,
5207
+ constant int64_t & ne0,
5208
+ constant int64_t & ne1,
5209
+ constant uint64_t & nb1,
5210
+ constant uint & r2,
5211
+ constant uint & r3,
5212
+ constant int & idx,
5213
+ device const char * src00,
5214
+ device const char * src01,
5215
+ device const char * src02,
5216
+ device const char * src03,
5217
+ device const char * src04,
5218
+ device const char * src05,
5219
+ device const char * src06,
5220
+ device const char * src07,
5221
+ uint3 tgpig[[threadgroup_position_in_grid]],
5222
+ uint tiitg[[thread_index_in_threadgroup]],
5223
+ uint tiisg[[thread_index_in_simdgroup]],
5224
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
5225
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5226
+
5227
+ const int64_t bid = tgpig.z/(ne12*ne13);
5228
+
5229
+ tgpig.z = tgpig.z%(ne12*ne13);
5230
+
5231
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
5232
+
5233
+ mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
5234
+ src0[id],
5235
+ (device const float *) (src1 + bid*nb11),
5236
+ dst + bid*ne0,
5237
+ ne00,
5238
+ ne01,
5239
+ ne02,
5240
+ ne10,
5241
+ ne12,
5242
+ ne0,
5243
+ ne1,
5244
+ r2,
5245
+ r3,
5246
+ tgpig,
5247
+ tiisg,
5248
+ sgitg);
5249
+ }
5250
+
5251
+ [[host_name("kernel_mul_mv_id_q5_0_f32")]]
5252
+ kernel void kernel_mul_mv_id_q5_0_f32(
5253
+ device const char * ids,
5254
+ device const char * src1,
5255
+ device float * dst,
5256
+ constant uint64_t & nbi1,
5257
+ constant int64_t & ne00,
5258
+ constant int64_t & ne01,
5259
+ constant int64_t & ne02,
5260
+ constant uint64_t & nb00,
5261
+ constant uint64_t & nb01,
5262
+ constant uint64_t & nb02,
5263
+ constant int64_t & ne10,
5264
+ constant int64_t & ne11,
5265
+ constant int64_t & ne12,
5266
+ constant int64_t & ne13,
5267
+ constant uint64_t & nb10,
5268
+ constant uint64_t & nb11,
5269
+ constant uint64_t & nb12,
5270
+ constant int64_t & ne0,
5271
+ constant int64_t & ne1,
5272
+ constant uint64_t & nb1,
5273
+ constant uint & r2,
5274
+ constant uint & r3,
5275
+ constant int & idx,
5276
+ device const char * src00,
5277
+ device const char * src01,
5278
+ device const char * src02,
5279
+ device const char * src03,
5280
+ device const char * src04,
5281
+ device const char * src05,
5282
+ device const char * src06,
5283
+ device const char * src07,
5284
+ uint3 tgpig[[threadgroup_position_in_grid]],
5285
+ uint tiitg[[thread_index_in_threadgroup]],
5286
+ uint tiisg[[thread_index_in_simdgroup]],
5287
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
5288
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5289
+
5290
+ const int64_t bid = tgpig.z/(ne12*ne13);
5291
+
5292
+ tgpig.z = tgpig.z%(ne12*ne13);
5293
+
5294
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
5295
+
5296
+ mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
5297
+ src0[id],
5298
+ (device const float *) (src1 + bid*nb11),
5299
+ dst + bid*ne0,
5300
+ ne00,
5301
+ ne01,
5302
+ ne02,
5303
+ ne10,
5304
+ ne12,
5305
+ ne0,
5306
+ ne1,
5307
+ r2,
5308
+ r3,
5309
+ tgpig,
5310
+ tiisg,
5311
+ sgitg);
5312
+ }
5313
+
5314
+ [[host_name("kernel_mul_mv_id_q5_1_f32")]]
5315
+ kernel void kernel_mul_mv_id_q5_1_f32(
5316
+ device const char * ids,
5317
+ device const char * src1,
5318
+ device float * dst,
5319
+ constant uint64_t & nbi1,
5320
+ constant int64_t & ne00,
5321
+ constant int64_t & ne01,
5322
+ constant int64_t & ne02,
5323
+ constant uint64_t & nb00,
5324
+ constant uint64_t & nb01,
5325
+ constant uint64_t & nb02,
5326
+ constant int64_t & ne10,
5327
+ constant int64_t & ne11,
5328
+ constant int64_t & ne12,
5329
+ constant int64_t & ne13,
5330
+ constant uint64_t & nb10,
5331
+ constant uint64_t & nb11,
5332
+ constant uint64_t & nb12,
5333
+ constant int64_t & ne0,
5334
+ constant int64_t & ne1,
5335
+ constant uint64_t & nb1,
5336
+ constant uint & r2,
5337
+ constant uint & r3,
5338
+ constant int & idx,
5339
+ device const char * src00,
5340
+ device const char * src01,
5341
+ device const char * src02,
5342
+ device const char * src03,
5343
+ device const char * src04,
5344
+ device const char * src05,
5345
+ device const char * src06,
5346
+ device const char * src07,
5347
+ uint3 tgpig[[threadgroup_position_in_grid]],
5348
+ uint tiitg[[thread_index_in_threadgroup]],
5349
+ uint tiisg[[thread_index_in_simdgroup]],
5350
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
5351
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5352
+
5353
+ const int64_t bid = tgpig.z/(ne12*ne13);
5354
+
5355
+ tgpig.z = tgpig.z%(ne12*ne13);
5356
+
5357
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
5358
+
5359
+ mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
5360
+ src0[id],
5361
+ (device const float *) (src1 + bid*nb11),
5362
+ dst + bid*ne0,
5363
+ ne00,
5364
+ ne01,
5365
+ ne02,
5366
+ ne10,
5367
+ ne12,
5368
+ ne0,
5369
+ ne1,
5370
+ r2,
5371
+ r3,
5372
+ tgpig,
5373
+ tiisg,
5374
+ sgitg);
5375
+ }
5376
+
5377
+ [[host_name("kernel_mul_mv_id_q2_K_f32")]]
5378
+ kernel void kernel_mul_mv_id_q2_K_f32(
5379
+ device const char * ids,
5380
+ device const char * src1,
5381
+ device float * dst,
5382
+ constant uint64_t & nbi1,
5383
+ constant int64_t & ne00,
5384
+ constant int64_t & ne01,
5385
+ constant int64_t & ne02,
5386
+ constant uint64_t & nb00,
5387
+ constant uint64_t & nb01,
5388
+ constant uint64_t & nb02,
5389
+ constant int64_t & ne10,
5390
+ constant int64_t & ne11,
5391
+ constant int64_t & ne12,
5392
+ constant int64_t & ne13,
5393
+ constant uint64_t & nb10,
5394
+ constant uint64_t & nb11,
5395
+ constant uint64_t & nb12,
5396
+ constant int64_t & ne0,
5397
+ constant int64_t & ne1,
5398
+ constant uint64_t & nb1,
5399
+ constant uint & r2,
5400
+ constant uint & r3,
5401
+ constant int & idx,
5402
+ device const char * src00,
5403
+ device const char * src01,
5404
+ device const char * src02,
5405
+ device const char * src03,
5406
+ device const char * src04,
5407
+ device const char * src05,
5408
+ device const char * src06,
5409
+ device const char * src07,
5410
+ uint3 tgpig[[threadgroup_position_in_grid]],
5411
+ uint tiitg[[thread_index_in_threadgroup]],
5412
+ uint tiisg[[thread_index_in_simdgroup]],
5413
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
5414
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5415
+
5416
+ const int64_t bid = tgpig.z/(ne12*ne13);
5417
+
5418
+ tgpig.z = tgpig.z%(ne12*ne13);
5419
+
5420
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
5421
+
5422
+ kernel_mul_mv_q2_K_f32_impl(
5423
+ src0[id],
5424
+ (device const float *) (src1 + bid*nb11),
5425
+ dst + bid*ne0,
5426
+ ne00,
5427
+ ne01,
5428
+ ne02,
5429
+ ne10,
5430
+ ne12,
5431
+ ne0,
5432
+ ne1,
5433
+ r2,
5434
+ r3,
5435
+ tgpig,
5436
+ tiisg,
5437
+ sgitg);
5438
+ }
5439
+
5440
+ [[host_name("kernel_mul_mv_id_q3_K_f32")]]
5441
+ kernel void kernel_mul_mv_id_q3_K_f32(
5442
+ device const char * ids,
5443
+ device const char * src1,
5444
+ device float * dst,
5445
+ constant uint64_t & nbi1,
5446
+ constant int64_t & ne00,
5447
+ constant int64_t & ne01,
5448
+ constant int64_t & ne02,
5449
+ constant uint64_t & nb00,
5450
+ constant uint64_t & nb01,
5451
+ constant uint64_t & nb02,
5452
+ constant int64_t & ne10,
5453
+ constant int64_t & ne11,
5454
+ constant int64_t & ne12,
5455
+ constant int64_t & ne13,
5456
+ constant uint64_t & nb10,
5457
+ constant uint64_t & nb11,
5458
+ constant uint64_t & nb12,
5459
+ constant int64_t & ne0,
5460
+ constant int64_t & ne1,
5461
+ constant uint64_t & nb1,
5462
+ constant uint & r2,
5463
+ constant uint & r3,
5464
+ constant int & idx,
5465
+ device const char * src00,
5466
+ device const char * src01,
5467
+ device const char * src02,
5468
+ device const char * src03,
5469
+ device const char * src04,
5470
+ device const char * src05,
5471
+ device const char * src06,
5472
+ device const char * src07,
5473
+ uint3 tgpig[[threadgroup_position_in_grid]],
5474
+ uint tiitg[[thread_index_in_threadgroup]],
5475
+ uint tiisg[[thread_index_in_simdgroup]],
5476
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
5477
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5478
+
5479
+ const int64_t bid = tgpig.z/(ne12*ne13);
5480
+
5481
+ tgpig.z = tgpig.z%(ne12*ne13);
5482
+
5483
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
5484
+
5485
+ kernel_mul_mv_q3_K_f32_impl(
5486
+ src0[id],
5487
+ (device const float *) (src1 + bid*nb11),
5488
+ dst + bid*ne0,
5489
+ ne00,
5490
+ ne01,
5491
+ ne02,
5492
+ ne10,
5493
+ ne12,
5494
+ ne0,
5495
+ ne1,
5496
+ r2,
5497
+ r3,
5498
+ tgpig,
5499
+ tiisg,
5500
+ sgitg);
5501
+ }
5502
+
5503
+ [[host_name("kernel_mul_mv_id_q4_K_f32")]]
5504
+ kernel void kernel_mul_mv_id_q4_K_f32(
5505
+ device const char * ids,
5506
+ device const char * src1,
5507
+ device float * dst,
5508
+ constant uint64_t & nbi1,
5509
+ constant int64_t & ne00,
5510
+ constant int64_t & ne01,
5511
+ constant int64_t & ne02,
5512
+ constant uint64_t & nb00,
5513
+ constant uint64_t & nb01,
5514
+ constant uint64_t & nb02,
5515
+ constant int64_t & ne10,
5516
+ constant int64_t & ne11,
5517
+ constant int64_t & ne12,
5518
+ constant int64_t & ne13,
5519
+ constant uint64_t & nb10,
5520
+ constant uint64_t & nb11,
5521
+ constant uint64_t & nb12,
5522
+ constant int64_t & ne0,
5523
+ constant int64_t & ne1,
5524
+ constant uint64_t & nb1,
5525
+ constant uint & r2,
5526
+ constant uint & r3,
5527
+ constant int & idx,
5528
+ device const char * src00,
5529
+ device const char * src01,
5530
+ device const char * src02,
5531
+ device const char * src03,
5532
+ device const char * src04,
5533
+ device const char * src05,
5534
+ device const char * src06,
5535
+ device const char * src07,
5536
+ uint3 tgpig[[threadgroup_position_in_grid]],
5537
+ uint tiitg[[thread_index_in_threadgroup]],
5538
+ uint tiisg[[thread_index_in_simdgroup]],
5539
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
5540
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5541
+
5542
+ const int64_t bid = tgpig.z/(ne12*ne13);
5543
+
5544
+ tgpig.z = tgpig.z%(ne12*ne13);
5545
+
5546
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
5547
+
5548
+ kernel_mul_mv_q4_K_f32_impl(
5549
+ src0[id],
5550
+ (device const float *) (src1 + bid*nb11),
5551
+ dst + bid*ne0,
5552
+ ne00,
5553
+ ne01,
5554
+ ne02,
5555
+ ne10,
5556
+ ne12,
5557
+ ne0,
5558
+ ne1,
5559
+ r2,
5560
+ r3,
5561
+ tgpig,
5562
+ tiisg,
5563
+ sgitg);
5564
+ }
5565
+
5566
+ [[host_name("kernel_mul_mv_id_q5_K_f32")]]
5567
+ kernel void kernel_mul_mv_id_q5_K_f32(
5568
+ device const char * ids,
5569
+ device const char * src1,
5570
+ device float * dst,
5571
+ constant uint64_t & nbi1,
5572
+ constant int64_t & ne00,
5573
+ constant int64_t & ne01,
5574
+ constant int64_t & ne02,
5575
+ constant uint64_t & nb00,
5576
+ constant uint64_t & nb01,
5577
+ constant uint64_t & nb02,
5578
+ constant int64_t & ne10,
5579
+ constant int64_t & ne11,
5580
+ constant int64_t & ne12,
5581
+ constant int64_t & ne13,
5582
+ constant uint64_t & nb10,
5583
+ constant uint64_t & nb11,
5584
+ constant uint64_t & nb12,
5585
+ constant int64_t & ne0,
5586
+ constant int64_t & ne1,
5587
+ constant uint64_t & nb1,
5588
+ constant uint & r2,
5589
+ constant uint & r3,
5590
+ constant int & idx,
5591
+ device const char * src00,
5592
+ device const char * src01,
5593
+ device const char * src02,
5594
+ device const char * src03,
5595
+ device const char * src04,
5596
+ device const char * src05,
5597
+ device const char * src06,
5598
+ device const char * src07,
5599
+ uint3 tgpig[[threadgroup_position_in_grid]],
5600
+ uint tiitg[[thread_index_in_threadgroup]],
5601
+ uint tiisg[[thread_index_in_simdgroup]],
5602
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
5603
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5604
+
5605
+ const int64_t bid = tgpig.z/(ne12*ne13);
5606
+
5607
+ tgpig.z = tgpig.z%(ne12*ne13);
5608
+
5609
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
5610
+
5611
+ kernel_mul_mv_q5_K_f32_impl(
5612
+ src0[id],
5613
+ (device const float *) (src1 + bid*nb11),
5614
+ dst + bid*ne0,
5615
+ ne00,
5616
+ ne01,
5617
+ ne02,
5618
+ ne10,
5619
+ ne12,
5620
+ ne0,
5621
+ ne1,
5622
+ r2,
5623
+ r3,
5624
+ tgpig,
5625
+ tiisg,
5626
+ sgitg);
5627
+ }
5628
+
5629
+ [[host_name("kernel_mul_mv_id_q6_K_f32")]]
5630
+ kernel void kernel_mul_mv_id_q6_K_f32(
5631
+ device const char * ids,
5632
+ device const char * src1,
5633
+ device float * dst,
5634
+ constant uint64_t & nbi1,
5635
+ constant int64_t & ne00,
5636
+ constant int64_t & ne01,
5637
+ constant int64_t & ne02,
5638
+ constant uint64_t & nb00,
5639
+ constant uint64_t & nb01,
5640
+ constant uint64_t & nb02,
5641
+ constant int64_t & ne10,
5642
+ constant int64_t & ne11,
5643
+ constant int64_t & ne12,
5644
+ constant int64_t & ne13,
5645
+ constant uint64_t & nb10,
5646
+ constant uint64_t & nb11,
5647
+ constant uint64_t & nb12,
5648
+ constant int64_t & ne0,
5649
+ constant int64_t & ne1,
5650
+ constant uint64_t & nb1,
5651
+ constant uint & r2,
5652
+ constant uint & r3,
5653
+ constant int & idx,
5654
+ device const char * src00,
5655
+ device const char * src01,
5656
+ device const char * src02,
5657
+ device const char * src03,
5658
+ device const char * src04,
5659
+ device const char * src05,
5660
+ device const char * src06,
5661
+ device const char * src07,
5662
+ uint3 tgpig[[threadgroup_position_in_grid]],
5663
+ uint tiitg[[thread_index_in_threadgroup]],
5664
+ uint tiisg[[thread_index_in_simdgroup]],
5665
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
5666
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5667
+
5668
+ const int64_t bid = tgpig.z/(ne12*ne13);
5669
+
5670
+ tgpig.z = tgpig.z%(ne12*ne13);
5671
+
5672
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
5673
+
5674
+ kernel_mul_mv_q6_K_f32_impl(
5675
+ src0[id],
5676
+ (device const float *) (src1 + bid*nb11),
5677
+ dst + bid*ne0,
5678
+ ne00,
5679
+ ne01,
5680
+ ne02,
5681
+ ne10,
5682
+ ne12,
5683
+ ne0,
5684
+ ne1,
5685
+ r2,
5686
+ r3,
5687
+ tgpig,
5688
+ tiisg,
5689
+ sgitg);
5690
+ }
5691
+
5692
+ [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
5693
+ kernel void kernel_mul_mv_id_iq2_xxs_f32(
5694
+ device const char * ids,
5695
+ device const char * src1,
5696
+ device float * dst,
5697
+ constant uint64_t & nbi1,
5698
+ constant int64_t & ne00,
5699
+ constant int64_t & ne01,
5700
+ constant int64_t & ne02,
5701
+ constant uint64_t & nb00,
5702
+ constant uint64_t & nb01,
5703
+ constant uint64_t & nb02,
5704
+ constant int64_t & ne10,
5705
+ constant int64_t & ne11,
5706
+ constant int64_t & ne12,
5707
+ constant int64_t & ne13,
5708
+ constant uint64_t & nb10,
5709
+ constant uint64_t & nb11,
5710
+ constant uint64_t & nb12,
5711
+ constant int64_t & ne0,
5712
+ constant int64_t & ne1,
5713
+ constant uint64_t & nb1,
5714
+ constant uint & r2,
5715
+ constant uint & r3,
5716
+ constant int & idx,
5717
+ device const char * src00,
5718
+ device const char * src01,
5719
+ device const char * src02,
5720
+ device const char * src03,
5721
+ device const char * src04,
5722
+ device const char * src05,
5723
+ device const char * src06,
5724
+ device const char * src07,
5725
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
5726
+ uint3 tgpig[[threadgroup_position_in_grid]],
5727
+ uint tiitg[[thread_index_in_threadgroup]],
5728
+ uint tiisg[[thread_index_in_simdgroup]],
5729
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
5730
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5731
+
5732
+ const int64_t bid = tgpig.z/(ne12*ne13);
5733
+
5734
+ tgpig.z = tgpig.z%(ne12*ne13);
5735
+
5736
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
5737
+
5738
+ kernel_mul_mv_iq2_xxs_f32_impl(
5739
+ src0[id],
5740
+ (device const float *) (src1 + bid*nb11),
5741
+ dst + bid*ne0,
5742
+ ne00,
5743
+ ne01,
5744
+ ne02,
5745
+ ne10,
5746
+ ne12,
5747
+ ne0,
5748
+ ne1,
5749
+ r2,
5750
+ r3,
5751
+ shared_values,
5752
+ tgpig,
5753
+ tiisg,
5754
+ sgitg);
5755
+ }
5756
+
5757
+ [[host_name("kernel_mul_mv_id_iq2_xs_f32")]]
5758
+ kernel void kernel_mul_mv_id_iq2_xs_f32(
5759
+ device const char * ids,
5760
+ device const char * src1,
5761
+ device float * dst,
5762
+ constant uint64_t & nbi1,
5763
+ constant int64_t & ne00,
5764
+ constant int64_t & ne01,
5765
+ constant int64_t & ne02,
5766
+ constant uint64_t & nb00,
5767
+ constant uint64_t & nb01,
5768
+ constant uint64_t & nb02,
5769
+ constant int64_t & ne10,
5770
+ constant int64_t & ne11,
5771
+ constant int64_t & ne12,
5772
+ constant int64_t & ne13,
5773
+ constant uint64_t & nb10,
5774
+ constant uint64_t & nb11,
5775
+ constant uint64_t & nb12,
5776
+ constant int64_t & ne0,
5777
+ constant int64_t & ne1,
5778
+ constant uint64_t & nb1,
5779
+ constant uint & r2,
5780
+ constant uint & r3,
5781
+ constant int & idx,
5782
+ device const char * src00,
5783
+ device const char * src01,
5784
+ device const char * src02,
5785
+ device const char * src03,
5786
+ device const char * src04,
5787
+ device const char * src05,
5788
+ device const char * src06,
5789
+ device const char * src07,
5790
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
5791
+ uint3 tgpig[[threadgroup_position_in_grid]],
5792
+ uint tiitg[[thread_index_in_threadgroup]],
5793
+ uint tiisg[[thread_index_in_simdgroup]],
5794
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
5795
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5796
+
5797
+ const int64_t bid = tgpig.z/(ne12*ne13);
5798
+
5799
+ tgpig.z = tgpig.z%(ne12*ne13);
5800
+
5801
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
5802
+
5803
+ kernel_mul_mv_iq2_xs_f32_impl(
5804
+ src0[id],
5805
+ (device const float *) (src1 + bid*nb11),
5806
+ dst + bid*ne0,
5807
+ ne00,
5808
+ ne01,
5809
+ ne02,
5810
+ ne10,
5811
+ ne12,
5812
+ ne0,
5813
+ ne1,
5814
+ r2,
5815
+ r3,
5816
+ shared_values,
5817
+ tgpig,
5818
+ tiisg,
5819
+ sgitg);
5820
+ }