whisper.rn 0.4.0-rc.3 → 0.4.0-rc.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. package/README.md +6 -6
  2. package/android/build.gradle +4 -0
  3. package/android/src/main/CMakeLists.txt +7 -0
  4. package/android/src/main/java/com/rnwhisper/AudioUtils.java +0 -80
  5. package/android/src/main/java/com/rnwhisper/RNWhisper.java +6 -1
  6. package/android/src/main/java/com/rnwhisper/WhisperContext.java +53 -135
  7. package/android/src/main/jni-utils.h +76 -0
  8. package/android/src/main/jni.cpp +188 -109
  9. package/cpp/README.md +1 -1
  10. package/cpp/coreml/whisper-encoder-impl.h +1 -1
  11. package/cpp/coreml/whisper-encoder.h +4 -0
  12. package/cpp/coreml/whisper-encoder.mm +4 -2
  13. package/cpp/ggml-alloc.c +451 -282
  14. package/cpp/ggml-alloc.h +74 -8
  15. package/cpp/ggml-backend-impl.h +112 -0
  16. package/cpp/ggml-backend.c +1357 -0
  17. package/cpp/ggml-backend.h +181 -0
  18. package/cpp/ggml-impl.h +243 -0
  19. package/cpp/{ggml-metal.metal → ggml-metal-whisper.metal} +1556 -329
  20. package/cpp/ggml-metal.h +28 -1
  21. package/cpp/ggml-metal.m +1128 -308
  22. package/cpp/ggml-quants.c +7382 -0
  23. package/cpp/ggml-quants.h +224 -0
  24. package/cpp/ggml.c +3848 -5245
  25. package/cpp/ggml.h +353 -155
  26. package/cpp/rn-audioutils.cpp +68 -0
  27. package/cpp/rn-audioutils.h +14 -0
  28. package/cpp/rn-whisper-log.h +11 -0
  29. package/cpp/rn-whisper.cpp +141 -59
  30. package/cpp/rn-whisper.h +47 -15
  31. package/cpp/whisper.cpp +1750 -964
  32. package/cpp/whisper.h +97 -15
  33. package/ios/RNWhisper.mm +15 -9
  34. package/ios/RNWhisper.xcodeproj/project.xcworkspace/contents.xcworkspacedata +4 -0
  35. package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +8 -0
  36. package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcuserdata/jhen.xcuserdatad/UserInterfaceState.xcuserstate +0 -0
  37. package/ios/RNWhisper.xcodeproj/xcuserdata/jhen.xcuserdatad/xcschemes/xcschememanagement.plist +19 -0
  38. package/ios/RNWhisperAudioUtils.h +0 -2
  39. package/ios/RNWhisperAudioUtils.m +0 -56
  40. package/ios/RNWhisperContext.h +8 -12
  41. package/ios/RNWhisperContext.mm +132 -138
  42. package/jest/mock.js +1 -1
  43. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  44. package/lib/commonjs/index.js +28 -9
  45. package/lib/commonjs/index.js.map +1 -1
  46. package/lib/commonjs/version.json +1 -1
  47. package/lib/module/NativeRNWhisper.js.map +1 -1
  48. package/lib/module/index.js +28 -9
  49. package/lib/module/index.js.map +1 -1
  50. package/lib/module/version.json +1 -1
  51. package/lib/typescript/NativeRNWhisper.d.ts +7 -1
  52. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  53. package/lib/typescript/index.d.ts +7 -2
  54. package/lib/typescript/index.d.ts.map +1 -1
  55. package/package.json +6 -5
  56. package/src/NativeRNWhisper.ts +8 -1
  57. package/src/index.ts +29 -17
  58. package/src/version.json +1 -1
  59. package/whisper-rn.podspec +1 -2
@@ -3,6 +3,8 @@
3
3
  using namespace metal;
4
4
 
5
5
  #define MAX(x, y) ((x) > (y) ? (x) : (y))
6
+ #define MIN(x, y) ((x) < (y) ? (x) : (y))
7
+ #define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
6
8
 
7
9
  #define QK4_0 32
8
10
  #define QR4_0 2
@@ -13,23 +15,187 @@ typedef struct {
13
15
 
14
16
  #define QK4_1 32
15
17
  typedef struct {
16
- half d; // delta
17
- half m; // min
18
+ half d; // delta
19
+ half m; // min
18
20
  uint8_t qs[QK4_1 / 2]; // nibbles / quants
19
21
  } block_q4_1;
20
22
 
23
+ #define QK5_0 32
24
+ typedef struct {
25
+ half d; // delta
26
+ uint8_t qh[4]; // 5-th bit of quants
27
+ uint8_t qs[QK5_0 / 2]; // nibbles / quants
28
+ } block_q5_0;
29
+
30
+ #define QK5_1 32
31
+ typedef struct {
32
+ half d; // delta
33
+ half m; // min
34
+ uint8_t qh[4]; // 5-th bit of quants
35
+ uint8_t qs[QK5_1 / 2]; // nibbles / quants
36
+ } block_q5_1;
37
+
21
38
  #define QK8_0 32
22
39
  typedef struct {
23
40
  half d; // delta
24
41
  int8_t qs[QK8_0]; // quants
25
42
  } block_q8_0;
26
43
 
44
+ #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
45
+
46
+ enum ggml_sort_order {
47
+ GGML_SORT_ASC,
48
+ GGML_SORT_DESC,
49
+ };
50
+
51
+ // general-purpose kernel for addition, multiplication and division of two tensors
52
+ // pros: works for non-contiguous tensors, supports broadcast across all dims
53
+ // cons: not very efficient
27
54
  kernel void kernel_add(
28
- device const float4 * src0,
29
- device const float4 * src1,
30
- device float4 * dst,
31
- uint tpig[[thread_position_in_grid]]) {
32
- dst[tpig] = src0[tpig] + src1[tpig];
55
+ device const char * src0,
56
+ device const char * src1,
57
+ device char * dst,
58
+ constant int64_t & ne00,
59
+ constant int64_t & ne01,
60
+ constant int64_t & ne02,
61
+ constant int64_t & ne03,
62
+ constant int64_t & nb00,
63
+ constant int64_t & nb01,
64
+ constant int64_t & nb02,
65
+ constant int64_t & nb03,
66
+ constant int64_t & ne10,
67
+ constant int64_t & ne11,
68
+ constant int64_t & ne12,
69
+ constant int64_t & ne13,
70
+ constant int64_t & nb10,
71
+ constant int64_t & nb11,
72
+ constant int64_t & nb12,
73
+ constant int64_t & nb13,
74
+ constant int64_t & ne0,
75
+ constant int64_t & ne1,
76
+ constant int64_t & ne2,
77
+ constant int64_t & ne3,
78
+ constant int64_t & nb0,
79
+ constant int64_t & nb1,
80
+ constant int64_t & nb2,
81
+ constant int64_t & nb3,
82
+ uint3 tgpig[[threadgroup_position_in_grid]],
83
+ uint3 tpitg[[thread_position_in_threadgroup]],
84
+ uint3 ntg[[threads_per_threadgroup]]) {
85
+ const int64_t i03 = tgpig.z;
86
+ const int64_t i02 = tgpig.y;
87
+ const int64_t i01 = tgpig.x;
88
+
89
+ const int64_t i13 = i03 % ne13;
90
+ const int64_t i12 = i02 % ne12;
91
+ const int64_t i11 = i01 % ne11;
92
+
93
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
94
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
95
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
96
+
97
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
98
+ const int i10 = i0 % ne10;
99
+ *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10));
100
+ }
101
+ }
102
+
103
+ kernel void kernel_mul(
104
+ device const char * src0,
105
+ device const char * src1,
106
+ device char * dst,
107
+ constant int64_t & ne00,
108
+ constant int64_t & ne01,
109
+ constant int64_t & ne02,
110
+ constant int64_t & ne03,
111
+ constant int64_t & nb00,
112
+ constant int64_t & nb01,
113
+ constant int64_t & nb02,
114
+ constant int64_t & nb03,
115
+ constant int64_t & ne10,
116
+ constant int64_t & ne11,
117
+ constant int64_t & ne12,
118
+ constant int64_t & ne13,
119
+ constant int64_t & nb10,
120
+ constant int64_t & nb11,
121
+ constant int64_t & nb12,
122
+ constant int64_t & nb13,
123
+ constant int64_t & ne0,
124
+ constant int64_t & ne1,
125
+ constant int64_t & ne2,
126
+ constant int64_t & ne3,
127
+ constant int64_t & nb0,
128
+ constant int64_t & nb1,
129
+ constant int64_t & nb2,
130
+ constant int64_t & nb3,
131
+ uint3 tgpig[[threadgroup_position_in_grid]],
132
+ uint3 tpitg[[thread_position_in_threadgroup]],
133
+ uint3 ntg[[threads_per_threadgroup]]) {
134
+ const int64_t i03 = tgpig.z;
135
+ const int64_t i02 = tgpig.y;
136
+ const int64_t i01 = tgpig.x;
137
+
138
+ const int64_t i13 = i03 % ne13;
139
+ const int64_t i12 = i02 % ne12;
140
+ const int64_t i11 = i01 % ne11;
141
+
142
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
143
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
144
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
145
+
146
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
147
+ const int i10 = i0 % ne10;
148
+ *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10));
149
+ }
150
+ }
151
+
152
+ kernel void kernel_div(
153
+ device const char * src0,
154
+ device const char * src1,
155
+ device char * dst,
156
+ constant int64_t & ne00,
157
+ constant int64_t & ne01,
158
+ constant int64_t & ne02,
159
+ constant int64_t & ne03,
160
+ constant int64_t & nb00,
161
+ constant int64_t & nb01,
162
+ constant int64_t & nb02,
163
+ constant int64_t & nb03,
164
+ constant int64_t & ne10,
165
+ constant int64_t & ne11,
166
+ constant int64_t & ne12,
167
+ constant int64_t & ne13,
168
+ constant int64_t & nb10,
169
+ constant int64_t & nb11,
170
+ constant int64_t & nb12,
171
+ constant int64_t & nb13,
172
+ constant int64_t & ne0,
173
+ constant int64_t & ne1,
174
+ constant int64_t & ne2,
175
+ constant int64_t & ne3,
176
+ constant int64_t & nb0,
177
+ constant int64_t & nb1,
178
+ constant int64_t & nb2,
179
+ constant int64_t & nb3,
180
+ uint3 tgpig[[threadgroup_position_in_grid]],
181
+ uint3 tpitg[[thread_position_in_threadgroup]],
182
+ uint3 ntg[[threads_per_threadgroup]]) {
183
+ const int64_t i03 = tgpig.z;
184
+ const int64_t i02 = tgpig.y;
185
+ const int64_t i01 = tgpig.x;
186
+
187
+ const int64_t i13 = i03 % ne13;
188
+ const int64_t i12 = i02 % ne12;
189
+ const int64_t i11 = i01 % ne11;
190
+
191
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
192
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
193
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
194
+
195
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
196
+ const int i10 = i0 % ne10;
197
+ *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10));
198
+ }
33
199
  }
34
200
 
35
201
  // assumption: src1 is a row
@@ -38,34 +204,41 @@ kernel void kernel_add_row(
38
204
  device const float4 * src0,
39
205
  device const float4 * src1,
40
206
  device float4 * dst,
41
- constant int64_t & nb,
207
+ constant int64_t & nb [[buffer(27)]],
42
208
  uint tpig[[thread_position_in_grid]]) {
43
209
  dst[tpig] = src0[tpig] + src1[tpig % nb];
44
210
  }
45
211
 
