llama_cpp 0.9.5 → 0.10.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,6 +3,8 @@
3
3
  using namespace metal;
4
4
 
5
5
  #define MAX(x, y) ((x) > (y) ? (x) : (y))
6
+ #define MIN(x, y) ((x) < (y) ? (x) : (y))
7
+ #define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
6
8
 
7
9
  #define QK4_0 32
8
10
  #define QR4_0 2
@@ -41,8 +43,13 @@ typedef struct {
41
43
 
42
44
  #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
43
45
 
44
- // general-purpose kernel for addition of two tensors
45
- // pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
46
+ enum ggml_sort_order {
47
+ GGML_SORT_ASC,
48
+ GGML_SORT_DESC,
49
+ };
50
+
51
+ // general-purpose kernel for addition, multiplication and division of two tensors
52
+ // pros: works for non-contiguous tensors, supports broadcast across all dims
46
53
  // cons: not very efficient
47
54
  kernel void kernel_add(
48
55
  device const char * src0,
@@ -72,6 +79,7 @@ kernel void kernel_add(
72
79
  constant int64_t & nb1,
73
80
  constant int64_t & nb2,
74
81
  constant int64_t & nb3,
82
+ constant int64_t & offs,
75
83
  uint3 tgpig[[threadgroup_position_in_grid]],
76
84
  uint3 tpitg[[thread_position_in_threadgroup]],
77
85
  uint3 ntg[[threads_per_threadgroup]]) {
@@ -83,16 +91,111 @@ kernel void kernel_add(
83
91
  const int64_t i12 = i02 % ne12;
84
92
  const int64_t i11 = i01 % ne11;
85
93
 
86
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
87
- device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
88
- device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
94
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
95
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
96
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs;
97
+
98
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
99
+ const int i10 = i0 % ne10;
100
+ *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10));
101
+ }
102
+ }
103
+
104
+ kernel void kernel_mul(
105
+ device const char * src0,
106
+ device const char * src1,
107
+ device char * dst,
108
+ constant int64_t & ne00,
109
+ constant int64_t & ne01,
110
+ constant int64_t & ne02,
111
+ constant int64_t & ne03,
112
+ constant int64_t & nb00,
113
+ constant int64_t & nb01,
114
+ constant int64_t & nb02,
115
+ constant int64_t & nb03,
116
+ constant int64_t & ne10,
117
+ constant int64_t & ne11,
118
+ constant int64_t & ne12,
119
+ constant int64_t & ne13,
120
+ constant int64_t & nb10,
121
+ constant int64_t & nb11,
122
+ constant int64_t & nb12,
123
+ constant int64_t & nb13,
124
+ constant int64_t & ne0,
125
+ constant int64_t & ne1,
126
+ constant int64_t & ne2,
127
+ constant int64_t & ne3,
128
+ constant int64_t & nb0,
129
+ constant int64_t & nb1,
130
+ constant int64_t & nb2,
131
+ constant int64_t & nb3,
132
+ uint3 tgpig[[threadgroup_position_in_grid]],
133
+ uint3 tpitg[[thread_position_in_threadgroup]],
134
+ uint3 ntg[[threads_per_threadgroup]]) {
135
+ const int64_t i03 = tgpig.z;
136
+ const int64_t i02 = tgpig.y;
137
+ const int64_t i01 = tgpig.x;
138
+
139
+ const int64_t i13 = i03 % ne13;
140
+ const int64_t i12 = i02 % ne12;
141
+ const int64_t i11 = i01 % ne11;
142
+
143
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
144
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
145
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
89
146
 
90
147
  for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
91
- ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0] + ((device float *)src1_ptr)[0];
148
+ const int i10 = i0 % ne10;
149
+ *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10));
150
+ }
151
+ }
152
+
153
+ kernel void kernel_div(
154
+ device const char * src0,
155
+ device const char * src1,
156
+ device char * dst,
157
+ constant int64_t & ne00,
158
+ constant int64_t & ne01,
159
+ constant int64_t & ne02,
160
+ constant int64_t & ne03,
161
+ constant int64_t & nb00,
162
+ constant int64_t & nb01,
163
+ constant int64_t & nb02,
164
+ constant int64_t & nb03,
165
+ constant int64_t & ne10,
166
+ constant int64_t & ne11,
167
+ constant int64_t & ne12,
168
+ constant int64_t & ne13,
169
+ constant int64_t & nb10,
170
+ constant int64_t & nb11,
171
+ constant int64_t & nb12,
172
+ constant int64_t & nb13,
173
+ constant int64_t & ne0,
174
+ constant int64_t & ne1,
175
+ constant int64_t & ne2,
176
+ constant int64_t & ne3,
177
+ constant int64_t & nb0,
178
+ constant int64_t & nb1,
179
+ constant int64_t & nb2,
180
+ constant int64_t & nb3,
181
+ uint3 tgpig[[threadgroup_position_in_grid]],
182
+ uint3 tpitg[[thread_position_in_threadgroup]],
183
+ uint3 ntg[[threads_per_threadgroup]]) {
184
+ const int64_t i03 = tgpig.z;
185
+ const int64_t i02 = tgpig.y;
186
+ const int64_t i01 = tgpig.x;
187
+
188
+ const int64_t i13 = i03 % ne13;
189
+ const int64_t i12 = i02 % ne12;
190
+ const int64_t i11 = i01 % ne11;
92
191
 
93
- src0_ptr += ntg.x*nb00;
94
- src1_ptr += ntg.x*nb10;
95
- dst_ptr += ntg.x*nb0;
192
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
193
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
194
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
195
+
196
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
197
+ const int i10 = i0 % ne10;
198
+ *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10));
96
199
  }
97
200
  }
98
201
 
@@ -102,28 +205,27 @@ kernel void kernel_add_row(
102
205
  device const float4 * src0,
103
206
  device const float4 * src1,
104
207
  device float4 * dst,
105
- constant int64_t & nb [[buffer(27)]],
208
+ constant int64_t & nb [[buffer(28)]],
106
209
  uint tpig[[thread_position_in_grid]]) {
107
210
  dst[tpig] = src0[tpig] + src1[tpig % nb];
108
211
  }
109
212
 
