llama_cpp 0.1.4 → 0.2.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,1436 @@
1
+ #include <metal_stdlib>
2
+
3
+ using namespace metal;
4
+
5
+ #define MAX(x, y) ((x) > (y) ? (x) : (y))
6
+
7
+ #define QK4_0 32
8
+ #define QR4_0 2
9
+ typedef struct {
10
+ half d; // delta
11
+ uint8_t qs[QK4_0 / 2]; // nibbles / quants
12
+ } block_q4_0;
13
+
14
+ #define QK4_1 32
15
+ typedef struct {
16
+ half d; // delta
17
+ half m; // min
18
+ uint8_t qs[QK4_1 / 2]; // nibbles / quants
19
+ } block_q4_1;
20
+
21
+ static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, int k) {
22
+ const int qk = QK4_0;
23
+
24
+ assert(k % qk == 0);
25
+
26
+ const int nb = k / qk;
27
+
28
+ for (int i = 0; i < nb; i++) {
29
+ const half d = x[i].d;
30
+
31
+ for (int j = 0; j < qk/2; ++j) {
32
+ const int x0 = (x[i].qs[j] & 0x0F) - 8;
33
+ const int x1 = (x[i].qs[j] >> 4) - 8;
34
+
35
+ y[i*qk + j + 0 ] = x0*d;
36
+ y[i*qk + j + qk/2] = x1*d;
37
+ }
38
+ }
39
+ }
40
+
41
+ static void dequantize_row_q4_1(device const block_q4_1 * x, device float * y, int k) {
42
+ const int qk = QK4_1;
43
+
44
+ assert(k % qk == 0);
45
+
46
+ const int nb = k / qk;
47
+
48
+ for (int i = 0; i < nb; i++) {
49
+ const half d = x[i].d;
50
+ const half m = x[i].m;
51
+
52
+ for (int j = 0; j < qk/2; ++j) {
53
+ const int x0 = (x[i].qs[j] & 0x0F);
54
+ const int x1 = (x[i].qs[j] >> 4);
55
+
56
+ y[i*qk + j + 0 ] = x0*d + m;
57
+ y[i*qk + j + qk/2] = x1*d + m;
58
+ }
59
+ }
60
+ }
61
+
62
+ kernel void kernel_add(
63
+ device const float * src0,
64
+ device const float * src1,
65
+ device float * dst,
66
+ uint tpig[[thread_position_in_grid]]) {
67
+ dst[tpig] = src0[tpig] + src1[tpig];
68
+ }
69
+
70
+ kernel void kernel_mul(
71
+ device const float * src0,
72
+ device const float * src1,
73
+ device float * dst,
74
+ uint tpig[[thread_position_in_grid]]) {
75
+ dst[tpig] = src0[tpig] * src1[tpig];
76
+ }
77
+
78
+ // assumption: src1 is a row
79
+ // broadcast src1 into src0
80
+ kernel void kernel_mul_row(
81
+ device const float * src0,
82
+ device const float * src1,
83
+ device float * dst,
84
+ constant int64_t & ne00,
85
+ uint tpig[[thread_position_in_grid]]) {
86
+ dst[tpig] = src0[tpig] * src1[tpig % ne00];
87
+ }
88
+
89
+ kernel void kernel_scale(
90
+ device const float * src0,
91
+ device float * dst,
92
+ constant float & scale,
93
+ uint tpig[[thread_position_in_grid]]) {
94
+ dst[tpig] = src0[tpig] * scale;
95
+ }
96
+
97
+ kernel void kernel_silu(
98
+ device const float * src0,
99
+ device float * dst,
100
+ uint tpig[[thread_position_in_grid]]) {
101
+ float x = src0[tpig];
102
+ dst[tpig] = x / (1.0f + exp(-x));
103
+ }
104
+
105
+ kernel void kernel_relu(
106
+ device const float * src0,
107
+ device float * dst,
108
+ uint tpig[[thread_position_in_grid]]) {
109
+ dst[tpig] = max(0.0f, src0[tpig]);
110
+ }
111
+
112
+ constant float GELU_COEF_A = 0.044715f;
113
+ constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
114
+
115
+ kernel void kernel_gelu(
116
+ device const float * src0,
117
+ device float * dst,
118
+ uint tpig[[thread_position_in_grid]]) {
119
+ float x = src0[tpig];
120
+ dst[tpig] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
121
+ }
122
+
123
+ kernel void kernel_soft_max(
124
+ device const float * src0,
125
+ device float * dst,
126
+ constant int64_t & ne00,
127
+ constant int64_t & ne01,
128
+ constant int64_t & ne02,
129
+ threadgroup float * buf [[threadgroup(0)]],
130
+ uint3 tgpig[[threadgroup_position_in_grid]],
131
+ uint3 tpitg[[thread_position_in_threadgroup]],
132
+ uint3 ntg[[threads_per_threadgroup]]) {
133
+ const int64_t i03 = tgpig[2];
134
+ const int64_t i02 = tgpig[1];
135
+ const int64_t i01 = tgpig[0];
136
+
137
+ device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
138
+ device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
139
+
140
+ // parallel max
141
+ buf[tpitg[0]] = -INFINITY;
142
+ for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
143
+ buf[tpitg[0]] = MAX(buf[tpitg[0]], psrc0[i00]);
144
+ }
145
+
146
+ // reduce
147
+ threadgroup_barrier(mem_flags::mem_threadgroup);
148
+ for (uint i = ntg[0]/2; i > 0; i /= 2) {
149
+ if (tpitg[0] < i) {
150
+ buf[tpitg[0]] = MAX(buf[tpitg[0]], buf[tpitg[0] + i]);
151
+ }
152
+ threadgroup_barrier(mem_flags::mem_threadgroup);
153
+ }
154
+
155
+ // broadcast
156
+ if (tpitg[0] == 0) {
157
+ buf[0] = buf[0];
158
+ }
159
+
160
+ threadgroup_barrier(mem_flags::mem_threadgroup);
161
+
162
+ const float max = buf[0];
163
+
164
+ // parallel sum
165
+ buf[tpitg[0]] = 0.0f;
166
+ for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
167
+ buf[tpitg[0]] += exp(psrc0[i00] - max);
168
+ }
169
+
170
+ // reduce
171
+ threadgroup_barrier(mem_flags::mem_threadgroup);
172
+ for (uint i = ntg[0]/2; i > 0; i /= 2) {
173
+ if (tpitg[0] < i) {
174
+ buf[tpitg[0]] += buf[tpitg[0] + i];
175
+ }
176
+ threadgroup_barrier(mem_flags::mem_threadgroup);
177
+ }
178
+
179
+ // broadcast
180
+ if (tpitg[0] == 0) {
181
+ buf[0] = buf[0];
182
+ }
183
+
184
+ threadgroup_barrier(mem_flags::mem_threadgroup);
185
+
186
+ const float sum = buf[0];
187
+
188
+ for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
189
+ pdst[i00] = exp(psrc0[i00] - max) / sum;
190
+ }
191
+ }
192
+
193
+ kernel void kernel_diag_mask_inf(
194
+ device const float * src0,
195
+ device float * dst,
196
+ constant int64_t & ne00,
197
+ constant int64_t & ne01,
198
+ constant int & n_past,
199
+ uint3 tpig[[thread_position_in_grid]]) {
200
+ const int64_t i02 = tpig[2];
201
+ const int64_t i01 = tpig[1];
202
+ const int64_t i00 = tpig[0];
203
+
204
+ if (i00 > n_past + i01) {
205
+ dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
206
+ } else {
207
+ dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
208
+ }
209
+ }
210
+
211
+ kernel void kernel_get_rows_f16(
212
+ device const void * src0,
213
+ device const int * src1,
214
+ device float * dst,
215
+ constant int64_t & ne00,
216
+ constant uint64_t & nb01,
217
+ constant uint64_t & nb1,
218
+ uint tpig[[thread_position_in_grid]]) {
219
+ const int i = tpig;
220
+ const int r = ((device int32_t *) src1)[i];
221
+
222
+ for (int j = 0; j < ne00; j++) {
223
+ dst[i*nb1 + j] = ((device half *) ((device char *) src0 + r*nb01))[j];
224
+ }
225
+ }
226
+
227
+ kernel void kernel_get_rows_q4_0(
228
+ device const void * src0,
229
+ device const int * src1,
230
+ device float * dst,
231
+ constant int64_t & ne00,
232
+ constant uint64_t & nb01,
233
+ constant uint64_t & nb1,
234
+ uint tpig[[thread_position_in_grid]]) {
235
+ const int i = tpig;
236
+ const int r = ((device int32_t *) src1)[i];
237
+
238
+ dequantize_row_q4_0(
239
+ (device const block_q4_0 *) ((device char *) src0 + r*nb01),
240
+ (device float *) ((device char *) dst + i*nb1), ne00);
241
+ }
242
+
243
+ kernel void kernel_get_rows_q4_1(
244
+ device const void * src0,
245
+ device const int * src1,
246
+ device float * dst,
247
+ constant int64_t & ne00,
248
+ constant uint64_t & nb01,
249
+ constant uint64_t & nb1,
250
+ uint tpig[[thread_position_in_grid]]) {
251
+ const int i = tpig;
252
+ const int r = ((device int32_t *) src1)[i];
253
+
254
+ dequantize_row_q4_1(
255
+ (device const block_q4_1 *) ((device char *) src0 + r*nb01),
256
+ (device float *) ((device char *) dst + i*nb1), ne00);
257
+ }
258
+
259
+ kernel void kernel_rms_norm(
260
+ device const void * src0,
261
+ device float * dst,
262
+ constant int64_t & ne00,
263
+ constant uint64_t & nb01,
264
+ constant float & eps,
265
+ threadgroup float * sum [[threadgroup(0)]],
266
+ uint tgpig[[threadgroup_position_in_grid]],
267
+ uint tpitg[[thread_position_in_threadgroup]],
268
+ uint ntg[[threads_per_threadgroup]]) {
269
+ device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
270
+
271
+ // parallel sum
272
+ sum[tpitg] = 0.0f;
273
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
274
+ sum[tpitg] += x[i00] * x[i00];
275
+ }
276
+
277
+ // reduce
278
+ threadgroup_barrier(mem_flags::mem_threadgroup);
279
+ for (uint i = ntg/2; i > 0; i /= 2) {
280
+ if (tpitg < i) {
281
+ sum[tpitg] += sum[tpitg + i];
282
+ }
283
+ threadgroup_barrier(mem_flags::mem_threadgroup);
284
+ }
285
+
286
+ // broadcast
287
+ if (tpitg == 0) {
288
+ sum[0] /= ne00;
289
+ }
290
+
291
+ threadgroup_barrier(mem_flags::mem_threadgroup);
292
+
293
+ const float mean = sum[0];
294
+ const float scale = 1.0f/sqrt(mean + eps);
295
+
296
+ device float * y = dst + tgpig*ne00;
297
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
298
+ y[i00] = x[i00] * scale;
299
+ }
300
+ }
301
+
302
+ kernel void kernel_mul_mat_q4_0_f32(
303
+ device const void * src0,
304
+ device const float * src1,
305
+ device float * dst,
306
+ constant int64_t & ne00,
307
+ constant int64_t & ne10,
308
+ constant int64_t & ne0,
309
+ threadgroup float * sum [[threadgroup(0)]],
310
+ uint2 tgpig[[threadgroup_position_in_grid]],
311
+ uint2 tpitg[[thread_position_in_threadgroup]],
312
+ uint2 tptg[[threads_per_threadgroup]]) {
313
+ const int nb = ne00/QK4_0;
314
+
315
+ const int64_t r0 = tgpig.x;
316
+ const int64_t r1 = tgpig.y;
317
+
318
+ device const block_q4_0 * x = (device const block_q4_0 *) src0 + r0*nb;
319
+ device const float * y = (device const float *) src1 + r1*ne10;
320
+
321
+ const int nth = tptg.x*tptg.y;
322
+ const int ith = tptg.y*tpitg.x + tpitg.y;
323
+
324
+ const int ix = tpitg.y/4; // 0 or 1
325
+ const int iy = tpitg.y - 4*ix; // 0...3
326
+
327
+ const int first = 4 * iy;
328
+
329
+ float sumf = 0;
330
+
331
+ for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) {
332
+
333
+ const float d = (float)x[i].d;
334
+
335
+ device const uint8_t * xl = x[i].qs + first;
336
+ device const float * yl = y + i * QK4_0 + first;
337
+
338
+ float2 acc = {0.0f, 0.0f};
339
+
340
+ for (int j = 0; j < 4; ++j) {
341
+
342
+ acc[0] += yl[j] * (xl[j] & 0xF) + yl[j+16] * (xl[j] >> 4);
343
+ acc[1] += yl[j] + yl[j+16];
344
+
345
+ }
346
+
347
+ sumf += d * (acc[0] - 8.f*acc[1]);
348
+ }
349
+
350
+ sum[ith] = sumf;
351
+
352
+ //
353
+ // Accumulate the sum from all threads in the threadgroup
354
+ //
355
+ threadgroup_barrier(mem_flags::mem_threadgroup);
356
+ if (ith%4 == 0) {
357
+ sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
358
+ }
359
+ threadgroup_barrier(mem_flags::mem_threadgroup);
360
+ if (ith%16 == 0) {
361
+ sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
362
+ }
363
+ threadgroup_barrier(mem_flags::mem_threadgroup);
364
+ if (ith == 0) {
365
+ for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
366
+ dst[r1*ne0 + r0] = sum[0];
367
+ }
368
+ }
369
+
370
+ kernel void kernel_mul_mat_q4_1_f32(
371
+ device const void * src0,
372
+ device const float * src1,
373
+ device float * dst,
374
+ constant int64_t & ne00,
375
+ constant int64_t & ne10,
376
+ constant int64_t & ne0,
377
+ threadgroup float * sum [[threadgroup(0)]],
378
+ uint2 tgpig[[threadgroup_position_in_grid]],
379
+ uint2 tpitg[[thread_position_in_threadgroup]],
380
+ uint2 tptg[[threads_per_threadgroup]]) {
381
+ const int nb = ne00/QK4_1;
382
+
383
+ const int64_t r0 = tgpig.x;
384
+ const int64_t r1 = tgpig.y;
385
+
386
+ device const block_q4_1 * x = (device const block_q4_1 *) src0 + r0*nb;
387
+ device const float * y = (device const float *) src1 + r1*ne10;
388
+
389
+ const uint nth = tptg.x*tptg.y;
390
+ const uint ith = tptg.y*tpitg.x + tpitg.y;
391
+
392
+ const int ix = tpitg.y/4; // 0 or 1
393
+ const int iy = tpitg.y - 4*ix; // 0...3
394
+
395
+ const int first = 4 * iy;
396
+
397
+ float sumf = 0;
398
+
399
+ for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) {
400
+
401
+ const float d = (float)x[i].d;
402
+ const float m = (float)x[i].m;
403
+
404
+ device const uint8_t * xl = x[i].qs + first;
405
+ device const float * yl = y + i * QK4_1 + first;
406
+
407
+ float2 acc = {0.0f, 0.0f};
408
+
409
+ for (int j = 0; j < 4; ++j) {
410
+
411
+ acc[0] += yl[j+ 0] * (d * (xl[j] & 0xF) + m);
412
+ acc[1] += yl[j+16] * (d * (xl[j] >> 4) + m);
413
+
414
+ }
415
+
416
+ sumf += acc[0] + acc[1];
417
+ }
418
+
419
+ sum[ith] = sumf;
420
+
421
+ //
422
+ // Accumulate the sum from all threads in the threadgroup
423
+ //
424
+ threadgroup_barrier(mem_flags::mem_threadgroup);
425
+ if (ith%4 == 0) {
426
+ sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
427
+ }
428
+ threadgroup_barrier(mem_flags::mem_threadgroup);
429
+ if (ith%16 == 0) {
430
+ sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
431
+ }
432
+ threadgroup_barrier(mem_flags::mem_threadgroup);
433
+ if (ith == 0) {
434
+ for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
435
+ dst[r1*ne0 + r0] = sum[0];
436
+ }
437
+ }
438
+
439
+ kernel void kernel_mul_mat_f16_f32(
440
+ device const char * src0,
441
+ device const char * src1,
442
+ device float * dst,
443
+ constant int64_t & ne00,
444
+ constant int64_t & ne01,
445
+ constant uint64_t & nb00,
446
+ constant uint64_t & nb01,
447
+ constant uint64_t & nb02,
448
+ constant int64_t & ne10,
449
+ constant int64_t & ne11,
450
+ constant uint64_t & nb10,
451
+ constant uint64_t & nb11,
452
+ constant uint64_t & nb12,
453
+ constant int64_t & ne0,
454
+ constant int64_t & ne1,
455
+ threadgroup float * sum [[threadgroup(0)]],
456
+ uint3 tgpig[[threadgroup_position_in_grid]],
457
+ uint3 tpig[[thread_position_in_grid]],
458
+ uint3 tpitg[[thread_position_in_threadgroup]],
459
+ uint3 tptg[[threads_per_threadgroup]]) {
460
+
461
+ const int64_t r0 = tgpig.x;
462
+ const int64_t r1 = tgpig.y;
463
+ const int64_t im = tgpig.z;
464
+
465
+ device const half * x = (device const half *) (src0 + r0*nb01 + im*nb02);
466
+ device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
467
+
468
+ sum[tpitg.x] = 0.0f;
469
+
470
+ for (int i = tpitg.x; i < ne00; i += tptg.x) {
471
+ sum[tpitg.x] += (float) x[i] * (float) y[i];
472
+ }
473
+
474
+ // accumulate the sum from all threads in the threadgroup
475
+ threadgroup_barrier(mem_flags::mem_threadgroup);
476
+ for (uint i = tptg.x/2; i > 0; i /= 2) {
477
+ if (tpitg.x < i) {
478
+ sum[tpitg.x] += sum[tpitg.x + i];
479
+ }
480
+ threadgroup_barrier(mem_flags::mem_threadgroup);
481
+ }
482
+
483
+ if (tpitg.x == 0) {
484
+ dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
485
+ }
486
+ }
487
+
488
+ kernel void kernel_rope(
489
+ device const void * src0,
490
+ device float * dst,
491
+ constant int64_t & ne00,
492
+ constant int64_t & ne01,
493
+ constant int64_t & ne02,
494
+ constant int64_t & ne03,
495
+ constant uint64_t & nb00,
496
+ constant uint64_t & nb01,
497
+ constant uint64_t & nb02,
498
+ constant uint64_t & nb03,
499
+ constant int64_t & ne0,
500
+ constant int64_t & ne1,
501
+ constant int64_t & ne2,
502
+ constant int64_t & ne3,
503
+ constant uint64_t & nb0,
504
+ constant uint64_t & nb1,
505
+ constant uint64_t & nb2,
506
+ constant uint64_t & nb3,
507
+ constant int & n_past,
508
+ constant int & n_dims,
509
+ constant int & mode,
510
+ uint3 tpig[[thread_position_in_grid]]) {
511
+ const int64_t i3 = tpig[2];
512
+ const int64_t i2 = tpig[1];
513
+ const int64_t i1 = tpig[0];
514
+
515
+ const bool is_neox = mode & 2;
516
+ const float theta_scale = pow(10000.0, -2.0f/n_dims);
517
+
518
+ const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
519
+
520
+ float theta = (float)p;
521
+
522
+ if (!is_neox) {
523
+ for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
524
+ const float cos_theta = cos(theta);
525
+ const float sin_theta = sin(theta);
526
+
527
+ theta *= theta_scale;
528
+
529
+ device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
530
+ device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
531
+
532
+ const float x0 = src[0];
533
+ const float x1 = src[1];
534
+
535
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
536
+ dst_data[1] = x0*sin_theta + x1*cos_theta;
537
+ }
538
+ } else {
539
+ // TODO: implement
540
+ }
541
+ }
542
+
543
+ kernel void kernel_cpy_f32_f16(
544
+ device const float * src0,
545
+ device half * dst,
546
+ constant int64_t & ne00,
547
+ constant int64_t & ne01,
548
+ constant int64_t & ne02,
549
+ constant int64_t & ne03,
550
+ constant uint64_t & nb00,
551
+ constant uint64_t & nb01,
552
+ constant uint64_t & nb02,
553
+ constant uint64_t & nb03,
554
+ constant int64_t & ne0,
555
+ constant int64_t & ne1,
556
+ constant int64_t & ne2,
557
+ constant int64_t & ne3,
558
+ constant uint64_t & nb0,
559
+ constant uint64_t & nb1,
560
+ constant uint64_t & nb2,
561
+ constant uint64_t & nb3,
562
+ uint3 tgpig[[threadgroup_position_in_grid]],
563
+ uint3 tpitg[[thread_position_in_threadgroup]],
564
+ uint3 ntg[[threads_per_threadgroup]]) {
565
+ const int64_t i03 = tgpig[2];
566
+ const int64_t i02 = tgpig[1];
567
+ const int64_t i01 = tgpig[0];
568
+
569
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
570
+
571
+ const int64_t i3 = n / (ne2*ne1*ne0);
572
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
573
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
574
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
575
+
576
+ device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
577
+
578
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
579
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
580
+
581
+ dst_data[i00] = src[0];
582
+ }
583
+ }
584
+
585
+ kernel void kernel_cpy_f32_f32(
586
+ device const float * src0,
587
+ device float * dst,
588
+ constant int64_t & ne00,
589
+ constant int64_t & ne01,
590
+ constant int64_t & ne02,
591
+ constant int64_t & ne03,
592
+ constant uint64_t & nb00,
593
+ constant uint64_t & nb01,
594
+ constant uint64_t & nb02,
595
+ constant uint64_t & nb03,
596
+ constant int64_t & ne0,
597
+ constant int64_t & ne1,
598
+ constant int64_t & ne2,
599
+ constant int64_t & ne3,
600
+ constant uint64_t & nb0,
601
+ constant uint64_t & nb1,
602
+ constant uint64_t & nb2,
603
+ constant uint64_t & nb3,
604
+ uint3 tgpig[[threadgroup_position_in_grid]],
605
+ uint3 tpitg[[thread_position_in_threadgroup]],
606
+ uint3 ntg[[threads_per_threadgroup]]) {
607
+ const int64_t i03 = tgpig[2];
608
+ const int64_t i02 = tgpig[1];
609
+ const int64_t i01 = tgpig[0];
610
+
611
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
612
+
613
+ const int64_t i3 = n / (ne2*ne1*ne0);
614
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
615
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
616
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
617
+
618
+ device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
619
+
620
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
621
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
622
+
623
+ dst_data[i00] = src[0];
624
+ }
625
+ }
626
+
627
+ //============================================ k-quants ======================================================
628
+
629
+ #define QK_K 256
630
+
631
+ typedef struct {
632
+ uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
633
+ uint8_t qs[QK_K/4]; // quants
634
+ half d; // super-block scale for quantized scales
635
+ half dmin; // super-block scale for quantized mins
636
+ } block_q2_k;
637
+ // 84 bytes / block
638
+
639
+ typedef struct {
640
+ uint8_t hmask[QK_K/8]; // quants - high bit
641
+ uint8_t qs[QK_K/4]; // quants - low 2 bits
642
+ uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
643
+ half d; // super-block scale
644
+ } block_q3_k;
645
+ // 110 bytes / block
646
+
647
+ typedef struct {
648
+ half d; // super-block scale for quantized scales
649
+ half dmin; // super-block scale for quantized mins
650
+ uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
651
+ uint8_t qs[QK_K/2]; // 4--bit quants
652
+ } block_q4_k;
653
+ // 144 bytes / block
654
+
655
+ typedef struct {
656
+ half d; // super-block scale for quantized scales
657
+ half dmin; // super-block scale for quantized mins
658
+ uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
659
+ uint8_t qh[QK_K/8]; // quants, high bit
660
+ uint8_t qs[QK_K/2]; // quants, low 4 bits
661
+ } block_q5_k;
662
+ // 176 bytes / block
663
+
664
+ typedef struct {
665
+ uint8_t ql[QK_K/2]; // quants, lower 4 bits
666
+ uint8_t qh[QK_K/4]; // quants, upper 2 bits
667
+ int8_t scales[QK_K/16]; // scales, quantized with 8 bits
668
+ half d; // super-block scale
669
+ } block_q6_k;
670
+ // 210 bytes / block
671
+
672
+ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
673
+ uchar4 r;
674
+ if (j < 4) {
675
+ r[0] = q[j+0] & 63;
676
+ r[2] = q[j+1] & 63;
677
+ r[1] = q[j+4] & 63;
678
+ r[3] = q[j+5] & 63;
679
+ } else {
680
+ r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
681
+ r[2] = (q[j+5] & 0xF) | ((q[j-3] >> 6) << 4);
682
+ r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
683
+ r[3] = (q[j+5] >> 4) | ((q[j+1] >> 6) << 4);
684
+ }
685
+ return r;
686
+ }
687
+
688
+ //========================================== dequantization =============================
689
+
690
+ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, int k) {
691
+ assert(k % QK_K == 0);
692
+ const int nb = k / QK_K;
693
+
694
+ for (int i = 0; i < nb; i++) {
695
+
696
+ const float d = x[i].d;
697
+ const float min = x[i].dmin;
698
+
699
+ device const uint8_t * q = x[i].qs;
700
+
701
+ int is = 0;
702
+ float dl, ml;
703
+ for (int n = 0; n < QK_K; n += 128) {
704
+ int shift = 0;
705
+ for (int j = 0; j < 4; ++j) {
706
+
707
+ uint8_t sc = x[i].scales[is++];
708
+ dl = d * (sc & 0xF); ml = min * (sc >> 4);
709
+ for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml;
710
+
711
+ sc = x[i].scales[is++];
712
+ dl = d * (sc & 0xF); ml = min * (sc >> 4);
713
+ for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml;
714
+
715
+ shift += 2;
716
+ }
717
+ q += 32;
718
+ }
719
+
720
+ }
721
+ }
722
+
723
+ static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, int k) {
724
+ assert(k % QK_K == 0);
725
+ const int nb = k / QK_K;
726
+
727
+ const uint16_t kmask1 = 0x0303;
728
+ const uint16_t kmask2 = 0x0f0f;
729
+
730
+ uint16_t aux[8];
731
+ thread const int8_t * scales = (thread const int8_t*)aux;
732
+
733
+ for (int i = 0; i < nb; i++) {
734
+
735
+ const float d_all = (float)(x[i].d);
736
+
737
+ device const uint8_t * q = x[i].qs;
738
+ device const uint8_t * h = x[i].hmask;
739
+ uint8_t m = 1;
740
+
741
+ device const uint16_t * a = (device const uint16_t *)x[i].scales;
742
+ aux[0] = (a[0] & kmask2) | (((a[4] >> 0) & kmask1) << 4);
743
+ aux[1] = (a[1] & kmask2) | (((a[5] >> 0) & kmask1) << 4);
744
+ aux[2] = (a[2] & kmask2) | (((a[4] >> 2) & kmask1) << 4);
745
+ aux[3] = (a[3] & kmask2) | (((a[5] >> 2) & kmask1) << 4);
746
+ aux[4] = ((a[0] >> 4) & kmask2) | (((a[4] >> 4) & kmask1) << 4);
747
+ aux[5] = ((a[1] >> 4) & kmask2) | (((a[5] >> 4) & kmask1) << 4);
748
+ aux[6] = ((a[2] >> 4) & kmask2) | (((a[4] >> 6) & kmask1) << 4);
749
+ aux[7] = ((a[3] >> 4) & kmask2) | (((a[5] >> 6) & kmask1) << 4);
750
+
751
+ int is = 0;
752
+ float dl;
753
+ for (int n = 0; n < QK_K; n += 128) {
754
+ int shift = 0;
755
+ for (int j = 0; j < 4; ++j) {
756
+
757
+ dl = d_all * (scales[is++] - 32);
758
+ for (int l = 0; l < 16; ++l) {
759
+ *y++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((h[l+ 0] & m) ? 0 : 4));
760
+ }
761
+
762
+ dl = d_all * (scales[is++] - 32);
763
+ for (int l = 0; l < 16; ++l) {
764
+ *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((h[l+16] & m) ? 0 : 4));
765
+ }
766
+
767
+ shift += 2;
768
+ m <<= 1;
769
+ }
770
+ q += 32;
771
+ }
772
+
773
+ }
774
+
775
+ }
776
+
777
+ static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, int k) {
778
+ assert(k % QK_K == 0);
779
+ const int nb = k / QK_K;
780
+
781
+
782
+ for (int i = 0; i < nb; i++) {
783
+
784
+ const float d = x[i].d;
785
+ const float min = x[i].dmin;
786
+
787
+ device const uint8_t * q = x[i].qs;
788
+ device const uint8_t * scales = x[i].scales;
789
+
790
+ int is = 0;
791
+ for (int j = 0; j < QK_K; j += 64) {
792
+ const uchar4 sc = get_scale_min_k4(is, scales);
793
+ const float d1 = d * sc[0]; const float m1 = min * sc[1];
794
+ const float d2 = d * sc[2]; const float m2 = min * sc[3];
795
+ for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
796
+ for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
797
+ q += 32; is += 2;
798
+ }
799
+
800
+ }
801
+ }
802
+
803
+ static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, int k) {
804
+ assert(k % QK_K == 0);
805
+ const int nb = k / QK_K;
806
+
807
+ for (int i = 0; i < nb; i++) {
808
+
809
+ const float d = (float)(x[i].d);
810
+ const float min = (float)(x[i].dmin);
811
+
812
+ device const uint8_t * ql = x[i].qs;
813
+ device const uint8_t * qh = x[i].qh;
814
+
815
+ int is = 0;
816
+ uint8_t u1 = 1, u2 = 2;
817
+ for (int j = 0; j < QK_K; j += 64) {
818
+ const uchar4 sc = get_scale_min_k4(is, x[i].scales);
819
+ const float d1 = d * sc[0]; const float m1 = min * sc[1];
820
+ const float d2 = d * sc[2]; const float m2 = min * sc[3];
821
+ for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1;
822
+ for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2;
823
+ ql += 32; is += 2;
824
+ u1 <<= 2; u2 <<= 2;
825
+ }
826
+ }
827
+
828
+ }
829
+
830
+ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, int k) {
831
+ assert(k % QK_K == 0);
832
+ const int nb = k / QK_K;
833
+
834
+ for (int i = 0; i < nb; i++) {
835
+
836
+ device const uint8_t * ql = x[i].ql;
837
+ device const uint8_t * qh = x[i].qh;
838
+ device const int8_t * sc = x[i].scales;
839
+
840
+ const float d = x[i].d;
841
+
842
+ for (int n = 0; n < QK_K; n += 128) {
843
+ for (int l = 0; l < 32; ++l) {
844
+ int is = l/16;
845
+ const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
846
+ const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
847
+ const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
848
+ const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
849
+ y[l + 0] = d * sc[is + 0] * q1;
850
+ y[l + 32] = d * sc[is + 2] * q2;
851
+ y[l + 64] = d * sc[is + 4] * q3;
852
+ y[l + 96] = d * sc[is + 6] * q4;
853
+ }
854
+ y += 128;
855
+ ql += 64;
856
+ qh += 32;
857
+ sc += 8;
858
+ }
859
+ }
860
+ }
861
+
862
+ kernel void kernel_get_rows_q2_k(
863
+ device const void * src0,
864
+ device const int * src1,
865
+ device float * dst,
866
+ constant int64_t & ne00,
867
+ constant uint64_t & nb01,
868
+ constant uint64_t & nb1,
869
+ uint tpig[[thread_position_in_grid]]) {
870
+ const int i = tpig;
871
+ const int r = ((device int32_t *) src1)[i];
872
+
873
+ dequantize_row_q2_k(
874
+ (device const block_q2_k *) ((device char *) src0 + r*nb01),
875
+ (device float *) ((device char *) dst + i*nb1), ne00);
876
+ }
877
+
878
+ kernel void kernel_get_rows_q3_k(
879
+ device const void * src0,
880
+ device const int * src1,
881
+ device float * dst,
882
+ constant int64_t & ne00,
883
+ constant uint64_t & nb01,
884
+ constant uint64_t & nb1,
885
+ uint tpig[[thread_position_in_grid]]) {
886
+ const int i = tpig;
887
+ const int r = ((device int32_t *) src1)[i];
888
+
889
+ dequantize_row_q3_k(
890
+ (device const block_q3_k *) ((device char *) src0 + r*nb01),
891
+ (device float *) ((device char *) dst + i*nb1), ne00);
892
+ }
893
+
894
+ kernel void kernel_get_rows_q4_k(
895
+ device const void * src0,
896
+ device const int * src1,
897
+ device float * dst,
898
+ constant int64_t & ne00,
899
+ constant uint64_t & nb01,
900
+ constant uint64_t & nb1,
901
+ uint tpig[[thread_position_in_grid]]) {
902
+ const int i = tpig;
903
+ const int r = ((device int32_t *) src1)[i];
904
+
905
+ dequantize_row_q4_k(
906
+ (device const block_q4_k *) ((device char *) src0 + r*nb01),
907
+ (device float *) ((device char *) dst + i*nb1), ne00);
908
+ }
909
+
910
+ kernel void kernel_get_rows_q5_k(
911
+ device const void * src0,
912
+ device const int * src1,
913
+ device float * dst,
914
+ constant int64_t & ne00,
915
+ constant uint64_t & nb01,
916
+ constant uint64_t & nb1,
917
+ uint tpig[[thread_position_in_grid]]) {
918
+ const int i = tpig;
919
+ const int r = ((device int32_t *) src1)[i];
920
+
921
+ dequantize_row_q5_k(
922
+ (device const block_q5_k *) ((device char *) src0 + r*nb01),
923
+ (device float *) ((device char *) dst + i*nb1), ne00);
924
+ }
925
+
926
+ kernel void kernel_get_rows_q6_k(
927
+ device const void * src0,
928
+ device const int * src1,
929
+ device float * dst,
930
+ constant int64_t & ne00,
931
+ constant uint64_t & nb01,
932
+ constant uint64_t & nb1,
933
+ uint tpig[[thread_position_in_grid]]) {
934
+ const int i = tpig;
935
+ const int r = ((device int32_t *) src1)[i];
936
+
937
+ dequantize_row_q6_k(
938
+ (device const block_q6_k *) ((device char *) src0 + r*nb01),
939
+ (device float *) ((device char *) dst + i*nb1), ne00);
940
+ }
941
+
942
+ //====================================== dot products =========================
943
+
944
+ kernel void kernel_mul_mat_q2_k_f32(
945
+ device const void * src0,
946
+ device const float * src1,
947
+ device float * dst,
948
+ constant int64_t & ne00,
949
+ constant int64_t & ne10,
950
+ constant int64_t & ne0,
951
+ threadgroup float * sum [[threadgroup(0)]],
952
+ uint2 tgpig[[threadgroup_position_in_grid]],
953
+ uint2 tpitg[[thread_position_in_threadgroup]],
954
+ uint2 tptg[[threads_per_threadgroup]]) {
955
+
956
+ const int nb = ne00/QK_K;
957
+
958
+ const int64_t r0 = tgpig.x;
959
+ const int64_t r1 = tgpig.y;
960
+
961
+ device const block_q2_k * x = (device const block_q2_k *) src0 + r0*nb;
962
+ device const float * yy = (device const float *) src1 + r1*ne10;
963
+
964
+ const int nth = tptg.x*tptg.y;
965
+ const int ith = tptg.y*tpitg.x + tpitg.y;
966
+
967
+ const int tid = tpitg.y; // 0...16
968
+ const int il = tid/4; // 0...3
969
+ const int ir = tid%4; // 0...3
970
+ const int ip = il/2; // 0 or 1
971
+ const int shift1 = 4*(il%2);// 0 or 4
972
+ const int shift2 = shift1+2;// 2 or 6
973
+ const int n = 8;
974
+ const int is = 4*il + (n*ir)/16;
975
+
976
+ const int y_offset = 64*il + n*ir;
977
+ const int q_offset = 32*ip + n*ir;
978
+
979
+ sum[ith] = 0.0f;
980
+
981
+ float sumf = 0;
982
+ for (int i = tpitg.x; i < nb; i += tptg.x) {
983
+
984
+ device const uint8_t * q = x[i].qs + q_offset;
985
+ device const uint8_t * scales = x[i].scales + is;
986
+
987
+ uint8_t d1 = scales[0] & 0xF;
988
+ uint8_t d2 = scales[2] & 0xF;
989
+ uint8_t m1 = scales[0] >> 4;
990
+ uint8_t m2 = scales[2] >> 4;
991
+
992
+ device const float * y = yy + i*QK_K + y_offset;
993
+
994
+ //float4 s = {0.f, 0.f, 0.f, 0.f};
995
+ float2 s = {0.f, 0.f};
996
+ float smin = 0;
997
+ for (int l = 0; l < n; ++l) {
998
+ s[0] += y[l+ 0] * ((q[l] >> shift1) & 3);
999
+ s[1] += y[l+32] * ((q[l] >> shift2) & 3);
1000
+ smin += y[l+ 0] * m1 + y[l+32] * m2;
1001
+ }
1002
+
1003
+ const float dall = (float)x[i].d;
1004
+ const float dmin = (float)x[i].dmin;
1005
+
1006
+ sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin;
1007
+
1008
+ }
1009
+ sum[ith] = sumf;
1010
+
1011
+ //int mask1 = (ith%4 == 0);
1012
+ //int mask2 = (ith%16 == 0);
1013
+
1014
+ //threadgroup_barrier(mem_flags::mem_threadgroup);
1015
+ //for (int i = 1; i < 4; ++i) sum[ith] += mask1 * sum[ith + i];
1016
+ //threadgroup_barrier(mem_flags::mem_threadgroup);
1017
+ //for (int i = 4; i < 16; i += 4) sum[ith] += mask2 * sum[ith + i];
1018
+ //threadgroup_barrier(mem_flags::mem_threadgroup);
1019
+ //if (ith == 0) {
1020
+ // for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1021
+ // dst[r1*ne0 + r0] = sum[0];
1022
+ //}
1023
+
1024
+ //
1025
+ // Accumulate the sum from all threads in the threadgroup
1026
+ // This version is slightly faster than the commented out one below,
1027
+ // which I copy-pasted from ggerganov's q4_0 dot product for metal.
1028
+ //
1029
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1030
+ if (ith%4 == 0) {
1031
+ for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
1032
+ }
1033
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1034
+ if (ith%16 == 0) {
1035
+ for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
1036
+ }
1037
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1038
+ if (ith == 0) {
1039
+ for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1040
+ dst[r1*ne0 + r0] = sum[0];
1041
+ }
1042
+ }
1043
+
1044
+ kernel void kernel_mul_mat_q3_k_f32(
1045
+ device const void * src0,
1046
+ device const float * src1,
1047
+ device float * dst,
1048
+ constant int64_t & ne00,
1049
+ constant int64_t & ne10,
1050
+ constant int64_t & ne0,
1051
+ constant int64_t & ne1,
1052
+ threadgroup float * sum [[threadgroup(0)]],
1053
+ uint2 tgpig[[threadgroup_position_in_grid]],
1054
+ uint2 tpitg[[thread_position_in_threadgroup]],
1055
+ uint2 tptg[[threads_per_threadgroup]]) {
1056
+
1057
+ const uint16_t kmask1 = 0x0303;
1058
+ const uint16_t kmask2 = 0x0f0f;
1059
+
1060
+ const uint8_t m3 = 3;
1061
+ const int8_t m4 = 4;
1062
+
1063
+ const int nb = ne00/QK_K;
1064
+
1065
+ const int64_t r0 = tgpig.x;
1066
+ const int64_t r1 = tgpig.y;
1067
+
1068
+ device const block_q3_k * x = (device const block_q3_k *) src0 + r0*nb;
1069
+ device const float * yy = (device const float *) src1 + r1*ne10;
1070
+
1071
+ const int nth = tptg.x*tptg.y;
1072
+ const int ith = tptg.y*tpitg.x + tpitg.y;
1073
+
1074
+ const int tid = tpitg.y; // expecting 16
1075
+ const int ip = tid/8; // 0 or 1
1076
+ const int il = tid/2 - 4*ip; // 0...3
1077
+ const int ir = tid%2;
1078
+ const int n = 8;
1079
+ const int l0 = n*ir;
1080
+
1081
+ const uint8_t m = 1 << (4*ip + il);
1082
+
1083
+ const int shift = 2*il;
1084
+
1085
+ const uint16_t s_shift1 = 4*ip;
1086
+ const uint16_t s_shift2 = s_shift1 + 2*(il/2);
1087
+ const int ik = 4 + (il%2);
1088
+
1089
+ const int q_offset = 32*ip + l0;
1090
+ const int y_offset = 128*ip + 32*il + l0;
1091
+
1092
+ //float sumf = 0;
1093
+ float sumf1 = 0, sumf2 = 0;
1094
+ for (int i = tpitg.x; i < nb; i += tptg.x) {
1095
+
1096
+ const float d_all = (float)(x[i].d);
1097
+
1098
+ device const uint8_t * q = x[i].qs + q_offset;
1099
+ device const uint8_t * h = x[i].hmask + l0;
1100
+ device const float * y = yy + i * QK_K + y_offset;
1101
+
1102
+ device const uint16_t * a = (device const uint16_t *)x[i].scales;
1103
+ const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
1104
+
1105
+ float s = 0;
1106
+ for (int l = 0; l < n; ++l) {
1107
+ s += y[l+ 0] * ((int8_t)((q[l+ 0] >> shift) & m3) - ((h[l+ 0] & m) ? 0 : m4));
1108
+ }
1109
+ float d = d_all * s;
1110
+ sumf1 += d * scales[0];
1111
+ sumf2 += d;
1112
+ //sumf += d_all * s * (scales[0] - 32);
1113
+
1114
+ s = 0;
1115
+ for (int l = 0; l < n; ++l) {
1116
+ s += y[l+16] * ((int8_t)((q[l+16] >> shift) & m3) - ((h[l+16] & m) ? 0 : m4));
1117
+ }
1118
+ d = d_all * s;
1119
+ sumf1 += d * scales[1];
1120
+ sumf2 += d;
1121
+ //sumf += d_all * s * (scales[1] - 32);
1122
+
1123
+ }
1124
+
1125
+ //sum[ith] = sumf;
1126
+ sum[ith] = sumf1 - 32.f*sumf2;
1127
+
1128
+ //
1129
+ // Accumulate the sum from all threads in the threadgroup
1130
+ //
1131
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1132
+ if (ith%4 == 0) {
1133
+ for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
1134
+ }
1135
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1136
+ if (ith%16 == 0) {
1137
+ for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
1138
+ }
1139
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1140
+ if (ith == 0) {
1141
+ for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1142
+ dst[r1*ne0 + r0] = sum[0];
1143
+ }
1144
+
1145
+ }
1146
+
1147
+ kernel void kernel_mul_mat_q4_k_f32(
1148
+ device const void * src0,
1149
+ device const float * src1,
1150
+ device float * dst,
1151
+ constant int64_t & ne00,
1152
+ constant int64_t & ne10,
1153
+ constant int64_t & ne0,
1154
+ threadgroup float * sum [[threadgroup(0)]],
1155
+ uint2 tgpig[[threadgroup_position_in_grid]],
1156
+ uint2 tpitg[[thread_position_in_threadgroup]],
1157
+ uint2 tptg[[threads_per_threadgroup]]) {
1158
+
1159
+ const uint16_t kmask1 = 0x3f3f;
1160
+ const uint16_t kmask2 = 0x0f0f;
1161
+ const uint16_t kmask3 = 0xc0c0;
1162
+
1163
+ const int nb = ne00/QK_K;
1164
+
1165
+ const int64_t r0 = tgpig.x;
1166
+ const int64_t r1 = tgpig.y;
1167
+
1168
+ device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb;
1169
+ device const float * yy = (device const float *) src1 + r1*ne10;
1170
+
1171
+ const int nth = tptg.x*tptg.y;
1172
+ const int ith = tptg.y*tpitg.x + tpitg.y;
1173
+
1174
+ const int tid = tpitg.y; // 0...16
1175
+ const int il = tid/4; // 0...3
1176
+ const int ir = tid - 4*il;// 0...3
1177
+ const int n = 4;
1178
+
1179
+ const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
1180
+ const int in = il%2;
1181
+
1182
+ const int l0 = n*(2*ir + in);
1183
+ const int q_offset = 32*im + l0;
1184
+ const int y_offset = 64*im + l0;
1185
+
1186
+ sum[ith] = 0.0f;
1187
+
1188
+ uchar2 sc1, sc2, sc3, sc4;
1189
+
1190
+ float sumf = 0;
1191
+ for (int i = tpitg.x; i < nb; i += tptg.x) {
1192
+
1193
+ device const uint8_t * q1 = (x + i)->qs + q_offset;
1194
+ device const uint8_t * q2 = q1 + 64;
1195
+ device const float * y1 = yy + i*QK_K + y_offset;
1196
+ device const float * y2 = y1 + 128;
1197
+
1198
+ const float dall = (float)((x + i)->d);
1199
+ const float dmin = (float)((x + i)->dmin);
1200
+
1201
+ device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
1202
+ sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
1203
+ sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
1204
+ sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
1205
+ sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
1206
+
1207
+ float4 s = {0.f, 0.f, 0.f, 0.f};
1208
+ float smin = 0;
1209
+ for (int l = 0; l < n; ++l) {
1210
+
1211
+ s[0] += y1[l] * (q1[l] & 0xF); s[1] += y1[l+32] * (q1[l] >> 4);
1212
+ s[2] += y2[l] * (q2[l] & 0xF); s[3] += y2[l+32] * (q2[l] >> 4);
1213
+ smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
1214
+
1215
+ }
1216
+ sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
1217
+
1218
+ }
1219
+
1220
+ sum[ith] = sumf;
1221
+
1222
+ //
1223
+ // Accumulate the sum from all threads in the threadgroup
1224
+ // This version is slightly faster than the commented out one below,
1225
+ // which I copy-pasted from ggerganov's q4_0 dot product for metal.
1226
+ //
1227
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1228
+ if (ith%4 == 0) {
1229
+ for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
1230
+ }
1231
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1232
+ if (ith%16 == 0) {
1233
+ for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
1234
+ }
1235
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1236
+ if (ith == 0) {
1237
+ for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1238
+ dst[r1*ne0 + r0] = sum[0];
1239
+ }
1240
+
1241
+ //// accumulate the sum from all threads in the threadgroup
1242
+ //threadgroup_barrier(mem_flags::mem_threadgroup);
1243
+ //for (uint i = nth/2; i > 0; i /= 2) {
1244
+ // if (ith < i) {
1245
+ // sum[ith] += sum[ith + i];
1246
+ // }
1247
+ // threadgroup_barrier(mem_flags::mem_threadgroup);
1248
+ //}
1249
+
1250
+ //if (ith == 0) {
1251
+ // dst[r1*ne0 + r0] = sum[0];
1252
+ //}
1253
+ }
1254
+
1255
+ kernel void kernel_mul_mat_q5_k_f32(
1256
+ device const void * src0,
1257
+ device const float * src1,
1258
+ device float * dst,
1259
+ constant int64_t & ne00,
1260
+ constant int64_t & ne10,
1261
+ constant int64_t & ne0,
1262
+ threadgroup float * sum [[threadgroup(0)]],
1263
+ uint2 tgpig[[threadgroup_position_in_grid]],
1264
+ uint2 tpitg[[thread_position_in_threadgroup]],
1265
+ uint2 tptg[[threads_per_threadgroup]]) {
1266
+
1267
+ const uint16_t kmask1 = 0x3f3f;
1268
+ const uint16_t kmask2 = 0x0f0f;
1269
+ const uint16_t kmask3 = 0xc0c0;
1270
+
1271
+ const int nb = ne00/QK_K;
1272
+
1273
+ const int64_t r0 = tgpig.x;
1274
+ const int64_t r1 = tgpig.y;
1275
+
1276
+ device const block_q5_k * x = (device const block_q5_k *) src0 + r0*nb;
1277
+ device const float * yy = (device const float *) src1 + r1*ne10;
1278
+
1279
+ const int nth = tptg.x*tptg.y;
1280
+ const int ith = tptg.y*tpitg.x + tpitg.y;
1281
+
1282
+ const int tid = tpitg.y; // 0...16
1283
+ const int il = tid/4; // 0...3
1284
+ const int ir = tid - 4*il;// 0...3
1285
+ const int n = 4;
1286
+
1287
+ const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
1288
+ const int in = il%2;
1289
+
1290
+ const int l0 = n*(2*ir + in);
1291
+ const int q_offset = 32*im + l0;
1292
+ const int y_offset = 64*im + l0;
1293
+
1294
+ const uint8_t hm1 = 1u << (2*im);
1295
+ const uint8_t hm2 = hm1 << 1;
1296
+ const uint8_t hm3 = hm1 << 4;
1297
+ const uint8_t hm4 = hm2 << 4;
1298
+
1299
+ uchar2 sc1, sc2, sc3, sc4;
1300
+
1301
+ float sumf = 0;
1302
+ for (int i = tpitg.x; i < nb; i += tptg.x) {
1303
+
1304
+ device const uint8_t * q1 = (x + i)->qs + q_offset;
1305
+ device const uint8_t * q2 = q1 + 64;
1306
+ device const uint8_t * qh = (x + i)->qh + l0;
1307
+ device const float * y1 = yy + i*QK_K + y_offset;
1308
+ device const float * y2 = y1 + 128;
1309
+
1310
+ const float dall = (float)((x + i)->d);
1311
+ const float dmin = (float)((x + i)->dmin);
1312
+
1313
+ device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
1314
+ sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
1315
+ sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
1316
+ sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
1317
+ sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
1318
+
1319
+ float4 s = {0.f, 0.f, 0.f, 0.f};
1320
+ float smin = 0;
1321
+ for (int l = 0; l < n; ++l) {
1322
+
1323
+ s[0] += y1[l+ 0] * ((q1[l] & 0xF) + (qh[l] & hm1 ? 16 : 0));
1324
+ s[1] += y1[l+32] * ((q1[l] >> 4) + (qh[l] & hm2 ? 16 : 0));
1325
+ s[2] += y2[l+ 0] * ((q2[l] & 0xF) + (qh[l] & hm3 ? 16 : 0));
1326
+ s[3] += y2[l+32] * ((q2[l] >> 4) + (qh[l] & hm4 ? 16 : 0));
1327
+ smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
1328
+
1329
+ }
1330
+ sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
1331
+
1332
+ }
1333
+ sum[ith] = sumf;
1334
+
1335
+ //
1336
+ // Accumulate the sum from all threads in the threadgroup
1337
+ //
1338
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1339
+ if (ith%4 == 0) {
1340
+ sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
1341
+ }
1342
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1343
+ if (ith%16 == 0) {
1344
+ sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
1345
+ }
1346
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1347
+ if (ith == 0) {
1348
+ for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1349
+ dst[r1*ne0 + r0] = sum[0];
1350
+ }
1351
+
1352
+ }
1353
+
1354
+ kernel void kernel_mul_mat_q6_k_f32(
1355
+ device const void * src0,
1356
+ device const float * src1,
1357
+ device float * dst,
1358
+ constant int64_t & ne00,
1359
+ constant int64_t & ne10,
1360
+ constant int64_t & ne0,
1361
+ threadgroup float * sum [[threadgroup(0)]],
1362
+ uint2 tgpig[[threadgroup_position_in_grid]],
1363
+ uint2 tpitg[[thread_position_in_threadgroup]],
1364
+ uint2 tptg[[threads_per_threadgroup]]) {
1365
+
1366
+ const uint8_t kmask1 = 0x03;
1367
+ const uint8_t kmask2 = 0x0C;
1368
+ const uint8_t kmask3 = 0x30;
1369
+ const uint8_t kmask4 = 0xC0;
1370
+
1371
+ const int nb = ne00/QK_K;
1372
+
1373
+ const int64_t r0 = tgpig.x;
1374
+ const int64_t r1 = tgpig.y;
1375
+
1376
+ device const block_q6_k * x = (device const block_q6_k *) src0 + r0*nb;
1377
+ device const float * yy = (device const float *) src1 + r1*ne10;
1378
+
1379
+ const int nth = tptg.x*tptg.y;
1380
+ const int ith = tptg.y*tpitg.x + tpitg.y;
1381
+
1382
+ // Note: we absolutely assume that tptg.y = 16 and QK_K = 256!
1383
+ const int iqs = 16 * tpitg.y;
1384
+ const int ip = iqs / 128; // 0 or 1
1385
+ const int il = (iqs - 128*ip)/16; // 0...7
1386
+ const int n = 4;
1387
+ const int l0 = n*il;
1388
+ const int is = 8*ip + l0/16;
1389
+
1390
+ const int y_offset = 128*ip + l0;
1391
+ const int q_offset_l = 64*ip + l0;
1392
+ const int q_offset_h = 32*ip + l0;
1393
+
1394
+ float sumf = 0;
1395
+ for (int i = tpitg.x; i < nb; i += tptg.x) {
1396
+
1397
+ device const uint8_t * ql = x[i].ql + q_offset_l;
1398
+ device const uint8_t * qh = x[i].qh + q_offset_h;
1399
+ device const int8_t * sc = x[i].scales + is;
1400
+
1401
+ device const float * y = yy + i * QK_K + y_offset;
1402
+
1403
+ const float dall = x[i].d;
1404
+
1405
+ float4 sums = {0.f, 0.f, 0.f, 0.f};
1406
+ for (int l = 0; l < n; ++l) {
1407
+ sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
1408
+ sums[1] += y[l+32] * ((int8_t)((ql[l+32] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
1409
+ sums[2] += y[l+64] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
1410
+ sums[3] += y[l+96] * ((int8_t)((ql[l+32] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
1411
+ }
1412
+
1413
+ sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
1414
+
1415
+ }
1416
+
1417
+ sum[ith] = sumf;
1418
+
1419
+ //
1420
+ // Accumulate the sum from all threads in the threadgroup
1421
+ //
1422
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1423
+ if (ith%4 == 0) {
1424
+ for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
1425
+ }
1426
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1427
+ if (ith%16 == 0) {
1428
+ for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
1429
+ }
1430
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1431
+ if (ith == 0) {
1432
+ for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1433
+ dst[r1*ne0 + r0] = sum[0];
1434
+ }
1435
+
1436
+ }