whisper.rn 0.4.0-rc.4 → 0.4.0-rc.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. package/README.md +6 -6
  2. package/android/build.gradle +4 -0
  3. package/android/src/main/CMakeLists.txt +5 -0
  4. package/android/src/main/java/com/rnwhisper/AudioUtils.java +0 -80
  5. package/android/src/main/java/com/rnwhisper/WhisperContext.java +51 -133
  6. package/android/src/main/jni-utils.h +76 -0
  7. package/android/src/main/jni.cpp +187 -112
  8. package/cpp/README.md +1 -1
  9. package/cpp/coreml/whisper-encoder-impl.h +1 -1
  10. package/cpp/coreml/whisper-encoder.h +4 -0
  11. package/cpp/coreml/whisper-encoder.mm +4 -2
  12. package/cpp/ggml-alloc.c +55 -19
  13. package/cpp/ggml-alloc.h +7 -0
  14. package/cpp/ggml-backend-impl.h +46 -21
  15. package/cpp/ggml-backend.c +563 -156
  16. package/cpp/ggml-backend.h +62 -17
  17. package/cpp/ggml-impl.h +1 -1
  18. package/cpp/ggml-metal-whisper.metal +1010 -253
  19. package/cpp/ggml-metal.h +7 -1
  20. package/cpp/ggml-metal.m +618 -187
  21. package/cpp/ggml-quants.c +64 -59
  22. package/cpp/ggml-quants.h +40 -40
  23. package/cpp/ggml.c +751 -1466
  24. package/cpp/ggml.h +90 -25
  25. package/cpp/rn-audioutils.cpp +68 -0
  26. package/cpp/rn-audioutils.h +14 -0
  27. package/cpp/rn-whisper-log.h +11 -0
  28. package/cpp/rn-whisper.cpp +141 -59
  29. package/cpp/rn-whisper.h +47 -15
  30. package/cpp/whisper.cpp +1635 -928
  31. package/cpp/whisper.h +55 -10
  32. package/ios/RNWhisper.mm +7 -7
  33. package/ios/RNWhisperAudioUtils.h +0 -2
  34. package/ios/RNWhisperAudioUtils.m +0 -56
  35. package/ios/RNWhisperContext.h +3 -11
  36. package/ios/RNWhisperContext.mm +62 -134
  37. package/lib/commonjs/version.json +1 -1
  38. package/lib/module/version.json +1 -1
  39. package/package.json +6 -5
  40. package/src/version.json +1 -1
@@ -3,6 +3,8 @@
3
3
  using namespace metal;
4
4
 
5
5
  #define MAX(x, y) ((x) > (y) ? (x) : (y))
6
+ #define MIN(x, y) ((x) < (y) ? (x) : (y))
7
+ #define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
6
8
 
7
9
  #define QK4_0 32
8
10
  #define QR4_0 2
@@ -39,8 +41,15 @@ typedef struct {
39
41
  int8_t qs[QK8_0]; // quants
40
42
  } block_q8_0;
41
43
 
42
- // general-purpose kernel for addition of two tensors
43
- // pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
44
+ #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
45
+
46
+ enum ggml_sort_order {
47
+ GGML_SORT_ASC,
48
+ GGML_SORT_DESC,
49
+ };
50
+
51
+ // general-purpose kernel for addition, multiplication and division of two tensors
52
+ // pros: works for non-contiguous tensors, supports broadcast across all dims
44
53
  // cons: not very efficient
45
54
  kernel void kernel_add(
46
55
  device const char * src0,
@@ -81,16 +90,111 @@ kernel void kernel_add(
81
90
  const int64_t i12 = i02 % ne12;
82
91
  const int64_t i11 = i01 % ne11;
83
92
 
84
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
85
- device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
86
- device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
93
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
94
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
95
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
87
96
 
88
97
  for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
89
- ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0] + ((device float *)src1_ptr)[0];
98
+ const int i10 = i0 % ne10;
99
+ *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10));
100
+ }
101
+ }
102
+
103
+ kernel void kernel_mul(
104
+ device const char * src0,
105
+ device const char * src1,
106
+ device char * dst,
107
+ constant int64_t & ne00,
108
+ constant int64_t & ne01,
109
+ constant int64_t & ne02,
110
+ constant int64_t & ne03,
111
+ constant int64_t & nb00,
112
+ constant int64_t & nb01,
113
+ constant int64_t & nb02,
114
+ constant int64_t & nb03,
115
+ constant int64_t & ne10,
116
+ constant int64_t & ne11,
117
+ constant int64_t & ne12,
118
+ constant int64_t & ne13,
119
+ constant int64_t & nb10,
120
+ constant int64_t & nb11,
121
+ constant int64_t & nb12,
122
+ constant int64_t & nb13,
123
+ constant int64_t & ne0,
124
+ constant int64_t & ne1,
125
+ constant int64_t & ne2,
126
+ constant int64_t & ne3,
127
+ constant int64_t & nb0,
128
+ constant int64_t & nb1,
129
+ constant int64_t & nb2,
130
+ constant int64_t & nb3,
131
+ uint3 tgpig[[threadgroup_position_in_grid]],
132
+ uint3 tpitg[[thread_position_in_threadgroup]],
133
+ uint3 ntg[[threads_per_threadgroup]]) {
134
+ const int64_t i03 = tgpig.z;
135
+ const int64_t i02 = tgpig.y;
136
+ const int64_t i01 = tgpig.x;
137
+
138
+ const int64_t i13 = i03 % ne13;
139
+ const int64_t i12 = i02 % ne12;
140
+ const int64_t i11 = i01 % ne11;
141
+
142
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
143
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
144
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
90
145
 
91
- src0_ptr += ntg.x*nb00;
92
- src1_ptr += ntg.x*nb10;
93
- dst_ptr += ntg.x*nb0;
146
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
147
+ const int i10 = i0 % ne10;
148
+ *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10));
149
+ }
150
+ }
151
+
152
+ kernel void kernel_div(
153
+ device const char * src0,
154
+ device const char * src1,
155
+ device char * dst,
156
+ constant int64_t & ne00,
157
+ constant int64_t & ne01,
158
+ constant int64_t & ne02,
159
+ constant int64_t & ne03,
160
+ constant int64_t & nb00,
161
+ constant int64_t & nb01,
162
+ constant int64_t & nb02,
163
+ constant int64_t & nb03,
164
+ constant int64_t & ne10,
165
+ constant int64_t & ne11,
166
+ constant int64_t & ne12,
167
+ constant int64_t & ne13,
168
+ constant int64_t & nb10,
169
+ constant int64_t & nb11,
170
+ constant int64_t & nb12,
171
+ constant int64_t & nb13,
172
+ constant int64_t & ne0,
173
+ constant int64_t & ne1,
174
+ constant int64_t & ne2,
175
+ constant int64_t & ne3,
176
+ constant int64_t & nb0,
177
+ constant int64_t & nb1,
178
+ constant int64_t & nb2,
179
+ constant int64_t & nb3,
180
+ uint3 tgpig[[threadgroup_position_in_grid]],
181
+ uint3 tpitg[[thread_position_in_threadgroup]],
182
+ uint3 ntg[[threads_per_threadgroup]]) {
183
+ const int64_t i03 = tgpig.z;
184
+ const int64_t i02 = tgpig.y;
185
+ const int64_t i01 = tgpig.x;
186
+
187
+ const int64_t i13 = i03 % ne13;
188
+ const int64_t i12 = i02 % ne12;
189
+ const int64_t i11 = i01 % ne11;
190
+
191
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
192
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
193
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
194
+
195
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
196
+ const int i10 = i0 % ne10;
197
+ *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10));
94
198
  }
95
199
  }
96
200
 
@@ -105,23 +209,22 @@ kernel void kernel_add_row(
105
209
  dst[tpig] = src0[tpig] + src1[tpig % nb];
106
210
  }
107
211
 
