whisper.rn 0.4.0-rc.7 → 0.4.0-rc.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. package/android/src/main/CMakeLists.txt +2 -1
  2. package/android/src/main/java/com/rnwhisper/AudioUtils.java +27 -12
  3. package/android/src/main/java/com/rnwhisper/RNWhisper.java +75 -34
  4. package/android/src/main/java/com/rnwhisper/WhisperContext.java +20 -3
  5. package/android/src/main/jni.cpp +29 -1
  6. package/android/src/newarch/java/com/rnwhisper/RNWhisperModule.java +10 -0
  7. package/android/src/oldarch/java/com/rnwhisper/RNWhisperModule.java +10 -0
  8. package/cpp/coreml/whisper-encoder.mm +1 -1
  9. package/cpp/ggml-aarch64.c +3209 -0
  10. package/cpp/ggml-aarch64.h +39 -0
  11. package/cpp/ggml-alloc.c +732 -494
  12. package/cpp/ggml-alloc.h +47 -63
  13. package/cpp/ggml-backend-impl.h +162 -47
  14. package/cpp/ggml-backend.cpp +2635 -0
  15. package/cpp/ggml-backend.h +216 -71
  16. package/cpp/ggml-common.h +1853 -0
  17. package/cpp/ggml-cpu-impl.h +614 -0
  18. package/cpp/ggml-impl.h +144 -178
  19. package/cpp/ggml-metal.h +14 -60
  20. package/cpp/ggml-metal.m +3437 -2097
  21. package/cpp/ggml-quants.c +12559 -4189
  22. package/cpp/ggml-quants.h +135 -212
  23. package/cpp/ggml-whisper.metallib +0 -0
  24. package/cpp/ggml.c +9029 -5219
  25. package/cpp/ggml.h +673 -338
  26. package/cpp/rn-whisper.cpp +91 -0
  27. package/cpp/rn-whisper.h +2 -0
  28. package/cpp/whisper.cpp +1476 -675
  29. package/cpp/whisper.h +84 -28
  30. package/ios/RNWhisper.mm +124 -37
  31. package/ios/RNWhisperAudioUtils.h +1 -0
  32. package/ios/RNWhisperAudioUtils.m +20 -13
  33. package/ios/RNWhisperContext.h +3 -2
  34. package/ios/RNWhisperContext.mm +41 -8
  35. package/jest/mock.js +9 -1
  36. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  37. package/lib/commonjs/index.js +48 -19
  38. package/lib/commonjs/index.js.map +1 -1
  39. package/lib/commonjs/version.json +1 -1
  40. package/lib/module/NativeRNWhisper.js.map +1 -1
  41. package/lib/module/index.js +48 -19
  42. package/lib/module/index.js.map +1 -1
  43. package/lib/module/version.json +1 -1
  44. package/lib/typescript/NativeRNWhisper.d.ts +6 -3
  45. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  46. package/lib/typescript/index.d.ts +25 -3
  47. package/lib/typescript/index.d.ts.map +1 -1
  48. package/package.json +6 -5
  49. package/src/NativeRNWhisper.ts +12 -3
  50. package/src/index.ts +63 -24
  51. package/src/version.json +1 -1
  52. package/whisper-rn.podspec +9 -2
  53. package/cpp/ggml-backend.c +0 -1357
  54. package/cpp/ggml-metal-whisper.metal +0 -4908
