llama_cpp 0.3.7 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +17 -0
- data/README.md +1 -1
- data/examples/chat.rb +2 -4
- data/ext/llama_cpp/extconf.rb +3 -3
- data/ext/llama_cpp/llama_cpp.cpp +118 -117
- data/ext/llama_cpp/src/ggml-alloc.c +97 -53
- data/ext/llama_cpp/src/ggml-alloc.h +4 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +1010 -497
- data/ext/llama_cpp/src/ggml-cuda.h +32 -23
- data/ext/llama_cpp/src/ggml-metal.h +9 -3
- data/ext/llama_cpp/src/ggml-metal.m +142 -161
- data/ext/llama_cpp/src/ggml-metal.metal +577 -500
- data/ext/llama_cpp/src/ggml.c +2064 -233
- data/ext/llama_cpp/src/ggml.h +238 -13
- data/ext/llama_cpp/src/k_quants.c +110 -54
- data/ext/llama_cpp/src/llama-util.h +10 -8
- data/ext/llama_cpp/src/llama.cpp +4544 -2890
- data/ext/llama_cpp/src/llama.h +133 -123
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +1 -1
- data/sig/llama_cpp.rbs +8 -8
- metadata +2 -2
@@ -18,46 +18,11 @@ typedef struct {
|
|
18
18
|
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
19
19
|
} block_q4_1;
|
20
20
|
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
const int nb = k / qk;
|
27
|
-
|
28
|
-
for (int i = 0; i < nb; i++) {
|
29
|
-
const half d = x[i].d;
|
30
|
-
|
31
|
-
for (int j = 0; j < qk/2; ++j) {
|
32
|
-
const int x0 = (x[i].qs[j] & 0x0F) - 8;
|
33
|
-
const int x1 = (x[i].qs[j] >> 4) - 8;
|
34
|
-
|
35
|
-
y[i*qk + j + 0 ] = x0*d;
|
36
|
-
y[i*qk + j + qk/2] = x1*d;
|
37
|
-
}
|
38
|
-
}
|
39
|
-
}
|
40
|
-
|
41
|
-
static void dequantize_row_q4_1(device const block_q4_1 * x, device float * y, int k) {
|
42
|
-
const int qk = QK4_1;
|
43
|
-
|
44
|
-
assert(k % qk == 0);
|
45
|
-
|
46
|
-
const int nb = k / qk;
|
47
|
-
|
48
|
-
for (int i = 0; i < nb; i++) {
|
49
|
-
const half d = x[i].d;
|
50
|
-
const half m = x[i].m;
|
51
|
-
|
52
|
-
for (int j = 0; j < qk/2; ++j) {
|
53
|
-
const int x0 = (x[i].qs[j] & 0x0F);
|
54
|
-
const int x1 = (x[i].qs[j] >> 4);
|
55
|
-
|
56
|
-
y[i*qk + j + 0 ] = x0*d + m;
|
57
|
-
y[i*qk + j + qk/2] = x1*d + m;
|
58
|
-
}
|
59
|
-
}
|
60
|
-
}
|
21
|
+
#define QK8_0 32
|
22
|
+
typedef struct {
|
23
|
+
half d; // delta
|
24
|
+
int8_t qs[QK8_0]; // quants
|
25
|
+
} block_q8_0;
|
61
26
|
|
62
27
|
kernel void kernel_add(
|
63
28
|
device const float * src0,
|
@@ -128,7 +93,12 @@ kernel void kernel_gelu(
|
|
128
93
|
device float * dst,
|
129
94
|
uint tpig[[thread_position_in_grid]]) {
|
130
95
|
float x = src0[tpig];
|
131
|
-
|
96
|
+
|
97
|
+
// BEWARE !!!
|
98
|
+
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
|
99
|
+
// This was observed with Falcon 7B and 40B models
|
100
|
+
//
|
101
|
+
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
132
102
|
}
|
133
103
|
|
134
104
|
kernel void kernel_soft_max(
|
@@ -219,54 +189,6 @@ kernel void kernel_diag_mask_inf(
|
|
219
189
|
}
|
220
190
|
}
|
221
191
|
|
222
|
-
kernel void kernel_get_rows_f16(
|
223
|
-
device const void * src0,
|
224
|
-
device const int * src1,
|
225
|
-
device float * dst,
|
226
|
-
constant int64_t & ne00,
|
227
|
-
constant uint64_t & nb01,
|
228
|
-
constant uint64_t & nb1,
|
229
|
-
uint tpig[[thread_position_in_grid]]) {
|
230
|
-
const int i = tpig;
|
231
|
-
const int r = ((device int32_t *) src1)[i];
|
232
|
-
|
233
|
-
for (int j = 0; j < ne00; j++) {
|
234
|
-
dst[i*nb1 + j] = ((device half *) ((device char *) src0 + r*nb01))[j];
|
235
|
-
}
|
236
|
-
}
|
237
|
-
|
238
|
-
kernel void kernel_get_rows_q4_0(
|
239
|
-
device const void * src0,
|
240
|
-
device const int * src1,
|
241
|
-
device float * dst,
|
242
|
-
constant int64_t & ne00,
|
243
|
-
constant uint64_t & nb01,
|
244
|
-
constant uint64_t & nb1,
|
245
|
-
uint tpig[[thread_position_in_grid]]) {
|
246
|
-
const int i = tpig;
|
247
|
-
const int r = ((device int32_t *) src1)[i];
|
248
|
-
|
249
|
-
dequantize_row_q4_0(
|
250
|
-
(device const block_q4_0 *) ((device char *) src0 + r*nb01),
|
251
|
-
(device float *) ((device char *) dst + i*nb1), ne00);
|
252
|
-
}
|
253
|
-
|
254
|
-
kernel void kernel_get_rows_q4_1(
|
255
|
-
device const void * src0,
|
256
|
-
device const int * src1,
|
257
|
-
device float * dst,
|
258
|
-
constant int64_t & ne00,
|
259
|
-
constant uint64_t & nb01,
|
260
|
-
constant uint64_t & nb1,
|
261
|
-
uint tpig[[thread_position_in_grid]]) {
|
262
|
-
const int i = tpig;
|
263
|
-
const int r = ((device int32_t *) src1)[i];
|
264
|
-
|
265
|
-
dequantize_row_q4_1(
|
266
|
-
(device const block_q4_1 *) ((device char *) src0 + r*nb01),
|
267
|
-
(device float *) ((device char *) dst + i*nb1), ne00);
|
268
|
-
}
|
269
|
-
|
270
192
|
kernel void kernel_norm(
|
271
193
|
device const void * src0,
|
272
194
|
device float * dst,
|
@@ -432,14 +354,16 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
|
|
432
354
|
// N_DST, so this is another explicit assumption of the implementation.
|
433
355
|
template<typename block_q_type, int nr, int nsg, int nw>
|
434
356
|
void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
|
435
|
-
int64_t ne00, int64_t ne10, int64_t ne0, int64_t
|
436
|
-
|
357
|
+
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
|
358
|
+
uint3 tgpig, uint tiisg, uint sgitg) {
|
437
359
|
const int nb = ne00/QK4_0;
|
438
360
|
const int r0 = tgpig.x;
|
439
361
|
const int r1 = tgpig.y;
|
362
|
+
const int im = tgpig.z;
|
440
363
|
const int first_row = (r0 * nsg + sgitg) * nr;
|
441
|
-
|
442
|
-
device const
|
364
|
+
const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
|
365
|
+
device const block_q_type * x = (device const block_q_type *) src0 + offset0;
|
366
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
443
367
|
float yl[16]; // src1 vector cache
|
444
368
|
float sumf[nr]={0.f};
|
445
369
|
|
@@ -470,7 +394,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
|
|
470
394
|
for (int row = 0; row < nr; ++row) {
|
471
395
|
const float tot = simd_sum(sumf[row]);
|
472
396
|
if (tiisg == 0 && first_row + row < ne01) {
|
473
|
-
dst[r1*ne0 + first_row + row] = tot;
|
397
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
|
474
398
|
}
|
475
399
|
}
|
476
400
|
}
|
@@ -480,13 +404,17 @@ kernel void kernel_mul_mat_q4_0_f32(
|
|
480
404
|
device const float * src1,
|
481
405
|
device float * dst,
|
482
406
|
constant int64_t & ne00,
|
483
|
-
constant int64_t & ne10,
|
484
|
-
constant int64_t & ne0,
|
485
407
|
constant int64_t & ne01[[buffer(4)]],
|
486
|
-
|
408
|
+
constant int64_t & ne02[[buffer(5)]],
|
409
|
+
constant int64_t & ne10[[buffer(9)]],
|
410
|
+
constant int64_t & ne12[[buffer(11)]],
|
411
|
+
constant int64_t & ne0[[buffer(15)]],
|
412
|
+
constant int64_t & ne1[[buffer(16)]],
|
413
|
+
constant uint & gqa[[buffer(17)]],
|
414
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
487
415
|
uint tiisg[[thread_index_in_simdgroup]],
|
488
416
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
489
|
-
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne10,ne0,
|
417
|
+
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
490
418
|
}
|
491
419
|
|
492
420
|
kernel void kernel_mul_mat_q4_1_f32(
|
@@ -494,13 +422,79 @@ kernel void kernel_mul_mat_q4_1_f32(
|
|
494
422
|
device const float * src1,
|
495
423
|
device float * dst,
|
496
424
|
constant int64_t & ne00,
|
497
|
-
constant int64_t & ne10,
|
498
|
-
constant int64_t & ne0,
|
499
425
|
constant int64_t & ne01[[buffer(4)]],
|
500
|
-
|
426
|
+
constant int64_t & ne02[[buffer(5)]],
|
427
|
+
constant int64_t & ne10[[buffer(9)]],
|
428
|
+
constant int64_t & ne12[[buffer(11)]],
|
429
|
+
constant int64_t & ne0[[buffer(15)]],
|
430
|
+
constant int64_t & ne1[[buffer(16)]],
|
431
|
+
constant uint & gqa[[buffer(17)]],
|
432
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
501
433
|
uint tiisg[[thread_index_in_simdgroup]],
|
502
434
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
503
|
-
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne10,ne0,
|
435
|
+
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
436
|
+
}
|
437
|
+
|
438
|
+
kernel void kernel_mul_mat_q8_0_f32(
|
439
|
+
device const void * src0,
|
440
|
+
device const float * src1,
|
441
|
+
device float * dst,
|
442
|
+
constant int64_t & ne00,
|
443
|
+
constant int64_t & ne01[[buffer(4)]],
|
444
|
+
constant int64_t & ne02[[buffer(5)]],
|
445
|
+
constant int64_t & ne10[[buffer(9)]],
|
446
|
+
constant int64_t & ne12[[buffer(11)]],
|
447
|
+
constant int64_t & ne0[[buffer(15)]],
|
448
|
+
constant int64_t & ne1[[buffer(16)]],
|
449
|
+
constant uint & gqa[[buffer(17)]],
|
450
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
451
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
452
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
453
|
+
const int nr = N_DST;
|
454
|
+
const int nsg = N_SIMDGROUP;
|
455
|
+
const int nw = N_SIMDWIDTH;
|
456
|
+
|
457
|
+
const int nb = ne00/QK8_0;
|
458
|
+
const int r0 = tgpig.x;
|
459
|
+
const int r1 = tgpig.y;
|
460
|
+
const int im = tgpig.z;
|
461
|
+
const int first_row = (r0 * nsg + sgitg) * nr;
|
462
|
+
const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
|
463
|
+
device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
|
464
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
465
|
+
|
466
|
+
float yl[16];
|
467
|
+
float sumf[nr]={0.f};
|
468
|
+
|
469
|
+
const int ix = tiisg/2;
|
470
|
+
const int il = tiisg%2;
|
471
|
+
|
472
|
+
device const float * yb = y + ix * QK8_0 + 16*il;
|
473
|
+
|
474
|
+
// each thread in a SIMD group deals with half a block.
|
475
|
+
for (int ib = ix; ib < nb; ib += nw/2) {
|
476
|
+
for (int i = 0; i < 16; ++i) {
|
477
|
+
yl[i] = yb[i];
|
478
|
+
}
|
479
|
+
|
480
|
+
for (int row = 0; row < nr; row++) {
|
481
|
+
device const int8_t * qs = x[ib+row*nb].qs + 16*il;
|
482
|
+
float sumq = 0.f;
|
483
|
+
for (int iq = 0; iq < 16; ++iq) {
|
484
|
+
sumq += qs[iq] * yl[iq];
|
485
|
+
}
|
486
|
+
sumf[row] += sumq*x[ib+row*nb].d;
|
487
|
+
}
|
488
|
+
|
489
|
+
yb += QK8_0 * 16;
|
490
|
+
}
|
491
|
+
|
492
|
+
for (int row = 0; row < nr; ++row) {
|
493
|
+
const float tot = simd_sum(sumf[row]);
|
494
|
+
if (tiisg == 0 && first_row + row < ne01) {
|
495
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
|
496
|
+
}
|
497
|
+
}
|
504
498
|
}
|
505
499
|
|
506
500
|
kernel void kernel_mul_mat_f16_f32(
|
@@ -554,7 +548,6 @@ kernel void kernel_mul_mat_f16_f32(
|
|
554
548
|
}
|
555
549
|
}
|
556
550
|
|
557
|
-
|
558
551
|
kernel void kernel_alibi_f32(
|
559
552
|
device const float * src0,
|
560
553
|
device float * dst,
|
@@ -650,7 +643,25 @@ kernel void kernel_rope(
|
|
650
643
|
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
651
644
|
}
|
652
645
|
} else {
|
653
|
-
|
646
|
+
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
|
647
|
+
for (int64_t ic = 0; ic < n_dims; ic += 2) {
|
648
|
+
const float cos_theta = cos(theta);
|
649
|
+
const float sin_theta = sin(theta);
|
650
|
+
|
651
|
+
theta *= theta_scale;
|
652
|
+
|
653
|
+
const int64_t i0 = ib*n_dims + ic/2;
|
654
|
+
|
655
|
+
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
656
|
+
device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
657
|
+
|
658
|
+
const float x0 = src[0];
|
659
|
+
const float x1 = src[n_dims/2];
|
660
|
+
|
661
|
+
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
662
|
+
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
663
|
+
}
|
664
|
+
}
|
654
665
|
}
|
655
666
|
}
|
656
667
|
|
@@ -869,354 +880,6 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
|
|
869
880
|
return r;
|
870
881
|
}
|
871
882
|
|
872
|
-
//========================================== dequantization =============================
|
873
|
-
|
874
|
-
static void dequantize_row_q2_K(device const block_q2_K * x, device float * y, int k) {
|
875
|
-
assert(k % QK_K == 0);
|
876
|
-
const int nb = k / QK_K;
|
877
|
-
|
878
|
-
for (int i = 0; i < nb; i++) {
|
879
|
-
|
880
|
-
const float d = x[i].d;
|
881
|
-
const float min = x[i].dmin;
|
882
|
-
|
883
|
-
device const uint8_t * q = x[i].qs;
|
884
|
-
|
885
|
-
#if QK_K == 256
|
886
|
-
int is = 0;
|
887
|
-
float dl, ml;
|
888
|
-
for (int n = 0; n < QK_K; n += 128) {
|
889
|
-
int shift = 0;
|
890
|
-
for (int j = 0; j < 4; ++j) {
|
891
|
-
|
892
|
-
uint8_t sc = x[i].scales[is++];
|
893
|
-
dl = d * (sc & 0xF); ml = min * (sc >> 4);
|
894
|
-
for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml;
|
895
|
-
|
896
|
-
sc = x[i].scales[is++];
|
897
|
-
dl = d * (sc & 0xF); ml = min * (sc >> 4);
|
898
|
-
for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml;
|
899
|
-
|
900
|
-
shift += 2;
|
901
|
-
}
|
902
|
-
q += 32;
|
903
|
-
}
|
904
|
-
#else
|
905
|
-
float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4);
|
906
|
-
float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4);
|
907
|
-
float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4);
|
908
|
-
float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4);
|
909
|
-
for (int l = 0; l < 16; ++l) {
|
910
|
-
y[l+ 0] = dl1 * ((q[l] >> 0) & 3) - ml1;
|
911
|
-
y[l+16] = dl2 * ((q[l] >> 2) & 3) - ml2;
|
912
|
-
y[l+32] = dl3 * ((q[l] >> 4) & 3) - ml3;
|
913
|
-
y[l+48] = dl4 * ((q[l] >> 6) & 3) - ml4;
|
914
|
-
}
|
915
|
-
y += QK_K;
|
916
|
-
#endif
|
917
|
-
|
918
|
-
}
|
919
|
-
}
|
920
|
-
|
921
|
-
static void dequantize_row_q3_K(device const block_q3_K * x, device float * y, int k) {
|
922
|
-
assert(k % QK_K == 0);
|
923
|
-
const int nb = k / QK_K;
|
924
|
-
|
925
|
-
#if QK_K == 256
|
926
|
-
|
927
|
-
const uint16_t kmask1 = 0x0303;
|
928
|
-
const uint16_t kmask2 = 0x0f0f;
|
929
|
-
|
930
|
-
uint16_t aux[8];
|
931
|
-
thread const int8_t * scales = (thread const int8_t*)aux;
|
932
|
-
|
933
|
-
for (int i = 0; i < nb; i++) {
|
934
|
-
|
935
|
-
const float d_all = (float)(x[i].d);
|
936
|
-
|
937
|
-
device const uint8_t * q = x[i].qs;
|
938
|
-
device const uint8_t * h = x[i].hmask;
|
939
|
-
uint8_t m = 1;
|
940
|
-
|
941
|
-
device const uint16_t * a = (device const uint16_t *)x[i].scales;
|
942
|
-
aux[0] = (a[0] & kmask2) | (((a[4] >> 0) & kmask1) << 4);
|
943
|
-
aux[1] = (a[1] & kmask2) | (((a[5] >> 0) & kmask1) << 4);
|
944
|
-
aux[2] = (a[2] & kmask2) | (((a[4] >> 2) & kmask1) << 4);
|
945
|
-
aux[3] = (a[3] & kmask2) | (((a[5] >> 2) & kmask1) << 4);
|
946
|
-
aux[4] = ((a[0] >> 4) & kmask2) | (((a[4] >> 4) & kmask1) << 4);
|
947
|
-
aux[5] = ((a[1] >> 4) & kmask2) | (((a[5] >> 4) & kmask1) << 4);
|
948
|
-
aux[6] = ((a[2] >> 4) & kmask2) | (((a[4] >> 6) & kmask1) << 4);
|
949
|
-
aux[7] = ((a[3] >> 4) & kmask2) | (((a[5] >> 6) & kmask1) << 4);
|
950
|
-
|
951
|
-
int is = 0;
|
952
|
-
float dl;
|
953
|
-
for (int n = 0; n < QK_K; n += 128) {
|
954
|
-
int shift = 0;
|
955
|
-
for (int j = 0; j < 4; ++j) {
|
956
|
-
|
957
|
-
dl = d_all * (scales[is++] - 32);
|
958
|
-
for (int l = 0; l < 16; ++l) {
|
959
|
-
*y++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((h[l+ 0] & m) ? 0 : 4));
|
960
|
-
}
|
961
|
-
|
962
|
-
dl = d_all * (scales[is++] - 32);
|
963
|
-
for (int l = 0; l < 16; ++l) {
|
964
|
-
*y++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((h[l+16] & m) ? 0 : 4));
|
965
|
-
}
|
966
|
-
|
967
|
-
shift += 2;
|
968
|
-
m <<= 1;
|
969
|
-
}
|
970
|
-
q += 32;
|
971
|
-
}
|
972
|
-
}
|
973
|
-
#else
|
974
|
-
for (int i = 0; i < nb; i++) {
|
975
|
-
|
976
|
-
const float d_all = (float)(x[i].d);
|
977
|
-
|
978
|
-
device const uint8_t * q = x[i].qs;
|
979
|
-
device const uint8_t * hm = x[i].hmask;
|
980
|
-
|
981
|
-
const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
|
982
|
-
const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
|
983
|
-
const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
|
984
|
-
const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
|
985
|
-
|
986
|
-
for (int l = 0; l < 8; ++l) {
|
987
|
-
uint8_t h = hm[l];
|
988
|
-
y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4));
|
989
|
-
y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4));
|
990
|
-
y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4));
|
991
|
-
y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4));
|
992
|
-
y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4));
|
993
|
-
y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4));
|
994
|
-
y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4));
|
995
|
-
y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4));
|
996
|
-
}
|
997
|
-
y += QK_K;
|
998
|
-
}
|
999
|
-
#endif
|
1000
|
-
|
1001
|
-
}
|
1002
|
-
|
1003
|
-
static void dequantize_row_q4_K(device const block_q4_K * x, device float * y, int k) {
|
1004
|
-
assert(k % QK_K == 0);
|
1005
|
-
const int nb = k / QK_K;
|
1006
|
-
|
1007
|
-
for (int i = 0; i < nb; i++) {
|
1008
|
-
|
1009
|
-
device const uint8_t * q = x[i].qs;
|
1010
|
-
|
1011
|
-
#if QK_K == 256
|
1012
|
-
const float d = x[i].d;
|
1013
|
-
const float min = x[i].dmin;
|
1014
|
-
|
1015
|
-
device const uint8_t * scales = x[i].scales;
|
1016
|
-
|
1017
|
-
int is = 0;
|
1018
|
-
for (int j = 0; j < QK_K; j += 64) {
|
1019
|
-
const uchar4 sc = get_scale_min_k4(is, scales);
|
1020
|
-
const float d1 = d * sc[0]; const float m1 = min * sc[1];
|
1021
|
-
const float d2 = d * sc[2]; const float m2 = min * sc[3];
|
1022
|
-
for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
|
1023
|
-
for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
|
1024
|
-
q += 32; is += 2;
|
1025
|
-
}
|
1026
|
-
#else
|
1027
|
-
device const uint8_t * s = x[i].scales;
|
1028
|
-
device const half2 * dh = (device const half2 *)x[i].d;
|
1029
|
-
const float2 d = (float2)dh[0];
|
1030
|
-
const float d1 = d[0] * (s[0] & 0xF);
|
1031
|
-
const float d2 = d[0] * (s[1] & 0xF);
|
1032
|
-
const float m1 = d[1] * (s[0] >> 4);
|
1033
|
-
const float m2 = d[1] * (s[1] >> 4);
|
1034
|
-
for (int l = 0; l < 32; ++l) {
|
1035
|
-
y[l+ 0] = d1 * (q[l] & 0xF) - m1;
|
1036
|
-
y[l+32] = d2 * (q[l] >> 4) - m2;
|
1037
|
-
}
|
1038
|
-
y += QK_K;
|
1039
|
-
#endif
|
1040
|
-
|
1041
|
-
}
|
1042
|
-
}
|
1043
|
-
|
1044
|
-
static void dequantize_row_q5_K(device const block_q5_K * x, device float * y, int k) {
|
1045
|
-
assert(k % QK_K == 0);
|
1046
|
-
const int nb = k / QK_K;
|
1047
|
-
|
1048
|
-
#if QK_K == 256
|
1049
|
-
for (int i = 0; i < nb; i++) {
|
1050
|
-
|
1051
|
-
const float d = (float)(x[i].d);
|
1052
|
-
const float min = (float)(x[i].dmin);
|
1053
|
-
|
1054
|
-
device const uint8_t * ql = x[i].qs;
|
1055
|
-
device const uint8_t * qh = x[i].qh;
|
1056
|
-
|
1057
|
-
int is = 0;
|
1058
|
-
uint8_t u1 = 1, u2 = 2;
|
1059
|
-
for (int j = 0; j < QK_K; j += 64) {
|
1060
|
-
const uchar4 sc = get_scale_min_k4(is, x[i].scales);
|
1061
|
-
const float d1 = d * sc[0]; const float m1 = min * sc[1];
|
1062
|
-
const float d2 = d * sc[2]; const float m2 = min * sc[3];
|
1063
|
-
for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1;
|
1064
|
-
for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2;
|
1065
|
-
ql += 32; is += 2;
|
1066
|
-
u1 <<= 2; u2 <<= 2;
|
1067
|
-
}
|
1068
|
-
}
|
1069
|
-
#else
|
1070
|
-
for (int i = 0; i < nb; i++) {
|
1071
|
-
|
1072
|
-
const float d = (float)x[i].d;
|
1073
|
-
|
1074
|
-
device const uint8_t * ql = x[i].qs;
|
1075
|
-
device const uint8_t * qh = x[i].qh;
|
1076
|
-
device const int8_t * sc = x[i].scales;
|
1077
|
-
|
1078
|
-
for (int l = 0; l < 8; ++l) {
|
1079
|
-
y[l+ 0] = d * sc[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16));
|
1080
|
-
y[l+ 8] = d * sc[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16));
|
1081
|
-
y[l+16] = d * sc[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16));
|
1082
|
-
y[l+24] = d * sc[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16));
|
1083
|
-
y[l+32] = d * sc[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16));
|
1084
|
-
y[l+40] = d * sc[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16));
|
1085
|
-
y[l+48] = d * sc[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16));
|
1086
|
-
y[l+56] = d * sc[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16));
|
1087
|
-
}
|
1088
|
-
y += QK_K;
|
1089
|
-
}
|
1090
|
-
#endif
|
1091
|
-
|
1092
|
-
}
|
1093
|
-
|
1094
|
-
static void dequantize_row_q6_K(device const block_q6_K * x, device float * y, int k) {
|
1095
|
-
assert(k % QK_K == 0);
|
1096
|
-
const int nb = k / QK_K;
|
1097
|
-
|
1098
|
-
for (int i = 0; i < nb; i++) {
|
1099
|
-
|
1100
|
-
device const uint8_t * ql = x[i].ql;
|
1101
|
-
device const uint8_t * qh = x[i].qh;
|
1102
|
-
device const int8_t * sc = x[i].scales;
|
1103
|
-
|
1104
|
-
const float d = x[i].d;
|
1105
|
-
|
1106
|
-
#if QK_K == 256
|
1107
|
-
for (int n = 0; n < QK_K; n += 128) {
|
1108
|
-
for (int l = 0; l < 32; ++l) {
|
1109
|
-
int is = l/16;
|
1110
|
-
const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
|
1111
|
-
const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
|
1112
|
-
const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
|
1113
|
-
const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
|
1114
|
-
y[l + 0] = d * sc[is + 0] * q1;
|
1115
|
-
y[l + 32] = d * sc[is + 2] * q2;
|
1116
|
-
y[l + 64] = d * sc[is + 4] * q3;
|
1117
|
-
y[l + 96] = d * sc[is + 6] * q4;
|
1118
|
-
}
|
1119
|
-
y += 128;
|
1120
|
-
ql += 64;
|
1121
|
-
qh += 32;
|
1122
|
-
sc += 8;
|
1123
|
-
}
|
1124
|
-
#else
|
1125
|
-
for (int l = 0; l < 16; ++l) {
|
1126
|
-
const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
|
1127
|
-
const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
|
1128
|
-
const int8_t q3 = (int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
|
1129
|
-
const int8_t q4 = (int8_t)((ql[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
|
1130
|
-
y[l+ 0] = d * sc[0] * q1;
|
1131
|
-
y[l+16] = d * sc[1] * q2;
|
1132
|
-
y[l+32] = d * sc[2] * q3;
|
1133
|
-
y[l+48] = d * sc[3] * q4;
|
1134
|
-
}
|
1135
|
-
y += 64;
|
1136
|
-
#endif
|
1137
|
-
}
|
1138
|
-
}
|
1139
|
-
|
1140
|
-
kernel void kernel_get_rows_q2_K(
|
1141
|
-
device const void * src0,
|
1142
|
-
device const int * src1,
|
1143
|
-
device float * dst,
|
1144
|
-
constant int64_t & ne00,
|
1145
|
-
constant uint64_t & nb01,
|
1146
|
-
constant uint64_t & nb1,
|
1147
|
-
uint tpig[[thread_position_in_grid]]) {
|
1148
|
-
const int i = tpig;
|
1149
|
-
const int r = ((device int32_t *) src1)[i];
|
1150
|
-
|
1151
|
-
dequantize_row_q2_K(
|
1152
|
-
(device const block_q2_K *) ((device char *) src0 + r*nb01),
|
1153
|
-
(device float *) ((device char *) dst + i*nb1), ne00);
|
1154
|
-
}
|
1155
|
-
|
1156
|
-
kernel void kernel_get_rows_q3_K(
|
1157
|
-
device const void * src0,
|
1158
|
-
device const int * src1,
|
1159
|
-
device float * dst,
|
1160
|
-
constant int64_t & ne00,
|
1161
|
-
constant uint64_t & nb01,
|
1162
|
-
constant uint64_t & nb1,
|
1163
|
-
uint tpig[[thread_position_in_grid]]) {
|
1164
|
-
const int i = tpig;
|
1165
|
-
const int r = ((device int32_t *) src1)[i];
|
1166
|
-
|
1167
|
-
dequantize_row_q3_K(
|
1168
|
-
(device const block_q3_K *) ((device char *) src0 + r*nb01),
|
1169
|
-
(device float *) ((device char *) dst + i*nb1), ne00);
|
1170
|
-
}
|
1171
|
-
|
1172
|
-
kernel void kernel_get_rows_q4_K(
|
1173
|
-
device const void * src0,
|
1174
|
-
device const int * src1,
|
1175
|
-
device float * dst,
|
1176
|
-
constant int64_t & ne00,
|
1177
|
-
constant uint64_t & nb01,
|
1178
|
-
constant uint64_t & nb1,
|
1179
|
-
uint tpig[[thread_position_in_grid]]) {
|
1180
|
-
const int i = tpig;
|
1181
|
-
const int r = ((device int32_t *) src1)[i];
|
1182
|
-
|
1183
|
-
dequantize_row_q4_K(
|
1184
|
-
(device const block_q4_K *) ((device char *) src0 + r*nb01),
|
1185
|
-
(device float *) ((device char *) dst + i*nb1), ne00);
|
1186
|
-
}
|
1187
|
-
|
1188
|
-
kernel void kernel_get_rows_q5_K(
|
1189
|
-
device const void * src0,
|
1190
|
-
device const int * src1,
|
1191
|
-
device float * dst,
|
1192
|
-
constant int64_t & ne00,
|
1193
|
-
constant uint64_t & nb01,
|
1194
|
-
constant uint64_t & nb1,
|
1195
|
-
uint tpig[[thread_position_in_grid]]) {
|
1196
|
-
const int i = tpig;
|
1197
|
-
const int r = ((device int32_t *) src1)[i];
|
1198
|
-
|
1199
|
-
dequantize_row_q5_K(
|
1200
|
-
(device const block_q5_K *) ((device char *) src0 + r*nb01),
|
1201
|
-
(device float *) ((device char *) dst + i*nb1), ne00);
|
1202
|
-
}
|
1203
|
-
|
1204
|
-
kernel void kernel_get_rows_q6_K(
|
1205
|
-
device const void * src0,
|
1206
|
-
device const int * src1,
|
1207
|
-
device float * dst,
|
1208
|
-
constant int64_t & ne00,
|
1209
|
-
constant uint64_t & nb01,
|
1210
|
-
constant uint64_t & nb1,
|
1211
|
-
uint tpig[[thread_position_in_grid]]) {
|
1212
|
-
const int i = tpig;
|
1213
|
-
const int r = ((device int32_t *) src1)[i];
|
1214
|
-
|
1215
|
-
dequantize_row_q6_K(
|
1216
|
-
(device const block_q6_K *) ((device char *) src0 + r*nb01),
|
1217
|
-
(device float *) ((device char *) dst + i*nb1), ne00);
|
1218
|
-
}
|
1219
|
-
|
1220
883
|
//====================================== dot products =========================
|
1221
884
|
|
1222
885
|
kernel void kernel_mul_mat_q2_K_f32(
|
@@ -1224,21 +887,27 @@ kernel void kernel_mul_mat_q2_K_f32(
|
|
1224
887
|
device const float * src1,
|
1225
888
|
device float * dst,
|
1226
889
|
constant int64_t & ne00,
|
1227
|
-
constant int64_t & ne10,
|
1228
|
-
constant int64_t & ne0,
|
1229
890
|
constant int64_t & ne01[[buffer(4)]],
|
1230
|
-
|
891
|
+
constant int64_t & ne02[[buffer(5)]],
|
892
|
+
constant int64_t & ne10[[buffer(9)]],
|
893
|
+
constant int64_t & ne12[[buffer(11)]],
|
894
|
+
constant int64_t & ne0[[buffer(15)]],
|
895
|
+
constant int64_t & ne1[[buffer(16)]],
|
896
|
+
constant uint & gqa[[buffer(17)]],
|
897
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1231
898
|
uint tiisg[[thread_index_in_simdgroup]],
|
1232
899
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1233
900
|
|
1234
901
|
const int nb = ne00/QK_K;
|
1235
902
|
const int r0 = tgpig.x;
|
1236
903
|
const int r1 = tgpig.y;
|
904
|
+
const int r2 = tgpig.z;
|
1237
905
|
|
1238
906
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
1239
907
|
const int ib_row = first_row * nb;
|
1240
|
-
|
1241
|
-
device const
|
908
|
+
const uint offset0 = r2/gqa*(nb*ne0);
|
909
|
+
device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
|
910
|
+
device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
|
1242
911
|
float yl[32];
|
1243
912
|
float sumf[N_DST]={0.f}, all_sum;
|
1244
913
|
|
@@ -1351,7 +1020,7 @@ kernel void kernel_mul_mat_q2_K_f32(
|
|
1351
1020
|
for (int row = 0; row < N_DST; ++row) {
|
1352
1021
|
all_sum = simd_sum(sumf[row]);
|
1353
1022
|
if (tiisg == 0) {
|
1354
|
-
dst[r1*ne0 + first_row + row] = all_sum;
|
1023
|
+
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
|
1355
1024
|
}
|
1356
1025
|
}
|
1357
1026
|
}
|
@@ -1362,10 +1031,14 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1362
1031
|
device const float * src1,
|
1363
1032
|
device float * dst,
|
1364
1033
|
constant int64_t & ne00,
|
1365
|
-
constant int64_t &
|
1366
|
-
constant int64_t &
|
1367
|
-
constant int64_t &
|
1368
|
-
|
1034
|
+
constant int64_t & ne01[[buffer(4)]],
|
1035
|
+
constant int64_t & ne02[[buffer(5)]],
|
1036
|
+
constant int64_t & ne10[[buffer(9)]],
|
1037
|
+
constant int64_t & ne12[[buffer(11)]],
|
1038
|
+
constant int64_t & ne0[[buffer(15)]],
|
1039
|
+
constant int64_t & ne1[[buffer(16)]],
|
1040
|
+
constant uint & gqa[[buffer(17)]],
|
1041
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1369
1042
|
uint tiisg[[thread_index_in_simdgroup]],
|
1370
1043
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1371
1044
|
|
@@ -1373,11 +1046,12 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1373
1046
|
|
1374
1047
|
const int64_t r0 = tgpig.x;
|
1375
1048
|
const int64_t r1 = tgpig.y;
|
1049
|
+
const int64_t r2 = tgpig.z;
|
1376
1050
|
|
1377
1051
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
|
1378
|
-
|
1379
|
-
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb;
|
1380
|
-
device const float * yy = (device const float *) src1 + r1*ne10;
|
1052
|
+
const uint offset0 = r2/gqa*(nb*ne0);
|
1053
|
+
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
|
1054
|
+
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
|
1381
1055
|
|
1382
1056
|
float yl[16];
|
1383
1057
|
|
@@ -1465,7 +1139,7 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1465
1139
|
const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
|
1466
1140
|
const float tot = simd_sum(sumf);
|
1467
1141
|
if (tiisg == 0) {
|
1468
|
-
dst[r1*ne0 + first_row + row] = tot;
|
1142
|
+
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
|
1469
1143
|
}
|
1470
1144
|
}
|
1471
1145
|
}
|
@@ -1475,10 +1149,14 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1475
1149
|
device const float * src1,
|
1476
1150
|
device float * dst,
|
1477
1151
|
constant int64_t & ne00,
|
1478
|
-
constant int64_t &
|
1479
|
-
constant int64_t &
|
1480
|
-
constant int64_t &
|
1481
|
-
|
1152
|
+
constant int64_t & ne01[[buffer(4)]],
|
1153
|
+
constant int64_t & ne02[[buffer(5)]],
|
1154
|
+
constant int64_t & ne10[[buffer(9)]],
|
1155
|
+
constant int64_t & ne12[[buffer(11)]],
|
1156
|
+
constant int64_t & ne0[[buffer(15)]],
|
1157
|
+
constant int64_t & ne1[[buffer(16)]],
|
1158
|
+
constant uint & gqa[[buffer(17)]],
|
1159
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1482
1160
|
uint tiisg[[thread_index_in_simdgroup]],
|
1483
1161
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1484
1162
|
|
@@ -1486,11 +1164,12 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1486
1164
|
|
1487
1165
|
const int64_t r0 = tgpig.x;
|
1488
1166
|
const int64_t r1 = tgpig.y;
|
1167
|
+
const int64_t r2 = tgpig.z;
|
1489
1168
|
|
1490
1169
|
const int row = 2 * r0 + sgitg;
|
1491
|
-
|
1492
|
-
device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb;
|
1493
|
-
device const float * yy = (device const float *) src1 + r1*ne10;
|
1170
|
+
const uint offset0 = r2/gqa*(nb*ne0);
|
1171
|
+
device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
|
1172
|
+
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
|
1494
1173
|
const int ix = tiisg/4;
|
1495
1174
|
const int il = 4 * (tiisg%4);// 0, 4, 8, 12
|
1496
1175
|
const int im = il/8; // 0, 0, 1, 1
|
@@ -1529,7 +1208,7 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1529
1208
|
|
1530
1209
|
const float tot = simd_sum(sumf);
|
1531
1210
|
if (tiisg == 0) {
|
1532
|
-
dst[r1*ne0 + row] = tot;
|
1211
|
+
dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
|
1533
1212
|
}
|
1534
1213
|
|
1535
1214
|
}
|
@@ -1541,10 +1220,14 @@ kernel void kernel_mul_mat_q4_K_f32(
|
|
1541
1220
|
device const float * src1,
|
1542
1221
|
device float * dst,
|
1543
1222
|
constant int64_t & ne00,
|
1544
|
-
constant int64_t & ne10,
|
1545
|
-
constant int64_t & ne0,
|
1546
1223
|
constant int64_t & ne01[[buffer(4)]],
|
1547
|
-
|
1224
|
+
constant int64_t & ne02[[buffer(5)]],
|
1225
|
+
constant int64_t & ne10[[buffer(9)]],
|
1226
|
+
constant int64_t & ne12[[buffer(11)]],
|
1227
|
+
constant int64_t & ne0[[buffer(15)]],
|
1228
|
+
constant int64_t & ne1[[buffer(16)]],
|
1229
|
+
constant uint & gqa[[buffer(17)]],
|
1230
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1548
1231
|
uint tiisg[[thread_index_in_simdgroup]],
|
1549
1232
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1550
1233
|
|
@@ -1560,10 +1243,12 @@ kernel void kernel_mul_mat_q4_K_f32(
|
|
1560
1243
|
const int nb = ne00/QK_K;
|
1561
1244
|
const int r0 = tgpig.x;
|
1562
1245
|
const int r1 = tgpig.y;
|
1246
|
+
const int r2 = tgpig.z;
|
1563
1247
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
1564
1248
|
const int ib_row = first_row * nb;
|
1565
|
-
|
1566
|
-
device const
|
1249
|
+
const uint offset0 = r2/gqa*(nb*ne0);
|
1250
|
+
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
|
1251
|
+
device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
|
1567
1252
|
float yl[16];
|
1568
1253
|
float yh[16];
|
1569
1254
|
float sumf[N_DST]={0.f}, all_sum;
|
@@ -1630,7 +1315,7 @@ kernel void kernel_mul_mat_q4_K_f32(
|
|
1630
1315
|
for (int row = 0; row < N_DST; ++row) {
|
1631
1316
|
all_sum = simd_sum(sumf[row]);
|
1632
1317
|
if (tiisg == 0) {
|
1633
|
-
dst[r1*ne0 + first_row + row] = all_sum;
|
1318
|
+
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
|
1634
1319
|
}
|
1635
1320
|
}
|
1636
1321
|
}
|
@@ -1640,10 +1325,14 @@ kernel void kernel_mul_mat_q4_K_f32(
|
|
1640
1325
|
device const float * src1,
|
1641
1326
|
device float * dst,
|
1642
1327
|
constant int64_t & ne00,
|
1643
|
-
constant int64_t & ne10,
|
1644
|
-
constant int64_t & ne0,
|
1645
1328
|
constant int64_t & ne01[[buffer(4)]],
|
1646
|
-
|
1329
|
+
constant int64_t & ne02[[buffer(5)]],
|
1330
|
+
constant int64_t & ne10[[buffer(9)]],
|
1331
|
+
constant int64_t & ne12[[buffer(11)]],
|
1332
|
+
constant int64_t & ne0[[buffer(15)]],
|
1333
|
+
constant int64_t & ne1[[buffer(16)]],
|
1334
|
+
constant uint & gqa[[buffer(17)]],
|
1335
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1647
1336
|
uint tiisg[[thread_index_in_simdgroup]],
|
1648
1337
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1649
1338
|
|
@@ -1653,10 +1342,12 @@ kernel void kernel_mul_mat_q4_K_f32(
|
|
1653
1342
|
const int nb = ne00/QK_K;
|
1654
1343
|
const int r0 = tgpig.x;
|
1655
1344
|
const int r1 = tgpig.y;
|
1345
|
+
const int r2 = tgpig.z;
|
1656
1346
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
1657
1347
|
const int ib_row = first_row * nb;
|
1658
|
-
|
1659
|
-
device const
|
1348
|
+
const uint offset0 = r2/gqa*(nb*ne0);
|
1349
|
+
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
|
1350
|
+
device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
|
1660
1351
|
float yl[8];
|
1661
1352
|
float yh[8];
|
1662
1353
|
float sumf[N_DST]={0.f}, all_sum;
|
@@ -1712,7 +1403,7 @@ kernel void kernel_mul_mat_q4_K_f32(
|
|
1712
1403
|
for (int row = 0; row < N_DST; ++row) {
|
1713
1404
|
all_sum = simd_sum(sumf[row]);
|
1714
1405
|
if (tiisg == 0) {
|
1715
|
-
dst[r1*ne0 + first_row + row] = all_sum;
|
1406
|
+
dst[r1*ne0+ r2*ne0*ne1 + first_row + row] = all_sum;
|
1716
1407
|
}
|
1717
1408
|
}
|
1718
1409
|
}
|
@@ -1723,9 +1414,14 @@ kernel void kernel_mul_mat_q5_K_f32(
|
|
1723
1414
|
device const float * src1,
|
1724
1415
|
device float * dst,
|
1725
1416
|
constant int64_t & ne00,
|
1726
|
-
constant int64_t &
|
1727
|
-
constant int64_t &
|
1728
|
-
|
1417
|
+
constant int64_t & ne01[[buffer(4)]],
|
1418
|
+
constant int64_t & ne02[[buffer(5)]],
|
1419
|
+
constant int64_t & ne10[[buffer(9)]],
|
1420
|
+
constant int64_t & ne12[[buffer(11)]],
|
1421
|
+
constant int64_t & ne0[[buffer(15)]],
|
1422
|
+
constant int64_t & ne1[[buffer(16)]],
|
1423
|
+
constant uint & gqa[[buffer(17)]],
|
1424
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1729
1425
|
uint tiisg[[thread_index_in_simdgroup]],
|
1730
1426
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1731
1427
|
|
@@ -1733,11 +1429,12 @@ kernel void kernel_mul_mat_q5_K_f32(
|
|
1733
1429
|
|
1734
1430
|
const int64_t r0 = tgpig.x;
|
1735
1431
|
const int64_t r1 = tgpig.y;
|
1432
|
+
const int r2 = tgpig.z;
|
1736
1433
|
|
1737
1434
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
|
1738
|
-
|
1739
|
-
device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb;
|
1740
|
-
device const float * yy = (device const float *) src1 + r1*ne10;
|
1435
|
+
const uint offset0 = r2/gqa*(nb*ne0);
|
1436
|
+
device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
|
1437
|
+
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
|
1741
1438
|
|
1742
1439
|
float sumf[2]={0.f};
|
1743
1440
|
|
@@ -1871,7 +1568,7 @@ kernel void kernel_mul_mat_q5_K_f32(
|
|
1871
1568
|
for (int row = 0; row < 2; ++row) {
|
1872
1569
|
const float tot = simd_sum(sumf[row]);
|
1873
1570
|
if (tiisg == 0) {
|
1874
|
-
dst[r1*ne0 + first_row + row] = tot;
|
1571
|
+
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
|
1875
1572
|
}
|
1876
1573
|
}
|
1877
1574
|
|
@@ -1882,9 +1579,14 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|
1882
1579
|
device const float * src1,
|
1883
1580
|
device float * dst,
|
1884
1581
|
constant int64_t & ne00,
|
1885
|
-
constant int64_t &
|
1886
|
-
constant int64_t &
|
1887
|
-
|
1582
|
+
constant int64_t & ne01[[buffer(4)]],
|
1583
|
+
constant int64_t & ne02[[buffer(5)]],
|
1584
|
+
constant int64_t & ne10[[buffer(9)]],
|
1585
|
+
constant int64_t & ne12[[buffer(11)]],
|
1586
|
+
constant int64_t & ne0[[buffer(15)]],
|
1587
|
+
constant int64_t & ne1[[buffer(16)]],
|
1588
|
+
constant uint & gqa[[buffer(17)]],
|
1589
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1888
1590
|
uint tiisg[[thread_index_in_simdgroup]],
|
1889
1591
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1890
1592
|
|
@@ -1897,11 +1599,12 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|
1897
1599
|
|
1898
1600
|
const int64_t r0 = tgpig.x;
|
1899
1601
|
const int64_t r1 = tgpig.y;
|
1602
|
+
const int r2 = tgpig.z;
|
1900
1603
|
|
1901
1604
|
const int row = 2 * r0 + sgitg;
|
1902
|
-
|
1903
|
-
device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb
|
1904
|
-
device const float * yy = (device const float *) src1 + r1*ne10;
|
1605
|
+
const uint offset0 = r2/gqa*(nb*ne0);
|
1606
|
+
device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
|
1607
|
+
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
|
1905
1608
|
|
1906
1609
|
float sumf = 0;
|
1907
1610
|
|
@@ -1967,6 +1670,380 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|
1967
1670
|
|
1968
1671
|
const float tot = simd_sum(sumf);
|
1969
1672
|
if (tiisg == 0) {
|
1970
|
-
dst[r1*ne0 + row] = tot;
|
1673
|
+
dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
|
1674
|
+
}
|
1675
|
+
}
|
1676
|
+
|
1677
|
+
//============================= templates and their specializations =============================
|
1678
|
+
|
1679
|
+
template <typename type4x4>
|
1680
|
+
void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
|
1681
|
+
half4x4 temp = *(((device half4x4 *)src));
|
1682
|
+
for (int i = 0; i < 16; i++){
|
1683
|
+
reg[i/4][i%4] = temp[i/4][i%4];
|
1684
|
+
}
|
1685
|
+
}
|
1686
|
+
|
1687
|
+
template <typename type4x4>
|
1688
|
+
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
|
1689
|
+
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
1690
|
+
const half d = il ? (xb->d / 16.h) : xb->d;
|
1691
|
+
const half m = il ? ( -8.h * 16.h) : -8.h;
|
1692
|
+
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
1693
|
+
const ushort mask1 = il ? 0xF000 : 0x0F00;
|
1694
|
+
|
1695
|
+
for (int i=0;i<8;i++) {
|
1696
|
+
reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) + m) * d;
|
1697
|
+
reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d;
|
1698
|
+
}
|
1699
|
+
}
|
1700
|
+
|
1701
|
+
template <typename type4x4>
|
1702
|
+
void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
|
1703
|
+
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
|
1704
|
+
const half d = il ? (xb->d / 16.h) : xb->d;
|
1705
|
+
const half m = xb->m;
|
1706
|
+
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
1707
|
+
const ushort mask1 = il ? 0xF000 : 0x0F00;
|
1708
|
+
|
1709
|
+
for (int i=0;i<8;i++) {
|
1710
|
+
reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) * d) + m;
|
1711
|
+
reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m;
|
1712
|
+
}
|
1713
|
+
}
|
1714
|
+
|
1715
|
+
template <typename type4x4>
|
1716
|
+
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
|
1717
|
+
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
1718
|
+
const half d = xb->d;
|
1719
|
+
|
1720
|
+
for (int i=0;i<16;i++) {
|
1721
|
+
reg[i/4][i%4] = (qs[i + 16*il] * d);
|
1722
|
+
}
|
1723
|
+
}
|
1724
|
+
|
1725
|
+
template <typename type4x4>
|
1726
|
+
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
|
1727
|
+
const half d = xb->d;
|
1728
|
+
const half min = xb->dmin;
|
1729
|
+
device const uint8_t * q = (device const uint8_t *)xb->qs;
|
1730
|
+
half dl, ml;
|
1731
|
+
uint8_t sc = xb->scales[il];
|
1732
|
+
|
1733
|
+
#if QK_K == 256
|
1734
|
+
q = q + 32*(il/8) + 16*(il&1);
|
1735
|
+
il = (il/2)%4;
|
1736
|
+
#endif
|
1737
|
+
half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
1738
|
+
uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
1739
|
+
dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
|
1740
|
+
for (int i = 0; i < 16; ++i) {
|
1741
|
+
reg[i/4][i%4] = dl * (q[i] & mask) - ml;
|
1971
1742
|
}
|
1972
1743
|
}
|
1744
|
+
|
1745
|
+
template <typename type4x4>
|
1746
|
+
void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
|
1747
|
+
const float d_all = (float)(xb->d);
|
1748
|
+
device const uint8_t * q = (device const uint8_t *)xb->qs;
|
1749
|
+
device const uint8_t * h = (device const uint8_t *)xb->hmask;
|
1750
|
+
device const int8_t * scales = (device const int8_t *)xb->scales;
|
1751
|
+
|
1752
|
+
#if QK_K == 256
|
1753
|
+
q = q + 32 * (il/8) + 16 * (il&1);
|
1754
|
+
h = h + 16 * (il&1);
|
1755
|
+
uint8_t m = 1 << (il/2);
|
1756
|
+
uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \
|
1757
|
+
((il/4)>0 ? 12 : 3);
|
1758
|
+
uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
|
1759
|
+
uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
|
1760
|
+
int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) : \
|
1761
|
+
(scale_2&kmask2) | ((scale_1&kmask1) << 4);
|
1762
|
+
float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
|
1763
|
+
|
1764
|
+
il = (il/2)%4;
|
1765
|
+
float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
1766
|
+
uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
1767
|
+
|
1768
|
+
for (int i = 0; i < 16; ++i) {
|
1769
|
+
reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i] & m) ? 0 : 4.f/coef));
|
1770
|
+
}
|
1771
|
+
#else
|
1772
|
+
float kcoef = il&1 ? 1.f/16.f : 1.f;
|
1773
|
+
uint16_t kmask = il&1 ? 0xF0 : 0x0F;
|
1774
|
+
float dl = d_all * ((scales[il/2] & kmask) * kcoef - 8);
|
1775
|
+
float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
1776
|
+
uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
1777
|
+
uint8_t m = 1<<(il*2);
|
1778
|
+
for (int i = 0; i < 16; ++i) {
|
1779
|
+
reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i%8] & (m * (1 + i/8))) ? 0 : 4.f/coef));
|
1780
|
+
}
|
1781
|
+
#endif
|
1782
|
+
}
|
1783
|
+
|
1784
|
+
template <typename type4x4>
|
1785
|
+
void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
|
1786
|
+
device const uint8_t * q = xb->qs;
|
1787
|
+
|
1788
|
+
#if QK_K == 256
|
1789
|
+
const float d = (float)(xb->d);
|
1790
|
+
const float min = (float)(xb->dmin);
|
1791
|
+
short is = (il/4) * 2;
|
1792
|
+
q = q + (il/4) * 32 + 16 * (il&1);
|
1793
|
+
il = il%4;
|
1794
|
+
const uchar4 sc = get_scale_min_k4(is, xb->scales);
|
1795
|
+
const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
|
1796
|
+
const float ml = il<2 ? min * sc[1] : min * sc[3];
|
1797
|
+
#else
|
1798
|
+
q = q + 16 * (il&1);
|
1799
|
+
device const uint8_t * s = xb->scales;
|
1800
|
+
device const half2 * dh = (device const half2 *)xb->d;
|
1801
|
+
const float2 d = (float2)dh[0];
|
1802
|
+
const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
|
1803
|
+
const float ml = il<2 ? d[1] * (s[0]>>4) : d[1 ]* (s[1]>>4);
|
1804
|
+
#endif
|
1805
|
+
const ushort mask = il<2 ? 0x0F : 0xF0;
|
1806
|
+
for (int i = 0; i < 16; ++i) {
|
1807
|
+
reg[i/4][i%4] = dl * (q[i] & mask) - ml;
|
1808
|
+
}
|
1809
|
+
}
|
1810
|
+
|
1811
|
+
template <typename type4x4>
|
1812
|
+
void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) {
|
1813
|
+
device const uint8_t * q = xb->qs;
|
1814
|
+
device const uint8_t * qh = xb->qh;
|
1815
|
+
|
1816
|
+
#if QK_K == 256
|
1817
|
+
const float d = (float)(xb->d);
|
1818
|
+
const float min = (float)(xb->dmin);
|
1819
|
+
short is = (il/4) * 2;
|
1820
|
+
q = q + 32 * (il/4) + 16 * (il&1);
|
1821
|
+
qh = qh + 16 * (il&1);
|
1822
|
+
uint8_t ul = 1 << (il/2);
|
1823
|
+
il = il%4;
|
1824
|
+
const uchar4 sc = get_scale_min_k4(is, xb->scales);
|
1825
|
+
const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
|
1826
|
+
const float ml = il<2 ? min * sc[1] : min * sc[3];
|
1827
|
+
|
1828
|
+
const ushort mask = il<2 ? 0x0F : 0xF0;
|
1829
|
+
const float qh_val = il<2 ? 16.f : 256.f;
|
1830
|
+
for (int i = 0; i < 16; ++i) {
|
1831
|
+
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
|
1832
|
+
}
|
1833
|
+
#else
|
1834
|
+
q = q + 16 * (il&1);
|
1835
|
+
device const int8_t * s = xb->scales;
|
1836
|
+
const float dl = xb->d * s[il];
|
1837
|
+
uint8_t m = 1<<(il*2);
|
1838
|
+
const float coef = il<2 ? 1.f : 1.f/16.f;
|
1839
|
+
const ushort mask = il<2 ? 0x0F : 0xF0;
|
1840
|
+
for (int i = 0; i < 16; ++i) {
|
1841
|
+
reg[i/4][i%4] = coef * dl * ((q[i] & mask) - (qh[i%8] & (m*(1+i/8)) ? 0.f : 16.f/coef));
|
1842
|
+
}
|
1843
|
+
#endif
|
1844
|
+
}
|
1845
|
+
|
1846
|
+
template <typename type4x4>
|
1847
|
+
void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
|
1848
|
+
const float d_all = (float)(xb->d);
|
1849
|
+
device const uint8_t * ql = (device const uint8_t *)xb->ql;
|
1850
|
+
device const uint8_t * qh = (device const uint8_t *)xb->qh;
|
1851
|
+
device const int8_t * scales = (device const int8_t *)xb->scales;
|
1852
|
+
|
1853
|
+
#if QK_K == 256
|
1854
|
+
ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
|
1855
|
+
qh = qh + 32*(il/8) + 16*(il&1);
|
1856
|
+
float sc = scales[(il%2) + 2 * ((il/2))];
|
1857
|
+
il = (il/2)%4;
|
1858
|
+
#else
|
1859
|
+
ql = ql + 16 * (il&1);
|
1860
|
+
float sc = scales[il];
|
1861
|
+
#endif
|
1862
|
+
for (int i = 0; i < 16; ++i) {
|
1863
|
+
uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
1864
|
+
uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
|
1865
|
+
const float coef = il>1 ? 1.f/16.f : 1.f;
|
1866
|
+
float q = il&1 ? ((ql[i]&kmask2)|((qh[i]&kmask1)<<2)) - 32.f/coef : \
|
1867
|
+
((ql[i]&kmask2)|((qh[i]&kmask1)<<4)) - 32.f/coef;
|
1868
|
+
reg[i/4][i%4] = d_all * sc * q * coef;
|
1869
|
+
}
|
1870
|
+
}
|
1871
|
+
|
1872
|
+
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
1873
|
+
kernel void kernel_get_rows(
|
1874
|
+
device const void * src0,
|
1875
|
+
device const int * src1,
|
1876
|
+
device float * dst,
|
1877
|
+
constant int64_t & ne00,
|
1878
|
+
constant uint64_t & nb01,
|
1879
|
+
constant uint64_t & nb1,
|
1880
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
1881
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
1882
|
+
uint tptg[[threads_per_threadgroup]]) {
|
1883
|
+
const int i = tgpig;
|
1884
|
+
const int r = ((device int32_t *) src1)[i];
|
1885
|
+
|
1886
|
+
for (int ind = tiitg; ind < ne00/16; ind += tptg) {
|
1887
|
+
float4x4 temp;
|
1888
|
+
dequantize_func(
|
1889
|
+
((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
|
1890
|
+
*(((device float4x4 *) ((device char *) dst + i*nb1)) + ind) = temp;
|
1891
|
+
}
|
1892
|
+
}
|
1893
|
+
|
1894
|
+
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
|
1895
|
+
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix A
|
1896
|
+
#define BLOCK_SIZE_K 32
|
1897
|
+
#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
|
1898
|
+
#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
|
1899
|
+
#define THREAD_PER_BLOCK 128
|
1900
|
+
#define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers
|
1901
|
+
#define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers
|
1902
|
+
#define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8
|
1903
|
+
#define SG_MAT_ROW 8
|
1904
|
+
|
1905
|
+
// each block_q contains 16*nl weights
|
1906
|
+
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
1907
|
+
kernel void kernel_mul_mm(device const uchar * src0,
|
1908
|
+
device const float * src1,
|
1909
|
+
device float * dst,
|
1910
|
+
constant int64_t & ne00,
|
1911
|
+
constant int64_t & ne02,
|
1912
|
+
constant int64_t & nb01,
|
1913
|
+
constant int64_t & nb02,
|
1914
|
+
constant int64_t & ne12,
|
1915
|
+
constant int64_t & ne0,
|
1916
|
+
constant int64_t & ne1,
|
1917
|
+
constant uint & gqa,
|
1918
|
+
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
1919
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1920
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
1921
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1922
|
+
|
1923
|
+
threadgroup half * sa = ((threadgroup half *)shared_memory);
|
1924
|
+
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
|
1925
|
+
|
1926
|
+
const uint r0 = tgpig.y;
|
1927
|
+
const uint r1 = tgpig.x;
|
1928
|
+
const uint im = tgpig.z;
|
1929
|
+
// if this block is of 64x32 shape or smaller
|
1930
|
+
short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
|
1931
|
+
short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
|
1932
|
+
// a thread shouldn't load data outside of the matrix
|
1933
|
+
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
1934
|
+
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
1935
|
+
|
1936
|
+
simdgroup_half8x8 ma[4];
|
1937
|
+
simdgroup_float8x8 mb[2];
|
1938
|
+
simdgroup_float8x8 c_res[8];
|
1939
|
+
for (int i = 0; i < 8; i++){
|
1940
|
+
c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
|
1941
|
+
}
|
1942
|
+
|
1943
|
+
short il = (tiitg % THREAD_PER_ROW);
|
1944
|
+
uint offset0 = im/gqa*nb02; ushort offset1 = il/nl;
|
1945
|
+
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
|
1946
|
+
device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \
|
1947
|
+
+ BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne00 * ne1;
|
1948
|
+
|
1949
|
+
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
1950
|
+
//load data and store to threadgroup memory
|
1951
|
+
half4x4 temp_a;
|
1952
|
+
dequantize_func(x, il, temp_a);
|
1953
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1954
|
+
#pragma unroll(16)
|
1955
|
+
for (int i = 0; i < 16; i++) {
|
1956
|
+
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
|
1957
|
+
+ 16 * (tiitg % THREAD_PER_ROW) + 8 * (i / 8)) \
|
1958
|
+
+ (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
|
1959
|
+
}
|
1960
|
+
*(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) \
|
1961
|
+
= *((device float2x4 *)y);
|
1962
|
+
il = (il + 2 < nl) ? il + 2 : il % 2;
|
1963
|
+
x = (il < 2) ? x + (2+nl-1)/nl : x;
|
1964
|
+
y += BLOCK_SIZE_K;
|
1965
|
+
|
1966
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1967
|
+
//load matrices from threadgroup memory and conduct outer products
|
1968
|
+
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
|
1969
|
+
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
|
1970
|
+
#pragma unroll(4)
|
1971
|
+
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
|
1972
|
+
#pragma unroll(4)
|
1973
|
+
for (int i = 0; i < 4; i++) {
|
1974
|
+
simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
|
1975
|
+
}
|
1976
|
+
simdgroup_barrier(mem_flags::mem_none);
|
1977
|
+
#pragma unroll(2)
|
1978
|
+
for (int i = 0; i < 2; i++) {
|
1979
|
+
simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
|
1980
|
+
}
|
1981
|
+
|
1982
|
+
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
|
1983
|
+
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
|
1984
|
+
#pragma unroll(8)
|
1985
|
+
for (int i = 0; i < 8; i++){
|
1986
|
+
simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
|
1987
|
+
}
|
1988
|
+
}
|
1989
|
+
}
|
1990
|
+
|
1991
|
+
if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
|
1992
|
+
device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1) \
|
1993
|
+
+ (BLOCK_SIZE_N * r1 + 16 * (sgitg>>1)) * ne0 + im*ne1*ne0;
|
1994
|
+
for (int i = 0; i < 8; i++) {
|
1995
|
+
simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
|
1996
|
+
}
|
1997
|
+
} else {
|
1998
|
+
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
1999
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
2000
|
+
threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
|
2001
|
+
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
|
2002
|
+
for (int i = 0; i < 8; i++) {
|
2003
|
+
simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
|
2004
|
+
}
|
2005
|
+
|
2006
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
2007
|
+
device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
|
2008
|
+
if (sgitg==0) {
|
2009
|
+
for (int i = 0; i < n_rows; i++) {
|
2010
|
+
for (int j = tiitg; j< n_cols; j += BLOCK_SIZE_N) {
|
2011
|
+
*(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
|
2012
|
+
}
|
2013
|
+
}
|
2014
|
+
}
|
2015
|
+
}
|
2016
|
+
}
|
2017
|
+
|
2018
|
+
#if QK_K == 256
|
2019
|
+
#define QK_NL 16
|
2020
|
+
#else
|
2021
|
+
#define QK_NL 4
|
2022
|
+
#endif
|
2023
|
+
|
2024
|
+
typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
|
2025
|
+
constant uint64_t &, constant uint64_t &, uint, uint, uint);
|
2026
|
+
|
2027
|
+
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
2028
|
+
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
2029
|
+
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
2030
|
+
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
|
2031
|
+
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
|
2032
|
+
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
|
2033
|
+
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
|
2034
|
+
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
|
2035
|
+
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
|
2036
|
+
|
2037
|
+
typedef void (mat_mm_t)(device const uchar *, device const float *, device float *, constant int64_t &,\
|
2038
|
+
constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \
|
2039
|
+
constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint);
|
2040
|
+
|
2041
|
+
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
2042
|
+
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
|
2043
|
+
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
|
2044
|
+
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
|
2045
|
+
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
|
2046
|
+
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
|
2047
|
+
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
|
2048
|
+
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
|
2049
|
+
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|