llama_cpp 0.9.5 → 0.10.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -3,6 +3,8 @@
3
3
  using namespace metal;
4
4
 
5
5
  #define MAX(x, y) ((x) > (y) ? (x) : (y))
6
+ #define MIN(x, y) ((x) < (y) ? (x) : (y))
7
+ #define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
6
8
 
7
9
  #define QK4_0 32
8
10
  #define QR4_0 2
@@ -41,8 +43,13 @@ typedef struct {
41
43
 
42
44
  #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
43
45
 
44
- // general-purpose kernel for addition of two tensors
45
- // pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
46
+ enum ggml_sort_order {
47
+ GGML_SORT_ASC,
48
+ GGML_SORT_DESC,
49
+ };
50
+
51
+ // general-purpose kernel for addition, multiplication and division of two tensors
52
+ // pros: works for non-contiguous tensors, supports broadcast across all dims
46
53
  // cons: not very efficient
47
54
  kernel void kernel_add(
48
55
  device const char * src0,
@@ -72,6 +79,7 @@ kernel void kernel_add(
72
79
  constant int64_t & nb1,
73
80
  constant int64_t & nb2,
74
81
  constant int64_t & nb3,
82
+ constant int64_t & offs,
75
83
  uint3 tgpig[[threadgroup_position_in_grid]],
76
84
  uint3 tpitg[[thread_position_in_threadgroup]],
77
85
  uint3 ntg[[threads_per_threadgroup]]) {
@@ -83,16 +91,111 @@ kernel void kernel_add(
83
91
  const int64_t i12 = i02 % ne12;
84
92
  const int64_t i11 = i01 % ne11;
85
93
 
86
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
87
- device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
88
- device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
94
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
95
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
96
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs;
97
+
98
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
99
+ const int i10 = i0 % ne10;
100
+ *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10));
101
+ }
102
+ }
103
+
104
+ kernel void kernel_mul(
105
+ device const char * src0,
106
+ device const char * src1,
107
+ device char * dst,
108
+ constant int64_t & ne00,
109
+ constant int64_t & ne01,
110
+ constant int64_t & ne02,
111
+ constant int64_t & ne03,
112
+ constant int64_t & nb00,
113
+ constant int64_t & nb01,
114
+ constant int64_t & nb02,
115
+ constant int64_t & nb03,
116
+ constant int64_t & ne10,
117
+ constant int64_t & ne11,
118
+ constant int64_t & ne12,
119
+ constant int64_t & ne13,
120
+ constant int64_t & nb10,
121
+ constant int64_t & nb11,
122
+ constant int64_t & nb12,
123
+ constant int64_t & nb13,
124
+ constant int64_t & ne0,
125
+ constant int64_t & ne1,
126
+ constant int64_t & ne2,
127
+ constant int64_t & ne3,
128
+ constant int64_t & nb0,
129
+ constant int64_t & nb1,
130
+ constant int64_t & nb2,
131
+ constant int64_t & nb3,
132
+ uint3 tgpig[[threadgroup_position_in_grid]],
133
+ uint3 tpitg[[thread_position_in_threadgroup]],
134
+ uint3 ntg[[threads_per_threadgroup]]) {
135
+ const int64_t i03 = tgpig.z;
136
+ const int64_t i02 = tgpig.y;
137
+ const int64_t i01 = tgpig.x;
138
+
139
+ const int64_t i13 = i03 % ne13;
140
+ const int64_t i12 = i02 % ne12;
141
+ const int64_t i11 = i01 % ne11;
142
+
143
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
144
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
145
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
89
146
 
90
147
  for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
91
- ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0] + ((device float *)src1_ptr)[0];
148
+ const int i10 = i0 % ne10;
149
+ *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10));
150
+ }
151
+ }
152
+
153
+ kernel void kernel_div(
154
+ device const char * src0,
155
+ device const char * src1,
156
+ device char * dst,
157
+ constant int64_t & ne00,
158
+ constant int64_t & ne01,
159
+ constant int64_t & ne02,
160
+ constant int64_t & ne03,
161
+ constant int64_t & nb00,
162
+ constant int64_t & nb01,
163
+ constant int64_t & nb02,
164
+ constant int64_t & nb03,
165
+ constant int64_t & ne10,
166
+ constant int64_t & ne11,
167
+ constant int64_t & ne12,
168
+ constant int64_t & ne13,
169
+ constant int64_t & nb10,
170
+ constant int64_t & nb11,
171
+ constant int64_t & nb12,
172
+ constant int64_t & nb13,
173
+ constant int64_t & ne0,
174
+ constant int64_t & ne1,
175
+ constant int64_t & ne2,
176
+ constant int64_t & ne3,
177
+ constant int64_t & nb0,
178
+ constant int64_t & nb1,
179
+ constant int64_t & nb2,
180
+ constant int64_t & nb3,
181
+ uint3 tgpig[[threadgroup_position_in_grid]],
182
+ uint3 tpitg[[thread_position_in_threadgroup]],
183
+ uint3 ntg[[threads_per_threadgroup]]) {
184
+ const int64_t i03 = tgpig.z;
185
+ const int64_t i02 = tgpig.y;
186
+ const int64_t i01 = tgpig.x;
187
+
188
+ const int64_t i13 = i03 % ne13;
189
+ const int64_t i12 = i02 % ne12;
190
+ const int64_t i11 = i01 % ne11;
92
191
 
93
- src0_ptr += ntg.x*nb00;
94
- src1_ptr += ntg.x*nb10;
95
- dst_ptr += ntg.x*nb0;
192
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
193
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
194
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
195
+
196
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
197
+ const int i10 = i0 % ne10;
198
+ *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10));
96
199
  }
97
200
  }
98
201
 
@@ -102,28 +205,27 @@ kernel void kernel_add_row(
102
205
  device const float4 * src0,
103
206
  device const float4 * src1,
104
207
  device float4 * dst,
105
- constant int64_t & nb [[buffer(27)]],
208
+ constant int64_t & nb [[buffer(28)]],
106
209
  uint tpig[[thread_position_in_grid]]) {
107
210
  dst[tpig] = src0[tpig] + src1[tpig % nb];
108
211
  }
109
212
 
