gpt_neox_client 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,2049 @@
1
+ #include <metal_stdlib>
2
+
3
+ using namespace metal;
4
+
5
+ #define MAX(x, y) ((x) > (y) ? (x) : (y))
6
+
7
+ #define QK4_0 32
8
+ #define QR4_0 2
9
+ typedef struct {
10
+ half d; // delta
11
+ uint8_t qs[QK4_0 / 2]; // nibbles / quants
12
+ } block_q4_0;
13
+
14
+ #define QK4_1 32
15
+ typedef struct {
16
+ half d; // delta
17
+ half m; // min
18
+ uint8_t qs[QK4_1 / 2]; // nibbles / quants
19
+ } block_q4_1;
20
+
21
+ #define QK8_0 32
22
+ typedef struct {
23
+ half d; // delta
24
+ int8_t qs[QK8_0]; // quants
25
+ } block_q8_0;
26
+
27
+ kernel void kernel_add(
28
+ device const float * src0,
29
+ device const float * src1,
30
+ device float * dst,
31
+ uint tpig[[thread_position_in_grid]]) {
32
+ dst[tpig] = src0[tpig] + src1[tpig];
33
+ }
34
+
35
+ // assumption: src1 is a row
36
+ // broadcast src1 into src0
37
+ kernel void kernel_add_row(
38
+ device const float * src0,
39
+ device const float * src1,
40
+ device float * dst,
41
+ constant int64_t & ne00,
42
+ uint tpig[[thread_position_in_grid]]) {
43
+ dst[tpig] = src0[tpig] + src1[tpig % ne00];
44
+ }
45
+
46
+ kernel void kernel_mul(
47
+ device const float * src0,
48
+ device const float * src1,
49
+ device float * dst,
50
+ uint tpig[[thread_position_in_grid]]) {
51
+ dst[tpig] = src0[tpig] * src1[tpig];
52
+ }
53
+
54
+ // assumption: src1 is a row
55
+ // broadcast src1 into src0
56
+ kernel void kernel_mul_row(
57
+ device const float * src0,
58
+ device const float * src1,
59
+ device float * dst,
60
+ constant int64_t & ne00,
61
+ uint tpig[[thread_position_in_grid]]) {
62
+ dst[tpig] = src0[tpig] * src1[tpig % ne00];
63
+ }
64
+
65
+ kernel void kernel_scale(
66
+ device const float * src0,
67
+ device float * dst,
68
+ constant float & scale,
69
+ uint tpig[[thread_position_in_grid]]) {
70
+ dst[tpig] = src0[tpig] * scale;
71
+ }
72
+
73
+ kernel void kernel_silu(
74
+ device const float * src0,
75
+ device float * dst,
76
+ uint tpig[[thread_position_in_grid]]) {
77
+ float x = src0[tpig];
78
+ dst[tpig] = x / (1.0f + exp(-x));
79
+ }
80
+
81
+ kernel void kernel_relu(
82
+ device const float * src0,
83
+ device float * dst,
84
+ uint tpig[[thread_position_in_grid]]) {
85
+ dst[tpig] = max(0.0f, src0[tpig]);
86
+ }
87
+
88
+ constant float GELU_COEF_A = 0.044715f;
89
+ constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
90
+
91
+ kernel void kernel_gelu(
92
+ device const float * src0,
93
+ device float * dst,
94
+ uint tpig[[thread_position_in_grid]]) {
95
+ float x = src0[tpig];
96
+
97
+ // BEWARE !!!
98
+ // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
99
+ // This was observed with Falcon 7B and 40B models
100
+ //
101
+ dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
102
+ }
103
+
104
+ kernel void kernel_soft_max(
105
+ device const float * src0,
106
+ device float * dst,
107
+ constant int64_t & ne00,
108
+ constant int64_t & ne01,
109
+ constant int64_t & ne02,
110
+ threadgroup float * buf [[threadgroup(0)]],
111
+ uint3 tgpig[[threadgroup_position_in_grid]],
112
+ uint3 tpitg[[thread_position_in_threadgroup]],
113
+ uint3 ntg[[threads_per_threadgroup]]) {
114
+ const int64_t i03 = tgpig[2];
115
+ const int64_t i02 = tgpig[1];
116
+ const int64_t i01 = tgpig[0];
117
+
118
+ device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
119
+ device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
120
+
121
+ // parallel max
122
+ buf[tpitg[0]] = -INFINITY;
123
+ for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
124
+ buf[tpitg[0]] = MAX(buf[tpitg[0]], psrc0[i00]);
125
+ }
126
+
127
+ // reduce
128
+ threadgroup_barrier(mem_flags::mem_threadgroup);
129
+ for (uint i = ntg[0]/2; i > 0; i /= 2) {
130
+ if (tpitg[0] < i) {
131
+ buf[tpitg[0]] = MAX(buf[tpitg[0]], buf[tpitg[0] + i]);
132
+ }
133
+ threadgroup_barrier(mem_flags::mem_threadgroup);
134
+ }
135
+
136
+ // broadcast
137
+ if (tpitg[0] == 0) {
138
+ buf[0] = buf[0];
139
+ }
140
+
141
+ threadgroup_barrier(mem_flags::mem_threadgroup);
142
+
143
+ const float max = buf[0];
144
+
145
+ // parallel sum
146
+ buf[tpitg[0]] = 0.0f;
147
+ for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
148
+ buf[tpitg[0]] += exp(psrc0[i00] - max);
149
+ }
150
+
151
+ // reduce
152
+ threadgroup_barrier(mem_flags::mem_threadgroup);
153
+ for (uint i = ntg[0]/2; i > 0; i /= 2) {
154
+ if (tpitg[0] < i) {
155
+ buf[tpitg[0]] += buf[tpitg[0] + i];
156
+ }
157
+ threadgroup_barrier(mem_flags::mem_threadgroup);
158
+ }
159
+
160
+ // broadcast
161
+ if (tpitg[0] == 0) {
162
+ buf[0] = buf[0];
163
+ }
164
+
165
+ threadgroup_barrier(mem_flags::mem_threadgroup);
166
+
167
+ const float sum = buf[0];
168
+
169
+ for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
170
+ pdst[i00] = exp(psrc0[i00] - max) / sum;
171
+ }
172
+ }
173
+
174
+ kernel void kernel_diag_mask_inf(
175
+ device const float * src0,
176
+ device float * dst,
177
+ constant int64_t & ne00,
178
+ constant int64_t & ne01,
179
+ constant int & n_past,
180
+ uint3 tpig[[thread_position_in_grid]]) {
181
+ const int64_t i02 = tpig[2];
182
+ const int64_t i01 = tpig[1];
183
+ const int64_t i00 = tpig[0];
184
+
185
+ if (i00 > n_past + i01) {
186
+ dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
187
+ } else {
188
+ dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
189
+ }
190
+ }
191
+
192
+ kernel void kernel_norm(
193
+ device const void * src0,
194
+ device float * dst,
195
+ constant int64_t & ne00,
196
+ constant uint64_t & nb01,
197
+ constant float & eps,
198
+ threadgroup float * sum [[threadgroup(0)]],
199
+ uint tgpig[[threadgroup_position_in_grid]],
200
+ uint tpitg[[thread_position_in_threadgroup]],
201
+ uint ntg[[threads_per_threadgroup]]) {
202
+ device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
203
+ // MEAN
204
+ // parallel sum
205
+ sum[tpitg] = 0.0f;
206
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
207
+ sum[tpitg] += x[i00];
208
+ }
209
+ // reduce
210
+ threadgroup_barrier(mem_flags::mem_threadgroup);
211
+ for (uint i = ntg/2; i > 0; i /= 2) {
212
+ if (tpitg < i) {
213
+ sum[tpitg] += sum[tpitg + i];
214
+ }
215
+ threadgroup_barrier(mem_flags::mem_threadgroup);
216
+ }
217
+ // broadcast
218
+ if (tpitg == 0) {
219
+ sum[0] /= ne00;
220
+ }
221
+ threadgroup_barrier(mem_flags::mem_threadgroup);
222
+ const float mean = sum[0];
223
+
224
+ // recenter
225
+ device float * y = dst + tgpig*ne00;
226
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
227
+ y[i00] = x[i00] - mean;
228
+ }
229
+
230
+ // VARIANCE
231
+ // parallel sum
232
+ sum[tpitg] = 0.0f;
233
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
234
+ sum[tpitg] += y[i00] * y[i00];
235
+ }
236
+ // reduce
237
+ threadgroup_barrier(mem_flags::mem_threadgroup);
238
+ for (uint i = ntg/2; i > 0; i /= 2) {
239
+ if (tpitg < i) {
240
+ sum[tpitg] += sum[tpitg + i];
241
+ }
242
+ threadgroup_barrier(mem_flags::mem_threadgroup);
243
+ }
244
+ // broadcast
245
+ if (tpitg == 0) {
246
+ sum[0] /= ne00;
247
+ }
248
+ threadgroup_barrier(mem_flags::mem_threadgroup);
249
+ const float variance = sum[0];
250
+
251
+ const float scale = 1.0f/sqrt(variance + eps);
252
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
253
+ y[i00] = y[i00] * scale;
254
+ }
255
+ }
256
+
257
+
258
+ kernel void kernel_rms_norm(
259
+ device const void * src0,
260
+ device float * dst,
261
+ constant int64_t & ne00,
262
+ constant uint64_t & nb01,
263
+ constant float & eps,
264
+ threadgroup float * sum [[threadgroup(0)]],
265
+ uint tgpig[[threadgroup_position_in_grid]],
266
+ uint tpitg[[thread_position_in_threadgroup]],
267
+ uint sgitg[[simdgroup_index_in_threadgroup]],
268
+ uint tiisg[[thread_index_in_simdgroup]],
269
+ uint ntg[[threads_per_threadgroup]]) {
270
+ device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
271
+ device const float * x_scalar = (device const float *) x;
272
+ float4 sumf=0;
273
+ float all_sum=0;
274
+
275
+ // parallel sum
276
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
277
+ sumf += x[i00] * x[i00];
278
+ }
279
+ all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
280
+ all_sum = simd_sum(all_sum);
281
+ if (tiisg == 0) {
282
+ sum[sgitg] = all_sum;
283
+ }
284
+
285
+ threadgroup_barrier(mem_flags::mem_threadgroup);
286
+ // broadcast, simd group number is ntg / 32
287
+ for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
288
+ if (tpitg < i) {
289
+ sum[tpitg] += sum[tpitg + i];
290
+ }
291
+ }
292
+ if (tpitg == 0) {
293
+ for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];}
294
+ sum[0] /= ne00;
295
+ }
296
+
297
+ threadgroup_barrier(mem_flags::mem_threadgroup);
298
+
299
+ const float mean = sum[0];
300
+ const float scale = 1.0f/sqrt(mean + eps);
301
+
302
+ device float4 * y = (device float4 *) (dst + tgpig*ne00);
303
+ device float * y_scalar = (device float *) y;
304
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
305
+ y[i00] = x[i00] * scale;
306
+ }
307
+ if (tpitg == 0) {
308
+ for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;}
309
+ }
310
+ }
311
+
312
+ // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
313
+ // il indicates where the q4 quants begin (0 or QK4_0/4)
314
+ // we assume that the yl's have been multiplied with the appropriate scale factor
315
+ // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
316
+ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
317
+ float d = qb_curr->d;
318
+ float2 acc = 0.f;
319
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
320
+ for (int i = 0; i < 8; i+=2) {
321
+ acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
322
+ + yl[i + 1] * (qs[i / 2] & 0x0F00);
323
+ acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
324
+ + yl[i + 9] * (qs[i / 2] & 0xF000);
325
+ }
326
+ return d * (sumy * -8.f + acc[0] + acc[1]);
327
+ }
328
+
329
+ // function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])
330
+ // il indicates where the q4 quants begin (0 or QK4_0/4)
331
+ // we assume that the yl's have been multiplied with the appropriate scale factor
332
+ // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
333
+ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
334
+ float d = qb_curr->d;
335
+ float m = qb_curr->m;
336
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
337
+ float2 acc = 0.f;
338
+ for (int i = 0; i < 8; i+=2) {
339
+ acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
340
+ + yl[i + 1] * (qs[i / 2] & 0x0F00);
341
+ acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
342
+ + yl[i + 9] * (qs[i / 2] & 0xF000);
343
+ }
344
+ return d * (acc[0] + acc[1]) + sumy * m;
345
+ }
346
+
347
+ // putting them in the kernel cause a significant performance penalty
348
+ #define N_DST 4 // each SIMD group works on 4 rows
349
+ #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
350
+ #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
351
+ //Note: This is a template, but strictly speaking it only applies to
352
+ // quantizations where the block size is 32. It also does not
353
+ // giard against the number of rows not being divisible by
354
+ // N_DST, so this is another explicit assumption of the implementation.
355
+ template<typename block_q_type, int nr, int nsg, int nw>
356
+ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
357
+ int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
358
+ uint3 tgpig, uint tiisg, uint sgitg) {
359
+ const int nb = ne00/QK4_0;
360
+ const int r0 = tgpig.x;
361
+ const int r1 = tgpig.y;
362
+ const int im = tgpig.z;
363
+ const int first_row = (r0 * nsg + sgitg) * nr;
364
+ const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
365
+ device const block_q_type * x = (device const block_q_type *) src0 + offset0;
366
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
367
+ float yl[16]; // src1 vector cache
368
+ float sumf[nr]={0.f};
369
+
370
+ const int ix = tiisg/2;
371
+ const int il = 8*(tiisg%2);
372
+
373
+ device const float * yb = y + ix * QK4_0 + il;
374
+
375
+ // each thread in a SIMD group deals with half a block.
376
+ for (int ib = ix; ib < nb; ib += nw/2) {
377
+ float sumy = 0;
378
+ for (int i = 0; i < 8; i += 2) {
379
+ sumy += yb[i] + yb[i+1];
380
+ yl[i+0] = yb[i+ 0];
381
+ yl[i+1] = yb[i+ 1]/256.f;
382
+ sumy += yb[i+16] + yb[i+17];
383
+ yl[i+8] = yb[i+16]/16.f;
384
+ yl[i+9] = yb[i+17]/4096.f;
385
+ }
386
+
387
+ for (int row = 0; row < nr; row++) {
388
+ sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il);
389
+ }
390
+
391
+ yb += QK4_0 * 16;
392
+ }
393
+
394
+ for (int row = 0; row < nr; ++row) {
395
+ const float tot = simd_sum(sumf[row]);
396
+ if (tiisg == 0 && first_row + row < ne01) {
397
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
398
+ }
399
+ }
400
+ }
401
+
402
+ kernel void kernel_mul_mat_q4_0_f32(
403
+ device const void * src0,
404
+ device const float * src1,
405
+ device float * dst,
406
+ constant int64_t & ne00,
407
+ constant int64_t & ne01[[buffer(4)]],
408
+ constant int64_t & ne02[[buffer(5)]],
409
+ constant int64_t & ne10[[buffer(9)]],
410
+ constant int64_t & ne12[[buffer(11)]],
411
+ constant int64_t & ne0[[buffer(15)]],
412
+ constant int64_t & ne1[[buffer(16)]],
413
+ constant uint & gqa[[buffer(17)]],
414
+ uint3 tgpig[[threadgroup_position_in_grid]],
415
+ uint tiisg[[thread_index_in_simdgroup]],
416
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
417
+ mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
418
+ }
419
+
420
+ kernel void kernel_mul_mat_q4_1_f32(
421
+ device const void * src0,
422
+ device const float * src1,
423
+ device float * dst,
424
+ constant int64_t & ne00,
425
+ constant int64_t & ne01[[buffer(4)]],
426
+ constant int64_t & ne02[[buffer(5)]],
427
+ constant int64_t & ne10[[buffer(9)]],
428
+ constant int64_t & ne12[[buffer(11)]],
429
+ constant int64_t & ne0[[buffer(15)]],
430
+ constant int64_t & ne1[[buffer(16)]],
431
+ constant uint & gqa[[buffer(17)]],
432
+ uint3 tgpig[[threadgroup_position_in_grid]],
433
+ uint tiisg[[thread_index_in_simdgroup]],
434
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
435
+ mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
436
+ }
437
+
438
+ kernel void kernel_mul_mat_q8_0_f32(
439
+ device const void * src0,
440
+ device const float * src1,
441
+ device float * dst,
442
+ constant int64_t & ne00,
443
+ constant int64_t & ne01[[buffer(4)]],
444
+ constant int64_t & ne02[[buffer(5)]],
445
+ constant int64_t & ne10[[buffer(9)]],
446
+ constant int64_t & ne12[[buffer(11)]],
447
+ constant int64_t & ne0[[buffer(15)]],
448
+ constant int64_t & ne1[[buffer(16)]],
449
+ constant uint & gqa[[buffer(17)]],
450
+ uint3 tgpig[[threadgroup_position_in_grid]],
451
+ uint tiisg[[thread_index_in_simdgroup]],
452
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
453
+ const int nr = N_DST;
454
+ const int nsg = N_SIMDGROUP;
455
+ const int nw = N_SIMDWIDTH;
456
+
457
+ const int nb = ne00/QK8_0;
458
+ const int r0 = tgpig.x;
459
+ const int r1 = tgpig.y;
460
+ const int im = tgpig.z;
461
+ const int first_row = (r0 * nsg + sgitg) * nr;
462
+ const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
463
+ device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
464
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
465
+
466
+ float yl[16];
467
+ float sumf[nr]={0.f};
468
+
469
+ const int ix = tiisg/2;
470
+ const int il = tiisg%2;
471
+
472
+ device const float * yb = y + ix * QK8_0 + 16*il;
473
+
474
+ // each thread in a SIMD group deals with half a block.
475
+ for (int ib = ix; ib < nb; ib += nw/2) {
476
+ for (int i = 0; i < 16; ++i) {
477
+ yl[i] = yb[i];
478
+ }
479
+
480
+ for (int row = 0; row < nr; row++) {
481
+ device const int8_t * qs = x[ib+row*nb].qs + 16*il;
482
+ float sumq = 0.f;
483
+ for (int iq = 0; iq < 16; ++iq) {
484
+ sumq += qs[iq] * yl[iq];
485
+ }
486
+ sumf[row] += sumq*x[ib+row*nb].d;
487
+ }
488
+
489
+ yb += QK8_0 * 16;
490
+ }
491
+
492
+ for (int row = 0; row < nr; ++row) {
493
+ const float tot = simd_sum(sumf[row]);
494
+ if (tiisg == 0 && first_row + row < ne01) {
495
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
496
+ }
497
+ }
498
+ }
499
+
500
+ kernel void kernel_mul_mat_f16_f32(
501
+ device const char * src0,
502
+ device const char * src1,
503
+ device float * dst,
504
+ constant int64_t & ne00,
505
+ constant int64_t & ne01,
506
+ constant int64_t & ne02,
507
+ constant uint64_t & nb00,
508
+ constant uint64_t & nb01,
509
+ constant uint64_t & nb02,
510
+ constant int64_t & ne10,
511
+ constant int64_t & ne11,
512
+ constant int64_t & ne12,
513
+ constant uint64_t & nb10,
514
+ constant uint64_t & nb11,
515
+ constant uint64_t & nb12,
516
+ constant int64_t & ne0,
517
+ constant int64_t & ne1,
518
+ threadgroup float * sum [[threadgroup(0)]],
519
+ uint3 tgpig[[threadgroup_position_in_grid]],
520
+ uint3 tpig[[thread_position_in_grid]],
521
+ uint3 tpitg[[thread_position_in_threadgroup]],
522
+ uint3 tptg[[threads_per_threadgroup]]) {
523
+
524
+ const int64_t r0 = tgpig.x;
525
+ const int64_t r1 = tgpig.y;
526
+ const int64_t im = tgpig.z;
527
+
528
+ device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
529
+ device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
530
+
531
+ sum[tpitg.x] = 0.0f;
532
+
533
+ for (int i = tpitg.x; i < ne00; i += tptg.x) {
534
+ sum[tpitg.x] += (float) x[i] * (float) y[i];
535
+ }
536
+
537
+ // accumulate the sum from all threads in the threadgroup
538
+ threadgroup_barrier(mem_flags::mem_threadgroup);
539
+ for (uint i = tptg.x/2; i > 0; i /= 2) {
540
+ if (tpitg.x < i) {
541
+ sum[tpitg.x] += sum[tpitg.x + i];
542
+ }
543
+ threadgroup_barrier(mem_flags::mem_threadgroup);
544
+ }
545
+
546
+ if (tpitg.x == 0) {
547
+ dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
548
+ }
549
+ }
550
+
551
+ kernel void kernel_alibi_f32(
552
+ device const float * src0,
553
+ device float * dst,
554
+ constant int64_t & ne00,
555
+ constant int64_t & ne01,
556
+ constant int64_t & ne02,
557
+ constant int64_t & ne03,
558
+ constant uint64_t & nb00,
559
+ constant uint64_t & nb01,
560
+ constant uint64_t & nb02,
561
+ constant uint64_t & nb03,
562
+ constant int64_t & ne0,
563
+ constant int64_t & ne1,
564
+ constant int64_t & ne2,
565
+ constant int64_t & ne3,
566
+ constant uint64_t & nb0,
567
+ constant uint64_t & nb1,
568
+ constant uint64_t & nb2,
569
+ constant uint64_t & nb3,
570
+ constant float & m0,
571
+ uint3 tgpig[[threadgroup_position_in_grid]],
572
+ uint3 tpitg[[thread_position_in_threadgroup]],
573
+ uint3 ntg[[threads_per_threadgroup]]) {
574
+ const int64_t i03 = tgpig[2];
575
+ const int64_t i02 = tgpig[1];
576
+ const int64_t i01 = tgpig[0];
577
+
578
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
579
+
580
+ const int64_t i3 = n / (ne2*ne1*ne0);
581
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
582
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
583
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
584
+
585
+ device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
586
+ float m_k = pow(m0, i2 + 1);
587
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
588
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
589
+ dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);
590
+ }
591
+ }
592
+
593
+ kernel void kernel_rope(
594
+ device const void * src0,
595
+ device float * dst,
596
+ constant int64_t & ne00,
597
+ constant int64_t & ne01,
598
+ constant int64_t & ne02,
599
+ constant int64_t & ne03,
600
+ constant uint64_t & nb00,
601
+ constant uint64_t & nb01,
602
+ constant uint64_t & nb02,
603
+ constant uint64_t & nb03,
604
+ constant int64_t & ne0,
605
+ constant int64_t & ne1,
606
+ constant int64_t & ne2,
607
+ constant int64_t & ne3,
608
+ constant uint64_t & nb0,
609
+ constant uint64_t & nb1,
610
+ constant uint64_t & nb2,
611
+ constant uint64_t & nb3,
612
+ constant int & n_past,
613
+ constant int & n_dims,
614
+ constant int & mode,
615
+ constant float & freq_base,
616
+ constant float & freq_scale,
617
+ uint3 tpig[[thread_position_in_grid]]) {
618
+ const int64_t i3 = tpig[2];
619
+ const int64_t i2 = tpig[1];
620
+ const int64_t i1 = tpig[0];
621
+
622
+ const bool is_neox = mode & 2;
623
+ const float theta_scale = pow(freq_base, -2.0f/n_dims);
624
+
625
+ const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
626
+
627
+ float theta = freq_scale * (float)p;
628
+
629
+ if (!is_neox) {
630
+ for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
631
+ const float cos_theta = cos(theta);
632
+ const float sin_theta = sin(theta);
633
+
634
+ theta *= theta_scale;
635
+
636
+ device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
637
+ device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
638
+
639
+ const float x0 = src[0];
640
+ const float x1 = src[1];
641
+
642
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
643
+ dst_data[1] = x0*sin_theta + x1*cos_theta;
644
+ }
645
+ } else {
646
+ for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
647
+ for (int64_t ic = 0; ic < n_dims; ic += 2) {
648
+ const float cos_theta = cos(theta);
649
+ const float sin_theta = sin(theta);
650
+
651
+ theta *= theta_scale;
652
+
653
+ const int64_t i0 = ib*n_dims + ic/2;
654
+
655
+ device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
656
+ device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
657
+
658
+ const float x0 = src[0];
659
+ const float x1 = src[n_dims/2];
660
+
661
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
662
+ dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
663
+ }
664
+ }
665
+ }
666
+ }
667
+
668
+ kernel void kernel_cpy_f16_f16(
669
+ device const half * src0,
670
+ device half * dst,
671
+ constant int64_t & ne00,
672
+ constant int64_t & ne01,
673
+ constant int64_t & ne02,
674
+ constant int64_t & ne03,
675
+ constant uint64_t & nb00,
676
+ constant uint64_t & nb01,
677
+ constant uint64_t & nb02,
678
+ constant uint64_t & nb03,
679
+ constant int64_t & ne0,
680
+ constant int64_t & ne1,
681
+ constant int64_t & ne2,
682
+ constant int64_t & ne3,
683
+ constant uint64_t & nb0,
684
+ constant uint64_t & nb1,
685
+ constant uint64_t & nb2,
686
+ constant uint64_t & nb3,
687
+ uint3 tgpig[[threadgroup_position_in_grid]],
688
+ uint3 tpitg[[thread_position_in_threadgroup]],
689
+ uint3 ntg[[threads_per_threadgroup]]) {
690
+ const int64_t i03 = tgpig[2];
691
+ const int64_t i02 = tgpig[1];
692
+ const int64_t i01 = tgpig[0];
693
+
694
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
695
+
696
+ const int64_t i3 = n / (ne2*ne1*ne0);
697
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
698
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
699
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
700
+
701
+ device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
702
+
703
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
704
+ device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
705
+ dst_data[i00] = src[0];
706
+ }
707
+ }
708
+
709
+ kernel void kernel_cpy_f32_f16(
710
+ device const float * src0,
711
+ device half * dst,
712
+ constant int64_t & ne00,
713
+ constant int64_t & ne01,
714
+ constant int64_t & ne02,
715
+ constant int64_t & ne03,
716
+ constant uint64_t & nb00,
717
+ constant uint64_t & nb01,
718
+ constant uint64_t & nb02,
719
+ constant uint64_t & nb03,
720
+ constant int64_t & ne0,
721
+ constant int64_t & ne1,
722
+ constant int64_t & ne2,
723
+ constant int64_t & ne3,
724
+ constant uint64_t & nb0,
725
+ constant uint64_t & nb1,
726
+ constant uint64_t & nb2,
727
+ constant uint64_t & nb3,
728
+ uint3 tgpig[[threadgroup_position_in_grid]],
729
+ uint3 tpitg[[thread_position_in_threadgroup]],
730
+ uint3 ntg[[threads_per_threadgroup]]) {
731
+ const int64_t i03 = tgpig[2];
732
+ const int64_t i02 = tgpig[1];
733
+ const int64_t i01 = tgpig[0];
734
+
735
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
736
+
737
+ const int64_t i3 = n / (ne2*ne1*ne0);
738
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
739
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
740
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
741
+
742
+ device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
743
+
744
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
745
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
746
+
747
+ dst_data[i00] = src[0];
748
+ }
749
+ }
750
+
751
+ kernel void kernel_cpy_f32_f32(
752
+ device const float * src0,
753
+ device float * dst,
754
+ constant int64_t & ne00,
755
+ constant int64_t & ne01,
756
+ constant int64_t & ne02,
757
+ constant int64_t & ne03,
758
+ constant uint64_t & nb00,
759
+ constant uint64_t & nb01,
760
+ constant uint64_t & nb02,
761
+ constant uint64_t & nb03,
762
+ constant int64_t & ne0,
763
+ constant int64_t & ne1,
764
+ constant int64_t & ne2,
765
+ constant int64_t & ne3,
766
+ constant uint64_t & nb0,
767
+ constant uint64_t & nb1,
768
+ constant uint64_t & nb2,
769
+ constant uint64_t & nb3,
770
+ uint3 tgpig[[threadgroup_position_in_grid]],
771
+ uint3 tpitg[[thread_position_in_threadgroup]],
772
+ uint3 ntg[[threads_per_threadgroup]]) {
773
+ const int64_t i03 = tgpig[2];
774
+ const int64_t i02 = tgpig[1];
775
+ const int64_t i01 = tgpig[0];
776
+
777
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
778
+
779
+ const int64_t i3 = n / (ne2*ne1*ne0);
780
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
781
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
782
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
783
+
784
+ device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
785
+
786
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
787
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
788
+
789
+ dst_data[i00] = src[0];
790
+ }
791
+ }
792
+
793
+ //============================================ k-quants ======================================================
794
+
795
+ #ifndef QK_K
796
+ #define QK_K 256
797
+ #else
798
+ static_assert(QK_K == 256 || QK_K == 64, "QK_K must be 256 or 64");
799
+ #endif
800
+
801
+ #if QK_K == 256
802
+ #define K_SCALE_SIZE 12
803
+ #else
804
+ #define K_SCALE_SIZE 4
805
+ #endif
806
+
807
+ typedef struct {
808
+ uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
809
+ uint8_t qs[QK_K/4]; // quants
810
+ half d; // super-block scale for quantized scales
811
+ half dmin; // super-block scale for quantized mins
812
+ } block_q2_K;
813
+ // 84 bytes / block
814
+
815
+ typedef struct {
816
+ uint8_t hmask[QK_K/8]; // quants - high bit
817
+ uint8_t qs[QK_K/4]; // quants - low 2 bits
818
+ #if QK_K == 64
819
+ uint8_t scales[2];
820
+ #else
821
+ uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
822
+ #endif
823
+ half d; // super-block scale
824
+ } block_q3_K;
825
+
826
+ #if QK_K == 64
827
+ typedef struct {
828
+ half d[2]; // super-block scales/mins
829
+ uint8_t scales[2];
830
+ uint8_t qs[QK_K/2]; // 4-bit quants
831
+ } block_q4_K;
832
+ #else
833
+ typedef struct {
834
+ half d; // super-block scale for quantized scales
835
+ half dmin; // super-block scale for quantized mins
836
+ uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
837
+ uint8_t qs[QK_K/2]; // 4--bit quants
838
+ } block_q4_K;
839
+ #endif
840
+
841
+ #if QK_K == 64
842
+ typedef struct {
843
+ half d; // super-block scales/mins
844
+ int8_t scales[QK_K/16]; // 8-bit block scales
845
+ uint8_t qh[QK_K/8]; // quants, high bit
846
+ uint8_t qs[QK_K/2]; // quants, low 4 bits
847
+ } block_q5_K;
848
+ #else
849
+ typedef struct {
850
+ half d; // super-block scale for quantized scales
851
+ half dmin; // super-block scale for quantized mins
852
+ uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
853
+ uint8_t qh[QK_K/8]; // quants, high bit
854
+ uint8_t qs[QK_K/2]; // quants, low 4 bits
855
+ } block_q5_K;
856
+ // 176 bytes / block
857
+ #endif
858
+
859
+ typedef struct {
860
+ uint8_t ql[QK_K/2]; // quants, lower 4 bits
861
+ uint8_t qh[QK_K/4]; // quants, upper 2 bits
862
+ int8_t scales[QK_K/16]; // scales, quantized with 8 bits
863
+ half d; // super-block scale
864
+ } block_q6_K;
865
+ // 210 bytes / block
866
+
867
+ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
868
+ uchar4 r;
869
+ if (j < 4) {
870
+ r[0] = q[j+0] & 63;
871
+ r[2] = q[j+1] & 63;
872
+ r[1] = q[j+4] & 63;
873
+ r[3] = q[j+5] & 63;
874
+ } else {
875
+ r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
876
+ r[2] = (q[j+5] & 0xF) | ((q[j-3] >> 6) << 4);
877
+ r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
878
+ r[3] = (q[j+5] >> 4) | ((q[j+1] >> 6) << 4);
879
+ }
880
+ return r;
881
+ }
882
+
883
+ //====================================== dot products =========================
884
+
885
+ kernel void kernel_mul_mat_q2_K_f32(
886
+ device const void * src0,
887
+ device const float * src1,
888
+ device float * dst,
889
+ constant int64_t & ne00,
890
+ constant int64_t & ne01[[buffer(4)]],
891
+ constant int64_t & ne02[[buffer(5)]],
892
+ constant int64_t & ne10[[buffer(9)]],
893
+ constant int64_t & ne12[[buffer(11)]],
894
+ constant int64_t & ne0[[buffer(15)]],
895
+ constant int64_t & ne1[[buffer(16)]],
896
+ constant uint & gqa[[buffer(17)]],
897
+ uint3 tgpig[[threadgroup_position_in_grid]],
898
+ uint tiisg[[thread_index_in_simdgroup]],
899
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
900
+
901
+ const int nb = ne00/QK_K;
902
+ const int r0 = tgpig.x;
903
+ const int r1 = tgpig.y;
904
+ const int r2 = tgpig.z;
905
+
906
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
907
+ const int ib_row = first_row * nb;
908
+ const uint offset0 = r2/gqa*(nb*ne0);
909
+ device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
910
+ device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
911
+ float yl[32];
912
+ float sumf[N_DST]={0.f}, all_sum;
913
+
914
+ const int step = sizeof(block_q2_K) * nb;
915
+
916
+ #if QK_K == 256
917
+ const int ix = tiisg/8; // 0...3
918
+ const int it = tiisg%8; // 0...7
919
+ const int im = it/4; // 0 or 1
920
+ const int ir = it%4; // 0...3
921
+ const int is = (8*ir)/16;// 0 or 1
922
+
923
+ device const float * y4 = y + ix * QK_K + 128 * im + 8 * ir;
924
+
925
+ for (int ib = ix; ib < nb; ib += 4) {
926
+
927
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
928
+ for (int i = 0; i < 8; ++i) {
929
+ yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
930
+ yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
931
+ yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
932
+ yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
933
+ }
934
+
935
+ device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*im + is;
936
+ device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
937
+ device const half * dh = &x[ib].d;
938
+
939
+ for (int row = 0; row < N_DST; row++) {
940
+
941
+ float4 acc1 = {0.f, 0.f, 0.f, 0.f};
942
+ float4 acc2 = {0.f, 0.f, 0.f, 0.f};
943
+ for (int i = 0; i < 8; i += 2) {
944
+ acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
945
+ acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
946
+ acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
947
+ acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
948
+ acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
949
+ acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
950
+ acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
951
+ acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
952
+ }
953
+ float dall = dh[0];
954
+ float dmin = dh[1] * 1.f/16.f;
955
+ sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
956
+ (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f +
957
+ (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f +
958
+ (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
959
+ dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
960
+
961
+ qs += step/2;
962
+ sc += step;
963
+ dh += step/2;
964
+ }
965
+
966
+ y4 += 4 * QK_K;
967
+ }
968
+ #else
969
+ const int ix = tiisg/2; // 0...15
970
+ const int it = tiisg%2; // 0...1
971
+
972
+ device const float * y4 = y + ix * QK_K + 8 * it;
973
+
974
+ for (int ib = ix; ib < nb; ib += 16) {
975
+
976
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
977
+ for (int i = 0; i < 8; ++i) {
978
+ yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
979
+ yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8];
980
+ yl[i+16] = y4[i+32]; sumy[2] += yl[i+16];
981
+ yl[i+24] = y4[i+48]; sumy[3] += yl[i+24];
982
+ }
983
+
984
+ device const uint8_t * sc = (device const uint8_t *)x[ib].scales;
985
+ device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
986
+ device const half * dh = &x[ib].d;
987
+
988
+ for (int row = 0; row < N_DST; row++) {
989
+
990
+ float4 acc1 = {0.f, 0.f, 0.f, 0.f};
991
+ float4 acc2 = {0.f, 0.f, 0.f, 0.f};
992
+ for (int i = 0; i < 8; i += 2) {
993
+ acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
994
+ acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
995
+ acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
996
+ acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
997
+ acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
998
+ acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
999
+ acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
1000
+ acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
1001
+ }
1002
+
1003
+ float dall = dh[0];
1004
+ float dmin = dh[1];
1005
+ sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
1006
+ (acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f +
1007
+ (acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f +
1008
+ (acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) -
1009
+ dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4));
1010
+
1011
+ qs += step/2;
1012
+ sc += step;
1013
+ dh += step/2;
1014
+ }
1015
+
1016
+ y4 += 16 * QK_K;
1017
+ }
1018
+ #endif
1019
+
1020
+ for (int row = 0; row < N_DST; ++row) {
1021
+ all_sum = simd_sum(sumf[row]);
1022
+ if (tiisg == 0) {
1023
+ dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
1024
+ }
1025
+ }
1026
+ }
1027
+
1028
+ #if QK_K == 256
1029
+ kernel void kernel_mul_mat_q3_K_f32(
1030
+ device const void * src0,
1031
+ device const float * src1,
1032
+ device float * dst,
1033
+ constant int64_t & ne00,
1034
+ constant int64_t & ne01[[buffer(4)]],
1035
+ constant int64_t & ne02[[buffer(5)]],
1036
+ constant int64_t & ne10[[buffer(9)]],
1037
+ constant int64_t & ne12[[buffer(11)]],
1038
+ constant int64_t & ne0[[buffer(15)]],
1039
+ constant int64_t & ne1[[buffer(16)]],
1040
+ constant uint & gqa[[buffer(17)]],
1041
+ uint3 tgpig[[threadgroup_position_in_grid]],
1042
+ uint tiisg[[thread_index_in_simdgroup]],
1043
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1044
+
1045
+ const int nb = ne00/QK_K;
1046
+
1047
+ const int64_t r0 = tgpig.x;
1048
+ const int64_t r1 = tgpig.y;
1049
+ const int64_t r2 = tgpig.z;
1050
+
1051
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
1052
+ const uint offset0 = r2/gqa*(nb*ne0);
1053
+ device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
1054
+ device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
1055
+
1056
+ float yl[16];
1057
+
1058
+ const uint16_t kmask1 = 0x0303;
1059
+ const uint16_t kmask2 = 0x0f0f;
1060
+
1061
+ const int tid = tiisg/2;
1062
+ const int ix = tiisg%2;
1063
+ const int ip = tid/8; // 0 or 1
1064
+ const int il = tid/2 - 4*ip; // 0...3
1065
+ const int ir = tid%2;
1066
+ const int n = 8;
1067
+ const int l0 = n*ir;
1068
+
1069
+ const uint16_t m1 = 1 << (4*ip + il);
1070
+ const uint16_t m2 = m1 << 8;
1071
+
1072
+ const int shift = 2*il;
1073
+ const uint16_t qm1 = 0x0003 << shift;
1074
+ const uint16_t qm2 = 0x0300 << shift;
1075
+ const int32_t v1 = 4 << shift;
1076
+ const int32_t v2 = 1024 << shift;
1077
+
1078
+ const uint16_t s_shift1 = 4*ip;
1079
+ const uint16_t s_shift2 = s_shift1 + 2*(il/2);
1080
+ const int ik = 4 + (il%2);
1081
+
1082
+ const int q_offset = 32*ip + l0;
1083
+ const int y_offset = 128*ip + 32*il + l0;
1084
+
1085
+ const int step = sizeof(block_q3_K) * nb / 2;
1086
+
1087
+ device const float * y1 = yy + ix*QK_K + y_offset;
1088
+
1089
+ float sumf1[2] = {0.f}, sumf2[2] = {0.f};
1090
+ for (int i = ix; i < nb; i += 2) {
1091
+
1092
+ for (int l = 0; l < 8; ++l) {
1093
+ yl[l+0] = y1[l+ 0];
1094
+ yl[l+8] = y1[l+16];
1095
+ }
1096
+
1097
+ device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
1098
+ device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0);
1099
+ device const uint16_t * a = (device const uint16_t *)(x[i].scales);
1100
+ device const half * dh = &x[i].d;
1101
+
1102
+ for (int row = 0; row < 2; ++row) {
1103
+
1104
+ const float d_all = (float)dh[0];
1105
+ const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
1106
+
1107
+ float s1 = 0, s2 = 0;
1108
+ for (int l = 0; l < n; l += 2) {
1109
+ const uint16_t qs = q[l/2];
1110
+ s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1));
1111
+ s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2));
1112
+ }
1113
+ float d = d_all * (s1 + 1.f/256.f * s2);
1114
+ sumf1[row] += d * scales[0];
1115
+ sumf2[row] += d;
1116
+
1117
+ s1 = s2 = 0;
1118
+ for (int l = 0; l < n; l += 2) {
1119
+ const uint16_t qs = q[l/2+8];
1120
+ s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1));
1121
+ s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2));
1122
+ }
1123
+ d = d_all * (s1 + 1.f/256.f * s2);
1124
+ sumf1[row] += d * scales[1];
1125
+ sumf2[row] += d;
1126
+
1127
+ q += step;
1128
+ h += step;
1129
+ a += step;
1130
+ dh += step;
1131
+
1132
+ }
1133
+
1134
+ y1 += 2 * QK_K;
1135
+
1136
+ }
1137
+
1138
+ for (int row = 0; row < 2; ++row) {
1139
+ const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
1140
+ const float tot = simd_sum(sumf);
1141
+ if (tiisg == 0) {
1142
+ dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
1143
+ }
1144
+ }
1145
+ }
1146
+ #else
1147
+ kernel void kernel_mul_mat_q3_K_f32(
1148
+ device const void * src0,
1149
+ device const float * src1,
1150
+ device float * dst,
1151
+ constant int64_t & ne00,
1152
+ constant int64_t & ne01[[buffer(4)]],
1153
+ constant int64_t & ne02[[buffer(5)]],
1154
+ constant int64_t & ne10[[buffer(9)]],
1155
+ constant int64_t & ne12[[buffer(11)]],
1156
+ constant int64_t & ne0[[buffer(15)]],
1157
+ constant int64_t & ne1[[buffer(16)]],
1158
+ constant uint & gqa[[buffer(17)]],
1159
+ uint3 tgpig[[threadgroup_position_in_grid]],
1160
+ uint tiisg[[thread_index_in_simdgroup]],
1161
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1162
+
1163
+ const int nb = ne00/QK_K;
1164
+
1165
+ const int64_t r0 = tgpig.x;
1166
+ const int64_t r1 = tgpig.y;
1167
+ const int64_t r2 = tgpig.z;
1168
+
1169
+ const int row = 2 * r0 + sgitg;
1170
+ const uint offset0 = r2/gqa*(nb*ne0);
1171
+ device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
1172
+ device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
1173
+ const int ix = tiisg/4;
1174
+ const int il = 4 * (tiisg%4);// 0, 4, 8, 12
1175
+ const int im = il/8; // 0, 0, 1, 1
1176
+ const int in = il%8; // 0, 4, 0, 4
1177
+
1178
+ float2 sum = {0.f, 0.f};
1179
+
1180
+ for (int i = ix; i < nb; i += 8) {
1181
+
1182
+ const float d_all = (float)(x[i].d);
1183
+
1184
+ device const uint16_t * q = (device const uint16_t *)(x[i].qs + il);
1185
+ device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in);
1186
+ device const uint16_t * s = (device const uint16_t *)(x[i].scales);
1187
+ device const float * y = yy + i * QK_K + il;
1188
+
1189
+ const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8);
1190
+ const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f;
1191
+ const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f;
1192
+ const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
1193
+
1194
+ for (int l = 0; l < 4; l += 2) {
1195
+ const uint16_t hm = h[l/2] >> im;
1196
+ sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
1197
+ + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
1198
+ + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
1199
+ + y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256));
1200
+ sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024))
1201
+ + y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096))
1202
+ + y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384))
1203
+ + y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536));
1204
+ }
1205
+
1206
+ }
1207
+ const float sumf = sum[0] + sum[1] * 1.f/256.f;
1208
+
1209
+ const float tot = simd_sum(sumf);
1210
+ if (tiisg == 0) {
1211
+ dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
1212
+ }
1213
+
1214
+ }
1215
+ #endif
1216
+
1217
+ #if QK_K == 256
1218
+ kernel void kernel_mul_mat_q4_K_f32(
1219
+ device const void * src0,
1220
+ device const float * src1,
1221
+ device float * dst,
1222
+ constant int64_t & ne00,
1223
+ constant int64_t & ne01[[buffer(4)]],
1224
+ constant int64_t & ne02[[buffer(5)]],
1225
+ constant int64_t & ne10[[buffer(9)]],
1226
+ constant int64_t & ne12[[buffer(11)]],
1227
+ constant int64_t & ne0[[buffer(15)]],
1228
+ constant int64_t & ne1[[buffer(16)]],
1229
+ constant uint & gqa[[buffer(17)]],
1230
+ uint3 tgpig[[threadgroup_position_in_grid]],
1231
+ uint tiisg[[thread_index_in_simdgroup]],
1232
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1233
+
1234
+ const uint16_t kmask1 = 0x3f3f;
1235
+ const uint16_t kmask2 = 0x0f0f;
1236
+ const uint16_t kmask3 = 0xc0c0;
1237
+
1238
+ const int ix = tiisg/8; // 0...3
1239
+ const int it = tiisg%8; // 0...7
1240
+ const int im = it/4; // 0 or 1
1241
+ const int ir = it%4; // 0...3
1242
+
1243
+ const int nb = ne00/QK_K;
1244
+ const int r0 = tgpig.x;
1245
+ const int r1 = tgpig.y;
1246
+ const int r2 = tgpig.z;
1247
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1248
+ const int ib_row = first_row * nb;
1249
+ const uint offset0 = r2/gqa*(nb*ne0);
1250
+ device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
1251
+ device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
1252
+ float yl[16];
1253
+ float yh[16];
1254
+ float sumf[N_DST]={0.f}, all_sum;
1255
+
1256
+ const int step = sizeof(block_q4_K) * nb / 2;
1257
+
1258
+ device const float * y4 = y + ix * QK_K + 64 * im + 8 * ir;
1259
+
1260
+ uint16_t sc16[4];
1261
+ thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
1262
+
1263
+ for (int ib = ix; ib < nb; ib += 4) {
1264
+
1265
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
1266
+ for (int i = 0; i < 8; ++i) {
1267
+ yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0];
1268
+ yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
1269
+ yh[i+0] = y4[i+128]; sumy[2] += yh[i+0];
1270
+ yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
1271
+ }
1272
+
1273
+ device const uint16_t * sc = (device const uint16_t *)x[ib].scales + im;
1274
+ device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
1275
+ device const half * dh = &x[ib].d;
1276
+
1277
+ for (int row = 0; row < N_DST; row++) {
1278
+
1279
+ sc16[0] = sc[0] & kmask1;
1280
+ sc16[1] = sc[2] & kmask1;
1281
+ sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
1282
+ sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
1283
+
1284
+ device const uint16_t * q2 = q1 + 32;
1285
+
1286
+ float4 acc1 = {0.f, 0.f, 0.f, 0.f};
1287
+ float4 acc2 = {0.f, 0.f, 0.f, 0.f};
1288
+ for (int i = 0; i < 8; i += 2) {
1289
+ acc1[0] += yl[i+0] * (q1[i/2] & 0x000F);
1290
+ acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00);
1291
+ acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0);
1292
+ acc1[3] += yl[i+9] * (q1[i/2] & 0xF000);
1293
+ acc2[0] += yh[i+0] * (q2[i/2] & 0x000F);
1294
+ acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00);
1295
+ acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0);
1296
+ acc2[3] += yh[i+9] * (q2[i/2] & 0xF000);
1297
+ }
1298
+
1299
+ float dall = dh[0];
1300
+ float dmin = dh[1];
1301
+ sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
1302
+ (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
1303
+ (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
1304
+ (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
1305
+ dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
1306
+
1307
+ q1 += step;
1308
+ sc += step;
1309
+ dh += step;
1310
+ }
1311
+
1312
+ y4 += 4 * QK_K;
1313
+ }
1314
+
1315
+ for (int row = 0; row < N_DST; ++row) {
1316
+ all_sum = simd_sum(sumf[row]);
1317
+ if (tiisg == 0) {
1318
+ dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
1319
+ }
1320
+ }
1321
+ }
1322
+ #else
1323
+ kernel void kernel_mul_mat_q4_K_f32(
1324
+ device const void * src0,
1325
+ device const float * src1,
1326
+ device float * dst,
1327
+ constant int64_t & ne00,
1328
+ constant int64_t & ne01[[buffer(4)]],
1329
+ constant int64_t & ne02[[buffer(5)]],
1330
+ constant int64_t & ne10[[buffer(9)]],
1331
+ constant int64_t & ne12[[buffer(11)]],
1332
+ constant int64_t & ne0[[buffer(15)]],
1333
+ constant int64_t & ne1[[buffer(16)]],
1334
+ constant uint & gqa[[buffer(17)]],
1335
+ uint3 tgpig[[threadgroup_position_in_grid]],
1336
+ uint tiisg[[thread_index_in_simdgroup]],
1337
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1338
+
1339
+ const int ix = tiisg/4; // 0...7
1340
+ const int it = tiisg%4; // 0...3
1341
+
1342
+ const int nb = ne00/QK_K;
1343
+ const int r0 = tgpig.x;
1344
+ const int r1 = tgpig.y;
1345
+ const int r2 = tgpig.z;
1346
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1347
+ const int ib_row = first_row * nb;
1348
+ const uint offset0 = r2/gqa*(nb*ne0);
1349
+ device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
1350
+ device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
1351
+ float yl[8];
1352
+ float yh[8];
1353
+ float sumf[N_DST]={0.f}, all_sum;
1354
+
1355
+ const int step = sizeof(block_q4_K) * nb / 2;
1356
+
1357
+ device const float * y4 = y + ix * QK_K + 8 * it;
1358
+
1359
+ uint16_t sc16[4];
1360
+
1361
+ for (int ib = ix; ib < nb; ib += 8) {
1362
+
1363
+ float2 sumy = {0.f, 0.f};
1364
+ for (int i = 0; i < 8; ++i) {
1365
+ yl[i] = y4[i+ 0]; sumy[0] += yl[i];
1366
+ yh[i] = y4[i+32]; sumy[1] += yh[i];
1367
+ }
1368
+
1369
+ device const uint16_t * sc = (device const uint16_t *)x[ib].scales;
1370
+ device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
1371
+ device const half * dh = x[ib].d;
1372
+
1373
+ for (int row = 0; row < N_DST; row++) {
1374
+
1375
+ sc16[0] = sc[0] & 0x000f;
1376
+ sc16[1] = sc[0] & 0x0f00;
1377
+ sc16[2] = sc[0] & 0x00f0;
1378
+ sc16[3] = sc[0] & 0xf000;
1379
+
1380
+ float2 acc1 = {0.f, 0.f};
1381
+ float2 acc2 = {0.f, 0.f};
1382
+ for (int i = 0; i < 8; i += 2) {
1383
+ acc1[0] += yl[i+0] * (qs[i/2] & 0x000F);
1384
+ acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00);
1385
+ acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0);
1386
+ acc2[1] += yh[i+1] * (qs[i/2] & 0xF000);
1387
+ }
1388
+
1389
+ float dall = dh[0];
1390
+ float dmin = dh[1];
1391
+ sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] +
1392
+ (acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) -
1393
+ dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f);
1394
+
1395
+ qs += step;
1396
+ sc += step;
1397
+ dh += step;
1398
+ }
1399
+
1400
+ y4 += 8 * QK_K;
1401
+ }
1402
+
1403
+ for (int row = 0; row < N_DST; ++row) {
1404
+ all_sum = simd_sum(sumf[row]);
1405
+ if (tiisg == 0) {
1406
+ dst[r1*ne0+ r2*ne0*ne1 + first_row + row] = all_sum;
1407
+ }
1408
+ }
1409
+ }
1410
+ #endif
1411
+
1412
+ kernel void kernel_mul_mat_q5_K_f32(
1413
+ device const void * src0,
1414
+ device const float * src1,
1415
+ device float * dst,
1416
+ constant int64_t & ne00,
1417
+ constant int64_t & ne01[[buffer(4)]],
1418
+ constant int64_t & ne02[[buffer(5)]],
1419
+ constant int64_t & ne10[[buffer(9)]],
1420
+ constant int64_t & ne12[[buffer(11)]],
1421
+ constant int64_t & ne0[[buffer(15)]],
1422
+ constant int64_t & ne1[[buffer(16)]],
1423
+ constant uint & gqa[[buffer(17)]],
1424
+ uint3 tgpig[[threadgroup_position_in_grid]],
1425
+ uint tiisg[[thread_index_in_simdgroup]],
1426
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1427
+
1428
+ const int nb = ne00/QK_K;
1429
+
1430
+ const int64_t r0 = tgpig.x;
1431
+ const int64_t r1 = tgpig.y;
1432
+ const int r2 = tgpig.z;
1433
+
1434
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
1435
+ const uint offset0 = r2/gqa*(nb*ne0);
1436
+ device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
1437
+ device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
1438
+
1439
+ float sumf[2]={0.f};
1440
+
1441
+ const int step = sizeof(block_q5_K) * nb;
1442
+
1443
+ #if QK_K == 256
1444
+ #
1445
+ float yl[16], yh[16];
1446
+
1447
+ const uint16_t kmask1 = 0x3f3f;
1448
+ const uint16_t kmask2 = 0x0f0f;
1449
+ const uint16_t kmask3 = 0xc0c0;
1450
+
1451
+ const int tid = tiisg/4;
1452
+ const int ix = tiisg%4;
1453
+ const int im = tid/4;
1454
+ const int ir = tid%4;
1455
+ const int n = 8;
1456
+
1457
+ const int l0 = n*ir;
1458
+ const int q_offset = 32*im + l0;
1459
+ const int y_offset = 64*im + l0;
1460
+
1461
+ const uint8_t hm1 = 1u << (2*im);
1462
+ const uint8_t hm2 = hm1 << 1;
1463
+ const uint8_t hm3 = hm1 << 4;
1464
+ const uint8_t hm4 = hm2 << 4;
1465
+
1466
+ uint16_t sc16[4];
1467
+ thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
1468
+
1469
+ device const float * y1 = yy + ix*QK_K + y_offset;
1470
+
1471
+ for (int i = ix; i < nb; i += 4) {
1472
+
1473
+ device const uint8_t * q1 = x[i].qs + q_offset;
1474
+ device const uint8_t * qh = x[i].qh + l0;
1475
+ device const half * dh = &x[i].d;
1476
+ device const uint16_t * a = (device const uint16_t *)x[i].scales + im;
1477
+
1478
+ device const float * y2 = y1 + 128;
1479
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
1480
+ for (int l = 0; l < 8; ++l) {
1481
+ yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
1482
+ yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
1483
+ yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
1484
+ yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
1485
+ }
1486
+
1487
+ for (int row = 0; row < 2; ++row) {
1488
+
1489
+ device const uint8_t * q2 = q1 + 64;
1490
+
1491
+ sc16[0] = a[0] & kmask1;
1492
+ sc16[1] = a[2] & kmask1;
1493
+ sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
1494
+ sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
1495
+
1496
+ float4 acc = {0.f, 0.f, 0.f, 0.f};
1497
+ for (int l = 0; l < n; ++l) {
1498
+ uint8_t h = qh[l];
1499
+ acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0));
1500
+ acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0));
1501
+ acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0));
1502
+ acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0));
1503
+ }
1504
+ const float dall = dh[0];
1505
+ const float dmin = dh[1];
1506
+ sumf[row] += dall * (acc[0] * sc8[0] + acc[1] * sc8[1] * 1.f/16.f + acc[2] * sc8[4] + acc[3] * sc8[5] * 1.f/16.f) -
1507
+ dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
1508
+
1509
+ q1 += step;
1510
+ qh += step;
1511
+ dh += step/2;
1512
+ a += step/2;
1513
+
1514
+ }
1515
+
1516
+ y1 += 4 * QK_K;
1517
+
1518
+ }
1519
+ #else
1520
+ float yl[8], yh[8];
1521
+
1522
+ const int il = 4 * (tiisg/8); // 0, 4, 8, 12
1523
+ const int ix = tiisg%8;
1524
+ const int im = il/8; // 0, 0, 1, 1
1525
+ const int in = il%8; // 0, 4, 0, 4
1526
+
1527
+ device const float * y = yy + ix*QK_K + il;
1528
+
1529
+ for (int i = ix; i < nb; i += 8) {
1530
+
1531
+ for (int l = 0; l < 4; ++l) {
1532
+ yl[l+0] = y[l+ 0];
1533
+ yl[l+4] = y[l+16];
1534
+ yh[l+0] = y[l+32];
1535
+ yh[l+4] = y[l+48];
1536
+ }
1537
+
1538
+ device const half * dh = &x[i].d;
1539
+ device const uint8_t * q = x[i].qs + il;
1540
+ device const uint8_t * h = x[i].qh + in;
1541
+ device const int8_t * s = x[i].scales;
1542
+
1543
+ for (int row = 0; row < 2; ++row) {
1544
+
1545
+ const float d = dh[0];
1546
+
1547
+ float2 acc = {0.f, 0.f};
1548
+ for (int l = 0; l < 4; ++l) {
1549
+ const uint8_t hl = h[l] >> im;
1550
+ acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
1551
+ + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
1552
+ acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
1553
+ + yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256));
1554
+ }
1555
+ sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]);
1556
+
1557
+ q += step;
1558
+ h += step;
1559
+ s += step;
1560
+ dh += step/2;
1561
+
1562
+ }
1563
+
1564
+ y += 8 * QK_K;
1565
+ }
1566
+ #endif
1567
+
1568
+ for (int row = 0; row < 2; ++row) {
1569
+ const float tot = simd_sum(sumf[row]);
1570
+ if (tiisg == 0) {
1571
+ dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
1572
+ }
1573
+ }
1574
+
1575
+ }
1576
+
1577
+ kernel void kernel_mul_mat_q6_K_f32(
1578
+ device const void * src0,
1579
+ device const float * src1,
1580
+ device float * dst,
1581
+ constant int64_t & ne00,
1582
+ constant int64_t & ne01[[buffer(4)]],
1583
+ constant int64_t & ne02[[buffer(5)]],
1584
+ constant int64_t & ne10[[buffer(9)]],
1585
+ constant int64_t & ne12[[buffer(11)]],
1586
+ constant int64_t & ne0[[buffer(15)]],
1587
+ constant int64_t & ne1[[buffer(16)]],
1588
+ constant uint & gqa[[buffer(17)]],
1589
+ uint3 tgpig[[threadgroup_position_in_grid]],
1590
+ uint tiisg[[thread_index_in_simdgroup]],
1591
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1592
+
1593
+ const uint8_t kmask1 = 0x03;
1594
+ const uint8_t kmask2 = 0x0C;
1595
+ const uint8_t kmask3 = 0x30;
1596
+ const uint8_t kmask4 = 0xC0;
1597
+
1598
+ const int nb = ne00/QK_K;
1599
+
1600
+ const int64_t r0 = tgpig.x;
1601
+ const int64_t r1 = tgpig.y;
1602
+ const int r2 = tgpig.z;
1603
+
1604
+ const int row = 2 * r0 + sgitg;
1605
+ const uint offset0 = r2/gqa*(nb*ne0);
1606
+ device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
1607
+ device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
1608
+
1609
+ float sumf = 0;
1610
+
1611
+ #if QK_K == 256
1612
+ const int tid = tiisg/2;
1613
+ const int ix = tiisg%2;
1614
+ const int ip = tid/8; // 0 or 1
1615
+ const int il = tid%8;
1616
+ const int n = 4;
1617
+ const int l0 = n*il;
1618
+ const int is = 8*ip + l0/16;
1619
+
1620
+ const int y_offset = 128*ip + l0;
1621
+ const int q_offset_l = 64*ip + l0;
1622
+ const int q_offset_h = 32*ip + l0;
1623
+
1624
+ for (int i = ix; i < nb; i += 2) {
1625
+
1626
+ device const uint8_t * q1 = x[i].ql + q_offset_l;
1627
+ device const uint8_t * q2 = q1 + 32;
1628
+ device const uint8_t * qh = x[i].qh + q_offset_h;
1629
+ device const int8_t * sc = x[i].scales + is;
1630
+
1631
+ device const float * y = yy + i * QK_K + y_offset;
1632
+
1633
+ const float dall = x[i].d;
1634
+
1635
+ float4 sums = {0.f, 0.f, 0.f, 0.f};
1636
+ for (int l = 0; l < n; ++l) {
1637
+ sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
1638
+ sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
1639
+ sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
1640
+ sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
1641
+ }
1642
+
1643
+ sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
1644
+
1645
+ }
1646
+
1647
+ #else
1648
+ const int ix = tiisg/4;
1649
+ const int il = 4*(tiisg%4);
1650
+
1651
+ for (int i = ix; i < nb; i += 8) {
1652
+ device const float * y = yy + i * QK_K + il;
1653
+ device const uint8_t * ql = x[i].ql + il;
1654
+ device const uint8_t * qh = x[i].qh + il;
1655
+ device const int8_t * s = x[i].scales;
1656
+
1657
+ const float d = x[i].d;
1658
+
1659
+ float4 sums = {0.f, 0.f, 0.f, 0.f};
1660
+ for (int l = 0; l < 4; ++l) {
1661
+ sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
1662
+ sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
1663
+ sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) >> 0)) - 32);
1664
+ sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
1665
+ }
1666
+ sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]);
1667
+ }
1668
+
1669
+ #endif
1670
+
1671
+ const float tot = simd_sum(sumf);
1672
+ if (tiisg == 0) {
1673
+ dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
1674
+ }
1675
+ }
1676
+
1677
+ //============================= templates and their specializations =============================
1678
+
1679
+ template <typename type4x4>
1680
+ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
1681
+ half4x4 temp = *(((device half4x4 *)src));
1682
+ for (int i = 0; i < 16; i++){
1683
+ reg[i/4][i%4] = temp[i/4][i%4];
1684
+ }
1685
+ }
1686
+
1687
+ template <typename type4x4>
1688
+ void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
1689
+ device const uint16_t * qs = ((device const uint16_t *)xb + 1);
1690
+ const half d = il ? (xb->d / 16.h) : xb->d;
1691
+ const half m = il ? ( -8.h * 16.h) : -8.h;
1692
+ const ushort mask0 = il ? 0x00F0 : 0x000F;
1693
+ const ushort mask1 = il ? 0xF000 : 0x0F00;
1694
+
1695
+ for (int i=0;i<8;i++) {
1696
+ reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) + m) * d;
1697
+ reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d;
1698
+ }
1699
+ }
1700
+
1701
+ template <typename type4x4>
1702
+ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
1703
+ device const uint16_t * qs = ((device const uint16_t *)xb + 2);
1704
+ const half d = il ? (xb->d / 16.h) : xb->d;
1705
+ const half m = xb->m;
1706
+ const ushort mask0 = il ? 0x00F0 : 0x000F;
1707
+ const ushort mask1 = il ? 0xF000 : 0x0F00;
1708
+
1709
+ for (int i=0;i<8;i++) {
1710
+ reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) * d) + m;
1711
+ reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m;
1712
+ }
1713
+ }
1714
+
1715
+ template <typename type4x4>
1716
+ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
1717
+ device const int8_t * qs = ((device const int8_t *)xb->qs);
1718
+ const half d = xb->d;
1719
+
1720
+ for (int i=0;i<16;i++) {
1721
+ reg[i/4][i%4] = (qs[i + 16*il] * d);
1722
+ }
1723
+ }
1724
+
1725
+ template <typename type4x4>
1726
+ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
1727
+ const half d = xb->d;
1728
+ const half min = xb->dmin;
1729
+ device const uint8_t * q = (device const uint8_t *)xb->qs;
1730
+ half dl, ml;
1731
+ uint8_t sc = xb->scales[il];
1732
+
1733
+ #if QK_K == 256
1734
+ q = q + 32*(il/8) + 16*(il&1);
1735
+ il = (il/2)%4;
1736
+ #endif
1737
+ half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
1738
+ uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
1739
+ dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
1740
+ for (int i = 0; i < 16; ++i) {
1741
+ reg[i/4][i%4] = dl * (q[i] & mask) - ml;
1742
+ }
1743
+ }
1744
+
1745
+ template <typename type4x4>
1746
+ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
1747
+ const float d_all = (float)(xb->d);
1748
+ device const uint8_t * q = (device const uint8_t *)xb->qs;
1749
+ device const uint8_t * h = (device const uint8_t *)xb->hmask;
1750
+ device const int8_t * scales = (device const int8_t *)xb->scales;
1751
+
1752
+ #if QK_K == 256
1753
+ q = q + 32 * (il/8) + 16 * (il&1);
1754
+ h = h + 16 * (il&1);
1755
+ uint8_t m = 1 << (il/2);
1756
+ uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \
1757
+ ((il/4)>0 ? 12 : 3);
1758
+ uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
1759
+ uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
1760
+ int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) : \
1761
+ (scale_2&kmask2) | ((scale_1&kmask1) << 4);
1762
+ float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
1763
+
1764
+ il = (il/2)%4;
1765
+ float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
1766
+ uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
1767
+
1768
+ for (int i = 0; i < 16; ++i) {
1769
+ reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i] & m) ? 0 : 4.f/coef));
1770
+ }
1771
+ #else
1772
+ float kcoef = il&1 ? 1.f/16.f : 1.f;
1773
+ uint16_t kmask = il&1 ? 0xF0 : 0x0F;
1774
+ float dl = d_all * ((scales[il/2] & kmask) * kcoef - 8);
1775
+ float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
1776
+ uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
1777
+ uint8_t m = 1<<(il*2);
1778
+ for (int i = 0; i < 16; ++i) {
1779
+ reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i%8] & (m * (1 + i/8))) ? 0 : 4.f/coef));
1780
+ }
1781
+ #endif
1782
+ }
1783
+
1784
+ template <typename type4x4>
1785
+ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
1786
+ device const uint8_t * q = xb->qs;
1787
+
1788
+ #if QK_K == 256
1789
+ const float d = (float)(xb->d);
1790
+ const float min = (float)(xb->dmin);
1791
+ short is = (il/4) * 2;
1792
+ q = q + (il/4) * 32 + 16 * (il&1);
1793
+ il = il%4;
1794
+ const uchar4 sc = get_scale_min_k4(is, xb->scales);
1795
+ const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
1796
+ const float ml = il<2 ? min * sc[1] : min * sc[3];
1797
+ #else
1798
+ q = q + 16 * (il&1);
1799
+ device const uint8_t * s = xb->scales;
1800
+ device const half2 * dh = (device const half2 *)xb->d;
1801
+ const float2 d = (float2)dh[0];
1802
+ const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
1803
+ const float ml = il<2 ? d[1] * (s[0]>>4) : d[1 ]* (s[1]>>4);
1804
+ #endif
1805
+ const ushort mask = il<2 ? 0x0F : 0xF0;
1806
+ for (int i = 0; i < 16; ++i) {
1807
+ reg[i/4][i%4] = dl * (q[i] & mask) - ml;
1808
+ }
1809
+ }
1810
+
1811
+ template <typename type4x4>
1812
+ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) {
1813
+ device const uint8_t * q = xb->qs;
1814
+ device const uint8_t * qh = xb->qh;
1815
+
1816
+ #if QK_K == 256
1817
+ const float d = (float)(xb->d);
1818
+ const float min = (float)(xb->dmin);
1819
+ short is = (il/4) * 2;
1820
+ q = q + 32 * (il/4) + 16 * (il&1);
1821
+ qh = qh + 16 * (il&1);
1822
+ uint8_t ul = 1 << (il/2);
1823
+ il = il%4;
1824
+ const uchar4 sc = get_scale_min_k4(is, xb->scales);
1825
+ const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
1826
+ const float ml = il<2 ? min * sc[1] : min * sc[3];
1827
+
1828
+ const ushort mask = il<2 ? 0x0F : 0xF0;
1829
+ const float qh_val = il<2 ? 16.f : 256.f;
1830
+ for (int i = 0; i < 16; ++i) {
1831
+ reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
1832
+ }
1833
+ #else
1834
+ q = q + 16 * (il&1);
1835
+ device const int8_t * s = xb->scales;
1836
+ const float dl = xb->d * s[il];
1837
+ uint8_t m = 1<<(il*2);
1838
+ const float coef = il<2 ? 1.f : 1.f/16.f;
1839
+ const ushort mask = il<2 ? 0x0F : 0xF0;
1840
+ for (int i = 0; i < 16; ++i) {
1841
+ reg[i/4][i%4] = coef * dl * ((q[i] & mask) - (qh[i%8] & (m*(1+i/8)) ? 0.f : 16.f/coef));
1842
+ }
1843
+ #endif
1844
+ }
1845
+
1846
+ template <typename type4x4>
1847
+ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
1848
+ const float d_all = (float)(xb->d);
1849
+ device const uint8_t * ql = (device const uint8_t *)xb->ql;
1850
+ device const uint8_t * qh = (device const uint8_t *)xb->qh;
1851
+ device const int8_t * scales = (device const int8_t *)xb->scales;
1852
+
1853
+ #if QK_K == 256
1854
+ ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
1855
+ qh = qh + 32*(il/8) + 16*(il&1);
1856
+ float sc = scales[(il%2) + 2 * ((il/2))];
1857
+ il = (il/2)%4;
1858
+ #else
1859
+ ql = ql + 16 * (il&1);
1860
+ float sc = scales[il];
1861
+ #endif
1862
+ for (int i = 0; i < 16; ++i) {
1863
+ uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
1864
+ uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
1865
+ const float coef = il>1 ? 1.f/16.f : 1.f;
1866
+ float q = il&1 ? ((ql[i]&kmask2)|((qh[i]&kmask1)<<2)) - 32.f/coef : \
1867
+ ((ql[i]&kmask2)|((qh[i]&kmask1)<<4)) - 32.f/coef;
1868
+ reg[i/4][i%4] = d_all * sc * q * coef;
1869
+ }
1870
+ }
1871
+
1872
+ template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
1873
+ kernel void kernel_get_rows(
1874
+ device const void * src0,
1875
+ device const int * src1,
1876
+ device float * dst,
1877
+ constant int64_t & ne00,
1878
+ constant uint64_t & nb01,
1879
+ constant uint64_t & nb1,
1880
+ uint tgpig[[threadgroup_position_in_grid]],
1881
+ uint tiitg[[thread_index_in_threadgroup]],
1882
+ uint tptg[[threads_per_threadgroup]]) {
1883
+ const int i = tgpig;
1884
+ const int r = ((device int32_t *) src1)[i];
1885
+
1886
+ for (int ind = tiitg; ind < ne00/16; ind += tptg) {
1887
+ float4x4 temp;
1888
+ dequantize_func(
1889
+ ((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
1890
+ *(((device float4x4 *) ((device char *) dst + i*nb1)) + ind) = temp;
1891
+ }
1892
+ }
1893
+
1894
+ #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
1895
+ #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix A
1896
+ #define BLOCK_SIZE_K 32
1897
+ #define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
1898
+ #define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
1899
+ #define THREAD_PER_BLOCK 128
1900
+ #define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers
1901
+ #define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers
1902
+ #define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8
1903
+ #define SG_MAT_ROW 8
1904
+
1905
+ // each block_q contains 16*nl weights
1906
+ template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
1907
+ kernel void kernel_mul_mm(device const uchar * src0,
1908
+ device const float * src1,
1909
+ device float * dst,
1910
+ constant int64_t & ne00,
1911
+ constant int64_t & ne02,
1912
+ constant int64_t & nb01,
1913
+ constant int64_t & nb02,
1914
+ constant int64_t & ne12,
1915
+ constant int64_t & ne0,
1916
+ constant int64_t & ne1,
1917
+ constant uint & gqa,
1918
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
1919
+ uint3 tgpig[[threadgroup_position_in_grid]],
1920
+ uint tiitg[[thread_index_in_threadgroup]],
1921
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1922
+
1923
+ threadgroup half * sa = ((threadgroup half *)shared_memory);
1924
+ threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
1925
+
1926
+ const uint r0 = tgpig.y;
1927
+ const uint r1 = tgpig.x;
1928
+ const uint im = tgpig.z;
1929
+ // if this block is of 64x32 shape or smaller
1930
+ short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
1931
+ short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
1932
+ // a thread shouldn't load data outside of the matrix
1933
+ short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
1934
+ short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
1935
+
1936
+ simdgroup_half8x8 ma[4];
1937
+ simdgroup_float8x8 mb[2];
1938
+ simdgroup_float8x8 c_res[8];
1939
+ for (int i = 0; i < 8; i++){
1940
+ c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
1941
+ }
1942
+
1943
+ short il = (tiitg % THREAD_PER_ROW);
1944
+ uint offset0 = im/gqa*nb02; ushort offset1 = il/nl;
1945
+ device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
1946
+ device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \
1947
+ + BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne00 * ne1;
1948
+
1949
+ for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
1950
+ //load data and store to threadgroup memory
1951
+ half4x4 temp_a;
1952
+ dequantize_func(x, il, temp_a);
1953
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1954
+ #pragma unroll(16)
1955
+ for (int i = 0; i < 16; i++) {
1956
+ *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
1957
+ + 16 * (tiitg % THREAD_PER_ROW) + 8 * (i / 8)) \
1958
+ + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
1959
+ }
1960
+ *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) \
1961
+ = *((device float2x4 *)y);
1962
+ il = (il + 2 < nl) ? il + 2 : il % 2;
1963
+ x = (il < 2) ? x + (2+nl-1)/nl : x;
1964
+ y += BLOCK_SIZE_K;
1965
+
1966
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1967
+ //load matrices from threadgroup memory and conduct outer products
1968
+ threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
1969
+ threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
1970
+ #pragma unroll(4)
1971
+ for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
1972
+ #pragma unroll(4)
1973
+ for (int i = 0; i < 4; i++) {
1974
+ simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
1975
+ }
1976
+ simdgroup_barrier(mem_flags::mem_none);
1977
+ #pragma unroll(2)
1978
+ for (int i = 0; i < 2; i++) {
1979
+ simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
1980
+ }
1981
+
1982
+ lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
1983
+ lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
1984
+ #pragma unroll(8)
1985
+ for (int i = 0; i < 8; i++){
1986
+ simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
1987
+ }
1988
+ }
1989
+ }
1990
+
1991
+ if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
1992
+ device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1) \
1993
+ + (BLOCK_SIZE_N * r1 + 16 * (sgitg>>1)) * ne0 + im*ne1*ne0;
1994
+ for (int i = 0; i < 8; i++) {
1995
+ simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
1996
+ }
1997
+ } else {
1998
+ // block is smaller than 64x32, we should avoid writing data outside of the matrix
1999
+ threadgroup_barrier(mem_flags::mem_threadgroup);
2000
+ threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
2001
+ + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
2002
+ for (int i = 0; i < 8; i++) {
2003
+ simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
2004
+ }
2005
+
2006
+ threadgroup_barrier(mem_flags::mem_threadgroup);
2007
+ device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
2008
+ if (sgitg==0) {
2009
+ for (int i = 0; i < n_rows; i++) {
2010
+ for (int j = tiitg; j< n_cols; j += BLOCK_SIZE_N) {
2011
+ *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
2012
+ }
2013
+ }
2014
+ }
2015
+ }
2016
+ }
2017
+
2018
+ #if QK_K == 256
2019
+ #define QK_NL 16
2020
+ #else
2021
+ #define QK_NL 4
2022
+ #endif
2023
+
2024
+ typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
2025
+ constant uint64_t &, constant uint64_t &, uint, uint, uint);
2026
+
2027
+ template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
2028
+ template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
2029
+ template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
2030
+ template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
2031
+ template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
2032
+ template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
2033
+ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
2034
+ template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
2035
+ template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
2036
+
2037
+ typedef void (mat_mm_t)(device const uchar *, device const float *, device float *, constant int64_t &,\
2038
+ constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \
2039
+ constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint);
2040
+
2041
+ template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
2042
+ template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
2043
+ template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
2044
+ template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
2045
+ template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
2046
+ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
2047
+ template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
2048
+ template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
2049
+ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;