108
- kernel void kernel_mul(
212
+ kernel void kernel_mul_row(
109
213
  device const float4 * src0,
110
214
  device const float4 * src1,
111
215
  device float4 * dst,
216
+ constant int64_t & nb [[buffer(27)]],
112
217
  uint tpig[[thread_position_in_grid]]) {
113
- dst[tpig] = src0[tpig] * src1[tpig];
218
+ dst[tpig] = src0[tpig] * src1[tpig % nb];
114
219
  }
115
220
 
116
- // assumption: src1 is a row
117
- // broadcast src1 into src0
118
- kernel void kernel_mul_row(
221
+ kernel void kernel_div_row(
119
222
  device const float4 * src0,
120
223
  device const float4 * src1,
121
224
  device float4 * dst,
122
- constant int64_t & nb,
225
+ constant int64_t & nb [[buffer(27)]],
123
226
  uint tpig[[thread_position_in_grid]]) {
124
- dst[tpig] = src0[tpig] * src1[tpig % nb];
227
+ dst[tpig] = src0[tpig] / src1[tpig % nb];
125
228
  }
126
229
 
127
230
  kernel void kernel_scale(
@@ -162,6 +265,54 @@ kernel void kernel_sqr(
162
265
  dst[tpig] = src0[tpig] * src0[tpig];
163
266
  }
164
267
 
268
+ kernel void kernel_sum_rows(
269
+ device const float * src0,
270
+ device float * dst,
271
+ constant int64_t & ne00,
272
+ constant int64_t & ne01,
273
+ constant int64_t & ne02,
274
+ constant int64_t & ne03,
275
+ constant int64_t & nb00,
276
+ constant int64_t & nb01,
277
+ constant int64_t & nb02,
278
+ constant int64_t & nb03,
279
+ constant int64_t & ne10,
280
+ constant int64_t & ne11,
281
+ constant int64_t & ne12,
282
+ constant int64_t & ne13,
283
+ constant int64_t & nb10,
284
+ constant int64_t & nb11,
285
+ constant int64_t & nb12,
286
+ constant int64_t & nb13,
287
+ constant int64_t & ne0,
288
+ constant int64_t & ne1,
289
+ constant int64_t & ne2,
290
+ constant int64_t & ne3,
291
+ constant int64_t & nb0,
292
+ constant int64_t & nb1,
293
+ constant int64_t & nb2,
294
+ constant int64_t & nb3,
295
+ uint3 tpig[[thread_position_in_grid]]) {
296
+ int64_t i3 = tpig.z;
297
+ int64_t i2 = tpig.y;
298
+ int64_t i1 = tpig.x;
299
+
300
+ if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
301
+ return;
302
+ }
303
+
304
+ device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
305
+ device float * dst_row = (device float *) ((device char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
306
+
307
+ float row_sum = 0;
308
+
309
+ for (int64_t i0 = 0; i0 < ne00; i0++) {
310
+ row_sum += src_row[i0];
311
+ }
312
+
313
+ dst_row[0] = row_sum;
314
+ }
315
+
165
316
  constant float GELU_COEF_A = 0.044715f;
166
317
  constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
167
318
 
@@ -180,10 +331,12 @@ kernel void kernel_gelu(
180
331
 
181
332
  kernel void kernel_soft_max(
182
333
  device const float * src0,
334
+ device const float * src1,
183
335
  device float * dst,
184
336
  constant int64_t & ne00,
185
337
  constant int64_t & ne01,
186
338
  constant int64_t & ne02,
339
+ constant float & scale,
187
340
  threadgroup float * buf [[threadgroup(0)]],
188
341
  uint tgpig[[threadgroup_position_in_grid]],
189
342
  uint tpitg[[thread_position_in_threadgroup]],
@@ -194,73 +347,77 @@ kernel void kernel_soft_max(
194
347
  const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
195
348
  const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
196
349
 
197
- device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
198
- device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
350
+ device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
351
+ device const float * pmask = src1 ? src1 + i01*ne00 : nullptr;
352
+ device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
199
353
 
200
354
  // parallel max
201
- float lmax = tpitg < ne00 ? psrc0[tpitg] : -INFINITY;
355
+ float lmax = -INFINITY;
202
356
 
203
- for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) {
204
- lmax = MAX(lmax, psrc0[i00]);
357
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
358
+ lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f));
205
359
  }
206
360
 
207
- float max = simd_max(lmax);
208
- if (tiisg == 0) {
209
- buf[sgitg] = max;
210
- }
361
+ // find the max value in the block
362
+ float max_val = simd_max(lmax);
363
+ if (ntg > N_SIMDWIDTH) {
364
+ if (sgitg == 0) {
365
+ buf[tiisg] = -INFINITY;
366
+ }
211
367
 
212
- threadgroup_barrier(mem_flags::mem_threadgroup);
368
+ threadgroup_barrier(mem_flags::mem_threadgroup);
213
369
 
214
- // broadcast, simd group number is ntg / 32
215
- for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
216
- if (tpitg < i) {
217
- buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
218
- }
219
- }
370
+ if (tiisg == 0) {
371
+ buf[sgitg] = max_val;
372
+ }
220
373
 
221
- threadgroup_barrier(mem_flags::mem_threadgroup);
374
+ threadgroup_barrier(mem_flags::mem_threadgroup);
222
375
 
223
- max = buf[0];
376
+ max_val = buf[tiisg];
377
+ max_val = simd_max(max_val);
378
+ }
224
379
 
225
380
  // parallel sum
226
381
  float lsum = 0.0f;
227
382
  for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
228
- const float exp_psrc0 = exp(psrc0[i00] - max);
383
+ const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
229
384
  lsum += exp_psrc0;
230
- // Remember the result of exp here. exp is expensive, so we really do not
231
- // wish to compute it twice.
232
385
  pdst[i00] = exp_psrc0;
233
386
  }
234
387
 
235
388
  float sum = simd_sum(lsum);
236
- if (tiisg == 0) {
237
- buf[sgitg] = sum;
238
- }
389
+ if (ntg > N_SIMDWIDTH) {
390
+ if (sgitg == 0) {
391
+ buf[tiisg] = 0.0f;
392
+ }
239
393
 
240
- threadgroup_barrier(mem_flags::mem_threadgroup);
394
+ threadgroup_barrier(mem_flags::mem_threadgroup);
241
395
 
242
- // broadcast, simd group number is ntg / 32
243
- for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
244
- if (tpitg < i) {
245
- buf[tpitg] += buf[tpitg + i];
246
- }
247
- }
396
+ if (tiisg == 0) {
397
+ buf[sgitg] = sum;
398
+ }
248
399
 
249
- threadgroup_barrier(mem_flags::mem_threadgroup);
400
+ threadgroup_barrier(mem_flags::mem_threadgroup);
250
401
 
251
- sum = buf[0];
402
+ sum = buf[tiisg];
403
+ sum = simd_sum(sum);
404
+ }
405
+
406
+ const float inv_sum = 1.0f/sum;
252
407
 
253
408
  for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
254
- pdst[i00] /= sum;
409
+ pdst[i00] *= inv_sum;
255
410
  }
256
411
  }
257
412
 
258
413
  kernel void kernel_soft_max_4(
259
414
  device const float * src0,
415
+ device const float * src1,
260
416
  device float * dst,
261
417
  constant int64_t & ne00,
262
418
  constant int64_t & ne01,
263
419
  constant int64_t & ne02,
420
+ constant float & scale,
264
421
  threadgroup float * buf [[threadgroup(0)]],
265
422
  uint tgpig[[threadgroup_position_in_grid]],
266
423
  uint tpitg[[thread_position_in_threadgroup]],
@@ -271,64 +428,68 @@ kernel void kernel_soft_max_4(
271
428
  const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
272
429
  const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
273
430
 
274
- device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
275
- device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
431
+ device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
432
+ device const float4 * pmask = src1 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
433
+ device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
276
434
 
277
435
  // parallel max
278
- float4 lmax4 = tpitg < ne00/4 ? psrc4[tpitg] : -INFINITY;
436
+ float4 lmax4 = -INFINITY;
279
437
 
280
- for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) {
281
- lmax4 = fmax(lmax4, psrc4[i00]);
438
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
439
+ lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f));
282
440
  }
283
441
 
284
442
  const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
285
- float max = simd_max(lmax);
286
- if (tiisg == 0) {
287
- buf[sgitg] = max;
288
- }
289
443
 
290
- threadgroup_barrier(mem_flags::mem_threadgroup);
444
+ float max_val = simd_max(lmax);
445
+ if (ntg > N_SIMDWIDTH) {
446
+ if (sgitg == 0) {
447
+ buf[tiisg] = -INFINITY;
448
+ }
291
449
 
292
- // broadcast, simd group number is ntg / 32
293
- for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
294
- if (tpitg < i) {
295
- buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
296
- }
297
- }
450
+ threadgroup_barrier(mem_flags::mem_threadgroup);
298
451
 
299
- threadgroup_barrier(mem_flags::mem_threadgroup);
452
+ if (tiisg == 0) {
453
+ buf[sgitg] = max_val;
454
+ }
300
455
 
301
- max = buf[0];
456
+ threadgroup_barrier(mem_flags::mem_threadgroup);
457
+
458
+ max_val = buf[tiisg];
459
+ max_val = simd_max(max_val);
460
+ }
302
461
 
303
462
  // parallel sum
304
463
  float4 lsum4 = 0.0f;
305
464
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
306
- const float4 exp_psrc4 = exp(psrc4[i00] - max);
465
+ const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
307
466
  lsum4 += exp_psrc4;
308
467
  pdst4[i00] = exp_psrc4;
309
468
  }
310
469
 
311
470
  const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
312
471
  float sum = simd_sum(lsum);
313
- if (tiisg == 0) {
314
- buf[sgitg] = sum;
315
- }
472
+ if (ntg > N_SIMDWIDTH) {
473
+ if (sgitg == 0) {
474
+ buf[tiisg] = 0.0f;
475
+ }
316
476
 
317
- threadgroup_barrier(mem_flags::mem_threadgroup);
477
+ threadgroup_barrier(mem_flags::mem_threadgroup);
318
478
 
319
- // broadcast, simd group number is ntg / 32
320
- for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
321
- if (tpitg < i) {
322
- buf[tpitg] += buf[tpitg + i];
323
- }
324
- }
479
+ if (tiisg == 0) {
480
+ buf[sgitg] = sum;
481
+ }
325
482
 
326
- threadgroup_barrier(mem_flags::mem_threadgroup);
483
+ threadgroup_barrier(mem_flags::mem_threadgroup);
327
484
 
328
- sum = buf[0];
485
+ sum = buf[tiisg];
486
+ sum = simd_sum(sum);
487
+ }
488
+
489
+ const float inv_sum = 1.0f/sum;
329
490
 
330
491
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
331
- pdst4[i00] /= sum;
492
+ pdst4[i00] *= inv_sum;
332
493
  }
333
494
  }
334
495
 
