whisper.rn 0.4.0-rc.4 → 0.4.0-rc.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. package/README.md +6 -6
  2. package/android/build.gradle +4 -0
  3. package/android/src/main/CMakeLists.txt +5 -0
  4. package/android/src/main/java/com/rnwhisper/AudioUtils.java +0 -80
  5. package/android/src/main/java/com/rnwhisper/WhisperContext.java +57 -134
  6. package/android/src/main/jni-utils.h +76 -0
  7. package/android/src/main/jni.cpp +188 -112
  8. package/cpp/README.md +1 -1
  9. package/cpp/coreml/whisper-encoder-impl.h +1 -1
  10. package/cpp/coreml/whisper-encoder.h +4 -0
  11. package/cpp/coreml/whisper-encoder.mm +4 -2
  12. package/cpp/ggml-alloc.c +55 -19
  13. package/cpp/ggml-alloc.h +8 -1
  14. package/cpp/ggml-backend-impl.h +46 -21
  15. package/cpp/ggml-backend.c +563 -156
  16. package/cpp/ggml-backend.h +62 -17
  17. package/cpp/ggml-impl.h +1 -1
  18. package/cpp/ggml-metal-whisper.metal +2444 -359
  19. package/cpp/ggml-metal.h +7 -1
  20. package/cpp/ggml-metal.m +1105 -197
  21. package/cpp/ggml-quants.c +66 -61
  22. package/cpp/ggml-quants.h +40 -40
  23. package/cpp/ggml.c +1040 -1590
  24. package/cpp/ggml.h +109 -30
  25. package/cpp/rn-audioutils.cpp +68 -0
  26. package/cpp/rn-audioutils.h +14 -0
  27. package/cpp/rn-whisper-log.h +11 -0
  28. package/cpp/rn-whisper.cpp +143 -59
  29. package/cpp/rn-whisper.h +48 -15
  30. package/cpp/whisper.cpp +1635 -928
  31. package/cpp/whisper.h +55 -10
  32. package/ios/RNWhisper.mm +7 -7
  33. package/ios/RNWhisperAudioUtils.h +0 -2
  34. package/ios/RNWhisperAudioUtils.m +0 -56
  35. package/ios/RNWhisperContext.h +3 -11
  36. package/ios/RNWhisperContext.mm +68 -137
  37. package/lib/commonjs/index.js.map +1 -1
  38. package/lib/commonjs/version.json +1 -1
  39. package/lib/module/index.js.map +1 -1
  40. package/lib/module/version.json +1 -1
  41. package/lib/typescript/index.d.ts +5 -0
  42. package/lib/typescript/index.d.ts.map +1 -1
  43. package/package.json +6 -5
  44. package/src/index.ts +5 -0
  45. package/src/version.json +1 -1
  46. package/ios/RNWhisper.xcodeproj/project.xcworkspace/contents.xcworkspacedata +0 -4
  47. package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +0 -8
  48. package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcuserdata/jhen.xcuserdatad/UserInterfaceState.xcuserstate +0 -0
  49. package/ios/RNWhisper.xcodeproj/xcuserdata/jhen.xcuserdatad/xcschemes/xcschememanagement.plist +0 -19
@@ -3,6 +3,8 @@
3
3
  using namespace metal;
4
4
 
5
5
  #define MAX(x, y) ((x) > (y) ? (x) : (y))
6
+ #define MIN(x, y) ((x) < (y) ? (x) : (y))
7
+ #define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
6
8
 
7
9
  #define QK4_0 32
8
10
  #define QR4_0 2
@@ -39,8 +41,15 @@ typedef struct {
39
41
  int8_t qs[QK8_0]; // quants
40
42
  } block_q8_0;
41
43
 
42
- // general-purpose kernel for addition of two tensors
43
- // pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
44
+ #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
45
+
46
+ enum ggml_sort_order {
47
+ GGML_SORT_ASC,
48
+ GGML_SORT_DESC,
49
+ };
50
+
51
+ // general-purpose kernel for addition, multiplication and division of two tensors
52
+ // pros: works for non-contiguous tensors, supports broadcast across all dims
44
53
  // cons: not very efficient
45
54
  kernel void kernel_add(
46
55
  device const char * src0,
@@ -70,6 +79,7 @@ kernel void kernel_add(
70
79
  constant int64_t & nb1,
71
80
  constant int64_t & nb2,
72
81
  constant int64_t & nb3,
82
+ constant int64_t & offs,
73
83
  uint3 tgpig[[threadgroup_position_in_grid]],
74
84
  uint3 tpitg[[thread_position_in_threadgroup]],
75
85
  uint3 ntg[[threads_per_threadgroup]]) {
@@ -81,16 +91,111 @@ kernel void kernel_add(
81
91
  const int64_t i12 = i02 % ne12;
82
92
  const int64_t i11 = i01 % ne11;
83
93
 
84
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
85
- device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
86
- device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
94
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
95
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
96
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs;
97
+
98
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
99
+ const int i10 = i0 % ne10;
100
+ *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10));
101
+ }
102
+ }
103
+
104
+ kernel void kernel_mul(
105
+ device const char * src0,
106
+ device const char * src1,
107
+ device char * dst,
108
+ constant int64_t & ne00,
109
+ constant int64_t & ne01,
110
+ constant int64_t & ne02,
111
+ constant int64_t & ne03,
112
+ constant int64_t & nb00,
113
+ constant int64_t & nb01,
114
+ constant int64_t & nb02,
115
+ constant int64_t & nb03,
116
+ constant int64_t & ne10,
117
+ constant int64_t & ne11,
118
+ constant int64_t & ne12,
119
+ constant int64_t & ne13,
120
+ constant int64_t & nb10,
121
+ constant int64_t & nb11,
122
+ constant int64_t & nb12,
123
+ constant int64_t & nb13,
124
+ constant int64_t & ne0,
125
+ constant int64_t & ne1,
126
+ constant int64_t & ne2,
127
+ constant int64_t & ne3,
128
+ constant int64_t & nb0,
129
+ constant int64_t & nb1,
130
+ constant int64_t & nb2,
131
+ constant int64_t & nb3,
132
+ uint3 tgpig[[threadgroup_position_in_grid]],
133
+ uint3 tpitg[[thread_position_in_threadgroup]],
134
+ uint3 ntg[[threads_per_threadgroup]]) {
135
+ const int64_t i03 = tgpig.z;
136
+ const int64_t i02 = tgpig.y;
137
+ const int64_t i01 = tgpig.x;
138
+
139
+ const int64_t i13 = i03 % ne13;
140
+ const int64_t i12 = i02 % ne12;
141
+ const int64_t i11 = i01 % ne11;
142
+
143
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
144
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
145
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
87
146
 
88
147
  for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
89
- ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0] + ((device float *)src1_ptr)[0];
148
+ const int i10 = i0 % ne10;
149
+ *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10));
150
+ }
151
+ }
152
+
153
+ kernel void kernel_div(
154
+ device const char * src0,
155
+ device const char * src1,
156
+ device char * dst,
157
+ constant int64_t & ne00,
158
+ constant int64_t & ne01,
159
+ constant int64_t & ne02,
160
+ constant int64_t & ne03,
161
+ constant int64_t & nb00,
162
+ constant int64_t & nb01,
163
+ constant int64_t & nb02,
164
+ constant int64_t & nb03,
165
+ constant int64_t & ne10,
166
+ constant int64_t & ne11,
167
+ constant int64_t & ne12,
168
+ constant int64_t & ne13,
169
+ constant int64_t & nb10,
170
+ constant int64_t & nb11,
171
+ constant int64_t & nb12,
172
+ constant int64_t & nb13,
173
+ constant int64_t & ne0,
174
+ constant int64_t & ne1,
175
+ constant int64_t & ne2,
176
+ constant int64_t & ne3,
177
+ constant int64_t & nb0,
178
+ constant int64_t & nb1,
179
+ constant int64_t & nb2,
180
+ constant int64_t & nb3,
181
+ uint3 tgpig[[threadgroup_position_in_grid]],
182
+ uint3 tpitg[[thread_position_in_threadgroup]],
183
+ uint3 ntg[[threads_per_threadgroup]]) {
184
+ const int64_t i03 = tgpig.z;
185
+ const int64_t i02 = tgpig.y;
186
+ const int64_t i01 = tgpig.x;
187
+
188
+ const int64_t i13 = i03 % ne13;
189
+ const int64_t i12 = i02 % ne12;
190
+ const int64_t i11 = i01 % ne11;
90
191
 
91
- src0_ptr += ntg.x*nb00;
92
- src1_ptr += ntg.x*nb10;
93
- dst_ptr += ntg.x*nb0;
192
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
193
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
194
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
195
+
196
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
197
+ const int i10 = i0 % ne10;
198
+ *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10));
94
199
  }
95
200
  }
96
201
 
@@ -100,28 +205,27 @@ kernel void kernel_add_row(
100
205
  device const float4 * src0,
101
206
  device const float4 * src1,
102
207
  device float4 * dst,
103
- constant int64_t & nb [[buffer(27)]],
208
+ constant int64_t & nb [[buffer(28)]],
104
209
  uint tpig[[thread_position_in_grid]]) {
105
210
  dst[tpig] = src0[tpig] + src1[tpig % nb];
106
211
  }
107
212
 
108
- kernel void kernel_mul(
213
+ kernel void kernel_mul_row(
109
214
  device const float4 * src0,
110
215
  device const float4 * src1,
111
216
  device float4 * dst,
217
+ constant int64_t & nb [[buffer(28)]],
112
218
  uint tpig[[thread_position_in_grid]]) {
113
- dst[tpig] = src0[tpig] * src1[tpig];
219
+ dst[tpig] = src0[tpig] * src1[tpig % nb];
114
220
  }
115
221
 
116
- // assumption: src1 is a row
117
- // broadcast src1 into src0
118
- kernel void kernel_mul_row(
222
+ kernel void kernel_div_row(
119
223
  device const float4 * src0,
120
224
  device const float4 * src1,
121
225
  device float4 * dst,
122
- constant int64_t & nb,
226
+ constant int64_t & nb [[buffer(28)]],
123
227
  uint tpig[[thread_position_in_grid]]) {
124
- dst[tpig] = src0[tpig] * src1[tpig % nb];
228
+ dst[tpig] = src0[tpig] / src1[tpig % nb];
125
229
  }
126
230
 
127
231
  kernel void kernel_scale(
@@ -140,14 +244,6 @@ kernel void kernel_scale_4(
140
244
  dst[tpig] = src0[tpig] * scale;
141
245
  }
142
246
 
143
- kernel void kernel_silu(
144
- device const float4 * src0,
145
- device float4 * dst,
146
- uint tpig[[thread_position_in_grid]]) {
147
- device const float4 & x = src0[tpig];
148
- dst[tpig] = x / (1.0f + exp(-x));
149
- }
150
-
151
247
  kernel void kernel_relu(
152
248
  device const float * src0,
153
249
  device float * dst,
@@ -155,15 +251,17 @@ kernel void kernel_relu(
155
251
  dst[tpig] = max(0.0f, src0[tpig]);
156
252
  }
157
253
 
158
- kernel void kernel_sqr(
254
+ kernel void kernel_tanh(
159
255
  device const float * src0,
160
256
  device float * dst,
161
257
  uint tpig[[thread_position_in_grid]]) {
162
- dst[tpig] = src0[tpig] * src0[tpig];
258
+ device const float & x = src0[tpig];
259
+ dst[tpig] = precise::tanh(x);
163
260
  }
164
261
 
165
- constant float GELU_COEF_A = 0.044715f;
166
- constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
262
+ constant float GELU_COEF_A = 0.044715f;
263
+ constant float GELU_QUICK_COEF = -1.702f;
264
+ constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
167
265
 
168
266
  kernel void kernel_gelu(
169
267
  device const float4 * src0,
@@ -178,12 +276,86 @@ kernel void kernel_gelu(
178
276
  dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
179
277
  }
180
278
 
279
+ kernel void kernel_gelu_quick(
280
+ device const float4 * src0,
281
+ device float4 * dst,
282
+ uint tpig[[thread_position_in_grid]]) {
283
+ device const float4 & x = src0[tpig];
284
+
285
+ dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
286
+ }
287
+
288
+ kernel void kernel_silu(
289
+ device const float4 * src0,
290
+ device float4 * dst,
291
+ uint tpig[[thread_position_in_grid]]) {
292
+ device const float4 & x = src0[tpig];
293
+ dst[tpig] = x / (1.0f + exp(-x));
294
+ }
295
+
296
+ kernel void kernel_sqr(
297
+ device const float * src0,
298
+ device float * dst,
299
+ uint tpig[[thread_position_in_grid]]) {
300
+ dst[tpig] = src0[tpig] * src0[tpig];
301
+ }
302
+
303
+ kernel void kernel_sum_rows(
304
+ device const float * src0,
305
+ device float * dst,
306
+ constant int64_t & ne00,
307
+ constant int64_t & ne01,
308
+ constant int64_t & ne02,
309
+ constant int64_t & ne03,
310
+ constant int64_t & nb00,
311
+ constant int64_t & nb01,
312
+ constant int64_t & nb02,
313
+ constant int64_t & nb03,
314
+ constant int64_t & ne10,
315
+ constant int64_t & ne11,
316
+ constant int64_t & ne12,
317
+ constant int64_t & ne13,
318
+ constant int64_t & nb10,
319
+ constant int64_t & nb11,
320
+ constant int64_t & nb12,
321
+ constant int64_t & nb13,
322
+ constant int64_t & ne0,
323
+ constant int64_t & ne1,
324
+ constant int64_t & ne2,
325
+ constant int64_t & ne3,
326
+ constant int64_t & nb0,
327
+ constant int64_t & nb1,
328
+ constant int64_t & nb2,
329
+ constant int64_t & nb3,
330
+ uint3 tpig[[thread_position_in_grid]]) {
331
+ int64_t i3 = tpig.z;
332
+ int64_t i2 = tpig.y;
333
+ int64_t i1 = tpig.x;
334
+
335
+ if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
336
+ return;
337
+ }
338
+
339
+ device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
340
+ device float * dst_row = (device float *) ((device char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
341
+
342
+ float row_sum = 0;
343
+
344
+ for (int64_t i0 = 0; i0 < ne00; i0++) {
345
+ row_sum += src_row[i0];
346
+ }
347
+
348
+ dst_row[0] = row_sum;
349
+ }
350
+
181
351
  kernel void kernel_soft_max(
182
352
  device const float * src0,
353
+ device const float * src1,
183
354
  device float * dst,
184
355
  constant int64_t & ne00,
185
356
  constant int64_t & ne01,
186
357
  constant int64_t & ne02,
358
+ constant float & scale,
187
359
  threadgroup float * buf [[threadgroup(0)]],
188
360
  uint tgpig[[threadgroup_position_in_grid]],
189
361
  uint tpitg[[thread_position_in_threadgroup]],
@@ -194,73 +366,82 @@ kernel void kernel_soft_max(
194
366
  const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
195
367
  const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
196
368
 
197
- device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
198
- device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
369
+ device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
370
+ device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
371
+ device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
199
372
 
200
373
  // parallel max
201
- float lmax = tpitg < ne00 ? psrc0[tpitg] : -INFINITY;
374
+ float lmax = -INFINITY;
202
375
 
203
- for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) {
204
- lmax = MAX(lmax, psrc0[i00]);
376
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
377
+ lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f));
205
378
  }
206
379
 
207
- float max = simd_max(lmax);
208
- if (tiisg == 0) {
209
- buf[sgitg] = max;
210
- }
380
+ // find the max value in the block
381
+ float max_val = simd_max(lmax);
382
+ if (ntg > N_SIMDWIDTH) {
383
+ if (sgitg == 0) {
384
+ buf[tiisg] = -INFINITY;
385
+ }
211
386
 
212
- threadgroup_barrier(mem_flags::mem_threadgroup);
387
+ threadgroup_barrier(mem_flags::mem_threadgroup);
213
388
 
214
- // broadcast, simd group number is ntg / 32
215
- for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
216
- if (tpitg < i) {
217
- buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
218
- }
219
- }
389
+ if (tiisg == 0) {
390
+ buf[sgitg] = max_val;
391
+ }
220
392
 
221
- threadgroup_barrier(mem_flags::mem_threadgroup);
393
+ threadgroup_barrier(mem_flags::mem_threadgroup);
222
394
 
223
- max = buf[0];
395
+ max_val = buf[tiisg];
396
+ max_val = simd_max(max_val);
397
+ }
224
398
 
225
399
  // parallel sum
226
400
  float lsum = 0.0f;
227
401
  for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
228
- const float exp_psrc0 = exp(psrc0[i00] - max);
402
+ const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
229
403
  lsum += exp_psrc0;
230
- // Remember the result of exp here. exp is expensive, so we really do not
231
- // wish to compute it twice.
232
404
  pdst[i00] = exp_psrc0;
233
405
  }
234
406
 
407
+ // This barrier fixes a failing test
408
+ // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
409
+ threadgroup_barrier(mem_flags::mem_none);
410
+
235
411
  float sum = simd_sum(lsum);
236
- if (tiisg == 0) {
237
- buf[sgitg] = sum;
238
- }
239
412
 
240
- threadgroup_barrier(mem_flags::mem_threadgroup);
413
+ if (ntg > N_SIMDWIDTH) {
414
+ if (sgitg == 0) {
415
+ buf[tiisg] = 0.0f;
416
+ }
241
417
 
242
- // broadcast, simd group number is ntg / 32
243
- for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
244
- if (tpitg < i) {
245
- buf[tpitg] += buf[tpitg + i];
246
- }
247
- }
418
+ threadgroup_barrier(mem_flags::mem_threadgroup);
248
419
 
249
- threadgroup_barrier(mem_flags::mem_threadgroup);
420
+ if (tiisg == 0) {
421
+ buf[sgitg] = sum;
422
+ }
423
+
424
+ threadgroup_barrier(mem_flags::mem_threadgroup);
425
+
426
+ sum = buf[tiisg];
427
+ sum = simd_sum(sum);
428
+ }
250
429
 
251
- sum = buf[0];
430
+ const float inv_sum = 1.0f/sum;
252
431
 
253
432
  for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
254
- pdst[i00] /= sum;
433
+ pdst[i00] *= inv_sum;
255
434
  }
256
435
  }
257
436
 
258
437
  kernel void kernel_soft_max_4(
259
438
  device const float * src0,
439
+ device const float * src1,
260
440
  device float * dst,
261
441
  constant int64_t & ne00,
262
442
  constant int64_t & ne01,
263
443
  constant int64_t & ne02,
444
+ constant float & scale,
264
445
  threadgroup float * buf [[threadgroup(0)]],
265
446
  uint tgpig[[threadgroup_position_in_grid]],
266
447
  uint tpitg[[thread_position_in_threadgroup]],
@@ -271,64 +452,74 @@ kernel void kernel_soft_max_4(
271
452
  const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
272
453
  const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
273
454
 
274
- device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
275
- device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
455
+ device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
456
+ device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
457
+ device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
276
458
 
277
459
  // parallel max
278
- float4 lmax4 = tpitg < ne00/4 ? psrc4[tpitg] : -INFINITY;
460
+ float4 lmax4 = -INFINITY;
279
461
 
280
- for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) {
281
- lmax4 = fmax(lmax4, psrc4[i00]);
462
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
463
+ lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f));
282
464
  }