46
- kernel void kernel_mul(
212
+ kernel void kernel_mul_row(
47
213
  device const float4 * src0,
48
214
  device const float4 * src1,
49
215
  device float4 * dst,
216
+ constant int64_t & nb [[buffer(27)]],
50
217
  uint tpig[[thread_position_in_grid]]) {
51
- dst[tpig] = src0[tpig] * src1[tpig];
218
+ dst[tpig] = src0[tpig] * src1[tpig % nb];
52
219
  }
53
220
 
54
- // assumption: src1 is a row
55
- // broadcast src1 into src0
56
- kernel void kernel_mul_row(
221
+ kernel void kernel_div_row(
57
222
  device const float4 * src0,
58
223
  device const float4 * src1,
59
224
  device float4 * dst,
60
- constant int64_t & nb,
225
+ constant int64_t & nb [[buffer(27)]],
61
226
  uint tpig[[thread_position_in_grid]]) {
62
- dst[tpig] = src0[tpig] * src1[tpig % nb];
227
+ dst[tpig] = src0[tpig] / src1[tpig % nb];
63
228
  }
64
229
 
65
230
  kernel void kernel_scale(
231
+ device const float * src0,
232
+ device float * dst,
233
+ constant float & scale,
234
+ uint tpig[[thread_position_in_grid]]) {
235
+ dst[tpig] = src0[tpig] * scale;
236
+ }
237
+
238
+ kernel void kernel_scale_4(
66
239
  device const float4 * src0,
67
240
  device float4 * dst,
68
- constant float & scale,
241
+ constant float & scale,
69
242
  uint tpig[[thread_position_in_grid]]) {
70
243
  dst[tpig] = src0[tpig] * scale;
71
244
  }
@@ -85,6 +258,61 @@ kernel void kernel_relu(
85
258
  dst[tpig] = max(0.0f, src0[tpig]);
86
259
  }
87
260
 
261
+ kernel void kernel_sqr(
262
+ device const float * src0,
263
+ device float * dst,
264
+ uint tpig[[thread_position_in_grid]]) {
265
+ dst[tpig] = src0[tpig] * src0[tpig];
266
+ }
267
+
268
+ kernel void kernel_sum_rows(
269
+ device const float * src0,
270
+ device float * dst,
271
+ constant int64_t & ne00,
272
+ constant int64_t & ne01,
273
+ constant int64_t & ne02,
274
+ constant int64_t & ne03,
275
+ constant int64_t & nb00,
276
+ constant int64_t & nb01,
277
+ constant int64_t & nb02,
278
+ constant int64_t & nb03,
279
+ constant int64_t & ne10,
280
+ constant int64_t & ne11,
281
+ constant int64_t & ne12,
282
+ constant int64_t & ne13,
283
+ constant int64_t & nb10,
284
+ constant int64_t & nb11,
285
+ constant int64_t & nb12,
286
+ constant int64_t & nb13,
287
+ constant int64_t & ne0,
288
+ constant int64_t & ne1,
289
+ constant int64_t & ne2,
290
+ constant int64_t & ne3,
291
+ constant int64_t & nb0,
292
+ constant int64_t & nb1,
293
+ constant int64_t & nb2,
294
+ constant int64_t & nb3,
295
+ uint3 tpig[[thread_position_in_grid]]) {
296
+ int64_t i3 = tpig.z;
297
+ int64_t i2 = tpig.y;
298
+ int64_t i1 = tpig.x;
299
+
300
+ if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
301
+ return;
302
+ }
303
+
304
+ device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
305
+ device float * dst_row = (device float *) ((device char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
306
+
307
+ float row_sum = 0;
308
+
309
+ for (int64_t i0 = 0; i0 < ne00; i0++) {
310
+ row_sum += src_row[i0];
311
+ }
312
+
313
+ dst_row[0] = row_sum;
314
+ }
315
+
88
316
  constant float GELU_COEF_A = 0.044715f;
89
317
  constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
90
318
 
@@ -103,82 +331,165 @@ kernel void kernel_gelu(
103
331
 
104
332
  kernel void kernel_soft_max(
105
333
  device const float * src0,
334
+ device const float * src1,
106
335
  device float * dst,
107
336
  constant int64_t & ne00,
108
337
  constant int64_t & ne01,
109
338
  constant int64_t & ne02,
110
- uint3 tgpig[[threadgroup_position_in_grid]],
111
- uint3 tpitg[[thread_position_in_threadgroup]],
112
- uint3 ntg[[threads_per_threadgroup]]) {
113
- const int64_t i03 = tgpig[2];
114
- const int64_t i02 = tgpig[1];
115
- const int64_t i01 = tgpig[0];
116
-
117
- device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
118
- device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
339
+ constant float & scale,
340
+ threadgroup float * buf [[threadgroup(0)]],
341
+ uint tgpig[[threadgroup_position_in_grid]],
342
+ uint tpitg[[thread_position_in_threadgroup]],
343
+ uint sgitg[[simdgroup_index_in_threadgroup]],
344
+ uint tiisg[[thread_index_in_simdgroup]],
345
+ uint ntg[[threads_per_threadgroup]]) {
346
+ const int64_t i03 = (tgpig) / (ne02*ne01);
347
+ const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
348
+ const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
349
+
350
+ device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
351
+ device const float * pmask = src1 ? src1 + i01*ne00 : nullptr;
352
+ device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
119
353
 
120
354
  // parallel max
121
- float lmax = psrc0[tpitg[0]];
122
- for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
123
- lmax = MAX(lmax, psrc0[i00]);
355
+ float lmax = -INFINITY;
356
+
357
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
358
+ lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f));
359
+ }
360
+
361
+ // find the max value in the block
362
+ float max_val = simd_max(lmax);
363
+ if (ntg > N_SIMDWIDTH) {
364
+ if (sgitg == 0) {
365
+ buf[tiisg] = -INFINITY;
366
+ }
367
+
368
+ threadgroup_barrier(mem_flags::mem_threadgroup);
369
+
370
+ if (tiisg == 0) {
371
+ buf[sgitg] = max_val;
372
+ }
373
+
374
+ threadgroup_barrier(mem_flags::mem_threadgroup);
375
+
376
+ max_val = buf[tiisg];
377
+ max_val = simd_max(max_val);
124
378
  }
125
- const float max = simd_max(lmax);
126
379
 
127
380
  // parallel sum
128
381
  float lsum = 0.0f;
129
- for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
130
- const float exp_psrc0 = exp(psrc0[i00] - max);
382
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
383
+ const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
131
384
  lsum += exp_psrc0;
132
- // Remember the result of exp here. exp is expensive, so we really do not
133
- // whish to compute it twice.
134
385
  pdst[i00] = exp_psrc0;
135
386
  }
136
387
 
137
- const float sum = simd_sum(lsum);
388
+ float sum = simd_sum(lsum);
389
+ if (ntg > N_SIMDWIDTH) {
390
+ if (sgitg == 0) {
391
+ buf[tiisg] = 0.0f;
392
+ }
393
+
394
+ threadgroup_barrier(mem_flags::mem_threadgroup);
138
395
 
139
- for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
140
- pdst[i00] /= sum;
396
+ if (tiisg == 0) {
397
+ buf[sgitg] = sum;
398
+ }
399
+
400
+ threadgroup_barrier(mem_flags::mem_threadgroup);
401
+
402
+ sum = buf[tiisg];
403
+ sum = simd_sum(sum);
404
+ }
405
+
406
+ const float inv_sum = 1.0f/sum;
407
+
408
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
409
+ pdst[i00] *= inv_sum;
141
410
  }
142
411
  }
143
412
 
144
413
  kernel void kernel_soft_max_4(
145
414
  device const float * src0,
415
+ device const float * src1,
146
416
  device float * dst,
147
417
  constant int64_t & ne00,
148
418
  constant int64_t & ne01,
149
419
  constant int64_t & ne02,
150
- uint3 tgpig[[threadgroup_position_in_grid]],
151
- uint3 tpitg[[thread_position_in_threadgroup]],
152
- uint3 ntg[[threads_per_threadgroup]]) {
153
- const int64_t i03 = tgpig[2];
154
- const int64_t i02 = tgpig[1];
155
- const int64_t i01 = tgpig[0];
156
-
157
- device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
158
- device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
420
+ constant float & scale,
421
+ threadgroup float * buf [[threadgroup(0)]],
422
+ uint tgpig[[threadgroup_position_in_grid]],
423
+ uint tpitg[[thread_position_in_threadgroup]],
424
+ uint sgitg[[simdgroup_index_in_threadgroup]],
425
+ uint tiisg[[thread_index_in_simdgroup]],
426
+ uint ntg[[threads_per_threadgroup]]) {
427
+ const int64_t i03 = (tgpig) / (ne02*ne01);
428
+ const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
429
+ const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
430
+
431
+ device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
432
+ device const float4 * pmask = src1 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
433
+ device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
159
434
 
160
435
  // parallel max
161
- float4 lmax4 = psrc4[tpitg[0]];
162
- for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
163
- lmax4 = fmax(lmax4, psrc4[i00]);
436
+ float4 lmax4 = -INFINITY;
437
+
438
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
439
+ lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f));
164
440
  }
165
- float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
166
441
 
167
- const float max = simd_max(lmax);
442
+ const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
443
+
444
+ float max_val = simd_max(lmax);
445
+ if (ntg > N_SIMDWIDTH) {
446
+ if (sgitg == 0) {
447
+ buf[tiisg] = -INFINITY;
448
+ }
449
+
450
+ threadgroup_barrier(mem_flags::mem_threadgroup);
451
+
452
+ if (tiisg == 0) {
453
+ buf[sgitg] = max_val;
454
+ }
455
+
456
+ threadgroup_barrier(mem_flags::mem_threadgroup);
457
+
458
+ max_val = buf[tiisg];
459
+ max_val = simd_max(max_val);
460
+ }
168
461
 
169
462
  // parallel sum
170
463
  float4 lsum4 = 0.0f;
171
- for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
172
- const float4 exp_psrc4 = exp(psrc4[i00] - max);
464
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
465
+ const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
173
466
  lsum4 += exp_psrc4;
174
467
  pdst4[i00] = exp_psrc4;
175
468
  }
176
- float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
177
469
 
178
- const float sum = simd_sum(lsum);
470
+ const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
471
+ float sum = simd_sum(lsum);
472
+ if (ntg > N_SIMDWIDTH) {
473
+ if (sgitg == 0) {
474
+ buf[tiisg] = 0.0f;
475
+ }
179
476
 
180
- for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
181
- pdst4[i00] /= sum;
477
+ threadgroup_barrier(mem_flags::mem_threadgroup);
478
+
479
+ if (tiisg == 0) {
480
+ buf[sgitg] = sum;
481
+ }
482
+
483
+ threadgroup_barrier(mem_flags::mem_threadgroup);
484
+
485
+ sum = buf[tiisg];
486
+ sum = simd_sum(sum);
487
+ }
488
+
489
+ const float inv_sum = 1.0f/sum;
490
+
491
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
492
+ pdst4[i00] *= inv_sum;
182
493
  }
183
494
  }
184
495
 
@@ -197,7 +508,7 @@ kernel void kernel_diag_mask_inf(
197
508
  dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
198
509
  } else {
199
510
  dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
200
- }
511
+ }
201
512
  }
202
513
 
203
514
  kernel void kernel_diag_mask_inf_8(
@@ -285,16 +596,16 @@ kernel void kernel_rms_norm(
285
596
  constant int64_t & ne00,
286
597
  constant uint64_t & nb01,
287
598
  constant float & eps,
288
- threadgroup float * sum [[threadgroup(0)]],
599
+ threadgroup float * buf [[threadgroup(0)]],
289
600
  uint tgpig[[threadgroup_position_in_grid]],
290
601
  uint tpitg[[thread_position_in_threadgroup]],
291
602
  uint sgitg[[simdgroup_index_in_threadgroup]],
292
603
  uint tiisg[[thread_index_in_simdgroup]],
293
604
  uint ntg[[threads_per_threadgroup]]) {
294
605
  device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
295
- device const float * x_scalar = (device const float *) x;
296
- float4 sumf=0;
297
- float all_sum=0;
606
+
607
+ float4 sumf = 0;
608
+ float all_sum = 0;
298
609
 
299
610
  // parallel sum
300
611
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
@@ -302,35 +613,30 @@ kernel void kernel_rms_norm(
302
613
  }
303
614
  all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
304
615
  all_sum = simd_sum(all_sum);
305
- if (tiisg == 0) {
306
- sum[sgitg] = all_sum;
307
- }
616
+ if (ntg > N_SIMDWIDTH) {
617
+ if (sgitg == 0) {
618
+ buf[tiisg] = 0.0f;
619
+ }
308
620
 
309
- threadgroup_barrier(mem_flags::mem_threadgroup);
310
- // broadcast, simd group number is ntg / 32
311
- for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
312
- if (tpitg < i) {
313
- sum[tpitg] += sum[tpitg + i];
314
- }
315
- }
316
- if (tpitg == 0) {
317
- for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];}
318
- sum[0] /= ne00;
319
- }
621
+ threadgroup_barrier(mem_flags::mem_threadgroup);
320
622
 
321
- threadgroup_barrier(mem_flags::mem_threadgroup);
623
+ if (tiisg == 0) {
624
+ buf[sgitg] = all_sum;
625
+ }
322
626
 
323
- const float mean = sum[0];
627
+ threadgroup_barrier(mem_flags::mem_threadgroup);
628
+
629
+ all_sum = buf[tiisg];
630
+ all_sum = simd_sum(all_sum);
631
+ }
632
+
633
+ const float mean = all_sum/ne00;
324
634
  const float scale = 1.0f/sqrt(mean + eps);
325
635
 
326
636
  device float4 * y = (device float4 *) (dst + tgpig*ne00);
327
- device float * y_scalar = (device float *) y;
328
637
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
329
638
  y[i00] = x[i00] * scale;
330
639
  }
331
- if (tpitg == 0) {
332
- for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;}
333
- }
334
640
  }
335
641
 
336
642
  // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
@@ -339,8 +645,11 @@ kernel void kernel_rms_norm(
339
645
  // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
340
646
  inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
341
647
  float d = qb_curr->d;
648
+
342
649
  float2 acc = 0.f;
650
+
343
651
  device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
652
+
344
653
  for (int i = 0; i < 8; i+=2) {
345
654
  acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
346
655
  + yl[i + 1] * (qs[i / 2] & 0x0F00);
@@ -357,8 +666,11 @@ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thre
357
666
  inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
358
667
  float d = qb_curr->d;
359
668
  float m = qb_curr->m;
360
- device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
669
+
361
670
  float2 acc = 0.f;
671
+
672
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
673
+
362
674
  for (int i = 0; i < 8; i+=2) {
363
675
  acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
364
676
  + yl[i + 1] * (qs[i / 2] & 0x0F00);
@@ -368,31 +680,92 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
368
680
  return d * (acc[0] + acc[1]) + sumy * m;
369
681
  }
370
682
 
683
+ // function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
684
+ // il indicates where the q5 quants begin (0 or QK5_0/4)
685
+ // we assume that the yl's have been multiplied with the appropriate scale factor
686
+ // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
687
+ inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
688
+ float d = qb_curr->d;
689
+
690
+ float2 acc = 0.f;
691
+
692
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2);
693
+ const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
694
+
695
+ for (int i = 0; i < 8; i+=2) {
696
+ acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
697
+ + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
698
+ acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
699
+ + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
700
+ }
701
+ return d * (sumy * -16.f + acc[0] + acc[1]);
702
+ }
703
+
704
+ // function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
705
+ // il indicates where the q5 quants begin (0 or QK5_1/4)
706
+ // we assume that the yl's have been multiplied with the appropriate scale factor
707
+ // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
708
+ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) {
709
+ float d = qb_curr->d;
710
+ float m = qb_curr->m;
711
+
712
+ float2 acc = 0.f;
713
+
714
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2);
715
+ const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
716
+
717
+ for (int i = 0; i < 8; i+=2) {
718
+ acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
719
+ + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
720
+ acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
721
+ + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
722
+ }
723
+ return d * (acc[0] + acc[1]) + sumy * m;
724
+ }
725
+
371
726
  // putting them in the kernel cause a significant performance penalty