110
- kernel void kernel_mul(
213
+ kernel void kernel_mul_row(
111
214
  device const float4 * src0,
112
215
  device const float4 * src1,
113
216
  device float4 * dst,
217
+ constant int64_t & nb [[buffer(28)]],
114
218
  uint tpig[[thread_position_in_grid]]) {
115
- dst[tpig] = src0[tpig] * src1[tpig];
219
+ dst[tpig] = src0[tpig] * src1[tpig % nb];
116
220
  }
117
221
 
118
- // assumption: src1 is a row
119
- // broadcast src1 into src0
120
- kernel void kernel_mul_row(
222
+ kernel void kernel_div_row(
121
223
  device const float4 * src0,
122
224
  device const float4 * src1,
123
225
  device float4 * dst,
124
- constant int64_t & nb,
226
+ constant int64_t & nb [[buffer(28)]],
125
227
  uint tpig[[thread_position_in_grid]]) {
126
- dst[tpig] = src0[tpig] * src1[tpig % nb];
228
+ dst[tpig] = src0[tpig] / src1[tpig % nb];
127
229
  }
128
230
 
129
231
  kernel void kernel_scale(
@@ -142,14 +244,6 @@ kernel void kernel_scale_4(
142
244
  dst[tpig] = src0[tpig] * scale;
143
245
  }
144
246
 
145
- kernel void kernel_silu(
146
- device const float4 * src0,
147
- device float4 * dst,
148
- uint tpig[[thread_position_in_grid]]) {
149
- device const float4 & x = src0[tpig];
150
- dst[tpig] = x / (1.0f + exp(-x));
151
- }
152
-
153
247
  kernel void kernel_relu(
154
248
  device const float * src0,
155
249
  device float * dst,
@@ -157,15 +251,17 @@ kernel void kernel_relu(
157
251
  dst[tpig] = max(0.0f, src0[tpig]);
158
252
  }
159
253
 
160
- kernel void kernel_sqr(
254
+ kernel void kernel_tanh(
161
255
  device const float * src0,
162
256
  device float * dst,
163
257
  uint tpig[[thread_position_in_grid]]) {
164
- dst[tpig] = src0[tpig] * src0[tpig];
258
+ device const float & x = src0[tpig];
259
+ dst[tpig] = precise::tanh(x);
165
260
  }
166
261
 
167
- constant float GELU_COEF_A = 0.044715f;
168
- constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
262
+ constant float GELU_COEF_A = 0.044715f;
263
+ constant float GELU_QUICK_COEF = -1.702f;
264
+ constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
169
265
 
170
266
  kernel void kernel_gelu(
171
267
  device const float4 * src0,
@@ -180,6 +276,78 @@ kernel void kernel_gelu(
180
276
  dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
181
277
  }
182
278
 
279
+ kernel void kernel_gelu_quick(
280
+ device const float4 * src0,
281
+ device float4 * dst,
282
+ uint tpig[[thread_position_in_grid]]) {
283
+ device const float4 & x = src0[tpig];
284
+
285
+ dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
286
+ }
287
+
288
+ kernel void kernel_silu(
289
+ device const float4 * src0,
290
+ device float4 * dst,
291
+ uint tpig[[thread_position_in_grid]]) {
292
+ device const float4 & x = src0[tpig];
293
+ dst[tpig] = x / (1.0f + exp(-x));
294
+ }
295
+
296
+ kernel void kernel_sqr(
297
+ device const float * src0,
298
+ device float * dst,
299
+ uint tpig[[thread_position_in_grid]]) {
300
+ dst[tpig] = src0[tpig] * src0[tpig];
301
+ }
302
+
303
+ kernel void kernel_sum_rows(
304
+ device const float * src0,
305
+ device float * dst,
306
+ constant int64_t & ne00,
307
+ constant int64_t & ne01,
308
+ constant int64_t & ne02,
309
+ constant int64_t & ne03,
310
+ constant int64_t & nb00,
311
+ constant int64_t & nb01,
312
+ constant int64_t & nb02,
313
+ constant int64_t & nb03,
314
+ constant int64_t & ne10,
315
+ constant int64_t & ne11,
316
+ constant int64_t & ne12,
317
+ constant int64_t & ne13,
318
+ constant int64_t & nb10,
319
+ constant int64_t & nb11,
320
+ constant int64_t & nb12,
321
+ constant int64_t & nb13,
322
+ constant int64_t & ne0,
323
+ constant int64_t & ne1,
324
+ constant int64_t & ne2,
325
+ constant int64_t & ne3,
326
+ constant int64_t & nb0,
327
+ constant int64_t & nb1,
328
+ constant int64_t & nb2,
329
+ constant int64_t & nb3,
330
+ uint3 tpig[[thread_position_in_grid]]) {
331
+ int64_t i3 = tpig.z;
332
+ int64_t i2 = tpig.y;
333
+ int64_t i1 = tpig.x;
334
+
335
+ if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
336
+ return;
337
+ }
338
+
339
+ device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
340
+ device float * dst_row = (device float *) ((device char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
341
+
342
+ float row_sum = 0;
343
+
344
+ for (int64_t i0 = 0; i0 < ne00; i0++) {
345
+ row_sum += src_row[i0];
346
+ }
347
+
348
+ dst_row[0] = row_sum;
349
+ }
350
+
183
351
  kernel void kernel_soft_max(
184
352
  device const float * src0,
185
353
  device const float * src1,
@@ -198,9 +366,9 @@ kernel void kernel_soft_max(
198
366
  const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
199
367
  const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
200
368
 
201
- device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
202
- device const float * pmask = src1 ? src1 + i01*ne00 : nullptr;
203
- device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
369
+ device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
370
+ device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
371
+ device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
204
372
 
205
373
  // parallel max
206
374
  float lmax = -INFINITY;
@@ -236,7 +404,12 @@ kernel void kernel_soft_max(
236
404
  pdst[i00] = exp_psrc0;
237
405
  }
238
406
 
407
+ // This barrier fixes a failing test
408
+ // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
409
+ threadgroup_barrier(mem_flags::mem_none);
410
+
239
411
  float sum = simd_sum(lsum);
412
+
240
413
  if (ntg > N_SIMDWIDTH) {
241
414
  if (sgitg == 0) {
242
415
  buf[tiisg] = 0.0f;
@@ -279,9 +452,9 @@ kernel void kernel_soft_max_4(
279
452
  const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
280
453
  const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
281
454
 
282
- device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
283
- device const float4 * pmask = src1 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
284
- device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
455
+ device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
456
+ device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
457
+ device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
285
458
 
286
459
  // parallel max
287
460
  float4 lmax4 = -INFINITY;
@@ -319,7 +492,13 @@ kernel void kernel_soft_max_4(
319
492
  }
320
493
 
321
494
  const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
495
+
496
+ // This barrier fixes a failing test
497
+ // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
498
+ threadgroup_barrier(mem_flags::mem_none);
499
+
322
500
  float sum = simd_sum(lsum);
501
+
323
502
  if (ntg > N_SIMDWIDTH) {
324
503
  if (sgitg == 0) {
325
504
  buf[tiisg] = 0.0f;
@@ -490,6 +669,94 @@ kernel void kernel_rms_norm(
490
669
  }
491
670
  }
492
671
 
672
+ kernel void kernel_group_norm(
673
+ device const float * src0,
674
+ device float * dst,
675
+ constant int64_t & ne00,
676
+ constant int64_t & ne01,
677
+ constant int64_t & ne02,
678
+ constant uint64_t & nb00,
679
+ constant uint64_t & nb01,
680
+ constant uint64_t & nb02,
681
+ constant int32_t & n_groups,
682
+ constant float & eps,
683
+ threadgroup float * buf [[threadgroup(0)]],
684
+ uint tgpig[[threadgroup_position_in_grid]],
685
+ uint tpitg[[thread_position_in_threadgroup]],
686
+ uint sgitg[[simdgroup_index_in_threadgroup]],
687
+ uint tiisg[[thread_index_in_simdgroup]],
688
+ uint ntg[[threads_per_threadgroup]]) {
689
+ const int64_t ne = ne00*ne01*ne02;
690
+ const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups);
691
+
692
+ int start = tgpig * gs;
693
+ int end = start + gs;
694
+
695
+ start += tpitg;
696
+
697
+ if (end >= ne) {
698
+ end = ne;
699
+ }
700
+
701
+ float tmp = 0.0f; // partial sum for thread in warp
702
+
703
+ for (int j = start; j < end; j += ntg) {
704
+ tmp += src0[j];
705
+ }
706
+
707
+ threadgroup_barrier(mem_flags::mem_threadgroup);
708
+ tmp = simd_sum(tmp);
709
+ if (ntg > N_SIMDWIDTH) {
710
+ if (sgitg == 0) {
711
+ buf[tiisg] = 0.0f;
712
+ }
713
+
714
+ threadgroup_barrier(mem_flags::mem_threadgroup);
715
+
716
+ if (tiisg == 0) {
717
+ buf[sgitg] = tmp;
718
+ }
719
+
720
+ threadgroup_barrier(mem_flags::mem_threadgroup);
721
+
722
+ tmp = buf[tiisg];
723
+ tmp = simd_sum(tmp);
724
+ }
725
+
726
+ const float mean = tmp / gs;
727
+ tmp = 0.0f;
728
+
729
+ for (int j = start; j < end; j += ntg) {
730
+ float xi = src0[j] - mean;
731
+ dst[j] = xi;
732
+ tmp += xi * xi;
733
+ }
734
+
735
+ tmp = simd_sum(tmp);
736
+ if (ntg > N_SIMDWIDTH) {
737
+ if (sgitg == 0) {
738
+ buf[tiisg] = 0.0f;
739
+ }
740
+
741
+ threadgroup_barrier(mem_flags::mem_threadgroup);
742
+
743
+ if (tiisg == 0) {
744
+ buf[sgitg] = tmp;
745
+ }
746
+
747
+ threadgroup_barrier(mem_flags::mem_threadgroup);
748
+
749
+ tmp = buf[tiisg];
750
+ tmp = simd_sum(tmp);
751
+ }
752
+
753
+ const float variance = tmp / gs;
754
+ const float scale = 1.0f/sqrt(variance + eps);
755
+ for (int j = start; j < end; j += ntg) {
756
+ dst[j] *= scale;
757
+ }
758
+ }
759
+
493
760
  // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
494
761
  // il indicates where the q4 quants begin (0 or QK4_0/4)
495
762
  // we assume that the yl's have been multiplied with the appropriate scale factor
@@ -582,9 +849,20 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
582
849
  // giard against the number of rows not being divisible by
583
850
  // N_DST, so this is another explicit assumption of the implementation.
584
851
  template<typename block_q_type, int nr, int nsg, int nw>
585
- void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
586
- int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
587
- uint3 tgpig, uint tiisg, uint sgitg) {
852
+ void mul_vec_q_n_f32_impl(
853
+ device const void * src0,
854
+ device const float * src1,
855
+ device float * dst,
856
+ int64_t ne00,
857
+ int64_t ne01,
858
+ int64_t ne02,
859
+ int64_t ne10,
860
+ int64_t ne12,
861
+ int64_t ne0,
862
+ int64_t ne1,
863
+ uint r2,
864
+ uint r3,
865
+ uint3 tgpig, uint tiisg, uint sgitg) {
588
866
  const int nb = ne00/QK4_0;
589
867
 
590
868
  const int r0 = tgpig.x;
@@ -593,7 +871,10 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
593
871
 
594
872
  const int first_row = (r0 * nsg + sgitg) * nr;
595
873
 
596
- const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
874
+ const uint i12 = im%ne12;
875
+ const uint i13 = im/ne12;
876
+
877
+ const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
597
878
 
598
879
  device const block_q_type * x = (device const block_q_type *) src0 + offset0;
599
880
  device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
@@ -643,13 +924,14 @@ kernel void kernel_mul_mv_q4_0_f32(
643
924
  constant int64_t & ne02[[buffer(5)]],
644
925
  constant int64_t & ne10[[buffer(9)]],
645
926
  constant int64_t & ne12[[buffer(11)]],
646
- constant int64_t & ne0[[buffer(15)]],
647
- constant int64_t & ne1[[buffer(16)]],
648
- constant uint & gqa[[buffer(17)]],
927
+ constant int64_t & ne0 [[buffer(15)]],
928
+ constant int64_t & ne1 [[buffer(16)]],
929
+ constant uint & r2 [[buffer(17)]],
930
+ constant uint & r3 [[buffer(18)]],
649
931
  uint3 tgpig[[threadgroup_position_in_grid]],
650
932
  uint tiisg[[thread_index_in_simdgroup]],
651
933
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
652
- mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
934
+ mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
653
935
  }
654
936
 
655
937
  kernel void kernel_mul_mv_q4_1_f32(
@@ -661,13 +943,14 @@ kernel void kernel_mul_mv_q4_1_f32(
661
943
  constant int64_t & ne02[[buffer(5)]],
662
944
  constant int64_t & ne10[[buffer(9)]],
663
945
  constant int64_t & ne12[[buffer(11)]],
664
- constant int64_t & ne0[[buffer(15)]],
665
- constant int64_t & ne1[[buffer(16)]],
666
- constant uint & gqa[[buffer(17)]],
946
+ constant int64_t & ne0 [[buffer(15)]],
947
+ constant int64_t & ne1 [[buffer(16)]],
948
+ constant uint & r2 [[buffer(17)]],
949
+ constant uint & r3 [[buffer(18)]],
667
950
  uint3 tgpig[[threadgroup_position_in_grid]],
668
951
  uint tiisg[[thread_index_in_simdgroup]],
669
952
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
670
- mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
953
+ mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
671
954
  }
672
955
 
673
956
  kernel void kernel_mul_mv_q5_0_f32(
@@ -679,13 +962,14 @@ kernel void kernel_mul_mv_q5_0_f32(
679
962
  constant int64_t & ne02[[buffer(5)]],
680
963
  constant int64_t & ne10[[buffer(9)]],
681
964
  constant int64_t & ne12[[buffer(11)]],
682
- constant int64_t & ne0[[buffer(15)]],
683
- constant int64_t & ne1[[buffer(16)]],
684
- constant uint & gqa[[buffer(17)]],
965
+ constant int64_t & ne0 [[buffer(15)]],
966
+ constant int64_t & ne1 [[buffer(16)]],
967
+ constant uint & r2 [[buffer(17)]],
968
+ constant uint & r3 [[buffer(18)]],
685
969
  uint3 tgpig[[threadgroup_position_in_grid]],
686
970
  uint tiisg[[thread_index_in_simdgroup]],
687
971
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
688
- mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
972
+ mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
689
973
  }
690
974
 
691
975
  kernel void kernel_mul_mv_q5_1_f32(
@@ -697,33 +981,35 @@ kernel void kernel_mul_mv_q5_1_f32(
697
981
  constant int64_t & ne02[[buffer(5)]],
698
982
  constant int64_t & ne10[[buffer(9)]],
699
983
  constant int64_t & ne12[[buffer(11)]],
700
- constant int64_t & ne0[[buffer(15)]],
701
- constant int64_t & ne1[[buffer(16)]],
702
- constant uint & gqa[[buffer(17)]],
984
+ constant int64_t & ne0 [[buffer(15)]],
985
+ constant int64_t & ne1 [[buffer(16)]],
986
+ constant uint & r2 [[buffer(17)]],
987
+ constant uint & r3 [[buffer(18)]],
703
988
  uint3 tgpig[[threadgroup_position_in_grid]],
704
989
  uint tiisg[[thread_index_in_simdgroup]],
705
990
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
706
- mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
991
+ mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
707
992
  }
708
993
 
709
994
 
710
995
  #define NB_Q8_0 8
711
996
 
712
- kernel void kernel_mul_mv_q8_0_f32(
997
+ void kernel_mul_mv_q8_0_f32_impl(
713
998
  device const void * src0,
714
999
  device const float * src1,
715
1000
  device float * dst,
716
1001
  constant int64_t & ne00,
717
- constant int64_t & ne01[[buffer(4)]],
718
- constant int64_t & ne02[[buffer(5)]],
719
- constant int64_t & ne10[[buffer(9)]],
720
- constant int64_t & ne12[[buffer(11)]],
721
- constant int64_t & ne0[[buffer(15)]],
722
- constant int64_t & ne1[[buffer(16)]],
723
- constant uint & gqa[[buffer(17)]],
1002
+ constant int64_t & ne01,
1003
+ constant int64_t & ne02,
1004
+ constant int64_t & ne10,
1005
+ constant int64_t & ne12,
1006
+ constant int64_t & ne0,
1007
+ constant int64_t & ne1,
1008
+ constant uint & r2,
1009
+ constant uint & r3,
724
1010
  uint3 tgpig[[threadgroup_position_in_grid]],
725
- uint tiisg[[thread_index_in_simdgroup]],
726
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
1011
+ uint tiisg[[thread_index_in_simdgroup]],
1012
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
727
1013
  const int nr = N_DST;
728
1014
  const int nsg = N_SIMDGROUP;
729
1015
  const int nw = N_SIMDWIDTH;
@@ -732,8 +1018,14 @@ kernel void kernel_mul_mv_q8_0_f32(
732
1018
  const int r0 = tgpig.x;
733
1019
  const int r1 = tgpig.y;
734
1020
  const int im = tgpig.z;
1021
+
735
1022
  const int first_row = (r0 * nsg + sgitg) * nr;
736
- const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
1023
+
1024
+ const uint i12 = im%ne12;
1025
+ const uint i13 = im/ne12;
1026
+
1027
+ const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
1028
+
737
1029
  device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
738
1030
  device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
739
1031
 
@@ -771,11 +1063,31 @@ kernel void kernel_mul_mv_q8_0_f32(
771
1063
  }
772
1064
  }
773
1065
 
774
- #define N_F32_F32 4
775
-
776
- kernel void kernel_mul_mv_f32_f32(
777
- device const char * src0,
778
- device const char * src1,
1066
+ [[host_name("kernel_mul_mv_q8_0_f32")]]
1067
+ kernel void kernel_mul_mv_q8_0_f32(
1068
+ device const void * src0,
1069
+ device const float * src1,
1070
+ device float * dst,
1071
+ constant int64_t & ne00,
1072
+ constant int64_t & ne01,
1073
+ constant int64_t & ne02,
1074
+ constant int64_t & ne10,
1075
+ constant int64_t & ne12,
1076
+ constant int64_t & ne0,
1077
+ constant int64_t & ne1,
1078
+ constant uint & r2 [[buffer(17)]],
1079
+ constant uint & r3 [[buffer(18)]],
1080
+ uint3 tgpig[[threadgroup_position_in_grid]],
1081
+ uint tiisg[[thread_index_in_simdgroup]],
1082
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1083
+ kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
1084
+ }
1085
+
1086
+ #define N_F32_F32 4
1087
+
1088
+ void kernel_mul_mv_f32_f32_impl(
1089
+ device const char * src0,
1090
+ device const char * src1,
779
1091
  device float * dst,
780
1092
  constant int64_t & ne00,
781
1093
  constant int64_t & ne01,
@@ -791,6 +1103,8 @@ kernel void kernel_mul_mv_f32_f32(
791
1103
  constant uint64_t & nb12,
792
1104
  constant int64_t & ne0,
793
1105
  constant int64_t & ne1,
1106
+ constant uint & r2,
1107
+ constant uint & r3,
794
1108
  uint3 tgpig[[threadgroup_position_in_grid]],
795
1109
  uint tiisg[[thread_index_in_simdgroup]]) {
796
1110
 
@@ -798,7 +1112,12 @@ kernel void kernel_mul_mv_f32_f32(
798
1112
  const int64_t rb = tgpig.y*N_F32_F32;
799
1113
  const int64_t im = tgpig.z;
800
1114
 
801
- device const float * x = (device const float *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
1115
+ const uint i12 = im%ne12;
1116
+ const uint i13 = im/ne12;
1117
+
1118
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1119
+
1120
+ device const float * x = (device const float *) (src0 + offset0);
802
1121
 
803
1122
  if (ne00 < 128) {
804
1123
  for (int row = 0; row < N_F32_F32; ++row) {
@@ -844,6 +1163,32 @@ kernel void kernel_mul_mv_f32_f32(
844
1163
  }
845
1164
  }
846
1165
 
1166
+ [[host_name("kernel_mul_mv_f32_f32")]]
1167
+ kernel void kernel_mul_mv_f32_f32(
1168
+ device const char * src0,
1169
+ device const char * src1,
1170
+ device float * dst,
1171
+ constant int64_t & ne00,
1172
+ constant int64_t & ne01,
1173
+ constant int64_t & ne02,
1174
+ constant uint64_t & nb00,
1175
+ constant uint64_t & nb01,
1176
+ constant uint64_t & nb02,
1177
+ constant int64_t & ne10,
1178
+ constant int64_t & ne11,
1179
+ constant int64_t & ne12,
1180
+ constant uint64_t & nb10,
1181
+ constant uint64_t & nb11,
1182
+ constant uint64_t & nb12,
1183
+ constant int64_t & ne0,
1184
+ constant int64_t & ne1,
1185
+ constant uint & r2 [[buffer(17)]],
1186
+ constant uint & r3 [[buffer(18)]],
1187
+ uint3 tgpig[[threadgroup_position_in_grid]],
1188
+ uint tiisg[[thread_index_in_simdgroup]]) {
1189
+ kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
1190
+ }
1191
+
847
1192
  #define N_F16_F16 4
848
1193
 
849
1194
  kernel void kernel_mul_mv_f16_f16(
@@ -864,6 +1209,8 @@ kernel void kernel_mul_mv_f16_f16(
864
1209
  constant uint64_t & nb12,
865
1210
  constant int64_t & ne0,
866
1211
  constant int64_t & ne1,
1212
+ constant uint & r2 [[buffer(17)]],
1213
+ constant uint & r3 [[buffer(18)]],
867
1214
  uint3 tgpig[[threadgroup_position_in_grid]],
868
1215
  uint tiisg[[thread_index_in_simdgroup]]) {
869
1216
 
@@ -871,7 +1218,12 @@ kernel void kernel_mul_mv_f16_f16(
871
1218
  const int64_t rb = tgpig.y*N_F16_F16;
872
1219
  const int64_t im = tgpig.z;
873
1220
 
874
- device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
1221
+ const uint i12 = im%ne12;
1222
+ const uint i13 = im/ne12;
1223
+
1224
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1225
+
1226
+ device const half * x = (device const half *) (src0 + offset0);
875
1227
 
876
1228
  if (ne00 < 128) {
877
1229
  for (int row = 0; row < N_F16_F16; ++row) {
@@ -917,7 +1269,7 @@ kernel void kernel_mul_mv_f16_f16(
917
1269
  }
918
1270
  }
919
1271
 
920
- kernel void kernel_mul_mv_f16_f32_1row(
1272
+ void kernel_mul_mv_f16_f32_1row_impl(
921
1273
  device const char * src0,
922
1274
  device const char * src1,
923
1275
  device float * dst,
@@ -935,6 +1287,8 @@ kernel void kernel_mul_mv_f16_f32_1row(
935
1287
  constant uint64_t & nb12,
936
1288
  constant int64_t & ne0,
937
1289
  constant int64_t & ne1,
1290
+ constant uint & r2,
1291
+ constant uint & r3,
938
1292
  uint3 tgpig[[threadgroup_position_in_grid]],
939
1293
  uint tiisg[[thread_index_in_simdgroup]]) {
940
1294
 
@@ -942,7 +1296,12 @@ kernel void kernel_mul_mv_f16_f32_1row(
942
1296
  const int64_t r1 = tgpig.y;
943
1297
  const int64_t im = tgpig.z;
944
1298
 
945
- device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
1299
+ const uint i12 = im%ne12;
1300
+ const uint i13 = im/ne12;
1301
+
1302
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1303
+
1304
+ device const half * x = (device const half *) (src0 + offset0);
946
1305
  device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
947
1306
 
948
1307
  float sumf = 0;
@@ -966,12 +1325,37 @@ kernel void kernel_mul_mv_f16_f32_1row(
966
1325
  dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
967
1326
  }
968
1327
  }
1328
+ }
969
1329
 
1330
+ [[host_name("kernel_mul_mv_f16_f32_1row")]]
1331
+ kernel void kernel_mul_mv_f16_f32_1row(
1332
+ device const char * src0,
1333
+ device const char * src1,
1334
+ device float * dst,
1335
+ constant int64_t & ne00,
1336
+ constant int64_t & ne01,
1337
+ constant int64_t & ne02,
1338
+ constant uint64_t & nb00,
1339
+ constant uint64_t & nb01,
1340
+ constant uint64_t & nb02,
1341
+ constant int64_t & ne10,
1342
+ constant int64_t & ne11,
1343
+ constant int64_t & ne12,
1344
+ constant uint64_t & nb10,
1345
+ constant uint64_t & nb11,
1346
+ constant uint64_t & nb12,
1347
+ constant int64_t & ne0,
1348
+ constant int64_t & ne1,
1349
+ constant uint & r2 [[buffer(17)]],
1350
+ constant uint & r3 [[buffer(18)]],
1351
+ uint3 tgpig[[threadgroup_position_in_grid]],
1352
+ uint tiisg[[thread_index_in_simdgroup]]) {
1353
+ kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
970
1354
  }
971
1355
 
972
1356
  #define N_F16_F32 4
973
1357
 
974
- kernel void kernel_mul_mv_f16_f32(
1358
+ void kernel_mul_mv_f16_f32_impl(
975
1359
  device const char * src0,
976
1360
  device const char * src1,
977
1361
  device float * dst,
@@ -989,6 +1373,8 @@ kernel void kernel_mul_mv_f16_f32(
989
1373
  constant uint64_t & nb12,
990
1374
  constant int64_t & ne0,
991
1375
  constant int64_t & ne1,
1376
+ constant uint & r2,
1377
+ constant uint & r3,
992
1378
  uint3 tgpig[[threadgroup_position_in_grid]],
993
1379
  uint tiisg[[thread_index_in_simdgroup]]) {
994
1380
 
@@ -996,7 +1382,12 @@ kernel void kernel_mul_mv_f16_f32(
996
1382
  const int64_t rb = tgpig.y*N_F16_F32;
997
1383
  const int64_t im = tgpig.z;
998
1384
 
999
- device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
1385
+ const uint i12 = im%ne12;
1386
+ const uint i13 = im/ne12;
1387
+
1388
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1389
+
1390
+ device const half * x = (device const half *) (src0 + offset0);
1000
1391
 
1001
1392
  if (ne00 < 128) {
1002
1393
  for (int row = 0; row < N_F16_F32; ++row) {
@@ -1042,6 +1433,32 @@ kernel void kernel_mul_mv_f16_f32(
1042
1433
  }
1043
1434
  }
1044
1435
 
1436
+ [[host_name("kernel_mul_mv_f16_f32")]]
1437
+ kernel void kernel_mul_mv_f16_f32(
1438
+ device const char * src0,
1439
+ device const char * src1,
1440
+ device float * dst,
1441
+ constant int64_t & ne00,
1442
+ constant int64_t & ne01,
1443
+ constant int64_t & ne02,
1444
+ constant uint64_t & nb00,
1445
+ constant uint64_t & nb01,
1446
+ constant uint64_t & nb02,
1447
+ constant int64_t & ne10,
1448
+ constant int64_t & ne11,
1449
+ constant int64_t & ne12,
1450
+ constant uint64_t & nb10,
1451
+ constant uint64_t & nb11,
1452
+ constant uint64_t & nb12,
1453
+ constant int64_t & ne0,
1454
+ constant int64_t & ne1,
1455
+ constant uint & r2 [[buffer(17)]],
1456
+ constant uint & r3 [[buffer(18)]],
1457
+ uint3 tgpig[[threadgroup_position_in_grid]],
1458
+ uint tiisg[[thread_index_in_simdgroup]]) {
1459
+ kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
1460
+ }
1461
+
1045
1462
  // Assumes row size (ne00) is a multiple of 4
1046
1463
  kernel void kernel_mul_mv_f16_f32_l4(
1047
1464
  device const char * src0,
@@ -1061,6 +1478,8 @@ kernel void kernel_mul_mv_f16_f32_l4(
1061
1478
  constant uint64_t & nb12,
1062
1479
  constant int64_t & ne0,
1063
1480
  constant int64_t & ne1,
1481
+ constant uint & r2 [[buffer(17)]],
1482
+ constant uint & r3 [[buffer(18)]],
1064
1483
  uint3 tgpig[[threadgroup_position_in_grid]],
1065
1484
  uint tiisg[[thread_index_in_simdgroup]]) {
1066
1485
 
@@ -1068,7 +1487,12 @@ kernel void kernel_mul_mv_f16_f32_l4(
1068
1487
  const int64_t r0 = tgpig.x;
1069
1488
  const int64_t im = tgpig.z;
1070
1489
 
1071
- device const half4 * x4 = (device const half4 *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
1490
+ const uint i12 = im%ne12;
1491
+ const uint i13 = im/ne12;
1492
+
1493
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1494
+
1495
+ device const half4 * x4 = (device const half4 *) (src0 + offset0);
1072
1496
 
1073
1497
  for (int r1 = 0; r1 < nrows; ++r1) {
1074
1498
  device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
@@ -1120,17 +1544,21 @@ kernel void kernel_alibi_f32(
1120
1544
  const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1121
1545
  const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1122
1546
  const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1547
+ const int64_t k = i3*ne3 + i2;
1123
1548
 
1124
- device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1125
1549
  float m_k;
1126
- if (i2 < n_heads_log2_floor) {
1127
- m_k = pow(m0, i2 + 1);
1550
+ if (k < n_heads_log2_floor) {
1551
+ m_k = pow(m0, k + 1);
1128
1552
  } else {
1129
- m_k = pow(m1, 2 * (i2 - n_heads_log2_floor) + 1);
1553
+ m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1);
1130
1554
  }
1555
+
1556
+ device char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1;
1557
+ device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01;
1131
1558
  for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
1132
- device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1133
- dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);
1559
+ const float src_v = *(device float *)(src_row + i00*nb00);
1560
+ device float * dst_v = (device float *)(dst_row + i00*nb0);
1561
+ *dst_v = i00 * m_k + src_v;
1134
1562
  }
1135
1563
  }
1136
1564
 
@@ -1335,9 +1763,160 @@ kernel void kernel_im2col_f16(
1335
1763
  }
1336
1764
  }
1337
1765
 
1766
+ kernel void kernel_upscale_f32(
1767
+ device const char * src0,
1768
+ device char * dst,
1769
+ constant int64_t & ne00,
1770
+ constant int64_t & ne01,
1771
+ constant int64_t & ne02,
1772
+ constant int64_t & ne03,
1773
+ constant uint64_t & nb00,
1774
+ constant uint64_t & nb01,
1775
+ constant uint64_t & nb02,
1776
+ constant uint64_t & nb03,
1777
+ constant int64_t & ne0,
1778
+ constant int64_t & ne1,
1779
+ constant int64_t & ne2,
1780
+ constant int64_t & ne3,
1781
+ constant uint64_t & nb0,
1782
+ constant uint64_t & nb1,
1783
+ constant uint64_t & nb2,
1784
+ constant uint64_t & nb3,
1785
+ constant int32_t & sf,
1786
+ uint3 tgpig[[threadgroup_position_in_grid]],
1787
+ uint3 tpitg[[thread_position_in_threadgroup]],
1788
+ uint3 ntg[[threads_per_threadgroup]]) {
1789
+
1790
+ const int64_t i3 = tgpig.z;
1791
+ const int64_t i2 = tgpig.y;
1792
+ const int64_t i1 = tgpig.x;
1793
+
1794
+ const int64_t i03 = i3;
1795
+ const int64_t i02 = i2;
1796
+ const int64_t i01 = i1/sf;
1797
+
1798
+ device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
1799
+ device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
1800
+
1801
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
1802
+ dst_ptr[i0] = src0_ptr[i0/sf];
1803
+ }
1804
+ }
1805
+
1806
+ kernel void kernel_pad_f32(
1807
+ device const char * src0,
1808
+ device char * dst,
1809
+ constant int64_t & ne00,
1810
+ constant int64_t & ne01,
1811
+ constant int64_t & ne02,
1812
+ constant int64_t & ne03,
1813
+ constant uint64_t & nb00,
1814
+ constant uint64_t & nb01,
1815
+ constant uint64_t & nb02,
1816
+ constant uint64_t & nb03,
1817
+ constant int64_t & ne0,
1818
+ constant int64_t & ne1,
1819
+ constant int64_t & ne2,
1820
+ constant int64_t & ne3,
1821
+ constant uint64_t & nb0,
1822
+ constant uint64_t & nb1,
1823
+ constant uint64_t & nb2,
1824
+ constant uint64_t & nb3,
1825
+ uint3 tgpig[[threadgroup_position_in_grid]],
1826
+ uint3 tpitg[[thread_position_in_threadgroup]],
1827
+ uint3 ntg[[threads_per_threadgroup]]) {
1828
+
1829
+ const int64_t i3 = tgpig.z;
1830
+ const int64_t i2 = tgpig.y;
1831
+ const int64_t i1 = tgpig.x;
1832
+
1833
+ const int64_t i03 = i3;
1834
+ const int64_t i02 = i2;
1835
+ const int64_t i01 = i1;
1836
+
1837
+ device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
1838
+ device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
1839
+
1840
+ if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
1841
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
1842
+ if (i0 < ne00) {
1843
+ dst_ptr[i0] = src0_ptr[i0];
1844
+ } else {
1845
+ dst_ptr[i0] = 0.0f;
1846
+ }
1847
+ }
1848
+
1849
+ return;
1850
+ }
1851
+
1852
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
1853
+ dst_ptr[i0] = 0.0f;
1854
+ }
1855
+ }
1856
+
1857
+ // bitonic sort implementation following the CUDA kernels as reference
1858
+ typedef void (argsort_t)(
1859
+ device const float * x,
1860
+ device int32_t * dst,
1861
+ constant int64_t & ncols,
1862
+ uint3 tgpig[[threadgroup_position_in_grid]],
1863
+ uint3 tpitg[[thread_position_in_threadgroup]]);
1864
+
1865
+ template<ggml_sort_order order>
1866
+ kernel void kernel_argsort_f32_i32(
1867
+ device const float * x,
1868
+ device int32_t * dst,
1869
+ constant int64_t & ncols,
1870
+ uint3 tgpig[[threadgroup_position_in_grid]],
1871
+ uint3 tpitg[[thread_position_in_threadgroup]]) {
1872
+ // bitonic sort
1873
+ int col = tpitg[0];
1874
+ int row = tgpig[1];
1875
+
1876
+ if (col >= ncols) return;
1877
+
1878
+ device const float * x_row = x + row * ncols;
1879
+ device int32_t * dst_row = dst + row * ncols;
1880
+
1881
+ // initialize indices
1882
+ if (col < ncols) {
1883
+ dst_row[col] = col;
1884
+ }
1885
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1886
+
1887
+ for (int k = 2; k <= ncols; k *= 2) {
1888
+ for (int j = k / 2; j > 0; j /= 2) {
1889
+ int ixj = col ^ j;
1890
+ if (ixj > col) {
1891
+ if ((col & k) == 0) {
1892
+ if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
1893
+ SWAP(dst_row[col], dst_row[ixj]);
1894
+ }
1895
+ } else {
1896
+ if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
1897
+ SWAP(dst_row[col], dst_row[ixj]);
1898
+ }
1899
+ }
1900
+ }
1901
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1902
+ }
1903
+ }
1904
+ }
1905
+
1906
+ template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ASC>;
1907
+ template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
1908
+
1909
+ kernel void kernel_leaky_relu_f32(
1910
+ device const float * src0,
1911
+ device float * dst,
1912
+ constant float & slope,
1913
+ uint tpig[[thread_position_in_grid]]) {
1914
+ dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
1915
+ }
1916
+
1338
1917
  kernel void kernel_cpy_f16_f16(
1339
- device const half * src0,
1340
- device half * dst,
1918
+ device const half * src0,
1919
+ device half * dst,
1341
1920
  constant int64_t & ne00,
1342
1921
  constant int64_t & ne01,
1343
1922
  constant int64_t & ne02,
@@ -1376,6 +1955,47 @@ kernel void kernel_cpy_f16_f16(
1376
1955
  }
1377
1956
  }
1378
1957
 
1958
+ kernel void kernel_cpy_f16_f32(
1959
+ device const half * src0,
1960
+ device float * dst,
1961
+ constant int64_t & ne00,
1962
+ constant int64_t & ne01,
1963
+ constant int64_t & ne02,
1964
+ constant int64_t & ne03,
1965
+ constant uint64_t & nb00,
1966
+ constant uint64_t & nb01,
1967
+ constant uint64_t & nb02,
1968
+ constant uint64_t & nb03,
1969
+ constant int64_t & ne0,
1970
+ constant int64_t & ne1,
1971
+ constant int64_t & ne2,
1972
+ constant int64_t & ne3,
1973
+ constant uint64_t & nb0,
1974
+ constant uint64_t & nb1,
1975
+ constant uint64_t & nb2,
1976
+ constant uint64_t & nb3,
1977
+ uint3 tgpig[[threadgroup_position_in_grid]],
1978
+ uint3 tpitg[[thread_position_in_threadgroup]],
1979
+ uint3 ntg[[threads_per_threadgroup]]) {
1980
+ const int64_t i03 = tgpig[2];
1981
+ const int64_t i02 = tgpig[1];
1982
+ const int64_t i01 = tgpig[0];
1983
+
1984
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
1985
+
1986
+ const int64_t i3 = n / (ne2*ne1*ne0);
1987
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1988
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1989
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1990
+
1991
+ device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1992
+
1993
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
1994
+ device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1995
+ dst_data[i00] = src[0];
1996
+ }
1997
+ }
1998
+
1379
1999
  kernel void kernel_cpy_f32_f16(
1380
2000
  device const float * src0,
1381
2001
  device half * dst,
@@ -1460,47 +2080,238 @@ kernel void kernel_cpy_f32_f32(
1460
2080
  }
1461
2081
  }
1462
2082
 
1463
- kernel void kernel_concat(
1464
- device const char * src0,
1465
- device const char * src1,
1466
- device char * dst,
1467
- constant int64_t & ne00,
1468
- constant int64_t & ne01,
1469
- constant int64_t & ne02,
1470
- constant int64_t & ne03,
1471
- constant uint64_t & nb00,
1472
- constant uint64_t & nb01,
1473
- constant uint64_t & nb02,
1474
- constant uint64_t & nb03,
1475
- constant int64_t & ne10,
1476
- constant int64_t & ne11,
1477
- constant int64_t & ne12,
1478
- constant int64_t & ne13,
1479
- constant uint64_t & nb10,
1480
- constant uint64_t & nb11,
1481
- constant uint64_t & nb12,
1482
- constant uint64_t & nb13,
1483
- constant int64_t & ne0,
1484
- constant int64_t & ne1,
1485
- constant int64_t & ne2,
1486
- constant int64_t & ne3,
1487
- constant uint64_t & nb0,
1488
- constant uint64_t & nb1,
1489
- constant uint64_t & nb2,
1490
- constant uint64_t & nb3,
1491
- uint3 tgpig[[threadgroup_position_in_grid]],
1492
- uint3 tpitg[[thread_position_in_threadgroup]],
1493
- uint3 ntg[[threads_per_threadgroup]]) {
1494
-
1495
- const int64_t i03 = tgpig.z;
1496
- const int64_t i02 = tgpig.y;
1497
- const int64_t i01 = tgpig.x;
1498
-
1499
- const int64_t i13 = i03 % ne13;
2083
+ kernel void kernel_cpy_f32_q8_0(
2084
+ device const float * src0,
2085
+ device void * dst,
2086
+ constant int64_t & ne00,
2087
+ constant int64_t & ne01,
2088
+ constant int64_t & ne02,
2089
+ constant int64_t & ne03,
2090
+ constant uint64_t & nb00,
2091
+ constant uint64_t & nb01,
2092
+ constant uint64_t & nb02,
2093
+ constant uint64_t & nb03,
2094
+ constant int64_t & ne0,
2095
+ constant int64_t & ne1,
2096
+ constant int64_t & ne2,
2097
+ constant int64_t & ne3,
2098
+ constant uint64_t & nb0,
2099
+ constant uint64_t & nb1,
2100
+ constant uint64_t & nb2,
2101
+ constant uint64_t & nb3,
2102
+ uint3 tgpig[[threadgroup_position_in_grid]],
2103
+ uint3 tpitg[[thread_position_in_threadgroup]],
2104
+ uint3 ntg[[threads_per_threadgroup]]) {
2105
+ const int64_t i03 = tgpig[2];
2106
+ const int64_t i02 = tgpig[1];
2107
+ const int64_t i01 = tgpig[0];
2108
+
2109
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2110
+
2111
+ const int64_t i3 = n / (ne2*ne1*ne0);
2112
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2113
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2114
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK8_0;
2115
+
2116
+ device block_q8_0 * dst_data = (device block_q8_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2117
+
2118
+ for (int64_t i00 = tpitg.x*QK8_0; i00 < ne00; i00 += ntg.x*QK8_0) {
2119
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2120
+
2121
+ float amax = 0.0f; // absolute max
2122
+
2123
+ for (int j = 0; j < QK8_0; j++) {
2124
+ const float v = src[j];
2125
+ amax = MAX(amax, fabs(v));
2126
+ }
2127
+
2128
+ const float d = amax / ((1 << 7) - 1);
2129
+ const float id = d ? 1.0f/d : 0.0f;
2130
+
2131
+ dst_data[i00/QK8_0].d = d;
2132
+
2133
+ for (int j = 0; j < QK8_0; ++j) {
2134
+ const float x0 = src[j]*id;
2135
+
2136
+ dst_data[i00/QK8_0].qs[j] = round(x0);
2137
+ }
2138
+ }
2139
+ }
2140
+
2141
+ kernel void kernel_cpy_f32_q4_0(
2142
+ device const float * src0,
2143
+ device void * dst,
2144
+ constant int64_t & ne00,
2145
+ constant int64_t & ne01,
2146
+ constant int64_t & ne02,
2147
+ constant int64_t & ne03,
2148
+ constant uint64_t & nb00,
2149
+ constant uint64_t & nb01,
2150
+ constant uint64_t & nb02,
2151
+ constant uint64_t & nb03,
2152
+ constant int64_t & ne0,
2153
+ constant int64_t & ne1,
2154
+ constant int64_t & ne2,
2155
+ constant int64_t & ne3,
2156
+ constant uint64_t & nb0,
2157
+ constant uint64_t & nb1,
2158
+ constant uint64_t & nb2,
2159
+ constant uint64_t & nb3,
2160
+ uint3 tgpig[[threadgroup_position_in_grid]],
2161
+ uint3 tpitg[[thread_position_in_threadgroup]],
2162
+ uint3 ntg[[threads_per_threadgroup]]) {
2163
+ const int64_t i03 = tgpig[2];
2164
+ const int64_t i02 = tgpig[1];
2165
+ const int64_t i01 = tgpig[0];
2166
+
2167
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2168
+
2169
+ const int64_t i3 = n / (ne2*ne1*ne0);
2170
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2171
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2172
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_0;
2173
+
2174
+ device block_q4_0 * dst_data = (device block_q4_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2175
+
2176
+ for (int64_t i00 = tpitg.x*QK4_0; i00 < ne00; i00 += ntg.x*QK4_0) {
2177
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2178
+
2179
+ float amax = 0.0f; // absolute max
2180
+ float max = 0.0f;
2181
+
2182
+ for (int j = 0; j < QK4_0; j++) {
2183
+ const float v = src[j];
2184
+ if (amax < fabs(v)) {
2185
+ amax = fabs(v);
2186
+ max = v;
2187
+ }
2188
+ }
2189
+
2190
+ const float d = max / -8;
2191
+ const float id = d ? 1.0f/d : 0.0f;
2192
+
2193
+ dst_data[i00/QK4_0].d = d;
2194
+
2195
+ for (int j = 0; j < QK4_0/2; ++j) {
2196
+ const float x0 = src[0 + j]*id;
2197
+ const float x1 = src[QK4_0/2 + j]*id;
2198
+
2199
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
2200
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
2201
+
2202
+ dst_data[i00/QK4_0].qs[j] = xi0;
2203
+ dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
2204
+ }
2205
+ }
2206
+ }
2207
+
2208
+ kernel void kernel_cpy_f32_q4_1(
2209
+ device const float * src0,
2210
+ device void * dst,
2211
+ constant int64_t & ne00,
2212
+ constant int64_t & ne01,
2213
+ constant int64_t & ne02,
2214
+ constant int64_t & ne03,
2215
+ constant uint64_t & nb00,
2216
+ constant uint64_t & nb01,
2217
+ constant uint64_t & nb02,
2218
+ constant uint64_t & nb03,
2219
+ constant int64_t & ne0,
2220
+ constant int64_t & ne1,
2221
+ constant int64_t & ne2,
2222
+ constant int64_t & ne3,
2223
+ constant uint64_t & nb0,
2224
+ constant uint64_t & nb1,
2225
+ constant uint64_t & nb2,
2226
+ constant uint64_t & nb3,
2227
+ uint3 tgpig[[threadgroup_position_in_grid]],
2228
+ uint3 tpitg[[thread_position_in_threadgroup]],
2229
+ uint3 ntg[[threads_per_threadgroup]]) {
2230
+ const int64_t i03 = tgpig[2];
2231
+ const int64_t i02 = tgpig[1];
2232
+ const int64_t i01 = tgpig[0];
2233
+
2234
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2235
+
2236
+ const int64_t i3 = n / (ne2*ne1*ne0);
2237
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2238
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2239
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_1;
2240
+
2241
+ device block_q4_1 * dst_data = (device block_q4_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2242
+
2243
+ for (int64_t i00 = tpitg.x*QK4_1; i00 < ne00; i00 += ntg.x*QK4_1) {
2244
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2245
+
2246
+ float min = FLT_MAX;
2247
+ float max = -FLT_MAX;
2248
+
2249
+ for (int j = 0; j < QK4_1; j++) {
2250
+ const float v = src[j];
2251
+ if (min > v) min = v;
2252
+ if (max < v) max = v;
2253
+ }
2254
+
2255
+ const float d = (max - min) / ((1 << 4) - 1);
2256
+ const float id = d ? 1.0f/d : 0.0f;
2257
+
2258
+ dst_data[i00/QK4_1].d = d;
2259
+ dst_data[i00/QK4_1].m = min;
2260
+
2261
+ for (int j = 0; j < QK4_1/2; ++j) {
2262
+ const float x0 = (src[0 + j] - min)*id;
2263
+ const float x1 = (src[QK4_1/2 + j] - min)*id;
2264
+
2265
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
2266
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
2267
+
2268
+ dst_data[i00/QK4_1].qs[j] = xi0;
2269
+ dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
2270
+ }
2271
+ }
2272
+ }
2273
+
2274
+ kernel void kernel_concat(
2275
+ device const char * src0,
2276
+ device const char * src1,
2277
+ device char * dst,
2278
+ constant int64_t & ne00,
2279
+ constant int64_t & ne01,
2280
+ constant int64_t & ne02,
2281
+ constant int64_t & ne03,
2282
+ constant uint64_t & nb00,
2283
+ constant uint64_t & nb01,
2284
+ constant uint64_t & nb02,
2285
+ constant uint64_t & nb03,
2286
+ constant int64_t & ne10,
2287
+ constant int64_t & ne11,
2288
+ constant int64_t & ne12,
2289
+ constant int64_t & ne13,
2290
+ constant uint64_t & nb10,
2291
+ constant uint64_t & nb11,
2292
+ constant uint64_t & nb12,
2293
+ constant uint64_t & nb13,
2294
+ constant int64_t & ne0,
2295
+ constant int64_t & ne1,
2296
+ constant int64_t & ne2,
2297
+ constant int64_t & ne3,
2298
+ constant uint64_t & nb0,
2299
+ constant uint64_t & nb1,
2300
+ constant uint64_t & nb2,
2301
+ constant uint64_t & nb3,
2302
+ uint3 tgpig[[threadgroup_position_in_grid]],
2303
+ uint3 tpitg[[thread_position_in_threadgroup]],
2304
+ uint3 ntg[[threads_per_threadgroup]]) {
2305
+
2306
+ const int64_t i03 = tgpig.z;
2307
+ const int64_t i02 = tgpig.y;
2308
+ const int64_t i01 = tgpig.x;
2309
+
2310
+ const int64_t i13 = i03 % ne13;
1500
2311
  const int64_t i12 = i02 % ne12;
1501
2312
  const int64_t i11 = i01 % ne11;
1502
2313
 
1503
- device const char * src0_ptr = src0 + i03 * nb03 + i02 * nb02 + i01 * nb01 + tpitg.x*nb00;
2314
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
1504
2315
  device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
1505
2316
  device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
1506
2317
 
@@ -1608,32 +2419,39 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
1608
2419
 
1609
2420
  //====================================== dot products =========================
1610
2421
 
1611
- kernel void kernel_mul_mv_q2_K_f32(
2422
+ void kernel_mul_mv_q2_K_f32_impl(
1612
2423
  device const void * src0,
1613
2424
  device const float * src1,
1614
2425
  device float * dst,
1615
2426
  constant int64_t & ne00,
1616
- constant int64_t & ne01[[buffer(4)]],
1617
- constant int64_t & ne02[[buffer(5)]],
1618
- constant int64_t & ne10[[buffer(9)]],
1619
- constant int64_t & ne12[[buffer(11)]],
1620
- constant int64_t & ne0[[buffer(15)]],
1621
- constant int64_t & ne1[[buffer(16)]],
1622
- constant uint & gqa[[buffer(17)]],
2427
+ constant int64_t & ne01,
2428
+ constant int64_t & ne02,
2429
+ constant int64_t & ne10,
2430
+ constant int64_t & ne12,
2431
+ constant int64_t & ne0,
2432
+ constant int64_t & ne1,
2433
+ constant uint & r2,
2434
+ constant uint & r3,
1623
2435
  uint3 tgpig[[threadgroup_position_in_grid]],
1624
- uint tiisg[[thread_index_in_simdgroup]],
1625
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2436
+ uint tiisg[[thread_index_in_simdgroup]],
2437
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1626
2438
 
1627
2439
  const int nb = ne00/QK_K;
1628
2440
  const int r0 = tgpig.x;
1629
2441
  const int r1 = tgpig.y;
1630
- const int r2 = tgpig.z;
2442
+ const int im = tgpig.z;
1631
2443
 
1632
2444
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1633
2445
  const int ib_row = first_row * nb;
1634
- const uint offset0 = r2/gqa*(nb*ne0);
2446
+
2447
+ const uint i12 = im%ne12;
2448
+ const uint i13 = im/ne12;
2449
+
2450
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2451
+
1635
2452
  device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
1636
- device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2453
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2454
+
1637
2455
  float yl[32];
1638
2456
  float sumf[N_DST]={0.f}, all_sum;
1639
2457
 
@@ -1642,11 +2460,11 @@ kernel void kernel_mul_mv_q2_K_f32(
1642
2460
  #if QK_K == 256
1643
2461
  const int ix = tiisg/8; // 0...3
1644
2462
  const int it = tiisg%8; // 0...7
1645
- const int im = it/4; // 0 or 1
2463
+ const int iq = it/4; // 0 or 1
1646
2464
  const int ir = it%4; // 0...3
1647
2465
  const int is = (8*ir)/16;// 0 or 1
1648
2466
 
1649
- device const float * y4 = y + ix * QK_K + 128 * im + 8 * ir;
2467
+ device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
1650
2468
 
1651
2469
  for (int ib = ix; ib < nb; ib += 4) {
1652
2470
 
@@ -1658,8 +2476,8 @@ kernel void kernel_mul_mv_q2_K_f32(
1658
2476
  yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
1659
2477
  }
1660
2478
 
1661
- device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*im + is;
1662
- device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
2479
+ device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*iq + is;
2480
+ device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
1663
2481
  device const half * dh = &x[ib].d;
1664
2482
 
1665
2483
  for (int row = 0; row < N_DST; row++) {
@@ -1746,13 +2564,13 @@ kernel void kernel_mul_mv_q2_K_f32(
1746
2564
  for (int row = 0; row < N_DST; ++row) {
1747
2565
  all_sum = simd_sum(sumf[row]);
1748
2566
  if (tiisg == 0) {
1749
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
2567
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
1750
2568
  }
1751
2569
  }
1752
2570
  }
1753
2571
 
1754
- #if QK_K == 256
1755
- kernel void kernel_mul_mv_q3_K_f32(
2572
+ [[host_name("kernel_mul_mv_q2_K_f32")]]
2573
+ kernel void kernel_mul_mv_q2_K_f32(
1756
2574
  device const void * src0,
1757
2575
  device const float * src1,
1758
2576
  device float * dst,
@@ -1761,23 +2579,50 @@ kernel void kernel_mul_mv_q3_K_f32(
1761
2579
  constant int64_t & ne02[[buffer(5)]],
1762
2580
  constant int64_t & ne10[[buffer(9)]],
1763
2581
  constant int64_t & ne12[[buffer(11)]],
1764
- constant int64_t & ne0[[buffer(15)]],
1765
- constant int64_t & ne1[[buffer(16)]],
1766
- constant uint & gqa[[buffer(17)]],
2582
+ constant int64_t & ne0 [[buffer(15)]],
2583
+ constant int64_t & ne1 [[buffer(16)]],
2584
+ constant uint & r2 [[buffer(17)]],
2585
+ constant uint & r3 [[buffer(18)]],
1767
2586
  uint3 tgpig[[threadgroup_position_in_grid]],
1768
- uint tiisg[[thread_index_in_simdgroup]],
1769
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2587
+ uint tiisg[[thread_index_in_simdgroup]],
2588
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2589
+
2590
+ kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
2591
+ }
2592
+
2593
+ #if QK_K == 256
2594
+ void kernel_mul_mv_q3_K_f32_impl(
2595
+ device const void * src0,
2596
+ device const float * src1,
2597
+ device float * dst,
2598
+ constant int64_t & ne00,
2599
+ constant int64_t & ne01,
2600
+ constant int64_t & ne02,
2601
+ constant int64_t & ne10,
2602
+ constant int64_t & ne12,
2603
+ constant int64_t & ne0,
2604
+ constant int64_t & ne1,
2605
+ constant uint & r2,
2606
+ constant uint & r3,
2607
+ uint3 tgpig[[threadgroup_position_in_grid]],
2608
+ uint tiisg[[thread_index_in_simdgroup]],
2609
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1770
2610
 
1771
2611
  const int nb = ne00/QK_K;
1772
2612
 
1773
2613
  const int64_t r0 = tgpig.x;
1774
2614
  const int64_t r1 = tgpig.y;
1775
- const int64_t r2 = tgpig.z;
2615
+ const int64_t im = tgpig.z;
1776
2616
 
1777
2617
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
1778
- const uint offset0 = r2/gqa*(nb*ne0);
2618
+
2619
+ const uint i12 = im%ne12;
2620
+ const uint i13 = im/ne12;
2621
+
2622
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2623
+
1779
2624
  device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
1780
- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2625
+ device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
1781
2626
 
1782
2627
  float yl[32];
1783
2628
 
@@ -1899,40 +2744,47 @@ kernel void kernel_mul_mv_q3_K_f32(
1899
2744
  }
1900
2745
  if (tiisg == 0) {
1901
2746
  for (int row = 0; row < 2; ++row) {
1902
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
2747
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = sumf1[row];
1903
2748
  }
1904
2749
  }
1905
2750
  }
1906
2751
  #else
1907
- kernel void kernel_mul_mv_q3_K_f32(
2752
+ void kernel_mul_mv_q3_K_f32_impl(
1908
2753
  device const void * src0,
1909
2754
  device const float * src1,
1910
2755
  device float * dst,
1911
2756
  constant int64_t & ne00,
1912
- constant int64_t & ne01[[buffer(4)]],
1913
- constant int64_t & ne02[[buffer(5)]],
1914
- constant int64_t & ne10[[buffer(9)]],
1915
- constant int64_t & ne12[[buffer(11)]],
1916
- constant int64_t & ne0[[buffer(15)]],
1917
- constant int64_t & ne1[[buffer(16)]],
1918
- constant uint & gqa[[buffer(17)]],
2757
+ constant int64_t & ne01,
2758
+ constant int64_t & ne02,
2759
+ constant int64_t & ne10,
2760
+ constant int64_t & ne12,
2761
+ constant int64_t & ne0,
2762
+ constant int64_t & ne1,
2763
+ constant uint & r2,
2764
+ constant uint & r3,
1919
2765
  uint3 tgpig[[threadgroup_position_in_grid]],
1920
- uint tiisg[[thread_index_in_simdgroup]],
1921
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2766
+ uint tiisg[[thread_index_in_simdgroup]],
2767
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1922
2768
 
1923
2769
  const int nb = ne00/QK_K;
1924
2770
 
1925
2771
  const int64_t r0 = tgpig.x;
1926
2772
  const int64_t r1 = tgpig.y;
1927
- const int64_t r2 = tgpig.z;
2773
+ const int64_t im = tgpig.z;
1928
2774
 
1929
2775
  const int row = 2 * r0 + sgitg;
1930
- const uint offset0 = r2/gqa*(nb*ne0);
2776
+
2777
+ const uint i12 = im%ne12;
2778
+ const uint i13 = im/ne12;
2779
+
2780
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2781
+
1931
2782
  device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
1932
- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2783
+ device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2784
+
1933
2785
  const int ix = tiisg/4;
1934
2786
  const int il = 4 * (tiisg%4);// 0, 4, 8, 12
1935
- const int im = il/8; // 0, 0, 1, 1
2787
+ const int iq = il/8; // 0, 0, 1, 1
1936
2788
  const int in = il%8; // 0, 4, 0, 4
1937
2789
 
1938
2790
  float2 sum = {0.f, 0.f};
@@ -1952,7 +2804,7 @@ kernel void kernel_mul_mv_q3_K_f32(
1952
2804
  const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
1953
2805
 
1954
2806
  for (int l = 0; l < 4; l += 2) {
1955
- const uint16_t hm = h[l/2] >> im;
2807
+ const uint16_t hm = h[l/2] >> iq;
1956
2808
  sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
1957
2809
  + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
1958
2810
  + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
@@ -1968,28 +2820,50 @@ kernel void kernel_mul_mv_q3_K_f32(
1968
2820
 
1969
2821
  const float tot = simd_sum(sumf);
1970
2822
  if (tiisg == 0) {
1971
- dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
2823
+ dst[r1*ne0 + im*ne0*ne1 + row] = tot;
1972
2824
  }
1973
2825
 
1974
2826
  }
1975
2827
  #endif
1976
2828
 
2829
+ [[host_name("kernel_mul_mv_q3_K_f32")]]
2830
+ kernel void kernel_mul_mv_q3_K_f32(
2831
+ device const void * src0,
2832
+ device const float * src1,
2833
+ device float * dst,
2834
+ constant int64_t & ne00,
2835
+ constant int64_t & ne01[[buffer(4)]],
2836
+ constant int64_t & ne02[[buffer(5)]],
2837
+ constant int64_t & ne10[[buffer(9)]],
2838
+ constant int64_t & ne12[[buffer(11)]],
2839
+ constant int64_t & ne0 [[buffer(15)]],
2840
+ constant int64_t & ne1 [[buffer(16)]],
2841
+ constant uint & r2 [[buffer(17)]],
2842
+ constant uint & r3 [[buffer(18)]],
2843
+ uint3 tgpig[[threadgroup_position_in_grid]],
2844
+ uint tiisg[[thread_index_in_simdgroup]],
2845
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2846
+
2847
+ kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
2848
+ }
2849
+
1977
2850
  #if QK_K == 256
1978
- kernel void kernel_mul_mv_q4_K_f32(
2851
+ void kernel_mul_mv_q4_K_f32_impl(
1979
2852
  device const void * src0,
1980
2853
  device const float * src1,
1981
2854
  device float * dst,
1982
2855
  constant int64_t & ne00,
1983
- constant int64_t & ne01 [[buffer(4)]],
1984
- constant int64_t & ne02 [[buffer(5)]],
1985
- constant int64_t & ne10 [[buffer(9)]],
1986
- constant int64_t & ne12 [[buffer(11)]],
1987
- constant int64_t & ne0 [[buffer(15)]],
1988
- constant int64_t & ne1 [[buffer(16)]],
1989
- constant uint & gqa [[buffer(17)]],
2856
+ constant int64_t & ne01,
2857
+ constant int64_t & ne02,
2858
+ constant int64_t & ne10,
2859
+ constant int64_t & ne12,
2860
+ constant int64_t & ne0,
2861
+ constant int64_t & ne1,
2862
+ constant uint & r2,
2863
+ constant uint & r3,
1990
2864
  uint3 tgpig[[threadgroup_position_in_grid]],
1991
- uint tiisg[[thread_index_in_simdgroup]],
1992
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2865
+ uint tiisg[[thread_index_in_simdgroup]],
2866
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1993
2867
 
1994
2868
  const uint16_t kmask1 = 0x3f3f;
1995
2869
  const uint16_t kmask2 = 0x0f0f;
@@ -1997,26 +2871,32 @@ kernel void kernel_mul_mv_q4_K_f32(
1997
2871
 
1998
2872
  const int ix = tiisg/8; // 0...3
1999
2873
  const int it = tiisg%8; // 0...7
2000
- const int im = it/4; // 0 or 1
2874
+ const int iq = it/4; // 0 or 1
2001
2875
  const int ir = it%4; // 0...3
2002
2876
 
2003
2877
  const int nb = ne00/QK_K;
2004
2878
  const int r0 = tgpig.x;
2005
2879
  const int r1 = tgpig.y;
2006
- const int r2 = tgpig.z;
2880
+ const int im = tgpig.z;
2007
2881
  //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
2008
2882
  const int first_row = r0 * N_DST;
2009
2883
  const int ib_row = first_row * nb;
2010
- const uint offset0 = r2/gqa*(nb*ne0);
2884
+
2885
+ const uint i12 = im%ne12;
2886
+ const uint i13 = im/ne12;
2887
+
2888
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2889
+
2011
2890
  device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
2012
- device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2891
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2892
+
2013
2893
  float yl[16];
2014
2894
  float yh[16];
2015
2895
  float sumf[N_DST]={0.f}, all_sum;
2016
2896
 
2017
2897
  const int step = sizeof(block_q4_K) * nb / 2;
2018
2898
 
2019
- device const float * y4 = y + ix * QK_K + 64 * im + 8 * ir;
2899
+ device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
2020
2900
 
2021
2901
  uint16_t sc16[4];
2022
2902
  thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
@@ -2031,8 +2911,8 @@ kernel void kernel_mul_mv_q4_K_f32(
2031
2911
  yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
2032
2912
  }
2033
2913
 
2034
- device const uint16_t * sc = (device const uint16_t *)x[ib].scales + im;
2035
- device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
2914
+ device const uint16_t * sc = (device const uint16_t *)x[ib].scales + iq;
2915
+ device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
2036
2916
  device const half * dh = &x[ib].d;
2037
2917
 
2038
2918
  for (int row = 0; row < N_DST; row++) {
@@ -2076,23 +2956,24 @@ kernel void kernel_mul_mv_q4_K_f32(
2076
2956
  for (int row = 0; row < N_DST; ++row) {
2077
2957
  all_sum = simd_sum(sumf[row]);
2078
2958
  if (tiisg == 0) {
2079
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
2959
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
2080
2960
  }
2081
2961
  }
2082
2962
  }
2083
2963
  #else
2084
- kernel void kernel_mul_mv_q4_K_f32(
2964
+ void kernel_mul_mv_q4_K_f32_impl(
2085
2965
  device const void * src0,
2086
2966
  device const float * src1,
2087
2967
  device float * dst,
2088
2968
  constant int64_t & ne00,
2089
- constant int64_t & ne01[[buffer(4)]],
2090
- constant int64_t & ne02[[buffer(5)]],
2091
- constant int64_t & ne10[[buffer(9)]],
2092
- constant int64_t & ne12[[buffer(11)]],
2093
- constant int64_t & ne0[[buffer(15)]],
2094
- constant int64_t & ne1[[buffer(16)]],
2095
- constant uint & gqa[[buffer(17)]],
2969
+ constant int64_t & ne01,
2970
+ constant int64_t & ne02,
2971
+ constant int64_t & ne10,
2972
+ constant int64_t & ne12,
2973
+ constant int64_t & ne0,
2974
+ constant int64_t & ne1,
2975
+ constant uint & r2,
2976
+ constant uint & r3,
2096
2977
  uint3 tgpig[[threadgroup_position_in_grid]],
2097
2978
  uint tiisg[[thread_index_in_simdgroup]],
2098
2979
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2103,12 +2984,18 @@ kernel void kernel_mul_mv_q4_K_f32(
2103
2984
  const int nb = ne00/QK_K;
2104
2985
  const int r0 = tgpig.x;
2105
2986
  const int r1 = tgpig.y;
2106
- const int r2 = tgpig.z;
2987
+ const int im = tgpig.z;
2107
2988
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
2108
2989
  const int ib_row = first_row * nb;
2109
- const uint offset0 = r2/gqa*(nb*ne0);
2990
+
2991
+ const uint i12 = im%ne12;
2992
+ const uint i13 = im/ne12;
2993
+
2994
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2995
+
2110
2996
  device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
2111
- device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2997
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2998
+
2112
2999
  float yl[8];
2113
3000
  float yh[8];
2114
3001
  float sumf[N_DST]={0.f}, all_sum;
@@ -2164,13 +3051,14 @@ kernel void kernel_mul_mv_q4_K_f32(
2164
3051
  for (int row = 0; row < N_DST; ++row) {
2165
3052
  all_sum = simd_sum(sumf[row]);
2166
3053
  if (tiisg == 0) {
2167
- dst[r1*ne0+ r2*ne0*ne1 + first_row + row] = all_sum;
3054
+ dst[r1*ne0+ im*ne0*ne1 + first_row + row] = all_sum;
2168
3055
  }
2169
3056
  }
2170
3057
  }
2171
3058
  #endif
2172
3059
 
2173
- kernel void kernel_mul_mv_q5_K_f32(
3060
+ [[host_name("kernel_mul_mv_q4_K_f32")]]
3061
+ kernel void kernel_mul_mv_q4_K_f32(
2174
3062
  device const void * src0,
2175
3063
  device const float * src1,
2176
3064
  device float * dst,
@@ -2179,23 +3067,49 @@ kernel void kernel_mul_mv_q5_K_f32(
2179
3067
  constant int64_t & ne02[[buffer(5)]],
2180
3068
  constant int64_t & ne10[[buffer(9)]],
2181
3069
  constant int64_t & ne12[[buffer(11)]],
2182
- constant int64_t & ne0[[buffer(15)]],
2183
- constant int64_t & ne1[[buffer(16)]],
2184
- constant uint & gqa[[buffer(17)]],
3070
+ constant int64_t & ne0 [[buffer(15)]],
3071
+ constant int64_t & ne1 [[buffer(16)]],
3072
+ constant uint & r2 [[buffer(17)]],
3073
+ constant uint & r3 [[buffer(18)]],
2185
3074
  uint3 tgpig[[threadgroup_position_in_grid]],
2186
3075
  uint tiisg[[thread_index_in_simdgroup]],
2187
3076
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
2188
3077
 
3078
+ kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
3079
+ }
3080
+
3081
+ void kernel_mul_mv_q5_K_f32_impl(
3082
+ device const void * src0,
3083
+ device const float * src1,
3084
+ device float * dst,
3085
+ constant int64_t & ne00,
3086
+ constant int64_t & ne01,
3087
+ constant int64_t & ne02,
3088
+ constant int64_t & ne10,
3089
+ constant int64_t & ne12,
3090
+ constant int64_t & ne0,
3091
+ constant int64_t & ne1,
3092
+ constant uint & r2,
3093
+ constant uint & r3,
3094
+ uint3 tgpig[[threadgroup_position_in_grid]],
3095
+ uint tiisg[[thread_index_in_simdgroup]],
3096
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3097
+
2189
3098
  const int nb = ne00/QK_K;
2190
3099
 
2191
3100
  const int64_t r0 = tgpig.x;
2192
3101
  const int64_t r1 = tgpig.y;
2193
- const int r2 = tgpig.z;
3102
+ const int im = tgpig.z;
2194
3103
 
2195
3104
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
2196
- const uint offset0 = r2/gqa*(nb*ne0);
3105
+
3106
+ const uint i12 = im%ne12;
3107
+ const uint i13 = im/ne12;
3108
+
3109
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
3110
+
2197
3111
  device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
2198
- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
3112
+ device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2199
3113
 
2200
3114
  float sumf[2]={0.f};
2201
3115
 
@@ -2211,15 +3125,15 @@ kernel void kernel_mul_mv_q5_K_f32(
2211
3125
 
2212
3126
  const int tid = tiisg/4;
2213
3127
  const int ix = tiisg%4;
2214
- const int im = tid/4;
3128
+ const int iq = tid/4;
2215
3129
  const int ir = tid%4;
2216
3130
  const int n = 8;
2217
3131
 
2218
3132
  const int l0 = n*ir;
2219
- const int q_offset = 32*im + l0;
2220
- const int y_offset = 64*im + l0;
3133
+ const int q_offset = 32*iq + l0;
3134
+ const int y_offset = 64*iq + l0;
2221
3135
 
2222
- const uint8_t hm1 = 1u << (2*im);
3136
+ const uint8_t hm1 = 1u << (2*iq);
2223
3137
  const uint8_t hm2 = hm1 << 1;
2224
3138
  const uint8_t hm3 = hm1 << 4;
2225
3139
  const uint8_t hm4 = hm2 << 4;
@@ -2234,7 +3148,7 @@ kernel void kernel_mul_mv_q5_K_f32(
2234
3148
  device const uint8_t * q1 = x[i].qs + q_offset;
2235
3149
  device const uint8_t * qh = x[i].qh + l0;
2236
3150
  device const half * dh = &x[i].d;
2237
- device const uint16_t * a = (device const uint16_t *)x[i].scales + im;
3151
+ device const uint16_t * a = (device const uint16_t *)x[i].scales + iq;
2238
3152
 
2239
3153
  device const float * y2 = y1 + 128;
2240
3154
  float4 sumy = {0.f, 0.f, 0.f, 0.f};
@@ -2290,7 +3204,7 @@ kernel void kernel_mul_mv_q5_K_f32(
2290
3204
 
2291
3205
  const int il = 4 * (tiisg/8); // 0, 4, 8, 12
2292
3206
  const int ix = tiisg%8;
2293
- const int im = il/8; // 0, 0, 1, 1
3207
+ const int iq = il/8; // 0, 0, 1, 1
2294
3208
  const int in = il%8; // 0, 4, 0, 4
2295
3209
 
2296
3210
  device const float * y = yy + ix*QK_K + il;
@@ -2315,7 +3229,7 @@ kernel void kernel_mul_mv_q5_K_f32(
2315
3229
 
2316
3230
  float2 acc = {0.f, 0.f};
2317
3231
  for (int l = 0; l < 4; ++l) {
2318
- const uint8_t hl = h[l] >> im;
3232
+ const uint8_t hl = h[l] >> iq;
2319
3233
  acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
2320
3234
  + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
2321
3235
  acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
@@ -2337,27 +3251,48 @@ kernel void kernel_mul_mv_q5_K_f32(
2337
3251
  for (int row = 0; row < 2; ++row) {
2338
3252
  const float tot = simd_sum(sumf[row]);
2339
3253
  if (tiisg == 0) {
2340
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
3254
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
2341
3255
  }
2342
3256
  }
3257
+ }
3258
+
3259
+ [[host_name("kernel_mul_mv_q5_K_f32")]]
3260
+ kernel void kernel_mul_mv_q5_K_f32(
3261
+ device const void * src0,
3262
+ device const float * src1,
3263
+ device float * dst,
3264
+ constant int64_t & ne00,
3265
+ constant int64_t & ne01[[buffer(4)]],
3266
+ constant int64_t & ne02[[buffer(5)]],
3267
+ constant int64_t & ne10[[buffer(9)]],
3268
+ constant int64_t & ne12[[buffer(11)]],
3269
+ constant int64_t & ne0 [[buffer(15)]],
3270
+ constant int64_t & ne1 [[buffer(16)]],
3271
+ constant uint & r2 [[buffer(17)]],
3272
+ constant uint & r3 [[buffer(18)]],
3273
+ uint3 tgpig[[threadgroup_position_in_grid]],
3274
+ uint tiisg[[thread_index_in_simdgroup]],
3275
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2343
3276
 
3277
+ kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
2344
3278
  }
2345
3279
 
2346
- kernel void kernel_mul_mv_q6_K_f32(
3280
+ void kernel_mul_mv_q6_K_f32_impl(
2347
3281
  device const void * src0,
2348
3282
  device const float * src1,
2349
3283
  device float * dst,
2350
3284
  constant int64_t & ne00,
2351
- constant int64_t & ne01[[buffer(4)]],
2352
- constant int64_t & ne02[[buffer(5)]],
2353
- constant int64_t & ne10[[buffer(9)]],
2354
- constant int64_t & ne12[[buffer(11)]],
2355
- constant int64_t & ne0[[buffer(15)]],
2356
- constant int64_t & ne1[[buffer(16)]],
2357
- constant uint & gqa[[buffer(17)]],
3285
+ constant int64_t & ne01,
3286
+ constant int64_t & ne02,
3287
+ constant int64_t & ne10,
3288
+ constant int64_t & ne12,
3289
+ constant int64_t & ne0,
3290
+ constant int64_t & ne1,
3291
+ constant uint & r2,
3292
+ constant uint & r3,
2358
3293
  uint3 tgpig[[threadgroup_position_in_grid]],
2359
- uint tiisg[[thread_index_in_simdgroup]],
2360
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3294
+ uint tiisg[[thread_index_in_simdgroup]],
3295
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2361
3296
 
2362
3297
  const uint8_t kmask1 = 0x03;
2363
3298
  const uint8_t kmask2 = 0x0C;
@@ -2368,12 +3303,17 @@ kernel void kernel_mul_mv_q6_K_f32(
2368
3303
 
2369
3304
  const int64_t r0 = tgpig.x;
2370
3305
  const int64_t r1 = tgpig.y;
2371
- const int r2 = tgpig.z;
3306
+ const int im = tgpig.z;
2372
3307
 
2373
3308
  const int row = 2 * r0 + sgitg;
2374
- const uint offset0 = r2/gqa*(nb*ne0);
3309
+
3310
+ const uint i12 = im%ne12;
3311
+ const uint i13 = im/ne12;
3312
+
3313
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
3314
+
2375
3315
  device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
2376
- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
3316
+ device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2377
3317
 
2378
3318
  float sumf = 0;
2379
3319
 
@@ -2439,10 +3379,31 @@ kernel void kernel_mul_mv_q6_K_f32(
2439
3379
 
2440
3380
  const float tot = simd_sum(sumf);
2441
3381
  if (tiisg == 0) {
2442
- dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
3382
+ dst[r1*ne0 + im*ne0*ne1 + row] = tot;
2443
3383
  }
2444
3384
  }
2445
3385
 
3386
+ [[host_name("kernel_mul_mv_q6_K_f32")]]
3387
+ kernel void kernel_mul_mv_q6_K_f32(
3388
+ device const void * src0,
3389
+ device const float * src1,
3390
+ device float * dst,
3391
+ constant int64_t & ne00,
3392
+ constant int64_t & ne01[[buffer(4)]],
3393
+ constant int64_t & ne02[[buffer(5)]],
3394
+ constant int64_t & ne10[[buffer(9)]],
3395
+ constant int64_t & ne12[[buffer(11)]],
3396
+ constant int64_t & ne0 [[buffer(15)]],
3397
+ constant int64_t & ne1 [[buffer(16)]],
3398
+ constant uint & r2 [[buffer(17)]],
3399
+ constant uint & r3 [[buffer(18)]],
3400
+ uint3 tgpig[[threadgroup_position_in_grid]],
3401
+ uint tiisg[[thread_index_in_simdgroup]],
3402
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3403
+
3404
+ kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
3405
+ }
3406
+
2446
3407
  //============================= templates and their specializations =============================
2447
3408
 
2448
3409
  // NOTE: this is not dequantizing - we are simply fitting the template
@@ -2560,10 +3521,10 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
2560
3521
 
2561
3522
  template <typename type4x4>
2562
3523
  void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
2563
- const half d = xb->d;
2564
- const half min = xb->dmin;
3524
+ const float d = xb->d;
3525
+ const float min = xb->dmin;
2565
3526
  device const uint8_t * q = (device const uint8_t *)xb->qs;
2566
- half dl, ml;
3527
+ float dl, ml;
2567
3528
  uint8_t sc = xb->scales[il];
2568
3529
 
2569
3530
  #if QK_K == 256
@@ -2633,10 +3594,10 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
2633
3594
  q = q + (il/4) * 32 + 16 * (il&1);
2634
3595
  il = il & 3;
2635
3596
  const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
2636
- const half d = il < 2 ? xb->d : xb->d / 16.h;
2637
- const half min = xb->dmin;
2638
- const half dl = d * sc[0];
2639
- const half ml = min * sc[1];
3597
+ const float d = il < 2 ? xb->d : xb->d / 16.h;
3598
+ const float min = xb->dmin;
3599
+ const float dl = d * sc[0];
3600
+ const float ml = min * sc[1];
2640
3601
  #else
2641
3602
  q = q + 16 * (il&1);
2642
3603
  device const uint8_t * s = xb->scales;
@@ -2663,13 +3624,13 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
2663
3624
  uint8_t ul = 1 << (il/2);
2664
3625
  il = il & 3;
2665
3626
  const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
2666
- const half d = il < 2 ? xb->d : xb->d / 16.h;
2667
- const half min = xb->dmin;
2668
- const half dl = d * sc[0];
2669
- const half ml = min * sc[1];
3627
+ const float d = il < 2 ? xb->d : xb->d / 16.h;
3628
+ const float min = xb->dmin;
3629
+ const float dl = d * sc[0];
3630
+ const float ml = min * sc[1];
2670
3631
 
2671
- const ushort mask = il<2 ? 0x0F : 0xF0;
2672
- const half qh_val = il<2 ? 16.h : 256.h;
3632
+ const ushort mask = il<2 ? 0x0F : 0xF0;
3633
+ const float qh_val = il<2 ? 16.f : 256.f;
2673
3634
  for (int i = 0; i < 16; ++i) {
2674
3635
  reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
2675
3636
  }
@@ -2717,22 +3678,90 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
2717
3678
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
2718
3679
  kernel void kernel_get_rows(
2719
3680
  device const void * src0,
2720
- device const int * src1,
3681
+ device const char * src1,
2721
3682
  device float * dst,
2722
3683
  constant int64_t & ne00,
2723
3684
  constant uint64_t & nb01,
3685
+ constant uint64_t & nb02,
3686
+ constant int64_t & ne10,
3687
+ constant uint64_t & nb10,
3688
+ constant uint64_t & nb11,
2724
3689
  constant uint64_t & nb1,
2725
- uint tgpig[[threadgroup_position_in_grid]],
3690
+ constant uint64_t & nb2,
3691
+ uint3 tgpig[[threadgroup_position_in_grid]],
2726
3692
  uint tiitg[[thread_index_in_threadgroup]],
2727
- uint tptg[[threads_per_threadgroup]]) {
2728
- const int i = tgpig;
2729
- const int r = ((device int32_t *) src1)[i];
3693
+ uint3 tptg [[threads_per_threadgroup]]) {
3694
+ //const int64_t i = tgpig;
3695
+ //const int64_t r = ((device int32_t *) src1)[i];
3696
+
3697
+ const int64_t i10 = tgpig.x;
3698
+ const int64_t i11 = tgpig.y;
2730
3699
 
2731
- for (int ind = tiitg; ind < ne00/16; ind += tptg) {
3700
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
3701
+
3702
+ const int64_t i02 = i11;
3703
+
3704
+ for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
2732
3705
  float4x4 temp;
2733
3706
  dequantize_func(
2734
- ((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
2735
- *(((device float4x4 *) ((device char *) dst + i*nb1)) + ind) = temp;
3707
+ ((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
3708
+ *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
3709
+ }
3710
+ }
3711
+
3712
+ kernel void kernel_get_rows_f32(
3713
+ device const void * src0,
3714
+ device const char * src1,
3715
+ device float * dst,
3716
+ constant int64_t & ne00,
3717
+ constant uint64_t & nb01,
3718
+ constant uint64_t & nb02,
3719
+ constant int64_t & ne10,
3720
+ constant uint64_t & nb10,
3721
+ constant uint64_t & nb11,
3722
+ constant uint64_t & nb1,
3723
+ constant uint64_t & nb2,
3724
+ uint3 tgpig[[threadgroup_position_in_grid]],
3725
+ uint tiitg[[thread_index_in_threadgroup]],
3726
+ uint3 tptg [[threads_per_threadgroup]]) {
3727
+ const int64_t i10 = tgpig.x;
3728
+ const int64_t i11 = tgpig.y;
3729
+
3730
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
3731
+
3732
+ const int64_t i02 = i11;
3733
+
3734
+ for (int ind = tiitg; ind < ne00; ind += tptg.x) {
3735
+ ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
3736
+ ((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
3737
+ }
3738
+ }
3739
+
3740
+ kernel void kernel_get_rows_f16(
3741
+ device const void * src0,
3742
+ device const char * src1,
3743
+ device float * dst,
3744
+ constant int64_t & ne00,
3745
+ constant uint64_t & nb01,
3746
+ constant uint64_t & nb02,
3747
+ constant int64_t & ne10,
3748
+ constant uint64_t & nb10,
3749
+ constant uint64_t & nb11,
3750
+ constant uint64_t & nb1,
3751
+ constant uint64_t & nb2,
3752
+ uint3 tgpig[[threadgroup_position_in_grid]],
3753
+ uint tiitg[[thread_index_in_threadgroup]],
3754
+ uint3 tptg [[threads_per_threadgroup]]) {
3755
+ const int64_t i10 = tgpig.x;
3756
+ const int64_t i11 = tgpig.y;
3757
+
3758
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
3759
+
3760
+ const int64_t i02 = i11;
3761
+
3762
+ for (int ind = tiitg; ind < ne00; ind += tptg.x) {
3763
+ ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
3764
+ ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
2736
3765
  }
2737
3766
  }
2738
3767
 
@@ -2749,24 +3778,25 @@ kernel void kernel_get_rows(
2749
3778
 
2750
3779
  // each block_q contains 16*nl weights
2751
3780
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
2752
- kernel void kernel_mul_mm(device const uchar * src0,
2753
- device const uchar * src1,
2754
- device float * dst,
2755
- constant int64_t & ne00,
2756
- constant int64_t & ne02,
2757
- constant int64_t & nb01,
2758
- constant int64_t & nb02,
2759
- constant int64_t & ne12,
2760
- constant int64_t & nb10,
2761
- constant int64_t & nb11,
2762
- constant int64_t & nb12,
2763
- constant int64_t & ne0,
2764
- constant int64_t & ne1,
2765
- constant uint & gqa,
2766
- threadgroup uchar * shared_memory [[threadgroup(0)]],
2767
- uint3 tgpig[[threadgroup_position_in_grid]],
2768
- uint tiitg[[thread_index_in_threadgroup]],
2769
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3781
+ void kernel_mul_mm_impl(device const uchar * src0,
3782
+ device const uchar * src1,
3783
+ device float * dst,
3784
+ constant int64_t & ne00,
3785
+ constant int64_t & ne02,
3786
+ constant int64_t & nb01,
3787
+ constant int64_t & nb02,
3788
+ constant int64_t & ne12,
3789
+ constant int64_t & nb10,
3790
+ constant int64_t & nb11,
3791
+ constant int64_t & nb12,
3792
+ constant int64_t & ne0,
3793
+ constant int64_t & ne1,
3794
+ constant uint & r2,
3795
+ constant uint & r3,
3796
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
3797
+ uint3 tgpig[[threadgroup_position_in_grid]],
3798
+ uint tiitg[[thread_index_in_threadgroup]],
3799
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2770
3800
 
2771
3801
  threadgroup half * sa = (threadgroup half *)(shared_memory);
2772
3802
  threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
@@ -2792,7 +3822,10 @@ kernel void kernel_mul_mm(device const uchar * src0,
2792
3822
 
2793
3823
  short il = (tiitg % THREAD_PER_ROW);
2794
3824
 
2795
- uint offset0 = im/gqa*nb02;
3825
+ const uint i12 = im%ne12;
3826
+ const uint i13 = im/ne12;
3827
+
3828
+ uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
2796
3829
  ushort offset1 = il/nl;
2797
3830
 
2798
3831
  device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
@@ -2876,17 +3909,137 @@ kernel void kernel_mul_mm(device const uchar * src0,
2876
3909
  }
2877
3910
  }
2878
3911
 
3912
+ template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
3913
+ kernel void kernel_mul_mm(device const uchar * src0,
3914
+ device const uchar * src1,
3915
+ device float * dst,
3916
+ constant int64_t & ne00,
3917
+ constant int64_t & ne02,
3918
+ constant int64_t & nb01,
3919
+ constant int64_t & nb02,
3920
+ constant int64_t & ne12,
3921
+ constant int64_t & nb10,
3922
+ constant int64_t & nb11,
3923
+ constant int64_t & nb12,
3924
+ constant int64_t & ne0,
3925
+ constant int64_t & ne1,
3926
+ constant uint & r2,
3927
+ constant uint & r3,
3928
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
3929
+ uint3 tgpig[[threadgroup_position_in_grid]],
3930
+ uint tiitg[[thread_index_in_threadgroup]],
3931
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3932
+ kernel_mul_mm_impl<block_q, nl, dequantize_func>(
3933
+ src0,
3934
+ src1,
3935
+ dst,
3936
+ ne00,
3937
+ ne02,
3938
+ nb01,
3939
+ nb02,
3940
+ ne12,
3941
+ nb10,
3942
+ nb11,
3943
+ nb12,
3944
+ ne0,
3945
+ ne1,
3946
+ r2,
3947
+ r3,
3948
+ shared_memory,
3949
+ tgpig,
3950
+ tiitg,
3951
+ sgitg);
3952
+ }
3953
+
3954
+ template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
3955
+ kernel void kernel_mul_mm_id(
3956
+ device const uchar * ids,
3957
+ device const uchar * src1,
3958
+ device uchar * dst,
3959
+ constant int64_t & nbi1,
3960
+ constant int64_t & ne00,
3961
+ constant int64_t & ne02,
3962
+ constant int64_t & nb01,
3963
+ constant int64_t & nb02,
3964
+ constant int64_t & ne12,
3965
+ constant int64_t & ne13,
3966
+ constant int64_t & nb10,
3967
+ constant int64_t & nb11,
3968
+ constant int64_t & nb12,
3969
+ constant int64_t & ne0,
3970
+ constant int64_t & ne1,
3971
+ constant int64_t & nb1,
3972
+ constant uint & r2,
3973
+ constant uint & r3,
3974
+ constant int & idx,
3975
+ device const uchar * src00,
3976
+ device const uchar * src01,
3977
+ device const uchar * src02,
3978
+ device const uchar * src03,
3979
+ device const uchar * src04,
3980
+ device const uchar * src05,
3981
+ device const uchar * src06,
3982
+ device const uchar * src07,
3983
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
3984
+ uint3 tgpig[[threadgroup_position_in_grid]],
3985
+ uint tiitg[[thread_index_in_threadgroup]],
3986
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3987
+ device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
3988
+
3989
+ const int64_t bid = tgpig.z/(ne12*ne13);
3990
+
3991
+ tgpig.z = tgpig.z%(ne12*ne13);
3992
+
3993
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
3994
+
3995
+ kernel_mul_mm_impl<block_q, nl, dequantize_func>(
3996
+ src0[id],
3997
+ src1 + bid*nb11,
3998
+ (device float *) (dst + bid*nb1),
3999
+ ne00,
4000
+ ne02,
4001
+ nb01,
4002
+ nb02,
4003
+ ne12,
4004
+ nb10,
4005
+ nb11,
4006
+ nb12,
4007
+ ne0,
4008
+ ne1,
4009
+ r2,
4010
+ r3,
4011
+ shared_memory,
4012
+ tgpig,
4013
+ tiitg,
4014
+ sgitg);
4015
+ }
4016
+
2879
4017
  #if QK_K == 256
2880
4018
  #define QK_NL 16
2881
4019
  #else
2882
4020
  #define QK_NL 4
2883
4021
  #endif
2884
4022
 
2885
- typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
2886
- constant uint64_t &, constant uint64_t &, uint, uint, uint);
4023
+ //
4024
+ // get rows
4025
+ //
2887
4026
 
2888
- template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
2889
- template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
4027
+ typedef void (get_rows_t)(
4028
+ device const void * src0,
4029
+ device const char * src1,
4030
+ device float * dst,
4031
+ constant int64_t & ne00,
4032
+ constant uint64_t & nb01,
4033
+ constant uint64_t & nb02,
4034
+ constant int64_t & ne10,
4035
+ constant uint64_t & nb10,
4036
+ constant uint64_t & nb11,
4037
+ constant uint64_t & nb1,
4038
+ constant uint64_t & nb2,
4039
+ uint3, uint, uint3);
4040
+
4041
+ //template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
4042
+ //template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
2890
4043
  template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
2891
4044
  template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
2892
4045
  template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
@@ -2898,6 +4051,10 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows
2898
4051
  template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
2899
4052
  template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
2900
4053
 
4054
+ //
4055
+ // matrix-matrix multiplication
4056
+ //
4057
+
2901
4058
  typedef void (mat_mm_t)(
2902
4059
  device const uchar * src0,
2903
4060
  device const uchar * src1,
@@ -2912,8 +4069,10 @@ typedef void (mat_mm_t)(
2912
4069
  constant int64_t & nb12,
2913
4070
  constant int64_t & ne0,
2914
4071
  constant int64_t & ne1,
2915
- constant uint & gqa,
2916
- threadgroup uchar *, uint3, uint, uint);
4072
+ constant uint & r2,
4073
+ constant uint & r3,
4074
+ threadgroup uchar *,
4075
+ uint3, uint, uint);
2917
4076
 
2918
4077
  template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
2919
4078
  template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
@@ -2927,3 +4086,823 @@ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
2927
4086
  template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
2928
4087
  template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
2929
4088
  template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
4089
+
4090
+ //
4091
+ // indirect matrix-matrix multiplication
4092
+ //
4093
+
4094
+ typedef void (mat_mm_id_t)(
4095
+ device const uchar * ids,
4096
+ device const uchar * src1,
4097
+ device uchar * dst,
4098
+ constant int64_t & nbi1,
4099
+ constant int64_t & ne00,
4100
+ constant int64_t & ne02,
4101
+ constant int64_t & nb01,
4102
+ constant int64_t & nb02,
4103
+ constant int64_t & ne12,
4104
+ constant int64_t & ne13,
4105
+ constant int64_t & nb10,
4106
+ constant int64_t & nb11,
4107
+ constant int64_t & nb12,
4108
+ constant int64_t & ne0,
4109
+ constant int64_t & ne1,
4110
+ constant int64_t & nb1,
4111
+ constant uint & r2,
4112
+ constant uint & r3,
4113
+ constant int & idx,
4114
+ device const uchar * src00,
4115
+ device const uchar * src01,
4116
+ device const uchar * src02,
4117
+ device const uchar * src03,
4118
+ device const uchar * src04,
4119
+ device const uchar * src05,
4120
+ device const uchar * src06,
4121
+ device const uchar * src07,
4122
+ threadgroup uchar *,
4123
+ uint3, uint, uint);
4124
+
4125
+ template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
4126
+ template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
4127
+ template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
4128
+ template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
4129
+ template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
4130
+ template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>;
4131
+ template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>;
4132
+ template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
4133
+ template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
4134
+ template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
4135
+ template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
4136
+ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
4137
+
4138
+ //
4139
+ // matrix-vector multiplication
4140
+ //
4141
+
4142
+ [[host_name("kernel_mul_mv_id_f32_f32")]]
4143
+ kernel void kernel_mul_mv_id_f32_f32(
4144
+ device const char * ids,
4145
+ device const char * src1,
4146
+ device uchar * dst,
4147
+ constant int64_t & nbi1,
4148
+ constant int64_t & ne00,
4149
+ constant int64_t & ne01,
4150
+ constant int64_t & ne02,
4151
+ constant uint64_t & nb00,
4152
+ constant uint64_t & nb01,
4153
+ constant uint64_t & nb02,
4154
+ constant int64_t & ne10,
4155
+ constant int64_t & ne11,
4156
+ constant int64_t & ne12,
4157
+ constant int64_t & ne13,
4158
+ constant uint64_t & nb10,
4159
+ constant uint64_t & nb11,
4160
+ constant uint64_t & nb12,
4161
+ constant int64_t & ne0,
4162
+ constant int64_t & ne1,
4163
+ constant int64_t & nb1,
4164
+ constant uint & r2,
4165
+ constant uint & r3,
4166
+ constant int & idx,
4167
+ device const char * src00,
4168
+ device const char * src01,
4169
+ device const char * src02,
4170
+ device const char * src03,
4171
+ device const char * src04,
4172
+ device const char * src05,
4173
+ device const char * src06,
4174
+ device const char * src07,
4175
+ uint3 tgpig[[threadgroup_position_in_grid]],
4176
+ uint tiitg[[thread_index_in_threadgroup]],
4177
+ uint tiisg[[thread_index_in_simdgroup]],
4178
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4179
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4180
+
4181
+ const int64_t bid = tgpig.z/(ne12*ne13);
4182
+
4183
+ tgpig.z = tgpig.z%(ne12*ne13);
4184
+
4185
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4186
+
4187
+ kernel_mul_mv_f32_f32_impl(
4188
+ src0[id],
4189
+ src1 + bid*nb11,
4190
+ (device float *) (dst + bid*nb1),
4191
+ ne00,
4192
+ ne01,
4193
+ ne02,
4194
+ nb00,
4195
+ nb01,
4196
+ nb02,
4197
+ ne10,
4198
+ ne11,
4199
+ ne12,
4200
+ nb10,
4201
+ nb11,
4202
+ nb12,
4203
+ ne0,
4204
+ ne1,
4205
+ r2,
4206
+ r3,
4207
+ tgpig,
4208
+ tiisg);
4209
+ }
4210
+
4211
+ [[host_name("kernel_mul_mv_id_f16_f32")]]
4212
+ kernel void kernel_mul_mv_id_f16_f32(
4213
+ device const char * ids,
4214
+ device const char * src1,
4215
+ device uchar * dst,
4216
+ constant int64_t & nbi1,
4217
+ constant int64_t & ne00,
4218
+ constant int64_t & ne01,
4219
+ constant int64_t & ne02,
4220
+ constant uint64_t & nb00,
4221
+ constant uint64_t & nb01,
4222
+ constant uint64_t & nb02,
4223
+ constant int64_t & ne10,
4224
+ constant int64_t & ne11,
4225
+ constant int64_t & ne12,
4226
+ constant int64_t & ne13,
4227
+ constant uint64_t & nb10,
4228
+ constant uint64_t & nb11,
4229
+ constant uint64_t & nb12,
4230
+ constant int64_t & ne0,
4231
+ constant int64_t & ne1,
4232
+ constant int64_t & nb1,
4233
+ constant uint & r2,
4234
+ constant uint & r3,
4235
+ constant int & idx,
4236
+ device const char * src00,
4237
+ device const char * src01,
4238
+ device const char * src02,
4239
+ device const char * src03,
4240
+ device const char * src04,
4241
+ device const char * src05,
4242
+ device const char * src06,
4243
+ device const char * src07,
4244
+ uint3 tgpig[[threadgroup_position_in_grid]],
4245
+ uint tiitg[[thread_index_in_threadgroup]],
4246
+ uint tiisg[[thread_index_in_simdgroup]],
4247
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4248
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4249
+
4250
+ const int64_t bid = tgpig.z/(ne12*ne13);
4251
+
4252
+ tgpig.z = tgpig.z%(ne12*ne13);
4253
+
4254
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4255
+
4256
+ kernel_mul_mv_f16_f32_impl(
4257
+ src0[id],
4258
+ src1 + bid*nb11,
4259
+ (device float *) (dst + bid*nb1),
4260
+ ne00,
4261
+ ne01,
4262
+ ne02,
4263
+ nb00,
4264
+ nb01,
4265
+ nb02,
4266
+ ne10,
4267
+ ne11,
4268
+ ne12,
4269
+ nb10,
4270
+ nb11,
4271
+ nb12,
4272
+ ne0,
4273
+ ne1,
4274
+ r2,
4275
+ r3,
4276
+ tgpig,
4277
+ tiisg);
4278
+ }
4279
+
4280
+ [[host_name("kernel_mul_mv_id_q8_0_f32")]]
4281
+ kernel void kernel_mul_mv_id_q8_0_f32(
4282
+ device const char * ids,
4283
+ device const char * src1,
4284
+ device uchar * dst,
4285
+ constant int64_t & nbi1,
4286
+ constant int64_t & ne00,
4287
+ constant int64_t & ne01,
4288
+ constant int64_t & ne02,
4289
+ constant uint64_t & nb00,
4290
+ constant uint64_t & nb01,
4291
+ constant uint64_t & nb02,
4292
+ constant int64_t & ne10,
4293
+ constant int64_t & ne11,
4294
+ constant int64_t & ne12,
4295
+ constant int64_t & ne13,
4296
+ constant uint64_t & nb10,
4297
+ constant uint64_t & nb11,
4298
+ constant uint64_t & nb12,
4299
+ constant int64_t & ne0,
4300
+ constant int64_t & ne1,
4301
+ constant int64_t & nb1,
4302
+ constant uint & r2,
4303
+ constant uint & r3,
4304
+ constant int & idx,
4305
+ device const char * src00,
4306
+ device const char * src01,
4307
+ device const char * src02,
4308
+ device const char * src03,
4309
+ device const char * src04,
4310
+ device const char * src05,
4311
+ device const char * src06,
4312
+ device const char * src07,
4313
+ uint3 tgpig[[threadgroup_position_in_grid]],
4314
+ uint tiitg[[thread_index_in_threadgroup]],
4315
+ uint tiisg[[thread_index_in_simdgroup]],
4316
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4317
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4318
+
4319
+ const int64_t bid = tgpig.z/(ne12*ne13);
4320
+
4321
+ tgpig.z = tgpig.z%(ne12*ne13);
4322
+
4323
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4324
+
4325
+ kernel_mul_mv_q8_0_f32_impl(
4326
+ src0[id],
4327
+ (device const float *) (src1 + bid*nb11),
4328
+ (device float *) ( dst + bid*nb1),
4329
+ ne00,
4330
+ ne01,
4331
+ ne02,
4332
+ ne10,
4333
+ ne12,
4334
+ ne0,
4335
+ ne1,
4336
+ r2,
4337
+ r3,
4338
+ tgpig,
4339
+ tiisg,
4340
+ sgitg);
4341
+ }
4342
+
4343
+ [[host_name("kernel_mul_mv_id_q4_0_f32")]]
4344
+ kernel void kernel_mul_mv_id_q4_0_f32(
4345
+ device const char * ids,
4346
+ device const char * src1,
4347
+ device uchar * dst,
4348
+ constant int64_t & nbi1,
4349
+ constant int64_t & ne00,
4350
+ constant int64_t & ne01,
4351
+ constant int64_t & ne02,
4352
+ constant uint64_t & nb00,
4353
+ constant uint64_t & nb01,
4354
+ constant uint64_t & nb02,
4355
+ constant int64_t & ne10,
4356
+ constant int64_t & ne11,
4357
+ constant int64_t & ne12,
4358
+ constant int64_t & ne13,
4359
+ constant uint64_t & nb10,
4360
+ constant uint64_t & nb11,
4361
+ constant uint64_t & nb12,
4362
+ constant int64_t & ne0,
4363
+ constant int64_t & ne1,
4364
+ constant int64_t & nb1,
4365
+ constant uint & r2,
4366
+ constant uint & r3,
4367
+ constant int & idx,
4368
+ device const char * src00,
4369
+ device const char * src01,
4370
+ device const char * src02,
4371
+ device const char * src03,
4372
+ device const char * src04,
4373
+ device const char * src05,
4374
+ device const char * src06,
4375
+ device const char * src07,
4376
+ uint3 tgpig[[threadgroup_position_in_grid]],
4377
+ uint tiitg[[thread_index_in_threadgroup]],
4378
+ uint tiisg[[thread_index_in_simdgroup]],
4379
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4380
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4381
+
4382
+ const int64_t bid = tgpig.z/(ne12*ne13);
4383
+
4384
+ tgpig.z = tgpig.z%(ne12*ne13);
4385
+
4386
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4387
+
4388
+ mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4389
+ src0[id],
4390
+ (device const float *) (src1 + bid*nb11),
4391
+ (device float *) ( dst + bid*nb1),
4392
+ ne00,
4393
+ ne01,
4394
+ ne02,
4395
+ ne10,
4396
+ ne12,
4397
+ ne0,
4398
+ ne1,
4399
+ r2,
4400
+ r3,
4401
+ tgpig,
4402
+ tiisg,
4403
+ sgitg);
4404
+ }
4405
+
4406
+ [[host_name("kernel_mul_mv_id_q4_1_f32")]]
4407
+ kernel void kernel_mul_mv_id_q4_1_f32(
4408
+ device const char * ids,
4409
+ device const char * src1,
4410
+ device uchar * dst,
4411
+ constant int64_t & nbi1,
4412
+ constant int64_t & ne00,
4413
+ constant int64_t & ne01,
4414
+ constant int64_t & ne02,
4415
+ constant uint64_t & nb00,
4416
+ constant uint64_t & nb01,
4417
+ constant uint64_t & nb02,
4418
+ constant int64_t & ne10,
4419
+ constant int64_t & ne11,
4420
+ constant int64_t & ne12,
4421
+ constant int64_t & ne13,
4422
+ constant uint64_t & nb10,
4423
+ constant uint64_t & nb11,
4424
+ constant uint64_t & nb12,
4425
+ constant int64_t & ne0,
4426
+ constant int64_t & ne1,
4427
+ constant int64_t & nb1,
4428
+ constant uint & r2,
4429
+ constant uint & r3,
4430
+ constant int & idx,
4431
+ device const char * src00,
4432
+ device const char * src01,
4433
+ device const char * src02,
4434
+ device const char * src03,
4435
+ device const char * src04,
4436
+ device const char * src05,
4437
+ device const char * src06,
4438
+ device const char * src07,
4439
+ uint3 tgpig[[threadgroup_position_in_grid]],
4440
+ uint tiitg[[thread_index_in_threadgroup]],
4441
+ uint tiisg[[thread_index_in_simdgroup]],
4442
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4443
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4444
+
4445
+ const int64_t bid = tgpig.z/(ne12*ne13);
4446
+
4447
+ tgpig.z = tgpig.z%(ne12*ne13);
4448
+
4449
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4450
+
4451
+ mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4452
+ src0[id],
4453
+ (device const float *) (src1 + bid*nb11),
4454
+ (device float *) ( dst + bid*nb1),
4455
+ ne00,
4456
+ ne01,
4457
+ ne02,
4458
+ ne10,
4459
+ ne12,
4460
+ ne0,
4461
+ ne1,
4462
+ r2,
4463
+ r3,
4464
+ tgpig,
4465
+ tiisg,
4466
+ sgitg);
4467
+ }
4468
+
4469
+ [[host_name("kernel_mul_mv_id_q5_0_f32")]]
4470
+ kernel void kernel_mul_mv_id_q5_0_f32(
4471
+ device const char * ids,
4472
+ device const char * src1,
4473
+ device uchar * dst,
4474
+ constant int64_t & nbi1,
4475
+ constant int64_t & ne00,
4476
+ constant int64_t & ne01,
4477
+ constant int64_t & ne02,
4478
+ constant uint64_t & nb00,
4479
+ constant uint64_t & nb01,
4480
+ constant uint64_t & nb02,
4481
+ constant int64_t & ne10,
4482
+ constant int64_t & ne11,
4483
+ constant int64_t & ne12,
4484
+ constant int64_t & ne13,
4485
+ constant uint64_t & nb10,
4486
+ constant uint64_t & nb11,
4487
+ constant uint64_t & nb12,
4488
+ constant int64_t & ne0,
4489
+ constant int64_t & ne1,
4490
+ constant int64_t & nb1,
4491
+ constant uint & r2,
4492
+ constant uint & r3,
4493
+ constant int & idx,
4494
+ device const char * src00,
4495
+ device const char * src01,
4496
+ device const char * src02,
4497
+ device const char * src03,
4498
+ device const char * src04,
4499
+ device const char * src05,
4500
+ device const char * src06,
4501
+ device const char * src07,
4502
+ uint3 tgpig[[threadgroup_position_in_grid]],
4503
+ uint tiitg[[thread_index_in_threadgroup]],
4504
+ uint tiisg[[thread_index_in_simdgroup]],
4505
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4506
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4507
+
4508
+ const int64_t bid = tgpig.z/(ne12*ne13);
4509
+
4510
+ tgpig.z = tgpig.z%(ne12*ne13);
4511
+
4512
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4513
+
4514
+ mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4515
+ src0[id],
4516
+ (device const float *) (src1 + bid*nb11),
4517
+ (device float *) ( dst + bid*nb1),
4518
+ ne00,
4519
+ ne01,
4520
+ ne02,
4521
+ ne10,
4522
+ ne12,
4523
+ ne0,
4524
+ ne1,
4525
+ r2,
4526
+ r3,
4527
+ tgpig,
4528
+ tiisg,
4529
+ sgitg);
4530
+ }
4531
+
4532
+ [[host_name("kernel_mul_mv_id_q5_1_f32")]]
4533
+ kernel void kernel_mul_mv_id_q5_1_f32(
4534
+ device const char * ids,
4535
+ device const char * src1,
4536
+ device uchar * dst,
4537
+ constant int64_t & nbi1,
4538
+ constant int64_t & ne00,
4539
+ constant int64_t & ne01,
4540
+ constant int64_t & ne02,
4541
+ constant uint64_t & nb00,
4542
+ constant uint64_t & nb01,
4543
+ constant uint64_t & nb02,
4544
+ constant int64_t & ne10,
4545
+ constant int64_t & ne11,
4546
+ constant int64_t & ne12,
4547
+ constant int64_t & ne13,
4548
+ constant uint64_t & nb10,
4549
+ constant uint64_t & nb11,
4550
+ constant uint64_t & nb12,
4551
+ constant int64_t & ne0,
4552
+ constant int64_t & ne1,
4553
+ constant int64_t & nb1,
4554
+ constant uint & r2,
4555
+ constant uint & r3,
4556
+ constant int & idx,
4557
+ device const char * src00,
4558
+ device const char * src01,
4559
+ device const char * src02,
4560
+ device const char * src03,
4561
+ device const char * src04,
4562
+ device const char * src05,
4563
+ device const char * src06,
4564
+ device const char * src07,
4565
+ uint3 tgpig[[threadgroup_position_in_grid]],
4566
+ uint tiitg[[thread_index_in_threadgroup]],
4567
+ uint tiisg[[thread_index_in_simdgroup]],
4568
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4569
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4570
+
4571
+ const int64_t bid = tgpig.z/(ne12*ne13);
4572
+
4573
+ tgpig.z = tgpig.z%(ne12*ne13);
4574
+
4575
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4576
+
4577
+ mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4578
+ src0[id],
4579
+ (device const float *) (src1 + bid*nb11),
4580
+ (device float *) ( dst + bid*nb1),
4581
+ ne00,
4582
+ ne01,
4583
+ ne02,
4584
+ ne10,
4585
+ ne12,
4586
+ ne0,
4587
+ ne1,
4588
+ r2,
4589
+ r3,
4590
+ tgpig,
4591
+ tiisg,
4592
+ sgitg);
4593
+ }
4594
+
4595
+ [[host_name("kernel_mul_mv_id_q2_K_f32")]]
4596
+ kernel void kernel_mul_mv_id_q2_K_f32(
4597
+ device const char * ids,
4598
+ device const char * src1,
4599
+ device uchar * dst,
4600
+ constant int64_t & nbi1,
4601
+ constant int64_t & ne00,
4602
+ constant int64_t & ne01,
4603
+ constant int64_t & ne02,
4604
+ constant uint64_t & nb00,
4605
+ constant uint64_t & nb01,
4606
+ constant uint64_t & nb02,
4607
+ constant int64_t & ne10,
4608
+ constant int64_t & ne11,
4609
+ constant int64_t & ne12,
4610
+ constant int64_t & ne13,
4611
+ constant uint64_t & nb10,
4612
+ constant uint64_t & nb11,
4613
+ constant uint64_t & nb12,
4614
+ constant int64_t & ne0,
4615
+ constant int64_t & ne1,
4616
+ constant int64_t & nb1,
4617
+ constant uint & r2,
4618
+ constant uint & r3,
4619
+ constant int & idx,
4620
+ device const char * src00,
4621
+ device const char * src01,
4622
+ device const char * src02,
4623
+ device const char * src03,
4624
+ device const char * src04,
4625
+ device const char * src05,
4626
+ device const char * src06,
4627
+ device const char * src07,
4628
+ uint3 tgpig[[threadgroup_position_in_grid]],
4629
+ uint tiitg[[thread_index_in_threadgroup]],
4630
+ uint tiisg[[thread_index_in_simdgroup]],
4631
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4632
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4633
+
4634
+ const int64_t bid = tgpig.z/(ne12*ne13);
4635
+
4636
+ tgpig.z = tgpig.z%(ne12*ne13);
4637
+
4638
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4639
+
4640
+ kernel_mul_mv_q2_K_f32_impl(
4641
+ src0[id],
4642
+ (device const float *) (src1 + bid*nb11),
4643
+ (device float *) ( dst + bid*nb1),
4644
+ ne00,
4645
+ ne01,
4646
+ ne02,
4647
+ ne10,
4648
+ ne12,
4649
+ ne0,
4650
+ ne1,
4651
+ r2,
4652
+ r3,
4653
+ tgpig,
4654
+ tiisg,
4655
+ sgitg);
4656
+ }
4657
+
4658
+ [[host_name("kernel_mul_mv_id_q3_K_f32")]]
4659
+ kernel void kernel_mul_mv_id_q3_K_f32(
4660
+ device const char * ids,
4661
+ device const char * src1,
4662
+ device uchar * dst,
4663
+ constant int64_t & nbi1,
4664
+ constant int64_t & ne00,
4665
+ constant int64_t & ne01,
4666
+ constant int64_t & ne02,
4667
+ constant uint64_t & nb00,
4668
+ constant uint64_t & nb01,
4669
+ constant uint64_t & nb02,
4670
+ constant int64_t & ne10,
4671
+ constant int64_t & ne11,
4672
+ constant int64_t & ne12,
4673
+ constant int64_t & ne13,
4674
+ constant uint64_t & nb10,
4675
+ constant uint64_t & nb11,
4676
+ constant uint64_t & nb12,
4677
+ constant int64_t & ne0,
4678
+ constant int64_t & ne1,
4679
+ constant int64_t & nb1,
4680
+ constant uint & r2,
4681
+ constant uint & r3,
4682
+ constant int & idx,
4683
+ device const char * src00,
4684
+ device const char * src01,
4685
+ device const char * src02,
4686
+ device const char * src03,
4687
+ device const char * src04,
4688
+ device const char * src05,
4689
+ device const char * src06,
4690
+ device const char * src07,
4691
+ uint3 tgpig[[threadgroup_position_in_grid]],
4692
+ uint tiitg[[thread_index_in_threadgroup]],
4693
+ uint tiisg[[thread_index_in_simdgroup]],
4694
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4695
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4696
+
4697
+ const int64_t bid = tgpig.z/(ne12*ne13);
4698
+
4699
+ tgpig.z = tgpig.z%(ne12*ne13);
4700
+
4701
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4702
+
4703
+ kernel_mul_mv_q3_K_f32_impl(
4704
+ src0[id],
4705
+ (device const float *) (src1 + bid*nb11),
4706
+ (device float *) ( dst + bid*nb1),
4707
+ ne00,
4708
+ ne01,
4709
+ ne02,
4710
+ ne10,
4711
+ ne12,
4712
+ ne0,
4713
+ ne1,
4714
+ r2,
4715
+ r3,
4716
+ tgpig,
4717
+ tiisg,
4718
+ sgitg);
4719
+ }
4720
+
4721
+ [[host_name("kernel_mul_mv_id_q4_K_f32")]]
4722
+ kernel void kernel_mul_mv_id_q4_K_f32(
4723
+ device const char * ids,
4724
+ device const char * src1,
4725
+ device uchar * dst,
4726
+ constant int64_t & nbi1,
4727
+ constant int64_t & ne00,
4728
+ constant int64_t & ne01,
4729
+ constant int64_t & ne02,
4730
+ constant uint64_t & nb00,
4731
+ constant uint64_t & nb01,
4732
+ constant uint64_t & nb02,
4733
+ constant int64_t & ne10,
4734
+ constant int64_t & ne11,
4735
+ constant int64_t & ne12,
4736
+ constant int64_t & ne13,
4737
+ constant uint64_t & nb10,
4738
+ constant uint64_t & nb11,
4739
+ constant uint64_t & nb12,
4740
+ constant int64_t & ne0,
4741
+ constant int64_t & ne1,
4742
+ constant int64_t & nb1,
4743
+ constant uint & r2,
4744
+ constant uint & r3,
4745
+ constant int & idx,
4746
+ device const char * src00,
4747
+ device const char * src01,
4748
+ device const char * src02,
4749
+ device const char * src03,
4750
+ device const char * src04,
4751
+ device const char * src05,
4752
+ device const char * src06,
4753
+ device const char * src07,
4754
+ uint3 tgpig[[threadgroup_position_in_grid]],
4755
+ uint tiitg[[thread_index_in_threadgroup]],
4756
+ uint tiisg[[thread_index_in_simdgroup]],
4757
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4758
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4759
+
4760
+ const int64_t bid = tgpig.z/(ne12*ne13);
4761
+
4762
+ tgpig.z = tgpig.z%(ne12*ne13);
4763
+
4764
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4765
+
4766
+ kernel_mul_mv_q4_K_f32_impl(
4767
+ src0[id],
4768
+ (device const float *) (src1 + bid*nb11),
4769
+ (device float *) ( dst + bid*nb1),
4770
+ ne00,
4771
+ ne01,
4772
+ ne02,
4773
+ ne10,
4774
+ ne12,
4775
+ ne0,
4776
+ ne1,
4777
+ r2,
4778
+ r3,
4779
+ tgpig,
4780
+ tiisg,
4781
+ sgitg);
4782
+ }
4783
+
4784
+ [[host_name("kernel_mul_mv_id_q5_K_f32")]]
4785
+ kernel void kernel_mul_mv_id_q5_K_f32(
4786
+ device const char * ids,
4787
+ device const char * src1,
4788
+ device uchar * dst,
4789
+ constant int64_t & nbi1,
4790
+ constant int64_t & ne00,
4791
+ constant int64_t & ne01,
4792
+ constant int64_t & ne02,
4793
+ constant uint64_t & nb00,
4794
+ constant uint64_t & nb01,
4795
+ constant uint64_t & nb02,
4796
+ constant int64_t & ne10,
4797
+ constant int64_t & ne11,
4798
+ constant int64_t & ne12,
4799
+ constant int64_t & ne13,
4800
+ constant uint64_t & nb10,
4801
+ constant uint64_t & nb11,
4802
+ constant uint64_t & nb12,
4803
+ constant int64_t & ne0,
4804
+ constant int64_t & ne1,
4805
+ constant int64_t & nb1,
4806
+ constant uint & r2,
4807
+ constant uint & r3,
4808
+ constant int & idx,
4809
+ device const char * src00,
4810
+ device const char * src01,
4811
+ device const char * src02,
4812
+ device const char * src03,
4813
+ device const char * src04,
4814
+ device const char * src05,
4815
+ device const char * src06,
4816
+ device const char * src07,
4817
+ uint3 tgpig[[threadgroup_position_in_grid]],
4818
+ uint tiitg[[thread_index_in_threadgroup]],
4819
+ uint tiisg[[thread_index_in_simdgroup]],
4820
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4821
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4822
+
4823
+ const int64_t bid = tgpig.z/(ne12*ne13);
4824
+
4825
+ tgpig.z = tgpig.z%(ne12*ne13);
4826
+
4827
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4828
+
4829
+ kernel_mul_mv_q5_K_f32_impl(
4830
+ src0[id],
4831
+ (device const float *) (src1 + bid*nb11),
4832
+ (device float *) ( dst + bid*nb1),
4833
+ ne00,
4834
+ ne01,
4835
+ ne02,
4836
+ ne10,
4837
+ ne12,
4838
+ ne0,
4839
+ ne1,
4840
+ r2,
4841
+ r3,
4842
+ tgpig,
4843
+ tiisg,
4844
+ sgitg);
4845
+ }
4846
+
4847
+ [[host_name("kernel_mul_mv_id_q6_K_f32")]]
4848
+ kernel void kernel_mul_mv_id_q6_K_f32(
4849
+ device const char * ids,
4850
+ device const char * src1,
4851
+ device uchar * dst,
4852
+ constant int64_t & nbi1,
4853
+ constant int64_t & ne00,
4854
+ constant int64_t & ne01,
4855
+ constant int64_t & ne02,
4856
+ constant uint64_t & nb00,
4857
+ constant uint64_t & nb01,
4858
+ constant uint64_t & nb02,
4859
+ constant int64_t & ne10,
4860
+ constant int64_t & ne11,
4861
+ constant int64_t & ne12,
4862
+ constant int64_t & ne13,
4863
+ constant uint64_t & nb10,
4864
+ constant uint64_t & nb11,
4865
+ constant uint64_t & nb12,
4866
+ constant int64_t & ne0,
4867
+ constant int64_t & ne1,
4868
+ constant int64_t & nb1,
4869
+ constant uint & r2,
4870
+ constant uint & r3,
4871
+ constant int & idx,
4872
+ device const char * src00,
4873
+ device const char * src01,
4874
+ device const char * src02,
4875
+ device const char * src03,
4876
+ device const char * src04,
4877
+ device const char * src05,
4878
+ device const char * src06,
4879
+ device const char * src07,
4880
+ uint3 tgpig[[threadgroup_position_in_grid]],
4881
+ uint tiitg[[thread_index_in_threadgroup]],
4882
+ uint tiisg[[thread_index_in_simdgroup]],
4883
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4884
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4885
+
4886
+ const int64_t bid = tgpig.z/(ne12*ne13);
4887
+
4888
+ tgpig.z = tgpig.z%(ne12*ne13);
4889
+
4890
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4891
+
4892
+ kernel_mul_mv_q6_K_f32_impl(
4893
+ src0[id],
4894
+ (device const float *) (src1 + bid*nb11),
4895
+ (device float *) ( dst + bid*nb1),
4896
+ ne00,
4897
+ ne01,
4898
+ ne02,
4899
+ ne10,
4900
+ ne12,
4901
+ ne0,
4902
+ ne1,
4903
+ r2,
4904
+ r3,
4905
+ tgpig,
4906
+ tiisg,
4907
+ sgitg);
4908
+ }