283
465
 
284
466
  const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
285
- float max = simd_max(lmax);
286
- if (tiisg == 0) {
287
- buf[sgitg] = max;
288
- }
289
467
 
290
- threadgroup_barrier(mem_flags::mem_threadgroup);
468
+ float max_val = simd_max(lmax);
469
+ if (ntg > N_SIMDWIDTH) {
470
+ if (sgitg == 0) {
471
+ buf[tiisg] = -INFINITY;
472
+ }
291
473
 
292
- // broadcast, simd group number is ntg / 32
293
- for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
294
- if (tpitg < i) {
295
- buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
296
- }
297
- }
474
+ threadgroup_barrier(mem_flags::mem_threadgroup);
298
475
 
299
- threadgroup_barrier(mem_flags::mem_threadgroup);
476
+ if (tiisg == 0) {
477
+ buf[sgitg] = max_val;
478
+ }
479
+
480
+ threadgroup_barrier(mem_flags::mem_threadgroup);
300
481
 
301
- max = buf[0];
482
+ max_val = buf[tiisg];
483
+ max_val = simd_max(max_val);
484
+ }
302
485
 
303
486
  // parallel sum
304
487
  float4 lsum4 = 0.0f;
305
488
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
306
- const float4 exp_psrc4 = exp(psrc4[i00] - max);
489
+ const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
307
490
  lsum4 += exp_psrc4;
308
491
  pdst4[i00] = exp_psrc4;
309
492
  }
310
493
 
311
494
  const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
495
+
496
+ // This barrier fixes a failing test
497
+ // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
498
+ threadgroup_barrier(mem_flags::mem_none);
499
+
312
500
  float sum = simd_sum(lsum);
313
- if (tiisg == 0) {
314
- buf[sgitg] = sum;
315
- }
316
501
 
317
- threadgroup_barrier(mem_flags::mem_threadgroup);
502
+ if (ntg > N_SIMDWIDTH) {
503
+ if (sgitg == 0) {
504
+ buf[tiisg] = 0.0f;
505
+ }
318
506
 
319
- // broadcast, simd group number is ntg / 32
320
- for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
321
- if (tpitg < i) {
322
- buf[tpitg] += buf[tpitg + i];
323
- }
324
- }
507
+ threadgroup_barrier(mem_flags::mem_threadgroup);
325
508
 
326
- threadgroup_barrier(mem_flags::mem_threadgroup);
509
+ if (tiisg == 0) {
510
+ buf[sgitg] = sum;
511
+ }
512
+
513
+ threadgroup_barrier(mem_flags::mem_threadgroup);
514
+
515
+ sum = buf[tiisg];
516
+ sum = simd_sum(sum);
517
+ }
327
518
 
328
- sum = buf[0];
519
+ const float inv_sum = 1.0f/sum;
329
520
 
330
521
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
331
- pdst4[i00] /= sum;
522
+ pdst4[i00] *= inv_sum;
332
523
  }
333
524
  }
334
525
 
@@ -435,14 +626,13 @@ kernel void kernel_rms_norm(
435
626
  constant int64_t & ne00,
436
627
  constant uint64_t & nb01,
437
628
  constant float & eps,
438
- threadgroup float * sum [[threadgroup(0)]],
629
+ threadgroup float * buf [[threadgroup(0)]],
439
630
  uint tgpig[[threadgroup_position_in_grid]],
440
631
  uint tpitg[[thread_position_in_threadgroup]],
441
632
  uint sgitg[[simdgroup_index_in_threadgroup]],
442
633
  uint tiisg[[thread_index_in_simdgroup]],
443
634
  uint ntg[[threads_per_threadgroup]]) {
444
- device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
445
- device const float * x_scalar = (device const float *) x;
635
+ device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
446
636
 
447
637
  float4 sumf = 0;
448
638
  float all_sum = 0;
@@ -453,52 +643,130 @@ kernel void kernel_rms_norm(
453
643
  }
454
644
  all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
455
645
  all_sum = simd_sum(all_sum);
456
- if (tiisg == 0) {
457
- sum[sgitg] = all_sum;
458
- }
646
+ if (ntg > N_SIMDWIDTH) {
647
+ if (sgitg == 0) {
648
+ buf[tiisg] = 0.0f;
649
+ }
459
650
 
460
- threadgroup_barrier(mem_flags::mem_threadgroup);
651
+ threadgroup_barrier(mem_flags::mem_threadgroup);
461
652
 
462
- // broadcast, simd group number is ntg / 32
463
- for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
464
- if (tpitg < i) {
465
- sum[tpitg] += sum[tpitg + i];
466
- }
467
- }
468
- if (tpitg == 0) {
469
- for (int i = 4 * (ne00 / 4); i < ne00; i++) {
470
- sum[0] += x_scalar[i];
653
+ if (tiisg == 0) {
654
+ buf[sgitg] = all_sum;
471
655
  }
472
- sum[0] /= ne00;
473
- }
474
656
 
475
- threadgroup_barrier(mem_flags::mem_threadgroup);
657
+ threadgroup_barrier(mem_flags::mem_threadgroup);
658
+
659
+ all_sum = buf[tiisg];
660
+ all_sum = simd_sum(all_sum);
661
+ }
476
662
 
477
- const float mean = sum[0];
663
+ const float mean = all_sum/ne00;
478
664
  const float scale = 1.0f/sqrt(mean + eps);
479
665
 
480
666
  device float4 * y = (device float4 *) (dst + tgpig*ne00);
481
- device float * y_scalar = (device float *) y;
482
667
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
483
668
  y[i00] = x[i00] * scale;
484
669
  }
485
- if (tpitg == 0) {
486
- for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
487
- y_scalar[i00] = x_scalar[i00] * scale;
488
- }
489
- }
490
670
  }
491
671
 
492
- // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
493
- // il indicates where the q4 quants begin (0 or QK4_0/4)
494
- // we assume that the yl's have been multiplied with the appropriate scale factor
495
- // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
496
- inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
497
- float d = qb_curr->d;
498
-
499
- float2 acc = 0.f;
500
-
501
- device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
672
+ kernel void kernel_group_norm(
673
+ device const float * src0,
674
+ device float * dst,
675
+ constant int64_t & ne00,
676
+ constant int64_t & ne01,
677
+ constant int64_t & ne02,
678
+ constant uint64_t & nb00,
679
+ constant uint64_t & nb01,
680
+ constant uint64_t & nb02,
681
+ constant int32_t & n_groups,
682
+ constant float & eps,
683
+ threadgroup float * buf [[threadgroup(0)]],
684
+ uint tgpig[[threadgroup_position_in_grid]],
685
+ uint tpitg[[thread_position_in_threadgroup]],
686
+ uint sgitg[[simdgroup_index_in_threadgroup]],
687
+ uint tiisg[[thread_index_in_simdgroup]],
688
+ uint ntg[[threads_per_threadgroup]]) {
689
+ const int64_t ne = ne00*ne01*ne02;
690
+ const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups);
691
+
692
+ int start = tgpig * gs;
693
+ int end = start + gs;
694
+
695
+ start += tpitg;
696
+
697
+ if (end >= ne) {
698
+ end = ne;
699
+ }
700
+
701
+ float tmp = 0.0f; // partial sum for thread in warp
702
+
703
+ for (int j = start; j < end; j += ntg) {
704
+ tmp += src0[j];
705
+ }
706
+
707
+ threadgroup_barrier(mem_flags::mem_threadgroup);
708
+ tmp = simd_sum(tmp);
709
+ if (ntg > N_SIMDWIDTH) {
710
+ if (sgitg == 0) {
711
+ buf[tiisg] = 0.0f;
712
+ }
713
+
714
+ threadgroup_barrier(mem_flags::mem_threadgroup);
715
+
716
+ if (tiisg == 0) {
717
+ buf[sgitg] = tmp;
718
+ }
719
+
720
+ threadgroup_barrier(mem_flags::mem_threadgroup);
721
+
722
+ tmp = buf[tiisg];
723
+ tmp = simd_sum(tmp);
724
+ }
725
+
726
+ const float mean = tmp / gs;
727
+ tmp = 0.0f;
728
+
729
+ for (int j = start; j < end; j += ntg) {
730
+ float xi = src0[j] - mean;
731
+ dst[j] = xi;
732
+ tmp += xi * xi;
733
+ }
734
+
735
+ tmp = simd_sum(tmp);
736
+ if (ntg > N_SIMDWIDTH) {
737
+ if (sgitg == 0) {
738
+ buf[tiisg] = 0.0f;
739
+ }
740
+
741
+ threadgroup_barrier(mem_flags::mem_threadgroup);
742
+
743
+ if (tiisg == 0) {
744
+ buf[sgitg] = tmp;
745
+ }
746
+
747
+ threadgroup_barrier(mem_flags::mem_threadgroup);
748
+
749
+ tmp = buf[tiisg];
750
+ tmp = simd_sum(tmp);
751
+ }
752
+
753
+ const float variance = tmp / gs;
754
+ const float scale = 1.0f/sqrt(variance + eps);
755
+ for (int j = start; j < end; j += ntg) {
756
+ dst[j] *= scale;
757
+ }
758
+ }
759
+
760
+ // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
761
+ // il indicates where the q4 quants begin (0 or QK4_0/4)
762
+ // we assume that the yl's have been multiplied with the appropriate scale factor
763
+ // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
764
+ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
765
+ float d = qb_curr->d;
766
+
767
+ float2 acc = 0.f;
768
+
769
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
502
770
 