372
- #define N_DST 4 // each SIMD group works on 4 rows
373
- #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
374
- #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
727
+ #define N_DST 4 // each SIMD group works on 4 rows
728
+ #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
375
729
  //Note: This is a template, but strictly speaking it only applies to
376
730
  // quantizations where the block size is 32. It also does not
377
731
  // giard against the number of rows not being divisible by
378
732
  // N_DST, so this is another explicit assumption of the implementation.
379
733
  template<typename block_q_type, int nr, int nsg, int nw>
380
- void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
381
- int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
382
- uint3 tgpig, uint tiisg, uint sgitg) {
734
+ void mul_vec_q_n_f32(
735
+ device const void * src0,
736
+ device const float * src1,
737
+ device float * dst,
738
+ int64_t ne00,
739
+ int64_t ne01,
740
+ int64_t ne02,
741
+ int64_t ne10,
742
+ int64_t ne12,
743
+ int64_t ne0,
744
+ int64_t ne1,
745
+ uint r2,
746
+ uint r3,
747
+ uint3 tgpig, uint tiisg, uint sgitg) {
383
748
  const int nb = ne00/QK4_0;
749
+
384
750
  const int r0 = tgpig.x;
385
751
  const int r1 = tgpig.y;
386
752
  const int im = tgpig.z;
753
+
387
754
  const int first_row = (r0 * nsg + sgitg) * nr;
388
- const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
755
+
756
+ const uint i12 = im%ne12;
757
+ const uint i13 = im/ne12;
758
+
759
+ const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
760
+
389
761
  device const block_q_type * x = (device const block_q_type *) src0 + offset0;
390
762
  device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
391
- float yl[16]; // src1 vector cache
392
- float sumf[nr]={0.f};
393
763
 
394
- const int ix = tiisg/2;
395
- const int il = 8*(tiisg%2);
764
+ float yl[16]; // src1 vector cache
765
+ float sumf[nr] = {0.f};
766
+
767
+ const int ix = (tiisg/2);
768
+ const int il = (tiisg%2)*8;
396
769
 
397
770
  device const float * yb = y + ix * QK4_0 + il;
398
771
 
@@ -403,6 +776,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
403
776
  sumy += yb[i] + yb[i+1];
404
777
  yl[i+0] = yb[i+ 0];
405
778
  yl[i+1] = yb[i+ 1]/256.f;
779
+
406
780
  sumy += yb[i+16] + yb[i+17];
407
781
  yl[i+8] = yb[i+16]/16.f;
408
782
  yl[i+9] = yb[i+17]/4096.f;
@@ -418,12 +792,12 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
418
792
  for (int row = 0; row < nr; ++row) {
419
793
  const float tot = simd_sum(sumf[row]);
420
794
  if (tiisg == 0 && first_row + row < ne01) {
421
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
795
+ dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot;
422
796
  }
423
797
  }
424
798
  }
425
799
 
426
- kernel void kernel_mul_mat_q4_0_f32(
800
+ kernel void kernel_mul_mv_q4_0_f32(
427
801
  device const void * src0,
428
802
  device const float * src1,
429
803
  device float * dst,
@@ -432,16 +806,17 @@ kernel void kernel_mul_mat_q4_0_f32(
432
806
  constant int64_t & ne02[[buffer(5)]],
433
807
  constant int64_t & ne10[[buffer(9)]],
434
808
  constant int64_t & ne12[[buffer(11)]],
435
- constant int64_t & ne0[[buffer(15)]],
436
- constant int64_t & ne1[[buffer(16)]],
437
- constant uint & gqa[[buffer(17)]],
809
+ constant int64_t & ne0 [[buffer(15)]],
810
+ constant int64_t & ne1 [[buffer(16)]],
811
+ constant uint & r2 [[buffer(17)]],
812
+ constant uint & r3 [[buffer(18)]],
438
813
  uint3 tgpig[[threadgroup_position_in_grid]],
439
- uint tiisg[[thread_index_in_simdgroup]],
440
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
441
- mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
814
+ uint tiisg[[thread_index_in_simdgroup]],
815
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
816
+ mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
442
817
  }
443
818
 
444
- kernel void kernel_mul_mat_q4_1_f32(
819
+ kernel void kernel_mul_mv_q4_1_f32(
445
820
  device const void * src0,
446
821
  device const float * src1,
447
822
  device float * dst,
@@ -450,18 +825,58 @@ kernel void kernel_mul_mat_q4_1_f32(
450
825
  constant int64_t & ne02[[buffer(5)]],
451
826
  constant int64_t & ne10[[buffer(9)]],
452
827
  constant int64_t & ne12[[buffer(11)]],
453
- constant int64_t & ne0[[buffer(15)]],
454
- constant int64_t & ne1[[buffer(16)]],
455
- constant uint & gqa[[buffer(17)]],
828
+ constant int64_t & ne0 [[buffer(15)]],
829
+ constant int64_t & ne1 [[buffer(16)]],
830
+ constant uint & r2 [[buffer(17)]],
831
+ constant uint & r3 [[buffer(18)]],
456
832
  uint3 tgpig[[threadgroup_position_in_grid]],
457
833
  uint tiisg[[thread_index_in_simdgroup]],
458
834
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
459
- mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
835
+ mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
460
836
  }
461
837
 
838
+ kernel void kernel_mul_mv_q5_0_f32(
839
+ device const void * src0,
840
+ device const float * src1,
841
+ device float * dst,
842
+ constant int64_t & ne00,
843
+ constant int64_t & ne01[[buffer(4)]],
844
+ constant int64_t & ne02[[buffer(5)]],
845
+ constant int64_t & ne10[[buffer(9)]],
846
+ constant int64_t & ne12[[buffer(11)]],
847
+ constant int64_t & ne0 [[buffer(15)]],
848
+ constant int64_t & ne1 [[buffer(16)]],
849
+ constant uint & r2 [[buffer(17)]],
850
+ constant uint & r3 [[buffer(18)]],
851
+ uint3 tgpig[[threadgroup_position_in_grid]],
852
+ uint tiisg[[thread_index_in_simdgroup]],
853
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
854
+ mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
855
+ }
856
+
857
+ kernel void kernel_mul_mv_q5_1_f32(
858
+ device const void * src0,
859
+ device const float * src1,
860
+ device float * dst,
861
+ constant int64_t & ne00,
862
+ constant int64_t & ne01[[buffer(4)]],
863
+ constant int64_t & ne02[[buffer(5)]],
864
+ constant int64_t & ne10[[buffer(9)]],
865
+ constant int64_t & ne12[[buffer(11)]],
866
+ constant int64_t & ne0 [[buffer(15)]],
867
+ constant int64_t & ne1 [[buffer(16)]],
868
+ constant uint & r2 [[buffer(17)]],
869
+ constant uint & r3 [[buffer(18)]],
870
+ uint3 tgpig[[threadgroup_position_in_grid]],
871
+ uint tiisg[[thread_index_in_simdgroup]],
872
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
873
+ mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
874
+ }
875
+
876
+
462
877
  #define NB_Q8_0 8
463
878
 
464
- kernel void kernel_mul_mat_q8_0_f32(
879
+ kernel void kernel_mul_mv_q8_0_f32(
465
880
  device const void * src0,
466
881
  device const float * src1,
467
882
  device float * dst,
@@ -470,9 +885,10 @@ kernel void kernel_mul_mat_q8_0_f32(
470
885
  constant int64_t & ne02[[buffer(5)]],
471
886
  constant int64_t & ne10[[buffer(9)]],
472
887
  constant int64_t & ne12[[buffer(11)]],
473
- constant int64_t & ne0[[buffer(15)]],
474
- constant int64_t & ne1[[buffer(16)]],
475
- constant uint & gqa[[buffer(17)]],
888
+ constant int64_t & ne0 [[buffer(15)]],
889
+ constant int64_t & ne1 [[buffer(16)]],
890
+ constant uint & r2 [[buffer(17)]],
891
+ constant uint & r3 [[buffer(18)]],
476
892
  uint3 tgpig[[threadgroup_position_in_grid]],
477
893
  uint tiisg[[thread_index_in_simdgroup]],
478
894
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -484,8 +900,14 @@ kernel void kernel_mul_mat_q8_0_f32(
484
900
  const int r0 = tgpig.x;
485
901
  const int r1 = tgpig.y;
486
902
  const int im = tgpig.z;
903
+
487
904
  const int first_row = (r0 * nsg + sgitg) * nr;
488
- const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
905
+
906
+ const uint i12 = im%ne12;
907
+ const uint i13 = im/ne12;
908
+
909
+ const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
910
+
489
911
  device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
490
912
  device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
491
913
 
@@ -525,7 +947,7 @@ kernel void kernel_mul_mat_q8_0_f32(
525
947
 
526
948
  #define N_F32_F32 4
527
949
 
528
- kernel void kernel_mul_mat_f32_f32(
950
+ kernel void kernel_mul_mv_f32_f32(
529
951
  device const char * src0,
530
952
  device const char * src1,
531
953
  device float * dst,
@@ -543,14 +965,21 @@ kernel void kernel_mul_mat_f32_f32(
543
965
  constant uint64_t & nb12,
544
966
  constant int64_t & ne0,
545
967
  constant int64_t & ne1,
968
+ constant uint & r2 [[buffer(17)]],
969
+ constant uint & r3 [[buffer(18)]],
546
970
  uint3 tgpig[[threadgroup_position_in_grid]],
547
- uint tiisg[[thread_index_in_simdgroup]]) {
971
+ uint tiisg[[thread_index_in_simdgroup]]) {
548
972
 
549
973
  const int64_t r0 = tgpig.x;
550
974
  const int64_t rb = tgpig.y*N_F32_F32;
551
975
  const int64_t im = tgpig.z;
552
976
 
553
- device const float * x = (device const float *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
977
+ const uint i12 = im%ne12;
978
+ const uint i13 = im/ne12;
979
+
980
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
981
+
982
+ device const float * x = (device const float *) (src0 + offset0);
554
983
 
555
984
  if (ne00 < 128) {
556
985
  for (int row = 0; row < N_F32_F32; ++row) {
@@ -596,7 +1025,9 @@ kernel void kernel_mul_mat_f32_f32(
596
1025
  }
597
1026
  }
598
1027
 
599
- kernel void kernel_mul_mat_f16_f32_1row(
1028
+ #define N_F16_F16 4
1029
+
1030
+ kernel void kernel_mul_mv_f16_f16(
600
1031
  device const char * src0,
601
1032
  device const char * src1,
602
1033
  device float * dst,
@@ -614,14 +1045,99 @@ kernel void kernel_mul_mat_f16_f32_1row(
614
1045
  constant uint64_t & nb12,
615
1046
  constant int64_t & ne0,
616
1047
  constant int64_t & ne1,
1048
+ constant uint & r2 [[buffer(17)]],
1049
+ constant uint & r3 [[buffer(18)]],
617
1050
  uint3 tgpig[[threadgroup_position_in_grid]],
618
- uint tiisg[[thread_index_in_simdgroup]]) {
1051
+ uint tiisg[[thread_index_in_simdgroup]]) {
1052
+
1053
+ const int64_t r0 = tgpig.x;
1054
+ const int64_t rb = tgpig.y*N_F16_F16;
1055
+ const int64_t im = tgpig.z;
1056
+
1057
+ const uint i12 = im%ne12;
1058
+ const uint i13 = im/ne12;
1059
+
1060
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1061
+
1062
+ device const half * x = (device const half *) (src0 + offset0);
1063
+
1064
+ if (ne00 < 128) {
1065
+ for (int row = 0; row < N_F16_F16; ++row) {
1066
+ int r1 = rb + row;
1067
+ if (r1 >= ne11) {
1068
+ break;
1069
+ }
1070
+
1071
+ device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
1072
+
1073
+ float sumf = 0;
1074
+ for (int i = tiisg; i < ne00; i += 32) {
1075
+ sumf += (half) x[i] * (half) y[i];
1076
+ }
1077
+
1078
+ float all_sum = simd_sum(sumf);
1079
+ if (tiisg == 0) {
1080
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1081
+ }
1082
+ }
1083
+ } else {
1084
+ device const half4 * x4 = (device const half4 *)x;
1085
+ for (int row = 0; row < N_F16_F16; ++row) {
1086
+ int r1 = rb + row;
1087
+ if (r1 >= ne11) {
1088
+ break;
1089
+ }
1090
+
1091
+ device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
1092
+ device const half4 * y4 = (device const half4 *) y;
1093
+
1094
+ float sumf = 0;
1095
+ for (int i = tiisg; i < ne00/4; i += 32) {
1096
+ for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k];
1097
+ }
1098
+
1099
+ float all_sum = simd_sum(sumf);
1100
+ if (tiisg == 0) {
1101
+ for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i];
1102
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1103
+ }
1104
+ }
1105
+ }
1106
+ }
1107
+
1108
+ kernel void kernel_mul_mv_f16_f32_1row(
1109
+ device const char * src0,
1110
+ device const char * src1,
1111
+ device float * dst,
1112
+ constant int64_t & ne00,
1113
+ constant int64_t & ne01,
1114
+ constant int64_t & ne02,
1115
+ constant uint64_t & nb00,
1116
+ constant uint64_t & nb01,
1117
+ constant uint64_t & nb02,
1118
+ constant int64_t & ne10,
1119
+ constant int64_t & ne11,
1120
+ constant int64_t & ne12,
1121
+ constant uint64_t & nb10,
1122
+ constant uint64_t & nb11,
1123
+ constant uint64_t & nb12,
1124
+ constant int64_t & ne0,
1125
+ constant int64_t & ne1,
1126
+ constant uint & r2 [[buffer(17)]],
1127
+ constant uint & r3 [[buffer(18)]],
1128
+ uint3 tgpig[[threadgroup_position_in_grid]],
1129
+ uint tiisg[[thread_index_in_simdgroup]]) {
619
1130
 
620
1131
  const int64_t r0 = tgpig.x;
621
1132
  const int64_t r1 = tgpig.y;
622
1133
  const int64_t im = tgpig.z;
623
1134
 
624
- device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
1135
+ const uint i12 = im%ne12;
1136
+ const uint i13 = im/ne12;
1137
+
1138
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1139
+
1140
+ device const half * x = (device const half *) (src0 + offset0);
625
1141
  device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
626
1142
 
627
1143
  float sumf = 0;
@@ -650,7 +1166,7 @@ kernel void kernel_mul_mat_f16_f32_1row(
650
1166
 
651
1167
  #define N_F16_F32 4
652
1168
 
653
- kernel void kernel_mul_mat_f16_f32(
1169
+ kernel void kernel_mul_mv_f16_f32(
654
1170
  device const char * src0,
655
1171
  device const char * src1,
656
1172
  device float * dst,
@@ -668,6 +1184,8 @@ kernel void kernel_mul_mat_f16_f32(
668
1184
  constant uint64_t & nb12,
669
1185
  constant int64_t & ne0,
670
1186
  constant int64_t & ne1,
1187
+ constant uint & r2 [[buffer(17)]],
1188
+ constant uint & r3 [[buffer(18)]],
671
1189
  uint3 tgpig[[threadgroup_position_in_grid]],
672
1190
  uint tiisg[[thread_index_in_simdgroup]]) {
673
1191
 
@@ -675,7 +1193,12 @@ kernel void kernel_mul_mat_f16_f32(
675
1193
  const int64_t rb = tgpig.y*N_F16_F32;
676
1194
  const int64_t im = tgpig.z;
677
1195
 
678
- device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
1196
+ const uint i12 = im%ne12;
1197
+ const uint i13 = im/ne12;
1198
+
1199
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1200
+
1201
+ device const half * x = (device const half *) (src0 + offset0);
679
1202
 
680
1203
  if (ne00 < 128) {
681
1204
  for (int row = 0; row < N_F16_F32; ++row) {
@@ -722,7 +1245,7 @@ kernel void kernel_mul_mat_f16_f32(
722
1245
  }
723
1246
 
724
1247
  // Assumes row size (ne00) is a multiple of 4
725
- kernel void kernel_mul_mat_f16_f32_l4(
1248
+ kernel void kernel_mul_mv_f16_f32_l4(
726
1249
  device const char * src0,
727
1250
  device const char * src1,
728
1251
  device float * dst,
@@ -740,33 +1263,387 @@ kernel void kernel_mul_mat_f16_f32_l4(
740
1263
  constant uint64_t & nb12,
741
1264
  constant int64_t & ne0,
742
1265
  constant int64_t & ne1,
1266
+ constant uint & r2 [[buffer(17)]],
1267
+ constant uint & r3 [[buffer(18)]],
1268
+ uint3 tgpig[[threadgroup_position_in_grid]],
1269
+ uint tiisg[[thread_index_in_simdgroup]]) {
1270
+
1271
+ const int nrows = ne11;
1272
+ const int64_t r0 = tgpig.x;
1273
+ const int64_t im = tgpig.z;
1274
+
1275
+ const uint i12 = im%ne12;
1276
+ const uint i13 = im/ne12;
1277
+
1278
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1279
+
1280
+ device const half4 * x4 = (device const half4 *) (src0 + offset0);
1281
+
1282
+ for (int r1 = 0; r1 < nrows; ++r1) {
1283
+ device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
1284
+
1285
+ float sumf = 0;
1286
+ for (int i = tiisg; i < ne00/4; i += 32) {
1287
+ for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
1288
+ }
1289
+
1290
+ float all_sum = simd_sum(sumf);
1291
+ if (tiisg == 0) {
1292
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1293
+ }
1294
+ }
1295
+ }
1296
+
1297
+ kernel void kernel_alibi_f32(
1298
+ device const float * src0,
1299
+ device float * dst,
1300
+ constant int64_t & ne00,
1301
+ constant int64_t & ne01,
1302
+ constant int64_t & ne02,
1303
+ constant int64_t & ne03,
1304
+ constant uint64_t & nb00,
1305
+ constant uint64_t & nb01,
1306
+ constant uint64_t & nb02,
1307
+ constant uint64_t & nb03,
1308
+ constant int64_t & ne0,
1309
+ constant int64_t & ne1,
1310
+ constant int64_t & ne2,
1311
+ constant int64_t & ne3,
1312
+ constant uint64_t & nb0,
1313
+ constant uint64_t & nb1,
1314
+ constant uint64_t & nb2,
1315
+ constant uint64_t & nb3,
1316
+ constant float & m0,
1317
+ constant float & m1,
1318
+ constant int & n_heads_log2_floor,
1319
+ uint3 tgpig[[threadgroup_position_in_grid]],
1320
+ uint3 tpitg[[thread_position_in_threadgroup]],
1321
+ uint3 ntg[[threads_per_threadgroup]]) {
1322
+ const int64_t i03 = tgpig[2];
1323
+ const int64_t i02 = tgpig[1];
1324
+ const int64_t i01 = tgpig[0];
1325
+
1326
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
1327
+
1328
+ const int64_t i3 = n / (ne2*ne1*ne0);
1329
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1330
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1331
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1332
+ const int64_t k = i3*ne3 + i2;
1333
+
1334
+ float m_k;
1335
+ if (k < n_heads_log2_floor) {
1336
+ m_k = pow(m0, k + 1);
1337
+ } else {
1338
+ m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1);
1339
+ }
1340
+
1341
+ device char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1;
1342
+ device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01;
1343
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
1344
+ const float src_v = *(device float *)(src_row + i00*nb00);
1345
+ device float * dst_v = (device float *)(dst_row + i00*nb0);
1346
+ *dst_v = i00 * m_k + src_v;
1347
+ }
1348
+ }
1349
+
1350
+ static float rope_yarn_ramp(const float low, const float high, const int i0) {
1351
+ const float y = (i0 / 2 - low) / max(0.001f, high - low);
1352
+ return 1.0f - min(1.0f, max(0.0f, y));
1353
+ }
1354
+
1355
+ // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
1356
+ // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
1357
+ static void rope_yarn(
1358
+ float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
1359
+ thread float * cos_theta, thread float * sin_theta
1360
+ ) {
1361
+ // Get n-d rotational scaling corrected for extrapolation
1362
+ float theta_interp = freq_scale * theta_extrap;
1363
+ float theta = theta_interp;
1364
+ if (ext_factor != 0.0f) {
1365
+ float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
1366
+ theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
1367
+
1368
+ // Get n-d magnitude scaling corrected for interpolation
1369
+ mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
1370
+ }
1371
+ *cos_theta = cos(theta) * mscale;
1372
+ *sin_theta = sin(theta) * mscale;
1373
+ }
1374
+
1375
+ // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
1376
+ // `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
1377
+ static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
1378
+ return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base));
1379
+ }
1380
+
1381
+ static void rope_yarn_corr_dims(
1382
+ int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
1383
+ ) {
1384
+ // start and end correction dims
1385
+ dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
1386
+ dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
1387
+ }
1388
+
1389
+ typedef void (rope_t)(
1390
+ device const void * src0,
1391
+ device const int32_t * src1,
1392
+ device float * dst,
1393
+ constant int64_t & ne00,
1394
+ constant int64_t & ne01,
1395
+ constant int64_t & ne02,
1396
+ constant int64_t & ne03,
1397
+ constant uint64_t & nb00,
1398
+ constant uint64_t & nb01,
1399
+ constant uint64_t & nb02,
1400
+ constant uint64_t & nb03,
1401
+ constant int64_t & ne0,
1402
+ constant int64_t & ne1,
1403
+ constant int64_t & ne2,
1404
+ constant int64_t & ne3,
1405
+ constant uint64_t & nb0,
1406
+ constant uint64_t & nb1,
1407
+ constant uint64_t & nb2,
1408
+ constant uint64_t & nb3,
1409
+ constant int & n_past,
1410
+ constant int & n_dims,
1411
+ constant int & mode,
1412
+ constant int & n_orig_ctx,
1413
+ constant float & freq_base,
1414
+ constant float & freq_scale,
1415
+ constant float & ext_factor,
1416
+ constant float & attn_factor,
1417
+ constant float & beta_fast,
1418
+ constant float & beta_slow,
1419
+ uint tiitg[[thread_index_in_threadgroup]],
1420
+ uint3 tptg[[threads_per_threadgroup]],
1421
+ uint3 tgpig[[threadgroup_position_in_grid]]);
1422
+
1423
+ template<typename T>
1424
+ kernel void kernel_rope(
1425
+ device const void * src0,
1426
+ device const int32_t * src1,
1427
+ device float * dst,
1428
+ constant int64_t & ne00,
1429
+ constant int64_t & ne01,
1430
+ constant int64_t & ne02,
1431
+ constant int64_t & ne03,
1432
+ constant uint64_t & nb00,
1433
+ constant uint64_t & nb01,
1434
+ constant uint64_t & nb02,
1435
+ constant uint64_t & nb03,
1436
+ constant int64_t & ne0,
1437
+ constant int64_t & ne1,
1438
+ constant int64_t & ne2,
1439
+ constant int64_t & ne3,
1440
+ constant uint64_t & nb0,
1441
+ constant uint64_t & nb1,
1442
+ constant uint64_t & nb2,
1443
+ constant uint64_t & nb3,
1444
+ constant int & n_past,
1445
+ constant int & n_dims,
1446
+ constant int & mode,
1447
+ constant int & n_orig_ctx,
1448
+ constant float & freq_base,
1449
+ constant float & freq_scale,
1450
+ constant float & ext_factor,
1451
+ constant float & attn_factor,
1452
+ constant float & beta_fast,
1453
+ constant float & beta_slow,
1454
+ uint tiitg[[thread_index_in_threadgroup]],
1455
+ uint3 tptg[[threads_per_threadgroup]],
1456
+ uint3 tgpig[[threadgroup_position_in_grid]]) {
1457
+ const int64_t i3 = tgpig[2];
1458
+ const int64_t i2 = tgpig[1];
1459
+ const int64_t i1 = tgpig[0];
1460
+
1461
+ const bool is_neox = mode & 2;
1462
+
1463
+ float corr_dims[2];
1464
+ rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
1465
+
1466
+ device const int32_t * pos = src1;
1467
+
1468
+ const int64_t p = pos[i2];
1469
+
1470
+ const float theta_0 = (float)p;
1471
+ const float inv_ndims = -1.f/n_dims;
1472
+
1473
+ if (!is_neox) {
1474
+ for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
1475
+
1476
+ const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
1477
+ float cos_theta, sin_theta;
1478
+ rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
1479
+
1480
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1481
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1482
+
1483
+ const T x0 = src[0];
1484
+ const T x1 = src[1];
1485
+
1486
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
1487
+ dst_data[1] = x0*sin_theta + x1*cos_theta;
1488
+ }
1489
+ } else {
1490
+ for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
1491
+ for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
1492
+
1493
+ // simplified from `(ib * n_dims + ic) * inv_ndims`
1494
+ const float cur_rot = inv_ndims*ic - ib;
1495
+
1496
+ const float theta = theta_0 * pow(freq_base, cur_rot);
1497
+ float cos_theta, sin_theta;
1498
+ rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
1499
+
1500
+ const int64_t i0 = ib*n_dims + ic/2;
1501
+
1502
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1503
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1504
+
1505
+ const float x0 = src[0];
1506
+ const float x1 = src[n_dims/2];
1507
+
1508
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
1509
+ dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
1510
+ }
1511
+ }
1512
+ }
1513
+ }
1514
+
1515
+ template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
1516
+ template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
1517
+
1518
+ kernel void kernel_im2col_f16(
1519
+ device const float * x,
1520
+ device half * dst,
1521
+ constant int32_t & ofs0,
1522
+ constant int32_t & ofs1,
1523
+ constant int32_t & IW,
1524
+ constant int32_t & IH,
1525
+ constant int32_t & CHW,
1526
+ constant int32_t & s0,
1527
+ constant int32_t & s1,
1528
+ constant int32_t & p0,
1529
+ constant int32_t & p1,
1530
+ constant int32_t & d0,
1531
+ constant int32_t & d1,
1532
+ uint3 tgpig[[threadgroup_position_in_grid]],
1533
+ uint3 tgpg[[threadgroups_per_grid]],
1534
+ uint3 tpitg[[thread_position_in_threadgroup]],
1535
+ uint3 ntg[[threads_per_threadgroup]]) {
1536
+ const int32_t iiw = tgpig[2] * s0 + tpitg[2] * d0 - p0;
1537
+ const int32_t iih = tgpig[1] * s1 + tpitg[1] * d1 - p1;
1538
+
1539
+ const int32_t offset_dst =
1540
+ (tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
1541
+ (tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
1542
+
1543
+ if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
1544
+ dst[offset_dst] = 0.0f;
1545
+ } else {
1546
+ const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
1547
+ dst[offset_dst] = x[offset_src + iih * IW + iiw];
1548
+ }
1549
+ }
1550
+
1551
+ // bitonic sort implementation following the CUDA kernels as reference
1552
+ typedef void (argsort_t)(
1553
+ device const float * x,
1554
+ device int32_t * dst,
1555
+ constant int64_t & ncols,
1556
+ uint3 tgpig[[threadgroup_position_in_grid]],
1557
+ uint3 tpitg[[thread_position_in_threadgroup]]);
1558
+
1559
+ template<ggml_sort_order order>
1560
+ kernel void kernel_argsort_f32_i32(
1561
+ device const float * x,
1562
+ device int32_t * dst,
1563
+ constant int64_t & ncols,
1564
+ uint3 tgpig[[threadgroup_position_in_grid]],
1565
+ uint3 tpitg[[thread_position_in_threadgroup]]) {
1566
+ // bitonic sort
1567
+ int col = tpitg[0];
1568
+ int row = tgpig[1];
1569
+
1570
+ if (col >= ncols) return;
1571
+
1572
+ device const float * x_row = x + row * ncols;
1573
+ device int32_t * dst_row = dst + row * ncols;
1574
+
1575
+ // initialize indices
1576
+ if (col < ncols) {
1577
+ dst_row[col] = col;
1578
+ }
1579
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1580
+
1581
+ for (int k = 2; k <= ncols; k *= 2) {
1582
+ for (int j = k / 2; j > 0; j /= 2) {
1583
+ int ixj = col ^ j;
1584
+ if (ixj > col) {
1585
+ if ((col & k) == 0) {
1586
+ if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
1587
+ SWAP(dst_row[col], dst_row[ixj]);
1588
+ }
1589
+ } else {
1590
+ if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
1591
+ SWAP(dst_row[col], dst_row[ixj]);
1592
+ }
1593
+ }
1594
+ }
1595
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1596
+ }
1597
+ }
1598
+ }
1599
+
1600
+ template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ASC>;
1601
+ template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
1602
+
1603
+ kernel void kernel_cpy_f16_f16(
1604
+ device const half * src0,
1605
+ device half * dst,
1606
+ constant int64_t & ne00,
1607
+ constant int64_t & ne01,
1608
+ constant int64_t & ne02,
1609
+ constant int64_t & ne03,
1610
+ constant uint64_t & nb00,
1611
+ constant uint64_t & nb01,
1612
+ constant uint64_t & nb02,
1613
+ constant uint64_t & nb03,
1614
+ constant int64_t & ne0,
1615
+ constant int64_t & ne1,
1616
+ constant int64_t & ne2,
1617
+ constant int64_t & ne3,
1618
+ constant uint64_t & nb0,
1619
+ constant uint64_t & nb1,
1620
+ constant uint64_t & nb2,
1621
+ constant uint64_t & nb3,
743
1622
  uint3 tgpig[[threadgroup_position_in_grid]],
744
- uint tiisg[[thread_index_in_simdgroup]]) {
745
-
746
- const int nrows = ne11;
747
- const int64_t r0 = tgpig.x;
748
- const int64_t im = tgpig.z;
1623
+ uint3 tpitg[[thread_position_in_threadgroup]],
1624
+ uint3 ntg[[threads_per_threadgroup]]) {
1625
+ const int64_t i03 = tgpig[2];
1626
+ const int64_t i02 = tgpig[1];
1627
+ const int64_t i01 = tgpig[0];
749
1628
 
750
- device const half4 * x4 = (device const half4 *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
1629
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
751
1630
 
752
- for (int r1 = 0; r1 < nrows; ++r1) {
753
- device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
1631
+ const int64_t i3 = n / (ne2*ne1*ne0);
1632
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1633
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1634
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
754
1635
 
755
- float sumf = 0;
756
- for (int i = tiisg; i < ne00/4; i += 32) {
757
- for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
758
- }
1636
+ device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
759
1637
 
760
- float all_sum = simd_sum(sumf);
761
- if (tiisg == 0) {
762
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
763
- }
1638
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
1639
+ device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1640
+ dst_data[i00] = src[0];
764
1641
  }
765
1642
  }
766
1643
 
767
- kernel void kernel_alibi_f32(
1644
+ kernel void kernel_cpy_f32_f16(
768
1645
  device const float * src0,
769
- device float * dst,
1646
+ device half * dst,
770
1647
  constant int64_t & ne00,
771
1648
  constant int64_t & ne01,
772
1649
  constant int64_t & ne02,
@@ -783,7 +1660,6 @@ kernel void kernel_alibi_f32(
783
1660
  constant uint64_t & nb1,
784
1661
  constant uint64_t & nb2,
785
1662
  constant uint64_t & nb3,
786
- constant float & m0,
787
1663
  uint3 tgpig[[threadgroup_position_in_grid]],
788
1664
  uint3 tpitg[[thread_position_in_threadgroup]],
789
1665
  uint3 ntg[[threads_per_threadgroup]]) {
@@ -798,16 +1674,17 @@ kernel void kernel_alibi_f32(
798
1674
  const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
799
1675
  const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
800
1676
 
801
- device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
802
- float m_k = pow(m0, i2 + 1);
1677
+ device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1678
+
803
1679
  for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
804
1680
  device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
805
- dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);
1681
+
1682
+ dst_data[i00] = src[0];
806
1683
  }
807
1684
  }
808
1685
 
809
- kernel void kernel_rope(
810
- device const void * src0,
1686
+ kernel void kernel_cpy_f32_f32(
1687
+ device const float * src0,
811
1688
  device float * dst,
812
1689
  constant int64_t & ne00,
813
1690
  constant int64_t & ne01,
@@ -825,67 +1702,32 @@ kernel void kernel_rope(
825
1702
  constant uint64_t & nb1,
826
1703
  constant uint64_t & nb2,
827
1704
  constant uint64_t & nb3,
828
- constant int & n_past,
829
- constant int & n_dims,
830
- constant int & mode,
831
- constant float & freq_base,
832
- constant float & freq_scale,
833
- uint tiitg[[thread_index_in_threadgroup]],
834
- uint3 tptg[[threads_per_threadgroup]],
835
- uint3 tgpig[[threadgroup_position_in_grid]]) {
836
- const int64_t i3 = tgpig[2];
837
- const int64_t i2 = tgpig[1];
838
- const int64_t i1 = tgpig[0];
839
-
840
- const bool is_neox = mode & 2;
841
-
842
- const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
843
-
844
- const float theta_0 = freq_scale * (float)p;
845
- const float inv_ndims = -1.f/n_dims;
846
-
847
- if (!is_neox) {
848
- for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
849
-
850
- const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
851
- const float cos_theta = cos(theta);
852
- const float sin_theta = sin(theta);
853
-
854
- device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
855
- device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
856
-
857
- const float x0 = src[0];
858
- const float x1 = src[1];
859
-
860
- dst_data[0] = x0*cos_theta - x1*sin_theta;
861
- dst_data[1] = x0*sin_theta + x1*cos_theta;
862
- }
863
- } else {
864
- for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
865
- for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
1705
+ uint3 tgpig[[threadgroup_position_in_grid]],
1706
+ uint3 tpitg[[thread_position_in_threadgroup]],
1707
+ uint3 ntg[[threads_per_threadgroup]]) {
1708
+ const int64_t i03 = tgpig[2];
1709
+ const int64_t i02 = tgpig[1];
1710
+ const int64_t i01 = tgpig[0];
866
1711
 
867
- const float theta = theta_0 * pow(freq_base, inv_ndims*ic - ib);
868
- const float cos_theta = cos(theta);
869
- const float sin_theta = sin(theta);
1712
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
870
1713
 
871
- const int64_t i0 = ib*n_dims + ic/2;
1714
+ const int64_t i3 = n / (ne2*ne1*ne0);
1715
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1716
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1717
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
872
1718
 
873
- device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
874
- device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1719
+ device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
875
1720
 
876
- const float x0 = src[0];
877
- const float x1 = src[n_dims/2];
1721
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
1722
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
878
1723
 
879
- dst_data[0] = x0*cos_theta - x1*sin_theta;
880
- dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
881
- }
882
- }
1724
+ dst_data[i00] = src[0];
883
1725
  }
884
1726
  }
885
1727
 
886
- kernel void kernel_cpy_f16_f16(
887
- device const half * src0,
888
- device half * dst,
1728
+ kernel void kernel_cpy_f32_q8_0(
1729
+ device const float * src0,
1730
+ device void * dst,
889
1731
  constant int64_t & ne00,
890
1732
  constant int64_t & ne01,
891
1733
  constant int64_t & ne02,
@@ -914,19 +1756,36 @@ kernel void kernel_cpy_f16_f16(
914
1756
  const int64_t i3 = n / (ne2*ne1*ne0);
915
1757
  const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
916
1758
  const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
917
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1759
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK8_0;
918
1760
 
919
- device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1761
+ device block_q8_0 * dst_data = (device block_q8_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
920
1762
 
921
- for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
922
- device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
923
- dst_data[i00] = src[0];
1763
+ for (int64_t i00 = tpitg.x*QK8_0; i00 < ne00; i00 += ntg.x*QK8_0) {
1764
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1765
+
1766
+ float amax = 0.0f; // absolute max
1767
+
1768
+ for (int j = 0; j < QK8_0; j++) {
1769
+ const float v = src[j];
1770
+ amax = MAX(amax, fabs(v));
1771
+ }
1772
+
1773
+ const float d = amax / ((1 << 7) - 1);
1774
+ const float id = d ? 1.0f/d : 0.0f;
1775
+
1776
+ dst_data[i00/QK8_0].d = d;
1777
+
1778
+ for (int j = 0; j < QK8_0; ++j) {
1779
+ const float x0 = src[j]*id;
1780
+
1781
+ dst_data[i00/QK8_0].qs[j] = round(x0);
1782
+ }
924
1783
  }
925
1784
  }
926
1785
 
927
- kernel void kernel_cpy_f32_f16(
1786
+ kernel void kernel_cpy_f32_q4_0(
928
1787
  device const float * src0,
929
- device half * dst,
1788
+ device void * dst,
930
1789
  constant int64_t & ne00,
931
1790
  constant int64_t & ne01,
932
1791
  constant int64_t & ne02,
@@ -955,20 +1814,45 @@ kernel void kernel_cpy_f32_f16(
955
1814
  const int64_t i3 = n / (ne2*ne1*ne0);
956
1815
  const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
957
1816
  const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
958
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1817
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_0;
959
1818
 
960
- device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1819
+ device block_q4_0 * dst_data = (device block_q4_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
961
1820
 
962
- for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
1821
+ for (int64_t i00 = tpitg.x*QK4_0; i00 < ne00; i00 += ntg.x*QK4_0) {
963
1822
  device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
964
1823
 
965
- dst_data[i00] = src[0];
1824
+ float amax = 0.0f; // absolute max
1825
+ float max = 0.0f;
1826
+
1827
+ for (int j = 0; j < QK4_0; j++) {
1828
+ const float v = src[j];
1829
+ if (amax < fabs(v)) {
1830
+ amax = fabs(v);
1831
+ max = v;
1832
+ }
1833
+ }
1834
+
1835
+ const float d = max / -8;
1836
+ const float id = d ? 1.0f/d : 0.0f;
1837
+
1838
+ dst_data[i00/QK4_0].d = d;
1839
+
1840
+ for (int j = 0; j < QK4_0/2; ++j) {
1841
+ const float x0 = src[0 + j]*id;
1842
+ const float x1 = src[QK4_0/2 + j]*id;
1843
+
1844
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
1845
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
1846
+
1847
+ dst_data[i00/QK4_0].qs[j] = xi0;
1848
+ dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
1849
+ }
966
1850
  }
967
1851
  }
968
1852
 
969
- kernel void kernel_cpy_f32_f32(
1853
+ kernel void kernel_cpy_f32_q4_1(
970
1854
  device const float * src0,
971
- device float * dst,
1855
+ device void * dst,
972
1856
  constant int64_t & ne00,
973
1857
  constant int64_t & ne01,
974
1858
  constant int64_t & ne02,
@@ -997,14 +1881,94 @@ kernel void kernel_cpy_f32_f32(
997
1881
  const int64_t i3 = n / (ne2*ne1*ne0);
998
1882
  const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
999
1883
  const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1000
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1884
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_1;
1001
1885
 
1002
- device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1886
+ device block_q4_1 * dst_data = (device block_q4_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1003
1887
 
1004
- for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
1888
+ for (int64_t i00 = tpitg.x*QK4_1; i00 < ne00; i00 += ntg.x*QK4_1) {
1005
1889
  device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1006
1890
 
1007
- dst_data[i00] = src[0];
1891
+ float min = FLT_MAX;
1892
+ float max = -FLT_MAX;
1893
+
1894
+ for (int j = 0; j < QK4_1; j++) {
1895
+ const float v = src[j];
1896
+ if (min > v) min = v;
1897
+ if (max < v) max = v;
1898
+ }
1899
+
1900
+ const float d = (max - min) / ((1 << 4) - 1);
1901
+ const float id = d ? 1.0f/d : 0.0f;
1902
+
1903
+ dst_data[i00/QK4_1].d = d;
1904
+ dst_data[i00/QK4_1].m = min;
1905
+
1906
+ for (int j = 0; j < QK4_1/2; ++j) {
1907
+ const float x0 = (src[0 + j] - min)*id;
1908
+ const float x1 = (src[QK4_1/2 + j] - min)*id;
1909
+
1910
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
1911
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
1912
+
1913
+ dst_data[i00/QK4_1].qs[j] = xi0;
1914
+ dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
1915
+ }
1916
+ }
1917
+ }
1918
+
1919
+ kernel void kernel_concat(
1920
+ device const char * src0,
1921
+ device const char * src1,
1922
+ device char * dst,
1923
+ constant int64_t & ne00,
1924
+ constant int64_t & ne01,
1925
+ constant int64_t & ne02,
1926
+ constant int64_t & ne03,
1927
+ constant uint64_t & nb00,
1928
+ constant uint64_t & nb01,
1929
+ constant uint64_t & nb02,
1930
+ constant uint64_t & nb03,
1931
+ constant int64_t & ne10,
1932
+ constant int64_t & ne11,
1933
+ constant int64_t & ne12,
1934
+ constant int64_t & ne13,
1935
+ constant uint64_t & nb10,
1936
+ constant uint64_t & nb11,
1937
+ constant uint64_t & nb12,
1938
+ constant uint64_t & nb13,
1939
+ constant int64_t & ne0,
1940
+ constant int64_t & ne1,
1941
+ constant int64_t & ne2,
1942
+ constant int64_t & ne3,
1943
+ constant uint64_t & nb0,
1944
+ constant uint64_t & nb1,
1945
+ constant uint64_t & nb2,
1946
+ constant uint64_t & nb3,
1947
+ uint3 tgpig[[threadgroup_position_in_grid]],
1948
+ uint3 tpitg[[thread_position_in_threadgroup]],
1949
+ uint3 ntg[[threads_per_threadgroup]]) {
1950
+
1951
+ const int64_t i03 = tgpig.z;
1952
+ const int64_t i02 = tgpig.y;
1953
+ const int64_t i01 = tgpig.x;
1954
+
1955
+ const int64_t i13 = i03 % ne13;
1956
+ const int64_t i12 = i02 % ne12;
1957
+ const int64_t i11 = i01 % ne11;
1958
+
1959
+ device const char * src0_ptr = src0 + i03 * nb03 + i02 * nb02 + i01 * nb01 + tpitg.x*nb00;
1960
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
1961
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
1962
+
1963
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
1964
+ if (i02 < ne02) {
1965
+ ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0];
1966
+ src0_ptr += ntg.x*nb00;
1967
+ } else {
1968
+ ((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0];
1969
+ src1_ptr += ntg.x*nb10;
1970
+ }
1971
+ dst_ptr += ntg.x*nb0;
1008
1972
  }
1009
1973
  }
1010
1974
 
@@ -1100,7 +2064,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
1100
2064
 
1101
2065
  //====================================== dot products =========================
1102
2066
 
1103
- kernel void kernel_mul_mat_q2_K_f32(
2067
+ kernel void kernel_mul_mv_q2_K_f32(
1104
2068
  device const void * src0,
1105
2069
  device const float * src1,
1106
2070
  device float * dst,
@@ -1109,23 +2073,30 @@ kernel void kernel_mul_mat_q2_K_f32(
1109
2073
  constant int64_t & ne02[[buffer(5)]],
1110
2074
  constant int64_t & ne10[[buffer(9)]],
1111
2075
  constant int64_t & ne12[[buffer(11)]],
1112
- constant int64_t & ne0[[buffer(15)]],
1113
- constant int64_t & ne1[[buffer(16)]],
1114
- constant uint & gqa[[buffer(17)]],
2076
+ constant int64_t & ne0 [[buffer(15)]],
2077
+ constant int64_t & ne1 [[buffer(16)]],
2078
+ constant uint & r2 [[buffer(17)]],
2079
+ constant uint & r3 [[buffer(18)]],
1115
2080
  uint3 tgpig[[threadgroup_position_in_grid]],
1116
- uint tiisg[[thread_index_in_simdgroup]],
1117
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2081
+ uint tiisg[[thread_index_in_simdgroup]],
2082
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1118
2083
 
1119
2084
  const int nb = ne00/QK_K;
1120
2085
  const int r0 = tgpig.x;
1121
2086
  const int r1 = tgpig.y;
1122
- const int r2 = tgpig.z;
2087
+ const int im = tgpig.z;
1123
2088
 
1124
2089
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1125
2090
  const int ib_row = first_row * nb;
1126
- const uint offset0 = r2/gqa*(nb*ne0);
2091
+
2092
+ const uint i12 = im%ne12;
2093
+ const uint i13 = im/ne12;
2094
+
2095
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2096
+
1127
2097
  device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
1128
- device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2098
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2099
+
1129
2100
  float yl[32];
1130
2101
  float sumf[N_DST]={0.f}, all_sum;
1131
2102
 
@@ -1134,11 +2105,11 @@ kernel void kernel_mul_mat_q2_K_f32(
1134
2105
  #if QK_K == 256
1135
2106
  const int ix = tiisg/8; // 0...3
1136
2107
  const int it = tiisg%8; // 0...7
1137
- const int im = it/4; // 0 or 1
2108
+ const int iq = it/4; // 0 or 1
1138
2109
  const int ir = it%4; // 0...3
1139
2110
  const int is = (8*ir)/16;// 0 or 1
1140
2111
 
1141
- device const float * y4 = y + ix * QK_K + 128 * im + 8 * ir;
2112
+ device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
1142
2113
 
1143
2114
  for (int ib = ix; ib < nb; ib += 4) {
1144
2115
 
@@ -1150,8 +2121,8 @@ kernel void kernel_mul_mat_q2_K_f32(
1150
2121
  yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
1151
2122
  }
1152
2123
 
1153
- device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*im + is;
1154
- device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
2124
+ device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*iq + is;
2125
+ device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
1155
2126
  device const half * dh = &x[ib].d;
1156
2127
 
1157
2128
  for (int row = 0; row < N_DST; row++) {
@@ -1238,13 +2209,13 @@ kernel void kernel_mul_mat_q2_K_f32(
1238
2209
  for (int row = 0; row < N_DST; ++row) {
1239
2210
  all_sum = simd_sum(sumf[row]);
1240
2211
  if (tiisg == 0) {
1241
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
2212
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
1242
2213
  }
1243
2214
  }
1244
2215
  }
1245
2216
 
1246
2217
  #if QK_K == 256
1247
- kernel void kernel_mul_mat_q3_K_f32(
2218
+ kernel void kernel_mul_mv_q3_K_f32(
1248
2219
  device const void * src0,
1249
2220
  device const float * src1,
1250
2221
  device float * dst,
@@ -1253,9 +2224,10 @@ kernel void kernel_mul_mat_q3_K_f32(
1253
2224
  constant int64_t & ne02[[buffer(5)]],
1254
2225
  constant int64_t & ne10[[buffer(9)]],
1255
2226
  constant int64_t & ne12[[buffer(11)]],
1256
- constant int64_t & ne0[[buffer(15)]],
1257
- constant int64_t & ne1[[buffer(16)]],
1258
- constant uint & gqa[[buffer(17)]],
2227
+ constant int64_t & ne0 [[buffer(15)]],
2228
+ constant int64_t & ne1 [[buffer(16)]],
2229
+ constant uint & r2 [[buffer(17)]],
2230
+ constant uint & r3 [[buffer(18)]],
1259
2231
  uint3 tgpig[[threadgroup_position_in_grid]],
1260
2232
  uint tiisg[[thread_index_in_simdgroup]],
1261
2233
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1264,17 +2236,22 @@ kernel void kernel_mul_mat_q3_K_f32(
1264
2236
 
1265
2237
  const int64_t r0 = tgpig.x;
1266
2238
  const int64_t r1 = tgpig.y;
1267
- const int64_t r2 = tgpig.z;
2239
+ const int64_t im = tgpig.z;
1268
2240
 
1269
2241
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
1270
- const uint offset0 = r2/gqa*(nb*ne0);
2242
+
2243
+ const uint i12 = im%ne12;
2244
+ const uint i13 = im/ne12;
2245
+
2246
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2247
+
1271
2248
  device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
1272
- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2249
+ device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
1273
2250
 
1274
2251
  float yl[32];
1275
2252
 
1276
- const uint16_t kmask1 = 0x3030;
1277
- const uint16_t kmask2 = 0x0f0f;
2253
+ //const uint16_t kmask1 = 0x3030;
2254
+ //const uint16_t kmask2 = 0x0f0f;
1278
2255
 
1279
2256
  const int tid = tiisg/4;
1280
2257
  const int ix = tiisg%4;
@@ -1391,12 +2368,12 @@ kernel void kernel_mul_mat_q3_K_f32(
1391
2368
  }
1392
2369
  if (tiisg == 0) {
1393
2370
  for (int row = 0; row < 2; ++row) {
1394
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
2371
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = sumf1[row];
1395
2372
  }
1396
2373
  }
1397
2374
  }
1398
2375
  #else
1399
- kernel void kernel_mul_mat_q3_K_f32(
2376
+ kernel void kernel_mul_mv_q3_K_f32(
1400
2377
  device const void * src0,
1401
2378
  device const float * src1,
1402
2379
  device float * dst,
@@ -1405,26 +2382,33 @@ kernel void kernel_mul_mat_q3_K_f32(
1405
2382
  constant int64_t & ne02[[buffer(5)]],
1406
2383
  constant int64_t & ne10[[buffer(9)]],
1407
2384
  constant int64_t & ne12[[buffer(11)]],
1408
- constant int64_t & ne0[[buffer(15)]],
1409
- constant int64_t & ne1[[buffer(16)]],
1410
- constant uint & gqa[[buffer(17)]],
2385
+ constant int64_t & ne0 [[buffer(15)]],
2386
+ constant int64_t & ne1 [[buffer(16)]],
2387
+ constant uint & r2 [[buffer(17)]],
2388
+ constant uint & r3 [[buffer(18)]],
1411
2389
  uint3 tgpig[[threadgroup_position_in_grid]],
1412
- uint tiisg[[thread_index_in_simdgroup]],
1413
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2390
+ uint tiisg[[thread_index_in_simdgroup]],
2391
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1414
2392
 
1415
2393
  const int nb = ne00/QK_K;
1416
2394
 
1417
2395
  const int64_t r0 = tgpig.x;
1418
2396
  const int64_t r1 = tgpig.y;
1419
- const int64_t r2 = tgpig.z;
2397
+ const int64_t im = tgpig.z;
1420
2398
 
1421
2399
  const int row = 2 * r0 + sgitg;
1422
- const uint offset0 = r2/gqa*(nb*ne0);
2400
+
2401
+ const uint i12 = im%ne12;
2402
+ const uint i13 = im/ne12;
2403
+
2404
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2405
+
1423
2406
  device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
1424
- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2407
+ device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2408
+
1425
2409
  const int ix = tiisg/4;
1426
2410
  const int il = 4 * (tiisg%4);// 0, 4, 8, 12
1427
- const int im = il/8; // 0, 0, 1, 1
2411
+ const int iq = il/8; // 0, 0, 1, 1
1428
2412
  const int in = il%8; // 0, 4, 0, 4
1429
2413
 
1430
2414
  float2 sum = {0.f, 0.f};
@@ -1444,7 +2428,7 @@ kernel void kernel_mul_mat_q3_K_f32(
1444
2428
  const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
1445
2429
 
1446
2430
  for (int l = 0; l < 4; l += 2) {
1447
- const uint16_t hm = h[l/2] >> im;
2431
+ const uint16_t hm = h[l/2] >> iq;
1448
2432
  sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
1449
2433
  + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
1450
2434
  + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
@@ -1460,14 +2444,14 @@ kernel void kernel_mul_mat_q3_K_f32(
1460
2444
 
1461
2445
  const float tot = simd_sum(sumf);
1462
2446
  if (tiisg == 0) {
1463
- dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
2447
+ dst[r1*ne0 + im*ne0*ne1 + row] = tot;
1464
2448
  }
1465
2449
 
1466
2450
  }
1467
2451
  #endif
1468
2452
 
1469
2453
  #if QK_K == 256
1470
- kernel void kernel_mul_mat_q4_K_f32(
2454
+ kernel void kernel_mul_mv_q4_K_f32(
1471
2455
  device const void * src0,
1472
2456
  device const float * src1,
1473
2457
  device float * dst,
@@ -1478,10 +2462,11 @@ kernel void kernel_mul_mat_q4_K_f32(
1478
2462
  constant int64_t & ne12 [[buffer(11)]],
1479
2463
  constant int64_t & ne0 [[buffer(15)]],
1480
2464
  constant int64_t & ne1 [[buffer(16)]],
1481
- constant uint & gqa [[buffer(17)]],
2465
+ constant uint & r2 [[buffer(17)]],
2466
+ constant uint & r3 [[buffer(18)]],
1482
2467
  uint3 tgpig[[threadgroup_position_in_grid]],
1483
- uint tiisg[[thread_index_in_simdgroup]],
1484
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2468
+ uint tiisg[[thread_index_in_simdgroup]],
2469
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1485
2470
 
1486
2471
  const uint16_t kmask1 = 0x3f3f;
1487
2472
  const uint16_t kmask2 = 0x0f0f;
@@ -1489,26 +2474,32 @@ kernel void kernel_mul_mat_q4_K_f32(
1489
2474
 
1490
2475
  const int ix = tiisg/8; // 0...3
1491
2476
  const int it = tiisg%8; // 0...7
1492
- const int im = it/4; // 0 or 1
2477
+ const int iq = it/4; // 0 or 1
1493
2478
  const int ir = it%4; // 0...3
1494
2479
 
1495
2480
  const int nb = ne00/QK_K;
1496
2481
  const int r0 = tgpig.x;
1497
2482
  const int r1 = tgpig.y;
1498
- const int r2 = tgpig.z;
2483
+ const int im = tgpig.z;
1499
2484
  //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1500
2485
  const int first_row = r0 * N_DST;
1501
2486
  const int ib_row = first_row * nb;
1502
- const uint offset0 = r2/gqa*(nb*ne0);
2487
+
2488
+ const uint i12 = im%ne12;
2489
+ const uint i13 = im/ne12;
2490
+
2491
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2492
+
1503
2493
  device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
1504
- device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2494
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2495
+
1505
2496
  float yl[16];
1506
2497
  float yh[16];
1507
2498
  float sumf[N_DST]={0.f}, all_sum;
1508
2499
 
1509
2500
  const int step = sizeof(block_q4_K) * nb / 2;
1510
2501
 
1511
- device const float * y4 = y + ix * QK_K + 64 * im + 8 * ir;
2502
+ device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
1512
2503
 
1513
2504
  uint16_t sc16[4];
1514
2505
  thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
@@ -1523,8 +2514,8 @@ kernel void kernel_mul_mat_q4_K_f32(
1523
2514
  yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
1524
2515
  }
1525
2516
 
1526
- device const uint16_t * sc = (device const uint16_t *)x[ib].scales + im;
1527
- device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
2517
+ device const uint16_t * sc = (device const uint16_t *)x[ib].scales + iq;
2518
+ device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
1528
2519
  device const half * dh = &x[ib].d;
1529
2520
 
1530
2521
  for (int row = 0; row < N_DST; row++) {
@@ -1568,12 +2559,12 @@ kernel void kernel_mul_mat_q4_K_f32(
1568
2559
  for (int row = 0; row < N_DST; ++row) {
1569
2560
  all_sum = simd_sum(sumf[row]);
1570
2561
  if (tiisg == 0) {
1571
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
2562
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
1572
2563
  }
1573
2564
  }
1574
2565
  }
1575
2566
  #else
1576
- kernel void kernel_mul_mat_q4_K_f32(
2567
+ kernel void kernel_mul_mv_q4_K_f32(
1577
2568
  device const void * src0,
1578
2569
  device const float * src1,
1579
2570
  device float * dst,
@@ -1582,9 +2573,10 @@ kernel void kernel_mul_mat_q4_K_f32(
1582
2573
  constant int64_t & ne02[[buffer(5)]],
1583
2574
  constant int64_t & ne10[[buffer(9)]],
1584
2575
  constant int64_t & ne12[[buffer(11)]],
1585
- constant int64_t & ne0[[buffer(15)]],
1586
- constant int64_t & ne1[[buffer(16)]],
1587
- constant uint & gqa[[buffer(17)]],
2576
+ constant int64_t & ne0 [[buffer(15)]],
2577
+ constant int64_t & ne1 [[buffer(16)]],
2578
+ constant uint & r2 [[buffer(17)]],
2579
+ constant uint & r3 [[buffer(18)]],
1588
2580
  uint3 tgpig[[threadgroup_position_in_grid]],
1589
2581
  uint tiisg[[thread_index_in_simdgroup]],
1590
2582
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1595,12 +2587,18 @@ kernel void kernel_mul_mat_q4_K_f32(
1595
2587
  const int nb = ne00/QK_K;
1596
2588
  const int r0 = tgpig.x;
1597
2589
  const int r1 = tgpig.y;
1598
- const int r2 = tgpig.z;
2590
+ const int im = tgpig.z;
1599
2591
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1600
2592
  const int ib_row = first_row * nb;
1601
- const uint offset0 = r2/gqa*(nb*ne0);
2593
+
2594
+ const uint i12 = im%ne12;
2595
+ const uint i13 = im/ne12;
2596
+
2597
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2598
+
1602
2599
  device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
1603
- device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2600
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2601
+
1604
2602
  float yl[8];
1605
2603
  float yh[8];
1606
2604
  float sumf[N_DST]={0.f}, all_sum;
@@ -1656,13 +2654,13 @@ kernel void kernel_mul_mat_q4_K_f32(
1656
2654
  for (int row = 0; row < N_DST; ++row) {
1657
2655
  all_sum = simd_sum(sumf[row]);
1658
2656
  if (tiisg == 0) {
1659
- dst[r1*ne0+ r2*ne0*ne1 + first_row + row] = all_sum;
2657
+ dst[r1*ne0+ im*ne0*ne1 + first_row + row] = all_sum;
1660
2658
  }
1661
2659
  }
1662
2660
  }
1663
2661
  #endif
1664
2662
 
1665
- kernel void kernel_mul_mat_q5_K_f32(
2663
+ kernel void kernel_mul_mv_q5_K_f32(
1666
2664
  device const void * src0,
1667
2665
  device const float * src1,
1668
2666
  device float * dst,
@@ -1671,9 +2669,10 @@ kernel void kernel_mul_mat_q5_K_f32(
1671
2669
  constant int64_t & ne02[[buffer(5)]],
1672
2670
  constant int64_t & ne10[[buffer(9)]],
1673
2671
  constant int64_t & ne12[[buffer(11)]],
1674
- constant int64_t & ne0[[buffer(15)]],
1675
- constant int64_t & ne1[[buffer(16)]],
1676
- constant uint & gqa[[buffer(17)]],
2672
+ constant int64_t & ne0 [[buffer(15)]],
2673
+ constant int64_t & ne1 [[buffer(16)]],
2674
+ constant uint & r2 [[buffer(17)]],
2675
+ constant uint & r3 [[buffer(18)]],
1677
2676
  uint3 tgpig[[threadgroup_position_in_grid]],
1678
2677
  uint tiisg[[thread_index_in_simdgroup]],
1679
2678
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1682,12 +2681,17 @@ kernel void kernel_mul_mat_q5_K_f32(
1682
2681
 
1683
2682
  const int64_t r0 = tgpig.x;
1684
2683
  const int64_t r1 = tgpig.y;
1685
- const int r2 = tgpig.z;
2684
+ const int im = tgpig.z;
1686
2685
 
1687
2686
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
1688
- const uint offset0 = r2/gqa*(nb*ne0);
2687
+
2688
+ const uint i12 = im%ne12;
2689
+ const uint i13 = im/ne12;
2690
+
2691
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2692
+
1689
2693
  device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
1690
- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2694
+ device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
1691
2695
 
1692
2696
  float sumf[2]={0.f};
1693
2697
 
@@ -1703,15 +2707,15 @@ kernel void kernel_mul_mat_q5_K_f32(
1703
2707
 
1704
2708
  const int tid = tiisg/4;
1705
2709
  const int ix = tiisg%4;
1706
- const int im = tid/4;
2710
+ const int iq = tid/4;
1707
2711
  const int ir = tid%4;
1708
2712
  const int n = 8;
1709
2713
 
1710
2714
  const int l0 = n*ir;
1711
- const int q_offset = 32*im + l0;
1712
- const int y_offset = 64*im + l0;
2715
+ const int q_offset = 32*iq + l0;
2716
+ const int y_offset = 64*iq + l0;
1713
2717
 
1714
- const uint8_t hm1 = 1u << (2*im);
2718
+ const uint8_t hm1 = 1u << (2*iq);
1715
2719
  const uint8_t hm2 = hm1 << 1;
1716
2720
  const uint8_t hm3 = hm1 << 4;
1717
2721
  const uint8_t hm4 = hm2 << 4;
@@ -1726,7 +2730,7 @@ kernel void kernel_mul_mat_q5_K_f32(
1726
2730
  device const uint8_t * q1 = x[i].qs + q_offset;
1727
2731
  device const uint8_t * qh = x[i].qh + l0;
1728
2732
  device const half * dh = &x[i].d;
1729
- device const uint16_t * a = (device const uint16_t *)x[i].scales + im;
2733
+ device const uint16_t * a = (device const uint16_t *)x[i].scales + iq;
1730
2734
 
1731
2735
  device const float * y2 = y1 + 128;
1732
2736
  float4 sumy = {0.f, 0.f, 0.f, 0.f};
@@ -1782,7 +2786,7 @@ kernel void kernel_mul_mat_q5_K_f32(
1782
2786
 
1783
2787
  const int il = 4 * (tiisg/8); // 0, 4, 8, 12
1784
2788
  const int ix = tiisg%8;
1785
- const int im = il/8; // 0, 0, 1, 1
2789
+ const int iq = il/8; // 0, 0, 1, 1
1786
2790
  const int in = il%8; // 0, 4, 0, 4
1787
2791
 
1788
2792
  device const float * y = yy + ix*QK_K + il;
@@ -1807,7 +2811,7 @@ kernel void kernel_mul_mat_q5_K_f32(
1807
2811
 
1808
2812
  float2 acc = {0.f, 0.f};
1809
2813
  for (int l = 0; l < 4; ++l) {
1810
- const uint8_t hl = h[l] >> im;
2814
+ const uint8_t hl = h[l] >> iq;
1811
2815
  acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
1812
2816
  + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
1813
2817
  acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
@@ -1829,13 +2833,13 @@ kernel void kernel_mul_mat_q5_K_f32(
1829
2833
  for (int row = 0; row < 2; ++row) {
1830
2834
  const float tot = simd_sum(sumf[row]);
1831
2835
  if (tiisg == 0) {
1832
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
2836
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
1833
2837
  }
1834
2838
  }
1835
2839
 
1836
2840
  }
1837
2841
 
1838
- kernel void kernel_mul_mat_q6_K_f32(
2842
+ kernel void kernel_mul_mv_q6_K_f32(
1839
2843
  device const void * src0,
1840
2844
  device const float * src1,
1841
2845
  device float * dst,
@@ -1844,9 +2848,10 @@ kernel void kernel_mul_mat_q6_K_f32(
1844
2848
  constant int64_t & ne02[[buffer(5)]],
1845
2849
  constant int64_t & ne10[[buffer(9)]],
1846
2850
  constant int64_t & ne12[[buffer(11)]],
1847
- constant int64_t & ne0[[buffer(15)]],
1848
- constant int64_t & ne1[[buffer(16)]],
1849
- constant uint & gqa[[buffer(17)]],
2851
+ constant int64_t & ne0 [[buffer(15)]],
2852
+ constant int64_t & ne1 [[buffer(16)]],
2853
+ constant uint & r2 [[buffer(17)]],
2854
+ constant uint & r3 [[buffer(18)]],
1850
2855
  uint3 tgpig[[threadgroup_position_in_grid]],
1851
2856
  uint tiisg[[thread_index_in_simdgroup]],
1852
2857
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1860,12 +2865,17 @@ kernel void kernel_mul_mat_q6_K_f32(
1860
2865
 
1861
2866
  const int64_t r0 = tgpig.x;
1862
2867
  const int64_t r1 = tgpig.y;
1863
- const int r2 = tgpig.z;
2868
+ const int im = tgpig.z;
1864
2869
 
1865
2870
  const int row = 2 * r0 + sgitg;
1866
- const uint offset0 = r2/gqa*(nb*ne0);
2871
+
2872
+ const uint i12 = im%ne12;
2873
+ const uint i13 = im/ne12;
2874
+
2875
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2876
+
1867
2877
  device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
1868
- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2878
+ device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
1869
2879
 
1870
2880
  float sumf = 0;
1871
2881
 
@@ -1931,7 +2941,7 @@ kernel void kernel_mul_mat_q6_K_f32(
1931
2941
 
1932
2942
  const float tot = simd_sum(sumf);
1933
2943
  if (tiisg == 0) {
1934
- dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
2944
+ dst[r1*ne0 + im*ne0*ne1 + row] = tot;
1935
2945
  }
1936
2946
  }
1937
2947
 
@@ -1984,6 +2994,62 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
1984
2994
  }
1985
2995
  }
1986
2996
 
2997
+ template <typename type4x4>
2998
+ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
2999
+ device const uint16_t * qs = ((device const uint16_t *)xb + 3);
3000
+ const float d = xb->d;
3001
+ const float md = -16.h * xb->d;
3002
+ const ushort mask = il ? 0x00F0 : 0x000F;
3003
+
3004
+ const uint32_t qh = *((device const uint32_t *)xb->qh);
3005
+
3006
+ const int x_mv = il ? 4 : 0;
3007
+
3008
+ const int gh_mv = il ? 12 : 0;
3009
+ const int gh_bk = il ? 0 : 4;
3010
+
3011
+ for (int i = 0; i < 8; i++) {
3012
+ // extract the 5-th bits for x0 and x1
3013
+ const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
3014
+ const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
3015
+
3016
+ // combine the 4-bits from qs with the 5th bit
3017
+ const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
3018
+ const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
3019
+
3020
+ reg[i/2][2*(i%2)+0] = d * x0 + md;
3021
+ reg[i/2][2*(i%2)+1] = d * x1 + md;
3022
+ }
3023
+ }
3024
+
3025
+ template <typename type4x4>
3026
+ void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
3027
+ device const uint16_t * qs = ((device const uint16_t *)xb + 4);
3028
+ const float d = xb->d;
3029
+ const float m = xb->m;
3030
+ const ushort mask = il ? 0x00F0 : 0x000F;
3031
+
3032
+ const uint32_t qh = *((device const uint32_t *)xb->qh);
3033
+
3034
+ const int x_mv = il ? 4 : 0;
3035
+
3036
+ const int gh_mv = il ? 12 : 0;
3037
+ const int gh_bk = il ? 0 : 4;
3038
+
3039
+ for (int i = 0; i < 8; i++) {
3040
+ // extract the 5-th bits for x0 and x1
3041
+ const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
3042
+ const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
3043
+
3044
+ // combine the 4-bits from qs with the 5th bit
3045
+ const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
3046
+ const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
3047
+
3048
+ reg[i/2][2*(i%2)+0] = d * x0 + m;
3049
+ reg[i/2][2*(i%2)+1] = d * x1 + m;
3050
+ }
3051
+ }
3052
+
1987
3053
  template <typename type4x4>
1988
3054
  void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
1989
3055
  device const int8_t * qs = ((device const int8_t *)xb->qs);
@@ -2173,7 +3239,7 @@ kernel void kernel_get_rows(
2173
3239
  }
2174
3240
 
2175
3241
  #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
2176
- #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix A
3242
+ #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
2177
3243
  #define BLOCK_SIZE_K 32
2178
3244
  #define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
2179
3245
  #define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
@@ -2185,24 +3251,25 @@ kernel void kernel_get_rows(
2185
3251
 
2186
3252
  // each block_q contains 16*nl weights
2187
3253
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
2188
- kernel void kernel_mul_mm(device const uchar * src0,
2189
- device const uchar * src1,
2190
- device float * dst,
2191
- constant int64_t & ne00,
2192
- constant int64_t & ne02,
2193
- constant int64_t & nb01,
2194
- constant int64_t & nb02,
2195
- constant int64_t & ne12,
2196
- constant int64_t & nb10,
2197
- constant int64_t & nb11,
2198
- constant int64_t & nb12,
2199
- constant int64_t & ne0,
2200
- constant int64_t & ne1,
2201
- constant uint & gqa,
2202
- threadgroup uchar * shared_memory [[threadgroup(0)]],
2203
- uint3 tgpig[[threadgroup_position_in_grid]],
2204
- uint tiitg[[thread_index_in_threadgroup]],
2205
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3254
+ void kernel_mul_mm_impl(device const uchar * src0,
3255
+ device const uchar * src1,
3256
+ device float * dst,
3257
+ constant int64_t & ne00,
3258
+ constant int64_t & ne02,
3259
+ constant int64_t & nb01,
3260
+ constant int64_t & nb02,
3261
+ constant int64_t & ne12,
3262
+ constant int64_t & nb10,
3263
+ constant int64_t & nb11,
3264
+ constant int64_t & nb12,
3265
+ constant int64_t & ne0,
3266
+ constant int64_t & ne1,
3267
+ constant uint & r2,
3268
+ constant uint & r3,
3269
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
3270
+ uint3 tgpig[[threadgroup_position_in_grid]],
3271
+ uint tiitg[[thread_index_in_threadgroup]],
3272
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2206
3273
 
2207
3274
  threadgroup half * sa = (threadgroup half *)(shared_memory);
2208
3275
  threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
@@ -2210,9 +3277,11 @@ kernel void kernel_mul_mm(device const uchar * src0,
2210
3277
  const uint r0 = tgpig.y;
2211
3278
  const uint r1 = tgpig.x;
2212
3279
  const uint im = tgpig.z;
3280
+
2213
3281
  // if this block is of 64x32 shape or smaller
2214
3282
  short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
2215
3283
  short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
3284
+
2216
3285
  // a thread shouldn't load data outside of the matrix
2217
3286
  short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
2218
3287
  short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
@@ -2226,7 +3295,10 @@ kernel void kernel_mul_mm(device const uchar * src0,
2226
3295
 
2227
3296
  short il = (tiitg % THREAD_PER_ROW);
2228
3297
 
2229
- uint offset0 = im/gqa*nb02;
3298
+ const uint i12 = im%ne12;
3299
+ const uint i13 = im/ne12;
3300
+
3301
+ uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
2230
3302
  ushort offset1 = il/nl;
2231
3303
 
2232
3304
  device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
@@ -2236,26 +3308,30 @@ kernel void kernel_mul_mm(device const uchar * src0,
2236
3308
  + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
2237
3309
 
2238
3310
  for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
2239
- //load data and store to threadgroup memory
3311
+ // load data and store to threadgroup memory
2240
3312
  half4x4 temp_a;
2241
3313
  dequantize_func(x, il, temp_a);
2242
3314
  threadgroup_barrier(mem_flags::mem_threadgroup);
3315
+
2243
3316
  #pragma unroll(16)
2244
3317
  for (int i = 0; i < 16; i++) {
2245
3318
  *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
2246
- + 16 * (tiitg % THREAD_PER_ROW) + 8 * (i / 8)) \
2247
- + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
3319
+ + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
3320
+ + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
2248
3321
  }
2249
- *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) \
2250
- = *((device float2x4 *)y);
3322
+
3323
+ *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
3324
+
2251
3325
  il = (il + 2 < nl) ? il + 2 : il % 2;
2252
3326
  x = (il < 2) ? x + (2+nl-1)/nl : x;
2253
3327
  y += BLOCK_SIZE_K;
2254
3328
 
2255
3329
  threadgroup_barrier(mem_flags::mem_threadgroup);
2256
- //load matrices from threadgroup memory and conduct outer products
3330
+
3331
+ // load matrices from threadgroup memory and conduct outer products
2257
3332
  threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
2258
3333
  threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
3334
+
2259
3335
  #pragma unroll(4)
2260
3336
  for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
2261
3337
  #pragma unroll(4)
@@ -2270,6 +3346,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
2270
3346
 
2271
3347
  lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
2272
3348
  lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
3349
+
2273
3350
  #pragma unroll(8)
2274
3351
  for (int i = 0; i < 8; i++){
2275
3352
  simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
@@ -2278,25 +3355,26 @@ kernel void kernel_mul_mm(device const uchar * src0,
2278
3355
  }
2279
3356
 
2280
3357
  if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
2281
- device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1) \
2282
- + (BLOCK_SIZE_N * r1 + 16 * (sgitg>>1)) * ne0 + im*ne1*ne0;
3358
+ device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
3359
+ + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
2283
3360
  for (int i = 0; i < 8; i++) {
2284
3361
  simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
2285
3362
  }
2286
3363
  } else {
2287
3364
  // block is smaller than 64x32, we should avoid writing data outside of the matrix
2288
3365
  threadgroup_barrier(mem_flags::mem_threadgroup);
2289
- threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
3366
+ threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
2290
3367
  + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
2291
3368
  for (int i = 0; i < 8; i++) {
2292
3369
  simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
2293
3370
  }
2294
3371
 
2295
3372
  threadgroup_barrier(mem_flags::mem_threadgroup);
2296
- device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
2297
- if (sgitg==0) {
3373
+
3374
+ device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
3375
+ if (sgitg == 0) {
2298
3376
  for (int i = 0; i < n_rows; i++) {
2299
- for (int j = tiitg; j< n_cols; j += BLOCK_SIZE_N) {
3377
+ for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
2300
3378
  *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
2301
3379
  }
2302
3380
  }
@@ -2304,19 +3382,123 @@ kernel void kernel_mul_mm(device const uchar * src0,
2304
3382
  }
2305
3383
  }
2306
3384
 
3385
+ template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
3386
+ kernel void kernel_mul_mm(device const uchar * src0,
3387
+ device const uchar * src1,
3388
+ device float * dst,
3389
+ constant int64_t & ne00,
3390
+ constant int64_t & ne02,
3391
+ constant int64_t & nb01,
3392
+ constant int64_t & nb02,
3393
+ constant int64_t & ne12,
3394
+ constant int64_t & nb10,
3395
+ constant int64_t & nb11,
3396
+ constant int64_t & nb12,
3397
+ constant int64_t & ne0,
3398
+ constant int64_t & ne1,
3399
+ constant uint & r2,
3400
+ constant uint & r3,
3401
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
3402
+ uint3 tgpig[[threadgroup_position_in_grid]],
3403
+ uint tiitg[[thread_index_in_threadgroup]],
3404
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3405
+ kernel_mul_mm_impl<block_q, nl, dequantize_func>(
3406
+ src0,
3407
+ src1,
3408
+ dst,
3409
+ ne00,
3410
+ ne02,
3411
+ nb01,
3412
+ nb02,
3413
+ ne12,
3414
+ nb10,
3415
+ nb11,
3416
+ nb12,
3417
+ ne0,
3418
+ ne1,
3419
+ r2,
3420
+ r3,
3421
+ shared_memory,
3422
+ tgpig,
3423
+ tiitg,
3424
+ sgitg);
3425
+ }
3426
+
3427
+ template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
3428
+ kernel void kernel_mul_mm_id(
3429
+ device const int32_t * ids,
3430
+ device const uchar * src1,
3431
+ device float * dst,
3432
+ constant int64_t & ne00,
3433
+ constant int64_t & ne02,
3434
+ constant int64_t & nb01,
3435
+ constant int64_t & nb02,
3436
+ constant int64_t & ne12,
3437
+ constant int64_t & nb10,
3438
+ constant int64_t & nb11,
3439
+ constant int64_t & nb12,
3440
+ constant int64_t & ne0,
3441
+ constant int64_t & ne1,
3442
+ constant uint & r2,
3443
+ constant uint & r3,
3444
+ constant int & idx,
3445
+ device const uchar * src00,
3446
+ device const uchar * src01,
3447
+ device const uchar * src02,
3448
+ device const uchar * src03,
3449
+ device const uchar * src04,
3450
+ device const uchar * src05,
3451
+ device const uchar * src06,
3452
+ device const uchar * src07,
3453
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
3454
+ uint3 tgpig[[threadgroup_position_in_grid]],
3455
+ uint tiitg[[thread_index_in_threadgroup]],
3456
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3457
+ device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
3458
+
3459
+ kernel_mul_mm_impl<block_q, nl, dequantize_func>(
3460
+ src0[ids[idx]],
3461
+ src1,
3462
+ dst,
3463
+ ne00,
3464
+ ne02,
3465
+ nb01,
3466
+ nb02,
3467
+ ne12,
3468
+ nb10,
3469
+ nb11,
3470
+ nb12,
3471
+ ne0,
3472
+ ne1,
3473
+ r2,
3474
+ r3,
3475
+ shared_memory,
3476
+ tgpig,
3477
+ tiitg,
3478
+ sgitg);
3479
+ }
3480
+
2307
3481
  #if QK_K == 256
2308
3482
  #define QK_NL 16
2309
3483
  #else
2310
3484
  #define QK_NL 4
2311
3485
  #endif
2312
3486
 
2313
- typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
2314
- constant uint64_t &, constant uint64_t &, uint, uint, uint);
3487
+ typedef void (get_rows_t)(
3488
+ device const void * src0,
3489
+ device const int * src1,
3490
+ device float * dst,
3491
+ constant int64_t & ne00,
3492
+ constant uint64_t & nb01,
3493
+ constant uint64_t & nb1,
3494
+ uint, uint, uint);
2315
3495
 
2316
3496
  template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
2317
3497
  template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
2318
3498
  template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
2319
3499
  template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
3500
+ template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
3501
+ template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows<block_q5_1, 2, dequantize_q5_1>;
2320
3502
  template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
2321
3503
  template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
2322
3504
  template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
@@ -2338,16 +3520,61 @@ typedef void (mat_mm_t)(
2338
3520
  constant int64_t & nb12,
2339
3521
  constant int64_t & ne0,
2340
3522
  constant int64_t & ne1,
2341
- constant uint & gqa,
2342
- threadgroup uchar *, uint3, uint, uint);
3523
+ constant uint & r2,
3524
+ constant uint & r3,
3525
+ threadgroup uchar *,
3526
+ uint3, uint, uint);
2343
3527
 
2344
3528
  template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
2345
3529
  template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
2346
3530
  template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
2347
3531
  template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
3532
+ template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
3533
+ template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
2348
3534
  template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
2349
3535
  template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
2350
3536
  template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
2351
3537
  template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
2352
3538
  template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
2353
3539
  template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
3540
+
3541
+ typedef void (mat_mm_id_t)(
3542
+ device const int32_t * ids,
3543
+ device const uchar * src1,
3544
+ device float * dst,
3545
+ constant int64_t & ne00,
3546
+ constant int64_t & ne02,
3547
+ constant int64_t & nb01,
3548
+ constant int64_t & nb02,
3549
+ constant int64_t & ne12,
3550
+ constant int64_t & nb10,
3551
+ constant int64_t & nb11,
3552
+ constant int64_t & nb12,
3553
+ constant int64_t & ne0,
3554
+ constant int64_t & ne1,
3555
+ constant uint & r2,
3556
+ constant uint & r3,
3557
+ constant int & idx,
3558
+ device const uchar * src00,
3559
+ device const uchar * src01,
3560
+ device const uchar * src02,
3561
+ device const uchar * src03,
3562
+ device const uchar * src04,
3563
+ device const uchar * src05,
3564
+ device const uchar * src06,
3565
+ device const uchar * src07,
3566
+ threadgroup uchar *,
3567
+ uint3, uint, uint);
3568
+
3569
+ template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
3570
+ template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
3571
+ template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
3572
+ template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
3573
+ template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
3574
+ template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>;
3575
+ template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>;
3576
+ template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
3577
+ template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
3578
+ template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
3579
+ template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
3580
+ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;