@@ -435,14 +596,13 @@ kernel void kernel_rms_norm(
435
596
  constant int64_t & ne00,
436
597
  constant uint64_t & nb01,
437
598
  constant float & eps,
438
- threadgroup float * sum [[threadgroup(0)]],
599
+ threadgroup float * buf [[threadgroup(0)]],
439
600
  uint tgpig[[threadgroup_position_in_grid]],
440
601
  uint tpitg[[thread_position_in_threadgroup]],
441
602
  uint sgitg[[simdgroup_index_in_threadgroup]],
442
603
  uint tiisg[[thread_index_in_simdgroup]],
443
604
  uint ntg[[threads_per_threadgroup]]) {
444
- device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
445
- device const float * x_scalar = (device const float *) x;
605
+ device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
446
606
 
447
607
  float4 sumf = 0;
448
608
  float all_sum = 0;
@@ -453,40 +613,30 @@ kernel void kernel_rms_norm(
453
613
  }
454
614
  all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
455
615
  all_sum = simd_sum(all_sum);
456
- if (tiisg == 0) {
457
- sum[sgitg] = all_sum;
458
- }
616
+ if (ntg > N_SIMDWIDTH) {
617
+ if (sgitg == 0) {
618
+ buf[tiisg] = 0.0f;
619
+ }
459
620
 
460
- threadgroup_barrier(mem_flags::mem_threadgroup);
621
+ threadgroup_barrier(mem_flags::mem_threadgroup);
461
622
 
462
- // broadcast, simd group number is ntg / 32
463
- for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
464
- if (tpitg < i) {
465
- sum[tpitg] += sum[tpitg + i];
466
- }
467
- }
468
- if (tpitg == 0) {
469
- for (int i = 4 * (ne00 / 4); i < ne00; i++) {
470
- sum[0] += x_scalar[i];
623
+ if (tiisg == 0) {
624
+ buf[sgitg] = all_sum;
471
625
  }
472
- sum[0] /= ne00;
473
- }
474
626
 
475
- threadgroup_barrier(mem_flags::mem_threadgroup);
627
+ threadgroup_barrier(mem_flags::mem_threadgroup);
476
628
 
477
- const float mean = sum[0];
629
+ all_sum = buf[tiisg];
630
+ all_sum = simd_sum(all_sum);
631
+ }
632
+
633
+ const float mean = all_sum/ne00;
478
634
  const float scale = 1.0f/sqrt(mean + eps);
479
635
 
480
636
  device float4 * y = (device float4 *) (dst + tgpig*ne00);
481
- device float * y_scalar = (device float *) y;
482
637
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
483
638
  y[i00] = x[i00] * scale;
484
639
  }
485
- if (tpitg == 0) {
486
- for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
487
- y_scalar[i00] = x_scalar[i00] * scale;
488
- }
489
- }
490
640
  }
491
641
 
492
642
  // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
@@ -576,15 +726,25 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
576
726
  // putting them in the kernel cause a significant performance penalty
577
727
  #define N_DST 4 // each SIMD group works on 4 rows
578
728
  #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
579
- #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
580
729
  //Note: This is a template, but strictly speaking it only applies to
581
730
  // quantizations where the block size is 32. It also does not
582
731
  // giard against the number of rows not being divisible by
583
732
  // N_DST, so this is another explicit assumption of the implementation.
584
733
  template<typename block_q_type, int nr, int nsg, int nw>