503
771
  for (int i = 0; i < 8; i+=2) {
504
772
  acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
@@ -576,15 +844,25 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
576
844
  // putting them in the kernel cause a significant performance penalty
577
845
  #define N_DST 4 // each SIMD group works on 4 rows
578
846
  #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
579
- #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
580
847
  //Note: This is a template, but strictly speaking it only applies to
581
848
  // quantizations where the block size is 32. It also does not
582
849
  // giard against the number of rows not being divisible by
583
850
  // N_DST, so this is another explicit assumption of the implementation.
584
851
  template<typename block_q_type, int nr, int nsg, int nw>
585
- void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
586
- int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
587
- uint3 tgpig, uint tiisg, uint sgitg) {
852
+ void mul_vec_q_n_f32_impl(
853
+ device const void * src0,
854
+ device const float * src1,
855
+ device float * dst,
856
+ int64_t ne00,
857
+ int64_t ne01,
858
+ int64_t ne02,
859
+ int64_t ne10,
860
+ int64_t ne12,
861
+ int64_t ne0,
862
+ int64_t ne1,
863
+ uint r2,
864
+ uint r3,
865
+ uint3 tgpig, uint tiisg, uint sgitg) {
588
866
  const int nb = ne00/QK4_0;
589
867
 
590
868
  const int r0 = tgpig.x;
@@ -593,7 +871,10 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
593
871
 
594
872
  const int first_row = (r0 * nsg + sgitg) * nr;
595
873
 
596
- const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
874
+ const uint i12 = im%ne12;
875
+ const uint i13 = im/ne12;
876
+
877
+ const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
597
878
 
598
879
  device const block_q_type * x = (device const block_q_type *) src0 + offset0;
599
880
  device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
@@ -643,13 +924,14 @@ kernel void kernel_mul_mv_q4_0_f32(
643
924
  constant int64_t & ne02[[buffer(5)]],
644
925
  constant int64_t & ne10[[buffer(9)]],
645
926
  constant int64_t & ne12[[buffer(11)]],
646
- constant int64_t & ne0[[buffer(15)]],
647
- constant int64_t & ne1[[buffer(16)]],
648
- constant uint & gqa[[buffer(17)]],
927
+ constant int64_t & ne0 [[buffer(15)]],
928
+ constant int64_t & ne1 [[buffer(16)]],
929
+ constant uint & r2 [[buffer(17)]],
930
+ constant uint & r3 [[buffer(18)]],
649
931
  uint3 tgpig[[threadgroup_position_in_grid]],
650
932
  uint tiisg[[thread_index_in_simdgroup]],
651
933
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
652
- mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
934
+ mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
653
935
  }
654
936
 
655
937
  kernel void kernel_mul_mv_q4_1_f32(
@@ -661,13 +943,14 @@ kernel void kernel_mul_mv_q4_1_f32(
661
943
  constant int64_t & ne02[[buffer(5)]],
662
944
  constant int64_t & ne10[[buffer(9)]],
663
945
  constant int64_t & ne12[[buffer(11)]],
664
- constant int64_t & ne0[[buffer(15)]],
665
- constant int64_t & ne1[[buffer(16)]],
666
- constant uint & gqa[[buffer(17)]],
946
+ constant int64_t & ne0 [[buffer(15)]],
947
+ constant int64_t & ne1 [[buffer(16)]],
948
+ constant uint & r2 [[buffer(17)]],
949
+ constant uint & r3 [[buffer(18)]],
667
950
  uint3 tgpig[[threadgroup_position_in_grid]],
668
951
  uint tiisg[[thread_index_in_simdgroup]],
669
952
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
670
- mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
953
+ mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
671
954
  }
672
955
 
673
956
  kernel void kernel_mul_mv_q5_0_f32(
@@ -679,13 +962,14 @@ kernel void kernel_mul_mv_q5_0_f32(
679
962
  constant int64_t & ne02[[buffer(5)]],
680
963
  constant int64_t & ne10[[buffer(9)]],
681
964
  constant int64_t & ne12[[buffer(11)]],
682
- constant int64_t & ne0[[buffer(15)]],
683
- constant int64_t & ne1[[buffer(16)]],
684
- constant uint & gqa[[buffer(17)]],
965
+ constant int64_t & ne0 [[buffer(15)]],
966
+ constant int64_t & ne1 [[buffer(16)]],
967
+ constant uint & r2 [[buffer(17)]],
968
+ constant uint & r3 [[buffer(18)]],
685
969
  uint3 tgpig[[threadgroup_position_in_grid]],
686
970
  uint tiisg[[thread_index_in_simdgroup]],
687
971
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
688
- mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
972
+ mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
689
973
  }
690
974
 
691
975
  kernel void kernel_mul_mv_q5_1_f32(
@@ -697,33 +981,35 @@ kernel void kernel_mul_mv_q5_1_f32(
697
981
  constant int64_t & ne02[[buffer(5)]],
698
982
  constant int64_t & ne10[[buffer(9)]],
699
983
  constant int64_t & ne12[[buffer(11)]],
700
- constant int64_t & ne0[[buffer(15)]],
701
- constant int64_t & ne1[[buffer(16)]],
702
- constant uint & gqa[[buffer(17)]],
984
+ constant int64_t & ne0 [[buffer(15)]],
985
+ constant int64_t & ne1 [[buffer(16)]],
986
+ constant uint & r2 [[buffer(17)]],
987
+ constant uint & r3 [[buffer(18)]],
703
988
  uint3 tgpig[[threadgroup_position_in_grid]],
704
989
  uint tiisg[[thread_index_in_simdgroup]],
705
990
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
706
- mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
991
+ mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
707
992
  }
708
993
 
709
994
 
710
995
  #define NB_Q8_0 8
711
996
 
712
- kernel void kernel_mul_mv_q8_0_f32(
997
+ void kernel_mul_mv_q8_0_f32_impl(
713
998
  device const void * src0,
714
999
  device const float * src1,
715
1000
  device float * dst,
716
1001
  constant int64_t & ne00,
717
- constant int64_t & ne01[[buffer(4)]],
718
- constant int64_t & ne02[[buffer(5)]],
719
- constant int64_t & ne10[[buffer(9)]],
720
- constant int64_t & ne12[[buffer(11)]],
721
- constant int64_t & ne0[[buffer(15)]],
722
- constant int64_t & ne1[[buffer(16)]],
723
- constant uint & gqa[[buffer(17)]],
1002
+ constant int64_t & ne01,
1003
+ constant int64_t & ne02,
1004
+ constant int64_t & ne10,
1005
+ constant int64_t & ne12,
1006
+ constant int64_t & ne0,
1007
+ constant int64_t & ne1,
1008
+ constant uint & r2,
1009
+ constant uint & r3,
724
1010
  uint3 tgpig[[threadgroup_position_in_grid]],
725
- uint tiisg[[thread_index_in_simdgroup]],
726
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
1011
+ uint tiisg[[thread_index_in_simdgroup]],
1012
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
727
1013
  const int nr = N_DST;
728
1014
  const int nsg = N_SIMDGROUP;
729
1015
  const int nw = N_SIMDWIDTH;
@@ -732,8 +1018,14 @@ kernel void kernel_mul_mv_q8_0_f32(
732
1018
  const int r0 = tgpig.x;
733
1019
  const int r1 = tgpig.y;
734
1020
  const int im = tgpig.z;
1021
+
735
1022
  const int first_row = (r0 * nsg + sgitg) * nr;
736
- const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
1023
+
1024
+ const uint i12 = im%ne12;
1025
+ const uint i13 = im/ne12;
1026
+
1027
+ const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
1028
+
737
1029
  device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
738
1030
  device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
739
1031
 
@@ -771,9 +1063,29 @@ kernel void kernel_mul_mv_q8_0_f32(
771
1063
  }
772
1064
  }
773
1065
 
1066
+ [[host_name("kernel_mul_mv_q8_0_f32")]]
1067
+ kernel void kernel_mul_mv_q8_0_f32(
1068
+ device const void * src0,
1069
+ device const float * src1,
1070
+ device float * dst,
1071
+ constant int64_t & ne00,
1072
+ constant int64_t & ne01,
1073
+ constant int64_t & ne02,
1074
+ constant int64_t & ne10,
1075
+ constant int64_t & ne12,
1076
+ constant int64_t & ne0,
1077
+ constant int64_t & ne1,
1078
+ constant uint & r2 [[buffer(17)]],
1079
+ constant uint & r3 [[buffer(18)]],
1080
+ uint3 tgpig[[threadgroup_position_in_grid]],
1081
+ uint tiisg[[thread_index_in_simdgroup]],
1082
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1083
+ kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
1084
+ }
1085
+
774
1086
  #define N_F32_F32 4
775
1087
 
776
- kernel void kernel_mul_mv_f32_f32(
1088
+ void kernel_mul_mv_f32_f32_impl(
777
1089
  device const char * src0,
778
1090
  device const char * src1,
779
1091
  device float * dst,
@@ -791,14 +1103,21 @@ kernel void kernel_mul_mv_f32_f32(
791
1103
  constant uint64_t & nb12,
792
1104
  constant int64_t & ne0,
793
1105
  constant int64_t & ne1,
1106
+ constant uint & r2,
1107
+ constant uint & r3,
794
1108
  uint3 tgpig[[threadgroup_position_in_grid]],
795
- uint tiisg[[thread_index_in_simdgroup]]) {
1109
+ uint tiisg[[thread_index_in_simdgroup]]) {
796
1110
 
797
1111
  const int64_t r0 = tgpig.x;
798
1112
  const int64_t rb = tgpig.y*N_F32_F32;
799
1113
  const int64_t im = tgpig.z;
800
1114
 
801
- device const float * x = (device const float *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
1115
+ const uint i12 = im%ne12;
1116
+ const uint i13 = im/ne12;
1117
+
1118
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1119
+
1120
+ device const float * x = (device const float *) (src0 + offset0);
802
1121
 
803
1122
  if (ne00 < 128) {
804
1123
  for (int row = 0; row < N_F32_F32; ++row) {
@@ -844,7 +1163,8 @@ kernel void kernel_mul_mv_f32_f32(
844
1163
  }
845
1164
  }
846
1165
 
847
- kernel void kernel_mul_mv_f16_f32_1row(
1166
+ [[host_name("kernel_mul_mv_f32_f32")]]
1167
+ kernel void kernel_mul_mv_f32_f32(
848
1168
  device const char * src0,
849
1169
  device const char * src1,
850
1170
  device float * dst,
@@ -862,6 +1182,113 @@ kernel void kernel_mul_mv_f16_f32_1row(
862
1182
  constant uint64_t & nb12,
863
1183
  constant int64_t & ne0,
864
1184
  constant int64_t & ne1,
1185
+ constant uint & r2 [[buffer(17)]],
1186
+ constant uint & r3 [[buffer(18)]],
1187
+ uint3 tgpig[[threadgroup_position_in_grid]],
1188
+ uint tiisg[[thread_index_in_simdgroup]]) {
1189
+ kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
1190
+ }
1191
+
1192
+ #define N_F16_F16 4
1193
+
1194
+ kernel void kernel_mul_mv_f16_f16(
1195
+ device const char * src0,
1196
+ device const char * src1,
1197
+ device float * dst,
1198
+ constant int64_t & ne00,
1199
+ constant int64_t & ne01,
1200
+ constant int64_t & ne02,
1201
+ constant uint64_t & nb00,
1202
+ constant uint64_t & nb01,
1203
+ constant uint64_t & nb02,
1204
+ constant int64_t & ne10,
1205
+ constant int64_t & ne11,
1206
+ constant int64_t & ne12,
1207
+ constant uint64_t & nb10,
1208
+ constant uint64_t & nb11,
1209
+ constant uint64_t & nb12,
1210
+ constant int64_t & ne0,
1211
+ constant int64_t & ne1,
1212
+ constant uint & r2 [[buffer(17)]],
1213
+ constant uint & r3 [[buffer(18)]],
1214
+ uint3 tgpig[[threadgroup_position_in_grid]],
1215
+ uint tiisg[[thread_index_in_simdgroup]]) {
1216
+
1217
+ const int64_t r0 = tgpig.x;
1218
+ const int64_t rb = tgpig.y*N_F16_F16;
1219
+ const int64_t im = tgpig.z;
1220
+
1221
+ const uint i12 = im%ne12;
1222
+ const uint i13 = im/ne12;
1223
+
1224
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1225
+
1226
+ device const half * x = (device const half *) (src0 + offset0);
1227
+
1228
+ if (ne00 < 128) {
1229
+ for (int row = 0; row < N_F16_F16; ++row) {
1230
+ int r1 = rb + row;
1231
+ if (r1 >= ne11) {
1232
+ break;
1233
+ }
1234
+
1235
+ device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
1236
+
1237
+ float sumf = 0;
1238
+ for (int i = tiisg; i < ne00; i += 32) {
1239
+ sumf += (half) x[i] * (half) y[i];
1240
+ }
1241
+
1242
+ float all_sum = simd_sum(sumf);
1243
+ if (tiisg == 0) {
1244
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1245
+ }
1246
+ }
1247
+ } else {
1248
+ device const half4 * x4 = (device const half4 *)x;
1249
+ for (int row = 0; row < N_F16_F16; ++row) {
1250
+ int r1 = rb + row;
1251
+ if (r1 >= ne11) {
1252
+ break;
1253
+ }
1254
+
1255
+ device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
1256
+ device const half4 * y4 = (device const half4 *) y;
1257
+
1258
+ float sumf = 0;
1259
+ for (int i = tiisg; i < ne00/4; i += 32) {
1260
+ for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k];
1261
+ }
1262
+
1263
+ float all_sum = simd_sum(sumf);
1264
+ if (tiisg == 0) {
1265
+ for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i];
1266
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1267
+ }
1268
+ }
1269
+ }
1270
+ }
1271
+
1272
+ void kernel_mul_mv_f16_f32_1row_impl(
1273
+ device const char * src0,
1274
+ device const char * src1,
1275
+ device float * dst,
1276
+ constant int64_t & ne00,
1277
+ constant int64_t & ne01,
1278
+ constant int64_t & ne02,
1279
+ constant uint64_t & nb00,
1280
+ constant uint64_t & nb01,
1281
+ constant uint64_t & nb02,
1282
+ constant int64_t & ne10,
1283
+ constant int64_t & ne11,
1284
+ constant int64_t & ne12,
1285
+ constant uint64_t & nb10,
1286
+ constant uint64_t & nb11,
1287
+ constant uint64_t & nb12,
1288
+ constant int64_t & ne0,
1289
+ constant int64_t & ne1,
1290
+ constant uint & r2,
1291
+ constant uint & r3,
865
1292
  uint3 tgpig[[threadgroup_position_in_grid]],
866
1293
  uint tiisg[[thread_index_in_simdgroup]]) {
867
1294
 
@@ -869,7 +1296,12 @@ kernel void kernel_mul_mv_f16_f32_1row(
869
1296
  const int64_t r1 = tgpig.y;
870
1297
  const int64_t im = tgpig.z;
871
1298
 
872
- device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
1299
+ const uint i12 = im%ne12;
1300
+ const uint i13 = im/ne12;
1301
+
1302
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1303
+
1304
+ device const half * x = (device const half *) (src0 + offset0);
873
1305
  device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
874
1306
 
875
1307
  float sumf = 0;
@@ -893,12 +1325,37 @@ kernel void kernel_mul_mv_f16_f32_1row(
893
1325
  dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
894
1326
  }
895
1327
  }
1328
+ }
896
1329
 
1330
+ [[host_name("kernel_mul_mv_f16_f32_1row")]]
1331
+ kernel void kernel_mul_mv_f16_f32_1row(
1332
+ device const char * src0,
1333
+ device const char * src1,
1334
+ device float * dst,
1335
+ constant int64_t & ne00,
1336
+ constant int64_t & ne01,
1337
+ constant int64_t & ne02,
1338
+ constant uint64_t & nb00,
1339
+ constant uint64_t & nb01,
1340
+ constant uint64_t & nb02,
1341
+ constant int64_t & ne10,
1342
+ constant int64_t & ne11,
1343
+ constant int64_t & ne12,
1344
+ constant uint64_t & nb10,
1345
+ constant uint64_t & nb11,
1346
+ constant uint64_t & nb12,
1347
+ constant int64_t & ne0,
1348
+ constant int64_t & ne1,
1349
+ constant uint & r2 [[buffer(17)]],
1350
+ constant uint & r3 [[buffer(18)]],
1351
+ uint3 tgpig[[threadgroup_position_in_grid]],
1352
+ uint tiisg[[thread_index_in_simdgroup]]) {
1353
+ kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
897
1354
  }
898
1355
 
899
1356
  #define N_F16_F32 4
900
1357
 
901
- kernel void kernel_mul_mv_f16_f32(
1358
+ void kernel_mul_mv_f16_f32_impl(
902
1359
  device const char * src0,
903
1360
  device const char * src1,
904
1361
  device float * dst,
@@ -916,6 +1373,8 @@ kernel void kernel_mul_mv_f16_f32(
916
1373
  constant uint64_t & nb12,
917
1374
  constant int64_t & ne0,
918
1375
  constant int64_t & ne1,
1376
+ constant uint & r2,
1377
+ constant uint & r3,
919
1378
  uint3 tgpig[[threadgroup_position_in_grid]],
920
1379
  uint tiisg[[thread_index_in_simdgroup]]) {
921
1380
 
@@ -923,7 +1382,12 @@ kernel void kernel_mul_mv_f16_f32(
923
1382
  const int64_t rb = tgpig.y*N_F16_F32;
924
1383
  const int64_t im = tgpig.z;
925
1384
 
926
- device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
1385
+ const uint i12 = im%ne12;
1386
+ const uint i13 = im/ne12;
1387
+
1388
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1389
+
1390
+ device const half * x = (device const half *) (src0 + offset0);
927
1391
 
928
1392
  if (ne00 < 128) {
929
1393
  for (int row = 0; row < N_F16_F32; ++row) {
@@ -969,6 +1433,32 @@ kernel void kernel_mul_mv_f16_f32(
969
1433
  }
970
1434
  }
971
1435
 
1436
+ [[host_name("kernel_mul_mv_f16_f32")]]
1437
+ kernel void kernel_mul_mv_f16_f32(
1438
+ device const char * src0,
1439
+ device const char * src1,
1440
+ device float * dst,
1441
+ constant int64_t & ne00,
1442
+ constant int64_t & ne01,
1443
+ constant int64_t & ne02,
1444
+ constant uint64_t & nb00,
1445
+ constant uint64_t & nb01,
1446
+ constant uint64_t & nb02,
1447
+ constant int64_t & ne10,
1448
+ constant int64_t & ne11,
1449
+ constant int64_t & ne12,
1450
+ constant uint64_t & nb10,
1451
+ constant uint64_t & nb11,
1452
+ constant uint64_t & nb12,
1453
+ constant int64_t & ne0,
1454
+ constant int64_t & ne1,
1455
+ constant uint & r2 [[buffer(17)]],
1456
+ constant uint & r3 [[buffer(18)]],
1457
+ uint3 tgpig[[threadgroup_position_in_grid]],
1458
+ uint tiisg[[thread_index_in_simdgroup]]) {
1459
+ kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
1460
+ }
1461
+
972
1462
  // Assumes row size (ne00) is a multiple of 4
973
1463
  kernel void kernel_mul_mv_f16_f32_l4(
974
1464
  device const char * src0,
@@ -988,6 +1478,8 @@ kernel void kernel_mul_mv_f16_f32_l4(
988
1478
  constant uint64_t & nb12,
989
1479
  constant int64_t & ne0,
990
1480
  constant int64_t & ne1,
1481
+ constant uint & r2 [[buffer(17)]],
1482
+ constant uint & r3 [[buffer(18)]],
991
1483
  uint3 tgpig[[threadgroup_position_in_grid]],
992
1484
  uint tiisg[[thread_index_in_simdgroup]]) {
993
1485
 
@@ -995,7 +1487,12 @@ kernel void kernel_mul_mv_f16_f32_l4(
995
1487
  const int64_t r0 = tgpig.x;
996
1488
  const int64_t im = tgpig.z;
997
1489
 
998
- device const half4 * x4 = (device const half4 *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
1490
+ const uint i12 = im%ne12;
1491
+ const uint i13 = im/ne12;
1492
+
1493
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1494
+
1495
+ device const half4 * x4 = (device const half4 *) (src0 + offset0);
999
1496
 
1000
1497
  for (int r1 = 0; r1 < nrows; ++r1) {
1001
1498
  device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
@@ -1047,17 +1544,21 @@ kernel void kernel_alibi_f32(
1047
1544
  const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1048
1545
  const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1049
1546
  const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1547
+ const int64_t k = i3*ne3 + i2;
1050
1548
 
1051
- device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1052
1549
  float m_k;
1053
- if (i2 < n_heads_log2_floor) {
1054
- m_k = pow(m0, i2 + 1);
1550
+ if (k < n_heads_log2_floor) {
1551
+ m_k = pow(m0, k + 1);
1055
1552
  } else {
1056
- m_k = pow(m1, 2 * (i2 - n_heads_log2_floor) + 1);
1553
+ m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1);
1057
1554
  }
1555
+
1556
+ device char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1;
1557
+ device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01;
1058
1558
  for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
1059
- device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1060
- dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);
1559
+ const float src_v = *(device float *)(src_row + i00*nb00);
1560
+ device float * dst_v = (device float *)(dst_row + i00*nb0);
1561
+ *dst_v = i00 * m_k + src_v;
1061
1562
  }
1062
1563
  }
1063
1564
 
@@ -1213,25 +1714,333 @@ kernel void kernel_rope(
1213
1714
 
1214
1715
  const int64_t i0 = ib*n_dims + ic/2;
1215
1716
 
1216
- device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1217
- device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1717
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1718
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1719
+
1720
+ const float x0 = src[0];
1721
+ const float x1 = src[n_dims/2];
1722
+
1723
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
1724
+ dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
1725
+ }
1726
+ }
1727
+ }
1728
+ }
1729
+
1730
+ template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
1731
+ template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
1732
+
1733
+ kernel void kernel_im2col_f16(
1734
+ device const float * x,
1735
+ device half * dst,
1736
+ constant int32_t & ofs0,
1737
+ constant int32_t & ofs1,
1738
+ constant int32_t & IW,
1739
+ constant int32_t & IH,
1740
+ constant int32_t & CHW,
1741
+ constant int32_t & s0,
1742
+ constant int32_t & s1,
1743
+ constant int32_t & p0,
1744
+ constant int32_t & p1,
1745
+ constant int32_t & d0,
1746
+ constant int32_t & d1,
1747
+ uint3 tgpig[[threadgroup_position_in_grid]],
1748
+ uint3 tgpg[[threadgroups_per_grid]],
1749
+ uint3 tpitg[[thread_position_in_threadgroup]],
1750
+ uint3 ntg[[threads_per_threadgroup]]) {
1751
+ const int32_t iiw = tgpig[2] * s0 + tpitg[2] * d0 - p0;
1752
+ const int32_t iih = tgpig[1] * s1 + tpitg[1] * d1 - p1;
1753
+
1754
+ const int32_t offset_dst =
1755
+ (tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
1756
+ (tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
1757
+
1758
+ if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
1759
+ dst[offset_dst] = 0.0f;
1760
+ } else {
1761
+ const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
1762
+ dst[offset_dst] = x[offset_src + iih * IW + iiw];
1763
+ }
1764
+ }
1765
+
1766
+ kernel void kernel_upscale_f32(
1767
+ device const char * src0,
1768
+ device char * dst,
1769
+ constant int64_t & ne00,
1770
+ constant int64_t & ne01,
1771
+ constant int64_t & ne02,
1772
+ constant int64_t & ne03,
1773
+ constant uint64_t & nb00,
1774
+ constant uint64_t & nb01,
1775
+ constant uint64_t & nb02,
1776
+ constant uint64_t & nb03,
1777
+ constant int64_t & ne0,
1778
+ constant int64_t & ne1,
1779
+ constant int64_t & ne2,
1780
+ constant int64_t & ne3,
1781
+ constant uint64_t & nb0,
1782
+ constant uint64_t & nb1,
1783
+ constant uint64_t & nb2,
1784
+ constant uint64_t & nb3,
1785
+ constant int32_t & sf,
1786
+ uint3 tgpig[[threadgroup_position_in_grid]],
1787
+ uint3 tpitg[[thread_position_in_threadgroup]],
1788
+ uint3 ntg[[threads_per_threadgroup]]) {
1789
+
1790
+ const int64_t i3 = tgpig.z;
1791
+ const int64_t i2 = tgpig.y;
1792
+ const int64_t i1 = tgpig.x;
1793
+
1794
+ const int64_t i03 = i3;
1795
+ const int64_t i02 = i2;
1796
+ const int64_t i01 = i1/sf;
1797
+
1798
+ device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
1799
+ device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
1800
+
1801
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
1802
+ dst_ptr[i0] = src0_ptr[i0/sf];
1803
+ }
1804
+ }
1805
+
1806
+ kernel void kernel_pad_f32(
1807
+ device const char * src0,
1808
+ device char * dst,
1809
+ constant int64_t & ne00,
1810
+ constant int64_t & ne01,
1811
+ constant int64_t & ne02,
1812
+ constant int64_t & ne03,
1813
+ constant uint64_t & nb00,
1814
+ constant uint64_t & nb01,
1815
+ constant uint64_t & nb02,
1816
+ constant uint64_t & nb03,
1817
+ constant int64_t & ne0,
1818
+ constant int64_t & ne1,
1819
+ constant int64_t & ne2,
1820
+ constant int64_t & ne3,
1821
+ constant uint64_t & nb0,
1822
+ constant uint64_t & nb1,
1823
+ constant uint64_t & nb2,
1824
+ constant uint64_t & nb3,
1825
+ uint3 tgpig[[threadgroup_position_in_grid]],
1826
+ uint3 tpitg[[thread_position_in_threadgroup]],
1827
+ uint3 ntg[[threads_per_threadgroup]]) {
1828
+
1829
+ const int64_t i3 = tgpig.z;
1830
+ const int64_t i2 = tgpig.y;
1831
+ const int64_t i1 = tgpig.x;
1832
+
1833
+ const int64_t i03 = i3;
1834
+ const int64_t i02 = i2;
1835
+ const int64_t i01 = i1;
1836
+
1837
+ device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
1838
+ device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
1839
+
1840
+ if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
1841
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
1842
+ if (i0 < ne00) {
1843
+ dst_ptr[i0] = src0_ptr[i0];
1844
+ } else {
1845
+ dst_ptr[i0] = 0.0f;
1846
+ }
1847
+ }
1848
+
1849
+ return;
1850
+ }
1851
+
1852
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
1853
+ dst_ptr[i0] = 0.0f;
1854
+ }
1855
+ }
1856
+
1857
+ // bitonic sort implementation following the CUDA kernels as reference
1858
+ typedef void (argsort_t)(
1859
+ device const float * x,
1860
+ device int32_t * dst,
1861
+ constant int64_t & ncols,
1862
+ uint3 tgpig[[threadgroup_position_in_grid]],
1863
+ uint3 tpitg[[thread_position_in_threadgroup]]);
1864
+
1865
+ template<ggml_sort_order order>
1866
+ kernel void kernel_argsort_f32_i32(
1867
+ device const float * x,
1868
+ device int32_t * dst,
1869
+ constant int64_t & ncols,
1870
+ uint3 tgpig[[threadgroup_position_in_grid]],
1871
+ uint3 tpitg[[thread_position_in_threadgroup]]) {
1872
+ // bitonic sort
1873
+ int col = tpitg[0];
1874
+ int row = tgpig[1];
1875
+
1876
+ if (col >= ncols) return;
1877
+
1878
+ device const float * x_row = x + row * ncols;
1879
+ device int32_t * dst_row = dst + row * ncols;
1880
+
1881
+ // initialize indices
1882
+ if (col < ncols) {
1883
+ dst_row[col] = col;
1884
+ }
1885
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1886
+
1887
+ for (int k = 2; k <= ncols; k *= 2) {
1888
+ for (int j = k / 2; j > 0; j /= 2) {
1889
+ int ixj = col ^ j;
1890
+ if (ixj > col) {
1891
+ if ((col & k) == 0) {
1892
+ if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
1893
+ SWAP(dst_row[col], dst_row[ixj]);
1894
+ }
1895
+ } else {
1896
+ if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
1897
+ SWAP(dst_row[col], dst_row[ixj]);
1898
+ }
1899
+ }
1900
+ }
1901
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1902
+ }
1903
+ }
1904
+ }
1905
+
1906
+ template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ASC>;
1907
+ template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
1908
+
1909
+ kernel void kernel_leaky_relu_f32(
1910
+ device const float * src0,
1911
+ device float * dst,
1912
+ constant float & slope,
1913
+ uint tpig[[thread_position_in_grid]]) {
1914
+ dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
1915
+ }
1916
+
1917
+ kernel void kernel_cpy_f16_f16(
1918
+ device const half * src0,
1919
+ device half * dst,
1920
+ constant int64_t & ne00,
1921
+ constant int64_t & ne01,
1922
+ constant int64_t & ne02,
1923
+ constant int64_t & ne03,
1924
+ constant uint64_t & nb00,
1925
+ constant uint64_t & nb01,
1926
+ constant uint64_t & nb02,
1927
+ constant uint64_t & nb03,
1928
+ constant int64_t & ne0,
1929
+ constant int64_t & ne1,
1930
+ constant int64_t & ne2,
1931
+ constant int64_t & ne3,
1932
+ constant uint64_t & nb0,
1933
+ constant uint64_t & nb1,
1934
+ constant uint64_t & nb2,
1935
+ constant uint64_t & nb3,
1936
+ uint3 tgpig[[threadgroup_position_in_grid]],
1937
+ uint3 tpitg[[thread_position_in_threadgroup]],
1938
+ uint3 ntg[[threads_per_threadgroup]]) {
1939
+ const int64_t i03 = tgpig[2];
1940
+ const int64_t i02 = tgpig[1];
1941
+ const int64_t i01 = tgpig[0];
1942
+
1943
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
1944
+
1945
+ const int64_t i3 = n / (ne2*ne1*ne0);
1946
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1947
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1948
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1949
+
1950
+ device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1951
+
1952
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
1953
+ device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1954
+ dst_data[i00] = src[0];
1955
+ }
1956
+ }
1957
+
1958
+ kernel void kernel_cpy_f16_f32(
1959
+ device const half * src0,
1960
+ device float * dst,
1961
+ constant int64_t & ne00,
1962
+ constant int64_t & ne01,
1963
+ constant int64_t & ne02,
1964
+ constant int64_t & ne03,
1965
+ constant uint64_t & nb00,
1966
+ constant uint64_t & nb01,
1967
+ constant uint64_t & nb02,
1968
+ constant uint64_t & nb03,
1969
+ constant int64_t & ne0,
1970
+ constant int64_t & ne1,
1971
+ constant int64_t & ne2,
1972
+ constant int64_t & ne3,
1973
+ constant uint64_t & nb0,
1974
+ constant uint64_t & nb1,
1975
+ constant uint64_t & nb2,
1976
+ constant uint64_t & nb3,
1977
+ uint3 tgpig[[threadgroup_position_in_grid]],
1978
+ uint3 tpitg[[thread_position_in_threadgroup]],
1979
+ uint3 ntg[[threads_per_threadgroup]]) {
1980
+ const int64_t i03 = tgpig[2];
1981
+ const int64_t i02 = tgpig[1];
1982
+ const int64_t i01 = tgpig[0];
1983
+
1984
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
1985
+
1986
+ const int64_t i3 = n / (ne2*ne1*ne0);
1987
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1988
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1989
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1990
+
1991
+ device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1992
+
1993
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
1994
+ device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1995
+ dst_data[i00] = src[0];
1996
+ }
1997
+ }
1998
+
1999
+ kernel void kernel_cpy_f32_f16(
2000
+ device const float * src0,
2001
+ device half * dst,
2002
+ constant int64_t & ne00,
2003
+ constant int64_t & ne01,
2004
+ constant int64_t & ne02,
2005
+ constant int64_t & ne03,
2006
+ constant uint64_t & nb00,
2007
+ constant uint64_t & nb01,
2008
+ constant uint64_t & nb02,
2009
+ constant uint64_t & nb03,
2010
+ constant int64_t & ne0,
2011
+ constant int64_t & ne1,
2012
+ constant int64_t & ne2,
2013
+ constant int64_t & ne3,
2014
+ constant uint64_t & nb0,
2015
+ constant uint64_t & nb1,
2016
+ constant uint64_t & nb2,
2017
+ constant uint64_t & nb3,
2018
+ uint3 tgpig[[threadgroup_position_in_grid]],
2019
+ uint3 tpitg[[thread_position_in_threadgroup]],
2020
+ uint3 ntg[[threads_per_threadgroup]]) {
2021
+ const int64_t i03 = tgpig[2];
2022
+ const int64_t i02 = tgpig[1];
2023
+ const int64_t i01 = tgpig[0];
2024
+
2025
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2026
+
2027
+ const int64_t i3 = n / (ne2*ne1*ne0);
2028
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2029
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2030
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
2031
+
2032
+ device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1218
2033
 
1219
- const float x0 = src[0];
1220
- const float x1 = src[n_dims/2];
2034
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
2035
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1221
2036
 
1222
- dst_data[0] = x0*cos_theta - x1*sin_theta;
1223
- dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
1224
- }
1225
- }
2037
+ dst_data[i00] = src[0];
1226
2038
  }
1227
2039
  }
1228
2040
 
1229
- template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
1230
- template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
1231
-
1232
- kernel void kernel_cpy_f16_f16(
1233
- device const half * src0,
1234
- device half * dst,
2041
+ kernel void kernel_cpy_f32_f32(
2042
+ device const float * src0,
2043
+ device float * dst,
1235
2044
  constant int64_t & ne00,
1236
2045
  constant int64_t & ne01,
1237
2046
  constant int64_t & ne02,
@@ -1262,17 +2071,18 @@ kernel void kernel_cpy_f16_f16(
1262
2071
  const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1263
2072
  const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1264
2073
 
1265
- device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2074
+ device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1266
2075
 
1267
2076
  for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
1268
- device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2077
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2078
+
1269
2079
  dst_data[i00] = src[0];
1270
2080
  }
1271
2081
  }
1272
2082
 
1273
- kernel void kernel_cpy_f32_f16(
2083
+ kernel void kernel_cpy_f32_q8_0(
1274
2084
  device const float * src0,
1275
- device half * dst,
2085
+ device void * dst,
1276
2086
  constant int64_t & ne00,
1277
2087
  constant int64_t & ne01,
1278
2088
  constant int64_t & ne02,
@@ -1301,20 +2111,36 @@ kernel void kernel_cpy_f32_f16(
1301
2111
  const int64_t i3 = n / (ne2*ne1*ne0);
1302
2112
  const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1303
2113
  const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1304
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
2114
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK8_0;
1305
2115
 
1306
- device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2116
+ device block_q8_0 * dst_data = (device block_q8_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1307
2117
 
1308
- for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
2118
+ for (int64_t i00 = tpitg.x*QK8_0; i00 < ne00; i00 += ntg.x*QK8_0) {
1309
2119
  device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1310
2120
 
1311
- dst_data[i00] = src[0];
2121
+ float amax = 0.0f; // absolute max
2122
+
2123
+ for (int j = 0; j < QK8_0; j++) {
2124
+ const float v = src[j];
2125
+ amax = MAX(amax, fabs(v));
2126
+ }
2127
+
2128
+ const float d = amax / ((1 << 7) - 1);
2129
+ const float id = d ? 1.0f/d : 0.0f;
2130
+
2131
+ dst_data[i00/QK8_0].d = d;
2132
+
2133
+ for (int j = 0; j < QK8_0; ++j) {
2134
+ const float x0 = src[j]*id;
2135
+
2136
+ dst_data[i00/QK8_0].qs[j] = round(x0);
2137
+ }
1312
2138
  }
1313
2139
  }
1314
2140
 
1315
- kernel void kernel_cpy_f32_f32(
2141
+ kernel void kernel_cpy_f32_q4_0(
1316
2142
  device const float * src0,
1317
- device float * dst,
2143
+ device void * dst,
1318
2144
  constant int64_t & ne00,
1319
2145
  constant int64_t & ne01,
1320
2146
  constant int64_t & ne02,
@@ -1343,21 +2169,112 @@ kernel void kernel_cpy_f32_f32(
1343
2169
  const int64_t i3 = n / (ne2*ne1*ne0);
1344
2170
  const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1345
2171
  const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1346
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
2172
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_0;
1347
2173
 
1348
- device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2174
+ device block_q4_0 * dst_data = (device block_q4_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1349
2175
 
1350
- for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
2176
+ for (int64_t i00 = tpitg.x*QK4_0; i00 < ne00; i00 += ntg.x*QK4_0) {
1351
2177
  device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1352
2178
 
1353
- dst_data[i00] = src[0];
2179
+ float amax = 0.0f; // absolute max
2180
+ float max = 0.0f;
2181
+
2182
+ for (int j = 0; j < QK4_0; j++) {
2183
+ const float v = src[j];
2184
+ if (amax < fabs(v)) {
2185
+ amax = fabs(v);
2186
+ max = v;
2187
+ }
2188
+ }
2189
+
2190
+ const float d = max / -8;
2191
+ const float id = d ? 1.0f/d : 0.0f;
2192
+
2193
+ dst_data[i00/QK4_0].d = d;
2194
+
2195
+ for (int j = 0; j < QK4_0/2; ++j) {
2196
+ const float x0 = src[0 + j]*id;
2197
+ const float x1 = src[QK4_0/2 + j]*id;
2198
+
2199
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
2200
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
2201
+
2202
+ dst_data[i00/QK4_0].qs[j] = xi0;
2203
+ dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
2204
+ }
2205
+ }
2206
+ }
2207
+
2208
+ kernel void kernel_cpy_f32_q4_1(
2209
+ device const float * src0,
2210
+ device void * dst,
2211
+ constant int64_t & ne00,
2212
+ constant int64_t & ne01,
2213
+ constant int64_t & ne02,
2214
+ constant int64_t & ne03,
2215
+ constant uint64_t & nb00,
2216
+ constant uint64_t & nb01,
2217
+ constant uint64_t & nb02,
2218
+ constant uint64_t & nb03,
2219
+ constant int64_t & ne0,
2220
+ constant int64_t & ne1,
2221
+ constant int64_t & ne2,
2222
+ constant int64_t & ne3,
2223
+ constant uint64_t & nb0,
2224
+ constant uint64_t & nb1,
2225
+ constant uint64_t & nb2,
2226
+ constant uint64_t & nb3,
2227
+ uint3 tgpig[[threadgroup_position_in_grid]],
2228
+ uint3 tpitg[[thread_position_in_threadgroup]],
2229
+ uint3 ntg[[threads_per_threadgroup]]) {
2230
+ const int64_t i03 = tgpig[2];
2231
+ const int64_t i02 = tgpig[1];
2232
+ const int64_t i01 = tgpig[0];
2233
+
2234
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2235
+
2236
+ const int64_t i3 = n / (ne2*ne1*ne0);
2237
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2238
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2239
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_1;
2240
+
2241
+ device block_q4_1 * dst_data = (device block_q4_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2242
+
2243
+ for (int64_t i00 = tpitg.x*QK4_1; i00 < ne00; i00 += ntg.x*QK4_1) {
2244
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2245
+
2246
+ float min = FLT_MAX;
2247
+ float max = -FLT_MAX;
2248
+
2249
+ for (int j = 0; j < QK4_1; j++) {
2250
+ const float v = src[j];
2251
+ if (min > v) min = v;
2252
+ if (max < v) max = v;
2253
+ }
2254
+
2255
+ const float d = (max - min) / ((1 << 4) - 1);
2256
+ const float id = d ? 1.0f/d : 0.0f;
2257
+
2258
+ dst_data[i00/QK4_1].d = d;
2259
+ dst_data[i00/QK4_1].m = min;
2260
+
2261
+ for (int j = 0; j < QK4_1/2; ++j) {
2262
+ const float x0 = (src[0 + j] - min)*id;
2263
+ const float x1 = (src[QK4_1/2 + j] - min)*id;
2264
+
2265
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
2266
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
2267
+
2268
+ dst_data[i00/QK4_1].qs[j] = xi0;
2269
+ dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
2270
+ }
1354
2271
  }
1355
2272
  }
1356
2273
 
1357
2274
  kernel void kernel_concat(
1358
- device const char * src0,
1359
- device const char * src1,
1360
- device char * dst,
2275
+ device const char * src0,
2276
+ device const char * src1,
2277
+ device char * dst,
1361
2278
  constant int64_t & ne00,
1362
2279
  constant int64_t & ne01,
1363
2280
  constant int64_t & ne02,
@@ -1394,7 +2311,7 @@ kernel void kernel_concat(
1394
2311
  const int64_t i12 = i02 % ne12;
1395
2312
  const int64_t i11 = i01 % ne11;
1396
2313
 
1397
- device const char * src0_ptr = src0 + i03 * nb03 + i02 * nb02 + i01 * nb01 + tpitg.x*nb00;
2314
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
1398
2315
  device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
1399
2316
  device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
1400
2317
 
@@ -1502,32 +2419,39 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
1502
2419
 
1503
2420
  //====================================== dot products =========================
1504
2421
 
1505
- kernel void kernel_mul_mv_q2_K_f32(
2422
+ void kernel_mul_mv_q2_K_f32_impl(
1506
2423
  device const void * src0,
1507
2424
  device const float * src1,
1508
2425
  device float * dst,
1509
2426
  constant int64_t & ne00,
1510
- constant int64_t & ne01[[buffer(4)]],
1511
- constant int64_t & ne02[[buffer(5)]],
1512
- constant int64_t & ne10[[buffer(9)]],
1513
- constant int64_t & ne12[[buffer(11)]],
1514
- constant int64_t & ne0[[buffer(15)]],
1515
- constant int64_t & ne1[[buffer(16)]],
1516
- constant uint & gqa[[buffer(17)]],
2427
+ constant int64_t & ne01,
2428
+ constant int64_t & ne02,
2429
+ constant int64_t & ne10,
2430
+ constant int64_t & ne12,
2431
+ constant int64_t & ne0,
2432
+ constant int64_t & ne1,
2433
+ constant uint & r2,
2434
+ constant uint & r3,
1517
2435
  uint3 tgpig[[threadgroup_position_in_grid]],
1518
- uint tiisg[[thread_index_in_simdgroup]],
1519
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2436
+ uint tiisg[[thread_index_in_simdgroup]],
2437
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1520
2438
 
1521
2439
  const int nb = ne00/QK_K;
1522
2440
  const int r0 = tgpig.x;
1523
2441
  const int r1 = tgpig.y;
1524
- const int r2 = tgpig.z;
2442
+ const int im = tgpig.z;
1525
2443
 
1526
2444
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1527
2445
  const int ib_row = first_row * nb;
1528
- const uint offset0 = r2/gqa*(nb*ne0);
2446
+
2447
+ const uint i12 = im%ne12;
2448
+ const uint i13 = im/ne12;
2449
+
2450
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2451
+
1529
2452
  device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
1530
- device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2453
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2454
+
1531
2455
  float yl[32];
1532
2456
  float sumf[N_DST]={0.f}, all_sum;
1533
2457
 
@@ -1536,11 +2460,11 @@ kernel void kernel_mul_mv_q2_K_f32(
1536
2460
  #if QK_K == 256
1537
2461
  const int ix = tiisg/8; // 0...3
1538
2462
  const int it = tiisg%8; // 0...7
1539
- const int im = it/4; // 0 or 1
2463
+ const int iq = it/4; // 0 or 1
1540
2464
  const int ir = it%4; // 0...3
1541
2465
  const int is = (8*ir)/16;// 0 or 1
1542
2466
 
1543
- device const float * y4 = y + ix * QK_K + 128 * im + 8 * ir;
2467
+ device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
1544
2468
 
1545
2469
  for (int ib = ix; ib < nb; ib += 4) {
1546
2470
 
@@ -1552,8 +2476,8 @@ kernel void kernel_mul_mv_q2_K_f32(
1552
2476
  yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
1553
2477
  }
1554
2478
 
1555
- device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*im + is;
1556
- device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
2479
+ device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*iq + is;
2480
+ device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
1557
2481
  device const half * dh = &x[ib].d;
1558
2482
 
1559
2483
  for (int row = 0; row < N_DST; row++) {
@@ -1640,13 +2564,13 @@ kernel void kernel_mul_mv_q2_K_f32(
1640
2564
  for (int row = 0; row < N_DST; ++row) {
1641
2565
  all_sum = simd_sum(sumf[row]);
1642
2566
  if (tiisg == 0) {
1643
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
2567
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
1644
2568
  }
1645
2569
  }
1646
2570
  }
1647
2571
 
1648
- #if QK_K == 256
1649
- kernel void kernel_mul_mv_q3_K_f32(
2572
+ [[host_name("kernel_mul_mv_q2_K_f32")]]
2573
+ kernel void kernel_mul_mv_q2_K_f32(
1650
2574
  device const void * src0,
1651
2575
  device const float * src1,
1652
2576
  device float * dst,
@@ -1655,23 +2579,50 @@ kernel void kernel_mul_mv_q3_K_f32(
1655
2579
  constant int64_t & ne02[[buffer(5)]],
1656
2580
  constant int64_t & ne10[[buffer(9)]],
1657
2581
  constant int64_t & ne12[[buffer(11)]],
1658
- constant int64_t & ne0[[buffer(15)]],
1659
- constant int64_t & ne1[[buffer(16)]],
1660
- constant uint & gqa[[buffer(17)]],
2582
+ constant int64_t & ne0 [[buffer(15)]],
2583
+ constant int64_t & ne1 [[buffer(16)]],
2584
+ constant uint & r2 [[buffer(17)]],
2585
+ constant uint & r3 [[buffer(18)]],
1661
2586
  uint3 tgpig[[threadgroup_position_in_grid]],
1662
- uint tiisg[[thread_index_in_simdgroup]],
1663
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2587
+ uint tiisg[[thread_index_in_simdgroup]],
2588
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2589
+
2590
+ kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
2591
+ }
2592
+
2593
+ #if QK_K == 256
2594
+ void kernel_mul_mv_q3_K_f32_impl(
2595
+ device const void * src0,
2596
+ device const float * src1,
2597
+ device float * dst,
2598
+ constant int64_t & ne00,
2599
+ constant int64_t & ne01,
2600
+ constant int64_t & ne02,
2601
+ constant int64_t & ne10,
2602
+ constant int64_t & ne12,
2603
+ constant int64_t & ne0,
2604
+ constant int64_t & ne1,
2605
+ constant uint & r2,
2606
+ constant uint & r3,
2607
+ uint3 tgpig[[threadgroup_position_in_grid]],
2608
+ uint tiisg[[thread_index_in_simdgroup]],
2609
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1664
2610
 
1665
2611
  const int nb = ne00/QK_K;
1666
2612
 
1667
2613
  const int64_t r0 = tgpig.x;
1668
2614
  const int64_t r1 = tgpig.y;
1669
- const int64_t r2 = tgpig.z;
2615
+ const int64_t im = tgpig.z;
1670
2616
 
1671
2617
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
1672
- const uint offset0 = r2/gqa*(nb*ne0);
2618
+
2619
+ const uint i12 = im%ne12;
2620
+ const uint i13 = im/ne12;
2621
+
2622
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2623
+
1673
2624
  device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
1674
- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2625
+ device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
1675
2626
 
1676
2627
  float yl[32];
1677
2628
 
@@ -1793,40 +2744,47 @@ kernel void kernel_mul_mv_q3_K_f32(
1793
2744
  }
1794
2745
  if (tiisg == 0) {
1795
2746
  for (int row = 0; row < 2; ++row) {
1796
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
2747
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = sumf1[row];
1797
2748
  }
1798
2749
  }
1799
2750
  }
1800
2751
  #else
1801
- kernel void kernel_mul_mv_q3_K_f32(
2752
+ void kernel_mul_mv_q3_K_f32_impl(
1802
2753
  device const void * src0,
1803
2754
  device const float * src1,
1804
2755
  device float * dst,
1805
2756
  constant int64_t & ne00,
1806
- constant int64_t & ne01[[buffer(4)]],
1807
- constant int64_t & ne02[[buffer(5)]],
1808
- constant int64_t & ne10[[buffer(9)]],
1809
- constant int64_t & ne12[[buffer(11)]],
1810
- constant int64_t & ne0[[buffer(15)]],
1811
- constant int64_t & ne1[[buffer(16)]],
1812
- constant uint & gqa[[buffer(17)]],
2757
+ constant int64_t & ne01,
2758
+ constant int64_t & ne02,
2759
+ constant int64_t & ne10,
2760
+ constant int64_t & ne12,
2761
+ constant int64_t & ne0,
2762
+ constant int64_t & ne1,
2763
+ constant uint & r2,
2764
+ constant uint & r3,
1813
2765
  uint3 tgpig[[threadgroup_position_in_grid]],
1814
- uint tiisg[[thread_index_in_simdgroup]],
1815
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2766
+ uint tiisg[[thread_index_in_simdgroup]],
2767
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1816
2768
 
1817
2769
  const int nb = ne00/QK_K;
1818
2770
 
1819
2771
  const int64_t r0 = tgpig.x;
1820
2772
  const int64_t r1 = tgpig.y;
1821
- const int64_t r2 = tgpig.z;
2773
+ const int64_t im = tgpig.z;
1822
2774
 
1823
2775
  const int row = 2 * r0 + sgitg;
1824
- const uint offset0 = r2/gqa*(nb*ne0);
2776
+
2777
+ const uint i12 = im%ne12;
2778
+ const uint i13 = im/ne12;
2779
+
2780
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2781
+
1825
2782
  device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
1826
- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2783
+ device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2784
+
1827
2785
  const int ix = tiisg/4;
1828
2786
  const int il = 4 * (tiisg%4);// 0, 4, 8, 12
1829
- const int im = il/8; // 0, 0, 1, 1
2787
+ const int iq = il/8; // 0, 0, 1, 1
1830
2788
  const int in = il%8; // 0, 4, 0, 4
1831
2789
 
1832
2790
  float2 sum = {0.f, 0.f};
@@ -1846,7 +2804,7 @@ kernel void kernel_mul_mv_q3_K_f32(
1846
2804
  const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
1847
2805
 
1848
2806
  for (int l = 0; l < 4; l += 2) {
1849
- const uint16_t hm = h[l/2] >> im;
2807
+ const uint16_t hm = h[l/2] >> iq;
1850
2808
  sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
1851
2809
  + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
1852
2810
  + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
@@ -1862,28 +2820,50 @@ kernel void kernel_mul_mv_q3_K_f32(
1862
2820
 
1863
2821
  const float tot = simd_sum(sumf);
1864
2822
  if (tiisg == 0) {
1865
- dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
2823
+ dst[r1*ne0 + im*ne0*ne1 + row] = tot;
1866
2824
  }
1867
2825
 
1868
2826
  }
1869
2827
  #endif
1870
2828
 
2829
+ [[host_name("kernel_mul_mv_q3_K_f32")]]
2830
+ kernel void kernel_mul_mv_q3_K_f32(
2831
+ device const void * src0,
2832
+ device const float * src1,
2833
+ device float * dst,
2834
+ constant int64_t & ne00,
2835
+ constant int64_t & ne01[[buffer(4)]],
2836
+ constant int64_t & ne02[[buffer(5)]],
2837
+ constant int64_t & ne10[[buffer(9)]],
2838
+ constant int64_t & ne12[[buffer(11)]],
2839
+ constant int64_t & ne0 [[buffer(15)]],
2840
+ constant int64_t & ne1 [[buffer(16)]],
2841
+ constant uint & r2 [[buffer(17)]],
2842
+ constant uint & r3 [[buffer(18)]],
2843
+ uint3 tgpig[[threadgroup_position_in_grid]],
2844
+ uint tiisg[[thread_index_in_simdgroup]],
2845
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2846
+
2847
+ kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
2848
+ }
2849
+
1871
2850
  #if QK_K == 256
1872
- kernel void kernel_mul_mv_q4_K_f32(
2851
+ void kernel_mul_mv_q4_K_f32_impl(
1873
2852
  device const void * src0,
1874
2853
  device const float * src1,
1875
2854
  device float * dst,
1876
2855
  constant int64_t & ne00,
1877
- constant int64_t & ne01 [[buffer(4)]],
1878
- constant int64_t & ne02 [[buffer(5)]],
1879
- constant int64_t & ne10 [[buffer(9)]],
1880
- constant int64_t & ne12 [[buffer(11)]],
1881
- constant int64_t & ne0 [[buffer(15)]],
1882
- constant int64_t & ne1 [[buffer(16)]],
1883
- constant uint & gqa [[buffer(17)]],
2856
+ constant int64_t & ne01,
2857
+ constant int64_t & ne02,
2858
+ constant int64_t & ne10,
2859
+ constant int64_t & ne12,
2860
+ constant int64_t & ne0,
2861
+ constant int64_t & ne1,
2862
+ constant uint & r2,
2863
+ constant uint & r3,
1884
2864
  uint3 tgpig[[threadgroup_position_in_grid]],
1885
- uint tiisg[[thread_index_in_simdgroup]],
1886
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2865
+ uint tiisg[[thread_index_in_simdgroup]],
2866
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1887
2867
 
1888
2868
  const uint16_t kmask1 = 0x3f3f;
1889
2869
  const uint16_t kmask2 = 0x0f0f;
@@ -1891,26 +2871,32 @@ kernel void kernel_mul_mv_q4_K_f32(
1891
2871
 
1892
2872
  const int ix = tiisg/8; // 0...3
1893
2873
  const int it = tiisg%8; // 0...7
1894
- const int im = it/4; // 0 or 1
2874
+ const int iq = it/4; // 0 or 1
1895
2875
  const int ir = it%4; // 0...3
1896
2876
 
1897
2877
  const int nb = ne00/QK_K;
1898
2878
  const int r0 = tgpig.x;
1899
2879
  const int r1 = tgpig.y;
1900
- const int r2 = tgpig.z;
2880
+ const int im = tgpig.z;
1901
2881
  //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1902
2882
  const int first_row = r0 * N_DST;
1903
2883
  const int ib_row = first_row * nb;
1904
- const uint offset0 = r2/gqa*(nb*ne0);
2884
+
2885
+ const uint i12 = im%ne12;
2886
+ const uint i13 = im/ne12;
2887
+
2888
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2889
+
1905
2890
  device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
1906
- device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2891
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2892
+
1907
2893
  float yl[16];
1908
2894
  float yh[16];
1909
2895
  float sumf[N_DST]={0.f}, all_sum;
1910
2896
 
1911
2897
  const int step = sizeof(block_q4_K) * nb / 2;
1912
2898
 
1913
- device const float * y4 = y + ix * QK_K + 64 * im + 8 * ir;
2899
+ device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
1914
2900
 
1915
2901
  uint16_t sc16[4];
1916
2902
  thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
@@ -1925,8 +2911,8 @@ kernel void kernel_mul_mv_q4_K_f32(
1925
2911
  yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
1926
2912
  }
1927
2913
 
1928
- device const uint16_t * sc = (device const uint16_t *)x[ib].scales + im;
1929
- device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
2914
+ device const uint16_t * sc = (device const uint16_t *)x[ib].scales + iq;
2915
+ device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
1930
2916
  device const half * dh = &x[ib].d;
1931
2917
 
1932
2918
  for (int row = 0; row < N_DST; row++) {
@@ -1970,23 +2956,24 @@ kernel void kernel_mul_mv_q4_K_f32(
1970
2956
  for (int row = 0; row < N_DST; ++row) {
1971
2957
  all_sum = simd_sum(sumf[row]);
1972
2958
  if (tiisg == 0) {
1973
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
2959
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
1974
2960
  }
1975
2961
  }
1976
2962
  }
1977
2963
  #else
1978
- kernel void kernel_mul_mv_q4_K_f32(
2964
+ void kernel_mul_mv_q4_K_f32_impl(
1979
2965
  device const void * src0,
1980
2966
  device const float * src1,
1981
2967
  device float * dst,
1982
2968
  constant int64_t & ne00,
1983
- constant int64_t & ne01[[buffer(4)]],
1984
- constant int64_t & ne02[[buffer(5)]],
1985
- constant int64_t & ne10[[buffer(9)]],
1986
- constant int64_t & ne12[[buffer(11)]],
1987
- constant int64_t & ne0[[buffer(15)]],
1988
- constant int64_t & ne1[[buffer(16)]],
1989
- constant uint & gqa[[buffer(17)]],
2969
+ constant int64_t & ne01,
2970
+ constant int64_t & ne02,
2971
+ constant int64_t & ne10,
2972
+ constant int64_t & ne12,
2973
+ constant int64_t & ne0,
2974
+ constant int64_t & ne1,
2975
+ constant uint & r2,
2976
+ constant uint & r3,
1990
2977
  uint3 tgpig[[threadgroup_position_in_grid]],
1991
2978
  uint tiisg[[thread_index_in_simdgroup]],
1992
2979
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1997,12 +2984,18 @@ kernel void kernel_mul_mv_q4_K_f32(
1997
2984
  const int nb = ne00/QK_K;
1998
2985
  const int r0 = tgpig.x;
1999
2986
  const int r1 = tgpig.y;
2000
- const int r2 = tgpig.z;
2987
+ const int im = tgpig.z;
2001
2988
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
2002
2989
  const int ib_row = first_row * nb;
2003
- const uint offset0 = r2/gqa*(nb*ne0);
2990
+
2991
+ const uint i12 = im%ne12;
2992
+ const uint i13 = im/ne12;
2993
+
2994
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2995
+
2004
2996
  device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
2005
- device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2997
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2998
+
2006
2999
  float yl[8];
2007
3000
  float yh[8];
2008
3001
  float sumf[N_DST]={0.f}, all_sum;
@@ -2058,13 +3051,14 @@ kernel void kernel_mul_mv_q4_K_f32(
2058
3051
  for (int row = 0; row < N_DST; ++row) {
2059
3052
  all_sum = simd_sum(sumf[row]);
2060
3053
  if (tiisg == 0) {
2061
- dst[r1*ne0+ r2*ne0*ne1 + first_row + row] = all_sum;
3054
+ dst[r1*ne0+ im*ne0*ne1 + first_row + row] = all_sum;
2062
3055
  }
2063
3056
  }
2064
3057
  }
2065
3058
  #endif
2066
3059
 
2067
- kernel void kernel_mul_mv_q5_K_f32(
3060
+ [[host_name("kernel_mul_mv_q4_K_f32")]]
3061
+ kernel void kernel_mul_mv_q4_K_f32(
2068
3062
  device const void * src0,
2069
3063
  device const float * src1,
2070
3064
  device float * dst,
@@ -2073,23 +3067,49 @@ kernel void kernel_mul_mv_q5_K_f32(
2073
3067
  constant int64_t & ne02[[buffer(5)]],
2074
3068
  constant int64_t & ne10[[buffer(9)]],
2075
3069
  constant int64_t & ne12[[buffer(11)]],
2076
- constant int64_t & ne0[[buffer(15)]],
2077
- constant int64_t & ne1[[buffer(16)]],
2078
- constant uint & gqa[[buffer(17)]],
3070
+ constant int64_t & ne0 [[buffer(15)]],
3071
+ constant int64_t & ne1 [[buffer(16)]],
3072
+ constant uint & r2 [[buffer(17)]],
3073
+ constant uint & r3 [[buffer(18)]],
2079
3074
  uint3 tgpig[[threadgroup_position_in_grid]],
2080
3075
  uint tiisg[[thread_index_in_simdgroup]],
2081
3076
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
2082
3077
 
3078
+ kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
3079
+ }
3080
+
3081
+ void kernel_mul_mv_q5_K_f32_impl(
3082
+ device const void * src0,
3083
+ device const float * src1,
3084
+ device float * dst,
3085
+ constant int64_t & ne00,
3086
+ constant int64_t & ne01,
3087
+ constant int64_t & ne02,
3088
+ constant int64_t & ne10,
3089
+ constant int64_t & ne12,
3090
+ constant int64_t & ne0,
3091
+ constant int64_t & ne1,
3092
+ constant uint & r2,
3093
+ constant uint & r3,
3094
+ uint3 tgpig[[threadgroup_position_in_grid]],
3095
+ uint tiisg[[thread_index_in_simdgroup]],
3096
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3097
+
2083
3098
  const int nb = ne00/QK_K;
2084
3099
 
2085
3100
  const int64_t r0 = tgpig.x;
2086
3101
  const int64_t r1 = tgpig.y;
2087
- const int r2 = tgpig.z;
3102
+ const int im = tgpig.z;
2088
3103
 
2089
3104
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
2090
- const uint offset0 = r2/gqa*(nb*ne0);
3105
+
3106
+ const uint i12 = im%ne12;
3107
+ const uint i13 = im/ne12;
3108
+
3109
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
3110
+
2091
3111
  device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
2092
- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
3112
+ device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2093
3113
 
2094
3114
  float sumf[2]={0.f};
2095
3115
 
@@ -2105,15 +3125,15 @@ kernel void kernel_mul_mv_q5_K_f32(
2105
3125
 
2106
3126
  const int tid = tiisg/4;
2107
3127
  const int ix = tiisg%4;
2108
- const int im = tid/4;
3128
+ const int iq = tid/4;
2109
3129
  const int ir = tid%4;
2110
3130
  const int n = 8;
2111
3131
 
2112
3132
  const int l0 = n*ir;
2113
- const int q_offset = 32*im + l0;
2114
- const int y_offset = 64*im + l0;
3133
+ const int q_offset = 32*iq + l0;
3134
+ const int y_offset = 64*iq + l0;
2115
3135
 
2116
- const uint8_t hm1 = 1u << (2*im);
3136
+ const uint8_t hm1 = 1u << (2*iq);
2117
3137
  const uint8_t hm2 = hm1 << 1;
2118
3138
  const uint8_t hm3 = hm1 << 4;
2119
3139
  const uint8_t hm4 = hm2 << 4;
@@ -2128,7 +3148,7 @@ kernel void kernel_mul_mv_q5_K_f32(
2128
3148
  device const uint8_t * q1 = x[i].qs + q_offset;
2129
3149
  device const uint8_t * qh = x[i].qh + l0;
2130
3150
  device const half * dh = &x[i].d;
2131
- device const uint16_t * a = (device const uint16_t *)x[i].scales + im;
3151
+ device const uint16_t * a = (device const uint16_t *)x[i].scales + iq;
2132
3152
 
2133
3153
  device const float * y2 = y1 + 128;
2134
3154
  float4 sumy = {0.f, 0.f, 0.f, 0.f};
@@ -2184,7 +3204,7 @@ kernel void kernel_mul_mv_q5_K_f32(
2184
3204
 
2185
3205
  const int il = 4 * (tiisg/8); // 0, 4, 8, 12
2186
3206
  const int ix = tiisg%8;
2187
- const int im = il/8; // 0, 0, 1, 1
3207
+ const int iq = il/8; // 0, 0, 1, 1
2188
3208
  const int in = il%8; // 0, 4, 0, 4
2189
3209
 
2190
3210
  device const float * y = yy + ix*QK_K + il;
@@ -2209,7 +3229,7 @@ kernel void kernel_mul_mv_q5_K_f32(
2209
3229
 
2210
3230
  float2 acc = {0.f, 0.f};
2211
3231
  for (int l = 0; l < 4; ++l) {
2212
- const uint8_t hl = h[l] >> im;
3232
+ const uint8_t hl = h[l] >> iq;
2213
3233
  acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
2214
3234
  + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
2215
3235
  acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
@@ -2231,27 +3251,48 @@ kernel void kernel_mul_mv_q5_K_f32(
2231
3251
  for (int row = 0; row < 2; ++row) {
2232
3252
  const float tot = simd_sum(sumf[row]);
2233
3253
  if (tiisg == 0) {
2234
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
3254
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
2235
3255
  }
2236
3256
  }
3257
+ }
3258
+
3259
+ [[host_name("kernel_mul_mv_q5_K_f32")]]
3260
+ kernel void kernel_mul_mv_q5_K_f32(
3261
+ device const void * src0,
3262
+ device const float * src1,
3263
+ device float * dst,
3264
+ constant int64_t & ne00,
3265
+ constant int64_t & ne01[[buffer(4)]],
3266
+ constant int64_t & ne02[[buffer(5)]],
3267
+ constant int64_t & ne10[[buffer(9)]],
3268
+ constant int64_t & ne12[[buffer(11)]],
3269
+ constant int64_t & ne0 [[buffer(15)]],
3270
+ constant int64_t & ne1 [[buffer(16)]],
3271
+ constant uint & r2 [[buffer(17)]],
3272
+ constant uint & r3 [[buffer(18)]],
3273
+ uint3 tgpig[[threadgroup_position_in_grid]],
3274
+ uint tiisg[[thread_index_in_simdgroup]],
3275
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2237
3276
 
3277
+ kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
2238
3278
  }
2239
3279
 
2240
- kernel void kernel_mul_mv_q6_K_f32(
3280
+ void kernel_mul_mv_q6_K_f32_impl(
2241
3281
  device const void * src0,
2242
3282
  device const float * src1,
2243
3283
  device float * dst,
2244
3284
  constant int64_t & ne00,
2245
- constant int64_t & ne01[[buffer(4)]],
2246
- constant int64_t & ne02[[buffer(5)]],
2247
- constant int64_t & ne10[[buffer(9)]],
2248
- constant int64_t & ne12[[buffer(11)]],
2249
- constant int64_t & ne0[[buffer(15)]],
2250
- constant int64_t & ne1[[buffer(16)]],
2251
- constant uint & gqa[[buffer(17)]],
3285
+ constant int64_t & ne01,
3286
+ constant int64_t & ne02,
3287
+ constant int64_t & ne10,
3288
+ constant int64_t & ne12,
3289
+ constant int64_t & ne0,
3290
+ constant int64_t & ne1,
3291
+ constant uint & r2,
3292
+ constant uint & r3,
2252
3293
  uint3 tgpig[[threadgroup_position_in_grid]],
2253
- uint tiisg[[thread_index_in_simdgroup]],
2254
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3294
+ uint tiisg[[thread_index_in_simdgroup]],
3295
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2255
3296
 
2256
3297
  const uint8_t kmask1 = 0x03;
2257
3298
  const uint8_t kmask2 = 0x0C;
@@ -2262,12 +3303,17 @@ kernel void kernel_mul_mv_q6_K_f32(
2262
3303
 
2263
3304
  const int64_t r0 = tgpig.x;
2264
3305
  const int64_t r1 = tgpig.y;
2265
- const int r2 = tgpig.z;
3306
+ const int im = tgpig.z;
2266
3307
 
2267
3308
  const int row = 2 * r0 + sgitg;
2268
- const uint offset0 = r2/gqa*(nb*ne0);
3309
+
3310
+ const uint i12 = im%ne12;
3311
+ const uint i13 = im/ne12;
3312
+
3313
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
3314
+
2269
3315
  device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
2270
- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
3316
+ device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2271
3317
 
2272
3318
  float sumf = 0;
2273
3319
 
@@ -2333,10 +3379,31 @@ kernel void kernel_mul_mv_q6_K_f32(
2333
3379
 
2334
3380
  const float tot = simd_sum(sumf);
2335
3381
  if (tiisg == 0) {
2336
- dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
3382
+ dst[r1*ne0 + im*ne0*ne1 + row] = tot;
2337
3383
  }
2338
3384
  }
2339
3385
 
3386
+ [[host_name("kernel_mul_mv_q6_K_f32")]]
3387
+ kernel void kernel_mul_mv_q6_K_f32(
3388
+ device const void * src0,
3389
+ device const float * src1,
3390
+ device float * dst,
3391
+ constant int64_t & ne00,
3392
+ constant int64_t & ne01[[buffer(4)]],
3393
+ constant int64_t & ne02[[buffer(5)]],
3394
+ constant int64_t & ne10[[buffer(9)]],
3395
+ constant int64_t & ne12[[buffer(11)]],
3396
+ constant int64_t & ne0 [[buffer(15)]],
3397
+ constant int64_t & ne1 [[buffer(16)]],
3398
+ constant uint & r2 [[buffer(17)]],
3399
+ constant uint & r3 [[buffer(18)]],
3400
+ uint3 tgpig[[threadgroup_position_in_grid]],
3401
+ uint tiisg[[thread_index_in_simdgroup]],
3402
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3403
+
3404
+ kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
3405
+ }
3406
+
2340
3407
  //============================= templates and their specializations =============================
2341
3408
 
2342
3409
  // NOTE: this is not dequantizing - we are simply fitting the template
@@ -2454,10 +3521,10 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
2454
3521
 
2455
3522
  template <typename type4x4>
2456
3523
  void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
2457
- const half d = xb->d;
2458
- const half min = xb->dmin;
3524
+ const float d = xb->d;
3525
+ const float min = xb->dmin;
2459
3526
  device const uint8_t * q = (device const uint8_t *)xb->qs;
2460
- half dl, ml;
3527
+ float dl, ml;
2461
3528
  uint8_t sc = xb->scales[il];
2462
3529
 
2463
3530
  #if QK_K == 256
@@ -2527,10 +3594,10 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
2527
3594
  q = q + (il/4) * 32 + 16 * (il&1);
2528
3595
  il = il & 3;
2529
3596
  const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
2530
- const half d = il < 2 ? xb->d : xb->d / 16.h;
2531
- const half min = xb->dmin;
2532
- const half dl = d * sc[0];
2533
- const half ml = min * sc[1];
3597
+ const float d = il < 2 ? xb->d : xb->d / 16.h;
3598
+ const float min = xb->dmin;
3599
+ const float dl = d * sc[0];
3600
+ const float ml = min * sc[1];
2534
3601
  #else
2535
3602
  q = q + 16 * (il&1);
2536
3603
  device const uint8_t * s = xb->scales;
@@ -2557,13 +3624,13 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
2557
3624
  uint8_t ul = 1 << (il/2);
2558
3625
  il = il & 3;
2559
3626
  const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
2560
- const half d = il < 2 ? xb->d : xb->d / 16.h;
2561
- const half min = xb->dmin;
2562
- const half dl = d * sc[0];
2563
- const half ml = min * sc[1];
3627
+ const float d = il < 2 ? xb->d : xb->d / 16.h;
3628
+ const float min = xb->dmin;
3629
+ const float dl = d * sc[0];
3630
+ const float ml = min * sc[1];
2564
3631
 
2565
- const ushort mask = il<2 ? 0x0F : 0xF0;
2566
- const half qh_val = il<2 ? 16.h : 256.h;
3632
+ const ushort mask = il<2 ? 0x0F : 0xF0;
3633
+ const float qh_val = il<2 ? 16.f : 256.f;
2567
3634
  for (int i = 0; i < 16; ++i) {
2568
3635
  reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
2569
3636
  }
@@ -2611,22 +3678,90 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
2611
3678
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
2612
3679
  kernel void kernel_get_rows(
2613
3680
  device const void * src0,
2614
- device const int * src1,
3681
+ device const char * src1,
2615
3682
  device float * dst,
2616
3683
  constant int64_t & ne00,
2617
3684
  constant uint64_t & nb01,
3685
+ constant uint64_t & nb02,
3686
+ constant int64_t & ne10,
3687
+ constant uint64_t & nb10,
3688
+ constant uint64_t & nb11,
2618
3689
  constant uint64_t & nb1,
2619
- uint tgpig[[threadgroup_position_in_grid]],
3690
+ constant uint64_t & nb2,
3691
+ uint3 tgpig[[threadgroup_position_in_grid]],
2620
3692
  uint tiitg[[thread_index_in_threadgroup]],
2621
- uint tptg[[threads_per_threadgroup]]) {
2622
- const int i = tgpig;
2623
- const int r = ((device int32_t *) src1)[i];
3693
+ uint3 tptg [[threads_per_threadgroup]]) {
3694
+ //const int64_t i = tgpig;
3695
+ //const int64_t r = ((device int32_t *) src1)[i];
3696
+
3697
+ const int64_t i10 = tgpig.x;
3698
+ const int64_t i11 = tgpig.y;
2624
3699
 
2625
- for (int ind = tiitg; ind < ne00/16; ind += tptg) {
3700
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
3701
+
3702
+ const int64_t i02 = i11;
3703
+
3704
+ for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
2626
3705
  float4x4 temp;
2627
3706
  dequantize_func(
2628
- ((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
2629
- *(((device float4x4 *) ((device char *) dst + i*nb1)) + ind) = temp;
3707
+ ((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
3708
+ *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
3709
+ }
3710
+ }
3711
+
3712
+ kernel void kernel_get_rows_f32(
3713
+ device const void * src0,
3714
+ device const char * src1,
3715
+ device float * dst,
3716
+ constant int64_t & ne00,
3717
+ constant uint64_t & nb01,
3718
+ constant uint64_t & nb02,
3719
+ constant int64_t & ne10,
3720
+ constant uint64_t & nb10,
3721
+ constant uint64_t & nb11,
3722
+ constant uint64_t & nb1,
3723
+ constant uint64_t & nb2,
3724
+ uint3 tgpig[[threadgroup_position_in_grid]],
3725
+ uint tiitg[[thread_index_in_threadgroup]],
3726
+ uint3 tptg [[threads_per_threadgroup]]) {
3727
+ const int64_t i10 = tgpig.x;
3728
+ const int64_t i11 = tgpig.y;
3729
+
3730
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
3731
+
3732
+ const int64_t i02 = i11;
3733
+
3734
+ for (int ind = tiitg; ind < ne00; ind += tptg.x) {
3735
+ ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
3736
+ ((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
3737
+ }
3738
+ }
3739
+
3740
+ kernel void kernel_get_rows_f16(
3741
+ device const void * src0,
3742
+ device const char * src1,
3743
+ device float * dst,
3744
+ constant int64_t & ne00,
3745
+ constant uint64_t & nb01,
3746
+ constant uint64_t & nb02,
3747
+ constant int64_t & ne10,
3748
+ constant uint64_t & nb10,
3749
+ constant uint64_t & nb11,
3750
+ constant uint64_t & nb1,
3751
+ constant uint64_t & nb2,
3752
+ uint3 tgpig[[threadgroup_position_in_grid]],
3753
+ uint tiitg[[thread_index_in_threadgroup]],
3754
+ uint3 tptg [[threads_per_threadgroup]]) {
3755
+ const int64_t i10 = tgpig.x;
3756
+ const int64_t i11 = tgpig.y;
3757
+
3758
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
3759
+
3760
+ const int64_t i02 = i11;
3761
+
3762
+ for (int ind = tiitg; ind < ne00; ind += tptg.x) {
3763
+ ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
3764
+ ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
2630
3765
  }
2631
3766
  }
2632
3767
 
@@ -2643,24 +3778,25 @@ kernel void kernel_get_rows(
2643
3778
 
2644
3779
  // each block_q contains 16*nl weights
2645
3780
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
2646
- kernel void kernel_mul_mm(device const uchar * src0,
2647
- device const uchar * src1,
2648
- device float * dst,
2649
- constant int64_t & ne00,
2650
- constant int64_t & ne02,
2651
- constant int64_t & nb01,
2652
- constant int64_t & nb02,
2653
- constant int64_t & ne12,
2654
- constant int64_t & nb10,
2655
- constant int64_t & nb11,
2656
- constant int64_t & nb12,
2657
- constant int64_t & ne0,
2658
- constant int64_t & ne1,
2659
- constant uint & gqa,
2660
- threadgroup uchar * shared_memory [[threadgroup(0)]],
2661
- uint3 tgpig[[threadgroup_position_in_grid]],
2662
- uint tiitg[[thread_index_in_threadgroup]],
2663
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3781
+ void kernel_mul_mm_impl(device const uchar * src0,
3782
+ device const uchar * src1,
3783
+ device float * dst,
3784
+ constant int64_t & ne00,
3785
+ constant int64_t & ne02,
3786
+ constant int64_t & nb01,
3787
+ constant int64_t & nb02,
3788
+ constant int64_t & ne12,
3789
+ constant int64_t & nb10,
3790
+ constant int64_t & nb11,
3791
+ constant int64_t & nb12,
3792
+ constant int64_t & ne0,
3793
+ constant int64_t & ne1,
3794
+ constant uint & r2,
3795
+ constant uint & r3,
3796
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
3797
+ uint3 tgpig[[threadgroup_position_in_grid]],
3798
+ uint tiitg[[thread_index_in_threadgroup]],
3799
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2664
3800
 
2665
3801
  threadgroup half * sa = (threadgroup half *)(shared_memory);
2666
3802
  threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
@@ -2686,7 +3822,10 @@ kernel void kernel_mul_mm(device const uchar * src0,
2686
3822
 
2687
3823
  short il = (tiitg % THREAD_PER_ROW);
2688
3824
 
2689
- uint offset0 = im/gqa*nb02;
3825
+ const uint i12 = im%ne12;
3826
+ const uint i13 = im/ne12;
3827
+
3828
+ uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
2690
3829
  ushort offset1 = il/nl;
2691
3830
 
2692
3831
  device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
@@ -2770,17 +3909,137 @@ kernel void kernel_mul_mm(device const uchar * src0,
2770
3909
  }
2771
3910
  }
2772
3911
 
3912
+ template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
3913
+ kernel void kernel_mul_mm(device const uchar * src0,
3914
+ device const uchar * src1,
3915
+ device float * dst,
3916
+ constant int64_t & ne00,
3917
+ constant int64_t & ne02,
3918
+ constant int64_t & nb01,
3919
+ constant int64_t & nb02,
3920
+ constant int64_t & ne12,
3921
+ constant int64_t & nb10,
3922
+ constant int64_t & nb11,
3923
+ constant int64_t & nb12,
3924
+ constant int64_t & ne0,
3925
+ constant int64_t & ne1,
3926
+ constant uint & r2,
3927
+ constant uint & r3,
3928
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
3929
+ uint3 tgpig[[threadgroup_position_in_grid]],
3930
+ uint tiitg[[thread_index_in_threadgroup]],
3931
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3932
+ kernel_mul_mm_impl<block_q, nl, dequantize_func>(
3933
+ src0,
3934
+ src1,
3935
+ dst,
3936
+ ne00,
3937
+ ne02,
3938
+ nb01,
3939
+ nb02,
3940
+ ne12,
3941
+ nb10,
3942
+ nb11,
3943
+ nb12,
3944
+ ne0,
3945
+ ne1,
3946
+ r2,
3947
+ r3,
3948
+ shared_memory,
3949
+ tgpig,
3950
+ tiitg,
3951
+ sgitg);
3952
+ }
3953
+
3954
+ template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
3955
+ kernel void kernel_mul_mm_id(
3956
+ device const uchar * ids,
3957
+ device const uchar * src1,
3958
+ device uchar * dst,
3959
+ constant int64_t & nbi1,
3960
+ constant int64_t & ne00,
3961
+ constant int64_t & ne02,
3962
+ constant int64_t & nb01,
3963
+ constant int64_t & nb02,
3964
+ constant int64_t & ne12,
3965
+ constant int64_t & ne13,
3966
+ constant int64_t & nb10,
3967
+ constant int64_t & nb11,
3968
+ constant int64_t & nb12,
3969
+ constant int64_t & ne0,
3970
+ constant int64_t & ne1,
3971
+ constant int64_t & nb1,
3972
+ constant uint & r2,
3973
+ constant uint & r3,
3974
+ constant int & idx,
3975
+ device const uchar * src00,
3976
+ device const uchar * src01,
3977
+ device const uchar * src02,
3978
+ device const uchar * src03,
3979
+ device const uchar * src04,
3980
+ device const uchar * src05,
3981
+ device const uchar * src06,
3982
+ device const uchar * src07,
3983
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
3984
+ uint3 tgpig[[threadgroup_position_in_grid]],
3985
+ uint tiitg[[thread_index_in_threadgroup]],
3986
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3987
+ device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
3988
+
3989
+ const int64_t bid = tgpig.z/(ne12*ne13);
3990
+
3991
+ tgpig.z = tgpig.z%(ne12*ne13);
3992
+
3993
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
3994
+
3995
+ kernel_mul_mm_impl<block_q, nl, dequantize_func>(
3996
+ src0[id],
3997
+ src1 + bid*nb11,
3998
+ (device float *) (dst + bid*nb1),
3999
+ ne00,
4000
+ ne02,
4001
+ nb01,
4002
+ nb02,
4003
+ ne12,
4004
+ nb10,
4005
+ nb11,
4006
+ nb12,
4007
+ ne0,
4008
+ ne1,
4009
+ r2,
4010
+ r3,
4011
+ shared_memory,
4012
+ tgpig,
4013
+ tiitg,
4014
+ sgitg);
4015
+ }
4016
+
2773
4017
  #if QK_K == 256
2774
4018
  #define QK_NL 16
2775
4019
  #else
2776
4020
  #define QK_NL 4
2777
4021
  #endif
2778
4022
 
2779
- typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
2780
- constant uint64_t &, constant uint64_t &, uint, uint, uint);
4023
+ //
4024
+ // get rows
4025
+ //
2781
4026
 
2782
- template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
2783
- template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
4027
+ typedef void (get_rows_t)(
4028
+ device const void * src0,
4029
+ device const char * src1,
4030
+ device float * dst,
4031
+ constant int64_t & ne00,
4032
+ constant uint64_t & nb01,
4033
+ constant uint64_t & nb02,
4034
+ constant int64_t & ne10,
4035
+ constant uint64_t & nb10,
4036
+ constant uint64_t & nb11,
4037
+ constant uint64_t & nb1,
4038
+ constant uint64_t & nb2,
4039
+ uint3, uint, uint3);
4040
+
4041
+ //template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
4042
+ //template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
2784
4043
  template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
2785
4044
  template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
2786
4045
  template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
@@ -2792,6 +4051,10 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows
2792
4051
  template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
2793
4052
  template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
2794
4053
 
4054
+ //
4055
+ // matrix-matrix multiplication
4056
+ //
4057
+
2795
4058
  typedef void (mat_mm_t)(
2796
4059
  device const uchar * src0,
2797
4060
  device const uchar * src1,
@@ -2806,8 +4069,10 @@ typedef void (mat_mm_t)(
2806
4069
  constant int64_t & nb12,
2807
4070
  constant int64_t & ne0,
2808
4071
  constant int64_t & ne1,
2809
- constant uint & gqa,
2810
- threadgroup uchar *, uint3, uint, uint);
4072
+ constant uint & r2,
4073
+ constant uint & r3,
4074
+ threadgroup uchar *,
4075
+ uint3, uint, uint);
2811
4076
 
2812
4077
  template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
2813
4078
  template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
@@ -2821,3 +4086,823 @@ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
2821
4086
  template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
2822
4087
  template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
2823
4088
  template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
4089
+
4090
+ //
4091
+ // indirect matrix-matrix multiplication
4092
+ //
4093
+
4094
+ typedef void (mat_mm_id_t)(
4095
+ device const uchar * ids,
4096
+ device const uchar * src1,
4097
+ device uchar * dst,
4098
+ constant int64_t & nbi1,
4099
+ constant int64_t & ne00,
4100
+ constant int64_t & ne02,
4101
+ constant int64_t & nb01,
4102
+ constant int64_t & nb02,
4103
+ constant int64_t & ne12,
4104
+ constant int64_t & ne13,
4105
+ constant int64_t & nb10,
4106
+ constant int64_t & nb11,
4107
+ constant int64_t & nb12,
4108
+ constant int64_t & ne0,
4109
+ constant int64_t & ne1,
4110
+ constant int64_t & nb1,
4111
+ constant uint & r2,
4112
+ constant uint & r3,
4113
+ constant int & idx,
4114
+ device const uchar * src00,
4115
+ device const uchar * src01,
4116
+ device const uchar * src02,
4117
+ device const uchar * src03,
4118
+ device const uchar * src04,
4119
+ device const uchar * src05,
4120
+ device const uchar * src06,
4121
+ device const uchar * src07,
4122
+ threadgroup uchar *,
4123
+ uint3, uint, uint);
4124
+
4125
+ template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
4126
+ template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
4127
+ template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
4128
+ template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
4129
+ template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
4130
+ template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>;
4131
+ template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>;
4132
+ template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
4133
+ template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
4134
+ template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
4135
+ template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
4136
+ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
4137
+
4138
+ //
4139
+ // matrix-vector multiplication
4140
+ //
4141
+
4142
+ [[host_name("kernel_mul_mv_id_f32_f32")]]
4143
+ kernel void kernel_mul_mv_id_f32_f32(
4144
+ device const char * ids,
4145
+ device const char * src1,
4146
+ device uchar * dst,
4147
+ constant int64_t & nbi1,
4148
+ constant int64_t & ne00,
4149
+ constant int64_t & ne01,
4150
+ constant int64_t & ne02,
4151
+ constant uint64_t & nb00,
4152
+ constant uint64_t & nb01,
4153
+ constant uint64_t & nb02,
4154
+ constant int64_t & ne10,
4155
+ constant int64_t & ne11,
4156
+ constant int64_t & ne12,
4157
+ constant int64_t & ne13,
4158
+ constant uint64_t & nb10,
4159
+ constant uint64_t & nb11,
4160
+ constant uint64_t & nb12,
4161
+ constant int64_t & ne0,
4162
+ constant int64_t & ne1,
4163
+ constant int64_t & nb1,
4164
+ constant uint & r2,
4165
+ constant uint & r3,
4166
+ constant int & idx,
4167
+ device const char * src00,
4168
+ device const char * src01,
4169
+ device const char * src02,
4170
+ device const char * src03,
4171
+ device const char * src04,
4172
+ device const char * src05,
4173
+ device const char * src06,
4174
+ device const char * src07,
4175
+ uint3 tgpig[[threadgroup_position_in_grid]],
4176
+ uint tiitg[[thread_index_in_threadgroup]],
4177
+ uint tiisg[[thread_index_in_simdgroup]],
4178
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4179
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4180
+
4181
+ const int64_t bid = tgpig.z/(ne12*ne13);
4182
+
4183
+ tgpig.z = tgpig.z%(ne12*ne13);
4184
+
4185
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4186
+
4187
+ kernel_mul_mv_f32_f32_impl(
4188
+ src0[id],
4189
+ src1 + bid*nb11,
4190
+ (device float *) (dst + bid*nb1),
4191
+ ne00,
4192
+ ne01,
4193
+ ne02,
4194
+ nb00,
4195
+ nb01,
4196
+ nb02,
4197
+ ne10,
4198
+ ne11,
4199
+ ne12,
4200
+ nb10,
4201
+ nb11,
4202
+ nb12,
4203
+ ne0,
4204
+ ne1,
4205
+ r2,
4206
+ r3,
4207
+ tgpig,
4208
+ tiisg);
4209
+ }
4210
+
4211
+ [[host_name("kernel_mul_mv_id_f16_f32")]]
4212
+ kernel void kernel_mul_mv_id_f16_f32(
4213
+ device const char * ids,
4214
+ device const char * src1,
4215
+ device uchar * dst,
4216
+ constant int64_t & nbi1,
4217
+ constant int64_t & ne00,
4218
+ constant int64_t & ne01,
4219
+ constant int64_t & ne02,
4220
+ constant uint64_t & nb00,
4221
+ constant uint64_t & nb01,
4222
+ constant uint64_t & nb02,
4223
+ constant int64_t & ne10,
4224
+ constant int64_t & ne11,
4225
+ constant int64_t & ne12,
4226
+ constant int64_t & ne13,
4227
+ constant uint64_t & nb10,
4228
+ constant uint64_t & nb11,
4229
+ constant uint64_t & nb12,
4230
+ constant int64_t & ne0,
4231
+ constant int64_t & ne1,
4232
+ constant int64_t & nb1,
4233
+ constant uint & r2,
4234
+ constant uint & r3,
4235
+ constant int & idx,
4236
+ device const char * src00,
4237
+ device const char * src01,
4238
+ device const char * src02,
4239
+ device const char * src03,
4240
+ device const char * src04,
4241
+ device const char * src05,
4242
+ device const char * src06,
4243
+ device const char * src07,
4244
+ uint3 tgpig[[threadgroup_position_in_grid]],
4245
+ uint tiitg[[thread_index_in_threadgroup]],
4246
+ uint tiisg[[thread_index_in_simdgroup]],
4247
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4248
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4249
+
4250
+ const int64_t bid = tgpig.z/(ne12*ne13);
4251
+
4252
+ tgpig.z = tgpig.z%(ne12*ne13);
4253
+
4254
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4255
+
4256
+ kernel_mul_mv_f16_f32_impl(
4257
+ src0[id],
4258
+ src1 + bid*nb11,
4259
+ (device float *) (dst + bid*nb1),
4260
+ ne00,
4261
+ ne01,
4262
+ ne02,
4263
+ nb00,
4264
+ nb01,
4265
+ nb02,
4266
+ ne10,
4267
+ ne11,
4268
+ ne12,
4269
+ nb10,
4270
+ nb11,
4271
+ nb12,
4272
+ ne0,
4273
+ ne1,
4274
+ r2,
4275
+ r3,
4276
+ tgpig,
4277
+ tiisg);
4278
+ }
4279
+
4280
+ [[host_name("kernel_mul_mv_id_q8_0_f32")]]
4281
+ kernel void kernel_mul_mv_id_q8_0_f32(
4282
+ device const char * ids,
4283
+ device const char * src1,
4284
+ device uchar * dst,
4285
+ constant int64_t & nbi1,
4286
+ constant int64_t & ne00,
4287
+ constant int64_t & ne01,
4288
+ constant int64_t & ne02,
4289
+ constant uint64_t & nb00,
4290
+ constant uint64_t & nb01,
4291
+ constant uint64_t & nb02,
4292
+ constant int64_t & ne10,
4293
+ constant int64_t & ne11,
4294
+ constant int64_t & ne12,
4295
+ constant int64_t & ne13,
4296
+ constant uint64_t & nb10,
4297
+ constant uint64_t & nb11,
4298
+ constant uint64_t & nb12,
4299
+ constant int64_t & ne0,
4300
+ constant int64_t & ne1,
4301
+ constant int64_t & nb1,
4302
+ constant uint & r2,
4303
+ constant uint & r3,
4304
+ constant int & idx,
4305
+ device const char * src00,
4306
+ device const char * src01,
4307
+ device const char * src02,
4308
+ device const char * src03,
4309
+ device const char * src04,
4310
+ device const char * src05,
4311
+ device const char * src06,
4312
+ device const char * src07,
4313
+ uint3 tgpig[[threadgroup_position_in_grid]],
4314
+ uint tiitg[[thread_index_in_threadgroup]],
4315
+ uint tiisg[[thread_index_in_simdgroup]],
4316
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4317
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4318
+
4319
+ const int64_t bid = tgpig.z/(ne12*ne13);
4320
+
4321
+ tgpig.z = tgpig.z%(ne12*ne13);
4322
+
4323
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4324
+
4325
+ kernel_mul_mv_q8_0_f32_impl(
4326
+ src0[id],
4327
+ (device const float *) (src1 + bid*nb11),
4328
+ (device float *) ( dst + bid*nb1),
4329
+ ne00,
4330
+ ne01,
4331
+ ne02,
4332
+ ne10,
4333
+ ne12,
4334
+ ne0,
4335
+ ne1,
4336
+ r2,
4337
+ r3,
4338
+ tgpig,
4339
+ tiisg,
4340
+ sgitg);
4341
+ }
4342
+
4343
+ [[host_name("kernel_mul_mv_id_q4_0_f32")]]
4344
+ kernel void kernel_mul_mv_id_q4_0_f32(
4345
+ device const char * ids,
4346
+ device const char * src1,
4347
+ device uchar * dst,
4348
+ constant int64_t & nbi1,
4349
+ constant int64_t & ne00,
4350
+ constant int64_t & ne01,
4351
+ constant int64_t & ne02,
4352
+ constant uint64_t & nb00,
4353
+ constant uint64_t & nb01,
4354
+ constant uint64_t & nb02,
4355
+ constant int64_t & ne10,
4356
+ constant int64_t & ne11,
4357
+ constant int64_t & ne12,
4358
+ constant int64_t & ne13,
4359
+ constant uint64_t & nb10,
4360
+ constant uint64_t & nb11,
4361
+ constant uint64_t & nb12,
4362
+ constant int64_t & ne0,
4363
+ constant int64_t & ne1,
4364
+ constant int64_t & nb1,
4365
+ constant uint & r2,
4366
+ constant uint & r3,
4367
+ constant int & idx,
4368
+ device const char * src00,
4369
+ device const char * src01,
4370
+ device const char * src02,
4371
+ device const char * src03,
4372
+ device const char * src04,
4373
+ device const char * src05,
4374
+ device const char * src06,
4375
+ device const char * src07,
4376
+ uint3 tgpig[[threadgroup_position_in_grid]],
4377
+ uint tiitg[[thread_index_in_threadgroup]],
4378
+ uint tiisg[[thread_index_in_simdgroup]],
4379
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4380
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4381
+
4382
+ const int64_t bid = tgpig.z/(ne12*ne13);
4383
+
4384
+ tgpig.z = tgpig.z%(ne12*ne13);
4385
+
4386
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4387
+
4388
+ mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4389
+ src0[id],
4390
+ (device const float *) (src1 + bid*nb11),
4391
+ (device float *) ( dst + bid*nb1),
4392
+ ne00,
4393
+ ne01,
4394
+ ne02,
4395
+ ne10,
4396
+ ne12,
4397
+ ne0,
4398
+ ne1,
4399
+ r2,
4400
+ r3,
4401
+ tgpig,
4402
+ tiisg,
4403
+ sgitg);
4404
+ }
4405
+
4406
+ [[host_name("kernel_mul_mv_id_q4_1_f32")]]
4407
+ kernel void kernel_mul_mv_id_q4_1_f32(
4408
+ device const char * ids,
4409
+ device const char * src1,
4410
+ device uchar * dst,
4411
+ constant int64_t & nbi1,
4412
+ constant int64_t & ne00,
4413
+ constant int64_t & ne01,
4414
+ constant int64_t & ne02,
4415
+ constant uint64_t & nb00,
4416
+ constant uint64_t & nb01,
4417
+ constant uint64_t & nb02,
4418
+ constant int64_t & ne10,
4419
+ constant int64_t & ne11,
4420
+ constant int64_t & ne12,
4421
+ constant int64_t & ne13,
4422
+ constant uint64_t & nb10,
4423
+ constant uint64_t & nb11,
4424
+ constant uint64_t & nb12,
4425
+ constant int64_t & ne0,
4426
+ constant int64_t & ne1,
4427
+ constant int64_t & nb1,
4428
+ constant uint & r2,
4429
+ constant uint & r3,
4430
+ constant int & idx,
4431
+ device const char * src00,
4432
+ device const char * src01,
4433
+ device const char * src02,
4434
+ device const char * src03,
4435
+ device const char * src04,
4436
+ device const char * src05,
4437
+ device const char * src06,
4438
+ device const char * src07,
4439
+ uint3 tgpig[[threadgroup_position_in_grid]],
4440
+ uint tiitg[[thread_index_in_threadgroup]],
4441
+ uint tiisg[[thread_index_in_simdgroup]],
4442
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4443
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4444
+
4445
+ const int64_t bid = tgpig.z/(ne12*ne13);
4446
+
4447
+ tgpig.z = tgpig.z%(ne12*ne13);
4448
+
4449
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4450
+
4451
+ mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4452
+ src0[id],
4453
+ (device const float *) (src1 + bid*nb11),
4454
+ (device float *) ( dst + bid*nb1),
4455
+ ne00,
4456
+ ne01,
4457
+ ne02,
4458
+ ne10,
4459
+ ne12,
4460
+ ne0,
4461
+ ne1,
4462
+ r2,
4463
+ r3,
4464
+ tgpig,
4465
+ tiisg,
4466
+ sgitg);
4467
+ }
4468
+
4469
+ [[host_name("kernel_mul_mv_id_q5_0_f32")]]
4470
+ kernel void kernel_mul_mv_id_q5_0_f32(
4471
+ device const char * ids,
4472
+ device const char * src1,
4473
+ device uchar * dst,
4474
+ constant int64_t & nbi1,
4475
+ constant int64_t & ne00,
4476
+ constant int64_t & ne01,
4477
+ constant int64_t & ne02,
4478
+ constant uint64_t & nb00,
4479
+ constant uint64_t & nb01,
4480
+ constant uint64_t & nb02,
4481
+ constant int64_t & ne10,
4482
+ constant int64_t & ne11,
4483
+ constant int64_t & ne12,
4484
+ constant int64_t & ne13,
4485
+ constant uint64_t & nb10,
4486
+ constant uint64_t & nb11,
4487
+ constant uint64_t & nb12,
4488
+ constant int64_t & ne0,
4489
+ constant int64_t & ne1,
4490
+ constant int64_t & nb1,
4491
+ constant uint & r2,
4492
+ constant uint & r3,
4493
+ constant int & idx,
4494
+ device const char * src00,
4495
+ device const char * src01,
4496
+ device const char * src02,
4497
+ device const char * src03,
4498
+ device const char * src04,
4499
+ device const char * src05,
4500
+ device const char * src06,
4501
+ device const char * src07,
4502
+ uint3 tgpig[[threadgroup_position_in_grid]],
4503
+ uint tiitg[[thread_index_in_threadgroup]],
4504
+ uint tiisg[[thread_index_in_simdgroup]],
4505
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4506
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4507
+
4508
+ const int64_t bid = tgpig.z/(ne12*ne13);
4509
+
4510
+ tgpig.z = tgpig.z%(ne12*ne13);
4511
+
4512
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4513
+
4514
+ mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4515
+ src0[id],
4516
+ (device const float *) (src1 + bid*nb11),
4517
+ (device float *) ( dst + bid*nb1),
4518
+ ne00,
4519
+ ne01,
4520
+ ne02,
4521
+ ne10,
4522
+ ne12,
4523
+ ne0,
4524
+ ne1,
4525
+ r2,
4526
+ r3,
4527
+ tgpig,
4528
+ tiisg,
4529
+ sgitg);
4530
+ }
4531
+
4532
+ [[host_name("kernel_mul_mv_id_q5_1_f32")]]
4533
+ kernel void kernel_mul_mv_id_q5_1_f32(
4534
+ device const char * ids,
4535
+ device const char * src1,
4536
+ device uchar * dst,
4537
+ constant int64_t & nbi1,
4538
+ constant int64_t & ne00,
4539
+ constant int64_t & ne01,
4540
+ constant int64_t & ne02,
4541
+ constant uint64_t & nb00,
4542
+ constant uint64_t & nb01,
4543
+ constant uint64_t & nb02,
4544
+ constant int64_t & ne10,
4545
+ constant int64_t & ne11,
4546
+ constant int64_t & ne12,
4547
+ constant int64_t & ne13,
4548
+ constant uint64_t & nb10,
4549
+ constant uint64_t & nb11,
4550
+ constant uint64_t & nb12,
4551
+ constant int64_t & ne0,
4552
+ constant int64_t & ne1,
4553
+ constant int64_t & nb1,
4554
+ constant uint & r2,
4555
+ constant uint & r3,
4556
+ constant int & idx,
4557
+ device const char * src00,
4558
+ device const char * src01,
4559
+ device const char * src02,
4560
+ device const char * src03,
4561
+ device const char * src04,
4562
+ device const char * src05,
4563
+ device const char * src06,
4564
+ device const char * src07,
4565
+ uint3 tgpig[[threadgroup_position_in_grid]],
4566
+ uint tiitg[[thread_index_in_threadgroup]],
4567
+ uint tiisg[[thread_index_in_simdgroup]],
4568
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4569
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4570
+
4571
+ const int64_t bid = tgpig.z/(ne12*ne13);
4572
+
4573
+ tgpig.z = tgpig.z%(ne12*ne13);
4574
+
4575
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4576
+
4577
+ mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4578
+ src0[id],
4579
+ (device const float *) (src1 + bid*nb11),
4580
+ (device float *) ( dst + bid*nb1),
4581
+ ne00,
4582
+ ne01,
4583
+ ne02,
4584
+ ne10,
4585
+ ne12,
4586
+ ne0,
4587
+ ne1,
4588
+ r2,
4589
+ r3,
4590
+ tgpig,
4591
+ tiisg,
4592
+ sgitg);
4593
+ }
4594
+
4595
+ [[host_name("kernel_mul_mv_id_q2_K_f32")]]
4596
+ kernel void kernel_mul_mv_id_q2_K_f32(
4597
+ device const char * ids,
4598
+ device const char * src1,
4599
+ device uchar * dst,
4600
+ constant int64_t & nbi1,
4601
+ constant int64_t & ne00,
4602
+ constant int64_t & ne01,
4603
+ constant int64_t & ne02,
4604
+ constant uint64_t & nb00,
4605
+ constant uint64_t & nb01,
4606
+ constant uint64_t & nb02,
4607
+ constant int64_t & ne10,
4608
+ constant int64_t & ne11,
4609
+ constant int64_t & ne12,
4610
+ constant int64_t & ne13,
4611
+ constant uint64_t & nb10,
4612
+ constant uint64_t & nb11,
4613
+ constant uint64_t & nb12,
4614
+ constant int64_t & ne0,
4615
+ constant int64_t & ne1,
4616
+ constant int64_t & nb1,
4617
+ constant uint & r2,
4618
+ constant uint & r3,
4619
+ constant int & idx,
4620
+ device const char * src00,
4621
+ device const char * src01,
4622
+ device const char * src02,
4623
+ device const char * src03,
4624
+ device const char * src04,
4625
+ device const char * src05,
4626
+ device const char * src06,
4627
+ device const char * src07,
4628
+ uint3 tgpig[[threadgroup_position_in_grid]],
4629
+ uint tiitg[[thread_index_in_threadgroup]],
4630
+ uint tiisg[[thread_index_in_simdgroup]],
4631
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4632
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4633
+
4634
+ const int64_t bid = tgpig.z/(ne12*ne13);
4635
+
4636
+ tgpig.z = tgpig.z%(ne12*ne13);
4637
+
4638
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4639
+
4640
+ kernel_mul_mv_q2_K_f32_impl(
4641
+ src0[id],
4642
+ (device const float *) (src1 + bid*nb11),
4643
+ (device float *) ( dst + bid*nb1),
4644
+ ne00,
4645
+ ne01,
4646
+ ne02,
4647
+ ne10,
4648
+ ne12,
4649
+ ne0,
4650
+ ne1,
4651
+ r2,
4652
+ r3,
4653
+ tgpig,
4654
+ tiisg,
4655
+ sgitg);
4656
+ }
4657
+
4658
+ [[host_name("kernel_mul_mv_id_q3_K_f32")]]
4659
+ kernel void kernel_mul_mv_id_q3_K_f32(
4660
+ device const char * ids,
4661
+ device const char * src1,
4662
+ device uchar * dst,
4663
+ constant int64_t & nbi1,
4664
+ constant int64_t & ne00,
4665
+ constant int64_t & ne01,
4666
+ constant int64_t & ne02,
4667
+ constant uint64_t & nb00,
4668
+ constant uint64_t & nb01,
4669
+ constant uint64_t & nb02,
4670
+ constant int64_t & ne10,
4671
+ constant int64_t & ne11,
4672
+ constant int64_t & ne12,
4673
+ constant int64_t & ne13,
4674
+ constant uint64_t & nb10,
4675
+ constant uint64_t & nb11,
4676
+ constant uint64_t & nb12,
4677
+ constant int64_t & ne0,
4678
+ constant int64_t & ne1,
4679
+ constant int64_t & nb1,
4680
+ constant uint & r2,
4681
+ constant uint & r3,
4682
+ constant int & idx,
4683
+ device const char * src00,
4684
+ device const char * src01,
4685
+ device const char * src02,
4686
+ device const char * src03,
4687
+ device const char * src04,
4688
+ device const char * src05,
4689
+ device const char * src06,
4690
+ device const char * src07,
4691
+ uint3 tgpig[[threadgroup_position_in_grid]],
4692
+ uint tiitg[[thread_index_in_threadgroup]],
4693
+ uint tiisg[[thread_index_in_simdgroup]],
4694
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4695
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4696
+
4697
+ const int64_t bid = tgpig.z/(ne12*ne13);
4698
+
4699
+ tgpig.z = tgpig.z%(ne12*ne13);
4700
+
4701
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4702
+
4703
+ kernel_mul_mv_q3_K_f32_impl(
4704
+ src0[id],
4705
+ (device const float *) (src1 + bid*nb11),
4706
+ (device float *) ( dst + bid*nb1),
4707
+ ne00,
4708
+ ne01,
4709
+ ne02,
4710
+ ne10,
4711
+ ne12,
4712
+ ne0,
4713
+ ne1,
4714
+ r2,
4715
+ r3,
4716
+ tgpig,
4717
+ tiisg,
4718
+ sgitg);
4719
+ }
4720
+
4721
+ [[host_name("kernel_mul_mv_id_q4_K_f32")]]
4722
+ kernel void kernel_mul_mv_id_q4_K_f32(
4723
+ device const char * ids,
4724
+ device const char * src1,
4725
+ device uchar * dst,
4726
+ constant int64_t & nbi1,
4727
+ constant int64_t & ne00,
4728
+ constant int64_t & ne01,
4729
+ constant int64_t & ne02,
4730
+ constant uint64_t & nb00,
4731
+ constant uint64_t & nb01,
4732
+ constant uint64_t & nb02,
4733
+ constant int64_t & ne10,
4734
+ constant int64_t & ne11,
4735
+ constant int64_t & ne12,
4736
+ constant int64_t & ne13,
4737
+ constant uint64_t & nb10,
4738
+ constant uint64_t & nb11,
4739
+ constant uint64_t & nb12,
4740
+ constant int64_t & ne0,
4741
+ constant int64_t & ne1,
4742
+ constant int64_t & nb1,
4743
+ constant uint & r2,
4744
+ constant uint & r3,
4745
+ constant int & idx,
4746
+ device const char * src00,
4747
+ device const char * src01,
4748
+ device const char * src02,
4749
+ device const char * src03,
4750
+ device const char * src04,
4751
+ device const char * src05,
4752
+ device const char * src06,
4753
+ device const char * src07,
4754
+ uint3 tgpig[[threadgroup_position_in_grid]],
4755
+ uint tiitg[[thread_index_in_threadgroup]],
4756
+ uint tiisg[[thread_index_in_simdgroup]],
4757
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4758
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4759
+
4760
+ const int64_t bid = tgpig.z/(ne12*ne13);
4761
+
4762
+ tgpig.z = tgpig.z%(ne12*ne13);
4763
+
4764
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4765
+
4766
+ kernel_mul_mv_q4_K_f32_impl(
4767
+ src0[id],
4768
+ (device const float *) (src1 + bid*nb11),
4769
+ (device float *) ( dst + bid*nb1),
4770
+ ne00,
4771
+ ne01,
4772
+ ne02,
4773
+ ne10,
4774
+ ne12,
4775
+ ne0,
4776
+ ne1,
4777
+ r2,
4778
+ r3,
4779
+ tgpig,
4780
+ tiisg,
4781
+ sgitg);
4782
+ }
4783
+
4784
+ [[host_name("kernel_mul_mv_id_q5_K_f32")]]
4785
+ kernel void kernel_mul_mv_id_q5_K_f32(
4786
+ device const char * ids,
4787
+ device const char * src1,
4788
+ device uchar * dst,
4789
+ constant int64_t & nbi1,
4790
+ constant int64_t & ne00,
4791
+ constant int64_t & ne01,
4792
+ constant int64_t & ne02,
4793
+ constant uint64_t & nb00,
4794
+ constant uint64_t & nb01,
4795
+ constant uint64_t & nb02,
4796
+ constant int64_t & ne10,
4797
+ constant int64_t & ne11,
4798
+ constant int64_t & ne12,
4799
+ constant int64_t & ne13,
4800
+ constant uint64_t & nb10,
4801
+ constant uint64_t & nb11,
4802
+ constant uint64_t & nb12,
4803
+ constant int64_t & ne0,
4804
+ constant int64_t & ne1,
4805
+ constant int64_t & nb1,
4806
+ constant uint & r2,
4807
+ constant uint & r3,
4808
+ constant int & idx,
4809
+ device const char * src00,
4810
+ device const char * src01,
4811
+ device const char * src02,
4812
+ device const char * src03,
4813
+ device const char * src04,
4814
+ device const char * src05,
4815
+ device const char * src06,
4816
+ device const char * src07,
4817
+ uint3 tgpig[[threadgroup_position_in_grid]],
4818
+ uint tiitg[[thread_index_in_threadgroup]],
4819
+ uint tiisg[[thread_index_in_simdgroup]],
4820
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4821
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4822
+
4823
+ const int64_t bid = tgpig.z/(ne12*ne13);
4824
+
4825
+ tgpig.z = tgpig.z%(ne12*ne13);
4826
+
4827
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4828
+
4829
+ kernel_mul_mv_q5_K_f32_impl(
4830
+ src0[id],
4831
+ (device const float *) (src1 + bid*nb11),
4832
+ (device float *) ( dst + bid*nb1),
4833
+ ne00,
4834
+ ne01,
4835
+ ne02,
4836
+ ne10,
4837
+ ne12,
4838
+ ne0,
4839
+ ne1,
4840
+ r2,
4841
+ r3,
4842
+ tgpig,
4843
+ tiisg,
4844
+ sgitg);
4845
+ }
4846
+
4847
+ [[host_name("kernel_mul_mv_id_q6_K_f32")]]
4848
+ kernel void kernel_mul_mv_id_q6_K_f32(
4849
+ device const char * ids,
4850
+ device const char * src1,
4851
+ device uchar * dst,
4852
+ constant int64_t & nbi1,
4853
+ constant int64_t & ne00,
4854
+ constant int64_t & ne01,
4855
+ constant int64_t & ne02,
4856
+ constant uint64_t & nb00,
4857
+ constant uint64_t & nb01,
4858
+ constant uint64_t & nb02,
4859
+ constant int64_t & ne10,
4860
+ constant int64_t & ne11,
4861
+ constant int64_t & ne12,
4862
+ constant int64_t & ne13,
4863
+ constant uint64_t & nb10,
4864
+ constant uint64_t & nb11,
4865
+ constant uint64_t & nb12,
4866
+ constant int64_t & ne0,
4867
+ constant int64_t & ne1,
4868
+ constant int64_t & nb1,
4869
+ constant uint & r2,
4870
+ constant uint & r3,
4871
+ constant int & idx,
4872
+ device const char * src00,
4873
+ device const char * src01,
4874
+ device const char * src02,
4875
+ device const char * src03,
4876
+ device const char * src04,
4877
+ device const char * src05,
4878
+ device const char * src06,
4879
+ device const char * src07,
4880
+ uint3 tgpig[[threadgroup_position_in_grid]],
4881
+ uint tiitg[[thread_index_in_threadgroup]],
4882
+ uint tiisg[[thread_index_in_simdgroup]],
4883
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4884
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4885
+
4886
+ const int64_t bid = tgpig.z/(ne12*ne13);
4887
+
4888
+ tgpig.z = tgpig.z%(ne12*ne13);
4889
+
4890
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4891
+
4892
+ kernel_mul_mv_q6_K_f32_impl(
4893
+ src0[id],
4894
+ (device const float *) (src1 + bid*nb11),
4895
+ (device float *) ( dst + bid*nb1),
4896
+ ne00,
4897
+ ne01,
4898
+ ne02,
4899
+ ne10,
4900
+ ne12,
4901
+ ne0,
4902
+ ne1,
4903
+ r2,
4904
+ r3,
4905
+ tgpig,
4906
+ tiisg,
4907
+ sgitg);
4908
+ }