110
- kernel void kernel_mul(
213
+ kernel void kernel_mul_row(
111
214
  device const float4 * src0,
112
215
  device const float4 * src1,
113
216
  device float4 * dst,
217
+ constant int64_t & nb [[buffer(28)]],
114
218
  uint tpig[[thread_position_in_grid]]) {
115
- dst[tpig] = src0[tpig] * src1[tpig];
219
+ dst[tpig] = src0[tpig] * src1[tpig % nb];
116
220
  }
117
221
 
118
- // assumption: src1 is a row
119
- // broadcast src1 into src0
120
- kernel void kernel_mul_row(
222
+ kernel void kernel_div_row(
121
223
  device const float4 * src0,
122
224
  device const float4 * src1,
123
225
  device float4 * dst,
124
- constant int64_t & nb,
226
+ constant int64_t & nb [[buffer(28)]],
125
227
  uint tpig[[thread_position_in_grid]]) {
126
- dst[tpig] = src0[tpig] * src1[tpig % nb];
228
+ dst[tpig] = src0[tpig] / src1[tpig % nb];
127
229
  }
128
230
 
129
231
  kernel void kernel_scale(
@@ -142,14 +244,6 @@ kernel void kernel_scale_4(
142
244
  dst[tpig] = src0[tpig] * scale;
143
245
  }
144
246
 
145
- kernel void kernel_silu(
146
- device const float4 * src0,
147
- device float4 * dst,
148
- uint tpig[[thread_position_in_grid]]) {
149
- device const float4 & x = src0[tpig];
150
- dst[tpig] = x / (1.0f + exp(-x));
151
- }
152
-
153
247
  kernel void kernel_relu(
154
248
  device const float * src0,
155
249
  device float * dst,
@@ -157,15 +251,17 @@ kernel void kernel_relu(
157
251
  dst[tpig] = max(0.0f, src0[tpig]);
158
252
  }
159
253
 
160
- kernel void kernel_sqr(
254
+ kernel void kernel_tanh(
161
255
  device const float * src0,
162
256
  device float * dst,
163
257
  uint tpig[[thread_position_in_grid]]) {
164
- dst[tpig] = src0[tpig] * src0[tpig];
258
+ device const float & x = src0[tpig];
259
+ dst[tpig] = precise::tanh(x);
165
260
  }
166
261
 
167
- constant float GELU_COEF_A = 0.044715f;
168
- constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
262
+ constant float GELU_COEF_A = 0.044715f;
263
+ constant float GELU_QUICK_COEF = -1.702f;
264
+ constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
169
265
 
170
266
  kernel void kernel_gelu(
171
267
  device const float4 * src0,
@@ -180,6 +276,78 @@ kernel void kernel_gelu(
180
276
  dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
181
277
  }
182
278
 
279
+ kernel void kernel_gelu_quick(
280
+ device const float4 * src0,
281
+ device float4 * dst,
282
+ uint tpig[[thread_position_in_grid]]) {
283
+ device const float4 & x = src0[tpig];
284
+
285
+ dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
286
+ }
287
+
288
+ kernel void kernel_silu(
289
+ device const float4 * src0,
290
+ device float4 * dst,
291
+ uint tpig[[thread_position_in_grid]]) {
292
+ device const float4 & x = src0[tpig];
293
+ dst[tpig] = x / (1.0f + exp(-x));
294
+ }
295
+
296
+ kernel void kernel_sqr(
297
+ device const float * src0,
298
+ device float * dst,
299
+ uint tpig[[thread_position_in_grid]]) {
300
+ dst[tpig] = src0[tpig] * src0[tpig];
301
+ }
302
+
303
+ kernel void kernel_sum_rows(
304
+ device const float * src0,
305
+ device float * dst,
306
+ constant int64_t & ne00,
307
+ constant int64_t & ne01,
308
+ constant int64_t & ne02,
309
+ constant int64_t & ne03,
310
+ constant int64_t & nb00,
311
+ constant int64_t & nb01,
312
+ constant int64_t & nb02,
313
+ constant int64_t & nb03,
314
+ constant int64_t & ne10,
315
+ constant int64_t & ne11,
316
+ constant int64_t & ne12,
317
+ constant int64_t & ne13,
318
+ constant int64_t & nb10,
319
+ constant int64_t & nb11,
320
+ constant int64_t & nb12,
321
+ constant int64_t & nb13,
322
+ constant int64_t & ne0,
323
+ constant int64_t & ne1,
324
+ constant int64_t & ne2,
325
+ constant int64_t & ne3,
326
+ constant int64_t & nb0,
327
+ constant int64_t & nb1,
328
+ constant int64_t & nb2,
329
+ constant int64_t & nb3,
330
+ uint3 tpig[[thread_position_in_grid]]) {
331
+ int64_t i3 = tpig.z;
332
+ int64_t i2 = tpig.y;
333
+ int64_t i1 = tpig.x;
334
+
335
+ if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
336
+ return;
337
+ }
338
+
339
+ device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
340
+ device float * dst_row = (device float *) ((device char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
341
+
342
+ float row_sum = 0;
343
+
344
+ for (int64_t i0 = 0; i0 < ne00; i0++) {
345
+ row_sum += src_row[i0];
346
+ }
347
+
348
+ dst_row[0] = row_sum;
349
+ }
350
+
183
351
  kernel void kernel_soft_max(
184
352
  device const float * src0,
185
353
  device const float * src1,
@@ -198,9 +366,9 @@ kernel void kernel_soft_max(
198
366
  const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
199
367
  const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
200
368
 
201
- device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
202
- device const float * pmask = src1 ? src1 + i01*ne00 : nullptr;
203
- device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
369
+ device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
370
+ device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
371
+ device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
204
372
 
205
373
  // parallel max
206
374
  float lmax = -INFINITY;
@@ -236,7 +404,12 @@ kernel void kernel_soft_max(
236
404
  pdst[i00] = exp_psrc0;
237
405
  }
238
406
 
407
+ // This barrier fixes a failing test
408
+ // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
409
+ threadgroup_barrier(mem_flags::mem_none);
410
+
239
411
  float sum = simd_sum(lsum);
412
+
240
413
  if (ntg > N_SIMDWIDTH) {
241
414
  if (sgitg == 0) {
242
415
  buf[tiisg] = 0.0f;
@@ -279,9 +452,9 @@ kernel void kernel_soft_max_4(
279
452
  const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
280
453
  const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
281
454
 
282
- device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
283
- device const float4 * pmask = src1 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
284
- device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
455
+ device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
456
+ device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
457
+ device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
285
458
 
286
459
  // parallel max
287
460
  float4 lmax4 = -INFINITY;
@@ -319,7 +492,13 @@ kernel void kernel_soft_max_4(
319
492
  }
320
493
 
321
494
  const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
495
+
496
+ // This barrier fixes a failing test
497
+ // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
498
+ threadgroup_barrier(mem_flags::mem_none);
499
+
322
500
  float sum = simd_sum(lsum);
501
+
323
502
  if (ntg > N_SIMDWIDTH) {
324
503
  if (sgitg == 0) {
325
504
  buf[tiisg] = 0.0f;
@@ -490,6 +669,94 @@ kernel void kernel_rms_norm(
490
669
  }
491
670
  }
492
671
 
672
+ kernel void kernel_group_norm(
673
+ device const float * src0,
674
+ device float * dst,
675
+ constant int64_t & ne00,
676
+ constant int64_t & ne01,
677
+ constant int64_t & ne02,
678
+ constant uint64_t & nb00,
679
+ constant uint64_t & nb01,
680
+ constant uint64_t & nb02,
681
+ constant int32_t & n_groups,
682
+ constant float & eps,
683
+ threadgroup float * buf [[threadgroup(0)]],
684
+ uint tgpig[[threadgroup_position_in_grid]],
685
+ uint tpitg[[thread_position_in_threadgroup]],
686
+ uint sgitg[[simdgroup_index_in_threadgroup]],
687
+ uint tiisg[[thread_index_in_simdgroup]],
688
+ uint ntg[[threads_per_threadgroup]]) {
689
+ const int64_t ne = ne00*ne01*ne02;
690
+ const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups);
691
+
692
+ int start = tgpig * gs;
693
+ int end = start + gs;
694
+
695
+ start += tpitg;
696
+
697
+ if (end >= ne) {
698
+ end = ne;
699
+ }
700
+
701
+ float tmp = 0.0f; // partial sum for thread in warp
702
+
703
+ for (int j = start; j < end; j += ntg) {
704
+ tmp += src0[j];
705
+ }
706
+
707
+ threadgroup_barrier(mem_flags::mem_threadgroup);
708
+ tmp = simd_sum(tmp);
709
+ if (ntg > N_SIMDWIDTH) {
710
+ if (sgitg == 0) {
711
+ buf[tiisg] = 0.0f;
712
+ }
713
+
714
+ threadgroup_barrier(mem_flags::mem_threadgroup);
715
+
716
+ if (tiisg == 0) {
717
+ buf[sgitg] = tmp;
718
+ }
719
+
720
+ threadgroup_barrier(mem_flags::mem_threadgroup);
721
+
722
+ tmp = buf[tiisg];
723
+ tmp = simd_sum(tmp);
724
+ }
725
+
726
+ const float mean = tmp / gs;
727
+ tmp = 0.0f;
728
+
729
+ for (int j = start; j < end; j += ntg) {
730
+ float xi = src0[j] - mean;
731
+ dst[j] = xi;
732
+ tmp += xi * xi;
733
+ }
734
+
735
+ tmp = simd_sum(tmp);
736
+ if (ntg > N_SIMDWIDTH) {
737
+ if (sgitg == 0) {
738
+ buf[tiisg] = 0.0f;
739
+ }
740
+
741
+ threadgroup_barrier(mem_flags::mem_threadgroup);
742
+
743
+ if (tiisg == 0) {
744
+ buf[sgitg] = tmp;
745
+ }
746
+
747
+ threadgroup_barrier(mem_flags::mem_threadgroup);
748
+
749
+ tmp = buf[tiisg];
750
+ tmp = simd_sum(tmp);
751
+ }
752
+
753
+ const float variance = tmp / gs;
754
+ const float scale = 1.0f/sqrt(variance + eps);
755
+ for (int j = start; j < end; j += ntg) {
756
+ dst[j] *= scale;
757
+ }
758
+ }
759
+
493
760
  // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
494
761
  // il indicates where the q4 quants begin (0 or QK4_0/4)
495
762
  // we assume that the yl's have been multiplied with the appropriate scale factor
@@ -582,9 +849,20 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
582
849
  // giard against the number of rows not being divisible by
583
850
  // N_DST, so this is another explicit assumption of the implementation.
584
851
  template<typename block_q_type, int nr, int nsg, int nw>
585
- void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
586
- int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
587
- uint3 tgpig, uint tiisg, uint sgitg) {
852
+ void mul_vec_q_n_f32_impl(
853
+ device const void * src0,
854
+ device const float * src1,
855
+ device float * dst,
856
+ int64_t ne00,
857
+ int64_t ne01,
858
+ int64_t ne02,
859
+ int64_t ne10,
860
+ int64_t ne12,
861
+ int64_t ne0,
862
+ int64_t ne1,
863
+ uint r2,
864
+ uint r3,
865
+ uint3 tgpig, uint tiisg, uint sgitg) {
588
866
  const int nb = ne00/QK4_0;
589
867
 
590
868
  const int r0 = tgpig.x;
@@ -593,7 +871,10 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
593
871
 
594
872
  const int first_row = (r0 * nsg + sgitg) * nr;
595
873
 
596
- const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
874
+ const uint i12 = im%ne12;
875
+ const uint i13 = im/ne12;
876
+
877
+ const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
597
878
 
598
879
  device const block_q_type * x = (device const block_q_type *) src0 + offset0;
599
880
  device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
@@ -643,13 +924,14 @@ kernel void kernel_mul_mv_q4_0_f32(
643
924
  constant int64_t & ne02[[buffer(5)]],
644
925
  constant int64_t & ne10[[buffer(9)]],
645
926
  constant int64_t & ne12[[buffer(11)]],
646
- constant int64_t & ne0[[buffer(15)]],
647
- constant int64_t & ne1[[buffer(16)]],
648
- constant uint & gqa[[buffer(17)]],
927
+ constant int64_t & ne0 [[buffer(15)]],
928
+ constant int64_t & ne1 [[buffer(16)]],
929
+ constant uint & r2 [[buffer(17)]],
930
+ constant uint & r3 [[buffer(18)]],
649
931
  uint3 tgpig[[threadgroup_position_in_grid]],
650
932
  uint tiisg[[thread_index_in_simdgroup]],
651
933
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
652
- mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
934
+ mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
653
935
  }
654
936
 
655
937
  kernel void kernel_mul_mv_q4_1_f32(
@@ -661,13 +943,14 @@ kernel void kernel_mul_mv_q4_1_f32(
661
943
  constant int64_t & ne02[[buffer(5)]],
662
944
  constant int64_t & ne10[[buffer(9)]],
663
945
  constant int64_t & ne12[[buffer(11)]],
664
- constant int64_t & ne0[[buffer(15)]],
665
- constant int64_t & ne1[[buffer(16)]],
666
- constant uint & gqa[[buffer(17)]],
946
+ constant int64_t & ne0 [[buffer(15)]],
947
+ constant int64_t & ne1 [[buffer(16)]],
948
+ constant uint & r2 [[buffer(17)]],
949
+ constant uint & r3 [[buffer(18)]],
667
950
  uint3 tgpig[[threadgroup_position_in_grid]],
668
951
  uint tiisg[[thread_index_in_simdgroup]],
669
952
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
670
- mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
953
+ mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
671
954
  }
672
955
 
673
956
  kernel void kernel_mul_mv_q5_0_f32(
@@ -679,13 +962,14 @@ kernel void kernel_mul_mv_q5_0_f32(
679
962
  constant int64_t & ne02[[buffer(5)]],
680
963
  constant int64_t & ne10[[buffer(9)]],
681
964
  constant int64_t & ne12[[buffer(11)]],
682
- constant int64_t & ne0[[buffer(15)]],
683
- constant int64_t & ne1[[buffer(16)]],
684
- constant uint & gqa[[buffer(17)]],
965
+ constant int64_t & ne0 [[buffer(15)]],
966
+ constant int64_t & ne1 [[buffer(16)]],
967
+ constant uint & r2 [[buffer(17)]],
968
+ constant uint & r3 [[buffer(18)]],
685
969
  uint3 tgpig[[threadgroup_position_in_grid]],
686
970
  uint tiisg[[thread_index_in_simdgroup]],
687
971
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
688
- mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
972
+ mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
689
973
  }
690
974
 
691
975
  kernel void kernel_mul_mv_q5_1_f32(
@@ -697,33 +981,35 @@ kernel void kernel_mul_mv_q5_1_f32(
697
981
  constant int64_t & ne02[[buffer(5)]],
698
982
  constant int64_t & ne10[[buffer(9)]],
699
983
  constant int64_t & ne12[[buffer(11)]],
700
- constant int64_t & ne0[[buffer(15)]],
701
- constant int64_t & ne1[[buffer(16)]],
702
- constant uint & gqa[[buffer(17)]],
984
+ constant int64_t & ne0 [[buffer(15)]],
985
+ constant int64_t & ne1 [[buffer(16)]],
986
+ constant uint & r2 [[buffer(17)]],
987
+ constant uint & r3 [[buffer(18)]],
703
988
  uint3 tgpig[[threadgroup_position_in_grid]],
704
989
  uint tiisg[[thread_index_in_simdgroup]],
705
990
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
706
- mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
991
+ mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
707
992
  }
708
993
 
709
994
 
710
995
  #define NB_Q8_0 8
711
996
 
712
- kernel void kernel_mul_mv_q8_0_f32(
997
+ void kernel_mul_mv_q8_0_f32_impl(
713
998
  device const void * src0,
714
999
  device const float * src1,
715
1000
  device float * dst,
716
1001
  constant int64_t & ne00,
717
- constant int64_t & ne01[[buffer(4)]],
718
- constant int64_t & ne02[[buffer(5)]],
719
- constant int64_t & ne10[[buffer(9)]],
720
- constant int64_t & ne12[[buffer(11)]],
721
- constant int64_t & ne0[[buffer(15)]],
722
- constant int64_t & ne1[[buffer(16)]],
723
- constant uint & gqa[[buffer(17)]],
1002
+ constant int64_t & ne01,
1003
+ constant int64_t & ne02,
1004
+ constant int64_t & ne10,
1005
+ constant int64_t & ne12,
1006
+ constant int64_t & ne0,
1007
+ constant int64_t & ne1,
1008
+ constant uint & r2,
1009
+ constant uint & r3,
724
1010
  uint3 tgpig[[threadgroup_position_in_grid]],
725
- uint tiisg[[thread_index_in_simdgroup]],
726
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
1011
+ uint tiisg[[thread_index_in_simdgroup]],
1012
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
727
1013
  const int nr = N_DST;
728
1014
  const int nsg = N_SIMDGROUP;
729
1015
  const int nw = N_SIMDWIDTH;
@@ -732,8 +1018,14 @@ kernel void kernel_mul_mv_q8_0_f32(
732
1018
  const int r0 = tgpig.x;
733
1019
  const int r1 = tgpig.y;
734
1020
  const int im = tgpig.z;
1021
+
735
1022
  const int first_row = (r0 * nsg + sgitg) * nr;
736
- const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
1023
+
1024
+ const uint i12 = im%ne12;
1025
+ const uint i13 = im/ne12;
1026
+
1027
+ const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
1028
+
737
1029
  device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
738
1030
  device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
739
1031
 
@@ -771,11 +1063,31 @@ kernel void kernel_mul_mv_q8_0_f32(
771
1063
  }
772
1064
  }
773
1065
 
774
- #define N_F32_F32 4
775
-
776
- kernel void kernel_mul_mv_f32_f32(
777
- device const char * src0,
778
- device const char * src1,
1066
+ [[host_name("kernel_mul_mv_q8_0_f32")]]
1067
+ kernel void kernel_mul_mv_q8_0_f32(
1068
+ device const void * src0,
1069
+ device const float * src1,
1070
+ device float * dst,
1071
+ constant int64_t & ne00,
1072
+ constant int64_t & ne01,
1073
+ constant int64_t & ne02,
1074
+ constant int64_t & ne10,
1075
+ constant int64_t & ne12,
1076
+ constant int64_t & ne0,
1077
+ constant int64_t & ne1,
1078
+ constant uint & r2 [[buffer(17)]],
1079
+ constant uint & r3 [[buffer(18)]],
1080
+ uint3 tgpig[[threadgroup_position_in_grid]],
1081
+ uint tiisg[[thread_index_in_simdgroup]],
1082
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1083
+ kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
1084
+ }
1085
+
1086
+ #define N_F32_F32 4
1087
+
1088
+ void kernel_mul_mv_f32_f32_impl(
1089
+ device const char * src0,
1090
+ device const char * src1,
779
1091
  device float * dst,
780
1092
  constant int64_t & ne00,
781
1093
  constant int64_t & ne01,
@@ -791,6 +1103,8 @@ kernel void kernel_mul_mv_f32_f32(
791
1103
  constant uint64_t & nb12,
792
1104
  constant int64_t & ne0,
793
1105
  constant int64_t & ne1,
1106
+ constant uint & r2,
1107
+ constant uint & r3,
794
1108
  uint3 tgpig[[threadgroup_position_in_grid]],
795
1109
  uint tiisg[[thread_index_in_simdgroup]]) {
796
1110
 
@@ -798,7 +1112,12 @@ kernel void kernel_mul_mv_f32_f32(
798
1112
  const int64_t rb = tgpig.y*N_F32_F32;
799
1113
  const int64_t im = tgpig.z;
800
1114
 
801
- device const float * x = (device const float *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
1115
+ const uint i12 = im%ne12;
1116
+ const uint i13 = im/ne12;
1117
+
1118
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1119
+
1120
+ device const float * x = (device const float *) (src0 + offset0);
802
1121
 
803
1122
  if (ne00 < 128) {
804
1123
  for (int row = 0; row < N_F32_F32; ++row) {
@@ -844,6 +1163,32 @@ kernel void kernel_mul_mv_f32_f32(
844
1163
  }
845
1164
  }
846
1165
 
1166
+ [[host_name("kernel_mul_mv_f32_f32")]]
1167
+ kernel void kernel_mul_mv_f32_f32(
1168
+ device const char * src0,
1169
+ device const char * src1,
1170
+ device float * dst,
1171
+ constant int64_t & ne00,
1172
+ constant int64_t & ne01,
1173
+ constant int64_t & ne02,
1174
+ constant uint64_t & nb00,
1175
+ constant uint64_t & nb01,
1176
+ constant uint64_t & nb02,
1177
+ constant int64_t & ne10,
1178
+ constant int64_t & ne11,
1179
+ constant int64_t & ne12,
1180
+ constant uint64_t & nb10,
1181
+ constant uint64_t & nb11,
1182
+ constant uint64_t & nb12,
1183
+ constant int64_t & ne0,
1184
+ constant int64_t & ne1,
1185
+ constant uint & r2 [[buffer(17)]],
1186
+ constant uint & r3 [[buffer(18)]],
1187
+ uint3 tgpig[[threadgroup_position_in_grid]],
1188
+ uint tiisg[[thread_index_in_simdgroup]]) {
1189
+ kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
1190
+ }
1191
+
847
1192
  #define N_F16_F16 4
848
1193
 
849
1194
  kernel void kernel_mul_mv_f16_f16(
@@ -864,6 +1209,8 @@ kernel void kernel_mul_mv_f16_f16(
864
1209
  constant uint64_t & nb12,
865
1210
  constant int64_t & ne0,
866
1211
  constant int64_t & ne1,
1212
+ constant uint & r2 [[buffer(17)]],
1213
+ constant uint & r3 [[buffer(18)]],
867
1214
  uint3 tgpig[[threadgroup_position_in_grid]],
868
1215
  uint tiisg[[thread_index_in_simdgroup]]) {
869
1216
 
@@ -871,7 +1218,12 @@ kernel void kernel_mul_mv_f16_f16(
871
1218
  const int64_t rb = tgpig.y*N_F16_F16;
872
1219
  const int64_t im = tgpig.z;
873
1220
 
874
- device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
1221
+ const uint i12 = im%ne12;
1222
+ const uint i13 = im/ne12;
1223
+
1224
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1225
+
1226
+ device const half * x = (device const half *) (src0 + offset0);
875
1227
 
876
1228
  if (ne00 < 128) {
877
1229
  for (int row = 0; row < N_F16_F16; ++row) {
@@ -917,7 +1269,7 @@ kernel void kernel_mul_mv_f16_f16(
917
1269
  }
918
1270
  }
919
1271
 
920
- kernel void kernel_mul_mv_f16_f32_1row(
1272
+ void kernel_mul_mv_f16_f32_1row_impl(
921
1273
  device const char * src0,
922
1274
  device const char * src1,
923
1275
  device float * dst,
@@ -935,6 +1287,8 @@ kernel void kernel_mul_mv_f16_f32_1row(
935
1287
  constant uint64_t & nb12,
936
1288
  constant int64_t & ne0,
937
1289
  constant int64_t & ne1,
1290
+ constant uint & r2,
1291
+ constant uint & r3,
938
1292
  uint3 tgpig[[threadgroup_position_in_grid]],
939
1293
  uint tiisg[[thread_index_in_simdgroup]]) {
940
1294
 
@@ -942,7 +1296,12 @@ kernel void kernel_mul_mv_f16_f32_1row(
942
1296
  const int64_t r1 = tgpig.y;
943
1297
  const int64_t im = tgpig.z;
944
1298
 
945
- device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
1299
+ const uint i12 = im%ne12;
1300
+ const uint i13 = im/ne12;
1301
+
1302
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1303
+
1304
+ device const half * x = (device const half *) (src0 + offset0);
946
1305
  device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
947
1306
 
948
1307
  float sumf = 0;
@@ -966,12 +1325,37 @@ kernel void kernel_mul_mv_f16_f32_1row(
966
1325
  dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
967
1326
  }
968
1327
  }
1328
+ }
969
1329
 
1330
+ [[host_name("kernel_mul_mv_f16_f32_1row")]]
1331
+ kernel void kernel_mul_mv_f16_f32_1row(
1332
+ device const char * src0,
1333
+ device const char * src1,
1334
+ device float * dst,
1335
+ constant int64_t & ne00,
1336
+ constant int64_t & ne01,
1337
+ constant int64_t & ne02,
1338
+ constant uint64_t & nb00,
1339
+ constant uint64_t & nb01,
1340
+ constant uint64_t & nb02,
1341
+ constant int64_t & ne10,
1342
+ constant int64_t & ne11,
1343
+ constant int64_t & ne12,
1344
+ constant uint64_t & nb10,
1345
+ constant uint64_t & nb11,
1346
+ constant uint64_t & nb12,
1347
+ constant int64_t & ne0,
1348
+ constant int64_t & ne1,
1349
+ constant uint & r2 [[buffer(17)]],
1350
+ constant uint & r3 [[buffer(18)]],
1351
+ uint3 tgpig[[threadgroup_position_in_grid]],
1352
+ uint tiisg[[thread_index_in_simdgroup]]) {
1353
+ kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
970
1354
  }
971
1355
 
972
1356
  #define N_F16_F32 4
973
1357
 
974
- kernel void kernel_mul_mv_f16_f32(
1358
+ void kernel_mul_mv_f16_f32_impl(
975
1359
  device const char * src0,
976
1360
  device const char * src1,
977
1361
  device float * dst,
@@ -989,6 +1373,8 @@ kernel void kernel_mul_mv_f16_f32(
989
1373
  constant uint64_t & nb12,
990
1374
  constant int64_t & ne0,
991
1375
  constant int64_t & ne1,
1376
+ constant uint & r2,
1377
+ constant uint & r3,
992
1378
  uint3 tgpig[[threadgroup_position_in_grid]],
993
1379
  uint tiisg[[thread_index_in_simdgroup]]) {
994
1380
 
@@ -996,7 +1382,12 @@ kernel void kernel_mul_mv_f16_f32(
996
1382
  const int64_t rb = tgpig.y*N_F16_F32;
997
1383
  const int64_t im = tgpig.z;
998
1384
 
999
- device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
1385
+ const uint i12 = im%ne12;
1386
+ const uint i13 = im/ne12;
1387
+
1388
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1389
+
1390
+ device const half * x = (device const half *) (src0 + offset0);
1000
1391
 
1001
1392
  if (ne00 < 128) {
1002
1393
  for (int row = 0; row < N_F16_F32; ++row) {
@@ -1042,6 +1433,32 @@ kernel void kernel_mul_mv_f16_f32(
1042
1433
  }
1043
1434
  }
1044
1435
 
1436
+ [[host_name("kernel_mul_mv_f16_f32")]]
1437
+ kernel void kernel_mul_mv_f16_f32(
1438
+ device const char * src0,
1439
+ device const char * src1,
1440
+ device float * dst,
1441
+ constant int64_t & ne00,
1442
+ constant int64_t & ne01,
1443
+ constant int64_t & ne02,
1444
+ constant uint64_t & nb00,
1445
+ constant uint64_t & nb01,
1446
+ constant uint64_t & nb02,
1447
+ constant int64_t & ne10,
1448
+ constant int64_t & ne11,
1449
+ constant int64_t & ne12,
1450
+ constant uint64_t & nb10,
1451
+ constant uint64_t & nb11,
1452
+ constant uint64_t & nb12,
1453
+ constant int64_t & ne0,
1454
+ constant int64_t & ne1,
1455
+ constant uint & r2 [[buffer(17)]],
1456
+ constant uint & r3 [[buffer(18)]],
1457
+ uint3 tgpig[[threadgroup_position_in_grid]],
1458
+ uint tiisg[[thread_index_in_simdgroup]]) {
1459
+ kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
1460
+ }
1461
+
1045
1462
  // Assumes row size (ne00) is a multiple of 4
1046
1463
  kernel void kernel_mul_mv_f16_f32_l4(
1047
1464
  device const char * src0,
@@ -1061,6 +1478,8 @@ kernel void kernel_mul_mv_f16_f32_l4(
1061
1478
  constant uint64_t & nb12,
1062
1479
  constant int64_t & ne0,
1063
1480
  constant int64_t & ne1,
1481
+ constant uint & r2 [[buffer(17)]],
1482
+ constant uint & r3 [[buffer(18)]],
1064
1483
  uint3 tgpig[[threadgroup_position_in_grid]],
1065
1484
  uint tiisg[[thread_index_in_simdgroup]]) {
1066
1485
 
@@ -1068,7 +1487,12 @@ kernel void kernel_mul_mv_f16_f32_l4(
1068
1487
  const int64_t r0 = tgpig.x;
1069
1488
  const int64_t im = tgpig.z;
1070
1489
 
1071
- device const half4 * x4 = (device const half4 *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
1490
+ const uint i12 = im%ne12;
1491
+ const uint i13 = im/ne12;
1492
+
1493
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1494
+
1495
+ device const half4 * x4 = (device const half4 *) (src0 + offset0);
1072
1496
 
1073
1497
  for (int r1 = 0; r1 < nrows; ++r1) {
1074
1498
  device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
@@ -1120,17 +1544,21 @@ kernel void kernel_alibi_f32(
1120
1544
  const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1121
1545
  const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1122
1546
  const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1547
+ const int64_t k = i3*ne3 + i2;
1123
1548
 
1124
- device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1125
1549
  float m_k;
1126
- if (i2 < n_heads_log2_floor) {
1127
- m_k = pow(m0, i2 + 1);
1550
+ if (k < n_heads_log2_floor) {
1551
+ m_k = pow(m0, k + 1);
1128
1552
  } else {
1129
- m_k = pow(m1, 2 * (i2 - n_heads_log2_floor) + 1);
1553
+ m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1);
1130
1554
  }
1555
+
1556
+ device char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1;
1557
+ device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01;
1131
1558
  for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
1132
- device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1133
- dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);
1559
+ const float src_v = *(device float *)(src_row + i00*nb00);
1560
+ device float * dst_v = (device float *)(dst_row + i00*nb0);
1561
+ *dst_v = i00 * m_k + src_v;
1134
1562
  }
1135
1563
  }
1136
1564
 
@@ -1335,9 +1763,160 @@ kernel void kernel_im2col_f16(
1335
1763
  }
1336
1764
  }
1337
1765
 
1766
+ kernel void kernel_upscale_f32(
1767
+ device const char * src0,
1768
+ device char * dst,
1769
+ constant int64_t & ne00,
1770
+ constant int64_t & ne01,
1771
+ constant int64_t & ne02,
1772
+ constant int64_t & ne03,
1773
+ constant uint64_t & nb00,
1774
+ constant uint64_t & nb01,
1775
+ constant uint64_t & nb02,
1776
+ constant uint64_t & nb03,
1777
+ constant int64_t & ne0,
1778
+ constant int64_t & ne1,
1779
+ constant int64_t & ne2,
1780
+ constant int64_t & ne3,
1781
+ constant uint64_t & nb0,
1782
+ constant uint64_t & nb1,
1783
+ constant uint64_t & nb2,
1784
+ constant uint64_t & nb3,
1785
+ constant int32_t & sf,
1786
+ uint3 tgpig[[threadgroup_position_in_grid]],
1787
+ uint3 tpitg[[thread_position_in_threadgroup]],
1788
+ uint3 ntg[[threads_per_threadgroup]]) {
1789
+
1790
+ const int64_t i3 = tgpig.z;
1791
+ const int64_t i2 = tgpig.y;
1792
+ const int64_t i1 = tgpig.x;
1793
+
1794
+ const int64_t i03 = i3;
1795
+ const int64_t i02 = i2;
1796
+ const int64_t i01 = i1/sf;
1797
+
1798
+ device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
1799
+ device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
1800
+
1801
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
1802
+ dst_ptr[i0] = src0_ptr[i0/sf];
1803
+ }
1804
+ }
1805
+
1806
+ kernel void kernel_pad_f32(
1807
+ device const char * src0,
1808
+ device char * dst,
1809
+ constant int64_t & ne00,
1810
+ constant int64_t & ne01,
1811
+ constant int64_t & ne02,
1812
+ constant int64_t & ne03,
1813
+ constant uint64_t & nb00,
1814
+ constant uint64_t & nb01,
1815
+ constant uint64_t & nb02,
1816
+ constant uint64_t & nb03,
1817
+ constant int64_t & ne0,
1818
+ constant int64_t & ne1,
1819
+ constant int64_t & ne2,
1820
+ constant int64_t & ne3,
1821
+ constant uint64_t & nb0,
1822
+ constant uint64_t & nb1,
1823
+ constant uint64_t & nb2,
1824
+ constant uint64_t & nb3,
1825
+ uint3 tgpig[[threadgroup_position_in_grid]],
1826
+ uint3 tpitg[[thread_position_in_threadgroup]],
1827
+ uint3 ntg[[threads_per_threadgroup]]) {
1828
+
1829
+ const int64_t i3 = tgpig.z;
1830
+ const int64_t i2 = tgpig.y;
1831
+ const int64_t i1 = tgpig.x;
1832
+
1833
+ const int64_t i03 = i3;
1834
+ const int64_t i02 = i2;
1835
+ const int64_t i01 = i1;
1836
+
1837
+ device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
1838
+ device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
1839
+
1840
+ if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
1841
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
1842
+ if (i0 < ne00) {
1843
+ dst_ptr[i0] = src0_ptr[i0];
1844
+ } else {
1845
+ dst_ptr[i0] = 0.0f;
1846
+ }
1847
+ }
1848
+
1849
+ return;
1850
+ }
1851
+
1852
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
1853
+ dst_ptr[i0] = 0.0f;
1854
+ }
1855
+ }
1856
+
1857
+ // bitonic sort implementation following the CUDA kernels as reference
1858
+ typedef void (argsort_t)(
1859
+ device const float * x,
1860
+ device int32_t * dst,
1861
+ constant int64_t & ncols,
1862
+ uint3 tgpig[[threadgroup_position_in_grid]],
1863
+ uint3 tpitg[[thread_position_in_threadgroup]]);
1864
+
1865
+ template<ggml_sort_order order>
1866
+ kernel void kernel_argsort_f32_i32(
1867
+ device const float * x,
1868
+ device int32_t * dst,
1869
+ constant int64_t & ncols,
1870
+ uint3 tgpig[[threadgroup_position_in_grid]],
1871
+ uint3 tpitg[[thread_position_in_threadgroup]]) {
1872
+ // bitonic sort
1873
+ int col = tpitg[0];
1874
+ int row = tgpig[1];
1875
+
1876
+ if (col >= ncols) return;
1877
+
1878
+ device const float * x_row = x + row * ncols;
1879
+ device int32_t * dst_row = dst + row * ncols;
1880
+
1881
+ // initialize indices
1882
+ if (col < ncols) {
1883
+ dst_row[col] = col;
1884
+ }
1885
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1886
+
1887
+ for (int k = 2; k <= ncols; k *= 2) {
1888
+ for (int j = k / 2; j > 0; j /= 2) {
1889
+ int ixj = col ^ j;
1890
+ if (ixj > col) {
1891
+ if ((col & k) == 0) {
1892
+ if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
1893
+ SWAP(dst_row[col], dst_row[ixj]);
1894
+ }
1895
+ } else {
1896
+ if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
1897
+ SWAP(dst_row[col], dst_row[ixj]);
1898
+ }
1899
+ }
1900
+ }
1901
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1902
+ }
1903
+ }
1904
+ }
1905
+
1906
+ template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ASC>;
1907
+ template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
1908
+
1909
+ kernel void kernel_leaky_relu_f32(
1910
+ device const float * src0,
1911
+ device float * dst,
1912
+ constant float & slope,
1913
+ uint tpig[[thread_position_in_grid]]) {
1914
+ dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
1915
+ }
1916
+
1338
1917
  kernel void kernel_cpy_f16_f16(
1339
- device const half * src0,
1340
- device half * dst,
1918
+ device const half * src0,
1919
+ device half * dst,
1341
1920
  constant int64_t & ne00,
1342
1921
  constant int64_t & ne01,
1343
1922
  constant int64_t & ne02,
@@ -1376,6 +1955,47 @@ kernel void kernel_cpy_f16_f16(
1376
1955
  }
1377
1956
  }
1378
1957
 
1958
+ kernel void kernel_cpy_f16_f32(
1959
+ device const half * src0,
1960
+ device float * dst,
1961
+ constant int64_t & ne00,
1962
+ constant int64_t & ne01,
1963
+ constant int64_t & ne02,
1964
+ constant int64_t & ne03,
1965
+ constant uint64_t & nb00,
1966
+ constant uint64_t & nb01,
1967
+ constant uint64_t & nb02,
1968
+ constant uint64_t & nb03,
1969
+ constant int64_t & ne0,
1970
+ constant int64_t & ne1,
1971
+ constant int64_t & ne2,
1972
+ constant int64_t & ne3,
1973
+ constant uint64_t & nb0,
1974
+ constant uint64_t & nb1,
1975
+ constant uint64_t & nb2,
1976
+ constant uint64_t & nb3,
1977
+ uint3 tgpig[[threadgroup_position_in_grid]],
1978
+ uint3 tpitg[[thread_position_in_threadgroup]],
1979
+ uint3 ntg[[threads_per_threadgroup]]) {
1980
+ const int64_t i03 = tgpig[2];
1981
+ const int64_t i02 = tgpig[1];
1982
+ const int64_t i01 = tgpig[0];
1983
+
1984
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
1985
+
1986
+ const int64_t i3 = n / (ne2*ne1*ne0);
1987
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1988
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1989
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1990
+
1991
+ device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1992
+
1993
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
1994
+ device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1995
+ dst_data[i00] = src[0];
1996
+ }
1997
+ }
1998
+
1379
1999
  kernel void kernel_cpy_f32_f16(
1380
2000
  device const float * src0,
1381
2001
  device half * dst,
@@ -1460,47 +2080,238 @@ kernel void kernel_cpy_f32_f32(
1460
2080
  }
1461
2081
  }
1462
2082
 
1463
- kernel void kernel_concat(
1464
- device const char * src0,
1465
- device const char * src1,
1466
- device char * dst,
1467
- constant int64_t & ne00,
1468
- constant int64_t & ne01,
1469
- constant int64_t & ne02,
1470
- constant int64_t & ne03,
1471
- constant uint64_t & nb00,
1472
- constant uint64_t & nb01,
1473
- constant uint64_t & nb02,
1474
- constant uint64_t & nb03,
1475
- constant int64_t & ne10,
1476
- constant int64_t & ne11,
1477
- constant int64_t & ne12,
1478
- constant int64_t & ne13,
1479
- constant uint64_t & nb10,
1480
- constant uint64_t & nb11,
1481
- constant uint64_t & nb12,
1482
- constant uint64_t & nb13,
1483
- constant int64_t & ne0,
1484
- constant int64_t & ne1,
1485
- constant int64_t & ne2,
1486
- constant int64_t & ne3,
1487
- constant uint64_t & nb0,
1488
- constant uint64_t & nb1,
1489
- constant uint64_t & nb2,
1490
- constant uint64_t & nb3,
1491
- uint3 tgpig[[threadgroup_position_in_grid]],
1492
- uint3 tpitg[[thread_position_in_threadgroup]],
1493
- uint3 ntg[[threads_per_threadgroup]]) {
1494
-
1495
- const int64_t i03 = tgpig.z;
1496
- const int64_t i02 = tgpig.y;
1497
- const int64_t i01 = tgpig.x;
1498
-
1499
- const int64_t i13 = i03 % ne13;
2083
+ kernel void kernel_cpy_f32_q8_0(
2084
+ device const float * src0,
2085
+ device void * dst,
2086
+ constant int64_t & ne00,
2087
+ constant int64_t & ne01,
2088
+ constant int64_t & ne02,
2089
+ constant int64_t & ne03,
2090
+ constant uint64_t & nb00,
2091
+ constant uint64_t & nb01,
2092
+ constant uint64_t & nb02,
2093
+ constant uint64_t & nb03,
2094
+ constant int64_t & ne0,
2095
+ constant int64_t & ne1,
2096
+ constant int64_t & ne2,
2097
+ constant int64_t & ne3,
2098
+ constant uint64_t & nb0,
2099
+ constant uint64_t & nb1,
2100
+ constant uint64_t & nb2,
2101
+ constant uint64_t & nb3,
2102
+ uint3 tgpig[[threadgroup_position_in_grid]],
2103
+ uint3 tpitg[[thread_position_in_threadgroup]],
2104
+ uint3 ntg[[threads_per_threadgroup]]) {
2105
+ const int64_t i03 = tgpig[2];
2106
+ const int64_t i02 = tgpig[1];
2107
+ const int64_t i01 = tgpig[0];
2108
+
2109
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2110
+
2111
+ const int64_t i3 = n / (ne2*ne1*ne0);
2112
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2113
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2114
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK8_0;
2115
+
2116
+ device block_q8_0 * dst_data = (device block_q8_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2117
+
2118
+ for (int64_t i00 = tpitg.x*QK8_0; i00 < ne00; i00 += ntg.x*QK8_0) {
2119
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2120
+
2121
+ float amax = 0.0f; // absolute max
2122
+
2123
+ for (int j = 0; j < QK8_0; j++) {
2124
+ const float v = src[j];
2125
+ amax = MAX(amax, fabs(v));
2126
+ }
2127
+
2128
+ const float d = amax / ((1 << 7) - 1);
2129
+ const float id = d ? 1.0f/d : 0.0f;
2130
+
2131
+ dst_data[i00/QK8_0].d = d;
2132
+
2133
+ for (int j = 0; j < QK8_0; ++j) {
2134
+ const float x0 = src[j]*id;
2135
+
2136
+ dst_data[i00/QK8_0].qs[j] = round(x0);
2137
+ }
2138
+ }
2139
+ }
2140
+
2141
+ kernel void kernel_cpy_f32_q4_0(
2142
+ device const float * src0,
2143
+ device void * dst,
2144
+ constant int64_t & ne00,
2145
+ constant int64_t & ne01,
2146
+ constant int64_t & ne02,
2147
+ constant int64_t & ne03,
2148
+ constant uint64_t & nb00,
2149
+ constant uint64_t & nb01,
2150
+ constant uint64_t & nb02,
2151
+ constant uint64_t & nb03,
2152
+ constant int64_t & ne0,
2153
+ constant int64_t & ne1,
2154
+ constant int64_t & ne2,
2155
+ constant int64_t & ne3,
2156
+ constant uint64_t & nb0,
2157
+ constant uint64_t & nb1,
2158
+ constant uint64_t & nb2,
2159
+ constant uint64_t & nb3,
2160
+ uint3 tgpig[[threadgroup_position_in_grid]],
2161
+ uint3 tpitg[[thread_position_in_threadgroup]],
2162
+ uint3 ntg[[threads_per_threadgroup]]) {
2163
+ const int64_t i03 = tgpig[2];
2164
+ const int64_t i02 = tgpig[1];
2165
+ const int64_t i01 = tgpig[0];
2166
+
2167
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2168
+
2169
+ const int64_t i3 = n / (ne2*ne1*ne0);
2170
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2171
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2172
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_0;
2173
+
2174
+ device block_q4_0 * dst_data = (device block_q4_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2175
+
2176
+ for (int64_t i00 = tpitg.x*QK4_0; i00 < ne00; i00 += ntg.x*QK4_0) {
2177
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2178
+
2179
+ float amax = 0.0f; // absolute max
2180
+ float max = 0.0f;
2181
+
2182
+ for (int j = 0; j < QK4_0; j++) {
2183
+ const float v = src[j];
2184
+ if (amax < fabs(v)) {
2185
+ amax = fabs(v);
2186
+ max = v;
2187
+ }
2188
+ }
2189
+
2190
+ const float d = max / -8;
2191
+ const float id = d ? 1.0f/d : 0.0f;
2192
+
2193
+ dst_data[i00/QK4_0].d = d;
2194
+
2195
+ for (int j = 0; j < QK4_0/2; ++j) {
2196
+ const float x0 = src[0 + j]*id;
2197
+ const float x1 = src[QK4_0/2 + j]*id;
2198
+
2199
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
2200
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
2201
+
2202
+ dst_data[i00/QK4_0].qs[j] = xi0;
2203
+ dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
2204
+ }
2205
+ }
2206
+ }
2207
+
2208
+ kernel void kernel_cpy_f32_q4_1(
2209
+ device const float * src0,
2210
+ device void * dst,
2211
+ constant int64_t & ne00,
2212
+ constant int64_t & ne01,
2213
+ constant int64_t & ne02,
2214
+ constant int64_t & ne03,
2215
+ constant uint64_t & nb00,
2216
+ constant uint64_t & nb01,
2217
+ constant uint64_t & nb02,
2218
+ constant uint64_t & nb03,
2219
+ constant int64_t & ne0,
2220
+ constant int64_t & ne1,
2221
+ constant int64_t & ne2,
2222
+ constant int64_t & ne3,
2223
+ constant uint64_t & nb0,
2224
+ constant uint64_t & nb1,
2225
+ constant uint64_t & nb2,
2226
+ constant uint64_t & nb3,
2227
+ uint3 tgpig[[threadgroup_position_in_grid]],
2228
+ uint3 tpitg[[thread_position_in_threadgroup]],
2229
+ uint3 ntg[[threads_per_threadgroup]]) {
2230
+ const int64_t i03 = tgpig[2];
2231
+ const int64_t i02 = tgpig[1];
2232
+ const int64_t i01 = tgpig[0];
2233
+
2234
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2235
+
2236
+ const int64_t i3 = n / (ne2*ne1*ne0);
2237
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2238
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2239
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_1;
2240
+
2241
+ device block_q4_1 * dst_data = (device block_q4_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2242
+
2243
+ for (int64_t i00 = tpitg.x*QK4_1; i00 < ne00; i00 += ntg.x*QK4_1) {
2244
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2245
+
2246
+ float min = FLT_MAX;
2247
+ float max = -FLT_MAX;
2248
+
2249
+ for (int j = 0; j < QK4_1; j++) {
2250
+ const float v = src[j];
2251
+ if (min > v) min = v;
2252
+ if (max < v) max = v;
2253
+ }
2254
+
2255
+ const float d = (max - min) / ((1 << 4) - 1);
2256
+ const float id = d ? 1.0f/d : 0.0f;
2257
+
2258
+ dst_data[i00/QK4_1].d = d;
2259
+ dst_data[i00/QK4_1].m = min;
2260
+
2261
+ for (int j = 0; j < QK4_1/2; ++j) {
2262
+ const float x0 = (src[0 + j] - min)*id;
2263
+ const float x1 = (src[QK4_1/2 + j] - min)*id;
2264
+
2265
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
2266
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
2267
+
2268
+ dst_data[i00/QK4_1].qs[j] = xi0;
2269
+ dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
2270
+ }
2271
+ }
2272
+ }
2273
+
2274
+ kernel void kernel_concat(
2275
+ device const char * src0,
2276
+ device const char * src1,
2277
+ device char * dst,
2278
+ constant int64_t & ne00,
2279
+ constant int64_t & ne01,
2280
+ constant int64_t & ne02,
2281
+ constant int64_t & ne03,
2282
+ constant uint64_t & nb00,
2283
+ constant uint64_t & nb01,
2284
+ constant uint64_t & nb02,
2285
+ constant uint64_t & nb03,
2286
+ constant int64_t & ne10,
2287
+ constant int64_t & ne11,
2288
+ constant int64_t & ne12,
2289
+ constant int64_t & ne13,
2290
+ constant uint64_t & nb10,
2291
+ constant uint64_t & nb11,
2292
+ constant uint64_t & nb12,
2293
+ constant uint64_t & nb13,
2294
+ constant int64_t & ne0,
2295
+ constant int64_t & ne1,
2296
+ constant int64_t & ne2,
2297
+ constant int64_t & ne3,
2298
+ constant uint64_t & nb0,
2299
+ constant uint64_t & nb1,
2300
+ constant uint64_t & nb2,
2301
+ constant uint64_t & nb3,
2302
+ uint3 tgpig[[threadgroup_position_in_grid]],
2303
+ uint3 tpitg[[thread_position_in_threadgroup]],
2304
+ uint3 ntg[[threads_per_threadgroup]]) {
2305
+
2306
+ const int64_t i03 = tgpig.z;
2307
+ const int64_t i02 = tgpig.y;
2308
+ const int64_t i01 = tgpig.x;
2309
+
2310
+ const int64_t i13 = i03 % ne13;
1500
2311
  const int64_t i12 = i02 % ne12;
1501
2312
  const int64_t i11 = i01 % ne11;
1502
2313
 
1503
- device const char * src0_ptr = src0 + i03 * nb03 + i02 * nb02 + i01 * nb01 + tpitg.x*nb00;
2314
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
1504
2315
  device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
1505
2316
  device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
1506
2317
 
@@ -1608,32 +2419,39 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
1608
2419
 
1609
2420
  //====================================== dot products =========================
1610
2421
 
1611
- kernel void kernel_mul_mv_q2_K_f32(
2422
+ void kernel_mul_mv_q2_K_f32_impl(
1612
2423
  device const void * src0,
1613
2424
  device const float * src1,
1614
2425
  device float * dst,
1615
2426
  constant int64_t & ne00,
1616
- constant int64_t & ne01[[buffer(4)]],
1617
- constant int64_t & ne02[[buffer(5)]],
1618
- constant int64_t & ne10[[buffer(9)]],
1619
- constant int64_t & ne12[[buffer(11)]],
1620
- constant int64_t & ne0[[buffer(15)]],
1621
- constant int64_t & ne1[[buffer(16)]],
1622
- constant uint & gqa[[buffer(17)]],
2427
+ constant int64_t & ne01,
2428
+ constant int64_t & ne02,
2429
+ constant int64_t & ne10,
2430
+ constant int64_t & ne12,
2431
+ constant int64_t & ne0,
2432
+ constant int64_t & ne1,
2433
+ constant uint & r2,
2434
+ constant uint & r3,
1623
2435
  uint3 tgpig[[threadgroup_position_in_grid]],
1624
- uint tiisg[[thread_index_in_simdgroup]],
1625
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2436
+ uint tiisg[[thread_index_in_simdgroup]],
2437
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1626
2438
 
1627
2439
  const int nb = ne00/QK_K;
1628
2440
  const int r0 = tgpig.x;
1629
2441
  const int r1 = tgpig.y;
1630
- const int r2 = tgpig.z;
2442
+ const int im = tgpig.z;
1631
2443
 
1632
2444
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1633
2445
  const int ib_row = first_row * nb;
1634
- const uint offset0 = r2/gqa*(nb*ne0);
2446
+
2447
+ const uint i12 = im%ne12;
2448
+ const uint i13 = im/ne12;
2449
+
2450
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2451
+
1635
2452
  device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
1636
- device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2453
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2454
+
1637
2455
  float yl[32];
1638
2456
  float sumf[N_DST]={0.f}, all_sum;
1639
2457
 
@@ -1642,11 +2460,11 @@ kernel void kernel_mul_mv_q2_K_f32(
1642
2460
  #if QK_K == 256
1643
2461
  const int ix = tiisg/8; // 0...3
1644
2462
  const int it = tiisg%8; // 0...7
1645
- const int im = it/4; // 0 or 1
2463
+ const int iq = it/4; // 0 or 1
1646
2464
  const int ir = it%4; // 0...3
1647
2465
  const int is = (8*ir)/16;// 0 or 1
1648
2466
 
1649
- device const float * y4 = y + ix * QK_K + 128 * im + 8 * ir;
2467
+ device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
1650
2468
 
1651
2469
  for (int ib = ix; ib < nb; ib += 4) {
1652
2470
 
@@ -1658,8 +2476,8 @@ kernel void kernel_mul_mv_q2_K_f32(
1658
2476
  yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
1659
2477
  }
1660
2478
 
1661
- device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*im + is;
1662
- device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
2479
+ device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*iq + is;
2480
+ device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
1663
2481
  device const half * dh = &x[ib].d;
1664
2482
 
1665
2483
  for (int row = 0; row < N_DST; row++) {
@@ -1746,13 +2564,13 @@ kernel void kernel_mul_mv_q2_K_f32(
1746
2564
  for (int row = 0; row < N_DST; ++row) {
1747
2565
  all_sum = simd_sum(sumf[row]);
1748
2566
  if (tiisg == 0) {
1749
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
2567
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
1750
2568
  }
1751
2569
  }
1752
2570
  }
1753
2571
 
1754
- #if QK_K == 256
1755
- kernel void kernel_mul_mv_q3_K_f32(
2572
+ [[host_name("kernel_mul_mv_q2_K_f32")]]
2573
+ kernel void kernel_mul_mv_q2_K_f32(
1756
2574
  device const void * src0,
1757
2575
  device const float * src1,
1758
2576
  device float * dst,
@@ -1761,23 +2579,50 @@ kernel void kernel_mul_mv_q3_K_f32(
1761
2579
  constant int64_t & ne02[[buffer(5)]],
1762
2580
  constant int64_t & ne10[[buffer(9)]],
1763
2581
  constant int64_t & ne12[[buffer(11)]],
1764
- constant int64_t & ne0[[buffer(15)]],
1765
- constant int64_t & ne1[[buffer(16)]],
1766
- constant uint & gqa[[buffer(17)]],
2582
+ constant int64_t & ne0 [[buffer(15)]],
2583
+ constant int64_t & ne1 [[buffer(16)]],
2584
+ constant uint & r2 [[buffer(17)]],
2585
+ constant uint & r3 [[buffer(18)]],
1767
2586
  uint3 tgpig[[threadgroup_position_in_grid]],
1768
- uint tiisg[[thread_index_in_simdgroup]],
1769
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2587
+ uint tiisg[[thread_index_in_simdgroup]],
2588
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2589
+
2590
+ kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
2591
+ }
2592
+
2593
+ #if QK_K == 256
2594
+ void kernel_mul_mv_q3_K_f32_impl(
2595
+ device const void * src0,
2596
+ device const float * src1,
2597
+ device float * dst,
2598
+ constant int64_t & ne00,
2599
+ constant int64_t & ne01,
2600
+ constant int64_t & ne02,
2601
+ constant int64_t & ne10,
2602
+ constant int64_t & ne12,
2603
+ constant int64_t & ne0,
2604
+ constant int64_t & ne1,
2605
+ constant uint & r2,
2606
+ constant uint & r3,
2607
+ uint3 tgpig[[threadgroup_position_in_grid]],
2608
+ uint tiisg[[thread_index_in_simdgroup]],
2609
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1770
2610
 
1771
2611
  const int nb = ne00/QK_K;
1772
2612
 
1773
2613
  const int64_t r0 = tgpig.x;
1774
2614
  const int64_t r1 = tgpig.y;
1775
- const int64_t r2 = tgpig.z;
2615
+ const int64_t im = tgpig.z;
1776
2616
 
1777
2617
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
1778
- const uint offset0 = r2/gqa*(nb*ne0);
2618
+
2619
+ const uint i12 = im%ne12;
2620
+ const uint i13 = im/ne12;
2621
+
2622
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2623
+
1779
2624
  device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
1780
- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2625
+ device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
1781
2626
 
1782
2627
  float yl[32];
1783
2628
 
@@ -1899,40 +2744,47 @@ kernel void kernel_mul_mv_q3_K_f32(
1899
2744
  }
1900
2745
  if (tiisg == 0) {
1901
2746
  for (int row = 0; row < 2; ++row) {
1902
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
2747
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = sumf1[row];
1903
2748
  }
1904
2749
  }
1905
2750
  }
1906
2751
  #else
1907
- kernel void kernel_mul_mv_q3_K_f32(
2752
+ void kernel_mul_mv_q3_K_f32_impl(
1908
2753
  device const void * src0,
1909
2754
  device const float * src1,
1910
2755
  device float * dst,
1911
2756
  constant int64_t & ne00,
1912
- constant int64_t & ne01[[buffer(4)]],
1913
- constant int64_t & ne02[[buffer(5)]],
1914
- constant int64_t & ne10[[buffer(9)]],
1915
- constant int64_t & ne12[[buffer(11)]],
1916
- constant int64_t & ne0[[buffer(15)]],
1917
- constant int64_t & ne1[[buffer(16)]],
1918
- constant uint & gqa[[buffer(17)]],
2757
+ constant int64_t & ne01,
2758
+ constant int64_t & ne02,
2759
+ constant int64_t & ne10,
2760
+ constant int64_t & ne12,
2761
+ constant int64_t & ne0,
2762
+ constant int64_t & ne1,
2763
+ constant uint & r2,
2764
+ constant uint & r3,
1919
2765
  uint3 tgpig[[threadgroup_position_in_grid]],
1920
- uint tiisg[[thread_index_in_simdgroup]],
1921
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2766
+ uint tiisg[[thread_index_in_simdgroup]],
2767
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1922
2768
 
1923
2769
  const int nb = ne00/QK_K;
1924
2770
 
1925
2771
  const int64_t r0 = tgpig.x;
1926
2772
  const int64_t r1 = tgpig.y;
1927
- const int64_t r2 = tgpig.z;
2773
+ const int64_t im = tgpig.z;
1928
2774
 
1929
2775
  const int row = 2 * r0 + sgitg;
1930
- const uint offset0 = r2/gqa*(nb*ne0);
2776
+
2777
+ const uint i12 = im%ne12;
2778
+ const uint i13 = im/ne12;
2779
+
2780
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2781
+
1931
2782
  device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
1932
- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2783
+ device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2784
+
1933
2785
  const int ix = tiisg/4;
1934
2786
  const int il = 4 * (tiisg%4);// 0, 4, 8, 12
1935
- const int im = il/8; // 0, 0, 1, 1
2787
+ const int iq = il/8; // 0, 0, 1, 1
1936
2788
  const int in = il%8; // 0, 4, 0, 4
1937
2789
 
1938
2790
  float2 sum = {0.f, 0.f};
@@ -1952,7 +2804,7 @@ kernel void kernel_mul_mv_q3_K_f32(
1952
2804
  const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
1953
2805
 
1954
2806
  for (int l = 0; l < 4; l += 2) {
1955
- const uint16_t hm = h[l/2] >> im;
2807
+ const uint16_t hm = h[l/2] >> iq;
1956
2808
  sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
1957
2809
  + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
1958
2810
  + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
@@ -1968,28 +2820,50 @@ kernel void kernel_mul_mv_q3_K_f32(
1968
2820
 
1969
2821
  const float tot = simd_sum(sumf);
1970
2822
  if (tiisg == 0) {
1971
- dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
2823
+ dst[r1*ne0 + im*ne0*ne1 + row] = tot;
1972
2824
  }
1973
2825
 
1974
2826
  }
1975
2827
  #endif
1976
2828
 
2829
+ [[host_name("kernel_mul_mv_q3_K_f32")]]
2830
+ kernel void kernel_mul_mv_q3_K_f32(
2831
+ device const void * src0,
2832
+ device const float * src1,
2833
+ device float * dst,
2834
+ constant int64_t & ne00,
2835
+ constant int64_t & ne01[[buffer(4)]],
2836
+ constant int64_t & ne02[[buffer(5)]],
2837
+ constant int64_t & ne10[[buffer(9)]],
2838
+ constant int64_t & ne12[[buffer(11)]],
2839
+ constant int64_t & ne0 [[buffer(15)]],
2840
+ constant int64_t & ne1 [[buffer(16)]],
2841
+ constant uint & r2 [[buffer(17)]],
2842
+ constant uint & r3 [[buffer(18)]],
2843
+ uint3 tgpig[[threadgroup_position_in_grid]],
2844
+ uint tiisg[[thread_index_in_simdgroup]],
2845
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2846
+
2847
+ kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
2848
+ }
2849
+
1977
2850
  #if QK_K == 256
1978
- kernel void kernel_mul_mv_q4_K_f32(
2851
+ void kernel_mul_mv_q4_K_f32_impl(
1979
2852
  device const void * src0,
1980
2853
  device const float * src1,
1981
2854
  device float * dst,
1982
2855
  constant int64_t & ne00,
1983
- constant int64_t & ne01 [[buffer(4)]],
1984
- constant int64_t & ne02 [[buffer(5)]],
1985
- constant int64_t & ne10 [[buffer(9)]],
1986
- constant int64_t & ne12 [[buffer(11)]],
1987
- constant int64_t & ne0 [[buffer(15)]],
1988
- constant int64_t & ne1 [[buffer(16)]],
1989
- constant uint & gqa [[buffer(17)]],
2856
+ constant int64_t & ne01,
2857
+ constant int64_t & ne02,
2858
+ constant int64_t & ne10,
2859
+ constant int64_t & ne12,
2860
+ constant int64_t & ne0,
2861
+ constant int64_t & ne1,
2862
+ constant uint & r2,
2863
+ constant uint & r3,
1990
2864
  uint3 tgpig[[threadgroup_position_in_grid]],
1991
- uint tiisg[[thread_index_in_simdgroup]],
1992
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2865
+ uint tiisg[[thread_index_in_simdgroup]],
2866
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1993
2867
 
1994
2868
  const uint16_t kmask1 = 0x3f3f;
1995
2869
  const uint16_t kmask2 = 0x0f0f;
@@ -1997,26 +2871,32 @@ kernel void kernel_mul_mv_q4_K_f32(
1997
2871
 
1998
2872
  const int ix = tiisg/8; // 0...3
1999
2873
  const int it = tiisg%8; // 0...7
2000
- const int im = it/4; // 0 or 1
2874
+ const int iq = it/4; // 0 or 1
2001
2875
  const int ir = it%4; // 0...3
2002
2876
 
2003
2877
  const int nb = ne00/QK_K;
2004
2878
  const int r0 = tgpig.x;
2005
2879
  const int r1 = tgpig.y;
2006
- const int r2 = tgpig.z;
2880
+ const int im = tgpig.z;
2007
2881
  //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
2008
2882
  const int first_row = r0 * N_DST;
2009
2883
  const int ib_row = first_row * nb;
2010
- const uint offset0 = r2/gqa*(nb*ne0);
2884
+
2885
+ const uint i12 = im%ne12;
2886
+ const uint i13 = im/ne12;
2887
+
2888
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2889
+
2011
2890
  device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
2012
- device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2891
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2892
+
2013
2893
  float yl[16];
2014
2894
  float yh[16];
2015
2895
  float sumf[N_DST]={0.f}, all_sum;
2016
2896
 
2017
2897
  const int step = sizeof(block_q4_K) * nb / 2;
2018
2898
 
2019
- device const float * y4 = y + ix * QK_K + 64 * im + 8 * ir;
2899
+ device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
2020
2900
 
2021
2901
  uint16_t sc16[4];
2022
2902
  thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
@@ -2031,8 +2911,8 @@ kernel void kernel_mul_mv_q4_K_f32(
2031
2911
  yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
2032
2912
  }
2033
2913
 
2034
- device const uint16_t * sc = (device const uint16_t *)x[ib].scales + im;
2035
- device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
2914
+ device const uint16_t * sc = (device const uint16_t *)x[ib].scales + iq;
2915
+ device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
2036
2916
  device const half * dh = &x[ib].d;
2037
2917
 
2038
2918
  for (int row = 0; row < N_DST; row++) {
@@ -2076,23 +2956,24 @@ kernel void kernel_mul_mv_q4_K_f32(
2076
2956
  for (int row = 0; row < N_DST; ++row) {
2077
2957
  all_sum = simd_sum(sumf[row]);
2078
2958
  if (tiisg == 0) {
2079
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
2959
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
2080
2960
  }
2081
2961
  }
2082
2962
  }
2083
2963
  #else
2084
- kernel void kernel_mul_mv_q4_K_f32(
2964
+ void kernel_mul_mv_q4_K_f32_impl(
2085
2965
  device const void * src0,
2086
2966
  device const float * src1,
2087
2967
  device float * dst,
2088
2968
  constant int64_t & ne00,
2089
- constant int64_t & ne01[[buffer(4)]],
2090
- constant int64_t & ne02[[buffer(5)]],
2091
- constant int64_t & ne10[[buffer(9)]],
2092
- constant int64_t & ne12[[buffer(11)]],
2093
- constant int64_t & ne0[[buffer(15)]],
2094
- constant int64_t & ne1[[buffer(16)]],
2095
- constant uint & gqa[[buffer(17)]],
2969
+ constant int64_t & ne01,
2970
+ constant int64_t & ne02,
2971
+ constant int64_t & ne10,
2972
+ constant int64_t & ne12,
2973
+ constant int64_t & ne0,
2974
+ constant int64_t & ne1,
2975
+ constant uint & r2,
2976
+ constant uint & r3,
2096
2977
  uint3 tgpig[[threadgroup_position_in_grid]],
2097
2978
  uint tiisg[[thread_index_in_simdgroup]],
2098
2979
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2103,12 +2984,18 @@ kernel void kernel_mul_mv_q4_K_f32(
2103
2984
  const int nb = ne00/QK_K;
2104
2985
  const int r0 = tgpig.x;
2105
2986
  const int r1 = tgpig.y;
2106
- const int r2 = tgpig.z;
2987
+ const int im = tgpig.z;
2107
2988
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
2108
2989
  const int ib_row = first_row * nb;
2109
- const uint offset0 = r2/gqa*(nb*ne0);
2990
+
2991
+ const uint i12 = im%ne12;
2992
+ const uint i13 = im/ne12;
2993
+
2994
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2995
+
2110
2996
  device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
2111
- device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2997
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2998
+
2112
2999
  float yl[8];
2113
3000
  float yh[8];
2114
3001
  float sumf[N_DST]={0.f}, all_sum;
@@ -2164,13 +3051,14 @@ kernel void kernel_mul_mv_q4_K_f32(
2164
3051
  for (int row = 0; row < N_DST; ++row) {
2165
3052
  all_sum = simd_sum(sumf[row]);
2166
3053
  if (tiisg == 0) {
2167
- dst[r1*ne0+ r2*ne0*ne1 + first_row + row] = all_sum;
3054
+ dst[r1*ne0+ im*ne0*ne1 + first_row + row] = all_sum;
2168
3055
  }
2169
3056
  }
2170
3057
  }
2171
3058
  #endif
2172
3059
 
2173
- kernel void kernel_mul_mv_q5_K_f32(
3060
+ [[host_name("kernel_mul_mv_q4_K_f32")]]
3061
+ kernel void kernel_mul_mv_q4_K_f32(
2174
3062
  device const void * src0,
2175
3063
  device const float * src1,
2176
3064
  device float * dst,
@@ -2179,23 +3067,49 @@ kernel void kernel_mul_mv_q5_K_f32(
2179
3067
  constant int64_t & ne02[[buffer(5)]],
2180
3068
  constant int64_t & ne10[[buffer(9)]],
2181
3069
  constant int64_t & ne12[[buffer(11)]],
2182
- constant int64_t & ne0[[buffer(15)]],
2183
- constant int64_t & ne1[[buffer(16)]],
2184
- constant uint & gqa[[buffer(17)]],
3070
+ constant int64_t & ne0 [[buffer(15)]],
3071
+ constant int64_t & ne1 [[buffer(16)]],
3072
+ constant uint & r2 [[buffer(17)]],
3073
+ constant uint & r3 [[buffer(18)]],
2185
3074
  uint3 tgpig[[threadgroup_position_in_grid]],
2186
3075
  uint tiisg[[thread_index_in_simdgroup]],
2187
3076
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
2188
3077
 
3078
+ kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
3079
+ }
3080
+
3081
+ void kernel_mul_mv_q5_K_f32_impl(
3082
+ device const void * src0,
3083
+ device const float * src1,
3084
+ device float * dst,
3085
+ constant int64_t & ne00,
3086
+ constant int64_t & ne01,
3087
+ constant int64_t & ne02,
3088
+ constant int64_t & ne10,
3089
+ constant int64_t & ne12,
3090
+ constant int64_t & ne0,
3091
+ constant int64_t & ne1,
3092
+ constant uint & r2,
3093
+ constant uint & r3,
3094
+ uint3 tgpig[[threadgroup_position_in_grid]],
3095
+ uint tiisg[[thread_index_in_simdgroup]],
3096
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3097
+
2189
3098
  const int nb = ne00/QK_K;
2190
3099
 
2191
3100
  const int64_t r0 = tgpig.x;
2192
3101
  const int64_t r1 = tgpig.y;
2193
- const int r2 = tgpig.z;
3102
+ const int im = tgpig.z;
2194
3103
 
2195
3104
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
2196
- const uint offset0 = r2/gqa*(nb*ne0);
3105
+
3106
+ const uint i12 = im%ne12;
3107
+ const uint i13 = im/ne12;
3108
+
3109
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
3110
+
2197
3111
  device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
2198
- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
3112
+ device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2199
3113
 
2200
3114
  float sumf[2]={0.f};
2201
3115
 
@@ -2211,15 +3125,15 @@ kernel void kernel_mul_mv_q5_K_f32(
2211
3125
 
2212
3126
  const int tid = tiisg/4;
2213
3127
  const int ix = tiisg%4;
2214
- const int im = tid/4;
3128
+ const int iq = tid/4;
2215
3129
  const int ir = tid%4;
2216
3130
  const int n = 8;
2217
3131
 
2218
3132
  const int l0 = n*ir;
2219
- const int q_offset = 32*im + l0;
2220
- const int y_offset = 64*im + l0;
3133
+ const int q_offset = 32*iq + l0;
3134
+ const int y_offset = 64*iq + l0;
2221
3135
 
2222
- const uint8_t hm1 = 1u << (2*im);
3136
+ const uint8_t hm1 = 1u << (2*iq);
2223
3137
  const uint8_t hm2 = hm1 << 1;
2224
3138
  const uint8_t hm3 = hm1 << 4;
2225
3139
  const uint8_t hm4 = hm2 << 4;
@@ -2234,7 +3148,7 @@ kernel void kernel_mul_mv_q5_K_f32(
2234
3148
  device const uint8_t * q1 = x[i].qs + q_offset;
2235
3149
  device const uint8_t * qh = x[i].qh + l0;
2236
3150
  device const half * dh = &x[i].d;
2237
- device const uint16_t * a = (device const uint16_t *)x[i].scales + im;
3151
+ device const uint16_t * a = (device const uint16_t *)x[i].scales + iq;
2238
3152
 
2239
3153
  device const float * y2 = y1 + 128;
2240
3154
  float4 sumy = {0.f, 0.f, 0.f, 0.f};
@@ -2290,7 +3204,7 @@ kernel void kernel_mul_mv_q5_K_f32(
2290
3204
 
2291
3205
  const int il = 4 * (tiisg/8); // 0, 4, 8, 12
2292
3206
  const int ix = tiisg%8;
2293
- const int im = il/8; // 0, 0, 1, 1
3207
+ const int iq = il/8; // 0, 0, 1, 1
2294
3208
  const int in = il%8; // 0, 4, 0, 4
2295
3209
 
2296
3210
  device const float * y = yy + ix*QK_K + il;
@@ -2315,7 +3229,7 @@ kernel void kernel_mul_mv_q5_K_f32(
2315
3229
 
2316
3230
  float2 acc = {0.f, 0.f};
2317
3231
  for (int l = 0; l < 4; ++l) {
2318
- const uint8_t hl = h[l] >> im;
3232
+ const uint8_t hl = h[l] >> iq;
2319
3233
  acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
2320
3234
  + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
2321
3235
  acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
@@ -2337,27 +3251,48 @@ kernel void kernel_mul_mv_q5_K_f32(
2337
3251
  for (int row = 0; row < 2; ++row) {
2338
3252
  const float tot = simd_sum(sumf[row]);
2339
3253
  if (tiisg == 0) {
2340
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
3254
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
2341
3255
  }
2342
3256
  }
3257
+ }
3258
+
3259
+ [[host_name("kernel_mul_mv_q5_K_f32")]]
3260
+ kernel void kernel_mul_mv_q5_K_f32(
3261
+ device const void * src0,
3262
+ device const float * src1,
3263
+ device float * dst,
3264
+ constant int64_t & ne00,
3265
+ constant int64_t & ne01[[buffer(4)]],
3266
+ constant int64_t & ne02[[buffer(5)]],
3267
+ constant int64_t & ne10[[buffer(9)]],
3268
+ constant int64_t & ne12[[buffer(11)]],
3269
+ constant int64_t & ne0 [[buffer(15)]],
3270
+ constant int64_t & ne1 [[buffer(16)]],
3271
+ constant uint & r2 [[buffer(17)]],
3272
+ constant uint & r3 [[buffer(18)]],
3273
+ uint3 tgpig[[threadgroup_position_in_grid]],
3274
+ uint tiisg[[thread_index_in_simdgroup]],
3275
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2343
3276
 
3277
+ kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
2344
3278
  }
2345
3279
 
2346
- kernel void kernel_mul_mv_q6_K_f32(
3280
+ void kernel_mul_mv_q6_K_f32_impl(
2347
3281
  device const void * src0,
2348
3282
  device const float * src1,
2349
3283
  device float * dst,
2350
3284
  constant int64_t & ne00,
2351
- constant int64_t & ne01[[buffer(4)]],
2352
- constant int64_t & ne02[[buffer(5)]],
2353
- constant int64_t & ne10[[buffer(9)]],
2354
- constant int64_t & ne12[[buffer(11)]],
2355
- constant int64_t & ne0[[buffer(15)]],
2356
- constant int64_t & ne1[[buffer(16)]],
2357
- constant uint & gqa[[buffer(17)]],
3285
+ constant int64_t & ne01,
3286
+ constant int64_t & ne02,
3287
+ constant int64_t & ne10,
3288
+ constant int64_t & ne12,
3289
+ constant int64_t & ne0,
3290
+ constant int64_t & ne1,
3291
+ constant uint & r2,
3292
+ constant uint & r3,
2358
3293
  uint3 tgpig[[threadgroup_position_in_grid]],
2359
- uint tiisg[[thread_index_in_simdgroup]],
2360
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3294
+ uint tiisg[[thread_index_in_simdgroup]],
3295
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2361
3296
 
2362
3297
  const uint8_t kmask1 = 0x03;
2363
3298
  const uint8_t kmask2 = 0x0C;
@@ -2368,12 +3303,17 @@ kernel void kernel_mul_mv_q6_K_f32(
2368
3303
 
2369
3304
  const int64_t r0 = tgpig.x;
2370
3305
  const int64_t r1 = tgpig.y;
2371
- const int r2 = tgpig.z;
3306
+ const int im = tgpig.z;
2372
3307
 
2373
3308
  const int row = 2 * r0 + sgitg;
2374
- const uint offset0 = r2/gqa*(nb*ne0);
3309
+
3310
+ const uint i12 = im%ne12;
3311
+ const uint i13 = im/ne12;
3312
+
3313
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
3314
+
2375
3315
  device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
2376
- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
3316
+ device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2377
3317
 
2378
3318
  float sumf = 0;
2379
3319
 
@@ -2439,10 +3379,31 @@ kernel void kernel_mul_mv_q6_K_f32(
2439
3379
 
2440
3380
  const float tot = simd_sum(sumf);
2441
3381
  if (tiisg == 0) {
2442
- dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
3382
+ dst[r1*ne0 + im*ne0*ne1 + row] = tot;
2443
3383
  }
2444
3384
  }
2445
3385
 
3386
+ [[host_name("kernel_mul_mv_q6_K_f32")]]
3387
+ kernel void kernel_mul_mv_q6_K_f32(
3388
+ device const void * src0,
3389
+ device const float * src1,
3390
+ device float * dst,
3391
+ constant int64_t & ne00,
3392
+ constant int64_t & ne01[[buffer(4)]],
3393
+ constant int64_t & ne02[[buffer(5)]],
3394
+ constant int64_t & ne10[[buffer(9)]],
3395
+ constant int64_t & ne12[[buffer(11)]],
3396
+ constant int64_t & ne0 [[buffer(15)]],
3397
+ constant int64_t & ne1 [[buffer(16)]],
3398
+ constant uint & r2 [[buffer(17)]],
3399
+ constant uint & r3 [[buffer(18)]],
3400
+ uint3 tgpig[[threadgroup_position_in_grid]],
3401
+ uint tiisg[[thread_index_in_simdgroup]],
3402
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3403
+
3404
+ kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
3405
+ }
3406
+
2446
3407
  //============================= templates and their specializations =============================
2447
3408
 
2448
3409
  // NOTE: this is not dequantizing - we are simply fitting the template
@@ -2560,10 +3521,10 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
2560
3521
 
2561
3522
  template <typename type4x4>
2562
3523
  void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
2563
- const half d = xb->d;
2564
- const half min = xb->dmin;
3524
+ const float d = xb->d;
3525
+ const float min = xb->dmin;
2565
3526
  device const uint8_t * q = (device const uint8_t *)xb->qs;
2566
- half dl, ml;
3527
+ float dl, ml;
2567
3528
  uint8_t sc = xb->scales[il];
2568
3529
 
2569
3530
  #if QK_K == 256
@@ -2633,10 +3594,10 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
2633
3594
  q = q + (il/4) * 32 + 16 * (il&1);
2634
3595
  il = il & 3;
2635
3596
  const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
2636
- const half d = il < 2 ? xb->d : xb->d / 16.h;
2637
- const half min = xb->dmin;
2638
- const half dl = d * sc[0];
2639
- const half ml = min * sc[1];
3597
+ const float d = il < 2 ? xb->d : xb->d / 16.h;
3598
+ const float min = xb->dmin;
3599
+ const float dl = d * sc[0];
3600
+ const float ml = min * sc[1];
2640
3601
  #else
2641
3602
  q = q + 16 * (il&1);
2642
3603
  device const uint8_t * s = xb->scales;
@@ -2663,13 +3624,13 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
2663
3624
  uint8_t ul = 1 << (il/2);
2664
3625
  il = il & 3;
2665
3626
  const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
2666
- const half d = il < 2 ? xb->d : xb->d / 16.h;
2667
- const half min = xb->dmin;
2668
- const half dl = d * sc[0];
2669
- const half ml = min * sc[1];
3627
+ const float d = il < 2 ? xb->d : xb->d / 16.h;
3628
+ const float min = xb->dmin;
3629
+ const float dl = d * sc[0];
3630
+ const float ml = min * sc[1];
2670
3631
 
2671
- const ushort mask = il<2 ? 0x0F : 0xF0;
2672
- const half qh_val = il<2 ? 16.h : 256.h;
3632
+ const ushort mask = il<2 ? 0x0F : 0xF0;
3633
+ const float qh_val = il<2 ? 16.f : 256.f;
2673
3634
  for (int i = 0; i < 16; ++i) {
2674
3635
  reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
2675
3636
  }
@@ -2717,22 +3678,90 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
2717
3678
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
2718
3679
  kernel void kernel_get_rows(
2719
3680
  device const void * src0,
2720
- device const int * src1,
3681
+ device const char * src1,
2721
3682
  device float * dst,
2722
3683
  constant int64_t & ne00,
2723
3684
  constant uint64_t & nb01,
3685
+ constant uint64_t & nb02,
3686
+ constant int64_t & ne10,
3687
+ constant uint64_t & nb10,
3688
+ constant uint64_t & nb11,
2724
3689
  constant uint64_t & nb1,
2725
- uint tgpig[[threadgroup_position_in_grid]],
3690
+ constant uint64_t & nb2,
3691
+ uint3 tgpig[[threadgroup_position_in_grid]],
2726
3692
  uint tiitg[[thread_index_in_threadgroup]],
2727
- uint tptg[[threads_per_threadgroup]]) {
2728
- const int i = tgpig;
2729
- const int r = ((device int32_t *) src1)[i];
3693
+ uint3 tptg [[threads_per_threadgroup]]) {
3694
+ //const int64_t i = tgpig;
3695
+ //const int64_t r = ((device int32_t *) src1)[i];
3696
+
3697
+ const int64_t i10 = tgpig.x;
3698
+ const int64_t i11 = tgpig.y;
2730
3699
 
2731
- for (int ind = tiitg; ind < ne00/16; ind += tptg) {
3700
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
3701
+
3702
+ const int64_t i02 = i11;
3703
+
3704
+ for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
2732
3705
  float4x4 temp;
2733
3706
  dequantize_func(
2734
- ((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
2735
- *(((device float4x4 *) ((device char *) dst + i*nb1)) + ind) = temp;
3707
+ ((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
3708
+ *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
3709
+ }
3710
+ }
3711
+
3712
+ kernel void kernel_get_rows_f32(
3713
+ device const void * src0,
3714
+ device const char * src1,
3715
+ device float * dst,
3716
+ constant int64_t & ne00,
3717
+ constant uint64_t & nb01,
3718
+ constant uint64_t & nb02,
3719
+ constant int64_t & ne10,
3720
+ constant uint64_t & nb10,
3721
+ constant uint64_t & nb11,
3722
+ constant uint64_t & nb1,
3723
+ constant uint64_t & nb2,
3724
+ uint3 tgpig[[threadgroup_position_in_grid]],
3725
+ uint tiitg[[thread_index_in_threadgroup]],
3726
+ uint3 tptg [[threads_per_threadgroup]]) {
3727
+ const int64_t i10 = tgpig.x;
3728
+ const int64_t i11 = tgpig.y;
3729
+
3730
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
3731
+
3732
+ const int64_t i02 = i11;
3733
+
3734
+ for (int ind = tiitg; ind < ne00; ind += tptg.x) {
3735
+ ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
3736
+ ((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
3737
+ }
3738
+ }
3739
+
3740
+ kernel void kernel_get_rows_f16(
3741
+ device const void * src0,
3742
+ device const char * src1,
3743
+ device float * dst,
3744
+ constant int64_t & ne00,
3745
+ constant uint64_t & nb01,
3746
+ constant uint64_t & nb02,
3747
+ constant int64_t & ne10,
3748
+ constant uint64_t & nb10,
3749
+ constant uint64_t & nb11,
3750
+ constant uint64_t & nb1,
3751
+ constant uint64_t & nb2,
3752
+ uint3 tgpig[[threadgroup_position_in_grid]],
3753
+ uint tiitg[[thread_index_in_threadgroup]],
3754
+ uint3 tptg [[threads_per_threadgroup]]) {
3755
+ const int64_t i10 = tgpig.x;
3756
+ const int64_t i11 = tgpig.y;
3757
+
3758
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
3759
+
3760
+ const int64_t i02 = i11;
3761
+
3762
+ for (int ind = tiitg; ind < ne00; ind += tptg.x) {
3763
+ ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
3764
+ ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
2736
3765
  }
2737
3766
  }
2738
3767
 
@@ -2749,24 +3778,25 @@ kernel void kernel_get_rows(
2749
3778
 
2750
3779
  // each block_q contains 16*nl weights
2751
3780
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
2752
- kernel void kernel_mul_mm(device const uchar * src0,
2753
- device const uchar * src1,
2754
- device float * dst,
2755
- constant int64_t & ne00,
2756
- constant int64_t & ne02,
2757
- constant int64_t & nb01,
2758
- constant int64_t & nb02,
2759
- constant int64_t & ne12,
2760
- constant int64_t & nb10,
2761
- constant int64_t & nb11,
2762
- constant int64_t & nb12,
2763
- constant int64_t & ne0,
2764
- constant int64_t & ne1,
2765
- constant uint & gqa,
2766
- threadgroup uchar * shared_memory [[threadgroup(0)]],
2767
- uint3 tgpig[[threadgroup_position_in_grid]],
2768
- uint tiitg[[thread_index_in_threadgroup]],
2769
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3781
+ void kernel_mul_mm_impl(device const uchar * src0,
3782
+ device const uchar * src1,
3783
+ device float * dst,
3784
+ constant int64_t & ne00,
3785
+ constant int64_t & ne02,
3786
+ constant int64_t & nb01,
3787
+ constant int64_t & nb02,
3788
+ constant int64_t & ne12,
3789
+ constant int64_t & nb10,
3790
+ constant int64_t & nb11,
3791
+ constant int64_t & nb12,
3792
+ constant int64_t & ne0,
3793
+ constant int64_t & ne1,
3794
+ constant uint & r2,
3795
+ constant uint & r3,
3796
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
3797
+ uint3 tgpig[[threadgroup_position_in_grid]],
3798
+ uint tiitg[[thread_index_in_threadgroup]],
3799
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2770
3800
 
2771
3801
  threadgroup half * sa = (threadgroup half *)(shared_memory);
2772
3802
  threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
@@ -2792,7 +3822,10 @@ kernel void kernel_mul_mm(device const uchar * src0,
2792
3822
 
2793
3823
  short il = (tiitg % THREAD_PER_ROW);
2794
3824
 
2795
- uint offset0 = im/gqa*nb02;
3825
+ const uint i12 = im%ne12;
3826
+ const uint i13 = im/ne12;
3827
+
3828
+ uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
2796
3829
  ushort offset1 = il/nl;
2797
3830
 
2798
3831
  device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
@@ -2876,17 +3909,137 @@ kernel void kernel_mul_mm(device const uchar * src0,
2876
3909
  }
2877
3910
  }
2878
3911
 
3912
+ template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
3913
+ kernel void kernel_mul_mm(device const uchar * src0,
3914
+ device const uchar * src1,
3915
+ device float * dst,
3916
+ constant int64_t & ne00,
3917
+ constant int64_t & ne02,
3918
+ constant int64_t & nb01,
3919
+ constant int64_t & nb02,
3920
+ constant int64_t & ne12,
3921
+ constant int64_t & nb10,
3922
+ constant int64_t & nb11,
3923
+ constant int64_t & nb12,
3924
+ constant int64_t & ne0,
3925
+ constant int64_t & ne1,
3926
+ constant uint & r2,
3927
+ constant uint & r3,
3928
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
3929
+ uint3 tgpig[[threadgroup_position_in_grid]],
3930
+ uint tiitg[[thread_index_in_threadgroup]],
3931
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3932
+ kernel_mul_mm_impl<block_q, nl, dequantize_func>(
3933
+ src0,
3934
+ src1,
3935
+ dst,
3936
+ ne00,
3937
+ ne02,
3938
+ nb01,
3939
+ nb02,
3940
+ ne12,
3941
+ nb10,
3942
+ nb11,
3943
+ nb12,
3944
+ ne0,
3945
+ ne1,
3946
+ r2,
3947
+ r3,
3948
+ shared_memory,
3949
+ tgpig,
3950
+ tiitg,
3951
+ sgitg);
3952
+ }
3953
+
3954
+ template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
3955
+ kernel void kernel_mul_mm_id(
3956
+ device const uchar * ids,
3957
+ device const uchar * src1,
3958
+ device uchar * dst,
3959
+ constant int64_t & nbi1,
3960
+ constant int64_t & ne00,
3961
+ constant int64_t & ne02,
3962
+ constant int64_t & nb01,
3963
+ constant int64_t & nb02,
3964
+ constant int64_t & ne12,
3965
+ constant int64_t & ne13,
3966
+ constant int64_t & nb10,
3967
+ constant int64_t & nb11,
3968
+ constant int64_t & nb12,
3969
+ constant int64_t & ne0,
3970
+ constant int64_t & ne1,
3971
+ constant int64_t & nb1,
3972
+ constant uint & r2,
3973
+ constant uint & r3,
3974
+ constant int & idx,
3975
+ device const uchar * src00,
3976
+ device const uchar * src01,
3977
+ device const uchar * src02,
3978
+ device const uchar * src03,
3979
+ device const uchar * src04,
3980
+ device const uchar * src05,
3981
+ device const uchar * src06,
3982
+ device const uchar * src07,
3983
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
3984
+ uint3 tgpig[[threadgroup_position_in_grid]],
3985
+ uint tiitg[[thread_index_in_threadgroup]],
3986
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3987
+ device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
3988
+
3989
+ const int64_t bid = tgpig.z/(ne12*ne13);
3990
+
3991
+ tgpig.z = tgpig.z%(ne12*ne13);
3992
+
3993
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
3994
+
3995
+ kernel_mul_mm_impl<block_q, nl, dequantize_func>(
3996
+ src0[id],
3997
+ src1 + bid*nb11,
3998
+ (device float *) (dst + bid*nb1),
3999
+ ne00,
4000
+ ne02,
4001
+ nb01,
4002
+ nb02,
4003
+ ne12,
4004
+ nb10,
4005
+ nb11,
4006
+ nb12,
4007
+ ne0,
4008
+ ne1,
4009
+ r2,
4010
+ r3,
4011
+ shared_memory,
4012
+ tgpig,
4013
+ tiitg,
4014
+ sgitg);
4015
+ }
4016
+
2879
4017
  #if QK_K == 256
2880
4018
  #define QK_NL 16
2881
4019
  #else
2882
4020
  #define QK_NL 4
2883
4021
  #endif
2884
4022
 
2885
- typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
2886
- constant uint64_t &, constant uint64_t &, uint, uint, uint);
4023
+ //
4024
+ // get rows
4025
+ //
2887
4026
 
2888
- template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
2889
- template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
4027
+ typedef void (get_rows_t)(
4028
+ device const void * src0,
4029
+ device const char * src1,
4030
+ device float * dst,
4031
+ constant int64_t & ne00,
4032
+ constant uint64_t & nb01,
4033
+ constant uint64_t & nb02,
4034
+ constant int64_t & ne10,
4035
+ constant uint64_t & nb10,
4036
+ constant uint64_t & nb11,
4037
+ constant uint64_t & nb1,
4038
+ constant uint64_t & nb2,
4039
+ uint3, uint, uint3);
4040
+
4041
+ //template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
4042
+ //template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
2890
4043
  template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
2891
4044
  template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
2892
4045
  template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
@@ -2898,6 +4051,10 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows
2898
4051
  template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
2899
4052
  template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
2900
4053
 
4054
+ //
4055
+ // matrix-matrix multiplication
4056
+ //
4057
+
2901
4058
  typedef void (mat_mm_t)(
2902
4059
  device const uchar * src0,
2903
4060
  device const uchar * src1,
@@ -2912,8 +4069,10 @@ typedef void (mat_mm_t)(
2912
4069
  constant int64_t & nb12,
2913
4070
  constant int64_t & ne0,
2914
4071
  constant int64_t & ne1,
2915
- constant uint & gqa,
2916
- threadgroup uchar *, uint3, uint, uint);
4072
+ constant uint & r2,
4073
+ constant uint & r3,
4074
+ threadgroup uchar *,
4075
+ uint3, uint, uint);
2917
4076
 
2918
4077
  template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
2919
4078
  template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
@@ -2927,3 +4086,823 @@ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
2927
4086
  template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
2928
4087
  template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
2929
4088
  template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
4089
+
4090
+ //
4091
+ // indirect matrix-matrix multiplication
4092
+ //
4093
+
4094
+ typedef void (mat_mm_id_t)(
4095
+ device const uchar * ids,
4096
+ device const uchar * src1,
4097
+ device uchar * dst,
4098
+ constant int64_t & nbi1,
4099
+ constant int64_t & ne00,
4100
+ constant int64_t & ne02,
4101
+ constant int64_t & nb01,
4102
+ constant int64_t & nb02,
4103
+ constant int64_t & ne12,
4104
+ constant int64_t & ne13,
4105
+ constant int64_t & nb10,
4106
+ constant int64_t & nb11,
4107
+ constant int64_t & nb12,
4108
+ constant int64_t & ne0,
4109
+ constant int64_t & ne1,
4110
+ constant int64_t & nb1,
4111
+ constant uint & r2,
4112
+ constant uint & r3,
4113
+ constant int & idx,
4114
+ device const uchar * src00,
4115
+ device const uchar * src01,
4116
+ device const uchar * src02,
4117
+ device const uchar * src03,
4118
+ device const uchar * src04,
4119
+ device const uchar * src05,
4120
+ device const uchar * src06,
4121
+ device const uchar * src07,
4122
+ threadgroup uchar *,
4123
+ uint3, uint, uint);
4124
+
4125
+ template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
4126
+ template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
4127
+ template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
4128
+ template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
4129
+ template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
4130
+ template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>;
4131
+ template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>;
4132
+ template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
4133
+ template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
4134
+ template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
4135
+ template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
4136
+ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
4137
+
4138
+ //
4139
+ // matrix-vector multiplication
4140
+ //
4141
+
4142
+ [[host_name("kernel_mul_mv_id_f32_f32")]]
4143
+ kernel void kernel_mul_mv_id_f32_f32(
4144
+ device const char * ids,
4145
+ device const char * src1,
4146
+ device uchar * dst,
4147
+ constant int64_t & nbi1,
4148
+ constant int64_t & ne00,
4149
+ constant int64_t & ne01,
4150
+ constant int64_t & ne02,
4151
+ constant uint64_t & nb00,
4152
+ constant uint64_t & nb01,
4153
+ constant uint64_t & nb02,
4154
+ constant int64_t & ne10,
4155
+ constant int64_t & ne11,
4156
+ constant int64_t & ne12,
4157
+ constant int64_t & ne13,
4158
+ constant uint64_t & nb10,
4159
+ constant uint64_t & nb11,
4160
+ constant uint64_t & nb12,
4161
+ constant int64_t & ne0,
4162
+ constant int64_t & ne1,
4163
+ constant int64_t & nb1,
4164
+ constant uint & r2,
4165
+ constant uint & r3,
4166
+ constant int & idx,
4167
+ device const char * src00,
4168
+ device const char * src01,
4169
+ device const char * src02,
4170
+ device const char * src03,
4171
+ device const char * src04,
4172
+ device const char * src05,
4173
+ device const char * src06,
4174
+ device const char * src07,
4175
+ uint3 tgpig[[threadgroup_position_in_grid]],
4176
+ uint tiitg[[thread_index_in_threadgroup]],
4177
+ uint tiisg[[thread_index_in_simdgroup]],
4178
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4179
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4180
+
4181
+ const int64_t bid = tgpig.z/(ne12*ne13);
4182
+
4183
+ tgpig.z = tgpig.z%(ne12*ne13);
4184
+
4185
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4186
+
4187
+ kernel_mul_mv_f32_f32_impl(
4188
+ src0[id],
4189
+ src1 + bid*nb11,
4190
+ (device float *) (dst + bid*nb1),
4191
+ ne00,
4192
+ ne01,
4193
+ ne02,
4194
+ nb00,
4195
+ nb01,
4196
+ nb02,
4197
+ ne10,
4198
+ ne11,
4199
+ ne12,
4200
+ nb10,
4201
+ nb11,
4202
+ nb12,
4203
+ ne0,
4204
+ ne1,
4205
+ r2,
4206
+ r3,
4207
+ tgpig,
4208
+ tiisg);
4209
+ }
4210
+
4211
+ [[host_name("kernel_mul_mv_id_f16_f32")]]
4212
+ kernel void kernel_mul_mv_id_f16_f32(
4213
+ device const char * ids,
4214
+ device const char * src1,
4215
+ device uchar * dst,
4216
+ constant int64_t & nbi1,
4217
+ constant int64_t & ne00,
4218
+ constant int64_t & ne01,
4219
+ constant int64_t & ne02,
4220
+ constant uint64_t & nb00,
4221
+ constant uint64_t & nb01,
4222
+ constant uint64_t & nb02,
4223
+ constant int64_t & ne10,
4224
+ constant int64_t & ne11,
4225
+ constant int64_t & ne12,
4226
+ constant int64_t & ne13,
4227
+ constant uint64_t & nb10,
4228
+ constant uint64_t & nb11,
4229
+ constant uint64_t & nb12,
4230
+ constant int64_t & ne0,
4231
+ constant int64_t & ne1,
4232
+ constant int64_t & nb1,
4233
+ constant uint & r2,
4234
+ constant uint & r3,
4235
+ constant int & idx,
4236
+ device const char * src00,
4237
+ device const char * src01,
4238
+ device const char * src02,
4239
+ device const char * src03,
4240
+ device const char * src04,
4241
+ device const char * src05,
4242
+ device const char * src06,
4243
+ device const char * src07,
4244
+ uint3 tgpig[[threadgroup_position_in_grid]],
4245
+ uint tiitg[[thread_index_in_threadgroup]],
4246
+ uint tiisg[[thread_index_in_simdgroup]],
4247
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4248
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4249
+
4250
+ const int64_t bid = tgpig.z/(ne12*ne13);
4251
+
4252
+ tgpig.z = tgpig.z%(ne12*ne13);
4253
+
4254
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4255
+
4256
+ kernel_mul_mv_f16_f32_impl(
4257
+ src0[id],
4258
+ src1 + bid*nb11,
4259
+ (device float *) (dst + bid*nb1),
4260
+ ne00,
4261
+ ne01,
4262
+ ne02,
4263
+ nb00,
4264
+ nb01,
4265
+ nb02,
4266
+ ne10,
4267
+ ne11,
4268
+ ne12,
4269
+ nb10,
4270
+ nb11,
4271
+ nb12,
4272
+ ne0,
4273
+ ne1,
4274
+ r2,
4275
+ r3,
4276
+ tgpig,
4277
+ tiisg);
4278
+ }
4279
+
4280
+ [[host_name("kernel_mul_mv_id_q8_0_f32")]]
4281
+ kernel void kernel_mul_mv_id_q8_0_f32(
4282
+ device const char * ids,
4283
+ device const char * src1,
4284
+ device uchar * dst,
4285
+ constant int64_t & nbi1,
4286
+ constant int64_t & ne00,
4287
+ constant int64_t & ne01,
4288
+ constant int64_t & ne02,
4289
+ constant uint64_t & nb00,
4290
+ constant uint64_t & nb01,
4291
+ constant uint64_t & nb02,
4292
+ constant int64_t & ne10,
4293
+ constant int64_t & ne11,
4294
+ constant int64_t & ne12,
4295
+ constant int64_t & ne13,
4296
+ constant uint64_t & nb10,
4297
+ constant uint64_t & nb11,
4298
+ constant uint64_t & nb12,
4299
+ constant int64_t & ne0,
4300
+ constant int64_t & ne1,
4301
+ constant int64_t & nb1,
4302
+ constant uint & r2,
4303
+ constant uint & r3,
4304
+ constant int & idx,
4305
+ device const char * src00,
4306
+ device const char * src01,
4307
+ device const char * src02,
4308
+ device const char * src03,
4309
+ device const char * src04,
4310
+ device const char * src05,
4311
+ device const char * src06,
4312
+ device const char * src07,
4313
+ uint3 tgpig[[threadgroup_position_in_grid]],
4314
+ uint tiitg[[thread_index_in_threadgroup]],
4315
+ uint tiisg[[thread_index_in_simdgroup]],
4316
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4317
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4318
+
4319
+ const int64_t bid = tgpig.z/(ne12*ne13);
4320
+
4321
+ tgpig.z = tgpig.z%(ne12*ne13);
4322
+
4323
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4324
+
4325
+ kernel_mul_mv_q8_0_f32_impl(
4326
+ src0[id],
4327
+ (device const float *) (src1 + bid*nb11),
4328
+ (device float *) ( dst + bid*nb1),
4329
+ ne00,
4330
+ ne01,
4331
+ ne02,
4332
+ ne10,
4333
+ ne12,
4334
+ ne0,
4335
+ ne1,
4336
+ r2,
4337
+ r3,
4338
+ tgpig,
4339
+ tiisg,
4340
+ sgitg);
4341
+ }
4342
+
4343
+ [[host_name("kernel_mul_mv_id_q4_0_f32")]]
4344
+ kernel void kernel_mul_mv_id_q4_0_f32(
4345
+ device const char * ids,
4346
+ device const char * src1,
4347
+ device uchar * dst,
4348
+ constant int64_t & nbi1,
4349
+ constant int64_t & ne00,
4350
+ constant int64_t & ne01,
4351
+ constant int64_t & ne02,
4352
+ constant uint64_t & nb00,
4353
+ constant uint64_t & nb01,
4354
+ constant uint64_t & nb02,
4355
+ constant int64_t & ne10,
4356
+ constant int64_t & ne11,
4357
+ constant int64_t & ne12,
4358
+ constant int64_t & ne13,
4359
+ constant uint64_t & nb10,
4360
+ constant uint64_t & nb11,
4361
+ constant uint64_t & nb12,
4362
+ constant int64_t & ne0,
4363
+ constant int64_t & ne1,
4364
+ constant int64_t & nb1,
4365
+ constant uint & r2,
4366
+ constant uint & r3,
4367
+ constant int & idx,
4368
+ device const char * src00,
4369
+ device const char * src01,
4370
+ device const char * src02,
4371
+ device const char * src03,
4372
+ device const char * src04,
4373
+ device const char * src05,
4374
+ device const char * src06,
4375
+ device const char * src07,
4376
+ uint3 tgpig[[threadgroup_position_in_grid]],
4377
+ uint tiitg[[thread_index_in_threadgroup]],
4378
+ uint tiisg[[thread_index_in_simdgroup]],
4379
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4380
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4381
+
4382
+ const int64_t bid = tgpig.z/(ne12*ne13);
4383
+
4384
+ tgpig.z = tgpig.z%(ne12*ne13);
4385
+
4386
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4387
+
4388
+ mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4389
+ src0[id],
4390
+ (device const float *) (src1 + bid*nb11),
4391
+ (device float *) ( dst + bid*nb1),
4392
+ ne00,
4393
+ ne01,
4394
+ ne02,
4395
+ ne10,
4396
+ ne12,
4397
+ ne0,
4398
+ ne1,
4399
+ r2,
4400
+ r3,
4401
+ tgpig,
4402
+ tiisg,
4403
+ sgitg);
4404
+ }
4405
+
4406
+ [[host_name("kernel_mul_mv_id_q4_1_f32")]]
4407
+ kernel void kernel_mul_mv_id_q4_1_f32(
4408
+ device const char * ids,
4409
+ device const char * src1,
4410
+ device uchar * dst,
4411
+ constant int64_t & nbi1,
4412
+ constant int64_t & ne00,
4413
+ constant int64_t & ne01,
4414
+ constant int64_t & ne02,
4415
+ constant uint64_t & nb00,
4416
+ constant uint64_t & nb01,
4417
+ constant uint64_t & nb02,
4418
+ constant int64_t & ne10,
4419
+ constant int64_t & ne11,
4420
+ constant int64_t & ne12,
4421
+ constant int64_t & ne13,
4422
+ constant uint64_t & nb10,
4423
+ constant uint64_t & nb11,
4424
+ constant uint64_t & nb12,
4425
+ constant int64_t & ne0,
4426
+ constant int64_t & ne1,
4427
+ constant int64_t & nb1,
4428
+ constant uint & r2,
4429
+ constant uint & r3,
4430
+ constant int & idx,
4431
+ device const char * src00,
4432
+ device const char * src01,
4433
+ device const char * src02,
4434
+ device const char * src03,
4435
+ device const char * src04,
4436
+ device const char * src05,
4437
+ device const char * src06,
4438
+ device const char * src07,
4439
+ uint3 tgpig[[threadgroup_position_in_grid]],
4440
+ uint tiitg[[thread_index_in_threadgroup]],
4441
+ uint tiisg[[thread_index_in_simdgroup]],
4442
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4443
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4444
+
4445
+ const int64_t bid = tgpig.z/(ne12*ne13);
4446
+
4447
+ tgpig.z = tgpig.z%(ne12*ne13);
4448
+
4449
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4450
+
4451
+ mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4452
+ src0[id],
4453
+ (device const float *) (src1 + bid*nb11),
4454
+ (device float *) ( dst + bid*nb1),
4455
+ ne00,
4456
+ ne01,
4457
+ ne02,
4458
+ ne10,
4459
+ ne12,
4460
+ ne0,
4461
+ ne1,
4462
+ r2,
4463
+ r3,
4464
+ tgpig,
4465
+ tiisg,
4466
+ sgitg);
4467
+ }
4468
+
4469
+ [[host_name("kernel_mul_mv_id_q5_0_f32")]]
4470
+ kernel void kernel_mul_mv_id_q5_0_f32(
4471
+ device const char * ids,
4472
+ device const char * src1,
4473
+ device uchar * dst,
4474
+ constant int64_t & nbi1,
4475
+ constant int64_t & ne00,
4476
+ constant int64_t & ne01,
4477
+ constant int64_t & ne02,
4478
+ constant uint64_t & nb00,
4479
+ constant uint64_t & nb01,
4480
+ constant uint64_t & nb02,
4481
+ constant int64_t & ne10,
4482
+ constant int64_t & ne11,
4483
+ constant int64_t & ne12,
4484
+ constant int64_t & ne13,
4485
+ constant uint64_t & nb10,
4486
+ constant uint64_t & nb11,
4487
+ constant uint64_t & nb12,
4488
+ constant int64_t & ne0,
4489
+ constant int64_t & ne1,
4490
+ constant int64_t & nb1,
4491
+ constant uint & r2,
4492
+ constant uint & r3,
4493
+ constant int & idx,
4494
+ device const char * src00,
4495
+ device const char * src01,
4496
+ device const char * src02,
4497
+ device const char * src03,
4498
+ device const char * src04,
4499
+ device const char * src05,
4500
+ device const char * src06,
4501
+ device const char * src07,
4502
+ uint3 tgpig[[threadgroup_position_in_grid]],
4503
+ uint tiitg[[thread_index_in_threadgroup]],
4504
+ uint tiisg[[thread_index_in_simdgroup]],
4505
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4506
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4507
+
4508
+ const int64_t bid = tgpig.z/(ne12*ne13);
4509
+
4510
+ tgpig.z = tgpig.z%(ne12*ne13);
4511
+
4512
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4513
+
4514
+ mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4515
+ src0[id],
4516
+ (device const float *) (src1 + bid*nb11),
4517
+ (device float *) ( dst + bid*nb1),
4518
+ ne00,
4519
+ ne01,
4520
+ ne02,
4521
+ ne10,
4522
+ ne12,
4523
+ ne0,
4524
+ ne1,
4525
+ r2,
4526
+ r3,
4527
+ tgpig,
4528
+ tiisg,
4529
+ sgitg);
4530
+ }
4531
+
4532
+ [[host_name("kernel_mul_mv_id_q5_1_f32")]]
4533
+ kernel void kernel_mul_mv_id_q5_1_f32(
4534
+ device const char * ids,
4535
+ device const char * src1,
4536
+ device uchar * dst,
4537
+ constant int64_t & nbi1,
4538
+ constant int64_t & ne00,
4539
+ constant int64_t & ne01,
4540
+ constant int64_t & ne02,
4541
+ constant uint64_t & nb00,
4542
+ constant uint64_t & nb01,
4543
+ constant uint64_t & nb02,
4544
+ constant int64_t & ne10,
4545
+ constant int64_t & ne11,
4546
+ constant int64_t & ne12,
4547
+ constant int64_t & ne13,
4548
+ constant uint64_t & nb10,
4549
+ constant uint64_t & nb11,
4550
+ constant uint64_t & nb12,
4551
+ constant int64_t & ne0,
4552
+ constant int64_t & ne1,
4553
+ constant int64_t & nb1,
4554
+ constant uint & r2,
4555
+ constant uint & r3,
4556
+ constant int & idx,
4557
+ device const char * src00,
4558
+ device const char * src01,
4559
+ device const char * src02,
4560
+ device const char * src03,
4561
+ device const char * src04,
4562
+ device const char * src05,
4563
+ device const char * src06,
4564
+ device const char * src07,
4565
+ uint3 tgpig[[threadgroup_position_in_grid]],
4566
+ uint tiitg[[thread_index_in_threadgroup]],
4567
+ uint tiisg[[thread_index_in_simdgroup]],
4568
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4569
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4570
+
4571
+ const int64_t bid = tgpig.z/(ne12*ne13);
4572
+
4573
+ tgpig.z = tgpig.z%(ne12*ne13);
4574
+
4575
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4576
+
4577
+ mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4578
+ src0[id],
4579
+ (device const float *) (src1 + bid*nb11),
4580
+ (device float *) ( dst + bid*nb1),
4581
+ ne00,
4582
+ ne01,
4583
+ ne02,
4584
+ ne10,
4585
+ ne12,
4586
+ ne0,
4587
+ ne1,
4588
+ r2,
4589
+ r3,
4590
+ tgpig,
4591
+ tiisg,
4592
+ sgitg);
4593
+ }
4594
+
4595
+ [[host_name("kernel_mul_mv_id_q2_K_f32")]]
4596
+ kernel void kernel_mul_mv_id_q2_K_f32(
4597
+ device const char * ids,
4598
+ device const char * src1,
4599
+ device uchar * dst,
4600
+ constant int64_t & nbi1,
4601
+ constant int64_t & ne00,
4602
+ constant int64_t & ne01,
4603
+ constant int64_t & ne02,
4604
+ constant uint64_t & nb00,
4605
+ constant uint64_t & nb01,
4606
+ constant uint64_t & nb02,
4607
+ constant int64_t & ne10,
4608
+ constant int64_t & ne11,
4609
+ constant int64_t & ne12,
4610
+ constant int64_t & ne13,
4611
+ constant uint64_t & nb10,
4612
+ constant uint64_t & nb11,
4613
+ constant uint64_t & nb12,
4614
+ constant int64_t & ne0,
4615
+ constant int64_t & ne1,
4616
+ constant int64_t & nb1,
4617
+ constant uint & r2,
4618
+ constant uint & r3,
4619
+ constant int & idx,
4620
+ device const char * src00,
4621
+ device const char * src01,
4622
+ device const char * src02,
4623
+ device const char * src03,
4624
+ device const char * src04,
4625
+ device const char * src05,
4626
+ device const char * src06,
4627
+ device const char * src07,
4628
+ uint3 tgpig[[threadgroup_position_in_grid]],
4629
+ uint tiitg[[thread_index_in_threadgroup]],
4630
+ uint tiisg[[thread_index_in_simdgroup]],
4631
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4632
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4633
+
4634
+ const int64_t bid = tgpig.z/(ne12*ne13);
4635
+
4636
+ tgpig.z = tgpig.z%(ne12*ne13);
4637
+
4638
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4639
+
4640
+ kernel_mul_mv_q2_K_f32_impl(
4641
+ src0[id],
4642
+ (device const float *) (src1 + bid*nb11),
4643
+ (device float *) ( dst + bid*nb1),
4644
+ ne00,
4645
+ ne01,
4646
+ ne02,
4647
+ ne10,
4648
+ ne12,
4649
+ ne0,
4650
+ ne1,
4651
+ r2,
4652
+ r3,
4653
+ tgpig,
4654
+ tiisg,
4655
+ sgitg);
4656
+ }
4657
+
4658
+ [[host_name("kernel_mul_mv_id_q3_K_f32")]]
4659
+ kernel void kernel_mul_mv_id_q3_K_f32(
4660
+ device const char * ids,
4661
+ device const char * src1,
4662
+ device uchar * dst,
4663
+ constant int64_t & nbi1,
4664
+ constant int64_t & ne00,
4665
+ constant int64_t & ne01,
4666
+ constant int64_t & ne02,
4667
+ constant uint64_t & nb00,
4668
+ constant uint64_t & nb01,
4669
+ constant uint64_t & nb02,
4670
+ constant int64_t & ne10,
4671
+ constant int64_t & ne11,
4672
+ constant int64_t & ne12,
4673
+ constant int64_t & ne13,
4674
+ constant uint64_t & nb10,
4675
+ constant uint64_t & nb11,
4676
+ constant uint64_t & nb12,
4677
+ constant int64_t & ne0,
4678
+ constant int64_t & ne1,
4679
+ constant int64_t & nb1,
4680
+ constant uint & r2,
4681
+ constant uint & r3,
4682
+ constant int & idx,
4683
+ device const char * src00,
4684
+ device const char * src01,
4685
+ device const char * src02,
4686
+ device const char * src03,
4687
+ device const char * src04,
4688
+ device const char * src05,
4689
+ device const char * src06,
4690
+ device const char * src07,
4691
+ uint3 tgpig[[threadgroup_position_in_grid]],
4692
+ uint tiitg[[thread_index_in_threadgroup]],
4693
+ uint tiisg[[thread_index_in_simdgroup]],
4694
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4695
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4696
+
4697
+ const int64_t bid = tgpig.z/(ne12*ne13);
4698
+
4699
+ tgpig.z = tgpig.z%(ne12*ne13);
4700
+
4701
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4702
+
4703
+ kernel_mul_mv_q3_K_f32_impl(
4704
+ src0[id],
4705
+ (device const float *) (src1 + bid*nb11),
4706
+ (device float *) ( dst + bid*nb1),
4707
+ ne00,
4708
+ ne01,
4709
+ ne02,
4710
+ ne10,
4711
+ ne12,
4712
+ ne0,
4713
+ ne1,
4714
+ r2,
4715
+ r3,
4716
+ tgpig,
4717
+ tiisg,
4718
+ sgitg);
4719
+ }
4720
+
4721
+ [[host_name("kernel_mul_mv_id_q4_K_f32")]]
4722
+ kernel void kernel_mul_mv_id_q4_K_f32(
4723
+ device const char * ids,
4724
+ device const char * src1,
4725
+ device uchar * dst,
4726
+ constant int64_t & nbi1,
4727
+ constant int64_t & ne00,
4728
+ constant int64_t & ne01,
4729
+ constant int64_t & ne02,
4730
+ constant uint64_t & nb00,
4731
+ constant uint64_t & nb01,
4732
+ constant uint64_t & nb02,
4733
+ constant int64_t & ne10,
4734
+ constant int64_t & ne11,
4735
+ constant int64_t & ne12,
4736
+ constant int64_t & ne13,
4737
+ constant uint64_t & nb10,
4738
+ constant uint64_t & nb11,
4739
+ constant uint64_t & nb12,
4740
+ constant int64_t & ne0,
4741
+ constant int64_t & ne1,
4742
+ constant int64_t & nb1,
4743
+ constant uint & r2,
4744
+ constant uint & r3,
4745
+ constant int & idx,
4746
+ device const char * src00,
4747
+ device const char * src01,
4748
+ device const char * src02,
4749
+ device const char * src03,
4750
+ device const char * src04,
4751
+ device const char * src05,
4752
+ device const char * src06,
4753
+ device const char * src07,
4754
+ uint3 tgpig[[threadgroup_position_in_grid]],
4755
+ uint tiitg[[thread_index_in_threadgroup]],
4756
+ uint tiisg[[thread_index_in_simdgroup]],
4757
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4758
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4759
+
4760
+ const int64_t bid = tgpig.z/(ne12*ne13);
4761
+
4762
+ tgpig.z = tgpig.z%(ne12*ne13);
4763
+
4764
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4765
+
4766
+ kernel_mul_mv_q4_K_f32_impl(
4767
+ src0[id],
4768
+ (device const float *) (src1 + bid*nb11),
4769
+ (device float *) ( dst + bid*nb1),
4770
+ ne00,
4771
+ ne01,
4772
+ ne02,
4773
+ ne10,
4774
+ ne12,
4775
+ ne0,
4776
+ ne1,
4777
+ r2,
4778
+ r3,
4779
+ tgpig,
4780
+ tiisg,
4781
+ sgitg);
4782
+ }
4783
+
4784
+ [[host_name("kernel_mul_mv_id_q5_K_f32")]]
4785
+ kernel void kernel_mul_mv_id_q5_K_f32(
4786
+ device const char * ids,
4787
+ device const char * src1,
4788
+ device uchar * dst,
4789
+ constant int64_t & nbi1,
4790
+ constant int64_t & ne00,
4791
+ constant int64_t & ne01,
4792
+ constant int64_t & ne02,
4793
+ constant uint64_t & nb00,
4794
+ constant uint64_t & nb01,
4795
+ constant uint64_t & nb02,
4796
+ constant int64_t & ne10,
4797
+ constant int64_t & ne11,
4798
+ constant int64_t & ne12,
4799
+ constant int64_t & ne13,
4800
+ constant uint64_t & nb10,
4801
+ constant uint64_t & nb11,
4802
+ constant uint64_t & nb12,
4803
+ constant int64_t & ne0,
4804
+ constant int64_t & ne1,
4805
+ constant int64_t & nb1,
4806
+ constant uint & r2,
4807
+ constant uint & r3,
4808
+ constant int & idx,
4809
+ device const char * src00,
4810
+ device const char * src01,
4811
+ device const char * src02,
4812
+ device const char * src03,
4813
+ device const char * src04,
4814
+ device const char * src05,
4815
+ device const char * src06,
4816
+ device const char * src07,
4817
+ uint3 tgpig[[threadgroup_position_in_grid]],
4818
+ uint tiitg[[thread_index_in_threadgroup]],
4819
+ uint tiisg[[thread_index_in_simdgroup]],
4820
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4821
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4822
+
4823
+ const int64_t bid = tgpig.z/(ne12*ne13);
4824
+
4825
+ tgpig.z = tgpig.z%(ne12*ne13);
4826
+
4827
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4828
+
4829
+ kernel_mul_mv_q5_K_f32_impl(
4830
+ src0[id],
4831
+ (device const float *) (src1 + bid*nb11),
4832
+ (device float *) ( dst + bid*nb1),
4833
+ ne00,
4834
+ ne01,
4835
+ ne02,
4836
+ ne10,
4837
+ ne12,
4838
+ ne0,
4839
+ ne1,
4840
+ r2,
4841
+ r3,
4842
+ tgpig,
4843
+ tiisg,
4844
+ sgitg);
4845
+ }
4846
+
4847
+ [[host_name("kernel_mul_mv_id_q6_K_f32")]]
4848
+ kernel void kernel_mul_mv_id_q6_K_f32(
4849
+ device const char * ids,
4850
+ device const char * src1,
4851
+ device uchar * dst,
4852
+ constant int64_t & nbi1,
4853
+ constant int64_t & ne00,
4854
+ constant int64_t & ne01,
4855
+ constant int64_t & ne02,
4856
+ constant uint64_t & nb00,
4857
+ constant uint64_t & nb01,
4858
+ constant uint64_t & nb02,
4859
+ constant int64_t & ne10,
4860
+ constant int64_t & ne11,
4861
+ constant int64_t & ne12,
4862
+ constant int64_t & ne13,
4863
+ constant uint64_t & nb10,
4864
+ constant uint64_t & nb11,
4865
+ constant uint64_t & nb12,
4866
+ constant int64_t & ne0,
4867
+ constant int64_t & ne1,
4868
+ constant int64_t & nb1,
4869
+ constant uint & r2,
4870
+ constant uint & r3,
4871
+ constant int & idx,
4872
+ device const char * src00,
4873
+ device const char * src01,
4874
+ device const char * src02,
4875
+ device const char * src03,
4876
+ device const char * src04,
4877
+ device const char * src05,
4878
+ device const char * src06,
4879
+ device const char * src07,
4880
+ uint3 tgpig[[threadgroup_position_in_grid]],
4881
+ uint tiitg[[thread_index_in_threadgroup]],
4882
+ uint tiisg[[thread_index_in_simdgroup]],
4883
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4884
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4885
+
4886
+ const int64_t bid = tgpig.z/(ne12*ne13);
4887
+
4888
+ tgpig.z = tgpig.z%(ne12*ne13);
4889
+
4890
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4891
+
4892
+ kernel_mul_mv_q6_K_f32_impl(
4893
+ src0[id],
4894
+ (device const float *) (src1 + bid*nb11),
4895
+ (device float *) ( dst + bid*nb1),
4896
+ ne00,
4897
+ ne01,
4898
+ ne02,
4899
+ ne10,
4900
+ ne12,
4901
+ ne0,
4902
+ ne1,
4903
+ r2,
4904
+ r3,
4905
+ tgpig,
4906
+ tiisg,
4907
+ sgitg);
4908
+ }