llama_cpp 0.1.4 → 0.2.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,1133 @@
1
+ #include <metal_stdlib>
2
+
3
+ using namespace metal;
4
+
5
+ #define MAX(x, y) ((x) > (y) ? (x) : (y))
6
+
7
+ #define QK4_0 32
8
+ #define QR4_0 2
9
+ typedef struct {
10
+ half d; // delta
11
+ uint8_t qs[QK4_0 / 2]; // nibbles / quants
12
+ } block_q4_0;
13
+
14
+ #define QK4_1 32
15
+ typedef struct {
16
+ half d; // delta
17
+ half m; // min
18
+ uint8_t qs[QK4_1 / 2]; // nibbles / quants
19
+ } block_q4_1;
20
+
21
+ static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, int k) {
22
+ const int qk = QK4_0;
23
+
24
+ assert(k % qk == 0);
25
+
26
+ const int nb = k / qk;
27
+
28
+ for (int i = 0; i < nb; i++) {
29
+ const half d = x[i].d;
30
+
31
+ for (int j = 0; j < qk/2; ++j) {
32
+ const int x0 = (x[i].qs[j] & 0x0F) - 8;
33
+ const int x1 = (x[i].qs[j] >> 4) - 8;
34
+
35
+ y[i*qk + j + 0 ] = x0*d;
36
+ y[i*qk + j + qk/2] = x1*d;
37
+ }
38
+ }
39
+ }
40
+
41
+ static void dequantize_row_q4_1(device const block_q4_1 * x, device float * y, int k) {
42
+ const int qk = QK4_1;
43
+
44
+ assert(k % qk == 0);
45
+
46
+ const int nb = k / qk;
47
+
48
+ for (int i = 0; i < nb; i++) {
49
+ const half d = x[i].d;
50
+ const half m = x[i].m;
51
+
52
+ for (int j = 0; j < qk/2; ++j) {
53
+ const int x0 = (x[i].qs[j] & 0x0F);
54
+ const int x1 = (x[i].qs[j] >> 4);
55
+
56
+ y[i*qk + j + 0 ] = x0*d + m;
57
+ y[i*qk + j + qk/2] = x1*d + m;
58
+ }
59
+ }
60
+ }
61
+
62
+ kernel void kernel_add(
63
+ device const float * src0,
64
+ device const float * src1,
65
+ device float * dst,
66
+ uint tpig[[thread_position_in_grid]]) {
67
+ dst[tpig] = src0[tpig] + src1[tpig];
68
+ }
69
+
70
+ kernel void kernel_mul(
71
+ device const float * src0,
72
+ device const float * src1,
73
+ device float * dst,
74
+ uint tpig[[thread_position_in_grid]]) {
75
+ dst[tpig] = src0[tpig] * src1[tpig];
76
+ }
77
+
78
+ // assumption: src1 is a row
79
+ // broadcast src1 into src0
80
+ kernel void kernel_mul_row(
81
+ device const float * src0,
82
+ device const float * src1,
83
+ device float * dst,
84
+ constant int64_t & ne00,
85
+ uint tpig[[thread_position_in_grid]]) {
86
+ dst[tpig] = src0[tpig] * src1[tpig % ne00];
87
+ }
88
+
89
+ kernel void kernel_scale(
90
+ device const float * src0,
91
+ device float * dst,
92
+ constant float & scale,
93
+ uint tpig[[thread_position_in_grid]]) {
94
+ dst[tpig] = src0[tpig] * scale;
95
+ }
96
+
97
+ kernel void kernel_silu(
98
+ device const float * src0,
99
+ device float * dst,
100
+ uint tpig[[thread_position_in_grid]]) {
101
+ float x = src0[tpig];
102
+ dst[tpig] = x / (1.0f + exp(-x));
103
+ }
104
+
105
+ kernel void kernel_relu(
106
+ device const float * src0,
107
+ device float * dst,
108
+ uint tpig[[thread_position_in_grid]]) {
109
+ dst[tpig] = max(0.0f, src0[tpig]);
110
+ }
111
+
112
+ constant float GELU_COEF_A = 0.044715f;
113
+ constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
114
+
115
+ kernel void kernel_gelu(
116
+ device const float * src0,
117
+ device float * dst,
118
+ uint tpig[[thread_position_in_grid]]) {
119
+ float x = src0[tpig];
120
+ dst[tpig] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
121
+ }
122
+
123
+ kernel void kernel_soft_max(
124
+ device const float * src0,
125
+ device float * dst,
126
+ constant int64_t & ne00,
127
+ constant int64_t & ne01,
128
+ constant int64_t & ne02,
129
+ threadgroup float * buf [[threadgroup(0)]],
130
+ uint3 tgpig[[threadgroup_position_in_grid]],
131
+ uint3 tpitg[[thread_position_in_threadgroup]],
132
+ uint3 ntg[[threads_per_threadgroup]]) {
133
+ const int64_t i03 = tgpig[2];
134
+ const int64_t i02 = tgpig[1];
135
+ const int64_t i01 = tgpig[0];
136
+
137
+ device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
138
+ device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
139
+
140
+ // parallel max
141
+ buf[tpitg[0]] = -INFINITY;
142
+ for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
143
+ buf[tpitg[0]] = MAX(buf[tpitg[0]], psrc0[i00]);
144
+ }
145
+
146
+ // reduce
147
+ threadgroup_barrier(mem_flags::mem_threadgroup);
148
+ for (uint i = ntg[0]/2; i > 0; i /= 2) {
149
+ if (tpitg[0] < i) {
150
+ buf[tpitg[0]] = MAX(buf[tpitg[0]], buf[tpitg[0] + i]);
151
+ }
152
+ threadgroup_barrier(mem_flags::mem_threadgroup);
153
+ }
154
+
155
+ // broadcast
156
+ if (tpitg[0] == 0) {
157
+ buf[0] = buf[0];
158
+ }
159
+
160
+ threadgroup_barrier(mem_flags::mem_threadgroup);
161
+
162
+ const float max = buf[0];
163
+
164
+ // parallel sum
165
+ buf[tpitg[0]] = 0.0f;
166
+ for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
167
+ buf[tpitg[0]] += exp(psrc0[i00] - max);
168
+ }
169
+
170
+ // reduce
171
+ threadgroup_barrier(mem_flags::mem_threadgroup);
172
+ for (uint i = ntg[0]/2; i > 0; i /= 2) {
173
+ if (tpitg[0] < i) {
174
+ buf[tpitg[0]] += buf[tpitg[0] + i];
175
+ }
176
+ threadgroup_barrier(mem_flags::mem_threadgroup);
177
+ }
178
+
179
+ // broadcast
180
+ if (tpitg[0] == 0) {
181
+ buf[0] = buf[0];
182
+ }
183
+
184
+ threadgroup_barrier(mem_flags::mem_threadgroup);
185
+
186
+ const float sum = buf[0];
187
+
188
+ for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
189
+ pdst[i00] = exp(psrc0[i00] - max) / sum;
190
+ }
191
+ }
192
+
193
+ kernel void kernel_diag_mask_inf(
194
+ device const float * src0,
195
+ device float * dst,
196
+ constant int64_t & ne00,
197
+ constant int64_t & ne01,
198
+ constant int & n_past,
199
+ uint3 tpig[[thread_position_in_grid]]) {
200
+ const int64_t i02 = tpig[2];
201
+ const int64_t i01 = tpig[1];
202
+ const int64_t i00 = tpig[0];
203
+
204
+ if (i00 > n_past + i01) {
205
+ dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
206
+ } else {
207
+ dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
208
+ }
209
+ }
210
+
211
+ kernel void kernel_get_rows_f16(
212
+ device const void * src0,
213
+ device const int * src1,
214
+ device float * dst,
215
+ constant int64_t & ne00,
216
+ constant uint64_t & nb01,
217
+ constant uint64_t & nb1,
218
+ uint tpig[[thread_position_in_grid]]) {
219
+ const int i = tpig;
220
+ const int r = ((device int32_t *) src1)[i];
221
+
222
+ for (int j = 0; j < ne00; j++) {
223
+ dst[i*nb1 + j] = ((device half *) ((device char *) src0 + r*nb01))[j];
224
+ }
225
+ }
226
+
227
+ kernel void kernel_get_rows_q4_0(
228
+ device const void * src0,
229
+ device const int * src1,
230
+ device float * dst,
231
+ constant int64_t & ne00,
232
+ constant uint64_t & nb01,
233
+ constant uint64_t & nb1,
234
+ uint tpig[[thread_position_in_grid]]) {
235
+ const int i = tpig;
236
+ const int r = ((device int32_t *) src1)[i];
237
+
238
+ dequantize_row_q4_0(
239
+ (device const block_q4_0 *) ((device char *) src0 + r*nb01),
240
+ (device float *) ((device char *) dst + i*nb1), ne00);
241
+ }
242
+
243
+ kernel void kernel_get_rows_q4_1(
244
+ device const void * src0,
245
+ device const int * src1,
246
+ device float * dst,
247
+ constant int64_t & ne00,
248
+ constant uint64_t & nb01,
249
+ constant uint64_t & nb1,
250
+ uint tpig[[thread_position_in_grid]]) {
251
+ const int i = tpig;
252
+ const int r = ((device int32_t *) src1)[i];
253
+
254
+ dequantize_row_q4_1(
255
+ (device const block_q4_1 *) ((device char *) src0 + r*nb01),
256
+ (device float *) ((device char *) dst + i*nb1), ne00);
257
+ }
258
+
259
+ kernel void kernel_rms_norm(
260
+ device const void * src0,
261
+ device float * dst,
262
+ constant int64_t & ne00,
263
+ constant uint64_t & nb01,
264
+ constant float & eps,
265
+ threadgroup float * sum [[threadgroup(0)]],
266
+ uint tgpig[[threadgroup_position_in_grid]],
267
+ uint tpitg[[thread_position_in_threadgroup]],
268
+ uint ntg[[threads_per_threadgroup]]) {
269
+ device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
270
+
271
+ // parallel sum
272
+ sum[tpitg] = 0.0f;
273
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
274
+ sum[tpitg] += x[i00] * x[i00];
275
+ }
276
+
277
+ // reduce
278
+ threadgroup_barrier(mem_flags::mem_threadgroup);
279
+ for (uint i = ntg/2; i > 0; i /= 2) {
280
+ if (tpitg < i) {
281
+ sum[tpitg] += sum[tpitg + i];
282
+ }
283
+ threadgroup_barrier(mem_flags::mem_threadgroup);
284
+ }
285
+
286
+ // broadcast
287
+ if (tpitg == 0) {
288
+ sum[0] /= ne00;
289
+ }
290
+
291
+ threadgroup_barrier(mem_flags::mem_threadgroup);
292
+
293
+ const float mean = sum[0];
294
+ const float scale = 1.0f/sqrt(mean + eps);
295
+
296
+ device float * y = dst + tgpig*ne00;
297
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
298
+ y[i00] = x[i00] * scale;
299
+ }
300
+ }
301
+
302
+ kernel void kernel_mul_mat_q4_0_f32(
303
+ device const void * src0,
304
+ device const float * src1,
305
+ device float * dst,
306
+ constant int64_t & ne00,
307
+ constant int64_t & ne01,
308
+ constant uint64_t & nb00,
309
+ constant uint64_t & nb01,
310
+ constant uint64_t & nb02,
311
+ constant int64_t & ne10,
312
+ constant int64_t & ne11,
313
+ constant uint64_t & nb10,
314
+ constant uint64_t & nb11,
315
+ constant uint64_t & nb12,
316
+ constant int64_t & ne0,
317
+ constant int64_t & ne1,
318
+ threadgroup float * sum [[threadgroup(0)]],
319
+ uint2 tgpig[[threadgroup_position_in_grid]],
320
+ uint2 tpig[[thread_position_in_grid]],
321
+ uint2 tpitg[[thread_position_in_threadgroup]],
322
+ uint2 tptg[[threads_per_threadgroup]]) {
323
+ const int nb = ne00/QK4_0;
324
+
325
+ const int8_t m8 = 8;
326
+
327
+ const int64_t r0 = tgpig.x;
328
+ const int64_t r1 = tgpig.y;
329
+
330
+ device const block_q4_0 * x = (device const block_q4_0 *) src0 + r0*nb;
331
+ device const float * y = (device const float *) src1 + r1*ne10;
332
+
333
+ const uint nth = tptg.x*tptg.y;
334
+ const uint ith = tptg.y*tpitg.x + tpitg.y;
335
+
336
+ const int ix = tpitg.y/4; // 0 or 1
337
+ const int iy = tpitg.y - 4*ix; // 0...3
338
+
339
+ const int first = 4 * iy;
340
+
341
+ float sumf = 0;
342
+
343
+ for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) {
344
+
345
+ const float d = (float)x[i].d;
346
+
347
+ device const uint8_t * xl = x[i].qs + first;
348
+ device const float * yl = y + i * QK4_0 + first;
349
+
350
+ float2 acc = {0.0f, 0.0f};
351
+
352
+ for (int j = 0; j < 4; ++j) {
353
+
354
+ acc[0] += yl[j+ 0] * ((int8_t)(xl[j] & 0xF) - m8);
355
+ acc[1] += yl[j+16] * ((int8_t)(xl[j] >> 4) - m8);
356
+
357
+ }
358
+
359
+ sumf += d * (acc[0] + acc[1]);
360
+ }
361
+
362
+ sum[ith] = sumf;
363
+
364
+ //
365
+ // Accumulate the sum from all threads in the threadgroup
366
+ // This version is slightly faster than the commented out one below,
367
+ // which I copy-pasted from ggerganov's q4_0 dot product for metal.
368
+ //
369
+ threadgroup_barrier(mem_flags::mem_threadgroup);
370
+ if (ith%4 == 0) {
371
+ for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
372
+ }
373
+ threadgroup_barrier(mem_flags::mem_threadgroup);
374
+ if (ith%16 == 0) {
375
+ for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
376
+ }
377
+ threadgroup_barrier(mem_flags::mem_threadgroup);
378
+ if (ith == 0) {
379
+ for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
380
+ dst[r1*ne0 + r0] = sum[0];
381
+ }
382
+
383
+ //// accumulate the sum from all threads in the threadgroup
384
+ //threadgroup_barrier(mem_flags::mem_threadgroup);
385
+ //for (uint i = nth/2; i > 0; i /= 2) {
386
+ // if (ith < i) {
387
+ // sum[ith] += sum[ith + i];
388
+ // }
389
+ // threadgroup_barrier(mem_flags::mem_threadgroup);
390
+ //}
391
+
392
+ //if (ith == 0) {
393
+ // dst[r1*ne0 + r0] = sum[0];
394
+ //}
395
+ }
396
+
397
+ kernel void kernel_mul_mat_q4_1_f32(
398
+ device const void * src0,
399
+ device const float * src1,
400
+ device float * dst,
401
+ constant int64_t & ne00,
402
+ constant int64_t & ne01,
403
+ constant uint64_t & nb00,
404
+ constant uint64_t & nb01,
405
+ constant uint64_t & nb02,
406
+ constant int64_t & ne10,
407
+ constant int64_t & ne11,
408
+ constant uint64_t & nb10,
409
+ constant uint64_t & nb11,
410
+ constant uint64_t & nb12,
411
+ constant int64_t & ne0,
412
+ constant int64_t & ne1,
413
+ threadgroup float * sum [[threadgroup(0)]],
414
+ uint2 tgpig[[threadgroup_position_in_grid]],
415
+ uint2 tpig[[thread_position_in_grid]],
416
+ uint2 tpitg[[thread_position_in_threadgroup]],
417
+ uint2 tptg[[threads_per_threadgroup]]) {
418
+ const int nb = ne00/QK4_1;
419
+
420
+ const int64_t r0 = tgpig.x;
421
+ const int64_t r1 = tgpig.y;
422
+
423
+ device const block_q4_1 * x = (device const block_q4_1 *) src0 + r0*nb;
424
+ device const float * y = (device const float *) src1 + r1*ne10;
425
+
426
+ const uint nth = tptg.x*tptg.y;
427
+ const uint ith = tptg.y*tpitg.x + tpitg.y;
428
+
429
+ const int ix = tpitg.y/4; // 0 or 1
430
+ const int iy = tpitg.y - 4*ix; // 0...3
431
+
432
+ const int first = 4 * iy;
433
+
434
+ float sumf = 0;
435
+
436
+ for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) {
437
+
438
+ const float d = (float)x[i].d;
439
+ const float m = (float)x[i].m;
440
+
441
+ device const uint8_t * xl = x[i].qs + first;
442
+ device const float * yl = y + i * QK4_1 + first;
443
+
444
+ float2 acc = {0.0f, 0.0f};
445
+
446
+ for (int j = 0; j < 4; ++j) {
447
+
448
+ acc[0] += yl[j+ 0] * (d * (xl[j] & 0xF) + m);
449
+ acc[1] += yl[j+16] * (d * (xl[j] >> 4) + m);
450
+
451
+ }
452
+
453
+ sumf += acc[0] + acc[1];
454
+ }
455
+
456
+ sum[ith] = sumf;
457
+
458
+ //
459
+ // Accumulate the sum from all threads in the threadgroup
460
+ //
461
+ threadgroup_barrier(mem_flags::mem_threadgroup);
462
+ if (ith%4 == 0) {
463
+ for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
464
+ }
465
+ threadgroup_barrier(mem_flags::mem_threadgroup);
466
+ if (ith%16 == 0) {
467
+ for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
468
+ }
469
+ threadgroup_barrier(mem_flags::mem_threadgroup);
470
+ if (ith == 0) {
471
+ for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
472
+ dst[r1*ne0 + r0] = sum[0];
473
+ }
474
+ }
475
+
476
+ kernel void kernel_mul_mat_f16_f32(
477
+ device const char * src0,
478
+ device const char * src1,
479
+ device float * dst,
480
+ constant int64_t & ne00,
481
+ constant int64_t & ne01,
482
+ constant uint64_t & nb00,
483
+ constant uint64_t & nb01,
484
+ constant uint64_t & nb02,
485
+ constant int64_t & ne10,
486
+ constant int64_t & ne11,
487
+ constant uint64_t & nb10,
488
+ constant uint64_t & nb11,
489
+ constant uint64_t & nb12,
490
+ constant int64_t & ne0,
491
+ constant int64_t & ne1,
492
+ threadgroup float * sum [[threadgroup(0)]],
493
+ uint3 tgpig[[threadgroup_position_in_grid]],
494
+ uint3 tpig[[thread_position_in_grid]],
495
+ uint3 tpitg[[thread_position_in_threadgroup]],
496
+ uint3 tptg[[threads_per_threadgroup]]) {
497
+
498
+ const int64_t r0 = tgpig.x;
499
+ const int64_t r1 = tgpig.y;
500
+ const int64_t im = tgpig.z;
501
+
502
+ device const half * x = (device const half *) (src0 + r0*nb01 + im*nb02);
503
+ device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
504
+
505
+ sum[tpitg.x] = 0.0f;
506
+
507
+ for (int i = tpitg.x; i < ne00; i += tptg.x) {
508
+ sum[tpitg.x] += (float) x[i] * (float) y[i];
509
+ }
510
+
511
+ // accumulate the sum from all threads in the threadgroup
512
+ threadgroup_barrier(mem_flags::mem_threadgroup);
513
+ for (uint i = tptg.x/2; i > 0; i /= 2) {
514
+ if (tpitg.x < i) {
515
+ sum[tpitg.x] += sum[tpitg.x + i];
516
+ }
517
+ threadgroup_barrier(mem_flags::mem_threadgroup);
518
+ }
519
+
520
+ if (tpitg.x == 0) {
521
+ dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
522
+ }
523
+ }
524
+
525
+ kernel void kernel_rope(
526
+ device const void * src0,
527
+ device float * dst,
528
+ constant int64_t & ne00,
529
+ constant int64_t & ne01,
530
+ constant int64_t & ne02,
531
+ constant int64_t & ne03,
532
+ constant uint64_t & nb00,
533
+ constant uint64_t & nb01,
534
+ constant uint64_t & nb02,
535
+ constant uint64_t & nb03,
536
+ constant int64_t & ne0,
537
+ constant int64_t & ne1,
538
+ constant int64_t & ne2,
539
+ constant int64_t & ne3,
540
+ constant uint64_t & nb0,
541
+ constant uint64_t & nb1,
542
+ constant uint64_t & nb2,
543
+ constant uint64_t & nb3,
544
+ constant int & n_past,
545
+ constant int & n_dims,
546
+ constant int & mode,
547
+ uint3 tpig[[thread_position_in_grid]]) {
548
+ const int64_t i3 = tpig[2];
549
+ const int64_t i2 = tpig[1];
550
+ const int64_t i1 = tpig[0];
551
+
552
+ const bool is_neox = mode & 2;
553
+ const float theta_scale = pow(10000.0, -2.0f/n_dims);
554
+
555
+ const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
556
+
557
+ float theta = (float)p;
558
+
559
+ if (!is_neox) {
560
+ for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
561
+ const float cos_theta = cos(theta);
562
+ const float sin_theta = sin(theta);
563
+
564
+ theta *= theta_scale;
565
+
566
+ device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
567
+ device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
568
+
569
+ const float x0 = src[0];
570
+ const float x1 = src[1];
571
+
572
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
573
+ dst_data[1] = x0*sin_theta + x1*cos_theta;
574
+ }
575
+ } else {
576
+ // TODO: implement
577
+ }
578
+ }
579
+
580
+ kernel void kernel_cpy_f32_f16(
581
+ device const float * src0,
582
+ device half * dst,
583
+ constant int64_t & ne00,
584
+ constant int64_t & ne01,
585
+ constant int64_t & ne02,
586
+ constant int64_t & ne03,
587
+ constant uint64_t & nb00,
588
+ constant uint64_t & nb01,
589
+ constant uint64_t & nb02,
590
+ constant uint64_t & nb03,
591
+ constant int64_t & ne0,
592
+ constant int64_t & ne1,
593
+ constant int64_t & ne2,
594
+ constant int64_t & ne3,
595
+ constant uint64_t & nb0,
596
+ constant uint64_t & nb1,
597
+ constant uint64_t & nb2,
598
+ constant uint64_t & nb3,
599
+ uint3 tgpig[[threadgroup_position_in_grid]],
600
+ uint3 tpitg[[thread_position_in_threadgroup]],
601
+ uint3 ntg[[threads_per_threadgroup]]) {
602
+ const int64_t i03 = tgpig[2];
603
+ const int64_t i02 = tgpig[1];
604
+ const int64_t i01 = tgpig[0];
605
+
606
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
607
+
608
+ const int64_t i3 = n / (ne2*ne1*ne0);
609
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
610
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
611
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
612
+
613
+ device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
614
+
615
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
616
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
617
+
618
+ dst_data[i00] = src[0];
619
+ }
620
+ }
621
+
622
+ kernel void kernel_cpy_f32_f32(
623
+ device const float * src0,
624
+ device float * dst,
625
+ constant int64_t & ne00,
626
+ constant int64_t & ne01,
627
+ constant int64_t & ne02,
628
+ constant int64_t & ne03,
629
+ constant uint64_t & nb00,
630
+ constant uint64_t & nb01,
631
+ constant uint64_t & nb02,
632
+ constant uint64_t & nb03,
633
+ constant int64_t & ne0,
634
+ constant int64_t & ne1,
635
+ constant int64_t & ne2,
636
+ constant int64_t & ne3,
637
+ constant uint64_t & nb0,
638
+ constant uint64_t & nb1,
639
+ constant uint64_t & nb2,
640
+ constant uint64_t & nb3,
641
+ uint3 tgpig[[threadgroup_position_in_grid]],
642
+ uint3 tpitg[[thread_position_in_threadgroup]],
643
+ uint3 ntg[[threads_per_threadgroup]]) {
644
+ const int64_t i03 = tgpig[2];
645
+ const int64_t i02 = tgpig[1];
646
+ const int64_t i01 = tgpig[0];
647
+
648
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
649
+
650
+ const int64_t i3 = n / (ne2*ne1*ne0);
651
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
652
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
653
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
654
+
655
+ device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
656
+
657
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
658
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
659
+
660
+ dst_data[i00] = src[0];
661
+ }
662
+ }
663
+
664
+ //============================================ k-quants ======================================================
665
+
666
+ #define QK_K 256
667
+
668
+ typedef struct {
669
+ uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
670
+ uint8_t qs[QK_K/4]; // quants
671
+ half d; // super-block scale for quantized scales
672
+ half dmin; // super-block scale for quantized mins
673
+ } block_q2_k;
674
+
675
+ typedef struct {
676
+ half d; // super-block scale for quantized scales
677
+ half dmin; // super-block scale for quantized mins
678
+ uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
679
+ uint8_t qs[QK_K/2]; // 4--bit quants
680
+ } block_q4_k;
681
+
682
+ typedef struct {
683
+ uint8_t ql[QK_K/2]; // quants, lower 4 bits
684
+ uint8_t qh[QK_K/4]; // quants, upper 2 bits
685
+ int8_t scales[QK_K/16]; // scales, quantized with 8 bits
686
+ half d; // super-block scale
687
+ } block_q6_k;
688
+
689
+ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
690
+ uchar4 r;
691
+ if (j < 4) {
692
+ r[0] = q[j+0] & 63; r[1] = q[j+4] & 63;
693
+ r[2] = q[j+1] & 63; r[3] = q[j+5] & 63;
694
+ } else {
695
+ r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
696
+ r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
697
+ r[2] = (q[j+5] & 0xF) | ((q[j-3] >> 6) << 4);
698
+ r[3] = (q[j+5] >> 4) | ((q[j+1] >> 6) << 4);
699
+ }
700
+ return r;
701
+ }
702
+
703
+ //========================================== dequantization =============================
704
+
705
+ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, int k) {
706
+ assert(k % QK_K == 0);
707
+ const int nb = k / QK_K;
708
+
709
+ for (int i = 0; i < nb; i++) {
710
+
711
+ const float d = x[i].d;
712
+ const float min = x[i].dmin;
713
+
714
+ device const uint8_t * q = x[i].qs;
715
+
716
+ int is = 0;
717
+ float dl, ml;
718
+ for (int n = 0; n < QK_K; n += 128) {
719
+ int shift = 0;
720
+ for (int j = 0; j < 4; ++j) {
721
+
722
+ uint8_t sc = x[i].scales[is++];
723
+ dl = d * (sc & 0xF); ml = min * (sc >> 4);
724
+ for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml;
725
+
726
+ sc = x[i].scales[is++];
727
+ dl = d * (sc & 0xF); ml = min * (sc >> 4);
728
+ for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml;
729
+
730
+ shift += 2;
731
+ }
732
+ q += 32;
733
+ }
734
+
735
+ }
736
+ }
737
+
738
+ static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, int k) {
739
+ assert(k % QK_K == 0);
740
+ const int nb = k / QK_K;
741
+
742
+ for (int i = 0; i < nb; i++) {
743
+
744
+ const float d = x[i].d;
745
+ const float min = x[i].dmin;
746
+
747
+ device const uint8_t * q = x[i].qs;
748
+ device const uint8_t * scales = x[i].scales;
749
+
750
+ int is = 0;
751
+ for (int j = 0; j < QK_K; j += 64) {
752
+ const uchar4 sc = get_scale_min_k4(is, scales);
753
+ const float d1 = d * sc[0]; const float m1 = min * sc[1];
754
+ const float d2 = d * sc[2]; const float m2 = min * sc[3];
755
+ for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
756
+ for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
757
+ q += 32; is += 2;
758
+ }
759
+
760
+ }
761
+ }
762
+
763
+ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, int k) {
764
+ assert(k % QK_K == 0);
765
+ const int nb = k / QK_K;
766
+
767
+ for (int i = 0; i < nb; i++) {
768
+
769
+ device const uint8_t * ql = x[i].ql;
770
+ device const uint8_t * qh = x[i].qh;
771
+ device const int8_t * sc = x[i].scales;
772
+
773
+ const float d = x[i].d;
774
+
775
+ for (int n = 0; n < QK_K; n += 128) {
776
+ for (int l = 0; l < 32; ++l) {
777
+ int is = l/16;
778
+ const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
779
+ const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
780
+ const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
781
+ const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
782
+ y[l + 0] = d * sc[is + 0] * q1;
783
+ y[l + 32] = d * sc[is + 2] * q2;
784
+ y[l + 64] = d * sc[is + 4] * q3;
785
+ y[l + 96] = d * sc[is + 6] * q4;
786
+ }
787
+ y += 128;
788
+ ql += 64;
789
+ qh += 32;
790
+ sc += 8;
791
+ }
792
+ }
793
+ }
794
+
795
+ kernel void kernel_get_rows_q2_k(
796
+ device const void * src0,
797
+ device const int * src1,
798
+ device float * dst,
799
+ constant int64_t & ne00,
800
+ constant uint64_t & nb01,
801
+ constant uint64_t & nb1,
802
+ uint tpig[[thread_position_in_grid]]) {
803
+ const int i = tpig;
804
+ const int r = ((device int32_t *) src1)[i];
805
+
806
+ dequantize_row_q2_k(
807
+ (device const block_q2_k *) ((device char *) src0 + r*nb01),
808
+ (device float *) ((device char *) dst + i*nb1), ne00);
809
+ }
810
+
811
+ kernel void kernel_get_rows_q4_k(
812
+ device const void * src0,
813
+ device const int * src1,
814
+ device float * dst,
815
+ constant int64_t & ne00,
816
+ constant uint64_t & nb01,
817
+ constant uint64_t & nb1,
818
+ uint tpig[[thread_position_in_grid]]) {
819
+ const int i = tpig;
820
+ const int r = ((device int32_t *) src1)[i];
821
+
822
+ dequantize_row_q4_k(
823
+ (device const block_q4_k *) ((device char *) src0 + r*nb01),
824
+ (device float *) ((device char *) dst + i*nb1), ne00);
825
+ }
826
+
827
+ kernel void kernel_get_rows_q6_k(
828
+ device const void * src0,
829
+ device const int * src1,
830
+ device float * dst,
831
+ constant int64_t & ne00,
832
+ constant uint64_t & nb01,
833
+ constant uint64_t & nb1,
834
+ uint tpig[[thread_position_in_grid]]) {
835
+ const int i = tpig;
836
+ const int r = ((device int32_t *) src1)[i];
837
+
838
+ dequantize_row_q6_k(
839
+ (device const block_q6_k *) ((device char *) src0 + r*nb01),
840
+ (device float *) ((device char *) dst + i*nb1), ne00);
841
+ }
842
+
843
+ //====================================== dot products =========================
844
+
845
+ kernel void kernel_mul_mat_q2_k_f32(
846
+ device const void * src0,
847
+ device const float * src1,
848
+ device float * dst,
849
+ constant int64_t & ne00,
850
+ constant int64_t & ne01,
851
+ constant uint64_t & nb00,
852
+ constant uint64_t & nb01,
853
+ constant uint64_t & nb02,
854
+ constant int64_t & ne10,
855
+ constant int64_t & ne11,
856
+ constant uint64_t & nb10,
857
+ constant uint64_t & nb11,
858
+ constant uint64_t & nb12,
859
+ constant int64_t & ne0,
860
+ constant int64_t & ne1,
861
+ threadgroup float * sum [[threadgroup(0)]],
862
+ uint2 tgpig[[threadgroup_position_in_grid]],
863
+ uint2 tpig[[thread_position_in_grid]], // we don't use this for now
864
+ uint2 tpitg[[thread_position_in_threadgroup]],
865
+ uint2 tptg[[threads_per_threadgroup]]) {
866
+
867
+ const int nb = ne00/QK_K;
868
+
869
+ const int64_t r0 = tgpig.x;
870
+ const int64_t r1 = tgpig.y;
871
+
872
+ device const block_q2_k * x = (device const block_q2_k *) src0 + r0*nb;
873
+ device const float * yy = (device const float *) src1 + r1*ne10;
874
+
875
+ const int nth = tptg.x*tptg.y;
876
+ const int ith = tptg.y*tpitg.x + tpitg.y;
877
+
878
+
879
+ const int tid = tpitg.y; // 0...16
880
+ const int il = tid/4; // 0...3
881
+ const int ir = tid%4; // 0...3
882
+ const int ip = il/2; // 0 or 1
883
+ const int shift1 = 4*(il%2);// 0 or 4
884
+ const int shift2 = shift1+2;// 2 or 6
885
+ const int n = 8;
886
+ const int is = 4*il + (n*ir)/16;
887
+
888
+ sum[ith] = 0.0f;
889
+
890
+ float sumf = 0;
891
+ for (int i = tpitg.x; i < nb; i += tptg.x) {
892
+
893
+ device const uint8_t * q = x[i].qs + 32*ip + n*ir;
894
+ device const uint8_t * scales = x[i].scales + is;
895
+
896
+ uint8_t d1 = scales[0] & 0xF;
897
+ uint8_t m1 = scales[0] >> 4;
898
+ uint8_t d2 = scales[2] & 0xF;
899
+ uint8_t m2 = scales[2] >> 4;
900
+
901
+ device const float * y = yy + i*QK_K + 64*il + n*ir;
902
+
903
+ const float dall = (float)x[i].d;
904
+ const float dmin = (float)x[i].dmin;
905
+
906
+ float4 s = {0.f, 0.f, 0.f, 0.f};
907
+ for (int l = 0; l < n; ++l) {
908
+ s[0] += y[l+ 0] * ((q[l] >> shift1) & 3); s[1] += y[l+ 0];
909
+ s[2] += y[l+32] * ((q[l] >> shift2) & 3); s[3] += y[l+32];
910
+ }
911
+ sumf += dall * (s[0] * d1 + s[2] * d2) - dmin * (s[1] * m1 + s[3] * m2);
912
+
913
+
914
+ }
915
+ sum[ith] = sumf;
916
+
917
+ //
918
+ // Accumulate the sum from all threads in the threadgroup
919
+ // This version is slightly faster than the commented out one below,
920
+ // which I copy-pasted from ggerganov's q4_0 dot product for metal.
921
+ //
922
+ threadgroup_barrier(mem_flags::mem_threadgroup);
923
+ if (ith%4 == 0) {
924
+ for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
925
+ }
926
+ threadgroup_barrier(mem_flags::mem_threadgroup);
927
+ if (ith%16 == 0) {
928
+ for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
929
+ }
930
+ threadgroup_barrier(mem_flags::mem_threadgroup);
931
+ if (ith == 0) {
932
+ for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
933
+ dst[r1*ne0 + r0] = sum[0];
934
+ }
935
+
936
+ //// accumulate the sum from all threads in the threadgroup
937
+ //threadgroup_barrier(mem_flags::mem_threadgroup);
938
+ //for (uint i = nth/2; i > 0; i /= 2) {
939
+ // if (ith < i) {
940
+ // sum[ith] += sum[ith + i];
941
+ // }
942
+ // threadgroup_barrier(mem_flags::mem_threadgroup);
943
+ //}
944
+
945
+ //if (ith == 0) {
946
+ // dst[r1*ne0 + r0] = sum[0];
947
+ //}
948
+ }
949
+
950
+ kernel void kernel_mul_mat_q4_k_f32(
951
+ device const void * src0,
952
+ device const float * src1,
953
+ device float * dst,
954
+ constant int64_t & ne00,
955
+ constant int64_t & ne01,
956
+ constant uint64_t & nb00,
957
+ constant uint64_t & nb01,
958
+ constant uint64_t & nb02,
959
+ constant int64_t & ne10,
960
+ constant int64_t & ne11,
961
+ constant uint64_t & nb10,
962
+ constant uint64_t & nb11,
963
+ constant uint64_t & nb12,
964
+ constant int64_t & ne0,
965
+ constant int64_t & ne1,
966
+ threadgroup float * sum [[threadgroup(0)]],
967
+ uint2 tgpig[[threadgroup_position_in_grid]],
968
+ uint2 tpig[[thread_position_in_grid]], // we don't use this for now
969
+ uint2 tpitg[[thread_position_in_threadgroup]],
970
+ uint2 tptg[[threads_per_threadgroup]]) {
971
+
972
+ const int nb = ne00/QK_K;
973
+
974
+ const int64_t r0 = tgpig.x;
975
+ const int64_t r1 = tgpig.y;
976
+
977
+ device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb;
978
+ device const float * yy = (device const float *) src1 + r1*ne10;
979
+
980
+ const uint nth = tptg.x*tptg.y;
981
+ const uint ith = tptg.y*tpitg.x + tpitg.y;
982
+
983
+ const int tid = tpitg.y; // 0...16
984
+ const int il = tid/4; // 0...3
985
+ const int ir = tid%4; // 0...3
986
+ const int n = 8;
987
+ const int is = 2*il;
988
+
989
+ sum[ith] = 0.0f;
990
+
991
+ float sumf = 0;
992
+ for (int i = tpitg.x; i < nb; i += tptg.x) {
993
+
994
+ device const uint8_t * q = (x + i)->qs + 32*il + n*ir;
995
+ device const float * y = yy + i*QK_K + 64*il + n*ir;
996
+ device const uint8_t * scales = (x + i)->scales;
997
+
998
+ const float dall = (float)((x + i)->d);
999
+ const float dmin = (float)((x + i)->dmin);
1000
+
1001
+ const uchar4 sc = get_scale_min_k4(is, scales);
1002
+
1003
+ float4 s = {0.f, 0.f, 0.f, 0.f};
1004
+ for (int l = 0; l < n; ++l) {
1005
+ s[0] += y[l+ 0] * (q[l] & 0xF); s[1] += y[l+ 0];
1006
+ s[2] += y[l+32] * (q[l] >> 4); s[3] += y[l+32];
1007
+ }
1008
+ sumf += dall * (s[0] * sc[0] + s[2] * sc[2]) - dmin * (s[1] * sc[1] + s[3] * sc[3]);
1009
+
1010
+ }
1011
+ sum[ith] = sumf;
1012
+
1013
+ //
1014
+ // Accumulate the sum from all threads in the threadgroup
1015
+ // This version is slightly faster than the commented out one below,
1016
+ // which I copy-pasted from ggerganov's q4_0 dot product for metal.
1017
+ //
1018
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1019
+ if (ith%4 == 0) {
1020
+ for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
1021
+ }
1022
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1023
+ if (ith%16 == 0) {
1024
+ for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
1025
+ }
1026
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1027
+ if (ith == 0) {
1028
+ for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1029
+ dst[r1*ne0 + r0] = sum[0];
1030
+ }
1031
+
1032
+ //// accumulate the sum from all threads in the threadgroup
1033
+ //threadgroup_barrier(mem_flags::mem_threadgroup);
1034
+ //for (uint i = nth/2; i > 0; i /= 2) {
1035
+ // if (ith < i) {
1036
+ // sum[ith] += sum[ith + i];
1037
+ // }
1038
+ // threadgroup_barrier(mem_flags::mem_threadgroup);
1039
+ //}
1040
+
1041
+ //if (ith == 0) {
1042
+ // dst[r1*ne0 + r0] = sum[0];
1043
+ //}
1044
+ }
1045
+
1046
+ kernel void kernel_mul_mat_q6_k_f32(
1047
+ device const void * src0,
1048
+ device const float * src1,
1049
+ device float * dst,
1050
+ constant int64_t & ne00,
1051
+ constant int64_t & ne01,
1052
+ constant uint64_t & nb00,
1053
+ constant uint64_t & nb01,
1054
+ constant uint64_t & nb02,
1055
+ constant int64_t & ne10,
1056
+ constant int64_t & ne11,
1057
+ constant uint64_t & nb10,
1058
+ constant uint64_t & nb11,
1059
+ constant uint64_t & nb12,
1060
+ constant int64_t & ne0,
1061
+ constant int64_t & ne1,
1062
+ threadgroup float * sum [[threadgroup(0)]],
1063
+ uint2 tgpig[[threadgroup_position_in_grid]],
1064
+ uint2 tpig[[thread_position_in_grid]], // we don't use this for now
1065
+ uint2 tpitg[[thread_position_in_threadgroup]],
1066
+ uint2 tptg[[threads_per_threadgroup]]) {
1067
+
1068
+ const uint8_t kmask1 = 0x03;
1069
+ const uint8_t kmask2 = 0x0C;
1070
+ const uint8_t kmask3 = 0x30;
1071
+ const uint8_t kmask4 = 0xC0;
1072
+
1073
+ const int nb = ne00/QK_K;
1074
+
1075
+ const int64_t r0 = tgpig.x;
1076
+ const int64_t r1 = tgpig.y;
1077
+
1078
+ device const block_q6_k * x = (device const block_q6_k *) src0 + r0*nb;
1079
+ device const float * yy = (device const float *) src1 + r1*ne10;
1080
+
1081
+ const uint nth = tptg.x*tptg.y;
1082
+ const uint ith = tptg.y*tpitg.x + tpitg.y;
1083
+
1084
+ const int step = QK_K / tptg.y; // we expect this to be 16
1085
+ const int iqs = step * tpitg.y; // 0...240 in steps of 16
1086
+ const int ip = iqs / 128; // 0 or 1
1087
+ const int il = (iqs - 128*ip)/16; // 0...7
1088
+ const int n = 4;
1089
+ const int is = 8*ip + (n*il)/16;
1090
+
1091
+ float sumf = 0;
1092
+ for (int i = tpitg.x; i < nb; i += tptg.x) {
1093
+
1094
+ device const uint8_t * ql = x[i].ql + 64*ip + n*il;
1095
+ device const uint8_t * qh = x[i].qh + 32*ip + n*il;
1096
+ device const int8_t * sc = x[i].scales + is;
1097
+
1098
+ device const float * y = yy + i * QK_K + 128*ip + n*il;
1099
+
1100
+ const float dall = x[i].d;
1101
+
1102
+ float4 sums = {0.f, 0.f, 0.f, 0.f};
1103
+ for (int l = 0; l < n; ++l) {
1104
+ sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
1105
+ sums[1] += y[l+32] * ((int8_t)((ql[l+32] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
1106
+ sums[2] += y[l+64] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
1107
+ sums[3] += y[l+96] * ((int8_t)((ql[l+32] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
1108
+ }
1109
+
1110
+ sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
1111
+
1112
+ }
1113
+
1114
+ sum[ith] = sumf;
1115
+
1116
+ //
1117
+ // Accumulate the sum from all threads in the threadgroup
1118
+ //
1119
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1120
+ if (ith%4 == 0) {
1121
+ for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
1122
+ }
1123
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1124
+ if (ith%16 == 0) {
1125
+ for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
1126
+ }
1127
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1128
+ if (ith == 0) {
1129
+ for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1130
+ dst[r1*ne0 + r0] = sum[0];
1131
+ }
1132
+
1133
+ }