585
- void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
586
- int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
587
- uint3 tgpig, uint tiisg, uint sgitg) {
734
+ void mul_vec_q_n_f32(
735
+ device const void * src0,
736
+ device const float * src1,
737
+ device float * dst,
738
+ int64_t ne00,
739
+ int64_t ne01,
740
+ int64_t ne02,
741
+ int64_t ne10,
742
+ int64_t ne12,
743
+ int64_t ne0,
744
+ int64_t ne1,
745
+ uint r2,
746
+ uint r3,
747
+ uint3 tgpig, uint tiisg, uint sgitg) {
588
748
  const int nb = ne00/QK4_0;
589
749
 
590
750
  const int r0 = tgpig.x;
@@ -593,7 +753,10 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
593
753
 
594
754
  const int first_row = (r0 * nsg + sgitg) * nr;
595
755
 
596
- const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
756
+ const uint i12 = im%ne12;
757
+ const uint i13 = im/ne12;
758
+
759
+ const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
597
760
 
598
761
  device const block_q_type * x = (device const block_q_type *) src0 + offset0;
599
762
  device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
@@ -643,13 +806,14 @@ kernel void kernel_mul_mv_q4_0_f32(
643
806
  constant int64_t & ne02[[buffer(5)]],
644
807
  constant int64_t & ne10[[buffer(9)]],
645
808
  constant int64_t & ne12[[buffer(11)]],
646
- constant int64_t & ne0[[buffer(15)]],
647
- constant int64_t & ne1[[buffer(16)]],
648
- constant uint & gqa[[buffer(17)]],
809
+ constant int64_t & ne0 [[buffer(15)]],
810
+ constant int64_t & ne1 [[buffer(16)]],
811
+ constant uint & r2 [[buffer(17)]],
812
+ constant uint & r3 [[buffer(18)]],
649
813
  uint3 tgpig[[threadgroup_position_in_grid]],
650
814
  uint tiisg[[thread_index_in_simdgroup]],
651
815
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
652
- mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
816
+ mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
653
817
  }
654
818
 
655
819
  kernel void kernel_mul_mv_q4_1_f32(
@@ -661,13 +825,14 @@ kernel void kernel_mul_mv_q4_1_f32(
661
825
  constant int64_t & ne02[[buffer(5)]],
662
826
  constant int64_t & ne10[[buffer(9)]],
663
827
  constant int64_t & ne12[[buffer(11)]],
664
- constant int64_t & ne0[[buffer(15)]],
665
- constant int64_t & ne1[[buffer(16)]],
666
- constant uint & gqa[[buffer(17)]],
828
+ constant int64_t & ne0 [[buffer(15)]],
829
+ constant int64_t & ne1 [[buffer(16)]],
830
+ constant uint & r2 [[buffer(17)]],
831
+ constant uint & r3 [[buffer(18)]],
667
832
  uint3 tgpig[[threadgroup_position_in_grid]],
668
833
  uint tiisg[[thread_index_in_simdgroup]],
669
834
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
670
- mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
835
+ mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
671
836
  }
672
837
 
673
838
  kernel void kernel_mul_mv_q5_0_f32(
@@ -679,13 +844,14 @@ kernel void kernel_mul_mv_q5_0_f32(
679
844
  constant int64_t & ne02[[buffer(5)]],
680
845
  constant int64_t & ne10[[buffer(9)]],
681
846
  constant int64_t & ne12[[buffer(11)]],
682
- constant int64_t & ne0[[buffer(15)]],
683
- constant int64_t & ne1[[buffer(16)]],
684
- constant uint & gqa[[buffer(17)]],
847
+ constant int64_t & ne0 [[buffer(15)]],
848
+ constant int64_t & ne1 [[buffer(16)]],
849
+ constant uint & r2 [[buffer(17)]],
850
+ constant uint & r3 [[buffer(18)]],
685
851
  uint3 tgpig[[threadgroup_position_in_grid]],
686
852
  uint tiisg[[thread_index_in_simdgroup]],
687
853
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
688
- mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
854
+ mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
689
855
  }
690
856
 
691
857
  kernel void kernel_mul_mv_q5_1_f32(
@@ -697,13 +863,14 @@ kernel void kernel_mul_mv_q5_1_f32(
697
863
  constant int64_t & ne02[[buffer(5)]],
698
864
  constant int64_t & ne10[[buffer(9)]],
699
865
  constant int64_t & ne12[[buffer(11)]],
700
- constant int64_t & ne0[[buffer(15)]],
701
- constant int64_t & ne1[[buffer(16)]],
702
- constant uint & gqa[[buffer(17)]],
866
+ constant int64_t & ne0 [[buffer(15)]],
867
+ constant int64_t & ne1 [[buffer(16)]],
868
+ constant uint & r2 [[buffer(17)]],
869
+ constant uint & r3 [[buffer(18)]],
703
870
  uint3 tgpig[[threadgroup_position_in_grid]],
704
871
  uint tiisg[[thread_index_in_simdgroup]],
705
872
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
706
- mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
873
+ mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
707
874
  }
708
875
 
709
876
 
@@ -718,9 +885,10 @@ kernel void kernel_mul_mv_q8_0_f32(
718
885
  constant int64_t & ne02[[buffer(5)]],
719
886
  constant int64_t & ne10[[buffer(9)]],
720
887
  constant int64_t & ne12[[buffer(11)]],
721
- constant int64_t & ne0[[buffer(15)]],
722
- constant int64_t & ne1[[buffer(16)]],
723
- constant uint & gqa[[buffer(17)]],
888
+ constant int64_t & ne0 [[buffer(15)]],
889
+ constant int64_t & ne1 [[buffer(16)]],
890
+ constant uint & r2 [[buffer(17)]],
891
+ constant uint & r3 [[buffer(18)]],
724
892
  uint3 tgpig[[threadgroup_position_in_grid]],
725
893
  uint tiisg[[thread_index_in_simdgroup]],
726
894
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -732,8 +900,14 @@ kernel void kernel_mul_mv_q8_0_f32(
732
900
  const int r0 = tgpig.x;
733
901
  const int r1 = tgpig.y;
734
902
  const int im = tgpig.z;
903
+
735
904
  const int first_row = (r0 * nsg + sgitg) * nr;
736
- const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
905
+
906
+ const uint i12 = im%ne12;
907
+ const uint i13 = im/ne12;
908
+
909
+ const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
910
+
737
911
  device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
738
912
  device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
739
913
 
@@ -791,14 +965,21 @@ kernel void kernel_mul_mv_f32_f32(
791
965
  constant uint64_t & nb12,
792
966
  constant int64_t & ne0,
793
967
  constant int64_t & ne1,
968
+ constant uint & r2 [[buffer(17)]],
969
+ constant uint & r3 [[buffer(18)]],
794
970
  uint3 tgpig[[threadgroup_position_in_grid]],
795
- uint tiisg[[thread_index_in_simdgroup]]) {
971
+ uint tiisg[[thread_index_in_simdgroup]]) {
796
972
 
797
973
  const int64_t r0 = tgpig.x;
798
974
  const int64_t rb = tgpig.y*N_F32_F32;
799
975
  const int64_t im = tgpig.z;
800
976
 
801
- device const float * x = (device const float *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
977
+ const uint i12 = im%ne12;
978
+ const uint i13 = im/ne12;
979
+
980
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
981
+
982
+ device const float * x = (device const float *) (src0 + offset0);
802
983
 
803
984
  if (ne00 < 128) {
804
985
  for (int row = 0; row < N_F32_F32; ++row) {
@@ -844,6 +1025,86 @@ kernel void kernel_mul_mv_f32_f32(
844
1025
  }
845
1026
  }
846
1027
 
1028
+ #define N_F16_F16 4
1029
+
1030
+ kernel void kernel_mul_mv_f16_f16(
1031
+ device const char * src0,
1032
+ device const char * src1,
1033
+ device float * dst,
1034
+ constant int64_t & ne00,
1035
+ constant int64_t & ne01,
1036
+ constant int64_t & ne02,
1037
+ constant uint64_t & nb00,
1038
+ constant uint64_t & nb01,
1039
+ constant uint64_t & nb02,
1040
+ constant int64_t & ne10,
1041
+ constant int64_t & ne11,
1042
+ constant int64_t & ne12,
1043
+ constant uint64_t & nb10,
1044
+ constant uint64_t & nb11,
1045
+ constant uint64_t & nb12,
1046
+ constant int64_t & ne0,
1047
+ constant int64_t & ne1,
1048
+ constant uint & r2 [[buffer(17)]],
1049
+ constant uint & r3 [[buffer(18)]],
1050
+ uint3 tgpig[[threadgroup_position_in_grid]],
1051
+ uint tiisg[[thread_index_in_simdgroup]]) {
1052
+
1053
+ const int64_t r0 = tgpig.x;
1054
+ const int64_t rb = tgpig.y*N_F16_F16;
1055
+ const int64_t im = tgpig.z;
1056
+
1057
+ const uint i12 = im%ne12;
1058
+ const uint i13 = im/ne12;
1059
+
1060
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1061
+
1062
+ device const half * x = (device const half *) (src0 + offset0);
1063
+
1064
+ if (ne00 < 128) {
1065
+ for (int row = 0; row < N_F16_F16; ++row) {
1066
+ int r1 = rb + row;
1067
+ if (r1 >= ne11) {
1068
+ break;
1069
+ }
1070
+
1071
+ device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
1072
+
1073
+ float sumf = 0;
1074
+ for (int i = tiisg; i < ne00; i += 32) {
1075
+ sumf += (half) x[i] * (half) y[i];
1076
+ }
1077
+
1078
+ float all_sum = simd_sum(sumf);
1079
+ if (tiisg == 0) {
1080
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1081
+ }
1082
+ }
1083
+ } else {
1084
+ device const half4 * x4 = (device const half4 *)x;
1085
+ for (int row = 0; row < N_F16_F16; ++row) {
1086
+ int r1 = rb + row;
1087
+ if (r1 >= ne11) {
1088
+ break;
1089
+ }
1090
+
1091
+ device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
1092
+ device const half4 * y4 = (device const half4 *) y;
1093
+
1094
+ float sumf = 0;
1095
+ for (int i = tiisg; i < ne00/4; i += 32) {
1096
+ for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k];
1097
+ }
1098
+
1099
+ float all_sum = simd_sum(sumf);
1100
+ if (tiisg == 0) {
1101
+ for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i];
1102
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1103
+ }
1104
+ }
1105
+ }
1106
+ }
1107
+
847
1108
  kernel void kernel_mul_mv_f16_f32_1row(
848
1109
  device const char * src0,
849
1110
  device const char * src1,
@@ -862,6 +1123,8 @@ kernel void kernel_mul_mv_f16_f32_1row(
862
1123
  constant uint64_t & nb12,
863
1124
  constant int64_t & ne0,
864
1125
  constant int64_t & ne1,
1126
+ constant uint & r2 [[buffer(17)]],
1127
+ constant uint & r3 [[buffer(18)]],
865
1128
  uint3 tgpig[[threadgroup_position_in_grid]],
866
1129
  uint tiisg[[thread_index_in_simdgroup]]) {
867
1130
 
@@ -869,7 +1132,12 @@ kernel void kernel_mul_mv_f16_f32_1row(
869
1132
  const int64_t r1 = tgpig.y;
870
1133
  const int64_t im = tgpig.z;
871
1134
 
872
- device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
1135
+ const uint i12 = im%ne12;
1136
+ const uint i13 = im/ne12;
1137
+
1138
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1139
+
1140
+ device const half * x = (device const half *) (src0 + offset0);
873
1141
  device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
874
1142
 
875
1143
  float sumf = 0;
@@ -916,6 +1184,8 @@ kernel void kernel_mul_mv_f16_f32(
916
1184
  constant uint64_t & nb12,
917
1185
  constant int64_t & ne0,
918
1186
  constant int64_t & ne1,
1187
+ constant uint & r2 [[buffer(17)]],
1188
+ constant uint & r3 [[buffer(18)]],
919
1189
  uint3 tgpig[[threadgroup_position_in_grid]],
920
1190
  uint tiisg[[thread_index_in_simdgroup]]) {
921
1191
 
@@ -923,7 +1193,12 @@ kernel void kernel_mul_mv_f16_f32(
923
1193
  const int64_t rb = tgpig.y*N_F16_F32;
924
1194
  const int64_t im = tgpig.z;
925
1195
 
926
- device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
1196
+ const uint i12 = im%ne12;
1197
+ const uint i13 = im/ne12;
1198
+
1199
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1200
+
1201
+ device const half * x = (device const half *) (src0 + offset0);
927
1202
 
928
1203
  if (ne00 < 128) {
929
1204
  for (int row = 0; row < N_F16_F32; ++row) {
@@ -988,6 +1263,8 @@ kernel void kernel_mul_mv_f16_f32_l4(
988
1263
  constant uint64_t & nb12,
989
1264
  constant int64_t & ne0,
990
1265
  constant int64_t & ne1,
1266
+ constant uint & r2 [[buffer(17)]],
1267
+ constant uint & r3 [[buffer(18)]],
991
1268
  uint3 tgpig[[threadgroup_position_in_grid]],
992
1269
  uint tiisg[[thread_index_in_simdgroup]]) {
993
1270
 
@@ -995,7 +1272,12 @@ kernel void kernel_mul_mv_f16_f32_l4(
995
1272
  const int64_t r0 = tgpig.x;
996
1273
  const int64_t im = tgpig.z;
997
1274
 
998
- device const half4 * x4 = (device const half4 *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
1275
+ const uint i12 = im%ne12;
1276
+ const uint i13 = im/ne12;
1277
+
1278
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1279
+
1280
+ device const half4 * x4 = (device const half4 *) (src0 + offset0);
999
1281
 
1000
1282
  for (int r1 = 0; r1 < nrows; ++r1) {
1001
1283
  device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
@@ -1047,17 +1329,21 @@ kernel void kernel_alibi_f32(
1047
1329
  const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1048
1330
  const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1049
1331
  const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1332
+ const int64_t k = i3*ne3 + i2;
1050
1333
 
1051
- device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1052
1334
  float m_k;
1053
- if (i2 < n_heads_log2_floor) {
1054
- m_k = pow(m0, i2 + 1);
1335
+ if (k < n_heads_log2_floor) {
1336
+ m_k = pow(m0, k + 1);
1055
1337
  } else {
1056
- m_k = pow(m1, 2 * (i2 - n_heads_log2_floor) + 1);
1338
+ m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1);
1057
1339
  }
1340
+
1341
+ device char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1;
1342
+ device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01;
1058
1343
  for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
1059
- device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1060
- dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);
1344
+ const float src_v = *(device float *)(src_row + i00*nb00);
1345
+ device float * dst_v = (device float *)(dst_row + i00*nb0);
1346
+ *dst_v = i00 * m_k + src_v;
1061
1347
  }
1062
1348
  }
1063
1349
 
@@ -1201,33 +1487,118 @@ kernel void kernel_rope(
1201
1487
  dst_data[1] = x0*sin_theta + x1*cos_theta;
1202
1488
  }
1203
1489
  } else {
1204
- for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
1205
- for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
1490
+ for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
1491
+ for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
1492
+
1493
+ // simplified from `(ib * n_dims + ic) * inv_ndims`
1494
+ const float cur_rot = inv_ndims*ic - ib;
1495
+
1496
+ const float theta = theta_0 * pow(freq_base, cur_rot);
1497
+ float cos_theta, sin_theta;
1498
+ rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
1499
+
1500
+ const int64_t i0 = ib*n_dims + ic/2;
1501
+
1502
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1503
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1504
+
1505
+ const float x0 = src[0];
1506
+ const float x1 = src[n_dims/2];
1507
+
1508
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
1509
+ dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
1510
+ }
1511
+ }
1512
+ }
1513
+ }
1514
+
1515
+ template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
1516
+ template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
1517
+
1518
+ kernel void kernel_im2col_f16(
1519
+ device const float * x,
1520
+ device half * dst,
1521
+ constant int32_t & ofs0,
1522
+ constant int32_t & ofs1,
1523
+ constant int32_t & IW,
1524
+ constant int32_t & IH,
1525
+ constant int32_t & CHW,
1526
+ constant int32_t & s0,
1527
+ constant int32_t & s1,
1528
+ constant int32_t & p0,
1529
+ constant int32_t & p1,
1530
+ constant int32_t & d0,
1531
+ constant int32_t & d1,
1532
+ uint3 tgpig[[threadgroup_position_in_grid]],
1533
+ uint3 tgpg[[threadgroups_per_grid]],
1534
+ uint3 tpitg[[thread_position_in_threadgroup]],
1535
+ uint3 ntg[[threads_per_threadgroup]]) {
1536
+ const int32_t iiw = tgpig[2] * s0 + tpitg[2] * d0 - p0;
1537
+ const int32_t iih = tgpig[1] * s1 + tpitg[1] * d1 - p1;
1538
+
1539
+ const int32_t offset_dst =
1540
+ (tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
1541
+ (tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
1542
+
1543
+ if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
1544
+ dst[offset_dst] = 0.0f;
1545
+ } else {
1546
+ const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
1547
+ dst[offset_dst] = x[offset_src + iih * IW + iiw];
1548
+ }
1549
+ }
1206
1550
 
1207
- // simplified from `(ib * n_dims + ic) * inv_ndims`
1208
- const float cur_rot = inv_ndims*ic - ib;
1551
+ // bitonic sort implementation following the CUDA kernels as reference
1552
+ typedef void (argsort_t)(
1553
+ device const float * x,
1554
+ device int32_t * dst,
1555
+ constant int64_t & ncols,
1556
+ uint3 tgpig[[threadgroup_position_in_grid]],
1557
+ uint3 tpitg[[thread_position_in_threadgroup]]);
1209
1558
 
1210
- const float theta = theta_0 * pow(freq_base, cur_rot);
1211
- float cos_theta, sin_theta;
1212
- rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
1559
+ template<ggml_sort_order order>
1560
+ kernel void kernel_argsort_f32_i32(
1561
+ device const float * x,
1562
+ device int32_t * dst,
1563
+ constant int64_t & ncols,
1564
+ uint3 tgpig[[threadgroup_position_in_grid]],
1565
+ uint3 tpitg[[thread_position_in_threadgroup]]) {
1566
+ // bitonic sort
1567
+ int col = tpitg[0];
1568
+ int row = tgpig[1];
1213
1569
 
1214
- const int64_t i0 = ib*n_dims + ic/2;
1570
+ if (col >= ncols) return;
1215
1571
 
1216
- device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1217
- device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1572
+ device const float * x_row = x + row * ncols;
1573
+ device int32_t * dst_row = dst + row * ncols;
1218
1574
 
1219
- const float x0 = src[0];
1220
- const float x1 = src[n_dims/2];
1575
+ // initialize indices
1576
+ if (col < ncols) {
1577
+ dst_row[col] = col;
1578
+ }
1579
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1221
1580
 
1222
- dst_data[0] = x0*cos_theta - x1*sin_theta;
1223
- dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
1581
+ for (int k = 2; k <= ncols; k *= 2) {
1582
+ for (int j = k / 2; j > 0; j /= 2) {
1583
+ int ixj = col ^ j;
1584
+ if (ixj > col) {
1585
+ if ((col & k) == 0) {
1586
+ if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
1587
+ SWAP(dst_row[col], dst_row[ixj]);
1588
+ }
1589
+ } else {
1590
+ if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
1591
+ SWAP(dst_row[col], dst_row[ixj]);
1592
+ }
1593
+ }
1224
1594
  }
1595
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1225
1596
  }
1226
1597
  }
1227
1598
  }
1228
1599
 
1229
- template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
1230
- template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
1600
+ template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ASC>;
1601
+ template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
1231
1602
 
1232
1603
  kernel void kernel_cpy_f16_f16(
1233
1604
  device const half * src0,
@@ -1354,6 +1725,197 @@ kernel void kernel_cpy_f32_f32(
1354
1725
  }
1355
1726
  }
1356
1727
 
1728
+ kernel void kernel_cpy_f32_q8_0(
1729
+ device const float * src0,
1730
+ device void * dst,
1731
+ constant int64_t & ne00,
1732
+ constant int64_t & ne01,
1733
+ constant int64_t & ne02,
1734
+ constant int64_t & ne03,
1735
+ constant uint64_t & nb00,
1736
+ constant uint64_t & nb01,
1737
+ constant uint64_t & nb02,
1738
+ constant uint64_t & nb03,
1739
+ constant int64_t & ne0,
1740
+ constant int64_t & ne1,
1741
+ constant int64_t & ne2,
1742
+ constant int64_t & ne3,
1743
+ constant uint64_t & nb0,
1744
+ constant uint64_t & nb1,
1745
+ constant uint64_t & nb2,
1746
+ constant uint64_t & nb3,
1747
+ uint3 tgpig[[threadgroup_position_in_grid]],
1748
+ uint3 tpitg[[thread_position_in_threadgroup]],
1749
+ uint3 ntg[[threads_per_threadgroup]]) {
1750
+ const int64_t i03 = tgpig[2];
1751
+ const int64_t i02 = tgpig[1];
1752
+ const int64_t i01 = tgpig[0];
1753
+
1754
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
1755
+
1756
+ const int64_t i3 = n / (ne2*ne1*ne0);
1757
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1758
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1759
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK8_0;
1760
+
1761
+ device block_q8_0 * dst_data = (device block_q8_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1762
+
1763
+ for (int64_t i00 = tpitg.x*QK8_0; i00 < ne00; i00 += ntg.x*QK8_0) {
1764
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1765
+
1766
+ float amax = 0.0f; // absolute max
1767
+
1768
+ for (int j = 0; j < QK8_0; j++) {
1769
+ const float v = src[j];
1770
+ amax = MAX(amax, fabs(v));
1771
+ }
1772
+
1773
+ const float d = amax / ((1 << 7) - 1);
1774
+ const float id = d ? 1.0f/d : 0.0f;
1775
+
1776
+ dst_data[i00/QK8_0].d = d;
1777
+
1778
+ for (int j = 0; j < QK8_0; ++j) {
1779
+ const float x0 = src[j]*id;
1780
+
1781
+ dst_data[i00/QK8_0].qs[j] = round(x0);
1782
+ }
1783
+ }
1784
+ }
1785
+
1786
+ kernel void kernel_cpy_f32_q4_0(
1787
+ device const float * src0,
1788
+ device void * dst,
1789
+ constant int64_t & ne00,
1790
+ constant int64_t & ne01,
1791
+ constant int64_t & ne02,
1792
+ constant int64_t & ne03,
1793
+ constant uint64_t & nb00,
1794
+ constant uint64_t & nb01,
1795
+ constant uint64_t & nb02,
1796
+ constant uint64_t & nb03,
1797
+ constant int64_t & ne0,
1798
+ constant int64_t & ne1,
1799
+ constant int64_t & ne2,
1800
+ constant int64_t & ne3,
1801
+ constant uint64_t & nb0,
1802
+ constant uint64_t & nb1,
1803
+ constant uint64_t & nb2,
1804
+ constant uint64_t & nb3,
1805
+ uint3 tgpig[[threadgroup_position_in_grid]],
1806
+ uint3 tpitg[[thread_position_in_threadgroup]],
1807
+ uint3 ntg[[threads_per_threadgroup]]) {
1808
+ const int64_t i03 = tgpig[2];
1809
+ const int64_t i02 = tgpig[1];
1810
+ const int64_t i01 = tgpig[0];
1811
+
1812
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
1813
+
1814
+ const int64_t i3 = n / (ne2*ne1*ne0);
1815
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1816
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1817
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_0;
1818
+
1819
+ device block_q4_0 * dst_data = (device block_q4_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1820
+
1821
+ for (int64_t i00 = tpitg.x*QK4_0; i00 < ne00; i00 += ntg.x*QK4_0) {
1822
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1823
+
1824
+ float amax = 0.0f; // absolute max
1825
+ float max = 0.0f;
1826
+
1827
+ for (int j = 0; j < QK4_0; j++) {
1828
+ const float v = src[j];
1829
+ if (amax < fabs(v)) {
1830
+ amax = fabs(v);
1831
+ max = v;
1832
+ }
1833
+ }
1834
+
1835
+ const float d = max / -8;
1836
+ const float id = d ? 1.0f/d : 0.0f;
1837
+
1838
+ dst_data[i00/QK4_0].d = d;
1839
+
1840
+ for (int j = 0; j < QK4_0/2; ++j) {
1841
+ const float x0 = src[0 + j]*id;
1842
+ const float x1 = src[QK4_0/2 + j]*id;
1843
+
1844
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
1845
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
1846
+
1847
+ dst_data[i00/QK4_0].qs[j] = xi0;
1848
+ dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
1849
+ }
1850
+ }
1851
+ }
1852
+
1853
+ kernel void kernel_cpy_f32_q4_1(
1854
+ device const float * src0,
1855
+ device void * dst,
1856
+ constant int64_t & ne00,
1857
+ constant int64_t & ne01,
1858
+ constant int64_t & ne02,
1859
+ constant int64_t & ne03,
1860
+ constant uint64_t & nb00,
1861
+ constant uint64_t & nb01,
1862
+ constant uint64_t & nb02,
1863
+ constant uint64_t & nb03,
1864
+ constant int64_t & ne0,
1865
+ constant int64_t & ne1,
1866
+ constant int64_t & ne2,
1867
+ constant int64_t & ne3,
1868
+ constant uint64_t & nb0,
1869
+ constant uint64_t & nb1,
1870
+ constant uint64_t & nb2,
1871
+ constant uint64_t & nb3,
1872
+ uint3 tgpig[[threadgroup_position_in_grid]],
1873
+ uint3 tpitg[[thread_position_in_threadgroup]],
1874
+ uint3 ntg[[threads_per_threadgroup]]) {
1875
+ const int64_t i03 = tgpig[2];
1876
+ const int64_t i02 = tgpig[1];
1877
+ const int64_t i01 = tgpig[0];
1878
+
1879
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
1880
+
1881
+ const int64_t i3 = n / (ne2*ne1*ne0);
1882
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1883
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1884
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_1;
1885
+
1886
+ device block_q4_1 * dst_data = (device block_q4_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1887
+
1888
+ for (int64_t i00 = tpitg.x*QK4_1; i00 < ne00; i00 += ntg.x*QK4_1) {
1889
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1890
+
1891
+ float min = FLT_MAX;
1892
+ float max = -FLT_MAX;
1893
+
1894
+ for (int j = 0; j < QK4_1; j++) {
1895
+ const float v = src[j];
1896
+ if (min > v) min = v;
1897
+ if (max < v) max = v;
1898
+ }
1899
+
1900
+ const float d = (max - min) / ((1 << 4) - 1);
1901
+ const float id = d ? 1.0f/d : 0.0f;
1902
+
1903
+ dst_data[i00/QK4_1].d = d;
1904
+ dst_data[i00/QK4_1].m = min;
1905
+
1906
+ for (int j = 0; j < QK4_1/2; ++j) {
1907
+ const float x0 = (src[0 + j] - min)*id;
1908
+ const float x1 = (src[QK4_1/2 + j] - min)*id;
1909
+
1910
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
1911
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
1912
+
1913
+ dst_data[i00/QK4_1].qs[j] = xi0;
1914
+ dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
1915
+ }
1916
+ }
1917
+ }
1918
+
1357
1919
  kernel void kernel_concat(
1358
1920
  device const char * src0,
1359
1921
  device const char * src1,
@@ -1511,23 +2073,30 @@ kernel void kernel_mul_mv_q2_K_f32(
1511
2073
  constant int64_t & ne02[[buffer(5)]],
1512
2074
  constant int64_t & ne10[[buffer(9)]],
1513
2075
  constant int64_t & ne12[[buffer(11)]],
1514
- constant int64_t & ne0[[buffer(15)]],
1515
- constant int64_t & ne1[[buffer(16)]],
1516
- constant uint & gqa[[buffer(17)]],
2076
+ constant int64_t & ne0 [[buffer(15)]],
2077
+ constant int64_t & ne1 [[buffer(16)]],
2078
+ constant uint & r2 [[buffer(17)]],
2079
+ constant uint & r3 [[buffer(18)]],
1517
2080
  uint3 tgpig[[threadgroup_position_in_grid]],
1518
- uint tiisg[[thread_index_in_simdgroup]],
1519
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2081
+ uint tiisg[[thread_index_in_simdgroup]],
2082
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1520
2083
 
1521
2084
  const int nb = ne00/QK_K;
1522
2085
  const int r0 = tgpig.x;
1523
2086
  const int r1 = tgpig.y;
1524
- const int r2 = tgpig.z;
2087
+ const int im = tgpig.z;
1525
2088
 
1526
2089
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1527
2090
  const int ib_row = first_row * nb;
1528
- const uint offset0 = r2/gqa*(nb*ne0);
2091
+
2092
+ const uint i12 = im%ne12;
2093
+ const uint i13 = im/ne12;
2094
+
2095
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2096
+
1529
2097
  device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
1530
- device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2098
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2099
+
1531
2100
  float yl[32];
1532
2101
  float sumf[N_DST]={0.f}, all_sum;
1533
2102
 
@@ -1536,11 +2105,11 @@ kernel void kernel_mul_mv_q2_K_f32(
1536
2105
  #if QK_K == 256
1537
2106
  const int ix = tiisg/8; // 0...3
1538
2107
  const int it = tiisg%8; // 0...7
1539
- const int im = it/4; // 0 or 1
2108
+ const int iq = it/4; // 0 or 1
1540
2109
  const int ir = it%4; // 0...3
1541
2110
  const int is = (8*ir)/16;// 0 or 1
1542
2111
 
1543
- device const float * y4 = y + ix * QK_K + 128 * im + 8 * ir;
2112
+ device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
1544
2113
 
1545
2114
  for (int ib = ix; ib < nb; ib += 4) {
1546
2115
 
@@ -1552,8 +2121,8 @@ kernel void kernel_mul_mv_q2_K_f32(
1552
2121
  yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
1553
2122
  }
1554
2123
 
1555
- device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*im + is;
1556
- device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
2124
+ device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*iq + is;
2125
+ device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
1557
2126
  device const half * dh = &x[ib].d;
1558
2127
 
1559
2128
  for (int row = 0; row < N_DST; row++) {
@@ -1640,7 +2209,7 @@ kernel void kernel_mul_mv_q2_K_f32(
1640
2209
  for (int row = 0; row < N_DST; ++row) {
1641
2210
  all_sum = simd_sum(sumf[row]);
1642
2211
  if (tiisg == 0) {
1643
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
2212
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
1644
2213
  }
1645
2214
  }
1646
2215
  }
@@ -1655,9 +2224,10 @@ kernel void kernel_mul_mv_q3_K_f32(
1655
2224
  constant int64_t & ne02[[buffer(5)]],
1656
2225
  constant int64_t & ne10[[buffer(9)]],
1657
2226
  constant int64_t & ne12[[buffer(11)]],
1658
- constant int64_t & ne0[[buffer(15)]],
1659
- constant int64_t & ne1[[buffer(16)]],
1660
- constant uint & gqa[[buffer(17)]],
2227
+ constant int64_t & ne0 [[buffer(15)]],
2228
+ constant int64_t & ne1 [[buffer(16)]],
2229
+ constant uint & r2 [[buffer(17)]],
2230
+ constant uint & r3 [[buffer(18)]],
1661
2231
  uint3 tgpig[[threadgroup_position_in_grid]],
1662
2232
  uint tiisg[[thread_index_in_simdgroup]],
1663
2233
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1666,12 +2236,17 @@ kernel void kernel_mul_mv_q3_K_f32(
1666
2236
 
1667
2237
  const int64_t r0 = tgpig.x;
1668
2238
  const int64_t r1 = tgpig.y;
1669
- const int64_t r2 = tgpig.z;
2239
+ const int64_t im = tgpig.z;
1670
2240
 
1671
2241
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
1672
- const uint offset0 = r2/gqa*(nb*ne0);
2242
+
2243
+ const uint i12 = im%ne12;
2244
+ const uint i13 = im/ne12;
2245
+
2246
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2247
+
1673
2248
  device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
1674
- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2249
+ device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
1675
2250
 
1676
2251
  float yl[32];
1677
2252
 
@@ -1793,7 +2368,7 @@ kernel void kernel_mul_mv_q3_K_f32(
1793
2368
  }
1794
2369
  if (tiisg == 0) {
1795
2370
  for (int row = 0; row < 2; ++row) {
1796
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
2371
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = sumf1[row];
1797
2372
  }
1798
2373
  }
1799
2374
  }
@@ -1807,26 +2382,33 @@ kernel void kernel_mul_mv_q3_K_f32(
1807
2382
  constant int64_t & ne02[[buffer(5)]],
1808
2383
  constant int64_t & ne10[[buffer(9)]],
1809
2384
  constant int64_t & ne12[[buffer(11)]],
1810
- constant int64_t & ne0[[buffer(15)]],
1811
- constant int64_t & ne1[[buffer(16)]],
1812
- constant uint & gqa[[buffer(17)]],
2385
+ constant int64_t & ne0 [[buffer(15)]],
2386
+ constant int64_t & ne1 [[buffer(16)]],
2387
+ constant uint & r2 [[buffer(17)]],
2388
+ constant uint & r3 [[buffer(18)]],
1813
2389
  uint3 tgpig[[threadgroup_position_in_grid]],
1814
- uint tiisg[[thread_index_in_simdgroup]],
1815
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2390
+ uint tiisg[[thread_index_in_simdgroup]],
2391
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1816
2392
 
1817
2393
  const int nb = ne00/QK_K;
1818
2394
 
1819
2395
  const int64_t r0 = tgpig.x;
1820
2396
  const int64_t r1 = tgpig.y;
1821
- const int64_t r2 = tgpig.z;
2397
+ const int64_t im = tgpig.z;
1822
2398
 
1823
2399
  const int row = 2 * r0 + sgitg;
1824
- const uint offset0 = r2/gqa*(nb*ne0);
2400
+
2401
+ const uint i12 = im%ne12;
2402
+ const uint i13 = im/ne12;
2403
+
2404
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2405
+
1825
2406
  device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
1826
- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2407
+ device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2408
+
1827
2409
  const int ix = tiisg/4;
1828
2410
  const int il = 4 * (tiisg%4);// 0, 4, 8, 12
1829
- const int im = il/8; // 0, 0, 1, 1
2411
+ const int iq = il/8; // 0, 0, 1, 1
1830
2412
  const int in = il%8; // 0, 4, 0, 4
1831
2413
 
1832
2414
  float2 sum = {0.f, 0.f};
@@ -1846,7 +2428,7 @@ kernel void kernel_mul_mv_q3_K_f32(
1846
2428
  const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
1847
2429
 
1848
2430
  for (int l = 0; l < 4; l += 2) {
1849
- const uint16_t hm = h[l/2] >> im;
2431
+ const uint16_t hm = h[l/2] >> iq;
1850
2432
  sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
1851
2433
  + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
1852
2434
  + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
@@ -1862,7 +2444,7 @@ kernel void kernel_mul_mv_q3_K_f32(
1862
2444
 
1863
2445
  const float tot = simd_sum(sumf);
1864
2446
  if (tiisg == 0) {
1865
- dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
2447
+ dst[r1*ne0 + im*ne0*ne1 + row] = tot;
1866
2448
  }
1867
2449
 
1868
2450
  }
@@ -1880,10 +2462,11 @@ kernel void kernel_mul_mv_q4_K_f32(
1880
2462
  constant int64_t & ne12 [[buffer(11)]],
1881
2463
  constant int64_t & ne0 [[buffer(15)]],
1882
2464
  constant int64_t & ne1 [[buffer(16)]],
1883
- constant uint & gqa [[buffer(17)]],
2465
+ constant uint & r2 [[buffer(17)]],
2466
+ constant uint & r3 [[buffer(18)]],
1884
2467
  uint3 tgpig[[threadgroup_position_in_grid]],
1885
- uint tiisg[[thread_index_in_simdgroup]],
1886
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2468
+ uint tiisg[[thread_index_in_simdgroup]],
2469
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1887
2470
 
1888
2471
  const uint16_t kmask1 = 0x3f3f;
1889
2472
  const uint16_t kmask2 = 0x0f0f;
@@ -1891,26 +2474,32 @@ kernel void kernel_mul_mv_q4_K_f32(
1891
2474
 
1892
2475
  const int ix = tiisg/8; // 0...3
1893
2476
  const int it = tiisg%8; // 0...7
1894
- const int im = it/4; // 0 or 1
2477
+ const int iq = it/4; // 0 or 1
1895
2478
  const int ir = it%4; // 0...3
1896
2479
 
1897
2480
  const int nb = ne00/QK_K;
1898
2481
  const int r0 = tgpig.x;
1899
2482
  const int r1 = tgpig.y;
1900
- const int r2 = tgpig.z;
2483
+ const int im = tgpig.z;
1901
2484
  //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1902
2485
  const int first_row = r0 * N_DST;
1903
2486
  const int ib_row = first_row * nb;
1904
- const uint offset0 = r2/gqa*(nb*ne0);
2487
+
2488
+ const uint i12 = im%ne12;
2489
+ const uint i13 = im/ne12;
2490
+
2491
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2492
+
1905
2493
  device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
1906
- device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2494
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2495
+
1907
2496
  float yl[16];
1908
2497
  float yh[16];
1909
2498
  float sumf[N_DST]={0.f}, all_sum;
1910
2499
 
1911
2500
  const int step = sizeof(block_q4_K) * nb / 2;
1912
2501
 
1913
- device const float * y4 = y + ix * QK_K + 64 * im + 8 * ir;
2502
+ device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
1914
2503
 
1915
2504
  uint16_t sc16[4];
1916
2505
  thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
@@ -1925,8 +2514,8 @@ kernel void kernel_mul_mv_q4_K_f32(
1925
2514
  yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
1926
2515
  }
1927
2516
 
1928
- device const uint16_t * sc = (device const uint16_t *)x[ib].scales + im;
1929
- device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
2517
+ device const uint16_t * sc = (device const uint16_t *)x[ib].scales + iq;
2518
+ device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
1930
2519
  device const half * dh = &x[ib].d;
1931
2520
 
1932
2521
  for (int row = 0; row < N_DST; row++) {
@@ -1970,7 +2559,7 @@ kernel void kernel_mul_mv_q4_K_f32(
1970
2559
  for (int row = 0; row < N_DST; ++row) {
1971
2560
  all_sum = simd_sum(sumf[row]);
1972
2561
  if (tiisg == 0) {
1973
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
2562
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
1974
2563
  }
1975
2564
  }
1976
2565
  }
@@ -1984,9 +2573,10 @@ kernel void kernel_mul_mv_q4_K_f32(
1984
2573
  constant int64_t & ne02[[buffer(5)]],
1985
2574
  constant int64_t & ne10[[buffer(9)]],
1986
2575
  constant int64_t & ne12[[buffer(11)]],
1987
- constant int64_t & ne0[[buffer(15)]],
1988
- constant int64_t & ne1[[buffer(16)]],
1989
- constant uint & gqa[[buffer(17)]],
2576
+ constant int64_t & ne0 [[buffer(15)]],
2577
+ constant int64_t & ne1 [[buffer(16)]],
2578
+ constant uint & r2 [[buffer(17)]],
2579
+ constant uint & r3 [[buffer(18)]],
1990
2580
  uint3 tgpig[[threadgroup_position_in_grid]],
1991
2581
  uint tiisg[[thread_index_in_simdgroup]],
1992
2582
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1997,12 +2587,18 @@ kernel void kernel_mul_mv_q4_K_f32(
1997
2587
  const int nb = ne00/QK_K;
1998
2588
  const int r0 = tgpig.x;
1999
2589
  const int r1 = tgpig.y;
2000
- const int r2 = tgpig.z;
2590
+ const int im = tgpig.z;
2001
2591
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
2002
2592
  const int ib_row = first_row * nb;
2003
- const uint offset0 = r2/gqa*(nb*ne0);
2593
+
2594
+ const uint i12 = im%ne12;
2595
+ const uint i13 = im/ne12;
2596
+
2597
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2598
+
2004
2599
  device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
2005
- device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2600
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2601
+
2006
2602
  float yl[8];
2007
2603
  float yh[8];
2008
2604
  float sumf[N_DST]={0.f}, all_sum;
@@ -2058,7 +2654,7 @@ kernel void kernel_mul_mv_q4_K_f32(
2058
2654
  for (int row = 0; row < N_DST; ++row) {
2059
2655
  all_sum = simd_sum(sumf[row]);
2060
2656
  if (tiisg == 0) {
2061
- dst[r1*ne0+ r2*ne0*ne1 + first_row + row] = all_sum;
2657
+ dst[r1*ne0+ im*ne0*ne1 + first_row + row] = all_sum;
2062
2658
  }
2063
2659
  }
2064
2660
  }
@@ -2073,9 +2669,10 @@ kernel void kernel_mul_mv_q5_K_f32(
2073
2669
  constant int64_t & ne02[[buffer(5)]],
2074
2670
  constant int64_t & ne10[[buffer(9)]],
2075
2671
  constant int64_t & ne12[[buffer(11)]],
2076
- constant int64_t & ne0[[buffer(15)]],
2077
- constant int64_t & ne1[[buffer(16)]],
2078
- constant uint & gqa[[buffer(17)]],
2672
+ constant int64_t & ne0 [[buffer(15)]],
2673
+ constant int64_t & ne1 [[buffer(16)]],
2674
+ constant uint & r2 [[buffer(17)]],
2675
+ constant uint & r3 [[buffer(18)]],
2079
2676
  uint3 tgpig[[threadgroup_position_in_grid]],
2080
2677
  uint tiisg[[thread_index_in_simdgroup]],
2081
2678
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2084,12 +2681,17 @@ kernel void kernel_mul_mv_q5_K_f32(
2084
2681
 
2085
2682
  const int64_t r0 = tgpig.x;
2086
2683
  const int64_t r1 = tgpig.y;
2087
- const int r2 = tgpig.z;
2684
+ const int im = tgpig.z;
2088
2685
 
2089
2686
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
2090
- const uint offset0 = r2/gqa*(nb*ne0);
2687
+
2688
+ const uint i12 = im%ne12;
2689
+ const uint i13 = im/ne12;
2690
+
2691
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2692
+
2091
2693
  device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
2092
- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2694
+ device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2093
2695
 
2094
2696
  float sumf[2]={0.f};
2095
2697
 
@@ -2105,15 +2707,15 @@ kernel void kernel_mul_mv_q5_K_f32(
2105
2707
 
2106
2708
  const int tid = tiisg/4;
2107
2709
  const int ix = tiisg%4;
2108
- const int im = tid/4;
2710
+ const int iq = tid/4;
2109
2711
  const int ir = tid%4;
2110
2712
  const int n = 8;
2111
2713
 
2112
2714
  const int l0 = n*ir;
2113
- const int q_offset = 32*im + l0;
2114
- const int y_offset = 64*im + l0;
2715
+ const int q_offset = 32*iq + l0;
2716
+ const int y_offset = 64*iq + l0;
2115
2717
 
2116
- const uint8_t hm1 = 1u << (2*im);
2718
+ const uint8_t hm1 = 1u << (2*iq);
2117
2719
  const uint8_t hm2 = hm1 << 1;
2118
2720
  const uint8_t hm3 = hm1 << 4;
2119
2721
  const uint8_t hm4 = hm2 << 4;
@@ -2128,7 +2730,7 @@ kernel void kernel_mul_mv_q5_K_f32(
2128
2730
  device const uint8_t * q1 = x[i].qs + q_offset;
2129
2731
  device const uint8_t * qh = x[i].qh + l0;
2130
2732
  device const half * dh = &x[i].d;
2131
- device const uint16_t * a = (device const uint16_t *)x[i].scales + im;
2733
+ device const uint16_t * a = (device const uint16_t *)x[i].scales + iq;
2132
2734
 
2133
2735
  device const float * y2 = y1 + 128;
2134
2736
  float4 sumy = {0.f, 0.f, 0.f, 0.f};
@@ -2184,7 +2786,7 @@ kernel void kernel_mul_mv_q5_K_f32(
2184
2786
 
2185
2787
  const int il = 4 * (tiisg/8); // 0, 4, 8, 12
2186
2788
  const int ix = tiisg%8;
2187
- const int im = il/8; // 0, 0, 1, 1
2789
+ const int iq = il/8; // 0, 0, 1, 1
2188
2790
  const int in = il%8; // 0, 4, 0, 4
2189
2791
 
2190
2792
  device const float * y = yy + ix*QK_K + il;
@@ -2209,7 +2811,7 @@ kernel void kernel_mul_mv_q5_K_f32(
2209
2811
 
2210
2812
  float2 acc = {0.f, 0.f};
2211
2813
  for (int l = 0; l < 4; ++l) {
2212
- const uint8_t hl = h[l] >> im;
2814
+ const uint8_t hl = h[l] >> iq;
2213
2815
  acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
2214
2816
  + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
2215
2817
  acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
@@ -2231,7 +2833,7 @@ kernel void kernel_mul_mv_q5_K_f32(
2231
2833
  for (int row = 0; row < 2; ++row) {
2232
2834
  const float tot = simd_sum(sumf[row]);
2233
2835
  if (tiisg == 0) {
2234
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
2836
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
2235
2837
  }
2236
2838
  }
2237
2839
 
@@ -2246,9 +2848,10 @@ kernel void kernel_mul_mv_q6_K_f32(
2246
2848
  constant int64_t & ne02[[buffer(5)]],
2247
2849
  constant int64_t & ne10[[buffer(9)]],
2248
2850
  constant int64_t & ne12[[buffer(11)]],
2249
- constant int64_t & ne0[[buffer(15)]],
2250
- constant int64_t & ne1[[buffer(16)]],
2251
- constant uint & gqa[[buffer(17)]],
2851
+ constant int64_t & ne0 [[buffer(15)]],
2852
+ constant int64_t & ne1 [[buffer(16)]],
2853
+ constant uint & r2 [[buffer(17)]],
2854
+ constant uint & r3 [[buffer(18)]],
2252
2855
  uint3 tgpig[[threadgroup_position_in_grid]],
2253
2856
  uint tiisg[[thread_index_in_simdgroup]],
2254
2857
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2262,12 +2865,17 @@ kernel void kernel_mul_mv_q6_K_f32(
2262
2865
 
2263
2866
  const int64_t r0 = tgpig.x;
2264
2867
  const int64_t r1 = tgpig.y;
2265
- const int r2 = tgpig.z;
2868
+ const int im = tgpig.z;
2266
2869
 
2267
2870
  const int row = 2 * r0 + sgitg;
2268
- const uint offset0 = r2/gqa*(nb*ne0);
2871
+
2872
+ const uint i12 = im%ne12;
2873
+ const uint i13 = im/ne12;
2874
+
2875
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2876
+
2269
2877
  device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
2270
- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2878
+ device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2271
2879
 
2272
2880
  float sumf = 0;
2273
2881
 
@@ -2333,7 +2941,7 @@ kernel void kernel_mul_mv_q6_K_f32(
2333
2941
 
2334
2942
  const float tot = simd_sum(sumf);
2335
2943
  if (tiisg == 0) {
2336
- dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
2944
+ dst[r1*ne0 + im*ne0*ne1 + row] = tot;
2337
2945
  }
2338
2946
  }
2339
2947
 
@@ -2643,24 +3251,25 @@ kernel void kernel_get_rows(
2643
3251
 
2644
3252
  // each block_q contains 16*nl weights
2645
3253
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
2646
- kernel void kernel_mul_mm(device const uchar * src0,
2647
- device const uchar * src1,
2648
- device float * dst,
2649
- constant int64_t & ne00,
2650
- constant int64_t & ne02,
2651
- constant int64_t & nb01,
2652
- constant int64_t & nb02,
2653
- constant int64_t & ne12,
2654
- constant int64_t & nb10,
2655
- constant int64_t & nb11,
2656
- constant int64_t & nb12,
2657
- constant int64_t & ne0,
2658
- constant int64_t & ne1,
2659
- constant uint & gqa,
2660
- threadgroup uchar * shared_memory [[threadgroup(0)]],
2661
- uint3 tgpig[[threadgroup_position_in_grid]],
2662
- uint tiitg[[thread_index_in_threadgroup]],
2663
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3254
+ void kernel_mul_mm_impl(device const uchar * src0,
3255
+ device const uchar * src1,
3256
+ device float * dst,
3257
+ constant int64_t & ne00,
3258
+ constant int64_t & ne02,
3259
+ constant int64_t & nb01,
3260
+ constant int64_t & nb02,
3261
+ constant int64_t & ne12,
3262
+ constant int64_t & nb10,
3263
+ constant int64_t & nb11,
3264
+ constant int64_t & nb12,
3265
+ constant int64_t & ne0,
3266
+ constant int64_t & ne1,
3267
+ constant uint & r2,
3268
+ constant uint & r3,
3269
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
3270
+ uint3 tgpig[[threadgroup_position_in_grid]],
3271
+ uint tiitg[[thread_index_in_threadgroup]],
3272
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2664
3273
 
2665
3274
  threadgroup half * sa = (threadgroup half *)(shared_memory);
2666
3275
  threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
@@ -2686,7 +3295,10 @@ kernel void kernel_mul_mm(device const uchar * src0,
2686
3295
 
2687
3296
  short il = (tiitg % THREAD_PER_ROW);
2688
3297
 
2689
- uint offset0 = im/gqa*nb02;
3298
+ const uint i12 = im%ne12;
3299
+ const uint i13 = im/ne12;
3300
+
3301
+ uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
2690
3302
  ushort offset1 = il/nl;
2691
3303
 
2692
3304
  device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
@@ -2770,14 +3382,116 @@ kernel void kernel_mul_mm(device const uchar * src0,
2770
3382
  }
2771
3383
  }
2772
3384
 
3385
+ template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
3386
+ kernel void kernel_mul_mm(device const uchar * src0,
3387
+ device const uchar * src1,
3388
+ device float * dst,
3389
+ constant int64_t & ne00,
3390
+ constant int64_t & ne02,
3391
+ constant int64_t & nb01,
3392
+ constant int64_t & nb02,
3393
+ constant int64_t & ne12,
3394
+ constant int64_t & nb10,
3395
+ constant int64_t & nb11,
3396
+ constant int64_t & nb12,
3397
+ constant int64_t & ne0,
3398
+ constant int64_t & ne1,
3399
+ constant uint & r2,
3400
+ constant uint & r3,
3401
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
3402
+ uint3 tgpig[[threadgroup_position_in_grid]],
3403
+ uint tiitg[[thread_index_in_threadgroup]],
3404
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3405
+ kernel_mul_mm_impl<block_q, nl, dequantize_func>(
3406
+ src0,
3407
+ src1,
3408
+ dst,
3409
+ ne00,
3410
+ ne02,
3411
+ nb01,
3412
+ nb02,
3413
+ ne12,
3414
+ nb10,
3415
+ nb11,
3416
+ nb12,
3417
+ ne0,
3418
+ ne1,
3419
+ r2,
3420
+ r3,
3421
+ shared_memory,
3422
+ tgpig,
3423
+ tiitg,
3424
+ sgitg);
3425
+ }
3426
+
3427
+ template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
3428
+ kernel void kernel_mul_mm_id(
3429
+ device const int32_t * ids,
3430
+ device const uchar * src1,
3431
+ device float * dst,
3432
+ constant int64_t & ne00,
3433
+ constant int64_t & ne02,
3434
+ constant int64_t & nb01,
3435
+ constant int64_t & nb02,
3436
+ constant int64_t & ne12,
3437
+ constant int64_t & nb10,
3438
+ constant int64_t & nb11,
3439
+ constant int64_t & nb12,
3440
+ constant int64_t & ne0,
3441
+ constant int64_t & ne1,
3442
+ constant uint & r2,
3443
+ constant uint & r3,
3444
+ constant int & idx,
3445
+ device const uchar * src00,
3446
+ device const uchar * src01,
3447
+ device const uchar * src02,
3448
+ device const uchar * src03,
3449
+ device const uchar * src04,
3450
+ device const uchar * src05,
3451
+ device const uchar * src06,
3452
+ device const uchar * src07,
3453
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
3454
+ uint3 tgpig[[threadgroup_position_in_grid]],
3455
+ uint tiitg[[thread_index_in_threadgroup]],
3456
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3457
+ device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
3458
+
3459
+ kernel_mul_mm_impl<block_q, nl, dequantize_func>(
3460
+ src0[ids[idx]],
3461
+ src1,
3462
+ dst,
3463
+ ne00,
3464
+ ne02,
3465
+ nb01,
3466
+ nb02,
3467
+ ne12,
3468
+ nb10,
3469
+ nb11,
3470
+ nb12,
3471
+ ne0,
3472
+ ne1,
3473
+ r2,
3474
+ r3,
3475
+ shared_memory,
3476
+ tgpig,
3477
+ tiitg,
3478
+ sgitg);
3479
+ }
3480
+
2773
3481
  #if QK_K == 256
2774
3482
  #define QK_NL 16
2775
3483
  #else
2776
3484
  #define QK_NL 4
2777
3485
  #endif
2778
3486
 
2779
- typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
2780
- constant uint64_t &, constant uint64_t &, uint, uint, uint);
3487
+ typedef void (get_rows_t)(
3488
+ device const void * src0,
3489
+ device const int * src1,
3490
+ device float * dst,
3491
+ constant int64_t & ne00,
3492
+ constant uint64_t & nb01,
3493
+ constant uint64_t & nb1,
3494
+ uint, uint, uint);
2781
3495
 
2782
3496
  template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
2783
3497
  template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
@@ -2806,8 +3520,10 @@ typedef void (mat_mm_t)(
2806
3520
  constant int64_t & nb12,
2807
3521
  constant int64_t & ne0,
2808
3522
  constant int64_t & ne1,
2809
- constant uint & gqa,
2810
- threadgroup uchar *, uint3, uint, uint);
3523
+ constant uint & r2,
3524
+ constant uint & r3,
3525
+ threadgroup uchar *,
3526
+ uint3, uint, uint);
2811
3527
 
2812
3528
  template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
2813
3529
  template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
@@ -2821,3 +3537,44 @@ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
2821
3537
  template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
2822
3538
  template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
2823
3539
  template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
3540
+
3541
+ typedef void (mat_mm_id_t)(
3542
+ device const int32_t * ids,
3543
+ device const uchar * src1,
3544
+ device float * dst,
3545
+ constant int64_t & ne00,
3546
+ constant int64_t & ne02,
3547
+ constant int64_t & nb01,
3548
+ constant int64_t & nb02,
3549
+ constant int64_t & ne12,
3550
+ constant int64_t & nb10,
3551
+ constant int64_t & nb11,
3552
+ constant int64_t & nb12,
3553
+ constant int64_t & ne0,
3554
+ constant int64_t & ne1,
3555
+ constant uint & r2,
3556
+ constant uint & r3,
3557
+ constant int & idx,
3558
+ device const uchar * src00,
3559
+ device const uchar * src01,
3560
+ device const uchar * src02,
3561
+ device const uchar * src03,
3562
+ device const uchar * src04,
3563
+ device const uchar * src05,
3564
+ device const uchar * src06,
3565
+ device const uchar * src07,
3566
+ threadgroup uchar *,
3567
+ uint3, uint, uint);
3568
+
3569
+ template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
3570
+ template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
3571
+ template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
3572
+ template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
3573
+ template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
3574
+ template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>;
3575
+ template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>;
3576
+ template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
3577
+ template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
3578
+ template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
3579
+ template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
3580
+ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;