@@ -1,4908 +0,0 @@
1
- #include <metal_stdlib>
2
-
3
- using namespace metal;
4
-
5
- #define MAX(x, y) ((x) > (y) ? (x) : (y))
6
- #define MIN(x, y) ((x) < (y) ? (x) : (y))
7
- #define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
8
-
9
- #define QK4_0 32
10
- #define QR4_0 2
11
- typedef struct {
12
- half d; // delta
13
- uint8_t qs[QK4_0 / 2]; // nibbles / quants
14
- } block_q4_0;
15
-
16
- #define QK4_1 32
17
- typedef struct {
18
- half d; // delta
19
- half m; // min
20
- uint8_t qs[QK4_1 / 2]; // nibbles / quants
21
- } block_q4_1;
22
-
23
- #define QK5_0 32
24
- typedef struct {
25
- half d; // delta
26
- uint8_t qh[4]; // 5-th bit of quants
27
- uint8_t qs[QK5_0 / 2]; // nibbles / quants
28
- } block_q5_0;
29
-
30
- #define QK5_1 32
31
- typedef struct {
32
- half d; // delta
33
- half m; // min
34
- uint8_t qh[4]; // 5-th bit of quants
35
- uint8_t qs[QK5_1 / 2]; // nibbles / quants
36
- } block_q5_1;
37
-
38
- #define QK8_0 32
39
- typedef struct {
40
- half d; // delta
41
- int8_t qs[QK8_0]; // quants
42
- } block_q8_0;
43
-
44
- #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
45
-
46
- enum ggml_sort_order {
47
- GGML_SORT_ASC,
48
- GGML_SORT_DESC,
49
- };
50
-
51
- // general-purpose kernel for addition, multiplication and division of two tensors
52
- // pros: works for non-contiguous tensors, supports broadcast across all dims
53
- // cons: not very efficient
54
- kernel void kernel_add(
55
- device const char * src0,
56
- device const char * src1,
57
- device char * dst,
58
- constant int64_t & ne00,
59
- constant int64_t & ne01,
60
- constant int64_t & ne02,
61
- constant int64_t & ne03,
62
- constant int64_t & nb00,
63
- constant int64_t & nb01,
64
- constant int64_t & nb02,
65
- constant int64_t & nb03,
66
- constant int64_t & ne10,
67
- constant int64_t & ne11,
68
- constant int64_t & ne12,
69
- constant int64_t & ne13,
70
- constant int64_t & nb10,
71
- constant int64_t & nb11,
72
- constant int64_t & nb12,
73
- constant int64_t & nb13,
74
- constant int64_t & ne0,
75
- constant int64_t & ne1,
76
- constant int64_t & ne2,
77
- constant int64_t & ne3,
78
- constant int64_t & nb0,
79
- constant int64_t & nb1,
80
- constant int64_t & nb2,
81
- constant int64_t & nb3,
82
- constant int64_t & offs,
83
- uint3 tgpig[[threadgroup_position_in_grid]],
84
- uint3 tpitg[[thread_position_in_threadgroup]],
85
- uint3 ntg[[threads_per_threadgroup]]) {
86
- const int64_t i03 = tgpig.z;
87
- const int64_t i02 = tgpig.y;
88
- const int64_t i01 = tgpig.x;
89
-
90
- const int64_t i13 = i03 % ne13;
91
- const int64_t i12 = i02 % ne12;
92
- const int64_t i11 = i01 % ne11;
93
-
94
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
95
- device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
96
- device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs;
97
-
98
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
99
- const int i10 = i0 % ne10;
100
- *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10));
101
- }
102
- }
103
-
104
- kernel void kernel_mul(
105
- device const char * src0,
106
- device const char * src1,
107
- device char * dst,
108
- constant int64_t & ne00,
109
- constant int64_t & ne01,
110
- constant int64_t & ne02,
111
- constant int64_t & ne03,
112
- constant int64_t & nb00,
113
- constant int64_t & nb01,
114
- constant int64_t & nb02,
115
- constant int64_t & nb03,
116
- constant int64_t & ne10,
117
- constant int64_t & ne11,
118
- constant int64_t & ne12,
119
- constant int64_t & ne13,
120
- constant int64_t & nb10,
121
- constant int64_t & nb11,
122
- constant int64_t & nb12,
123
- constant int64_t & nb13,
124
- constant int64_t & ne0,
125
- constant int64_t & ne1,
126
- constant int64_t & ne2,
127
- constant int64_t & ne3,
128
- constant int64_t & nb0,
129
- constant int64_t & nb1,
130
- constant int64_t & nb2,
131
- constant int64_t & nb3,
132
- uint3 tgpig[[threadgroup_position_in_grid]],
133
- uint3 tpitg[[thread_position_in_threadgroup]],
134
- uint3 ntg[[threads_per_threadgroup]]) {
135
- const int64_t i03 = tgpig.z;
136
- const int64_t i02 = tgpig.y;
137
- const int64_t i01 = tgpig.x;
138
-
139
- const int64_t i13 = i03 % ne13;
140
- const int64_t i12 = i02 % ne12;
141
- const int64_t i11 = i01 % ne11;
142
-
143
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
144
- device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
145
- device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
146
-
147
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
148
- const int i10 = i0 % ne10;
149
- *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10));
150
- }
151
- }
152
-
153
- kernel void kernel_div(
154
- device const char * src0,
155
- device const char * src1,
156
- device char * dst,
157
- constant int64_t & ne00,
158
- constant int64_t & ne01,
159
- constant int64_t & ne02,
160
- constant int64_t & ne03,
161
- constant int64_t & nb00,
162
- constant int64_t & nb01,
163
- constant int64_t & nb02,
164
- constant int64_t & nb03,
165
- constant int64_t & ne10,
166
- constant int64_t & ne11,
167
- constant int64_t & ne12,
168
- constant int64_t & ne13,
169
- constant int64_t & nb10,
170
- constant int64_t & nb11,
171
- constant int64_t & nb12,
172
- constant int64_t & nb13,
173
- constant int64_t & ne0,
174
- constant int64_t & ne1,
175
- constant int64_t & ne2,
176
- constant int64_t & ne3,
177
- constant int64_t & nb0,
178
- constant int64_t & nb1,
179
- constant int64_t & nb2,
180
- constant int64_t & nb3,
181
- uint3 tgpig[[threadgroup_position_in_grid]],
182
- uint3 tpitg[[thread_position_in_threadgroup]],
183
- uint3 ntg[[threads_per_threadgroup]]) {
184
- const int64_t i03 = tgpig.z;
185
- const int64_t i02 = tgpig.y;
186
- const int64_t i01 = tgpig.x;
187
-
188
- const int64_t i13 = i03 % ne13;
189
- const int64_t i12 = i02 % ne12;
190
- const int64_t i11 = i01 % ne11;
191
-
192
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
193
- device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
194
- device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
195
-
196
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
197
- const int i10 = i0 % ne10;
198
- *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10));
199
- }
200
- }
201
-
202
- // assumption: src1 is a row
203
- // broadcast src1 into src0
204
- kernel void kernel_add_row(
205
- device const float4 * src0,
206
- device const float4 * src1,
207
- device float4 * dst,
208
- constant int64_t & nb [[buffer(28)]],
209
- uint tpig[[thread_position_in_grid]]) {
210
- dst[tpig] = src0[tpig] + src1[tpig % nb];
211
- }
212
-
213
- kernel void kernel_mul_row(
214
- device const float4 * src0,
215
- device const float4 * src1,
216
- device float4 * dst,
217
- constant int64_t & nb [[buffer(28)]],
218
- uint tpig[[thread_position_in_grid]]) {
219
- dst[tpig] = src0[tpig] * src1[tpig % nb];
220
- }
221
-
222
- kernel void kernel_div_row(
223
- device const float4 * src0,
224
- device const float4 * src1,
225
- device float4 * dst,
226
- constant int64_t & nb [[buffer(28)]],
227
- uint tpig[[thread_position_in_grid]]) {
228
- dst[tpig] = src0[tpig] / src1[tpig % nb];
229
- }
230
-
231
- kernel void kernel_scale(
232
- device const float * src0,
233
- device float * dst,
234
- constant float & scale,
235
- uint tpig[[thread_position_in_grid]]) {
236
- dst[tpig] = src0[tpig] * scale;
237
- }
238
-
239
- kernel void kernel_scale_4(
240
- device const float4 * src0,
241
- device float4 * dst,
242
- constant float & scale,
243
- uint tpig[[thread_position_in_grid]]) {
244
- dst[tpig] = src0[tpig] * scale;
245
- }
246
-
247
- kernel void kernel_relu(
248
- device const float * src0,
249
- device float * dst,
250
- uint tpig[[thread_position_in_grid]]) {
251
- dst[tpig] = max(0.0f, src0[tpig]);
252
- }
253
-
254
- kernel void kernel_tanh(
255
- device const float * src0,
256
- device float * dst,
257
- uint tpig[[thread_position_in_grid]]) {
258
- device const float & x = src0[tpig];
259
- dst[tpig] = precise::tanh(x);
260
- }
261
-
262
- constant float GELU_COEF_A = 0.044715f;
263
- constant float GELU_QUICK_COEF = -1.702f;
264
- constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
265
-
266
- kernel void kernel_gelu(
267
- device const float4 * src0,
268
- device float4 * dst,
269
- uint tpig[[thread_position_in_grid]]) {
270
- device const float4 & x = src0[tpig];
271
-
272
- // BEWARE !!!
273
- // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
274
- // This was observed with Falcon 7B and 40B models
275
- //
276
- dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
277
- }
278
-
279
- kernel void kernel_gelu_quick(
280
- device const float4 * src0,
281
- device float4 * dst,
282
- uint tpig[[thread_position_in_grid]]) {
283
- device const float4 & x = src0[tpig];
284
-
285
- dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
286
- }
287
-
288
- kernel void kernel_silu(
289
- device const float4 * src0,
290
- device float4 * dst,
291
- uint tpig[[thread_position_in_grid]]) {
292
- device const float4 & x = src0[tpig];
293
- dst[tpig] = x / (1.0f + exp(-x));
294
- }
295
-
296
- kernel void kernel_sqr(
297
- device const float * src0,
298
- device float * dst,
299
- uint tpig[[thread_position_in_grid]]) {
300
- dst[tpig] = src0[tpig] * src0[tpig];
301
- }
302
-
303
- kernel void kernel_sum_rows(
304
- device const float * src0,
305
- device float * dst,
306
- constant int64_t & ne00,
307
- constant int64_t & ne01,
308
- constant int64_t & ne02,
309
- constant int64_t & ne03,
310
- constant int64_t & nb00,
311
- constant int64_t & nb01,
312
- constant int64_t & nb02,
313
- constant int64_t & nb03,
314
- constant int64_t & ne10,
315
- constant int64_t & ne11,
316
- constant int64_t & ne12,
317
- constant int64_t & ne13,
318
- constant int64_t & nb10,
319
- constant int64_t & nb11,
320
- constant int64_t & nb12,
321
- constant int64_t & nb13,
322
- constant int64_t & ne0,
323
- constant int64_t & ne1,
324
- constant int64_t & ne2,
325
- constant int64_t & ne3,
326
- constant int64_t & nb0,
327
- constant int64_t & nb1,
328
- constant int64_t & nb2,
329
- constant int64_t & nb3,
330
- uint3 tpig[[thread_position_in_grid]]) {
331
- int64_t i3 = tpig.z;
332
- int64_t i2 = tpig.y;
333
- int64_t i1 = tpig.x;
334
-
335
- if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
336
- return;
337
- }
338
-
339
- device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
340
- device float * dst_row = (device float *) ((device char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
341
-
342
- float row_sum = 0;
343
-
344
- for (int64_t i0 = 0; i0 < ne00; i0++) {
345
- row_sum += src_row[i0];
346
- }
347
-
348
- dst_row[0] = row_sum;
349
- }
350
-
351
- kernel void kernel_soft_max(
352
- device const float * src0,
353
- device const float * src1,
354
- device float * dst,
355
- constant int64_t & ne00,
356
- constant int64_t & ne01,
357
- constant int64_t & ne02,
358
- constant float & scale,
359
- threadgroup float * buf [[threadgroup(0)]],
360
- uint tgpig[[threadgroup_position_in_grid]],
361
- uint tpitg[[thread_position_in_threadgroup]],
362
- uint sgitg[[simdgroup_index_in_threadgroup]],
363
- uint tiisg[[thread_index_in_simdgroup]],
364
- uint ntg[[threads_per_threadgroup]]) {
365
- const int64_t i03 = (tgpig) / (ne02*ne01);
366
- const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
367
- const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
368
-
369
- device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
370
- device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
371
- device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
372
-
373
- // parallel max
374
- float lmax = -INFINITY;
375
-
376
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
377
- lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f));
378
- }
379
-
380
- // find the max value in the block
381
- float max_val = simd_max(lmax);
382
- if (ntg > N_SIMDWIDTH) {
383
- if (sgitg == 0) {
384
- buf[tiisg] = -INFINITY;
385
- }
386
-
387
- threadgroup_barrier(mem_flags::mem_threadgroup);
388
-
389
- if (tiisg == 0) {
390
- buf[sgitg] = max_val;
391
- }
392
-
393
- threadgroup_barrier(mem_flags::mem_threadgroup);
394
-
395
- max_val = buf[tiisg];
396
- max_val = simd_max(max_val);
397
- }
398
-
399
- // parallel sum
400
- float lsum = 0.0f;
401
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
402
- const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
403
- lsum += exp_psrc0;
404
- pdst[i00] = exp_psrc0;
405
- }
406
-
407
- // This barrier fixes a failing test
408
- // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
409
- threadgroup_barrier(mem_flags::mem_none);
410
-
411
- float sum = simd_sum(lsum);
412
-
413
- if (ntg > N_SIMDWIDTH) {
414
- if (sgitg == 0) {
415
- buf[tiisg] = 0.0f;
416
- }
417
-
418
- threadgroup_barrier(mem_flags::mem_threadgroup);
419
-
420
- if (tiisg == 0) {
421
- buf[sgitg] = sum;
422
- }
423
-
424
- threadgroup_barrier(mem_flags::mem_threadgroup);
425
-
426
- sum = buf[tiisg];
427
- sum = simd_sum(sum);
428
- }
429
-
430
- const float inv_sum = 1.0f/sum;
431
-
432
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
433
- pdst[i00] *= inv_sum;
434
- }
435
- }
436
-
437
- kernel void kernel_soft_max_4(
438
- device const float * src0,
439
- device const float * src1,
440
- device float * dst,
441
- constant int64_t & ne00,
442
- constant int64_t & ne01,
443
- constant int64_t & ne02,
444
- constant float & scale,
445
- threadgroup float * buf [[threadgroup(0)]],
446
- uint tgpig[[threadgroup_position_in_grid]],
447
- uint tpitg[[thread_position_in_threadgroup]],
448
- uint sgitg[[simdgroup_index_in_threadgroup]],
449
- uint tiisg[[thread_index_in_simdgroup]],
450
- uint ntg[[threads_per_threadgroup]]) {
451
- const int64_t i03 = (tgpig) / (ne02*ne01);
452
- const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
453
- const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
454
-
455
- device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
456
- device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
457
- device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
458
-
459
- // parallel max
460
- float4 lmax4 = -INFINITY;
461
-
462
- for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
463
- lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f));
464
- }
465
-
466
- const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
467
-
468
- float max_val = simd_max(lmax);
469
- if (ntg > N_SIMDWIDTH) {
470
- if (sgitg == 0) {
471
- buf[tiisg] = -INFINITY;
472
- }
473
-
474
- threadgroup_barrier(mem_flags::mem_threadgroup);
475
-
476
- if (tiisg == 0) {
477
- buf[sgitg] = max_val;
478
- }
479
-
480
- threadgroup_barrier(mem_flags::mem_threadgroup);
481
-
482
- max_val = buf[tiisg];
483
- max_val = simd_max(max_val);
484
- }
485
-
486
- // parallel sum
487
- float4 lsum4 = 0.0f;
488
- for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
489
- const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
490
- lsum4 += exp_psrc4;
491
- pdst4[i00] = exp_psrc4;
492
- }
493
-
494
- const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
495
-
496
- // This barrier fixes a failing test
497
- // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
498
- threadgroup_barrier(mem_flags::mem_none);
499
-
500
- float sum = simd_sum(lsum);
501
-
502
- if (ntg > N_SIMDWIDTH) {
503
- if (sgitg == 0) {
504
- buf[tiisg] = 0.0f;
505
- }
506
-
507
- threadgroup_barrier(mem_flags::mem_threadgroup);
508
-
509
- if (tiisg == 0) {
510
- buf[sgitg] = sum;
511
- }
512
-
513
- threadgroup_barrier(mem_flags::mem_threadgroup);
514
-
515
- sum = buf[tiisg];
516
- sum = simd_sum(sum);
517
- }
518
-
519
- const float inv_sum = 1.0f/sum;
520
-
521
- for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
522
- pdst4[i00] *= inv_sum;
523
- }
524
- }
525
-
526
- kernel void kernel_diag_mask_inf(
527
- device const float * src0,
528
- device float * dst,
529
- constant int64_t & ne00,
530
- constant int64_t & ne01,
531
- constant int & n_past,
532
- uint3 tpig[[thread_position_in_grid]]) {
533
- const int64_t i02 = tpig[2];
534
- const int64_t i01 = tpig[1];
535
- const int64_t i00 = tpig[0];
536
-
537
- if (i00 > n_past + i01) {
538
- dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
539
- } else {
540
- dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
541
- }
542
- }
543
-
544
- kernel void kernel_diag_mask_inf_8(
545
- device const float4 * src0,
546
- device float4 * dst,
547
- constant int64_t & ne00,
548
- constant int64_t & ne01,
549
- constant int & n_past,
550
- uint3 tpig[[thread_position_in_grid]]) {
551
-
552
- const int64_t i = 2*tpig[0];
553
-
554
- dst[i+0] = src0[i+0];
555
- dst[i+1] = src0[i+1];
556
- int64_t i4 = 4*i;
557
- const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
558
- const int64_t i01 = i4/(ne00); i4 -= i01*ne00;
559
- const int64_t i00 = i4;
560
- for (int k = 3; k >= 0; --k) {
561
- if (i00 + 4 + k <= n_past + i01) {
562
- break;
563
- }
564
- dst[i+1][k] = -INFINITY;
565
- if (i00 + k > n_past + i01) {
566
- dst[i][k] = -INFINITY;
567
- }
568
- }
569
- }
570
-
571
- kernel void kernel_norm(
572
- device const void * src0,
573
- device float * dst,
574
- constant int64_t & ne00,
575
- constant uint64_t & nb01,
576
- constant float & eps,
577
- threadgroup float * sum [[threadgroup(0)]],
578
- uint tgpig[[threadgroup_position_in_grid]],
579
- uint tpitg[[thread_position_in_threadgroup]],
580
- uint ntg[[threads_per_threadgroup]]) {
581
- device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
582
- // MEAN
583
- // parallel sum
584
- sum[tpitg] = 0.0f;
585
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
586
- sum[tpitg] += x[i00];
587
- }
588
- // reduce
589
- threadgroup_barrier(mem_flags::mem_threadgroup);
590
- for (uint i = ntg/2; i > 0; i /= 2) {
591
- if (tpitg < i) {
592
- sum[tpitg] += sum[tpitg + i];
593
- }
594
- threadgroup_barrier(mem_flags::mem_threadgroup);
595
- }
596
- const float mean = sum[0] / ne00;
597
-
598
- // recenter and VARIANCE
599
- threadgroup_barrier(mem_flags::mem_threadgroup);
600
- device float * y = dst + tgpig*ne00;
601
- sum[tpitg] = 0.0f;
602
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
603
- y[i00] = x[i00] - mean;
604
- sum[tpitg] += y[i00] * y[i00];
605
- }
606
-
607
- // reduce
608
- threadgroup_barrier(mem_flags::mem_threadgroup);
609
- for (uint i = ntg/2; i > 0; i /= 2) {
610
- if (tpitg < i) {
611
- sum[tpitg] += sum[tpitg + i];
612
- }
613
- threadgroup_barrier(mem_flags::mem_threadgroup);
614
- }
615
- const float variance = sum[0] / ne00;
616
-
617
- const float scale = 1.0f/sqrt(variance + eps);
618
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
619
- y[i00] = y[i00] * scale;
620
- }
621
- }
622
-
623
- kernel void kernel_rms_norm(
624
- device const void * src0,
625
- device float * dst,
626
- constant int64_t & ne00,
627
- constant uint64_t & nb01,
628
- constant float & eps,
629
- threadgroup float * buf [[threadgroup(0)]],
630
- uint tgpig[[threadgroup_position_in_grid]],
631
- uint tpitg[[thread_position_in_threadgroup]],
632
- uint sgitg[[simdgroup_index_in_threadgroup]],
633
- uint tiisg[[thread_index_in_simdgroup]],
634
- uint ntg[[threads_per_threadgroup]]) {
635
- device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
636
-
637
- float4 sumf = 0;
638
- float all_sum = 0;
639
-
640
- // parallel sum
641
- for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
642
- sumf += x[i00] * x[i00];
643
- }
644
- all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
645
- all_sum = simd_sum(all_sum);
646
- if (ntg > N_SIMDWIDTH) {
647
- if (sgitg == 0) {
648
- buf[tiisg] = 0.0f;
649
- }
650
-
651
- threadgroup_barrier(mem_flags::mem_threadgroup);
652
-
653
- if (tiisg == 0) {
654
- buf[sgitg] = all_sum;
655
- }
656
-
657
- threadgroup_barrier(mem_flags::mem_threadgroup);
658
-
659
- all_sum = buf[tiisg];
660
- all_sum = simd_sum(all_sum);
661
- }
662
-
663
- const float mean = all_sum/ne00;
664
- const float scale = 1.0f/sqrt(mean + eps);
665
-
666
- device float4 * y = (device float4 *) (dst + tgpig*ne00);
667
- for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
668
- y[i00] = x[i00] * scale;
669
- }
670
- }
671
-
672
- kernel void kernel_group_norm(
673
- device const float * src0,
674
- device float * dst,
675
- constant int64_t & ne00,
676
- constant int64_t & ne01,
677
- constant int64_t & ne02,
678
- constant uint64_t & nb00,
679
- constant uint64_t & nb01,
680
- constant uint64_t & nb02,
681
- constant int32_t & n_groups,
682
- constant float & eps,
683
- threadgroup float * buf [[threadgroup(0)]],
684
- uint tgpig[[threadgroup_position_in_grid]],
685
- uint tpitg[[thread_position_in_threadgroup]],
686
- uint sgitg[[simdgroup_index_in_threadgroup]],
687
- uint tiisg[[thread_index_in_simdgroup]],
688
- uint ntg[[threads_per_threadgroup]]) {
689
- const int64_t ne = ne00*ne01*ne02;
690
- const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups);
691
-
692
- int start = tgpig * gs;
693
- int end = start + gs;
694
-
695
- start += tpitg;
696
-
697
- if (end >= ne) {
698
- end = ne;
699
- }
700
-
701
- float tmp = 0.0f; // partial sum for thread in warp
702
-
703
- for (int j = start; j < end; j += ntg) {
704
- tmp += src0[j];
705
- }
706
-
707
- threadgroup_barrier(mem_flags::mem_threadgroup);
708
- tmp = simd_sum(tmp);
709
- if (ntg > N_SIMDWIDTH) {
710
- if (sgitg == 0) {
711
- buf[tiisg] = 0.0f;
712
- }
713
-
714
- threadgroup_barrier(mem_flags::mem_threadgroup);
715
-
716
- if (tiisg == 0) {
717
- buf[sgitg] = tmp;
718
- }
719
-
720
- threadgroup_barrier(mem_flags::mem_threadgroup);
721
-
722
- tmp = buf[tiisg];
723
- tmp = simd_sum(tmp);
724
- }
725
-
726
- const float mean = tmp / gs;
727
- tmp = 0.0f;
728
-
729
- for (int j = start; j < end; j += ntg) {
730
- float xi = src0[j] - mean;
731
- dst[j] = xi;
732
- tmp += xi * xi;
733
- }
734
-
735
- tmp = simd_sum(tmp);
736
- if (ntg > N_SIMDWIDTH) {
737
- if (sgitg == 0) {
738
- buf[tiisg] = 0.0f;
739
- }
740
-
741
- threadgroup_barrier(mem_flags::mem_threadgroup);
742
-
743
- if (tiisg == 0) {
744
- buf[sgitg] = tmp;
745
- }
746
-
747
- threadgroup_barrier(mem_flags::mem_threadgroup);
748
-
749
- tmp = buf[tiisg];
750
- tmp = simd_sum(tmp);
751
- }
752
-
753
- const float variance = tmp / gs;
754
- const float scale = 1.0f/sqrt(variance + eps);
755
- for (int j = start; j < end; j += ntg) {
756
- dst[j] *= scale;
757
- }
758
- }
759
-
760
- // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
761
- // il indicates where the q4 quants begin (0 or QK4_0/4)
762
- // we assume that the yl's have been multiplied with the appropriate scale factor
763
- // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
764
- inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
765
- float d = qb_curr->d;
766
-
767
- float2 acc = 0.f;
768
-
769
- device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
770
-
771
- for (int i = 0; i < 8; i+=2) {
772
- acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
773
- + yl[i + 1] * (qs[i / 2] & 0x0F00);
774
- acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
775
- + yl[i + 9] * (qs[i / 2] & 0xF000);
776
- }
777
- return d * (sumy * -8.f + acc[0] + acc[1]);
778
- }
779
-
780
- // function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])
781
- // il indicates where the q4 quants begin (0 or QK4_0/4)
782
- // we assume that the yl's have been multiplied with the appropriate scale factor
783
- // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
784
- inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
785
- float d = qb_curr->d;
786
- float m = qb_curr->m;
787
-
788
- float2 acc = 0.f;
789
-
790
- device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
791
-
792
- for (int i = 0; i < 8; i+=2) {
793
- acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
794
- + yl[i + 1] * (qs[i / 2] & 0x0F00);
795
- acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
796
- + yl[i + 9] * (qs[i / 2] & 0xF000);
797
- }
798
- return d * (acc[0] + acc[1]) + sumy * m;
799
- }
800
-
801
- // function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
802
- // il indicates where the q5 quants begin (0 or QK5_0/4)
803
- // we assume that the yl's have been multiplied with the appropriate scale factor
804
- // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
805
- inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
806
- float d = qb_curr->d;
807
-
808
- float2 acc = 0.f;
809
-
810
- device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2);
811
- const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
812
-
813
- for (int i = 0; i < 8; i+=2) {
814
- acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
815
- + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
816
- acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
817
- + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
818
- }
819
- return d * (sumy * -16.f + acc[0] + acc[1]);
820
- }
821
-
822
- // function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
823
- // il indicates where the q5 quants begin (0 or QK5_1/4)
824
- // we assume that the yl's have been multiplied with the appropriate scale factor
825
- // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
826
- inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) {
827
- float d = qb_curr->d;
828
- float m = qb_curr->m;
829
-
830
- float2 acc = 0.f;
831
-
832
- device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2);
833
- const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
834
-
835
- for (int i = 0; i < 8; i+=2) {
836
- acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
837
- + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
838
- acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
839
- + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
840
- }
841
- return d * (acc[0] + acc[1]) + sumy * m;
842
- }
843
-
844
- // putting them in the kernel cause a significant performance penalty
845
- #define N_DST 4 // each SIMD group works on 4 rows
846
- #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
847
- //Note: This is a template, but strictly speaking it only applies to
848
- // quantizations where the block size is 32. It also does not
849
- // giard against the number of rows not being divisible by
850
- // N_DST, so this is another explicit assumption of the implementation.
851
- template<typename block_q_type, int nr, int nsg, int nw>
852
- void mul_vec_q_n_f32_impl(
853
- device const void * src0,
854
- device const float * src1,
855
- device float * dst,
856
- int64_t ne00,
857
- int64_t ne01,
858
- int64_t ne02,
859
- int64_t ne10,
860
- int64_t ne12,
861
- int64_t ne0,
862
- int64_t ne1,
863
- uint r2,
864
- uint r3,
865
- uint3 tgpig, uint tiisg, uint sgitg) {
866
- const int nb = ne00/QK4_0;
867
-
868
- const int r0 = tgpig.x;
869
- const int r1 = tgpig.y;
870
- const int im = tgpig.z;
871
-
872
- const int first_row = (r0 * nsg + sgitg) * nr;
873
-
874
- const uint i12 = im%ne12;
875
- const uint i13 = im/ne12;
876
-
877
- const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
878
-
879
- device const block_q_type * x = (device const block_q_type *) src0 + offset0;
880
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
881
-
882
- float yl[16]; // src1 vector cache
883
- float sumf[nr] = {0.f};
884
-
885
- const int ix = (tiisg/2);
886
- const int il = (tiisg%2)*8;
887
-
888
- device const float * yb = y + ix * QK4_0 + il;
889
-
890
- // each thread in a SIMD group deals with half a block.
891
- for (int ib = ix; ib < nb; ib += nw/2) {
892
- float sumy = 0;
893
- for (int i = 0; i < 8; i += 2) {
894
- sumy += yb[i] + yb[i+1];
895
- yl[i+0] = yb[i+ 0];
896
- yl[i+1] = yb[i+ 1]/256.f;
897
-
898
- sumy += yb[i+16] + yb[i+17];
899
- yl[i+8] = yb[i+16]/16.f;
900
- yl[i+9] = yb[i+17]/4096.f;
901
- }
902
-
903
- for (int row = 0; row < nr; row++) {
904
- sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il);
905
- }
906
-
907
- yb += QK4_0 * 16;
908
- }
909
-
910
- for (int row = 0; row < nr; ++row) {
911
- const float tot = simd_sum(sumf[row]);
912
- if (tiisg == 0 && first_row + row < ne01) {
913
- dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot;
914
- }
915
- }
916
- }
917
-
918
- kernel void kernel_mul_mv_q4_0_f32(
919
- device const void * src0,
920
- device const float * src1,
921
- device float * dst,
922
- constant int64_t & ne00,
923
- constant int64_t & ne01[[buffer(4)]],
924
- constant int64_t & ne02[[buffer(5)]],
925
- constant int64_t & ne10[[buffer(9)]],
926
- constant int64_t & ne12[[buffer(11)]],
927
- constant int64_t & ne0 [[buffer(15)]],
928
- constant int64_t & ne1 [[buffer(16)]],
929
- constant uint & r2 [[buffer(17)]],
930
- constant uint & r3 [[buffer(18)]],
931
- uint3 tgpig[[threadgroup_position_in_grid]],
932
- uint tiisg[[thread_index_in_simdgroup]],
933
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
934
- mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
935
- }
936
-
937
- kernel void kernel_mul_mv_q4_1_f32(
938
- device const void * src0,
939
- device const float * src1,
940
- device float * dst,
941
- constant int64_t & ne00,
942
- constant int64_t & ne01[[buffer(4)]],
943
- constant int64_t & ne02[[buffer(5)]],
944
- constant int64_t & ne10[[buffer(9)]],
945
- constant int64_t & ne12[[buffer(11)]],
946
- constant int64_t & ne0 [[buffer(15)]],
947
- constant int64_t & ne1 [[buffer(16)]],
948
- constant uint & r2 [[buffer(17)]],
949
- constant uint & r3 [[buffer(18)]],
950
- uint3 tgpig[[threadgroup_position_in_grid]],
951
- uint tiisg[[thread_index_in_simdgroup]],
952
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
953
- mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
954
- }
955
-
956
- kernel void kernel_mul_mv_q5_0_f32(
957
- device const void * src0,
958
- device const float * src1,
959
- device float * dst,
960
- constant int64_t & ne00,
961
- constant int64_t & ne01[[buffer(4)]],
962
- constant int64_t & ne02[[buffer(5)]],
963
- constant int64_t & ne10[[buffer(9)]],
964
- constant int64_t & ne12[[buffer(11)]],
965
- constant int64_t & ne0 [[buffer(15)]],
966
- constant int64_t & ne1 [[buffer(16)]],
967
- constant uint & r2 [[buffer(17)]],
968
- constant uint & r3 [[buffer(18)]],
969
- uint3 tgpig[[threadgroup_position_in_grid]],
970
- uint tiisg[[thread_index_in_simdgroup]],
971
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
972
- mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
973
- }
974
-
975
- kernel void kernel_mul_mv_q5_1_f32(
976
- device const void * src0,
977
- device const float * src1,
978
- device float * dst,
979
- constant int64_t & ne00,
980
- constant int64_t & ne01[[buffer(4)]],
981
- constant int64_t & ne02[[buffer(5)]],
982
- constant int64_t & ne10[[buffer(9)]],
983
- constant int64_t & ne12[[buffer(11)]],
984
- constant int64_t & ne0 [[buffer(15)]],
985
- constant int64_t & ne1 [[buffer(16)]],
986
- constant uint & r2 [[buffer(17)]],
987
- constant uint & r3 [[buffer(18)]],
988
- uint3 tgpig[[threadgroup_position_in_grid]],
989
- uint tiisg[[thread_index_in_simdgroup]],
990
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
991
- mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
992
- }
993
-
994
-
995
- #define NB_Q8_0 8
996
-
997
- void kernel_mul_mv_q8_0_f32_impl(
998
- device const void * src0,
999
- device const float * src1,
1000
- device float * dst,
1001
- constant int64_t & ne00,
1002
- constant int64_t & ne01,
1003
- constant int64_t & ne02,
1004
- constant int64_t & ne10,
1005
- constant int64_t & ne12,
1006
- constant int64_t & ne0,
1007
- constant int64_t & ne1,
1008
- constant uint & r2,
1009
- constant uint & r3,
1010
- uint3 tgpig[[threadgroup_position_in_grid]],
1011
- uint tiisg[[thread_index_in_simdgroup]],
1012
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
1013
- const int nr = N_DST;
1014
- const int nsg = N_SIMDGROUP;
1015
- const int nw = N_SIMDWIDTH;
1016
-
1017
- const int nb = ne00/QK8_0;
1018
- const int r0 = tgpig.x;
1019
- const int r1 = tgpig.y;
1020
- const int im = tgpig.z;
1021
-
1022
- const int first_row = (r0 * nsg + sgitg) * nr;
1023
-
1024
- const uint i12 = im%ne12;
1025
- const uint i13 = im/ne12;
1026
-
1027
- const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
1028
-
1029
- device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
1030
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
1031
-
1032
- float yl[NB_Q8_0];
1033
- float sumf[nr]={0.f};
1034
-
1035
- const int ix = tiisg/4;
1036
- const int il = tiisg%4;
1037
-
1038
- device const float * yb = y + ix * QK8_0 + NB_Q8_0*il;
1039
-
1040
- // each thread in a SIMD group deals with NB_Q8_0 quants at a time
1041
- for (int ib = ix; ib < nb; ib += nw/4) {
1042
- for (int i = 0; i < NB_Q8_0; ++i) {
1043
- yl[i] = yb[i];
1044
- }
1045
-
1046
- for (int row = 0; row < nr; row++) {
1047
- device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il;
1048
- float sumq = 0.f;
1049
- for (int iq = 0; iq < NB_Q8_0; ++iq) {
1050
- sumq += qs[iq] * yl[iq];
1051
- }
1052
- sumf[row] += sumq*x[ib+row*nb].d;
1053
- }
1054
-
1055
- yb += NB_Q8_0 * nw;
1056
- }
1057
-
1058
- for (int row = 0; row < nr; ++row) {
1059
- const float tot = simd_sum(sumf[row]);
1060
- if (tiisg == 0 && first_row + row < ne01) {
1061
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
1062
- }
1063
- }
1064
- }
1065
-
1066
- [[host_name("kernel_mul_mv_q8_0_f32")]]
1067
- kernel void kernel_mul_mv_q8_0_f32(
1068
- device const void * src0,
1069
- device const float * src1,
1070
- device float * dst,
1071
- constant int64_t & ne00,
1072
- constant int64_t & ne01,
1073
- constant int64_t & ne02,
1074
- constant int64_t & ne10,
1075
- constant int64_t & ne12,
1076
- constant int64_t & ne0,
1077
- constant int64_t & ne1,
1078
- constant uint & r2 [[buffer(17)]],
1079
- constant uint & r3 [[buffer(18)]],
1080
- uint3 tgpig[[threadgroup_position_in_grid]],
1081
- uint tiisg[[thread_index_in_simdgroup]],
1082
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
1083
- kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
1084
- }
1085
-
1086
- #define N_F32_F32 4
1087
-
1088
- void kernel_mul_mv_f32_f32_impl(
1089
- device const char * src0,
1090
- device const char * src1,
1091
- device float * dst,
1092
- constant int64_t & ne00,
1093
- constant int64_t & ne01,
1094
- constant int64_t & ne02,
1095
- constant uint64_t & nb00,
1096
- constant uint64_t & nb01,
1097
- constant uint64_t & nb02,
1098
- constant int64_t & ne10,
1099
- constant int64_t & ne11,
1100
- constant int64_t & ne12,
1101
- constant uint64_t & nb10,
1102
- constant uint64_t & nb11,
1103
- constant uint64_t & nb12,
1104
- constant int64_t & ne0,
1105
- constant int64_t & ne1,
1106
- constant uint & r2,
1107
- constant uint & r3,
1108
- uint3 tgpig[[threadgroup_position_in_grid]],
1109
- uint tiisg[[thread_index_in_simdgroup]]) {
1110
-
1111
- const int64_t r0 = tgpig.x;
1112
- const int64_t rb = tgpig.y*N_F32_F32;
1113
- const int64_t im = tgpig.z;
1114
-
1115
- const uint i12 = im%ne12;
1116
- const uint i13 = im/ne12;
1117
-
1118
- const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1119
-
1120
- device const float * x = (device const float *) (src0 + offset0);
1121
-
1122
- if (ne00 < 128) {
1123
- for (int row = 0; row < N_F32_F32; ++row) {
1124
- int r1 = rb + row;
1125
- if (r1 >= ne11) {
1126
- break;
1127
- }
1128
-
1129
- device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
1130
-
1131
- float sumf = 0;
1132
- for (int i = tiisg; i < ne00; i += 32) {
1133
- sumf += (float) x[i] * (float) y[i];
1134
- }
1135
-
1136
- float all_sum = simd_sum(sumf);
1137
- if (tiisg == 0) {
1138
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1139
- }
1140
- }
1141
- } else {
1142
- device const float4 * x4 = (device const float4 *)x;
1143
- for (int row = 0; row < N_F32_F32; ++row) {
1144
- int r1 = rb + row;
1145
- if (r1 >= ne11) {
1146
- break;
1147
- }
1148
-
1149
- device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
1150
- device const float4 * y4 = (device const float4 *) y;
1151
-
1152
- float sumf = 0;
1153
- for (int i = tiisg; i < ne00/4; i += 32) {
1154
- for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
1155
- }
1156
-
1157
- float all_sum = simd_sum(sumf);
1158
- if (tiisg == 0) {
1159
- for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
1160
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1161
- }
1162
- }
1163
- }
1164
- }
1165
-
1166
- [[host_name("kernel_mul_mv_f32_f32")]]
1167
- kernel void kernel_mul_mv_f32_f32(
1168
- device const char * src0,
1169
- device const char * src1,
1170
- device float * dst,
1171
- constant int64_t & ne00,
1172
- constant int64_t & ne01,
1173
- constant int64_t & ne02,
1174
- constant uint64_t & nb00,
1175
- constant uint64_t & nb01,
1176
- constant uint64_t & nb02,
1177
- constant int64_t & ne10,
1178
- constant int64_t & ne11,
1179
- constant int64_t & ne12,
1180
- constant uint64_t & nb10,
1181
- constant uint64_t & nb11,
1182
- constant uint64_t & nb12,
1183
- constant int64_t & ne0,
1184
- constant int64_t & ne1,
1185
- constant uint & r2 [[buffer(17)]],
1186
- constant uint & r3 [[buffer(18)]],
1187
- uint3 tgpig[[threadgroup_position_in_grid]],
1188
- uint tiisg[[thread_index_in_simdgroup]]) {
1189
- kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
1190
- }
1191
-
1192
- #define N_F16_F16 4
1193
-
1194
- kernel void kernel_mul_mv_f16_f16(
1195
- device const char * src0,
1196
- device const char * src1,
1197
- device float * dst,
1198
- constant int64_t & ne00,
1199
- constant int64_t & ne01,
1200
- constant int64_t & ne02,
1201
- constant uint64_t & nb00,
1202
- constant uint64_t & nb01,
1203
- constant uint64_t & nb02,
1204
- constant int64_t & ne10,
1205
- constant int64_t & ne11,
1206
- constant int64_t & ne12,
1207
- constant uint64_t & nb10,
1208
- constant uint64_t & nb11,
1209
- constant uint64_t & nb12,
1210
- constant int64_t & ne0,
1211
- constant int64_t & ne1,
1212
- constant uint & r2 [[buffer(17)]],
1213
- constant uint & r3 [[buffer(18)]],
1214
- uint3 tgpig[[threadgroup_position_in_grid]],
1215
- uint tiisg[[thread_index_in_simdgroup]]) {
1216
-
1217
- const int64_t r0 = tgpig.x;
1218
- const int64_t rb = tgpig.y*N_F16_F16;
1219
- const int64_t im = tgpig.z;
1220
-
1221
- const uint i12 = im%ne12;
1222
- const uint i13 = im/ne12;
1223
-
1224
- const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1225
-
1226
- device const half * x = (device const half *) (src0 + offset0);
1227
-
1228
- if (ne00 < 128) {
1229
- for (int row = 0; row < N_F16_F16; ++row) {
1230
- int r1 = rb + row;
1231
- if (r1 >= ne11) {
1232
- break;
1233
- }
1234
-
1235
- device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
1236
-
1237
- float sumf = 0;
1238
- for (int i = tiisg; i < ne00; i += 32) {
1239
- sumf += (half) x[i] * (half) y[i];
1240
- }
1241
-
1242
- float all_sum = simd_sum(sumf);
1243
- if (tiisg == 0) {
1244
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1245
- }
1246
- }
1247
- } else {
1248
- device const half4 * x4 = (device const half4 *)x;
1249
- for (int row = 0; row < N_F16_F16; ++row) {
1250
- int r1 = rb + row;
1251
- if (r1 >= ne11) {
1252
- break;
1253
- }
1254
-
1255
- device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
1256
- device const half4 * y4 = (device const half4 *) y;
1257
-
1258
- float sumf = 0;
1259
- for (int i = tiisg; i < ne00/4; i += 32) {
1260
- for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k];
1261
- }
1262
-
1263
- float all_sum = simd_sum(sumf);
1264
- if (tiisg == 0) {
1265
- for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i];
1266
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1267
- }
1268
- }
1269
- }
1270
- }
1271
-
1272
- void kernel_mul_mv_f16_f32_1row_impl(
1273
- device const char * src0,
1274
- device const char * src1,
1275
- device float * dst,
1276
- constant int64_t & ne00,
1277
- constant int64_t & ne01,
1278
- constant int64_t & ne02,
1279
- constant uint64_t & nb00,
1280
- constant uint64_t & nb01,
1281
- constant uint64_t & nb02,
1282
- constant int64_t & ne10,
1283
- constant int64_t & ne11,
1284
- constant int64_t & ne12,
1285
- constant uint64_t & nb10,
1286
- constant uint64_t & nb11,
1287
- constant uint64_t & nb12,
1288
- constant int64_t & ne0,
1289
- constant int64_t & ne1,
1290
- constant uint & r2,
1291
- constant uint & r3,
1292
- uint3 tgpig[[threadgroup_position_in_grid]],
1293
- uint tiisg[[thread_index_in_simdgroup]]) {
1294
-
1295
- const int64_t r0 = tgpig.x;
1296
- const int64_t r1 = tgpig.y;
1297
- const int64_t im = tgpig.z;
1298
-
1299
- const uint i12 = im%ne12;
1300
- const uint i13 = im/ne12;
1301
-
1302
- const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1303
-
1304
- device const half * x = (device const half *) (src0 + offset0);
1305
- device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
1306
-
1307
- float sumf = 0;
1308
- if (ne00 < 128) {
1309
- for (int i = tiisg; i < ne00; i += 32) {
1310
- sumf += (float) x[i] * (float) y[i];
1311
- }
1312
- float all_sum = simd_sum(sumf);
1313
- if (tiisg == 0) {
1314
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1315
- }
1316
- } else {
1317
- device const half4 * x4 = (device const half4 *) x;
1318
- device const float4 * y4 = (device const float4 *) y;
1319
- for (int i = tiisg; i < ne00/4; i += 32) {
1320
- for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k];
1321
- }
1322
- float all_sum = simd_sum(sumf);
1323
- if (tiisg == 0) {
1324
- for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
1325
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1326
- }
1327
- }
1328
- }
1329
-
1330
- [[host_name("kernel_mul_mv_f16_f32_1row")]]
1331
- kernel void kernel_mul_mv_f16_f32_1row(
1332
- device const char * src0,
1333
- device const char * src1,
1334
- device float * dst,
1335
- constant int64_t & ne00,
1336
- constant int64_t & ne01,
1337
- constant int64_t & ne02,
1338
- constant uint64_t & nb00,
1339
- constant uint64_t & nb01,
1340
- constant uint64_t & nb02,
1341
- constant int64_t & ne10,
1342
- constant int64_t & ne11,
1343
- constant int64_t & ne12,
1344
- constant uint64_t & nb10,
1345
- constant uint64_t & nb11,
1346
- constant uint64_t & nb12,
1347
- constant int64_t & ne0,
1348
- constant int64_t & ne1,
1349
- constant uint & r2 [[buffer(17)]],
1350
- constant uint & r3 [[buffer(18)]],
1351
- uint3 tgpig[[threadgroup_position_in_grid]],
1352
- uint tiisg[[thread_index_in_simdgroup]]) {
1353
- kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
1354
- }
1355
-
1356
- #define N_F16_F32 4
1357
-
1358
- void kernel_mul_mv_f16_f32_impl(
1359
- device const char * src0,
1360
- device const char * src1,
1361
- device float * dst,
1362
- constant int64_t & ne00,
1363
- constant int64_t & ne01,
1364
- constant int64_t & ne02,
1365
- constant uint64_t & nb00,
1366
- constant uint64_t & nb01,
1367
- constant uint64_t & nb02,
1368
- constant int64_t & ne10,
1369
- constant int64_t & ne11,
1370
- constant int64_t & ne12,
1371
- constant uint64_t & nb10,
1372
- constant uint64_t & nb11,
1373
- constant uint64_t & nb12,
1374
- constant int64_t & ne0,
1375
- constant int64_t & ne1,
1376
- constant uint & r2,
1377
- constant uint & r3,
1378
- uint3 tgpig[[threadgroup_position_in_grid]],
1379
- uint tiisg[[thread_index_in_simdgroup]]) {
1380
-
1381
- const int64_t r0 = tgpig.x;
1382
- const int64_t rb = tgpig.y*N_F16_F32;
1383
- const int64_t im = tgpig.z;
1384
-
1385
- const uint i12 = im%ne12;
1386
- const uint i13 = im/ne12;
1387
-
1388
- const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1389
-
1390
- device const half * x = (device const half *) (src0 + offset0);
1391
-
1392
- if (ne00 < 128) {
1393
- for (int row = 0; row < N_F16_F32; ++row) {
1394
- int r1 = rb + row;
1395
- if (r1 >= ne11) {
1396
- break;
1397
- }
1398
-
1399
- device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
1400
-
1401
- float sumf = 0;
1402
- for (int i = tiisg; i < ne00; i += 32) {
1403
- sumf += (float) x[i] * (float) y[i];
1404
- }
1405
-
1406
- float all_sum = simd_sum(sumf);
1407
- if (tiisg == 0) {
1408
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1409
- }
1410
- }
1411
- } else {
1412
- device const half4 * x4 = (device const half4 *)x;
1413
- for (int row = 0; row < N_F16_F32; ++row) {
1414
- int r1 = rb + row;
1415
- if (r1 >= ne11) {
1416
- break;
1417
- }
1418
-
1419
- device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
1420
- device const float4 * y4 = (device const float4 *) y;
1421
-
1422
- float sumf = 0;
1423
- for (int i = tiisg; i < ne00/4; i += 32) {
1424
- for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
1425
- }
1426
-
1427
- float all_sum = simd_sum(sumf);
1428
- if (tiisg == 0) {
1429
- for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
1430
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1431
- }
1432
- }
1433
- }
1434
- }
1435
-
1436
- [[host_name("kernel_mul_mv_f16_f32")]]
1437
- kernel void kernel_mul_mv_f16_f32(
1438
- device const char * src0,
1439
- device const char * src1,
1440
- device float * dst,
1441
- constant int64_t & ne00,
1442
- constant int64_t & ne01,
1443
- constant int64_t & ne02,
1444
- constant uint64_t & nb00,
1445
- constant uint64_t & nb01,
1446
- constant uint64_t & nb02,
1447
- constant int64_t & ne10,
1448
- constant int64_t & ne11,
1449
- constant int64_t & ne12,
1450
- constant uint64_t & nb10,
1451
- constant uint64_t & nb11,
1452
- constant uint64_t & nb12,
1453
- constant int64_t & ne0,
1454
- constant int64_t & ne1,
1455
- constant uint & r2 [[buffer(17)]],
1456
- constant uint & r3 [[buffer(18)]],
1457
- uint3 tgpig[[threadgroup_position_in_grid]],
1458
- uint tiisg[[thread_index_in_simdgroup]]) {
1459
- kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
1460
- }
1461
-
1462
- // Assumes row size (ne00) is a multiple of 4
1463
- kernel void kernel_mul_mv_f16_f32_l4(
1464
- device const char * src0,
1465
- device const char * src1,
1466
- device float * dst,
1467
- constant int64_t & ne00,
1468
- constant int64_t & ne01,
1469
- constant int64_t & ne02,
1470
- constant uint64_t & nb00,
1471
- constant uint64_t & nb01,
1472
- constant uint64_t & nb02,
1473
- constant int64_t & ne10,
1474
- constant int64_t & ne11,
1475
- constant int64_t & ne12,
1476
- constant uint64_t & nb10,
1477
- constant uint64_t & nb11,
1478
- constant uint64_t & nb12,
1479
- constant int64_t & ne0,
1480
- constant int64_t & ne1,
1481
- constant uint & r2 [[buffer(17)]],
1482
- constant uint & r3 [[buffer(18)]],
1483
- uint3 tgpig[[threadgroup_position_in_grid]],
1484
- uint tiisg[[thread_index_in_simdgroup]]) {
1485
-
1486
- const int nrows = ne11;
1487
- const int64_t r0 = tgpig.x;
1488
- const int64_t im = tgpig.z;
1489
-
1490
- const uint i12 = im%ne12;
1491
- const uint i13 = im/ne12;
1492
-
1493
- const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1494
-
1495
- device const half4 * x4 = (device const half4 *) (src0 + offset0);
1496
-
1497
- for (int r1 = 0; r1 < nrows; ++r1) {
1498
- device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
1499
-
1500
- float sumf = 0;
1501
- for (int i = tiisg; i < ne00/4; i += 32) {
1502
- for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
1503
- }
1504
-
1505
- float all_sum = simd_sum(sumf);
1506
- if (tiisg == 0) {
1507
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1508
- }
1509
- }
1510
- }
1511
-
1512
- kernel void kernel_alibi_f32(
1513
- device const float * src0,
1514
- device float * dst,
1515
- constant int64_t & ne00,
1516
- constant int64_t & ne01,
1517
- constant int64_t & ne02,
1518
- constant int64_t & ne03,
1519
- constant uint64_t & nb00,
1520
- constant uint64_t & nb01,
1521
- constant uint64_t & nb02,
1522
- constant uint64_t & nb03,
1523
- constant int64_t & ne0,
1524
- constant int64_t & ne1,
1525
- constant int64_t & ne2,
1526
- constant int64_t & ne3,
1527
- constant uint64_t & nb0,
1528
- constant uint64_t & nb1,
1529
- constant uint64_t & nb2,
1530
- constant uint64_t & nb3,
1531
- constant float & m0,
1532
- constant float & m1,
1533
- constant int & n_heads_log2_floor,
1534
- uint3 tgpig[[threadgroup_position_in_grid]],
1535
- uint3 tpitg[[thread_position_in_threadgroup]],
1536
- uint3 ntg[[threads_per_threadgroup]]) {
1537
- const int64_t i03 = tgpig[2];
1538
- const int64_t i02 = tgpig[1];
1539
- const int64_t i01 = tgpig[0];
1540
-
1541
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
1542
-
1543
- const int64_t i3 = n / (ne2*ne1*ne0);
1544
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1545
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1546
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1547
- const int64_t k = i3*ne3 + i2;
1548
-
1549
- float m_k;
1550
- if (k < n_heads_log2_floor) {
1551
- m_k = pow(m0, k + 1);
1552
- } else {
1553
- m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1);
1554
- }
1555
-
1556
- device char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1;
1557
- device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01;
1558
- for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
1559
- const float src_v = *(device float *)(src_row + i00*nb00);
1560
- device float * dst_v = (device float *)(dst_row + i00*nb0);
1561
- *dst_v = i00 * m_k + src_v;
1562
- }
1563
- }
1564
-
1565
- static float rope_yarn_ramp(const float low, const float high, const int i0) {
1566
- const float y = (i0 / 2 - low) / max(0.001f, high - low);
1567
- return 1.0f - min(1.0f, max(0.0f, y));
1568
- }
1569
-
1570
- // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
1571
- // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
1572
- static void rope_yarn(
1573
- float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
1574
- thread float * cos_theta, thread float * sin_theta
1575
- ) {
1576
- // Get n-d rotational scaling corrected for extrapolation
1577
- float theta_interp = freq_scale * theta_extrap;
1578
- float theta = theta_interp;
1579
- if (ext_factor != 0.0f) {
1580
- float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
1581
- theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
1582
-
1583
- // Get n-d magnitude scaling corrected for interpolation
1584
- mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
1585
- }
1586
- *cos_theta = cos(theta) * mscale;
1587
- *sin_theta = sin(theta) * mscale;
1588
- }
1589
-
1590
- // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
1591
- // `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
1592
- static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
1593
- return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base));
1594
- }
1595
-
1596
- static void rope_yarn_corr_dims(
1597
- int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
1598
- ) {
1599
- // start and end correction dims
1600
- dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
1601
- dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
1602
- }
1603
-
1604
- typedef void (rope_t)(
1605
- device const void * src0,
1606
- device const int32_t * src1,
1607
- device float * dst,
1608
- constant int64_t & ne00,
1609
- constant int64_t & ne01,
1610
- constant int64_t & ne02,
1611
- constant int64_t & ne03,
1612
- constant uint64_t & nb00,
1613
- constant uint64_t & nb01,
1614
- constant uint64_t & nb02,
1615
- constant uint64_t & nb03,
1616
- constant int64_t & ne0,
1617
- constant int64_t & ne1,
1618
- constant int64_t & ne2,
1619
- constant int64_t & ne3,
1620
- constant uint64_t & nb0,
1621
- constant uint64_t & nb1,
1622
- constant uint64_t & nb2,
1623
- constant uint64_t & nb3,
1624
- constant int & n_past,
1625
- constant int & n_dims,
1626
- constant int & mode,
1627
- constant int & n_orig_ctx,
1628
- constant float & freq_base,
1629
- constant float & freq_scale,
1630
- constant float & ext_factor,
1631
- constant float & attn_factor,
1632
- constant float & beta_fast,
1633
- constant float & beta_slow,
1634
- uint tiitg[[thread_index_in_threadgroup]],
1635
- uint3 tptg[[threads_per_threadgroup]],
1636
- uint3 tgpig[[threadgroup_position_in_grid]]);
1637
-
1638
- template<typename T>
1639
- kernel void kernel_rope(
1640
- device const void * src0,
1641
- device const int32_t * src1,
1642
- device float * dst,
1643
- constant int64_t & ne00,
1644
- constant int64_t & ne01,
1645
- constant int64_t & ne02,
1646
- constant int64_t & ne03,
1647
- constant uint64_t & nb00,
1648
- constant uint64_t & nb01,
1649
- constant uint64_t & nb02,
1650
- constant uint64_t & nb03,
1651
- constant int64_t & ne0,
1652
- constant int64_t & ne1,
1653
- constant int64_t & ne2,
1654
- constant int64_t & ne3,
1655
- constant uint64_t & nb0,
1656
- constant uint64_t & nb1,
1657
- constant uint64_t & nb2,
1658
- constant uint64_t & nb3,
1659
- constant int & n_past,
1660
- constant int & n_dims,
1661
- constant int & mode,
1662
- constant int & n_orig_ctx,
1663
- constant float & freq_base,
1664
- constant float & freq_scale,
1665
- constant float & ext_factor,
1666
- constant float & attn_factor,
1667
- constant float & beta_fast,
1668
- constant float & beta_slow,
1669
- uint tiitg[[thread_index_in_threadgroup]],
1670
- uint3 tptg[[threads_per_threadgroup]],
1671
- uint3 tgpig[[threadgroup_position_in_grid]]) {
1672
- const int64_t i3 = tgpig[2];
1673
- const int64_t i2 = tgpig[1];
1674
- const int64_t i1 = tgpig[0];
1675
-
1676
- const bool is_neox = mode & 2;
1677
-
1678
- float corr_dims[2];
1679
- rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
1680
-
1681
- device const int32_t * pos = src1;
1682
-
1683
- const int64_t p = pos[i2];
1684
-
1685
- const float theta_0 = (float)p;
1686
- const float inv_ndims = -1.f/n_dims;
1687
-
1688
- if (!is_neox) {
1689
- for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
1690
-
1691
- const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
1692
- float cos_theta, sin_theta;
1693
- rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
1694
-
1695
- device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1696
- device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1697
-
1698
- const T x0 = src[0];
1699
- const T x1 = src[1];
1700
-
1701
- dst_data[0] = x0*cos_theta - x1*sin_theta;
1702
- dst_data[1] = x0*sin_theta + x1*cos_theta;
1703
- }
1704
- } else {
1705
- for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
1706
- for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
1707
-
1708
- // simplified from `(ib * n_dims + ic) * inv_ndims`
1709
- const float cur_rot = inv_ndims*ic - ib;
1710
-
1711
- const float theta = theta_0 * pow(freq_base, cur_rot);
1712
- float cos_theta, sin_theta;
1713
- rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
1714
-
1715
- const int64_t i0 = ib*n_dims + ic/2;
1716
-
1717
- device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1718
- device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1719
-
1720
- const float x0 = src[0];
1721
- const float x1 = src[n_dims/2];
1722
-
1723
- dst_data[0] = x0*cos_theta - x1*sin_theta;
1724
- dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
1725
- }
1726
- }
1727
- }
1728
- }
1729
-
1730
- template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
1731
- template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
1732
-
1733
- kernel void kernel_im2col_f16(
1734
- device const float * x,
1735
- device half * dst,
1736
- constant int32_t & ofs0,
1737
- constant int32_t & ofs1,
1738
- constant int32_t & IW,
1739
- constant int32_t & IH,
1740
- constant int32_t & CHW,
1741
- constant int32_t & s0,
1742
- constant int32_t & s1,
1743
- constant int32_t & p0,
1744
- constant int32_t & p1,
1745
- constant int32_t & d0,
1746
- constant int32_t & d1,
1747
- uint3 tgpig[[threadgroup_position_in_grid]],
1748
- uint3 tgpg[[threadgroups_per_grid]],
1749
- uint3 tpitg[[thread_position_in_threadgroup]],
1750
- uint3 ntg[[threads_per_threadgroup]]) {
1751
- const int32_t iiw = tgpig[2] * s0 + tpitg[2] * d0 - p0;
1752
- const int32_t iih = tgpig[1] * s1 + tpitg[1] * d1 - p1;
1753
-
1754
- const int32_t offset_dst =
1755
- (tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
1756
- (tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
1757
-
1758
- if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
1759
- dst[offset_dst] = 0.0f;
1760
- } else {
1761
- const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
1762
- dst[offset_dst] = x[offset_src + iih * IW + iiw];
1763
- }
1764
- }
1765
-
1766
- kernel void kernel_upscale_f32(
1767
- device const char * src0,
1768
- device char * dst,
1769
- constant int64_t & ne00,
1770
- constant int64_t & ne01,
1771
- constant int64_t & ne02,
1772
- constant int64_t & ne03,
1773
- constant uint64_t & nb00,
1774
- constant uint64_t & nb01,
1775
- constant uint64_t & nb02,
1776
- constant uint64_t & nb03,
1777
- constant int64_t & ne0,
1778
- constant int64_t & ne1,
1779
- constant int64_t & ne2,
1780
- constant int64_t & ne3,
1781
- constant uint64_t & nb0,
1782
- constant uint64_t & nb1,
1783
- constant uint64_t & nb2,
1784
- constant uint64_t & nb3,
1785
- constant int32_t & sf,
1786
- uint3 tgpig[[threadgroup_position_in_grid]],
1787
- uint3 tpitg[[thread_position_in_threadgroup]],
1788
- uint3 ntg[[threads_per_threadgroup]]) {
1789
-
1790
- const int64_t i3 = tgpig.z;
1791
- const int64_t i2 = tgpig.y;
1792
- const int64_t i1 = tgpig.x;
1793
-
1794
- const int64_t i03 = i3;
1795
- const int64_t i02 = i2;
1796
- const int64_t i01 = i1/sf;
1797
-
1798
- device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
1799
- device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
1800
-
1801
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
1802
- dst_ptr[i0] = src0_ptr[i0/sf];
1803
- }
1804
- }
1805
-
1806
- kernel void kernel_pad_f32(
1807
- device const char * src0,
1808
- device char * dst,
1809
- constant int64_t & ne00,
1810
- constant int64_t & ne01,
1811
- constant int64_t & ne02,
1812
- constant int64_t & ne03,
1813
- constant uint64_t & nb00,
1814
- constant uint64_t & nb01,
1815
- constant uint64_t & nb02,
1816
- constant uint64_t & nb03,
1817
- constant int64_t & ne0,
1818
- constant int64_t & ne1,
1819
- constant int64_t & ne2,
1820
- constant int64_t & ne3,
1821
- constant uint64_t & nb0,
1822
- constant uint64_t & nb1,
1823
- constant uint64_t & nb2,
1824
- constant uint64_t & nb3,
1825
- uint3 tgpig[[threadgroup_position_in_grid]],
1826
- uint3 tpitg[[thread_position_in_threadgroup]],
1827
- uint3 ntg[[threads_per_threadgroup]]) {
1828
-
1829
- const int64_t i3 = tgpig.z;
1830
- const int64_t i2 = tgpig.y;
1831
- const int64_t i1 = tgpig.x;
1832
-
1833
- const int64_t i03 = i3;
1834
- const int64_t i02 = i2;
1835
- const int64_t i01 = i1;
1836
-
1837
- device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
1838
- device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
1839
-
1840
- if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
1841
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
1842
- if (i0 < ne00) {
1843
- dst_ptr[i0] = src0_ptr[i0];
1844
- } else {
1845
- dst_ptr[i0] = 0.0f;
1846
- }
1847
- }
1848
-
1849
- return;
1850
- }
1851
-
1852
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
1853
- dst_ptr[i0] = 0.0f;
1854
- }
1855
- }
1856
-
1857
- // bitonic sort implementation following the CUDA kernels as reference
1858
- typedef void (argsort_t)(
1859
- device const float * x,
1860
- device int32_t * dst,
1861
- constant int64_t & ncols,
1862
- uint3 tgpig[[threadgroup_position_in_grid]],
1863
- uint3 tpitg[[thread_position_in_threadgroup]]);
1864
-
1865
- template<ggml_sort_order order>
1866
- kernel void kernel_argsort_f32_i32(
1867
- device const float * x,
1868
- device int32_t * dst,
1869
- constant int64_t & ncols,
1870
- uint3 tgpig[[threadgroup_position_in_grid]],
1871
- uint3 tpitg[[thread_position_in_threadgroup]]) {
1872
- // bitonic sort
1873
- int col = tpitg[0];
1874
- int row = tgpig[1];
1875
-
1876
- if (col >= ncols) return;
1877
-
1878
- device const float * x_row = x + row * ncols;
1879
- device int32_t * dst_row = dst + row * ncols;
1880
-
1881
- // initialize indices
1882
- if (col < ncols) {
1883
- dst_row[col] = col;
1884
- }
1885
- threadgroup_barrier(mem_flags::mem_threadgroup);
1886
-
1887
- for (int k = 2; k <= ncols; k *= 2) {
1888
- for (int j = k / 2; j > 0; j /= 2) {
1889
- int ixj = col ^ j;
1890
- if (ixj > col) {
1891
- if ((col & k) == 0) {
1892
- if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
1893
- SWAP(dst_row[col], dst_row[ixj]);
1894
- }
1895
- } else {
1896
- if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
1897
- SWAP(dst_row[col], dst_row[ixj]);
1898
- }
1899
- }
1900
- }
1901
- threadgroup_barrier(mem_flags::mem_threadgroup);
1902
- }
1903
- }
1904
- }
1905
-
1906
- template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ASC>;
1907
- template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
1908
-
1909
- kernel void kernel_leaky_relu_f32(
1910
- device const float * src0,
1911
- device float * dst,
1912
- constant float & slope,
1913
- uint tpig[[thread_position_in_grid]]) {
1914
- dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
1915
- }
1916
-
1917
- kernel void kernel_cpy_f16_f16(
1918
- device const half * src0,
1919
- device half * dst,
1920
- constant int64_t & ne00,
1921
- constant int64_t & ne01,
1922
- constant int64_t & ne02,
1923
- constant int64_t & ne03,
1924
- constant uint64_t & nb00,
1925
- constant uint64_t & nb01,
1926
- constant uint64_t & nb02,
1927
- constant uint64_t & nb03,
1928
- constant int64_t & ne0,
1929
- constant int64_t & ne1,
1930
- constant int64_t & ne2,
1931
- constant int64_t & ne3,
1932
- constant uint64_t & nb0,
1933
- constant uint64_t & nb1,
1934
- constant uint64_t & nb2,
1935
- constant uint64_t & nb3,
1936
- uint3 tgpig[[threadgroup_position_in_grid]],
1937
- uint3 tpitg[[thread_position_in_threadgroup]],
1938
- uint3 ntg[[threads_per_threadgroup]]) {
1939
- const int64_t i03 = tgpig[2];
1940
- const int64_t i02 = tgpig[1];
1941
- const int64_t i01 = tgpig[0];
1942
-
1943
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
1944
-
1945
- const int64_t i3 = n / (ne2*ne1*ne0);
1946
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1947
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1948
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1949
-
1950
- device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1951
-
1952
- for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
1953
- device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1954
- dst_data[i00] = src[0];
1955
- }
1956
- }
1957
-
1958
- kernel void kernel_cpy_f16_f32(
1959
- device const half * src0,
1960
- device float * dst,
1961
- constant int64_t & ne00,
1962
- constant int64_t & ne01,
1963
- constant int64_t & ne02,
1964
- constant int64_t & ne03,
1965
- constant uint64_t & nb00,
1966
- constant uint64_t & nb01,
1967
- constant uint64_t & nb02,
1968
- constant uint64_t & nb03,
1969
- constant int64_t & ne0,
1970
- constant int64_t & ne1,
1971
- constant int64_t & ne2,
1972
- constant int64_t & ne3,
1973
- constant uint64_t & nb0,
1974
- constant uint64_t & nb1,
1975
- constant uint64_t & nb2,
1976
- constant uint64_t & nb3,
1977
- uint3 tgpig[[threadgroup_position_in_grid]],
1978
- uint3 tpitg[[thread_position_in_threadgroup]],
1979
- uint3 ntg[[threads_per_threadgroup]]) {
1980
- const int64_t i03 = tgpig[2];
1981
- const int64_t i02 = tgpig[1];
1982
- const int64_t i01 = tgpig[0];
1983
-
1984
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
1985
-
1986
- const int64_t i3 = n / (ne2*ne1*ne0);
1987
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1988
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1989
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1990
-
1991
- device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1992
-
1993
- for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
1994
- device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1995
- dst_data[i00] = src[0];
1996
- }
1997
- }
1998
-
1999
- kernel void kernel_cpy_f32_f16(
2000
- device const float * src0,
2001
- device half * dst,
2002
- constant int64_t & ne00,
2003
- constant int64_t & ne01,
2004
- constant int64_t & ne02,
2005
- constant int64_t & ne03,
2006
- constant uint64_t & nb00,
2007
- constant uint64_t & nb01,
2008
- constant uint64_t & nb02,
2009
- constant uint64_t & nb03,
2010
- constant int64_t & ne0,
2011
- constant int64_t & ne1,
2012
- constant int64_t & ne2,
2013
- constant int64_t & ne3,
2014
- constant uint64_t & nb0,
2015
- constant uint64_t & nb1,
2016
- constant uint64_t & nb2,
2017
- constant uint64_t & nb3,
2018
- uint3 tgpig[[threadgroup_position_in_grid]],
2019
- uint3 tpitg[[thread_position_in_threadgroup]],
2020
- uint3 ntg[[threads_per_threadgroup]]) {
2021
- const int64_t i03 = tgpig[2];
2022
- const int64_t i02 = tgpig[1];
2023
- const int64_t i01 = tgpig[0];
2024
-
2025
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2026
-
2027
- const int64_t i3 = n / (ne2*ne1*ne0);
2028
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2029
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2030
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
2031
-
2032
- device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2033
-
2034
- for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
2035
- device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2036
-
2037
- dst_data[i00] = src[0];
2038
- }
2039
- }
2040
-
2041
- kernel void kernel_cpy_f32_f32(
2042
- device const float * src0,
2043
- device float * dst,
2044
- constant int64_t & ne00,
2045
- constant int64_t & ne01,
2046
- constant int64_t & ne02,
2047
- constant int64_t & ne03,
2048
- constant uint64_t & nb00,
2049
- constant uint64_t & nb01,
2050
- constant uint64_t & nb02,
2051
- constant uint64_t & nb03,
2052
- constant int64_t & ne0,
2053
- constant int64_t & ne1,
2054
- constant int64_t & ne2,
2055
- constant int64_t & ne3,
2056
- constant uint64_t & nb0,
2057
- constant uint64_t & nb1,
2058
- constant uint64_t & nb2,
2059
- constant uint64_t & nb3,
2060
- uint3 tgpig[[threadgroup_position_in_grid]],
2061
- uint3 tpitg[[thread_position_in_threadgroup]],
2062
- uint3 ntg[[threads_per_threadgroup]]) {
2063
- const int64_t i03 = tgpig[2];
2064
- const int64_t i02 = tgpig[1];
2065
- const int64_t i01 = tgpig[0];
2066
-
2067
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2068
-
2069
- const int64_t i3 = n / (ne2*ne1*ne0);
2070
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2071
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2072
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
2073
-
2074
- device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2075
-
2076
- for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
2077
- device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2078
-
2079
- dst_data[i00] = src[0];
2080
- }
2081
- }
2082
-
2083
- kernel void kernel_cpy_f32_q8_0(
2084
- device const float * src0,
2085
- device void * dst,
2086
- constant int64_t & ne00,
2087
- constant int64_t & ne01,
2088
- constant int64_t & ne02,
2089
- constant int64_t & ne03,
2090
- constant uint64_t & nb00,
2091
- constant uint64_t & nb01,
2092
- constant uint64_t & nb02,
2093
- constant uint64_t & nb03,
2094
- constant int64_t & ne0,
2095
- constant int64_t & ne1,
2096
- constant int64_t & ne2,
2097
- constant int64_t & ne3,
2098
- constant uint64_t & nb0,
2099
- constant uint64_t & nb1,
2100
- constant uint64_t & nb2,
2101
- constant uint64_t & nb3,
2102
- uint3 tgpig[[threadgroup_position_in_grid]],
2103
- uint3 tpitg[[thread_position_in_threadgroup]],
2104
- uint3 ntg[[threads_per_threadgroup]]) {
2105
- const int64_t i03 = tgpig[2];
2106
- const int64_t i02 = tgpig[1];
2107
- const int64_t i01 = tgpig[0];
2108
-
2109
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2110
-
2111
- const int64_t i3 = n / (ne2*ne1*ne0);
2112
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2113
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2114
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK8_0;
2115
-
2116
- device block_q8_0 * dst_data = (device block_q8_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2117
-
2118
- for (int64_t i00 = tpitg.x*QK8_0; i00 < ne00; i00 += ntg.x*QK8_0) {
2119
- device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2120
-
2121
- float amax = 0.0f; // absolute max
2122
-
2123
- for (int j = 0; j < QK8_0; j++) {
2124
- const float v = src[j];
2125
- amax = MAX(amax, fabs(v));
2126
- }
2127
-
2128
- const float d = amax / ((1 << 7) - 1);
2129
- const float id = d ? 1.0f/d : 0.0f;
2130
-
2131
- dst_data[i00/QK8_0].d = d;
2132
-
2133
- for (int j = 0; j < QK8_0; ++j) {
2134
- const float x0 = src[j]*id;
2135
-
2136
- dst_data[i00/QK8_0].qs[j] = round(x0);
2137
- }
2138
- }
2139
- }
2140
-
2141
- kernel void kernel_cpy_f32_q4_0(
2142
- device const float * src0,
2143
- device void * dst,
2144
- constant int64_t & ne00,
2145
- constant int64_t & ne01,
2146
- constant int64_t & ne02,
2147
- constant int64_t & ne03,
2148
- constant uint64_t & nb00,
2149
- constant uint64_t & nb01,
2150
- constant uint64_t & nb02,
2151
- constant uint64_t & nb03,
2152
- constant int64_t & ne0,
2153
- constant int64_t & ne1,
2154
- constant int64_t & ne2,
2155
- constant int64_t & ne3,
2156
- constant uint64_t & nb0,
2157
- constant uint64_t & nb1,
2158
- constant uint64_t & nb2,
2159
- constant uint64_t & nb3,
2160
- uint3 tgpig[[threadgroup_position_in_grid]],
2161
- uint3 tpitg[[thread_position_in_threadgroup]],
2162
- uint3 ntg[[threads_per_threadgroup]]) {
2163
- const int64_t i03 = tgpig[2];
2164
- const int64_t i02 = tgpig[1];
2165
- const int64_t i01 = tgpig[0];
2166
-
2167
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2168
-
2169
- const int64_t i3 = n / (ne2*ne1*ne0);
2170
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2171
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2172
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_0;
2173
-
2174
- device block_q4_0 * dst_data = (device block_q4_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2175
-
2176
- for (int64_t i00 = tpitg.x*QK4_0; i00 < ne00; i00 += ntg.x*QK4_0) {
2177
- device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2178
-
2179
- float amax = 0.0f; // absolute max
2180
- float max = 0.0f;
2181
-
2182
- for (int j = 0; j < QK4_0; j++) {
2183
- const float v = src[j];
2184
- if (amax < fabs(v)) {
2185
- amax = fabs(v);
2186
- max = v;
2187
- }
2188
- }
2189
-
2190
- const float d = max / -8;
2191
- const float id = d ? 1.0f/d : 0.0f;
2192
-
2193
- dst_data[i00/QK4_0].d = d;
2194
-
2195
- for (int j = 0; j < QK4_0/2; ++j) {
2196
- const float x0 = src[0 + j]*id;
2197
- const float x1 = src[QK4_0/2 + j]*id;
2198
-
2199
- const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
2200
- const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
2201
-
2202
- dst_data[i00/QK4_0].qs[j] = xi0;
2203
- dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
2204
- }
2205
- }
2206
- }
2207
-
2208
- kernel void kernel_cpy_f32_q4_1(
2209
- device const float * src0,
2210
- device void * dst,
2211
- constant int64_t & ne00,
2212
- constant int64_t & ne01,
2213
- constant int64_t & ne02,
2214
- constant int64_t & ne03,
2215
- constant uint64_t & nb00,
2216
- constant uint64_t & nb01,
2217
- constant uint64_t & nb02,
2218
- constant uint64_t & nb03,
2219
- constant int64_t & ne0,
2220
- constant int64_t & ne1,
2221
- constant int64_t & ne2,
2222
- constant int64_t & ne3,
2223
- constant uint64_t & nb0,
2224
- constant uint64_t & nb1,
2225
- constant uint64_t & nb2,
2226
- constant uint64_t & nb3,
2227
- uint3 tgpig[[threadgroup_position_in_grid]],
2228
- uint3 tpitg[[thread_position_in_threadgroup]],
2229
- uint3 ntg[[threads_per_threadgroup]]) {
2230
- const int64_t i03 = tgpig[2];
2231
- const int64_t i02 = tgpig[1];
2232
- const int64_t i01 = tgpig[0];
2233
-
2234
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2235
-
2236
- const int64_t i3 = n / (ne2*ne1*ne0);
2237
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2238
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2239
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_1;
2240
-
2241
- device block_q4_1 * dst_data = (device block_q4_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2242
-
2243
- for (int64_t i00 = tpitg.x*QK4_1; i00 < ne00; i00 += ntg.x*QK4_1) {
2244
- device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2245
-
2246
- float min = FLT_MAX;
2247
- float max = -FLT_MAX;
2248
-
2249
- for (int j = 0; j < QK4_1; j++) {
2250
- const float v = src[j];
2251
- if (min > v) min = v;
2252
- if (max < v) max = v;
2253
- }
2254
-
2255
- const float d = (max - min) / ((1 << 4) - 1);
2256
- const float id = d ? 1.0f/d : 0.0f;
2257
-
2258
- dst_data[i00/QK4_1].d = d;
2259
- dst_data[i00/QK4_1].m = min;
2260
-
2261
- for (int j = 0; j < QK4_1/2; ++j) {
2262
- const float x0 = (src[0 + j] - min)*id;
2263
- const float x1 = (src[QK4_1/2 + j] - min)*id;
2264
-
2265
- const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
2266
- const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
2267
-
2268
- dst_data[i00/QK4_1].qs[j] = xi0;
2269
- dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
2270
- }
2271
- }
2272
- }
2273
-
2274
- kernel void kernel_concat(
2275
- device const char * src0,
2276
- device const char * src1,
2277
- device char * dst,
2278
- constant int64_t & ne00,
2279
- constant int64_t & ne01,
2280
- constant int64_t & ne02,
2281
- constant int64_t & ne03,
2282
- constant uint64_t & nb00,
2283
- constant uint64_t & nb01,
2284
- constant uint64_t & nb02,
2285
- constant uint64_t & nb03,
2286
- constant int64_t & ne10,
2287
- constant int64_t & ne11,
2288
- constant int64_t & ne12,
2289
- constant int64_t & ne13,
2290
- constant uint64_t & nb10,
2291
- constant uint64_t & nb11,
2292
- constant uint64_t & nb12,
2293
- constant uint64_t & nb13,
2294
- constant int64_t & ne0,
2295
- constant int64_t & ne1,
2296
- constant int64_t & ne2,
2297
- constant int64_t & ne3,
2298
- constant uint64_t & nb0,
2299
- constant uint64_t & nb1,
2300
- constant uint64_t & nb2,
2301
- constant uint64_t & nb3,
2302
- uint3 tgpig[[threadgroup_position_in_grid]],
2303
- uint3 tpitg[[thread_position_in_threadgroup]],
2304
- uint3 ntg[[threads_per_threadgroup]]) {
2305
-
2306
- const int64_t i03 = tgpig.z;
2307
- const int64_t i02 = tgpig.y;
2308
- const int64_t i01 = tgpig.x;
2309
-
2310
- const int64_t i13 = i03 % ne13;
2311
- const int64_t i12 = i02 % ne12;
2312
- const int64_t i11 = i01 % ne11;
2313
-
2314
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
2315
- device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
2316
- device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
2317
-
2318
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
2319
- if (i02 < ne02) {
2320
- ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0];
2321
- src0_ptr += ntg.x*nb00;
2322
- } else {
2323
- ((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0];
2324
- src1_ptr += ntg.x*nb10;
2325
- }
2326
- dst_ptr += ntg.x*nb0;
2327
- }
2328
- }
2329
-
2330
- //============================================ k-quants ======================================================
2331
-
2332
- #ifndef QK_K
2333
- #define QK_K 256
2334
- #else
2335
- static_assert(QK_K == 256 || QK_K == 64, "QK_K must be 256 or 64");
2336
- #endif
2337
-
2338
- #if QK_K == 256
2339
- #define K_SCALE_SIZE 12
2340
- #else
2341
- #define K_SCALE_SIZE 4
2342
- #endif
2343
-
2344
- typedef struct {
2345
- uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
2346
- uint8_t qs[QK_K/4]; // quants
2347
- half d; // super-block scale for quantized scales
2348
- half dmin; // super-block scale for quantized mins
2349
- } block_q2_K;
2350
- // 84 bytes / block
2351
-
2352
- typedef struct {
2353
- uint8_t hmask[QK_K/8]; // quants - high bit
2354
- uint8_t qs[QK_K/4]; // quants - low 2 bits
2355
- #if QK_K == 64
2356
- uint8_t scales[2];
2357
- #else
2358
- uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
2359
- #endif
2360
- half d; // super-block scale
2361
- } block_q3_K;
2362
-
2363
- #if QK_K == 64
2364
- typedef struct {
2365
- half d[2]; // super-block scales/mins
2366
- uint8_t scales[2];
2367
- uint8_t qs[QK_K/2]; // 4-bit quants
2368
- } block_q4_K;
2369
- #else
2370
- typedef struct {
2371
- half d; // super-block scale for quantized scales
2372
- half dmin; // super-block scale for quantized mins
2373
- uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
2374
- uint8_t qs[QK_K/2]; // 4--bit quants
2375
- } block_q4_K;
2376
- #endif
2377
-
2378
- #if QK_K == 64
2379
- typedef struct {
2380
- half d; // super-block scales/mins
2381
- int8_t scales[QK_K/16]; // 8-bit block scales
2382
- uint8_t qh[QK_K/8]; // quants, high bit
2383
- uint8_t qs[QK_K/2]; // quants, low 4 bits
2384
- } block_q5_K;
2385
- #else
2386
- typedef struct {
2387
- half d; // super-block scale for quantized scales
2388
- half dmin; // super-block scale for quantized mins
2389
- uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
2390
- uint8_t qh[QK_K/8]; // quants, high bit
2391
- uint8_t qs[QK_K/2]; // quants, low 4 bits
2392
- } block_q5_K;
2393
- // 176 bytes / block
2394
- #endif
2395
-
2396
- typedef struct {
2397
- uint8_t ql[QK_K/2]; // quants, lower 4 bits
2398
- uint8_t qh[QK_K/4]; // quants, upper 2 bits
2399
- int8_t scales[QK_K/16]; // scales, quantized with 8 bits
2400
- half d; // super-block scale
2401
- } block_q6_K;
2402
- // 210 bytes / block
2403
-
2404
- static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
2405
- uchar4 r;
2406
- if (j < 4) {
2407
- r[0] = q[j+0] & 63;
2408
- r[2] = q[j+1] & 63;
2409
- r[1] = q[j+4] & 63;
2410
- r[3] = q[j+5] & 63;
2411
- } else {
2412
- r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
2413
- r[2] = (q[j+5] & 0xF) | ((q[j-3] >> 6) << 4);
2414
- r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
2415
- r[3] = (q[j+5] >> 4) | ((q[j+1] >> 6) << 4);
2416
- }
2417
- return r;
2418
- }
2419
-
2420
- //====================================== dot products =========================
2421
-
2422
- void kernel_mul_mv_q2_K_f32_impl(
2423
- device const void * src0,
2424
- device const float * src1,
2425
- device float * dst,
2426
- constant int64_t & ne00,
2427
- constant int64_t & ne01,
2428
- constant int64_t & ne02,
2429
- constant int64_t & ne10,
2430
- constant int64_t & ne12,
2431
- constant int64_t & ne0,
2432
- constant int64_t & ne1,
2433
- constant uint & r2,
2434
- constant uint & r3,
2435
- uint3 tgpig[[threadgroup_position_in_grid]],
2436
- uint tiisg[[thread_index_in_simdgroup]],
2437
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2438
-
2439
- const int nb = ne00/QK_K;
2440
- const int r0 = tgpig.x;
2441
- const int r1 = tgpig.y;
2442
- const int im = tgpig.z;
2443
-
2444
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
2445
- const int ib_row = first_row * nb;
2446
-
2447
- const uint i12 = im%ne12;
2448
- const uint i13 = im/ne12;
2449
-
2450
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2451
-
2452
- device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
2453
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2454
-
2455
- float yl[32];
2456
- float sumf[N_DST]={0.f}, all_sum;
2457
-
2458
- const int step = sizeof(block_q2_K) * nb;
2459
-
2460
- #if QK_K == 256
2461
- const int ix = tiisg/8; // 0...3
2462
- const int it = tiisg%8; // 0...7
2463
- const int iq = it/4; // 0 or 1
2464
- const int ir = it%4; // 0...3
2465
- const int is = (8*ir)/16;// 0 or 1
2466
-
2467
- device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
2468
-
2469
- for (int ib = ix; ib < nb; ib += 4) {
2470
-
2471
- float4 sumy = {0.f, 0.f, 0.f, 0.f};
2472
- for (int i = 0; i < 8; ++i) {
2473
- yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
2474
- yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
2475
- yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
2476
- yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
2477
- }
2478
-
2479
- device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*iq + is;
2480
- device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
2481
- device const half * dh = &x[ib].d;
2482
-
2483
- for (int row = 0; row < N_DST; row++) {
2484
-
2485
- float4 acc1 = {0.f, 0.f, 0.f, 0.f};
2486
- float4 acc2 = {0.f, 0.f, 0.f, 0.f};
2487
- for (int i = 0; i < 8; i += 2) {
2488
- acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
2489
- acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
2490
- acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
2491
- acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
2492
- acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
2493
- acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
2494
- acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
2495
- acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
2496
- }
2497
- float dall = dh[0];
2498
- float dmin = dh[1] * 1.f/16.f;
2499
- sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
2500
- (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f +
2501
- (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f +
2502
- (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
2503
- dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
2504
-
2505
- qs += step/2;
2506
- sc += step;
2507
- dh += step/2;
2508
- }
2509
-
2510
- y4 += 4 * QK_K;
2511
- }
2512
- #else
2513
- const int ix = tiisg/2; // 0...15
2514
- const int it = tiisg%2; // 0...1
2515
-
2516
- device const float * y4 = y + ix * QK_K + 8 * it;
2517
-
2518
- for (int ib = ix; ib < nb; ib += 16) {
2519
-
2520
- float4 sumy = {0.f, 0.f, 0.f, 0.f};
2521
- for (int i = 0; i < 8; ++i) {
2522
- yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
2523
- yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8];
2524
- yl[i+16] = y4[i+32]; sumy[2] += yl[i+16];
2525
- yl[i+24] = y4[i+48]; sumy[3] += yl[i+24];
2526
- }
2527
-
2528
- device const uint8_t * sc = (device const uint8_t *)x[ib].scales;
2529
- device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
2530
- device const half * dh = &x[ib].d;
2531
-
2532
- for (int row = 0; row < N_DST; row++) {
2533
-
2534
- float4 acc1 = {0.f, 0.f, 0.f, 0.f};
2535
- float4 acc2 = {0.f, 0.f, 0.f, 0.f};
2536
- for (int i = 0; i < 8; i += 2) {
2537
- acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
2538
- acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
2539
- acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
2540
- acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
2541
- acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
2542
- acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
2543
- acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
2544
- acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
2545
- }
2546
-
2547
- float dall = dh[0];
2548
- float dmin = dh[1];
2549
- sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
2550
- (acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f +
2551
- (acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f +
2552
- (acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) -
2553
- dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4));
2554
-
2555
- qs += step/2;
2556
- sc += step;
2557
- dh += step/2;
2558
- }
2559
-
2560
- y4 += 16 * QK_K;
2561
- }
2562
- #endif
2563
-
2564
- for (int row = 0; row < N_DST; ++row) {
2565
- all_sum = simd_sum(sumf[row]);
2566
- if (tiisg == 0) {
2567
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
2568
- }
2569
- }
2570
- }
2571
-
2572
- [[host_name("kernel_mul_mv_q2_K_f32")]]
2573
- kernel void kernel_mul_mv_q2_K_f32(
2574
- device const void * src0,
2575
- device const float * src1,
2576
- device float * dst,
2577
- constant int64_t & ne00,
2578
- constant int64_t & ne01[[buffer(4)]],
2579
- constant int64_t & ne02[[buffer(5)]],
2580
- constant int64_t & ne10[[buffer(9)]],
2581
- constant int64_t & ne12[[buffer(11)]],
2582
- constant int64_t & ne0 [[buffer(15)]],
2583
- constant int64_t & ne1 [[buffer(16)]],
2584
- constant uint & r2 [[buffer(17)]],
2585
- constant uint & r3 [[buffer(18)]],
2586
- uint3 tgpig[[threadgroup_position_in_grid]],
2587
- uint tiisg[[thread_index_in_simdgroup]],
2588
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2589
-
2590
- kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
2591
- }
2592
-
2593
- #if QK_K == 256
2594
- void kernel_mul_mv_q3_K_f32_impl(
2595
- device const void * src0,
2596
- device const float * src1,
2597
- device float * dst,
2598
- constant int64_t & ne00,
2599
- constant int64_t & ne01,
2600
- constant int64_t & ne02,
2601
- constant int64_t & ne10,
2602
- constant int64_t & ne12,
2603
- constant int64_t & ne0,
2604
- constant int64_t & ne1,
2605
- constant uint & r2,
2606
- constant uint & r3,
2607
- uint3 tgpig[[threadgroup_position_in_grid]],
2608
- uint tiisg[[thread_index_in_simdgroup]],
2609
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2610
-
2611
- const int nb = ne00/QK_K;
2612
-
2613
- const int64_t r0 = tgpig.x;
2614
- const int64_t r1 = tgpig.y;
2615
- const int64_t im = tgpig.z;
2616
-
2617
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
2618
-
2619
- const uint i12 = im%ne12;
2620
- const uint i13 = im/ne12;
2621
-
2622
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2623
-
2624
- device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
2625
- device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2626
-
2627
- float yl[32];
2628
-
2629
- //const uint16_t kmask1 = 0x3030;
2630
- //const uint16_t kmask2 = 0x0f0f;
2631
-
2632
- const int tid = tiisg/4;
2633
- const int ix = tiisg%4;
2634
- const int ip = tid/4; // 0 or 1
2635
- const int il = 2*((tid%4)/2); // 0 or 2
2636
- const int ir = tid%2;
2637
- const int n = 8;
2638
- const int l0 = n*ir;
2639
-
2640
- // One would think that the Metal compiler would figure out that ip and il can only have
2641
- // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
2642
- // with these two tales.
2643
- //
2644
- // Possible masks for the high bit
2645
- const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, // ip = 0, il = 0
2646
- {0x0004, 0x0400, 0x0008, 0x0800}, // ip = 0, il = 2
2647
- {0x0010, 0x1000, 0x0020, 0x2000}, // ip = 1, il = 0
2648
- {0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2
2649
-
2650
- // Possible masks for the low 2 bits
2651
- const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}};
2652
-
2653
- const ushort4 hm = mm[2*ip + il/2];
2654
-
2655
- const int shift = 2*il;
2656
- const float v1 = il == 0 ? 4.f : 64.f;
2657
- const float v2 = 4.f * v1;
2658
-
2659
- const uint16_t s_shift1 = 4*ip;
2660
- const uint16_t s_shift2 = s_shift1 + il;
2661
-
2662
- const int q_offset = 32*ip + l0;
2663
- const int y_offset = 128*ip + 32*il + l0;
2664
-
2665
- const int step = sizeof(block_q3_K) * nb / 2;
2666
-
2667
- device const float * y1 = yy + ix*QK_K + y_offset;
2668
-
2669
- uint32_t scales32, aux32;
2670
- thread uint16_t * scales16 = (thread uint16_t *)&scales32;
2671
- thread const int8_t * scales = (thread const int8_t *)&scales32;
2672
-
2673
- float sumf1[2] = {0.f};
2674
- float sumf2[2] = {0.f};
2675
- for (int i = ix; i < nb; i += 4) {
2676
-
2677
- for (int l = 0; l < 8; ++l) {
2678
- yl[l+ 0] = y1[l+ 0];
2679
- yl[l+ 8] = y1[l+16];
2680
- yl[l+16] = y1[l+32];
2681
- yl[l+24] = y1[l+48];
2682
- }
2683
-
2684
- device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
2685
- device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0);
2686
- device const uint16_t * a = (device const uint16_t *)(x[i].scales);
2687
- device const half * dh = &x[i].d;
2688
-
2689
- for (int row = 0; row < 2; ++row) {
2690
-
2691
- const float d_all = (float)dh[0];
2692
-
2693
- scales16[0] = a[4];
2694
- scales16[1] = a[5];
2695
- aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030;
2696
- scales16[0] = a[il+0];
2697
- scales16[1] = a[il+1];
2698
- scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
2699
-
2700
- float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
2701
- for (int l = 0; l < n; l += 2) {
2702
- const int32_t qs = q[l/2];
2703
- s1 += yl[l+0] * (qs & qm[il/2][0]);
2704
- s2 += yl[l+1] * (qs & qm[il/2][1]);
2705
- s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]);
2706
- s4 += yl[l+16] * (qs & qm[il/2][2]);
2707
- s5 += yl[l+17] * (qs & qm[il/2][3]);
2708
- s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]);
2709
- }
2710
- float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
2711
- float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
2712
- sumf1[row] += d1 * (scales[0] - 32);
2713
- sumf2[row] += d2 * (scales[2] - 32);
2714
-
2715
- s1 = s2 = s3 = s4 = s5 = s6 = 0;
2716
- for (int l = 0; l < n; l += 2) {
2717
- const int32_t qs = q[l/2+8];
2718
- s1 += yl[l+8] * (qs & qm[il/2][0]);
2719
- s2 += yl[l+9] * (qs & qm[il/2][1]);
2720
- s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]);
2721
- s4 += yl[l+24] * (qs & qm[il/2][2]);
2722
- s5 += yl[l+25] * (qs & qm[il/2][3]);
2723
- s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]);
2724
- }
2725
- d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
2726
- d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
2727
- sumf1[row] += d1 * (scales[1] - 32);
2728
- sumf2[row] += d2 * (scales[3] - 32);
2729
-
2730
- q += step;
2731
- h += step;
2732
- a += step;
2733
- dh += step;
2734
-
2735
- }
2736
-
2737
- y1 += 4 * QK_K;
2738
-
2739
- }
2740
-
2741
- for (int row = 0; row < 2; ++row) {
2742
- const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
2743
- sumf1[row] = simd_sum(sumf);
2744
- }
2745
- if (tiisg == 0) {
2746
- for (int row = 0; row < 2; ++row) {
2747
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = sumf1[row];
2748
- }
2749
- }
2750
- }
2751
- #else
2752
- void kernel_mul_mv_q3_K_f32_impl(
2753
- device const void * src0,
2754
- device const float * src1,
2755
- device float * dst,
2756
- constant int64_t & ne00,
2757
- constant int64_t & ne01,
2758
- constant int64_t & ne02,
2759
- constant int64_t & ne10,
2760
- constant int64_t & ne12,
2761
- constant int64_t & ne0,
2762
- constant int64_t & ne1,
2763
- constant uint & r2,
2764
- constant uint & r3,
2765
- uint3 tgpig[[threadgroup_position_in_grid]],
2766
- uint tiisg[[thread_index_in_simdgroup]],
2767
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2768
-
2769
- const int nb = ne00/QK_K;
2770
-
2771
- const int64_t r0 = tgpig.x;
2772
- const int64_t r1 = tgpig.y;
2773
- const int64_t im = tgpig.z;
2774
-
2775
- const int row = 2 * r0 + sgitg;
2776
-
2777
- const uint i12 = im%ne12;
2778
- const uint i13 = im/ne12;
2779
-
2780
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2781
-
2782
- device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
2783
- device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2784
-
2785
- const int ix = tiisg/4;
2786
- const int il = 4 * (tiisg%4);// 0, 4, 8, 12
2787
- const int iq = il/8; // 0, 0, 1, 1
2788
- const int in = il%8; // 0, 4, 0, 4
2789
-
2790
- float2 sum = {0.f, 0.f};
2791
-
2792
- for (int i = ix; i < nb; i += 8) {
2793
-
2794
- const float d_all = (float)(x[i].d);
2795
-
2796
- device const uint16_t * q = (device const uint16_t *)(x[i].qs + il);
2797
- device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in);
2798
- device const uint16_t * s = (device const uint16_t *)(x[i].scales);
2799
- device const float * y = yy + i * QK_K + il;
2800
-
2801
- const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8);
2802
- const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f;
2803
- const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f;
2804
- const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
2805
-
2806
- for (int l = 0; l < 4; l += 2) {
2807
- const uint16_t hm = h[l/2] >> iq;
2808
- sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
2809
- + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
2810
- + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
2811
- + y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256));
2812
- sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024))
2813
- + y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096))
2814
- + y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384))
2815
- + y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536));
2816
- }
2817
-
2818
- }
2819
- const float sumf = sum[0] + sum[1] * 1.f/256.f;
2820
-
2821
- const float tot = simd_sum(sumf);
2822
- if (tiisg == 0) {
2823
- dst[r1*ne0 + im*ne0*ne1 + row] = tot;
2824
- }
2825
-
2826
- }
2827
- #endif
2828
-
2829
- [[host_name("kernel_mul_mv_q3_K_f32")]]
2830
- kernel void kernel_mul_mv_q3_K_f32(
2831
- device const void * src0,
2832
- device const float * src1,
2833
- device float * dst,
2834
- constant int64_t & ne00,
2835
- constant int64_t & ne01[[buffer(4)]],
2836
- constant int64_t & ne02[[buffer(5)]],
2837
- constant int64_t & ne10[[buffer(9)]],
2838
- constant int64_t & ne12[[buffer(11)]],
2839
- constant int64_t & ne0 [[buffer(15)]],
2840
- constant int64_t & ne1 [[buffer(16)]],
2841
- constant uint & r2 [[buffer(17)]],
2842
- constant uint & r3 [[buffer(18)]],
2843
- uint3 tgpig[[threadgroup_position_in_grid]],
2844
- uint tiisg[[thread_index_in_simdgroup]],
2845
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2846
-
2847
- kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
2848
- }
2849
-
2850
- #if QK_K == 256
2851
- void kernel_mul_mv_q4_K_f32_impl(
2852
- device const void * src0,
2853
- device const float * src1,
2854
- device float * dst,
2855
- constant int64_t & ne00,
2856
- constant int64_t & ne01,
2857
- constant int64_t & ne02,
2858
- constant int64_t & ne10,
2859
- constant int64_t & ne12,
2860
- constant int64_t & ne0,
2861
- constant int64_t & ne1,
2862
- constant uint & r2,
2863
- constant uint & r3,
2864
- uint3 tgpig[[threadgroup_position_in_grid]],
2865
- uint tiisg[[thread_index_in_simdgroup]],
2866
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2867
-
2868
- const uint16_t kmask1 = 0x3f3f;
2869
- const uint16_t kmask2 = 0x0f0f;
2870
- const uint16_t kmask3 = 0xc0c0;
2871
-
2872
- const int ix = tiisg/8; // 0...3
2873
- const int it = tiisg%8; // 0...7
2874
- const int iq = it/4; // 0 or 1
2875
- const int ir = it%4; // 0...3
2876
-
2877
- const int nb = ne00/QK_K;
2878
- const int r0 = tgpig.x;
2879
- const int r1 = tgpig.y;
2880
- const int im = tgpig.z;
2881
- //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
2882
- const int first_row = r0 * N_DST;
2883
- const int ib_row = first_row * nb;
2884
-
2885
- const uint i12 = im%ne12;
2886
- const uint i13 = im/ne12;
2887
-
2888
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2889
-
2890
- device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
2891
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2892
-
2893
- float yl[16];
2894
- float yh[16];
2895
- float sumf[N_DST]={0.f}, all_sum;
2896
-
2897
- const int step = sizeof(block_q4_K) * nb / 2;
2898
-
2899
- device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
2900
-
2901
- uint16_t sc16[4];
2902
- thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
2903
-
2904
- for (int ib = ix; ib < nb; ib += 4) {
2905
-
2906
- float4 sumy = {0.f, 0.f, 0.f, 0.f};
2907
- for (int i = 0; i < 8; ++i) {
2908
- yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0];
2909
- yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
2910
- yh[i+0] = y4[i+128]; sumy[2] += yh[i+0];
2911
- yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
2912
- }
2913
-
2914
- device const uint16_t * sc = (device const uint16_t *)x[ib].scales + iq;
2915
- device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
2916
- device const half * dh = &x[ib].d;
2917
-
2918
- for (int row = 0; row < N_DST; row++) {
2919
-
2920
- sc16[0] = sc[0] & kmask1;
2921
- sc16[1] = sc[2] & kmask1;
2922
- sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
2923
- sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
2924
-
2925
- device const uint16_t * q2 = q1 + 32;
2926
-
2927
- float4 acc1 = {0.f, 0.f, 0.f, 0.f};
2928
- float4 acc2 = {0.f, 0.f, 0.f, 0.f};
2929
- for (int i = 0; i < 8; i += 2) {
2930
- acc1[0] += yl[i+0] * (q1[i/2] & 0x000F);
2931
- acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00);
2932
- acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0);
2933
- acc1[3] += yl[i+9] * (q1[i/2] & 0xF000);
2934
- acc2[0] += yh[i+0] * (q2[i/2] & 0x000F);
2935
- acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00);
2936
- acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0);
2937
- acc2[3] += yh[i+9] * (q2[i/2] & 0xF000);
2938
- }
2939
-
2940
- float dall = dh[0];
2941
- float dmin = dh[1];
2942
- sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
2943
- (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
2944
- (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
2945
- (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
2946
- dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
2947
-
2948
- q1 += step;
2949
- sc += step;
2950
- dh += step;
2951
- }
2952
-
2953
- y4 += 4 * QK_K;
2954
- }
2955
-
2956
- for (int row = 0; row < N_DST; ++row) {
2957
- all_sum = simd_sum(sumf[row]);
2958
- if (tiisg == 0) {
2959
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
2960
- }
2961
- }
2962
- }
2963
- #else
2964
- void kernel_mul_mv_q4_K_f32_impl(
2965
- device const void * src0,
2966
- device const float * src1,
2967
- device float * dst,
2968
- constant int64_t & ne00,
2969
- constant int64_t & ne01,
2970
- constant int64_t & ne02,
2971
- constant int64_t & ne10,
2972
- constant int64_t & ne12,
2973
- constant int64_t & ne0,
2974
- constant int64_t & ne1,
2975
- constant uint & r2,
2976
- constant uint & r3,
2977
- uint3 tgpig[[threadgroup_position_in_grid]],
2978
- uint tiisg[[thread_index_in_simdgroup]],
2979
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2980
-
2981
- const int ix = tiisg/4; // 0...7
2982
- const int it = tiisg%4; // 0...3
2983
-
2984
- const int nb = ne00/QK_K;
2985
- const int r0 = tgpig.x;
2986
- const int r1 = tgpig.y;
2987
- const int im = tgpig.z;
2988
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
2989
- const int ib_row = first_row * nb;
2990
-
2991
- const uint i12 = im%ne12;
2992
- const uint i13 = im/ne12;
2993
-
2994
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2995
-
2996
- device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
2997
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2998
-
2999
- float yl[8];
3000
- float yh[8];
3001
- float sumf[N_DST]={0.f}, all_sum;
3002
-
3003
- const int step = sizeof(block_q4_K) * nb / 2;
3004
-
3005
- device const float * y4 = y + ix * QK_K + 8 * it;
3006
-
3007
- uint16_t sc16[4];
3008
-
3009
- for (int ib = ix; ib < nb; ib += 8) {
3010
-
3011
- float2 sumy = {0.f, 0.f};
3012
- for (int i = 0; i < 8; ++i) {
3013
- yl[i] = y4[i+ 0]; sumy[0] += yl[i];
3014
- yh[i] = y4[i+32]; sumy[1] += yh[i];
3015
- }
3016
-
3017
- device const uint16_t * sc = (device const uint16_t *)x[ib].scales;
3018
- device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
3019
- device const half * dh = x[ib].d;
3020
-
3021
- for (int row = 0; row < N_DST; row++) {
3022
-
3023
- sc16[0] = sc[0] & 0x000f;
3024
- sc16[1] = sc[0] & 0x0f00;
3025
- sc16[2] = sc[0] & 0x00f0;
3026
- sc16[3] = sc[0] & 0xf000;
3027
-
3028
- float2 acc1 = {0.f, 0.f};
3029
- float2 acc2 = {0.f, 0.f};
3030
- for (int i = 0; i < 8; i += 2) {
3031
- acc1[0] += yl[i+0] * (qs[i/2] & 0x000F);
3032
- acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00);
3033
- acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0);
3034
- acc2[1] += yh[i+1] * (qs[i/2] & 0xF000);
3035
- }
3036
-
3037
- float dall = dh[0];
3038
- float dmin = dh[1];
3039
- sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] +
3040
- (acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) -
3041
- dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f);
3042
-
3043
- qs += step;
3044
- sc += step;
3045
- dh += step;
3046
- }
3047
-
3048
- y4 += 8 * QK_K;
3049
- }
3050
-
3051
- for (int row = 0; row < N_DST; ++row) {
3052
- all_sum = simd_sum(sumf[row]);
3053
- if (tiisg == 0) {
3054
- dst[r1*ne0+ im*ne0*ne1 + first_row + row] = all_sum;
3055
- }
3056
- }
3057
- }
3058
- #endif
3059
-
3060
- [[host_name("kernel_mul_mv_q4_K_f32")]]
3061
- kernel void kernel_mul_mv_q4_K_f32(
3062
- device const void * src0,
3063
- device const float * src1,
3064
- device float * dst,
3065
- constant int64_t & ne00,
3066
- constant int64_t & ne01[[buffer(4)]],
3067
- constant int64_t & ne02[[buffer(5)]],
3068
- constant int64_t & ne10[[buffer(9)]],
3069
- constant int64_t & ne12[[buffer(11)]],
3070
- constant int64_t & ne0 [[buffer(15)]],
3071
- constant int64_t & ne1 [[buffer(16)]],
3072
- constant uint & r2 [[buffer(17)]],
3073
- constant uint & r3 [[buffer(18)]],
3074
- uint3 tgpig[[threadgroup_position_in_grid]],
3075
- uint tiisg[[thread_index_in_simdgroup]],
3076
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3077
-
3078
- kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
3079
- }
3080
-
3081
- void kernel_mul_mv_q5_K_f32_impl(
3082
- device const void * src0,
3083
- device const float * src1,
3084
- device float * dst,
3085
- constant int64_t & ne00,
3086
- constant int64_t & ne01,
3087
- constant int64_t & ne02,
3088
- constant int64_t & ne10,
3089
- constant int64_t & ne12,
3090
- constant int64_t & ne0,
3091
- constant int64_t & ne1,
3092
- constant uint & r2,
3093
- constant uint & r3,
3094
- uint3 tgpig[[threadgroup_position_in_grid]],
3095
- uint tiisg[[thread_index_in_simdgroup]],
3096
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3097
-
3098
- const int nb = ne00/QK_K;
3099
-
3100
- const int64_t r0 = tgpig.x;
3101
- const int64_t r1 = tgpig.y;
3102
- const int im = tgpig.z;
3103
-
3104
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
3105
-
3106
- const uint i12 = im%ne12;
3107
- const uint i13 = im/ne12;
3108
-
3109
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
3110
-
3111
- device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
3112
- device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
3113
-
3114
- float sumf[2]={0.f};
3115
-
3116
- const int step = sizeof(block_q5_K) * nb;
3117
-
3118
- #if QK_K == 256
3119
- #
3120
- float yl[16], yh[16];
3121
-
3122
- const uint16_t kmask1 = 0x3f3f;
3123
- const uint16_t kmask2 = 0x0f0f;
3124
- const uint16_t kmask3 = 0xc0c0;
3125
-
3126
- const int tid = tiisg/4;
3127
- const int ix = tiisg%4;
3128
- const int iq = tid/4;
3129
- const int ir = tid%4;
3130
- const int n = 8;
3131
-
3132
- const int l0 = n*ir;
3133
- const int q_offset = 32*iq + l0;
3134
- const int y_offset = 64*iq + l0;
3135
-
3136
- const uint8_t hm1 = 1u << (2*iq);
3137
- const uint8_t hm2 = hm1 << 1;
3138
- const uint8_t hm3 = hm1 << 4;
3139
- const uint8_t hm4 = hm2 << 4;
3140
-
3141
- uint16_t sc16[4];
3142
- thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
3143
-
3144
- device const float * y1 = yy + ix*QK_K + y_offset;
3145
-
3146
- for (int i = ix; i < nb; i += 4) {
3147
-
3148
- device const uint8_t * q1 = x[i].qs + q_offset;
3149
- device const uint8_t * qh = x[i].qh + l0;
3150
- device const half * dh = &x[i].d;
3151
- device const uint16_t * a = (device const uint16_t *)x[i].scales + iq;
3152
-
3153
- device const float * y2 = y1 + 128;
3154
- float4 sumy = {0.f, 0.f, 0.f, 0.f};
3155
- for (int l = 0; l < 8; ++l) {
3156
- yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
3157
- yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
3158
- yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
3159
- yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
3160
- }
3161
-
3162
- for (int row = 0; row < 2; ++row) {
3163
-
3164
- device const uint8_t * q2 = q1 + 64;
3165
-
3166
- sc16[0] = a[0] & kmask1;
3167
- sc16[1] = a[2] & kmask1;
3168
- sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
3169
- sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
3170
-
3171
- float4 acc1 = {0.f};
3172
- float4 acc2 = {0.f};
3173
- for (int l = 0; l < n; ++l) {
3174
- uint8_t h = qh[l];
3175
- acc1[0] += yl[l+0] * (q1[l] & 0x0F);
3176
- acc1[1] += yl[l+8] * (q1[l] & 0xF0);
3177
- acc1[2] += yh[l+0] * (q2[l] & 0x0F);
3178
- acc1[3] += yh[l+8] * (q2[l] & 0xF0);
3179
- acc2[0] += h & hm1 ? yl[l+0] : 0.f;
3180
- acc2[1] += h & hm2 ? yl[l+8] : 0.f;
3181
- acc2[2] += h & hm3 ? yh[l+0] : 0.f;
3182
- acc2[3] += h & hm4 ? yh[l+8] : 0.f;
3183
- }
3184
- const float dall = dh[0];
3185
- const float dmin = dh[1];
3186
- sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) +
3187
- sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
3188
- sc8[4] * (acc1[2] + 16.f*acc2[2]) +
3189
- sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
3190
- dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
3191
-
3192
- q1 += step;
3193
- qh += step;
3194
- dh += step/2;
3195
- a += step/2;
3196
-
3197
- }
3198
-
3199
- y1 += 4 * QK_K;
3200
-
3201
- }
3202
- #else
3203
- float yl[8], yh[8];
3204
-
3205
- const int il = 4 * (tiisg/8); // 0, 4, 8, 12
3206
- const int ix = tiisg%8;
3207
- const int iq = il/8; // 0, 0, 1, 1
3208
- const int in = il%8; // 0, 4, 0, 4
3209
-
3210
- device const float * y = yy + ix*QK_K + il;
3211
-
3212
- for (int i = ix; i < nb; i += 8) {
3213
-
3214
- for (int l = 0; l < 4; ++l) {
3215
- yl[l+0] = y[l+ 0];
3216
- yl[l+4] = y[l+16];
3217
- yh[l+0] = y[l+32];
3218
- yh[l+4] = y[l+48];
3219
- }
3220
-
3221
- device const half * dh = &x[i].d;
3222
- device const uint8_t * q = x[i].qs + il;
3223
- device const uint8_t * h = x[i].qh + in;
3224
- device const int8_t * s = x[i].scales;
3225
-
3226
- for (int row = 0; row < 2; ++row) {
3227
-
3228
- const float d = dh[0];
3229
-
3230
- float2 acc = {0.f, 0.f};
3231
- for (int l = 0; l < 4; ++l) {
3232
- const uint8_t hl = h[l] >> iq;
3233
- acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
3234
- + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
3235
- acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
3236
- + yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256));
3237
- }
3238
- sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]);
3239
-
3240
- q += step;
3241
- h += step;
3242
- s += step;
3243
- dh += step/2;
3244
-
3245
- }
3246
-
3247
- y += 8 * QK_K;
3248
- }
3249
- #endif
3250
-
3251
- for (int row = 0; row < 2; ++row) {
3252
- const float tot = simd_sum(sumf[row]);
3253
- if (tiisg == 0) {
3254
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
3255
- }
3256
- }
3257
- }
3258
-
3259
- [[host_name("kernel_mul_mv_q5_K_f32")]]
3260
- kernel void kernel_mul_mv_q5_K_f32(
3261
- device const void * src0,
3262
- device const float * src1,
3263
- device float * dst,
3264
- constant int64_t & ne00,
3265
- constant int64_t & ne01[[buffer(4)]],
3266
- constant int64_t & ne02[[buffer(5)]],
3267
- constant int64_t & ne10[[buffer(9)]],
3268
- constant int64_t & ne12[[buffer(11)]],
3269
- constant int64_t & ne0 [[buffer(15)]],
3270
- constant int64_t & ne1 [[buffer(16)]],
3271
- constant uint & r2 [[buffer(17)]],
3272
- constant uint & r3 [[buffer(18)]],
3273
- uint3 tgpig[[threadgroup_position_in_grid]],
3274
- uint tiisg[[thread_index_in_simdgroup]],
3275
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3276
-
3277
- kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
3278
- }
3279
-
3280
- void kernel_mul_mv_q6_K_f32_impl(
3281
- device const void * src0,
3282
- device const float * src1,
3283
- device float * dst,
3284
- constant int64_t & ne00,
3285
- constant int64_t & ne01,
3286
- constant int64_t & ne02,
3287
- constant int64_t & ne10,
3288
- constant int64_t & ne12,
3289
- constant int64_t & ne0,
3290
- constant int64_t & ne1,
3291
- constant uint & r2,
3292
- constant uint & r3,
3293
- uint3 tgpig[[threadgroup_position_in_grid]],
3294
- uint tiisg[[thread_index_in_simdgroup]],
3295
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3296
-
3297
- const uint8_t kmask1 = 0x03;
3298
- const uint8_t kmask2 = 0x0C;
3299
- const uint8_t kmask3 = 0x30;
3300
- const uint8_t kmask4 = 0xC0;
3301
-
3302
- const int nb = ne00/QK_K;
3303
-
3304
- const int64_t r0 = tgpig.x;
3305
- const int64_t r1 = tgpig.y;
3306
- const int im = tgpig.z;
3307
-
3308
- const int row = 2 * r0 + sgitg;
3309
-
3310
- const uint i12 = im%ne12;
3311
- const uint i13 = im/ne12;
3312
-
3313
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
3314
-
3315
- device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
3316
- device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
3317
-
3318
- float sumf = 0;
3319
-
3320
- #if QK_K == 256
3321
- const int tid = tiisg/2;
3322
- const int ix = tiisg%2;
3323
- const int ip = tid/8; // 0 or 1
3324
- const int il = tid%8;
3325
- const int n = 4;
3326
- const int l0 = n*il;
3327
- const int is = 8*ip + l0/16;
3328
-
3329
- const int y_offset = 128*ip + l0;
3330
- const int q_offset_l = 64*ip + l0;
3331
- const int q_offset_h = 32*ip + l0;
3332
-
3333
- for (int i = ix; i < nb; i += 2) {
3334
-
3335
- device const uint8_t * q1 = x[i].ql + q_offset_l;
3336
- device const uint8_t * q2 = q1 + 32;
3337
- device const uint8_t * qh = x[i].qh + q_offset_h;
3338
- device const int8_t * sc = x[i].scales + is;
3339
-
3340
- device const float * y = yy + i * QK_K + y_offset;
3341
-
3342
- const float dall = x[i].d;
3343
-
3344
- float4 sums = {0.f, 0.f, 0.f, 0.f};
3345
- for (int l = 0; l < n; ++l) {
3346
- sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
3347
- sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
3348
- sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
3349
- sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
3350
- }
3351
-
3352
- sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
3353
-
3354
- }
3355
-
3356
- #else
3357
- const int ix = tiisg/4;
3358
- const int il = 4*(tiisg%4);
3359
-
3360
- for (int i = ix; i < nb; i += 8) {
3361
- device const float * y = yy + i * QK_K + il;
3362
- device const uint8_t * ql = x[i].ql + il;
3363
- device const uint8_t * qh = x[i].qh + il;
3364
- device const int8_t * s = x[i].scales;
3365
-
3366
- const float d = x[i].d;
3367
-
3368
- float4 sums = {0.f, 0.f, 0.f, 0.f};
3369
- for (int l = 0; l < 4; ++l) {
3370
- sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
3371
- sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
3372
- sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) >> 0)) - 32);
3373
- sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
3374
- }
3375
- sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]);
3376
- }
3377
-
3378
- #endif
3379
-
3380
- const float tot = simd_sum(sumf);
3381
- if (tiisg == 0) {
3382
- dst[r1*ne0 + im*ne0*ne1 + row] = tot;
3383
- }
3384
- }
3385
-
3386
- [[host_name("kernel_mul_mv_q6_K_f32")]]
3387
- kernel void kernel_mul_mv_q6_K_f32(
3388
- device const void * src0,
3389
- device const float * src1,
3390
- device float * dst,
3391
- constant int64_t & ne00,
3392
- constant int64_t & ne01[[buffer(4)]],
3393
- constant int64_t & ne02[[buffer(5)]],
3394
- constant int64_t & ne10[[buffer(9)]],
3395
- constant int64_t & ne12[[buffer(11)]],
3396
- constant int64_t & ne0 [[buffer(15)]],
3397
- constant int64_t & ne1 [[buffer(16)]],
3398
- constant uint & r2 [[buffer(17)]],
3399
- constant uint & r3 [[buffer(18)]],
3400
- uint3 tgpig[[threadgroup_position_in_grid]],
3401
- uint tiisg[[thread_index_in_simdgroup]],
3402
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3403
-
3404
- kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
3405
- }
3406
-
3407
- //============================= templates and their specializations =============================
3408
-
3409
- // NOTE: this is not dequantizing - we are simply fitting the template
3410
- template <typename type4x4>
3411
- void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
3412
- float4x4 temp = *(((device float4x4 *)src));
3413
- for (int i = 0; i < 16; i++){
3414
- reg[i/4][i%4] = temp[i/4][i%4];
3415
- }
3416
- }
3417
-
3418
- template <typename type4x4>
3419
- void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
3420
- half4x4 temp = *(((device half4x4 *)src));
3421
- for (int i = 0; i < 16; i++){
3422
- reg[i/4][i%4] = temp[i/4][i%4];
3423
- }
3424
- }
3425
-
3426
- template <typename type4x4>
3427
- void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
3428
- device const uint16_t * qs = ((device const uint16_t *)xb + 1);
3429
- const float d1 = il ? (xb->d / 16.h) : xb->d;
3430
- const float d2 = d1 / 256.f;
3431
- const float md = -8.h * xb->d;
3432
- const ushort mask0 = il ? 0x00F0 : 0x000F;
3433
- const ushort mask1 = mask0 << 8;
3434
-
3435
- for (int i=0;i<8;i++) {
3436
- reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
3437
- reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
3438
- }
3439
- }
3440
-
3441
- template <typename type4x4>
3442
- void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
3443
- device const uint16_t * qs = ((device const uint16_t *)xb + 2);
3444
- const float d1 = il ? (xb->d / 16.h) : xb->d;
3445
- const float d2 = d1 / 256.f;
3446
- const float m = xb->m;
3447
- const ushort mask0 = il ? 0x00F0 : 0x000F;
3448
- const ushort mask1 = mask0 << 8;
3449
-
3450
- for (int i=0;i<8;i++) {
3451
- reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
3452
- reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
3453
- }
3454
- }
3455
-
3456
- template <typename type4x4>
3457
- void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
3458
- device const uint16_t * qs = ((device const uint16_t *)xb + 3);
3459
- const float d = xb->d;
3460
- const float md = -16.h * xb->d;
3461
- const ushort mask = il ? 0x00F0 : 0x000F;
3462
-
3463
- const uint32_t qh = *((device const uint32_t *)xb->qh);
3464
-
3465
- const int x_mv = il ? 4 : 0;
3466
-
3467
- const int gh_mv = il ? 12 : 0;
3468
- const int gh_bk = il ? 0 : 4;
3469
-
3470
- for (int i = 0; i < 8; i++) {
3471
- // extract the 5-th bits for x0 and x1
3472
- const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
3473
- const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
3474
-
3475
- // combine the 4-bits from qs with the 5th bit
3476
- const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
3477
- const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
3478
-
3479
- reg[i/2][2*(i%2)+0] = d * x0 + md;
3480
- reg[i/2][2*(i%2)+1] = d * x1 + md;
3481
- }
3482
- }
3483
-
3484
- template <typename type4x4>
3485
- void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
3486
- device const uint16_t * qs = ((device const uint16_t *)xb + 4);
3487
- const float d = xb->d;
3488
- const float m = xb->m;
3489
- const ushort mask = il ? 0x00F0 : 0x000F;
3490
-
3491
- const uint32_t qh = *((device const uint32_t *)xb->qh);
3492
-
3493
- const int x_mv = il ? 4 : 0;
3494
-
3495
- const int gh_mv = il ? 12 : 0;
3496
- const int gh_bk = il ? 0 : 4;
3497
-
3498
- for (int i = 0; i < 8; i++) {
3499
- // extract the 5-th bits for x0 and x1
3500
- const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
3501
- const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
3502
-
3503
- // combine the 4-bits from qs with the 5th bit
3504
- const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
3505
- const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
3506
-
3507
- reg[i/2][2*(i%2)+0] = d * x0 + m;
3508
- reg[i/2][2*(i%2)+1] = d * x1 + m;
3509
- }
3510
- }
3511
-
3512
- template <typename type4x4>
3513
- void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
3514
- device const int8_t * qs = ((device const int8_t *)xb->qs);
3515
- const half d = xb->d;
3516
-
3517
- for (int i=0;i<16;i++) {
3518
- reg[i/4][i%4] = (qs[i + 16*il] * d);
3519
- }
3520
- }
3521
-
3522
- template <typename type4x4>
3523
- void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
3524
- const float d = xb->d;
3525
- const float min = xb->dmin;
3526
- device const uint8_t * q = (device const uint8_t *)xb->qs;
3527
- float dl, ml;
3528
- uint8_t sc = xb->scales[il];
3529
-
3530
- #if QK_K == 256
3531
- q = q + 32*(il/8) + 16*(il&1);
3532
- il = (il/2)%4;
3533
- #endif
3534
- half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
3535
- uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
3536
- dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
3537
- for (int i = 0; i < 16; ++i) {
3538
- reg[i/4][i%4] = dl * (q[i] & mask) - ml;
3539
- }
3540
- }
3541
-
3542
- template <typename type4x4>
3543
- void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
3544
- const half d_all = xb->d;
3545
- device const uint8_t * q = (device const uint8_t *)xb->qs;
3546
- device const uint8_t * h = (device const uint8_t *)xb->hmask;
3547
- device const int8_t * scales = (device const int8_t *)xb->scales;
3548
-
3549
- #if QK_K == 256
3550
- q = q + 32 * (il/8) + 16 * (il&1);
3551
- h = h + 16 * (il&1);
3552
- uint8_t m = 1 << (il/2);
3553
- uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \
3554
- ((il/4)>0 ? 12 : 3);
3555
- uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
3556
- uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
3557
- int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
3558
- : (scale_2&kmask2) | ((scale_1&kmask1) << 4);
3559
- half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h);
3560
- const half ml = 4.h * dl;
3561
-
3562
- il = (il/2) & 3;
3563
- const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
3564
- const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
3565
- dl *= coef;
3566
-
3567
- for (int i = 0; i < 16; ++i) {
3568
- reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
3569
- }
3570
- #else
3571
- float kcoef = il&1 ? 1.f/16.f : 1.f;
3572
- uint16_t kmask = il&1 ? 0xF0 : 0x0F;
3573
- float dl = d_all * ((scales[il/2] & kmask) * kcoef - 8);
3574
- float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
3575
- uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
3576
- uint8_t m = 1<<(il*2);
3577
- for (int i = 0; i < 16; ++i) {
3578
- reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i%8] & (m * (1 + i/8))) ? 0 : 4.f/coef));
3579
- }
3580
- #endif
3581
- }
3582
-
3583
- static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
3584
- return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
3585
- : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
3586
- }
3587
-
3588
- template <typename type4x4>
3589
- void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
3590
- device const uchar * q = xb->qs;
3591
-
3592
- #if QK_K == 256
3593
- short is = (il/4) * 2;
3594
- q = q + (il/4) * 32 + 16 * (il&1);
3595
- il = il & 3;
3596
- const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
3597
- const float d = il < 2 ? xb->d : xb->d / 16.h;
3598
- const float min = xb->dmin;
3599
- const float dl = d * sc[0];
3600
- const float ml = min * sc[1];
3601
- #else
3602
- q = q + 16 * (il&1);
3603
- device const uint8_t * s = xb->scales;
3604
- device const half2 * dh = (device const half2 *)xb->d;
3605
- const float2 d = (float2)dh[0];
3606
- const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
3607
- const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4);
3608
- #endif
3609
- const ushort mask = il<2 ? 0x0F : 0xF0;
3610
- for (int i = 0; i < 16; ++i) {
3611
- reg[i/4][i%4] = dl * (q[i] & mask) - ml;
3612
- }
3613
- }
3614
-
3615
- template <typename type4x4>
3616
- void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) {
3617
- device const uint8_t * q = xb->qs;
3618
- device const uint8_t * qh = xb->qh;
3619
-
3620
- #if QK_K == 256
3621
- short is = (il/4) * 2;
3622
- q = q + 32 * (il/4) + 16 * (il&1);
3623
- qh = qh + 16 * (il&1);
3624
- uint8_t ul = 1 << (il/2);
3625
- il = il & 3;
3626
- const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
3627
- const float d = il < 2 ? xb->d : xb->d / 16.h;
3628
- const float min = xb->dmin;
3629
- const float dl = d * sc[0];
3630
- const float ml = min * sc[1];
3631
-
3632
- const ushort mask = il<2 ? 0x0F : 0xF0;
3633
- const float qh_val = il<2 ? 16.f : 256.f;
3634
- for (int i = 0; i < 16; ++i) {
3635
- reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
3636
- }
3637
- #else
3638
- q = q + 16 * (il&1);
3639
- device const int8_t * s = xb->scales;
3640
- const float dl = xb->d * s[il];
3641
- uint8_t m = 1<<(il*2);
3642
- const float coef = il<2 ? 1.f : 1.f/16.f;
3643
- const ushort mask = il<2 ? 0x0F : 0xF0;
3644
- for (int i = 0; i < 16; ++i) {
3645
- reg[i/4][i%4] = coef * dl * ((q[i] & mask) - (qh[i%8] & (m*(1+i/8)) ? 0.f : 16.f/coef));
3646
- }
3647
- #endif
3648
- }
3649
-
3650
- template <typename type4x4>
3651
- void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
3652
- const half d_all = xb->d;
3653
- device const uint8_t * ql = (device const uint8_t *)xb->ql;
3654
- device const uint8_t * qh = (device const uint8_t *)xb->qh;
3655
- device const int8_t * scales = (device const int8_t *)xb->scales;
3656
-
3657
- #if QK_K == 256
3658
- ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
3659
- qh = qh + 32*(il/8) + 16*(il&1);
3660
- half sc = scales[(il%2) + 2 * ((il/2))];
3661
- il = (il/2) & 3;
3662
- #else
3663
- ql = ql + 16 * (il&1);
3664
- half sc = scales[il];
3665
- #endif
3666
- const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
3667
- const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
3668
- const half coef = il>1 ? 1.f/16.h : 1.h;
3669
- const half ml = d_all * sc * 32.h;
3670
- const half dl = d_all * sc * coef;
3671
- for (int i = 0; i < 16; ++i) {
3672
- const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
3673
- : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
3674
- reg[i/4][i%4] = dl * q - ml;
3675
- }
3676
- }
3677
-
3678
- template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
3679
- kernel void kernel_get_rows(
3680
- device const void * src0,
3681
- device const char * src1,
3682
- device float * dst,
3683
- constant int64_t & ne00,
3684
- constant uint64_t & nb01,
3685
- constant uint64_t & nb02,
3686
- constant int64_t & ne10,
3687
- constant uint64_t & nb10,
3688
- constant uint64_t & nb11,
3689
- constant uint64_t & nb1,
3690
- constant uint64_t & nb2,
3691
- uint3 tgpig[[threadgroup_position_in_grid]],
3692
- uint tiitg[[thread_index_in_threadgroup]],
3693
- uint3 tptg [[threads_per_threadgroup]]) {
3694
- //const int64_t i = tgpig;
3695
- //const int64_t r = ((device int32_t *) src1)[i];
3696
-
3697
- const int64_t i10 = tgpig.x;
3698
- const int64_t i11 = tgpig.y;
3699
-
3700
- const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
3701
-
3702
- const int64_t i02 = i11;
3703
-
3704
- for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
3705
- float4x4 temp;
3706
- dequantize_func(
3707
- ((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
3708
- *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
3709
- }
3710
- }
3711
-
3712
- kernel void kernel_get_rows_f32(
3713
- device const void * src0,
3714
- device const char * src1,
3715
- device float * dst,
3716
- constant int64_t & ne00,
3717
- constant uint64_t & nb01,
3718
- constant uint64_t & nb02,
3719
- constant int64_t & ne10,
3720
- constant uint64_t & nb10,
3721
- constant uint64_t & nb11,
3722
- constant uint64_t & nb1,
3723
- constant uint64_t & nb2,
3724
- uint3 tgpig[[threadgroup_position_in_grid]],
3725
- uint tiitg[[thread_index_in_threadgroup]],
3726
- uint3 tptg [[threads_per_threadgroup]]) {
3727
- const int64_t i10 = tgpig.x;
3728
- const int64_t i11 = tgpig.y;
3729
-
3730
- const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
3731
-
3732
- const int64_t i02 = i11;
3733
-
3734
- for (int ind = tiitg; ind < ne00; ind += tptg.x) {
3735
- ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
3736
- ((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
3737
- }
3738
- }
3739
-
3740
- kernel void kernel_get_rows_f16(
3741
- device const void * src0,
3742
- device const char * src1,
3743
- device float * dst,
3744
- constant int64_t & ne00,
3745
- constant uint64_t & nb01,
3746
- constant uint64_t & nb02,
3747
- constant int64_t & ne10,
3748
- constant uint64_t & nb10,
3749
- constant uint64_t & nb11,
3750
- constant uint64_t & nb1,
3751
- constant uint64_t & nb2,
3752
- uint3 tgpig[[threadgroup_position_in_grid]],
3753
- uint tiitg[[thread_index_in_threadgroup]],
3754
- uint3 tptg [[threads_per_threadgroup]]) {
3755
- const int64_t i10 = tgpig.x;
3756
- const int64_t i11 = tgpig.y;
3757
-
3758
- const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
3759
-
3760
- const int64_t i02 = i11;
3761
-
3762
- for (int ind = tiitg; ind < ne00; ind += tptg.x) {
3763
- ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
3764
- ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
3765
- }
3766
- }
3767
-
3768
- #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
3769
- #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
3770
- #define BLOCK_SIZE_K 32
3771
- #define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
3772
- #define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
3773
- #define THREAD_PER_BLOCK 128
3774
- #define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers
3775
- #define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers
3776
- #define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8
3777
- #define SG_MAT_ROW 8
3778
-
3779
- // each block_q contains 16*nl weights
3780
- template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
3781
- void kernel_mul_mm_impl(device const uchar * src0,
3782
- device const uchar * src1,
3783
- device float * dst,
3784
- constant int64_t & ne00,
3785
- constant int64_t & ne02,
3786
- constant int64_t & nb01,
3787
- constant int64_t & nb02,
3788
- constant int64_t & ne12,
3789
- constant int64_t & nb10,
3790
- constant int64_t & nb11,
3791
- constant int64_t & nb12,
3792
- constant int64_t & ne0,
3793
- constant int64_t & ne1,
3794
- constant uint & r2,
3795
- constant uint & r3,
3796
- threadgroup uchar * shared_memory [[threadgroup(0)]],
3797
- uint3 tgpig[[threadgroup_position_in_grid]],
3798
- uint tiitg[[thread_index_in_threadgroup]],
3799
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3800
-
3801
- threadgroup half * sa = (threadgroup half *)(shared_memory);
3802
- threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
3803
-
3804
- const uint r0 = tgpig.y;
3805
- const uint r1 = tgpig.x;
3806
- const uint im = tgpig.z;
3807
-
3808
- // if this block is of 64x32 shape or smaller
3809
- short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
3810
- short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
3811
-
3812
- // a thread shouldn't load data outside of the matrix
3813
- short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
3814
- short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
3815
-
3816
- simdgroup_half8x8 ma[4];
3817
- simdgroup_float8x8 mb[2];
3818
- simdgroup_float8x8 c_res[8];
3819
- for (int i = 0; i < 8; i++){
3820
- c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
3821
- }
3822
-
3823
- short il = (tiitg % THREAD_PER_ROW);
3824
-
3825
- const uint i12 = im%ne12;
3826
- const uint i13 = im/ne12;
3827
-
3828
- uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
3829
- ushort offset1 = il/nl;
3830
-
3831
- device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
3832
- device const float * y = (device const float *)(src1
3833
- + nb12 * im
3834
- + nb11 * (r1 * BLOCK_SIZE_N + thread_col)
3835
- + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
3836
-
3837
- for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
3838
- // load data and store to threadgroup memory
3839
- half4x4 temp_a;
3840
- dequantize_func(x, il, temp_a);
3841
- threadgroup_barrier(mem_flags::mem_threadgroup);
3842
-
3843
- #pragma unroll(16)
3844
- for (int i = 0; i < 16; i++) {
3845
- *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
3846
- + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
3847
- + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
3848
- }
3849
-
3850
- *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
3851
-
3852
- il = (il + 2 < nl) ? il + 2 : il % 2;
3853
- x = (il < 2) ? x + (2+nl-1)/nl : x;
3854
- y += BLOCK_SIZE_K;
3855
-
3856
- threadgroup_barrier(mem_flags::mem_threadgroup);
3857
-
3858
- // load matrices from threadgroup memory and conduct outer products
3859
- threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
3860
- threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
3861
-
3862
- #pragma unroll(4)
3863
- for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
3864
- #pragma unroll(4)
3865
- for (int i = 0; i < 4; i++) {
3866
- simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
3867
- }
3868
- simdgroup_barrier(mem_flags::mem_none);
3869
- #pragma unroll(2)
3870
- for (int i = 0; i < 2; i++) {
3871
- simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
3872
- }
3873
-
3874
- lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
3875
- lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
3876
-
3877
- #pragma unroll(8)
3878
- for (int i = 0; i < 8; i++){
3879
- simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
3880
- }
3881
- }
3882
- }
3883
-
3884
- if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
3885
- device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
3886
- + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
3887
- for (int i = 0; i < 8; i++) {
3888
- simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
3889
- }
3890
- } else {
3891
- // block is smaller than 64x32, we should avoid writing data outside of the matrix
3892
- threadgroup_barrier(mem_flags::mem_threadgroup);
3893
- threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
3894
- + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
3895
- for (int i = 0; i < 8; i++) {
3896
- simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
3897
- }
3898
-
3899
- threadgroup_barrier(mem_flags::mem_threadgroup);
3900
-
3901
- device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
3902
- if (sgitg == 0) {
3903
- for (int i = 0; i < n_rows; i++) {
3904
- for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
3905
- *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
3906
- }
3907
- }
3908
- }
3909
- }
3910
- }
3911
-
3912
- template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
3913
- kernel void kernel_mul_mm(device const uchar * src0,
3914
- device const uchar * src1,
3915
- device float * dst,
3916
- constant int64_t & ne00,
3917
- constant int64_t & ne02,
3918
- constant int64_t & nb01,
3919
- constant int64_t & nb02,
3920
- constant int64_t & ne12,
3921
- constant int64_t & nb10,
3922
- constant int64_t & nb11,
3923
- constant int64_t & nb12,
3924
- constant int64_t & ne0,
3925
- constant int64_t & ne1,
3926
- constant uint & r2,
3927
- constant uint & r3,
3928
- threadgroup uchar * shared_memory [[threadgroup(0)]],
3929
- uint3 tgpig[[threadgroup_position_in_grid]],
3930
- uint tiitg[[thread_index_in_threadgroup]],
3931
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3932
- kernel_mul_mm_impl<block_q, nl, dequantize_func>(
3933
- src0,
3934
- src1,
3935
- dst,
3936
- ne00,
3937
- ne02,
3938
- nb01,
3939
- nb02,
3940
- ne12,
3941
- nb10,
3942
- nb11,
3943
- nb12,
3944
- ne0,
3945
- ne1,
3946
- r2,
3947
- r3,
3948
- shared_memory,
3949
- tgpig,
3950
- tiitg,
3951
- sgitg);
3952
- }
3953
-
3954
- template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
3955
- kernel void kernel_mul_mm_id(
3956
- device const uchar * ids,
3957
- device const uchar * src1,
3958
- device uchar * dst,
3959
- constant int64_t & nbi1,
3960
- constant int64_t & ne00,
3961
- constant int64_t & ne02,
3962
- constant int64_t & nb01,
3963
- constant int64_t & nb02,
3964
- constant int64_t & ne12,
3965
- constant int64_t & ne13,
3966
- constant int64_t & nb10,
3967
- constant int64_t & nb11,
3968
- constant int64_t & nb12,
3969
- constant int64_t & ne0,
3970
- constant int64_t & ne1,
3971
- constant int64_t & nb1,
3972
- constant uint & r2,
3973
- constant uint & r3,
3974
- constant int & idx,
3975
- device const uchar * src00,
3976
- device const uchar * src01,
3977
- device const uchar * src02,
3978
- device const uchar * src03,
3979
- device const uchar * src04,
3980
- device const uchar * src05,
3981
- device const uchar * src06,
3982
- device const uchar * src07,
3983
- threadgroup uchar * shared_memory [[threadgroup(0)]],
3984
- uint3 tgpig[[threadgroup_position_in_grid]],
3985
- uint tiitg[[thread_index_in_threadgroup]],
3986
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3987
- device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
3988
-
3989
- const int64_t bid = tgpig.z/(ne12*ne13);
3990
-
3991
- tgpig.z = tgpig.z%(ne12*ne13);
3992
-
3993
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
3994
-
3995
- kernel_mul_mm_impl<block_q, nl, dequantize_func>(
3996
- src0[id],
3997
- src1 + bid*nb11,
3998
- (device float *) (dst + bid*nb1),
3999
- ne00,
4000
- ne02,
4001
- nb01,
4002
- nb02,
4003
- ne12,
4004
- nb10,
4005
- nb11,
4006
- nb12,
4007
- ne0,
4008
- ne1,
4009
- r2,
4010
- r3,
4011
- shared_memory,
4012
- tgpig,
4013
- tiitg,
4014
- sgitg);
4015
- }
4016
-
4017
- #if QK_K == 256
4018
- #define QK_NL 16
4019
- #else
4020
- #define QK_NL 4
4021
- #endif
4022
-
4023
- //
4024
- // get rows
4025
- //
4026
-
4027
- typedef void (get_rows_t)(
4028
- device const void * src0,
4029
- device const char * src1,
4030
- device float * dst,
4031
- constant int64_t & ne00,
4032
- constant uint64_t & nb01,
4033
- constant uint64_t & nb02,
4034
- constant int64_t & ne10,
4035
- constant uint64_t & nb10,
4036
- constant uint64_t & nb11,
4037
- constant uint64_t & nb1,
4038
- constant uint64_t & nb2,
4039
- uint3, uint, uint3);
4040
-
4041
- //template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
4042
- //template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
4043
- template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
4044
- template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
4045
- template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
4046
- template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows<block_q5_1, 2, dequantize_q5_1>;
4047
- template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
4048
- template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
4049
- template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
4050
- template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
4051
- template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
4052
- template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
4053
-
4054
- //
4055
- // matrix-matrix multiplication
4056
- //
4057
-
4058
- typedef void (mat_mm_t)(
4059
- device const uchar * src0,
4060
- device const uchar * src1,
4061
- device float * dst,
4062
- constant int64_t & ne00,
4063
- constant int64_t & ne02,
4064
- constant int64_t & nb01,
4065
- constant int64_t & nb02,
4066
- constant int64_t & ne12,
4067
- constant int64_t & nb10,
4068
- constant int64_t & nb11,
4069
- constant int64_t & nb12,
4070
- constant int64_t & ne0,
4071
- constant int64_t & ne1,
4072
- constant uint & r2,
4073
- constant uint & r3,
4074
- threadgroup uchar *,
4075
- uint3, uint, uint);
4076
-
4077
- template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
4078
- template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
4079
- template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
4080
- template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
4081
- template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
4082
- template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
4083
- template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
4084
- template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
4085
- template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
4086
- template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
4087
- template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
4088
- template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
4089
-
4090
- //
4091
- // indirect matrix-matrix multiplication
4092
- //
4093
-
4094
- typedef void (mat_mm_id_t)(
4095
- device const uchar * ids,
4096
- device const uchar * src1,
4097
- device uchar * dst,
4098
- constant int64_t & nbi1,
4099
- constant int64_t & ne00,
4100
- constant int64_t & ne02,
4101
- constant int64_t & nb01,
4102
- constant int64_t & nb02,
4103
- constant int64_t & ne12,
4104
- constant int64_t & ne13,
4105
- constant int64_t & nb10,
4106
- constant int64_t & nb11,
4107
- constant int64_t & nb12,
4108
- constant int64_t & ne0,
4109
- constant int64_t & ne1,
4110
- constant int64_t & nb1,
4111
- constant uint & r2,
4112
- constant uint & r3,
4113
- constant int & idx,
4114
- device const uchar * src00,
4115
- device const uchar * src01,
4116
- device const uchar * src02,
4117
- device const uchar * src03,
4118
- device const uchar * src04,
4119
- device const uchar * src05,
4120
- device const uchar * src06,
4121
- device const uchar * src07,
4122
- threadgroup uchar *,
4123
- uint3, uint, uint);
4124
-
4125
- template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
4126
- template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
4127
- template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
4128
- template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
4129
- template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
4130
- template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>;
4131
- template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>;
4132
- template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
4133
- template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
4134
- template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
4135
- template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
4136
- template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
4137
-
4138
- //
4139
- // matrix-vector multiplication
4140
- //
4141
-
4142
- [[host_name("kernel_mul_mv_id_f32_f32")]]
4143
- kernel void kernel_mul_mv_id_f32_f32(
4144
- device const char * ids,
4145
- device const char * src1,
4146
- device uchar * dst,
4147
- constant int64_t & nbi1,
4148
- constant int64_t & ne00,
4149
- constant int64_t & ne01,
4150
- constant int64_t & ne02,
4151
- constant uint64_t & nb00,
4152
- constant uint64_t & nb01,
4153
- constant uint64_t & nb02,
4154
- constant int64_t & ne10,
4155
- constant int64_t & ne11,
4156
- constant int64_t & ne12,
4157
- constant int64_t & ne13,
4158
- constant uint64_t & nb10,
4159
- constant uint64_t & nb11,
4160
- constant uint64_t & nb12,
4161
- constant int64_t & ne0,
4162
- constant int64_t & ne1,
4163
- constant int64_t & nb1,
4164
- constant uint & r2,
4165
- constant uint & r3,
4166
- constant int & idx,
4167
- device const char * src00,
4168
- device const char * src01,
4169
- device const char * src02,
4170
- device const char * src03,
4171
- device const char * src04,
4172
- device const char * src05,
4173
- device const char * src06,
4174
- device const char * src07,
4175
- uint3 tgpig[[threadgroup_position_in_grid]],
4176
- uint tiitg[[thread_index_in_threadgroup]],
4177
- uint tiisg[[thread_index_in_simdgroup]],
4178
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
4179
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4180
-
4181
- const int64_t bid = tgpig.z/(ne12*ne13);
4182
-
4183
- tgpig.z = tgpig.z%(ne12*ne13);
4184
-
4185
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4186
-
4187
- kernel_mul_mv_f32_f32_impl(
4188
- src0[id],
4189
- src1 + bid*nb11,
4190
- (device float *) (dst + bid*nb1),
4191
- ne00,
4192
- ne01,
4193
- ne02,
4194
- nb00,
4195
- nb01,
4196
- nb02,
4197
- ne10,
4198
- ne11,
4199
- ne12,
4200
- nb10,
4201
- nb11,
4202
- nb12,
4203
- ne0,
4204
- ne1,
4205
- r2,
4206
- r3,
4207
- tgpig,
4208
- tiisg);
4209
- }
4210
-
4211
- [[host_name("kernel_mul_mv_id_f16_f32")]]
4212
- kernel void kernel_mul_mv_id_f16_f32(
4213
- device const char * ids,
4214
- device const char * src1,
4215
- device uchar * dst,
4216
- constant int64_t & nbi1,
4217
- constant int64_t & ne00,
4218
- constant int64_t & ne01,
4219
- constant int64_t & ne02,
4220
- constant uint64_t & nb00,
4221
- constant uint64_t & nb01,
4222
- constant uint64_t & nb02,
4223
- constant int64_t & ne10,
4224
- constant int64_t & ne11,
4225
- constant int64_t & ne12,
4226
- constant int64_t & ne13,
4227
- constant uint64_t & nb10,
4228
- constant uint64_t & nb11,
4229
- constant uint64_t & nb12,
4230
- constant int64_t & ne0,
4231
- constant int64_t & ne1,
4232
- constant int64_t & nb1,
4233
- constant uint & r2,
4234
- constant uint & r3,
4235
- constant int & idx,
4236
- device const char * src00,
4237
- device const char * src01,
4238
- device const char * src02,
4239
- device const char * src03,
4240
- device const char * src04,
4241
- device const char * src05,
4242
- device const char * src06,
4243
- device const char * src07,
4244
- uint3 tgpig[[threadgroup_position_in_grid]],
4245
- uint tiitg[[thread_index_in_threadgroup]],
4246
- uint tiisg[[thread_index_in_simdgroup]],
4247
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
4248
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4249
-
4250
- const int64_t bid = tgpig.z/(ne12*ne13);
4251
-
4252
- tgpig.z = tgpig.z%(ne12*ne13);
4253
-
4254
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4255
-
4256
- kernel_mul_mv_f16_f32_impl(
4257
- src0[id],
4258
- src1 + bid*nb11,
4259
- (device float *) (dst + bid*nb1),
4260
- ne00,
4261
- ne01,
4262
- ne02,
4263
- nb00,
4264
- nb01,
4265
- nb02,
4266
- ne10,
4267
- ne11,
4268
- ne12,
4269
- nb10,
4270
- nb11,
4271
- nb12,
4272
- ne0,
4273
- ne1,
4274
- r2,
4275
- r3,
4276
- tgpig,
4277
- tiisg);
4278
- }
4279
-
4280
- [[host_name("kernel_mul_mv_id_q8_0_f32")]]
4281
- kernel void kernel_mul_mv_id_q8_0_f32(
4282
- device const char * ids,
4283
- device const char * src1,
4284
- device uchar * dst,
4285
- constant int64_t & nbi1,
4286
- constant int64_t & ne00,
4287
- constant int64_t & ne01,
4288
- constant int64_t & ne02,
4289
- constant uint64_t & nb00,
4290
- constant uint64_t & nb01,
4291
- constant uint64_t & nb02,
4292
- constant int64_t & ne10,
4293
- constant int64_t & ne11,
4294
- constant int64_t & ne12,
4295
- constant int64_t & ne13,
4296
- constant uint64_t & nb10,
4297
- constant uint64_t & nb11,
4298
- constant uint64_t & nb12,
4299
- constant int64_t & ne0,
4300
- constant int64_t & ne1,
4301
- constant int64_t & nb1,
4302
- constant uint & r2,
4303
- constant uint & r3,
4304
- constant int & idx,
4305
- device const char * src00,
4306
- device const char * src01,
4307
- device const char * src02,
4308
- device const char * src03,
4309
- device const char * src04,
4310
- device const char * src05,
4311
- device const char * src06,
4312
- device const char * src07,
4313
- uint3 tgpig[[threadgroup_position_in_grid]],
4314
- uint tiitg[[thread_index_in_threadgroup]],
4315
- uint tiisg[[thread_index_in_simdgroup]],
4316
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
4317
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4318
-
4319
- const int64_t bid = tgpig.z/(ne12*ne13);
4320
-
4321
- tgpig.z = tgpig.z%(ne12*ne13);
4322
-
4323
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4324
-
4325
- kernel_mul_mv_q8_0_f32_impl(
4326
- src0[id],
4327
- (device const float *) (src1 + bid*nb11),
4328
- (device float *) ( dst + bid*nb1),
4329
- ne00,
4330
- ne01,
4331
- ne02,
4332
- ne10,
4333
- ne12,
4334
- ne0,
4335
- ne1,
4336
- r2,
4337
- r3,
4338
- tgpig,
4339
- tiisg,
4340
- sgitg);
4341
- }
4342
-
4343
- [[host_name("kernel_mul_mv_id_q4_0_f32")]]
4344
- kernel void kernel_mul_mv_id_q4_0_f32(
4345
- device const char * ids,
4346
- device const char * src1,
4347
- device uchar * dst,
4348
- constant int64_t & nbi1,
4349
- constant int64_t & ne00,
4350
- constant int64_t & ne01,
4351
- constant int64_t & ne02,
4352
- constant uint64_t & nb00,
4353
- constant uint64_t & nb01,
4354
- constant uint64_t & nb02,
4355
- constant int64_t & ne10,
4356
- constant int64_t & ne11,
4357
- constant int64_t & ne12,
4358
- constant int64_t & ne13,
4359
- constant uint64_t & nb10,
4360
- constant uint64_t & nb11,
4361
- constant uint64_t & nb12,
4362
- constant int64_t & ne0,
4363
- constant int64_t & ne1,
4364
- constant int64_t & nb1,
4365
- constant uint & r2,
4366
- constant uint & r3,
4367
- constant int & idx,
4368
- device const char * src00,
4369
- device const char * src01,
4370
- device const char * src02,
4371
- device const char * src03,
4372
- device const char * src04,
4373
- device const char * src05,
4374
- device const char * src06,
4375
- device const char * src07,
4376
- uint3 tgpig[[threadgroup_position_in_grid]],
4377
- uint tiitg[[thread_index_in_threadgroup]],
4378
- uint tiisg[[thread_index_in_simdgroup]],
4379
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
4380
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4381
-
4382
- const int64_t bid = tgpig.z/(ne12*ne13);
4383
-
4384
- tgpig.z = tgpig.z%(ne12*ne13);
4385
-
4386
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4387
-
4388
- mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4389
- src0[id],
4390
- (device const float *) (src1 + bid*nb11),
4391
- (device float *) ( dst + bid*nb1),
4392
- ne00,
4393
- ne01,
4394
- ne02,
4395
- ne10,
4396
- ne12,
4397
- ne0,
4398
- ne1,
4399
- r2,
4400
- r3,
4401
- tgpig,
4402
- tiisg,
4403
- sgitg);
4404
- }
4405
-
4406
- [[host_name("kernel_mul_mv_id_q4_1_f32")]]
4407
- kernel void kernel_mul_mv_id_q4_1_f32(
4408
- device const char * ids,
4409
- device const char * src1,
4410
- device uchar * dst,
4411
- constant int64_t & nbi1,
4412
- constant int64_t & ne00,
4413
- constant int64_t & ne01,
4414
- constant int64_t & ne02,
4415
- constant uint64_t & nb00,
4416
- constant uint64_t & nb01,
4417
- constant uint64_t & nb02,
4418
- constant int64_t & ne10,
4419
- constant int64_t & ne11,
4420
- constant int64_t & ne12,
4421
- constant int64_t & ne13,
4422
- constant uint64_t & nb10,
4423
- constant uint64_t & nb11,
4424
- constant uint64_t & nb12,
4425
- constant int64_t & ne0,
4426
- constant int64_t & ne1,
4427
- constant int64_t & nb1,
4428
- constant uint & r2,
4429
- constant uint & r3,
4430
- constant int & idx,
4431
- device const char * src00,
4432
- device const char * src01,
4433
- device const char * src02,
4434
- device const char * src03,
4435
- device const char * src04,
4436
- device const char * src05,
4437
- device const char * src06,
4438
- device const char * src07,
4439
- uint3 tgpig[[threadgroup_position_in_grid]],
4440
- uint tiitg[[thread_index_in_threadgroup]],
4441
- uint tiisg[[thread_index_in_simdgroup]],
4442
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
4443
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4444
-
4445
- const int64_t bid = tgpig.z/(ne12*ne13);
4446
-
4447
- tgpig.z = tgpig.z%(ne12*ne13);
4448
-
4449
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4450
-
4451
- mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4452
- src0[id],
4453
- (device const float *) (src1 + bid*nb11),
4454
- (device float *) ( dst + bid*nb1),
4455
- ne00,
4456
- ne01,
4457
- ne02,
4458
- ne10,
4459
- ne12,
4460
- ne0,
4461
- ne1,
4462
- r2,
4463
- r3,
4464
- tgpig,
4465
- tiisg,
4466
- sgitg);
4467
- }
4468
-
4469
- [[host_name("kernel_mul_mv_id_q5_0_f32")]]
4470
- kernel void kernel_mul_mv_id_q5_0_f32(
4471
- device const char * ids,
4472
- device const char * src1,
4473
- device uchar * dst,
4474
- constant int64_t & nbi1,
4475
- constant int64_t & ne00,
4476
- constant int64_t & ne01,
4477
- constant int64_t & ne02,
4478
- constant uint64_t & nb00,
4479
- constant uint64_t & nb01,
4480
- constant uint64_t & nb02,
4481
- constant int64_t & ne10,
4482
- constant int64_t & ne11,
4483
- constant int64_t & ne12,
4484
- constant int64_t & ne13,
4485
- constant uint64_t & nb10,
4486
- constant uint64_t & nb11,
4487
- constant uint64_t & nb12,
4488
- constant int64_t & ne0,
4489
- constant int64_t & ne1,
4490
- constant int64_t & nb1,
4491
- constant uint & r2,
4492
- constant uint & r3,
4493
- constant int & idx,
4494
- device const char * src00,
4495
- device const char * src01,
4496
- device const char * src02,
4497
- device const char * src03,
4498
- device const char * src04,
4499
- device const char * src05,
4500
- device const char * src06,
4501
- device const char * src07,
4502
- uint3 tgpig[[threadgroup_position_in_grid]],
4503
- uint tiitg[[thread_index_in_threadgroup]],
4504
- uint tiisg[[thread_index_in_simdgroup]],
4505
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
4506
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4507
-
4508
- const int64_t bid = tgpig.z/(ne12*ne13);
4509
-
4510
- tgpig.z = tgpig.z%(ne12*ne13);
4511
-
4512
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4513
-
4514
- mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4515
- src0[id],
4516
- (device const float *) (src1 + bid*nb11),
4517
- (device float *) ( dst + bid*nb1),
4518
- ne00,
4519
- ne01,
4520
- ne02,
4521
- ne10,
4522
- ne12,
4523
- ne0,
4524
- ne1,
4525
- r2,
4526
- r3,
4527
- tgpig,
4528
- tiisg,
4529
- sgitg);
4530
- }
4531
-
4532
- [[host_name("kernel_mul_mv_id_q5_1_f32")]]
4533
- kernel void kernel_mul_mv_id_q5_1_f32(
4534
- device const char * ids,
4535
- device const char * src1,
4536
- device uchar * dst,
4537
- constant int64_t & nbi1,
4538
- constant int64_t & ne00,
4539
- constant int64_t & ne01,
4540
- constant int64_t & ne02,
4541
- constant uint64_t & nb00,
4542
- constant uint64_t & nb01,
4543
- constant uint64_t & nb02,
4544
- constant int64_t & ne10,
4545
- constant int64_t & ne11,
4546
- constant int64_t & ne12,
4547
- constant int64_t & ne13,
4548
- constant uint64_t & nb10,
4549
- constant uint64_t & nb11,
4550
- constant uint64_t & nb12,
4551
- constant int64_t & ne0,
4552
- constant int64_t & ne1,
4553
- constant int64_t & nb1,
4554
- constant uint & r2,
4555
- constant uint & r3,
4556
- constant int & idx,
4557
- device const char * src00,
4558
- device const char * src01,
4559
- device const char * src02,
4560
- device const char * src03,
4561
- device const char * src04,
4562
- device const char * src05,
4563
- device const char * src06,
4564
- device const char * src07,
4565
- uint3 tgpig[[threadgroup_position_in_grid]],
4566
- uint tiitg[[thread_index_in_threadgroup]],
4567
- uint tiisg[[thread_index_in_simdgroup]],
4568
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
4569
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4570
-
4571
- const int64_t bid = tgpig.z/(ne12*ne13);
4572
-
4573
- tgpig.z = tgpig.z%(ne12*ne13);
4574
-
4575
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4576
-
4577
- mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4578
- src0[id],
4579
- (device const float *) (src1 + bid*nb11),
4580
- (device float *) ( dst + bid*nb1),
4581
- ne00,
4582
- ne01,
4583
- ne02,
4584
- ne10,
4585
- ne12,
4586
- ne0,
4587
- ne1,
4588
- r2,
4589
- r3,
4590
- tgpig,
4591
- tiisg,
4592
- sgitg);
4593
- }
4594
-
4595
- [[host_name("kernel_mul_mv_id_q2_K_f32")]]
4596
- kernel void kernel_mul_mv_id_q2_K_f32(
4597
- device const char * ids,
4598
- device const char * src1,
4599
- device uchar * dst,
4600
- constant int64_t & nbi1,
4601
- constant int64_t & ne00,
4602
- constant int64_t & ne01,
4603
- constant int64_t & ne02,
4604
- constant uint64_t & nb00,
4605
- constant uint64_t & nb01,
4606
- constant uint64_t & nb02,
4607
- constant int64_t & ne10,
4608
- constant int64_t & ne11,
4609
- constant int64_t & ne12,
4610
- constant int64_t & ne13,
4611
- constant uint64_t & nb10,
4612
- constant uint64_t & nb11,
4613
- constant uint64_t & nb12,
4614
- constant int64_t & ne0,
4615
- constant int64_t & ne1,
4616
- constant int64_t & nb1,
4617
- constant uint & r2,
4618
- constant uint & r3,
4619
- constant int & idx,
4620
- device const char * src00,
4621
- device const char * src01,
4622
- device const char * src02,
4623
- device const char * src03,
4624
- device const char * src04,
4625
- device const char * src05,
4626
- device const char * src06,
4627
- device const char * src07,
4628
- uint3 tgpig[[threadgroup_position_in_grid]],
4629
- uint tiitg[[thread_index_in_threadgroup]],
4630
- uint tiisg[[thread_index_in_simdgroup]],
4631
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
4632
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4633
-
4634
- const int64_t bid = tgpig.z/(ne12*ne13);
4635
-
4636
- tgpig.z = tgpig.z%(ne12*ne13);
4637
-
4638
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4639
-
4640
- kernel_mul_mv_q2_K_f32_impl(
4641
- src0[id],
4642
- (device const float *) (src1 + bid*nb11),
4643
- (device float *) ( dst + bid*nb1),
4644
- ne00,
4645
- ne01,
4646
- ne02,
4647
- ne10,
4648
- ne12,
4649
- ne0,
4650
- ne1,
4651
- r2,
4652
- r3,
4653
- tgpig,
4654
- tiisg,
4655
- sgitg);
4656
- }
4657
-
4658
- [[host_name("kernel_mul_mv_id_q3_K_f32")]]
4659
- kernel void kernel_mul_mv_id_q3_K_f32(
4660
- device const char * ids,
4661
- device const char * src1,
4662
- device uchar * dst,
4663
- constant int64_t & nbi1,
4664
- constant int64_t & ne00,
4665
- constant int64_t & ne01,
4666
- constant int64_t & ne02,
4667
- constant uint64_t & nb00,
4668
- constant uint64_t & nb01,
4669
- constant uint64_t & nb02,
4670
- constant int64_t & ne10,
4671
- constant int64_t & ne11,
4672
- constant int64_t & ne12,
4673
- constant int64_t & ne13,
4674
- constant uint64_t & nb10,
4675
- constant uint64_t & nb11,
4676
- constant uint64_t & nb12,
4677
- constant int64_t & ne0,
4678
- constant int64_t & ne1,
4679
- constant int64_t & nb1,
4680
- constant uint & r2,
4681
- constant uint & r3,
4682
- constant int & idx,
4683
- device const char * src00,
4684
- device const char * src01,
4685
- device const char * src02,
4686
- device const char * src03,
4687
- device const char * src04,
4688
- device const char * src05,
4689
- device const char * src06,
4690
- device const char * src07,
4691
- uint3 tgpig[[threadgroup_position_in_grid]],
4692
- uint tiitg[[thread_index_in_threadgroup]],
4693
- uint tiisg[[thread_index_in_simdgroup]],
4694
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
4695
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4696
-
4697
- const int64_t bid = tgpig.z/(ne12*ne13);
4698
-
4699
- tgpig.z = tgpig.z%(ne12*ne13);
4700
-
4701
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4702
-
4703
- kernel_mul_mv_q3_K_f32_impl(
4704
- src0[id],
4705
- (device const float *) (src1 + bid*nb11),
4706
- (device float *) ( dst + bid*nb1),
4707
- ne00,
4708
- ne01,
4709
- ne02,
4710
- ne10,
4711
- ne12,
4712
- ne0,
4713
- ne1,
4714
- r2,
4715
- r3,
4716
- tgpig,
4717
- tiisg,
4718
- sgitg);
4719
- }
4720
-
4721
- [[host_name("kernel_mul_mv_id_q4_K_f32")]]
4722
- kernel void kernel_mul_mv_id_q4_K_f32(
4723
- device const char * ids,
4724
- device const char * src1,
4725
- device uchar * dst,
4726
- constant int64_t & nbi1,
4727
- constant int64_t & ne00,
4728
- constant int64_t & ne01,
4729
- constant int64_t & ne02,
4730
- constant uint64_t & nb00,
4731
- constant uint64_t & nb01,
4732
- constant uint64_t & nb02,
4733
- constant int64_t & ne10,
4734
- constant int64_t & ne11,
4735
- constant int64_t & ne12,
4736
- constant int64_t & ne13,
4737
- constant uint64_t & nb10,
4738
- constant uint64_t & nb11,
4739
- constant uint64_t & nb12,
4740
- constant int64_t & ne0,
4741
- constant int64_t & ne1,
4742
- constant int64_t & nb1,
4743
- constant uint & r2,
4744
- constant uint & r3,
4745
- constant int & idx,
4746
- device const char * src00,
4747
- device const char * src01,
4748
- device const char * src02,
4749
- device const char * src03,
4750
- device const char * src04,
4751
- device const char * src05,
4752
- device const char * src06,
4753
- device const char * src07,
4754
- uint3 tgpig[[threadgroup_position_in_grid]],
4755
- uint tiitg[[thread_index_in_threadgroup]],
4756
- uint tiisg[[thread_index_in_simdgroup]],
4757
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
4758
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4759
-
4760
- const int64_t bid = tgpig.z/(ne12*ne13);
4761
-
4762
- tgpig.z = tgpig.z%(ne12*ne13);
4763
-
4764
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4765
-
4766
- kernel_mul_mv_q4_K_f32_impl(
4767
- src0[id],
4768
- (device const float *) (src1 + bid*nb11),
4769
- (device float *) ( dst + bid*nb1),
4770
- ne00,
4771
- ne01,
4772
- ne02,
4773
- ne10,
4774
- ne12,
4775
- ne0,
4776
- ne1,
4777
- r2,
4778
- r3,
4779
- tgpig,
4780
- tiisg,
4781
- sgitg);
4782
- }
4783
-
4784
- [[host_name("kernel_mul_mv_id_q5_K_f32")]]
4785
- kernel void kernel_mul_mv_id_q5_K_f32(
4786
- device const char * ids,
4787
- device const char * src1,
4788
- device uchar * dst,
4789
- constant int64_t & nbi1,
4790
- constant int64_t & ne00,
4791
- constant int64_t & ne01,
4792
- constant int64_t & ne02,
4793
- constant uint64_t & nb00,
4794
- constant uint64_t & nb01,
4795
- constant uint64_t & nb02,
4796
- constant int64_t & ne10,
4797
- constant int64_t & ne11,
4798
- constant int64_t & ne12,
4799
- constant int64_t & ne13,
4800
- constant uint64_t & nb10,
4801
- constant uint64_t & nb11,
4802
- constant uint64_t & nb12,
4803
- constant int64_t & ne0,
4804
- constant int64_t & ne1,
4805
- constant int64_t & nb1,
4806
- constant uint & r2,
4807
- constant uint & r3,
4808
- constant int & idx,
4809
- device const char * src00,
4810
- device const char * src01,
4811
- device const char * src02,
4812
- device const char * src03,
4813
- device const char * src04,
4814
- device const char * src05,
4815
- device const char * src06,
4816
- device const char * src07,
4817
- uint3 tgpig[[threadgroup_position_in_grid]],
4818
- uint tiitg[[thread_index_in_threadgroup]],
4819
- uint tiisg[[thread_index_in_simdgroup]],
4820
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
4821
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4822
-
4823
- const int64_t bid = tgpig.z/(ne12*ne13);
4824
-
4825
- tgpig.z = tgpig.z%(ne12*ne13);
4826
-
4827
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4828
-
4829
- kernel_mul_mv_q5_K_f32_impl(
4830
- src0[id],
4831
- (device const float *) (src1 + bid*nb11),
4832
- (device float *) ( dst + bid*nb1),
4833
- ne00,
4834
- ne01,
4835
- ne02,
4836
- ne10,
4837
- ne12,
4838
- ne0,
4839
- ne1,
4840
- r2,
4841
- r3,
4842
- tgpig,
4843
- tiisg,
4844
- sgitg);
4845
- }
4846
-
4847
- [[host_name("kernel_mul_mv_id_q6_K_f32")]]
4848
- kernel void kernel_mul_mv_id_q6_K_f32(
4849
- device const char * ids,
4850
- device const char * src1,
4851
- device uchar * dst,
4852
- constant int64_t & nbi1,
4853
- constant int64_t & ne00,
4854
- constant int64_t & ne01,
4855
- constant int64_t & ne02,
4856
- constant uint64_t & nb00,
4857
- constant uint64_t & nb01,
4858
- constant uint64_t & nb02,
4859
- constant int64_t & ne10,
4860
- constant int64_t & ne11,
4861
- constant int64_t & ne12,
4862
- constant int64_t & ne13,
4863
- constant uint64_t & nb10,
4864
- constant uint64_t & nb11,
4865
- constant uint64_t & nb12,
4866
- constant int64_t & ne0,
4867
- constant int64_t & ne1,
4868
- constant int64_t & nb1,
4869
- constant uint & r2,
4870
- constant uint & r3,
4871
- constant int & idx,
4872
- device const char * src00,
4873
- device const char * src01,
4874
- device const char * src02,
4875
- device const char * src03,
4876
- device const char * src04,
4877
- device const char * src05,
4878
- device const char * src06,
4879
- device const char * src07,
4880
- uint3 tgpig[[threadgroup_position_in_grid]],
4881
- uint tiitg[[thread_index_in_threadgroup]],
4882
- uint tiisg[[thread_index_in_simdgroup]],
4883
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
4884
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4885
-
4886
- const int64_t bid = tgpig.z/(ne12*ne13);
4887
-
4888
- tgpig.z = tgpig.z%(ne12*ne13);
4889
-
4890
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4891
-
4892
- kernel_mul_mv_q6_K_f32_impl(
4893
- src0[id],
4894
- (device const float *) (src1 + bid*nb11),
4895
- (device float *) ( dst + bid*nb1),
4896
- ne00,
4897
- ne01,
4898
- ne02,
4899
- ne10,
4900
- ne12,
4901
- ne0,
4902
- ne1,
4903
- r2,
4904
- r3,
4905
- tgpig,
4906
- tiisg,
4907
- sgitg);
4908
- }