llama_cpp 0.3.6 → 0.3.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +11 -0
- data/ext/llama_cpp/extconf.rb +2 -2
- data/ext/llama_cpp/llama_cpp.cpp +8 -0
- data/ext/llama_cpp/src/ggml-alloc.c +44 -6
- data/ext/llama_cpp/src/ggml-alloc.h +4 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +1398 -702
- data/ext/llama_cpp/src/ggml-cuda.h +19 -23
- data/ext/llama_cpp/src/ggml-metal.h +6 -3
- data/ext/llama_cpp/src/ggml-metal.m +112 -146
- data/ext/llama_cpp/src/ggml-metal.metal +471 -498
- data/ext/llama_cpp/src/ggml.c +396 -150
- data/ext/llama_cpp/src/ggml.h +113 -32
- data/ext/llama_cpp/src/llama-util.h +51 -9
- data/ext/llama_cpp/src/llama.cpp +390 -210
- data/ext/llama_cpp/src/llama.h +20 -1
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +1 -0
- metadata +2 -2
@@ -18,47 +18,6 @@ typedef struct {
|
|
18
18
|
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
19
19
|
} block_q4_1;
|
20
20
|
|
21
|
-
static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, int k) {
|
22
|
-
const int qk = QK4_0;
|
23
|
-
|
24
|
-
assert(k % qk == 0);
|
25
|
-
|
26
|
-
const int nb = k / qk;
|
27
|
-
|
28
|
-
for (int i = 0; i < nb; i++) {
|
29
|
-
const half d = x[i].d;
|
30
|
-
|
31
|
-
for (int j = 0; j < qk/2; ++j) {
|
32
|
-
const int x0 = (x[i].qs[j] & 0x0F) - 8;
|
33
|
-
const int x1 = (x[i].qs[j] >> 4) - 8;
|
34
|
-
|
35
|
-
y[i*qk + j + 0 ] = x0*d;
|
36
|
-
y[i*qk + j + qk/2] = x1*d;
|
37
|
-
}
|
38
|
-
}
|
39
|
-
}
|
40
|
-
|
41
|
-
static void dequantize_row_q4_1(device const block_q4_1 * x, device float * y, int k) {
|
42
|
-
const int qk = QK4_1;
|
43
|
-
|
44
|
-
assert(k % qk == 0);
|
45
|
-
|
46
|
-
const int nb = k / qk;
|
47
|
-
|
48
|
-
for (int i = 0; i < nb; i++) {
|
49
|
-
const half d = x[i].d;
|
50
|
-
const half m = x[i].m;
|
51
|
-
|
52
|
-
for (int j = 0; j < qk/2; ++j) {
|
53
|
-
const int x0 = (x[i].qs[j] & 0x0F);
|
54
|
-
const int x1 = (x[i].qs[j] >> 4);
|
55
|
-
|
56
|
-
y[i*qk + j + 0 ] = x0*d + m;
|
57
|
-
y[i*qk + j + qk/2] = x1*d + m;
|
58
|
-
}
|
59
|
-
}
|
60
|
-
}
|
61
|
-
|
62
21
|
kernel void kernel_add(
|
63
22
|
device const float * src0,
|
64
23
|
device const float * src1,
|
@@ -219,54 +178,6 @@ kernel void kernel_diag_mask_inf(
|
|
219
178
|
}
|
220
179
|
}
|
221
180
|
|
222
|
-
kernel void kernel_get_rows_f16(
|
223
|
-
device const void * src0,
|
224
|
-
device const int * src1,
|
225
|
-
device float * dst,
|
226
|
-
constant int64_t & ne00,
|
227
|
-
constant uint64_t & nb01,
|
228
|
-
constant uint64_t & nb1,
|
229
|
-
uint tpig[[thread_position_in_grid]]) {
|
230
|
-
const int i = tpig;
|
231
|
-
const int r = ((device int32_t *) src1)[i];
|
232
|
-
|
233
|
-
for (int j = 0; j < ne00; j++) {
|
234
|
-
dst[i*nb1 + j] = ((device half *) ((device char *) src0 + r*nb01))[j];
|
235
|
-
}
|
236
|
-
}
|
237
|
-
|
238
|
-
kernel void kernel_get_rows_q4_0(
|
239
|
-
device const void * src0,
|
240
|
-
device const int * src1,
|
241
|
-
device float * dst,
|
242
|
-
constant int64_t & ne00,
|
243
|
-
constant uint64_t & nb01,
|
244
|
-
constant uint64_t & nb1,
|
245
|
-
uint tpig[[thread_position_in_grid]]) {
|
246
|
-
const int i = tpig;
|
247
|
-
const int r = ((device int32_t *) src1)[i];
|
248
|
-
|
249
|
-
dequantize_row_q4_0(
|
250
|
-
(device const block_q4_0 *) ((device char *) src0 + r*nb01),
|
251
|
-
(device float *) ((device char *) dst + i*nb1), ne00);
|
252
|
-
}
|
253
|
-
|
254
|
-
kernel void kernel_get_rows_q4_1(
|
255
|
-
device const void * src0,
|
256
|
-
device const int * src1,
|
257
|
-
device float * dst,
|
258
|
-
constant int64_t & ne00,
|
259
|
-
constant uint64_t & nb01,
|
260
|
-
constant uint64_t & nb1,
|
261
|
-
uint tpig[[thread_position_in_grid]]) {
|
262
|
-
const int i = tpig;
|
263
|
-
const int r = ((device int32_t *) src1)[i];
|
264
|
-
|
265
|
-
dequantize_row_q4_1(
|
266
|
-
(device const block_q4_1 *) ((device char *) src0 + r*nb01),
|
267
|
-
(device float *) ((device char *) dst + i*nb1), ne00);
|
268
|
-
}
|
269
|
-
|
270
181
|
kernel void kernel_norm(
|
271
182
|
device const void * src0,
|
272
183
|
device float * dst,
|
@@ -432,14 +343,16 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
|
|
432
343
|
// N_DST, so this is another explicit assumption of the implementation.
|
433
344
|
template<typename block_q_type, int nr, int nsg, int nw>
|
434
345
|
void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
|
435
|
-
int64_t ne00, int64_t ne10, int64_t ne0, int64_t
|
436
|
-
|
346
|
+
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
|
347
|
+
uint3 tgpig, uint tiisg, uint sgitg) {
|
437
348
|
const int nb = ne00/QK4_0;
|
438
349
|
const int r0 = tgpig.x;
|
439
350
|
const int r1 = tgpig.y;
|
351
|
+
const int im = tgpig.z;
|
440
352
|
const int first_row = (r0 * nsg + sgitg) * nr;
|
441
|
-
|
442
|
-
device const
|
353
|
+
const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
|
354
|
+
device const block_q_type * x = (device const block_q_type *) src0 + offset0;
|
355
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
443
356
|
float yl[16]; // src1 vector cache
|
444
357
|
float sumf[nr]={0.f};
|
445
358
|
|
@@ -470,7 +383,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
|
|
470
383
|
for (int row = 0; row < nr; ++row) {
|
471
384
|
const float tot = simd_sum(sumf[row]);
|
472
385
|
if (tiisg == 0 && first_row + row < ne01) {
|
473
|
-
dst[r1*ne0 + first_row + row] = tot;
|
386
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
|
474
387
|
}
|
475
388
|
}
|
476
389
|
}
|
@@ -480,13 +393,17 @@ kernel void kernel_mul_mat_q4_0_f32(
|
|
480
393
|
device const float * src1,
|
481
394
|
device float * dst,
|
482
395
|
constant int64_t & ne00,
|
483
|
-
constant int64_t & ne10,
|
484
|
-
constant int64_t & ne0,
|
485
396
|
constant int64_t & ne01[[buffer(4)]],
|
486
|
-
|
397
|
+
constant int64_t & ne02[[buffer(5)]],
|
398
|
+
constant int64_t & ne10[[buffer(9)]],
|
399
|
+
constant int64_t & ne12[[buffer(11)]],
|
400
|
+
constant int64_t & ne0[[buffer(15)]],
|
401
|
+
constant int64_t & ne1[[buffer(16)]],
|
402
|
+
constant uint & gqa[[buffer(17)]],
|
403
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
487
404
|
uint tiisg[[thread_index_in_simdgroup]],
|
488
405
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
489
|
-
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne10,ne0,
|
406
|
+
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
490
407
|
}
|
491
408
|
|
492
409
|
kernel void kernel_mul_mat_q4_1_f32(
|
@@ -494,13 +411,17 @@ kernel void kernel_mul_mat_q4_1_f32(
|
|
494
411
|
device const float * src1,
|
495
412
|
device float * dst,
|
496
413
|
constant int64_t & ne00,
|
497
|
-
constant int64_t & ne10,
|
498
|
-
constant int64_t & ne0,
|
499
414
|
constant int64_t & ne01[[buffer(4)]],
|
500
|
-
|
415
|
+
constant int64_t & ne02[[buffer(5)]],
|
416
|
+
constant int64_t & ne10[[buffer(9)]],
|
417
|
+
constant int64_t & ne12[[buffer(11)]],
|
418
|
+
constant int64_t & ne0[[buffer(15)]],
|
419
|
+
constant int64_t & ne1[[buffer(16)]],
|
420
|
+
constant uint & gqa[[buffer(17)]],
|
421
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
501
422
|
uint tiisg[[thread_index_in_simdgroup]],
|
502
423
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
503
|
-
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne10,ne0,
|
424
|
+
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
504
425
|
}
|
505
426
|
|
506
427
|
kernel void kernel_mul_mat_f16_f32(
|
@@ -869,354 +790,6 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
|
|
869
790
|
return r;
|
870
791
|
}
|
871
792
|
|
872
|
-
//========================================== dequantization =============================
|
873
|
-
|
874
|
-
static void dequantize_row_q2_K(device const block_q2_K * x, device float * y, int k) {
|
875
|
-
assert(k % QK_K == 0);
|
876
|
-
const int nb = k / QK_K;
|
877
|
-
|
878
|
-
for (int i = 0; i < nb; i++) {
|
879
|
-
|
880
|
-
const float d = x[i].d;
|
881
|
-
const float min = x[i].dmin;
|
882
|
-
|
883
|
-
device const uint8_t * q = x[i].qs;
|
884
|
-
|
885
|
-
#if QK_K == 256
|
886
|
-
int is = 0;
|
887
|
-
float dl, ml;
|
888
|
-
for (int n = 0; n < QK_K; n += 128) {
|
889
|
-
int shift = 0;
|
890
|
-
for (int j = 0; j < 4; ++j) {
|
891
|
-
|
892
|
-
uint8_t sc = x[i].scales[is++];
|
893
|
-
dl = d * (sc & 0xF); ml = min * (sc >> 4);
|
894
|
-
for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml;
|
895
|
-
|
896
|
-
sc = x[i].scales[is++];
|
897
|
-
dl = d * (sc & 0xF); ml = min * (sc >> 4);
|
898
|
-
for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml;
|
899
|
-
|
900
|
-
shift += 2;
|
901
|
-
}
|
902
|
-
q += 32;
|
903
|
-
}
|
904
|
-
#else
|
905
|
-
float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4);
|
906
|
-
float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4);
|
907
|
-
float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4);
|
908
|
-
float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4);
|
909
|
-
for (int l = 0; l < 16; ++l) {
|
910
|
-
y[l+ 0] = dl1 * ((q[l] >> 0) & 3) - ml1;
|
911
|
-
y[l+16] = dl2 * ((q[l] >> 2) & 3) - ml2;
|
912
|
-
y[l+32] = dl3 * ((q[l] >> 4) & 3) - ml3;
|
913
|
-
y[l+48] = dl4 * ((q[l] >> 6) & 3) - ml4;
|
914
|
-
}
|
915
|
-
y += QK_K;
|
916
|
-
#endif
|
917
|
-
|
918
|
-
}
|
919
|
-
}
|
920
|
-
|
921
|
-
static void dequantize_row_q3_K(device const block_q3_K * x, device float * y, int k) {
|
922
|
-
assert(k % QK_K == 0);
|
923
|
-
const int nb = k / QK_K;
|
924
|
-
|
925
|
-
#if QK_K == 256
|
926
|
-
|
927
|
-
const uint16_t kmask1 = 0x0303;
|
928
|
-
const uint16_t kmask2 = 0x0f0f;
|
929
|
-
|
930
|
-
uint16_t aux[8];
|
931
|
-
thread const int8_t * scales = (thread const int8_t*)aux;
|
932
|
-
|
933
|
-
for (int i = 0; i < nb; i++) {
|
934
|
-
|
935
|
-
const float d_all = (float)(x[i].d);
|
936
|
-
|
937
|
-
device const uint8_t * q = x[i].qs;
|
938
|
-
device const uint8_t * h = x[i].hmask;
|
939
|
-
uint8_t m = 1;
|
940
|
-
|
941
|
-
device const uint16_t * a = (device const uint16_t *)x[i].scales;
|
942
|
-
aux[0] = (a[0] & kmask2) | (((a[4] >> 0) & kmask1) << 4);
|
943
|
-
aux[1] = (a[1] & kmask2) | (((a[5] >> 0) & kmask1) << 4);
|
944
|
-
aux[2] = (a[2] & kmask2) | (((a[4] >> 2) & kmask1) << 4);
|
945
|
-
aux[3] = (a[3] & kmask2) | (((a[5] >> 2) & kmask1) << 4);
|
946
|
-
aux[4] = ((a[0] >> 4) & kmask2) | (((a[4] >> 4) & kmask1) << 4);
|
947
|
-
aux[5] = ((a[1] >> 4) & kmask2) | (((a[5] >> 4) & kmask1) << 4);
|
948
|
-
aux[6] = ((a[2] >> 4) & kmask2) | (((a[4] >> 6) & kmask1) << 4);
|
949
|
-
aux[7] = ((a[3] >> 4) & kmask2) | (((a[5] >> 6) & kmask1) << 4);
|
950
|
-
|
951
|
-
int is = 0;
|
952
|
-
float dl;
|
953
|
-
for (int n = 0; n < QK_K; n += 128) {
|
954
|
-
int shift = 0;
|
955
|
-
for (int j = 0; j < 4; ++j) {
|
956
|
-
|
957
|
-
dl = d_all * (scales[is++] - 32);
|
958
|
-
for (int l = 0; l < 16; ++l) {
|
959
|
-
*y++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((h[l+ 0] & m) ? 0 : 4));
|
960
|
-
}
|
961
|
-
|
962
|
-
dl = d_all * (scales[is++] - 32);
|
963
|
-
for (int l = 0; l < 16; ++l) {
|
964
|
-
*y++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((h[l+16] & m) ? 0 : 4));
|
965
|
-
}
|
966
|
-
|
967
|
-
shift += 2;
|
968
|
-
m <<= 1;
|
969
|
-
}
|
970
|
-
q += 32;
|
971
|
-
}
|
972
|
-
}
|
973
|
-
#else
|
974
|
-
for (int i = 0; i < nb; i++) {
|
975
|
-
|
976
|
-
const float d_all = (float)(x[i].d);
|
977
|
-
|
978
|
-
device const uint8_t * q = x[i].qs;
|
979
|
-
device const uint8_t * hm = x[i].hmask;
|
980
|
-
|
981
|
-
const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
|
982
|
-
const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
|
983
|
-
const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
|
984
|
-
const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
|
985
|
-
|
986
|
-
for (int l = 0; l < 8; ++l) {
|
987
|
-
uint8_t h = hm[l];
|
988
|
-
y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4));
|
989
|
-
y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4));
|
990
|
-
y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4));
|
991
|
-
y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4));
|
992
|
-
y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4));
|
993
|
-
y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4));
|
994
|
-
y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4));
|
995
|
-
y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4));
|
996
|
-
}
|
997
|
-
y += QK_K;
|
998
|
-
}
|
999
|
-
#endif
|
1000
|
-
|
1001
|
-
}
|
1002
|
-
|
1003
|
-
static void dequantize_row_q4_K(device const block_q4_K * x, device float * y, int k) {
|
1004
|
-
assert(k % QK_K == 0);
|
1005
|
-
const int nb = k / QK_K;
|
1006
|
-
|
1007
|
-
for (int i = 0; i < nb; i++) {
|
1008
|
-
|
1009
|
-
device const uint8_t * q = x[i].qs;
|
1010
|
-
|
1011
|
-
#if QK_K == 256
|
1012
|
-
const float d = x[i].d;
|
1013
|
-
const float min = x[i].dmin;
|
1014
|
-
|
1015
|
-
device const uint8_t * scales = x[i].scales;
|
1016
|
-
|
1017
|
-
int is = 0;
|
1018
|
-
for (int j = 0; j < QK_K; j += 64) {
|
1019
|
-
const uchar4 sc = get_scale_min_k4(is, scales);
|
1020
|
-
const float d1 = d * sc[0]; const float m1 = min * sc[1];
|
1021
|
-
const float d2 = d * sc[2]; const float m2 = min * sc[3];
|
1022
|
-
for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
|
1023
|
-
for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
|
1024
|
-
q += 32; is += 2;
|
1025
|
-
}
|
1026
|
-
#else
|
1027
|
-
device const uint8_t * s = x[i].scales;
|
1028
|
-
device const half2 * dh = (device const half2 *)x[i].d;
|
1029
|
-
const float2 d = (float2)dh[0];
|
1030
|
-
const float d1 = d[0] * (s[0] & 0xF);
|
1031
|
-
const float d2 = d[0] * (s[1] & 0xF);
|
1032
|
-
const float m1 = d[1] * (s[0] >> 4);
|
1033
|
-
const float m2 = d[1] * (s[1] >> 4);
|
1034
|
-
for (int l = 0; l < 32; ++l) {
|
1035
|
-
y[l+ 0] = d1 * (q[l] & 0xF) - m1;
|
1036
|
-
y[l+32] = d2 * (q[l] >> 4) - m2;
|
1037
|
-
}
|
1038
|
-
y += QK_K;
|
1039
|
-
#endif
|
1040
|
-
|
1041
|
-
}
|
1042
|
-
}
|
1043
|
-
|
1044
|
-
static void dequantize_row_q5_K(device const block_q5_K * x, device float * y, int k) {
|
1045
|
-
assert(k % QK_K == 0);
|
1046
|
-
const int nb = k / QK_K;
|
1047
|
-
|
1048
|
-
#if QK_K == 256
|
1049
|
-
for (int i = 0; i < nb; i++) {
|
1050
|
-
|
1051
|
-
const float d = (float)(x[i].d);
|
1052
|
-
const float min = (float)(x[i].dmin);
|
1053
|
-
|
1054
|
-
device const uint8_t * ql = x[i].qs;
|
1055
|
-
device const uint8_t * qh = x[i].qh;
|
1056
|
-
|
1057
|
-
int is = 0;
|
1058
|
-
uint8_t u1 = 1, u2 = 2;
|
1059
|
-
for (int j = 0; j < QK_K; j += 64) {
|
1060
|
-
const uchar4 sc = get_scale_min_k4(is, x[i].scales);
|
1061
|
-
const float d1 = d * sc[0]; const float m1 = min * sc[1];
|
1062
|
-
const float d2 = d * sc[2]; const float m2 = min * sc[3];
|
1063
|
-
for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1;
|
1064
|
-
for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2;
|
1065
|
-
ql += 32; is += 2;
|
1066
|
-
u1 <<= 2; u2 <<= 2;
|
1067
|
-
}
|
1068
|
-
}
|
1069
|
-
#else
|
1070
|
-
for (int i = 0; i < nb; i++) {
|
1071
|
-
|
1072
|
-
const float d = (float)x[i].d;
|
1073
|
-
|
1074
|
-
device const uint8_t * ql = x[i].qs;
|
1075
|
-
device const uint8_t * qh = x[i].qh;
|
1076
|
-
device const int8_t * sc = x[i].scales;
|
1077
|
-
|
1078
|
-
for (int l = 0; l < 8; ++l) {
|
1079
|
-
y[l+ 0] = d * sc[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16));
|
1080
|
-
y[l+ 8] = d * sc[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16));
|
1081
|
-
y[l+16] = d * sc[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16));
|
1082
|
-
y[l+24] = d * sc[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16));
|
1083
|
-
y[l+32] = d * sc[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16));
|
1084
|
-
y[l+40] = d * sc[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16));
|
1085
|
-
y[l+48] = d * sc[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16));
|
1086
|
-
y[l+56] = d * sc[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16));
|
1087
|
-
}
|
1088
|
-
y += QK_K;
|
1089
|
-
}
|
1090
|
-
#endif
|
1091
|
-
|
1092
|
-
}
|
1093
|
-
|
1094
|
-
static void dequantize_row_q6_K(device const block_q6_K * x, device float * y, int k) {
|
1095
|
-
assert(k % QK_K == 0);
|
1096
|
-
const int nb = k / QK_K;
|
1097
|
-
|
1098
|
-
for (int i = 0; i < nb; i++) {
|
1099
|
-
|
1100
|
-
device const uint8_t * ql = x[i].ql;
|
1101
|
-
device const uint8_t * qh = x[i].qh;
|
1102
|
-
device const int8_t * sc = x[i].scales;
|
1103
|
-
|
1104
|
-
const float d = x[i].d;
|
1105
|
-
|
1106
|
-
#if QK_K == 256
|
1107
|
-
for (int n = 0; n < QK_K; n += 128) {
|
1108
|
-
for (int l = 0; l < 32; ++l) {
|
1109
|
-
int is = l/16;
|
1110
|
-
const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
|
1111
|
-
const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
|
1112
|
-
const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
|
1113
|
-
const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
|
1114
|
-
y[l + 0] = d * sc[is + 0] * q1;
|
1115
|
-
y[l + 32] = d * sc[is + 2] * q2;
|
1116
|
-
y[l + 64] = d * sc[is + 4] * q3;
|
1117
|
-
y[l + 96] = d * sc[is + 6] * q4;
|
1118
|
-
}
|
1119
|
-
y += 128;
|
1120
|
-
ql += 64;
|
1121
|
-
qh += 32;
|
1122
|
-
sc += 8;
|
1123
|
-
}
|
1124
|
-
#else
|
1125
|
-
for (int l = 0; l < 16; ++l) {
|
1126
|
-
const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
|
1127
|
-
const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
|
1128
|
-
const int8_t q3 = (int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
|
1129
|
-
const int8_t q4 = (int8_t)((ql[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
|
1130
|
-
y[l+ 0] = d * sc[0] * q1;
|
1131
|
-
y[l+16] = d * sc[1] * q2;
|
1132
|
-
y[l+32] = d * sc[2] * q3;
|
1133
|
-
y[l+48] = d * sc[3] * q4;
|
1134
|
-
}
|
1135
|
-
y += 64;
|
1136
|
-
#endif
|
1137
|
-
}
|
1138
|
-
}
|
1139
|
-
|
1140
|
-
kernel void kernel_get_rows_q2_K(
|
1141
|
-
device const void * src0,
|
1142
|
-
device const int * src1,
|
1143
|
-
device float * dst,
|
1144
|
-
constant int64_t & ne00,
|
1145
|
-
constant uint64_t & nb01,
|
1146
|
-
constant uint64_t & nb1,
|
1147
|
-
uint tpig[[thread_position_in_grid]]) {
|
1148
|
-
const int i = tpig;
|
1149
|
-
const int r = ((device int32_t *) src1)[i];
|
1150
|
-
|
1151
|
-
dequantize_row_q2_K(
|
1152
|
-
(device const block_q2_K *) ((device char *) src0 + r*nb01),
|
1153
|
-
(device float *) ((device char *) dst + i*nb1), ne00);
|
1154
|
-
}
|
1155
|
-
|
1156
|
-
kernel void kernel_get_rows_q3_K(
|
1157
|
-
device const void * src0,
|
1158
|
-
device const int * src1,
|
1159
|
-
device float * dst,
|
1160
|
-
constant int64_t & ne00,
|
1161
|
-
constant uint64_t & nb01,
|
1162
|
-
constant uint64_t & nb1,
|
1163
|
-
uint tpig[[thread_position_in_grid]]) {
|
1164
|
-
const int i = tpig;
|
1165
|
-
const int r = ((device int32_t *) src1)[i];
|
1166
|
-
|
1167
|
-
dequantize_row_q3_K(
|
1168
|
-
(device const block_q3_K *) ((device char *) src0 + r*nb01),
|
1169
|
-
(device float *) ((device char *) dst + i*nb1), ne00);
|
1170
|
-
}
|
1171
|
-
|
1172
|
-
kernel void kernel_get_rows_q4_K(
|
1173
|
-
device const void * src0,
|
1174
|
-
device const int * src1,
|
1175
|
-
device float * dst,
|
1176
|
-
constant int64_t & ne00,
|
1177
|
-
constant uint64_t & nb01,
|
1178
|
-
constant uint64_t & nb1,
|
1179
|
-
uint tpig[[thread_position_in_grid]]) {
|
1180
|
-
const int i = tpig;
|
1181
|
-
const int r = ((device int32_t *) src1)[i];
|
1182
|
-
|
1183
|
-
dequantize_row_q4_K(
|
1184
|
-
(device const block_q4_K *) ((device char *) src0 + r*nb01),
|
1185
|
-
(device float *) ((device char *) dst + i*nb1), ne00);
|
1186
|
-
}
|
1187
|
-
|
1188
|
-
kernel void kernel_get_rows_q5_K(
|
1189
|
-
device const void * src0,
|
1190
|
-
device const int * src1,
|
1191
|
-
device float * dst,
|
1192
|
-
constant int64_t & ne00,
|
1193
|
-
constant uint64_t & nb01,
|
1194
|
-
constant uint64_t & nb1,
|
1195
|
-
uint tpig[[thread_position_in_grid]]) {
|
1196
|
-
const int i = tpig;
|
1197
|
-
const int r = ((device int32_t *) src1)[i];
|
1198
|
-
|
1199
|
-
dequantize_row_q5_K(
|
1200
|
-
(device const block_q5_K *) ((device char *) src0 + r*nb01),
|
1201
|
-
(device float *) ((device char *) dst + i*nb1), ne00);
|
1202
|
-
}
|
1203
|
-
|
1204
|
-
kernel void kernel_get_rows_q6_K(
|
1205
|
-
device const void * src0,
|
1206
|
-
device const int * src1,
|
1207
|
-
device float * dst,
|
1208
|
-
constant int64_t & ne00,
|
1209
|
-
constant uint64_t & nb01,
|
1210
|
-
constant uint64_t & nb1,
|
1211
|
-
uint tpig[[thread_position_in_grid]]) {
|
1212
|
-
const int i = tpig;
|
1213
|
-
const int r = ((device int32_t *) src1)[i];
|
1214
|
-
|
1215
|
-
dequantize_row_q6_K(
|
1216
|
-
(device const block_q6_K *) ((device char *) src0 + r*nb01),
|
1217
|
-
(device float *) ((device char *) dst + i*nb1), ne00);
|
1218
|
-
}
|
1219
|
-
|
1220
793
|
//====================================== dot products =========================
|
1221
794
|
|
1222
795
|
kernel void kernel_mul_mat_q2_K_f32(
|
@@ -1224,21 +797,27 @@ kernel void kernel_mul_mat_q2_K_f32(
|
|
1224
797
|
device const float * src1,
|
1225
798
|
device float * dst,
|
1226
799
|
constant int64_t & ne00,
|
1227
|
-
constant int64_t & ne10,
|
1228
|
-
constant int64_t & ne0,
|
1229
800
|
constant int64_t & ne01[[buffer(4)]],
|
1230
|
-
|
801
|
+
constant int64_t & ne02[[buffer(5)]],
|
802
|
+
constant int64_t & ne10[[buffer(9)]],
|
803
|
+
constant int64_t & ne12[[buffer(11)]],
|
804
|
+
constant int64_t & ne0[[buffer(15)]],
|
805
|
+
constant int64_t & ne1[[buffer(16)]],
|
806
|
+
constant uint & gqa[[buffer(17)]],
|
807
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1231
808
|
uint tiisg[[thread_index_in_simdgroup]],
|
1232
809
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1233
810
|
|
1234
811
|
const int nb = ne00/QK_K;
|
1235
812
|
const int r0 = tgpig.x;
|
1236
813
|
const int r1 = tgpig.y;
|
814
|
+
const int r2 = tgpig.z;
|
1237
815
|
|
1238
816
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
1239
817
|
const int ib_row = first_row * nb;
|
1240
|
-
|
1241
|
-
device const
|
818
|
+
const uint offset0 = r2/gqa*(nb*ne0);
|
819
|
+
device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
|
820
|
+
device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
|
1242
821
|
float yl[32];
|
1243
822
|
float sumf[N_DST]={0.f}, all_sum;
|
1244
823
|
|
@@ -1351,7 +930,7 @@ kernel void kernel_mul_mat_q2_K_f32(
|
|
1351
930
|
for (int row = 0; row < N_DST; ++row) {
|
1352
931
|
all_sum = simd_sum(sumf[row]);
|
1353
932
|
if (tiisg == 0) {
|
1354
|
-
dst[r1*ne0 + first_row + row] = all_sum;
|
933
|
+
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
|
1355
934
|
}
|
1356
935
|
}
|
1357
936
|
}
|
@@ -1362,10 +941,14 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1362
941
|
device const float * src1,
|
1363
942
|
device float * dst,
|
1364
943
|
constant int64_t & ne00,
|
1365
|
-
constant int64_t &
|
1366
|
-
constant int64_t &
|
1367
|
-
constant int64_t &
|
1368
|
-
|
944
|
+
constant int64_t & ne01[[buffer(4)]],
|
945
|
+
constant int64_t & ne02[[buffer(5)]],
|
946
|
+
constant int64_t & ne10[[buffer(9)]],
|
947
|
+
constant int64_t & ne12[[buffer(11)]],
|
948
|
+
constant int64_t & ne0[[buffer(15)]],
|
949
|
+
constant int64_t & ne1[[buffer(16)]],
|
950
|
+
constant uint & gqa[[buffer(17)]],
|
951
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1369
952
|
uint tiisg[[thread_index_in_simdgroup]],
|
1370
953
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1371
954
|
|
@@ -1373,11 +956,12 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1373
956
|
|
1374
957
|
const int64_t r0 = tgpig.x;
|
1375
958
|
const int64_t r1 = tgpig.y;
|
959
|
+
const int64_t r2 = tgpig.z;
|
1376
960
|
|
1377
961
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
|
1378
|
-
|
1379
|
-
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb;
|
1380
|
-
device const float * yy = (device const float *) src1 + r1*ne10;
|
962
|
+
const uint offset0 = r2/gqa*(nb*ne0);
|
963
|
+
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
|
964
|
+
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
|
1381
965
|
|
1382
966
|
float yl[16];
|
1383
967
|
|
@@ -1465,7 +1049,7 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1465
1049
|
const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
|
1466
1050
|
const float tot = simd_sum(sumf);
|
1467
1051
|
if (tiisg == 0) {
|
1468
|
-
dst[r1*ne0 + first_row + row] = tot;
|
1052
|
+
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
|
1469
1053
|
}
|
1470
1054
|
}
|
1471
1055
|
}
|
@@ -1475,10 +1059,14 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1475
1059
|
device const float * src1,
|
1476
1060
|
device float * dst,
|
1477
1061
|
constant int64_t & ne00,
|
1478
|
-
constant int64_t &
|
1479
|
-
constant int64_t &
|
1480
|
-
constant int64_t &
|
1481
|
-
|
1062
|
+
constant int64_t & ne01[[buffer(4)]],
|
1063
|
+
constant int64_t & ne02[[buffer(5)]],
|
1064
|
+
constant int64_t & ne10[[buffer(9)]],
|
1065
|
+
constant int64_t & ne12[[buffer(11)]],
|
1066
|
+
constant int64_t & ne0[[buffer(15)]],
|
1067
|
+
constant int64_t & ne1[[buffer(16)]],
|
1068
|
+
constant uint & gqa[[buffer(17)]],
|
1069
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1482
1070
|
uint tiisg[[thread_index_in_simdgroup]],
|
1483
1071
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1484
1072
|
|
@@ -1486,11 +1074,12 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1486
1074
|
|
1487
1075
|
const int64_t r0 = tgpig.x;
|
1488
1076
|
const int64_t r1 = tgpig.y;
|
1077
|
+
const int64_t r2 = tgpig.z;
|
1489
1078
|
|
1490
1079
|
const int row = 2 * r0 + sgitg;
|
1491
|
-
|
1492
|
-
device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb;
|
1493
|
-
device const float * yy = (device const float *) src1 + r1*ne10;
|
1080
|
+
const uint offset0 = r2/gqa*(nb*ne0);
|
1081
|
+
device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
|
1082
|
+
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
|
1494
1083
|
const int ix = tiisg/4;
|
1495
1084
|
const int il = 4 * (tiisg%4);// 0, 4, 8, 12
|
1496
1085
|
const int im = il/8; // 0, 0, 1, 1
|
@@ -1529,7 +1118,7 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1529
1118
|
|
1530
1119
|
const float tot = simd_sum(sumf);
|
1531
1120
|
if (tiisg == 0) {
|
1532
|
-
dst[r1*ne0 + row] = tot;
|
1121
|
+
dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
|
1533
1122
|
}
|
1534
1123
|
|
1535
1124
|
}
|
@@ -1541,10 +1130,14 @@ kernel void kernel_mul_mat_q4_K_f32(
|
|
1541
1130
|
device const float * src1,
|
1542
1131
|
device float * dst,
|
1543
1132
|
constant int64_t & ne00,
|
1544
|
-
constant int64_t & ne10,
|
1545
|
-
constant int64_t & ne0,
|
1546
1133
|
constant int64_t & ne01[[buffer(4)]],
|
1547
|
-
|
1134
|
+
constant int64_t & ne02[[buffer(5)]],
|
1135
|
+
constant int64_t & ne10[[buffer(9)]],
|
1136
|
+
constant int64_t & ne12[[buffer(11)]],
|
1137
|
+
constant int64_t & ne0[[buffer(15)]],
|
1138
|
+
constant int64_t & ne1[[buffer(16)]],
|
1139
|
+
constant uint & gqa[[buffer(17)]],
|
1140
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1548
1141
|
uint tiisg[[thread_index_in_simdgroup]],
|
1549
1142
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1550
1143
|
|
@@ -1560,10 +1153,12 @@ kernel void kernel_mul_mat_q4_K_f32(
|
|
1560
1153
|
const int nb = ne00/QK_K;
|
1561
1154
|
const int r0 = tgpig.x;
|
1562
1155
|
const int r1 = tgpig.y;
|
1156
|
+
const int r2 = tgpig.z;
|
1563
1157
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
1564
1158
|
const int ib_row = first_row * nb;
|
1565
|
-
|
1566
|
-
device const
|
1159
|
+
const uint offset0 = r2/gqa*(nb*ne0);
|
1160
|
+
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
|
1161
|
+
device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
|
1567
1162
|
float yl[16];
|
1568
1163
|
float yh[16];
|
1569
1164
|
float sumf[N_DST]={0.f}, all_sum;
|
@@ -1630,7 +1225,7 @@ kernel void kernel_mul_mat_q4_K_f32(
|
|
1630
1225
|
for (int row = 0; row < N_DST; ++row) {
|
1631
1226
|
all_sum = simd_sum(sumf[row]);
|
1632
1227
|
if (tiisg == 0) {
|
1633
|
-
dst[r1*ne0 + first_row + row] = all_sum;
|
1228
|
+
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
|
1634
1229
|
}
|
1635
1230
|
}
|
1636
1231
|
}
|
@@ -1640,10 +1235,14 @@ kernel void kernel_mul_mat_q4_K_f32(
|
|
1640
1235
|
device const float * src1,
|
1641
1236
|
device float * dst,
|
1642
1237
|
constant int64_t & ne00,
|
1643
|
-
constant int64_t & ne10,
|
1644
|
-
constant int64_t & ne0,
|
1645
1238
|
constant int64_t & ne01[[buffer(4)]],
|
1646
|
-
|
1239
|
+
constant int64_t & ne02[[buffer(5)]],
|
1240
|
+
constant int64_t & ne10[[buffer(9)]],
|
1241
|
+
constant int64_t & ne12[[buffer(11)]],
|
1242
|
+
constant int64_t & ne0[[buffer(15)]],
|
1243
|
+
constant int64_t & ne1[[buffer(16)]],
|
1244
|
+
constant uint & gqa[[buffer(17)]],
|
1245
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1647
1246
|
uint tiisg[[thread_index_in_simdgroup]],
|
1648
1247
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1649
1248
|
|
@@ -1653,10 +1252,12 @@ kernel void kernel_mul_mat_q4_K_f32(
|
|
1653
1252
|
const int nb = ne00/QK_K;
|
1654
1253
|
const int r0 = tgpig.x;
|
1655
1254
|
const int r1 = tgpig.y;
|
1255
|
+
const int r2 = tgpig.z;
|
1656
1256
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
1657
1257
|
const int ib_row = first_row * nb;
|
1658
|
-
|
1659
|
-
device const
|
1258
|
+
const uint offset0 = r2/gqa*(nb*ne0);
|
1259
|
+
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
|
1260
|
+
device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
|
1660
1261
|
float yl[8];
|
1661
1262
|
float yh[8];
|
1662
1263
|
float sumf[N_DST]={0.f}, all_sum;
|
@@ -1712,7 +1313,7 @@ kernel void kernel_mul_mat_q4_K_f32(
|
|
1712
1313
|
for (int row = 0; row < N_DST; ++row) {
|
1713
1314
|
all_sum = simd_sum(sumf[row]);
|
1714
1315
|
if (tiisg == 0) {
|
1715
|
-
dst[r1*ne0 + first_row + row] = all_sum;
|
1316
|
+
dst[r1*ne0+ r2*ne0*ne1 + first_row + row] = all_sum;
|
1716
1317
|
}
|
1717
1318
|
}
|
1718
1319
|
}
|
@@ -1723,9 +1324,14 @@ kernel void kernel_mul_mat_q5_K_f32(
|
|
1723
1324
|
device const float * src1,
|
1724
1325
|
device float * dst,
|
1725
1326
|
constant int64_t & ne00,
|
1726
|
-
constant int64_t &
|
1727
|
-
constant int64_t &
|
1728
|
-
|
1327
|
+
constant int64_t & ne01[[buffer(4)]],
|
1328
|
+
constant int64_t & ne02[[buffer(5)]],
|
1329
|
+
constant int64_t & ne10[[buffer(9)]],
|
1330
|
+
constant int64_t & ne12[[buffer(11)]],
|
1331
|
+
constant int64_t & ne0[[buffer(15)]],
|
1332
|
+
constant int64_t & ne1[[buffer(16)]],
|
1333
|
+
constant uint & gqa[[buffer(17)]],
|
1334
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1729
1335
|
uint tiisg[[thread_index_in_simdgroup]],
|
1730
1336
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1731
1337
|
|
@@ -1733,11 +1339,12 @@ kernel void kernel_mul_mat_q5_K_f32(
|
|
1733
1339
|
|
1734
1340
|
const int64_t r0 = tgpig.x;
|
1735
1341
|
const int64_t r1 = tgpig.y;
|
1342
|
+
const int r2 = tgpig.z;
|
1736
1343
|
|
1737
1344
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
|
1738
|
-
|
1739
|
-
device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb;
|
1740
|
-
device const float * yy = (device const float *) src1 + r1*ne10;
|
1345
|
+
const uint offset0 = r2/gqa*(nb*ne0);
|
1346
|
+
device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
|
1347
|
+
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
|
1741
1348
|
|
1742
1349
|
float sumf[2]={0.f};
|
1743
1350
|
|
@@ -1871,7 +1478,7 @@ kernel void kernel_mul_mat_q5_K_f32(
|
|
1871
1478
|
for (int row = 0; row < 2; ++row) {
|
1872
1479
|
const float tot = simd_sum(sumf[row]);
|
1873
1480
|
if (tiisg == 0) {
|
1874
|
-
dst[r1*ne0 + first_row + row] = tot;
|
1481
|
+
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
|
1875
1482
|
}
|
1876
1483
|
}
|
1877
1484
|
|
@@ -1882,9 +1489,14 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|
1882
1489
|
device const float * src1,
|
1883
1490
|
device float * dst,
|
1884
1491
|
constant int64_t & ne00,
|
1885
|
-
constant int64_t &
|
1886
|
-
constant int64_t &
|
1887
|
-
|
1492
|
+
constant int64_t & ne01[[buffer(4)]],
|
1493
|
+
constant int64_t & ne02[[buffer(5)]],
|
1494
|
+
constant int64_t & ne10[[buffer(9)]],
|
1495
|
+
constant int64_t & ne12[[buffer(11)]],
|
1496
|
+
constant int64_t & ne0[[buffer(15)]],
|
1497
|
+
constant int64_t & ne1[[buffer(16)]],
|
1498
|
+
constant uint & gqa[[buffer(17)]],
|
1499
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1888
1500
|
uint tiisg[[thread_index_in_simdgroup]],
|
1889
1501
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1890
1502
|
|
@@ -1897,11 +1509,12 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|
1897
1509
|
|
1898
1510
|
const int64_t r0 = tgpig.x;
|
1899
1511
|
const int64_t r1 = tgpig.y;
|
1512
|
+
const int r2 = tgpig.z;
|
1900
1513
|
|
1901
1514
|
const int row = 2 * r0 + sgitg;
|
1902
|
-
|
1903
|
-
device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb
|
1904
|
-
device const float * yy = (device const float *) src1 + r1*ne10;
|
1515
|
+
const uint offset0 = r2/gqa*(nb*ne0);
|
1516
|
+
device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
|
1517
|
+
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
|
1905
1518
|
|
1906
1519
|
float sumf = 0;
|
1907
1520
|
|
@@ -1967,6 +1580,366 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|
1967
1580
|
|
1968
1581
|
const float tot = simd_sum(sumf);
|
1969
1582
|
if (tiisg == 0) {
|
1970
|
-
dst[r1*ne0 + row] = tot;
|
1583
|
+
dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
|
1584
|
+
}
|
1585
|
+
}
|
1586
|
+
|
1587
|
+
//============================= templates and their specializations =============================
|
1588
|
+
|
1589
|
+
template <typename type4x4>
|
1590
|
+
void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
|
1591
|
+
half4x4 temp = *(((device half4x4 *)src));
|
1592
|
+
for (int i = 0; i < 16; i++){
|
1593
|
+
reg[i/4][i%4] = temp[i/4][i%4];
|
1594
|
+
}
|
1595
|
+
}
|
1596
|
+
|
1597
|
+
template <typename type4x4>
|
1598
|
+
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
|
1599
|
+
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
1600
|
+
const half d = il ? (xb->d / 16.h) : xb->d;
|
1601
|
+
const half m = il ? (-8.h * 16.h) : -8.h;
|
1602
|
+
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
1603
|
+
const ushort mask1 = il ? 0xF000 : 0x0F00;
|
1604
|
+
|
1605
|
+
for (int i=0;i<8;i++) {
|
1606
|
+
reg[i/2][2*(i%2)] = (((qs[i] & mask0)) + m) * d;
|
1607
|
+
reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d;
|
1608
|
+
}
|
1609
|
+
}
|
1610
|
+
|
1611
|
+
template <typename type4x4>
|
1612
|
+
void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
|
1613
|
+
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
|
1614
|
+
const half d = il ? (xb->d / 16.h) : xb->d;
|
1615
|
+
const half m = xb->m;
|
1616
|
+
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
1617
|
+
const ushort mask1 = il ? 0xF000 : 0x0F00;
|
1618
|
+
|
1619
|
+
for (int i=0;i<8;i++) {
|
1620
|
+
reg[i/2][2*(i%2)] = (((qs[i] & mask0)) * d) + m;
|
1621
|
+
reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m;
|
1622
|
+
}
|
1623
|
+
}
|
1624
|
+
|
1625
|
+
template <typename type4x4>
|
1626
|
+
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
|
1627
|
+
const half d = xb->d;
|
1628
|
+
const half min = xb->dmin;
|
1629
|
+
device const uint8_t * q = (device const uint8_t *)xb->qs;
|
1630
|
+
half dl, ml;
|
1631
|
+
uint8_t sc = xb->scales[il];
|
1632
|
+
|
1633
|
+
#if QK_K == 256
|
1634
|
+
q = q + 32*(il/8) + 16*(il&1);
|
1635
|
+
il = (il/2)%4;
|
1636
|
+
#endif
|
1637
|
+
half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
1638
|
+
uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
1639
|
+
dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
|
1640
|
+
for (int i = 0; i < 16; ++i) {
|
1641
|
+
reg[i/4][i%4] = dl * (q[i] & mask) - ml;
|
1642
|
+
}
|
1643
|
+
}
|
1644
|
+
|
1645
|
+
template <typename type4x4>
|
1646
|
+
void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
|
1647
|
+
const float d_all = (float)(xb->d);
|
1648
|
+
device const uint8_t * q = (device const uint8_t *)xb->qs;
|
1649
|
+
device const uint8_t * h = (device const uint8_t *)xb->hmask;
|
1650
|
+
device const int8_t * scales = (device const int8_t *)xb->scales;
|
1651
|
+
|
1652
|
+
#if QK_K == 256
|
1653
|
+
q = q + 32 * (il/8) + 16 * (il&1);
|
1654
|
+
h = h + 16 * (il&1);
|
1655
|
+
uint8_t m = 1 << (il/2);
|
1656
|
+
uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \
|
1657
|
+
((il/4)>0 ? 12 : 3);
|
1658
|
+
uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
|
1659
|
+
uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
|
1660
|
+
int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) : \
|
1661
|
+
(scale_2&kmask2) | ((scale_1&kmask1) << 4);
|
1662
|
+
float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
|
1663
|
+
|
1664
|
+
il = (il/2)%4;
|
1665
|
+
float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
1666
|
+
uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
1667
|
+
|
1668
|
+
for (int i = 0; i < 16; ++i) {
|
1669
|
+
reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i] & m) ? 0 : 4.f/coef));
|
1670
|
+
}
|
1671
|
+
#else
|
1672
|
+
float kcoef = il&1 ? 1.f/16.f : 1.f;
|
1673
|
+
uint16_t kmask = il&1 ? 0xF0 : 0x0F;
|
1674
|
+
float dl = d_all * ((scales[il/2] & kmask) * kcoef - 8);
|
1675
|
+
float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
1676
|
+
uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
1677
|
+
uint8_t m = 1<<(il*2);
|
1678
|
+
for (int i = 0; i < 16; ++i) {
|
1679
|
+
reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i%8] & (m * (1 + i/8))) ? 0 : 4.f/coef));
|
1680
|
+
}
|
1681
|
+
#endif
|
1682
|
+
}
|
1683
|
+
|
1684
|
+
template <typename type4x4>
|
1685
|
+
void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
|
1686
|
+
device const uint8_t * q = xb->qs;
|
1687
|
+
|
1688
|
+
#if QK_K == 256
|
1689
|
+
const float d = (float)(xb->d);
|
1690
|
+
const float min = (float)(xb->dmin);
|
1691
|
+
short is = (il/4) * 2;
|
1692
|
+
q = q + (il/4) * 32 + 16 * (il&1);
|
1693
|
+
il = il%4;
|
1694
|
+
const uchar4 sc = get_scale_min_k4(is, xb->scales);
|
1695
|
+
const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
|
1696
|
+
const float ml = il<2 ? min * sc[1] : min * sc[3];
|
1697
|
+
#else
|
1698
|
+
q = q + 16 * (il&1);
|
1699
|
+
device const uint8_t * s = xb->scales;
|
1700
|
+
device const half2 * dh = (device const half2 *)xb->d;
|
1701
|
+
const float2 d = (float2)dh[0];
|
1702
|
+
const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
|
1703
|
+
const float ml = il<2 ? d[1] * (s[0]>>4) : d[1 ]* (s[1]>>4);
|
1704
|
+
#endif
|
1705
|
+
const ushort mask = il<2 ? 0x0F : 0xF0;
|
1706
|
+
for (int i = 0; i < 16; ++i) {
|
1707
|
+
reg[i/4][i%4] = dl * (q[i] & mask) - ml;
|
1708
|
+
}
|
1709
|
+
}
|
1710
|
+
|
1711
|
+
template <typename type4x4>
|
1712
|
+
void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) {
|
1713
|
+
device const uint8_t * q = xb->qs;
|
1714
|
+
device const uint8_t * qh = xb->qh;
|
1715
|
+
|
1716
|
+
#if QK_K == 256
|
1717
|
+
const float d = (float)(xb->d);
|
1718
|
+
const float min = (float)(xb->dmin);
|
1719
|
+
short is = (il/4) * 2;
|
1720
|
+
q = q + 32 * (il/4) + 16 * (il&1);
|
1721
|
+
qh = qh + 16 * (il&1);
|
1722
|
+
uint8_t ul = 1 << (il/2);
|
1723
|
+
il = il%4;
|
1724
|
+
const uchar4 sc = get_scale_min_k4(is, xb->scales);
|
1725
|
+
const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
|
1726
|
+
const float ml = il<2 ? min * sc[1] : min * sc[3];
|
1727
|
+
|
1728
|
+
const ushort mask = il<2 ? 0x0F : 0xF0;
|
1729
|
+
const float qh_val = il<2 ? 16.f : 256.f;
|
1730
|
+
for (int i = 0; i < 16; ++i) {
|
1731
|
+
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
|
1732
|
+
}
|
1733
|
+
#else
|
1734
|
+
q = q + 16 * (il&1);
|
1735
|
+
device const int8_t * s = xb->scales;
|
1736
|
+
const float dl = xb->d * s[il];
|
1737
|
+
uint8_t m = 1<<(il*2);
|
1738
|
+
const float coef = il<2 ? 1.f : 1.f/16.f;
|
1739
|
+
const ushort mask = il<2 ? 0x0F : 0xF0;
|
1740
|
+
for (int i = 0; i < 16; ++i) {
|
1741
|
+
reg[i/4][i%4] = coef * dl * ((q[i] & mask) - (qh[i%8] & (m*(1+i/8)) ? 0.f : 16.f/coef));
|
1742
|
+
}
|
1743
|
+
#endif
|
1744
|
+
}
|
1745
|
+
|
1746
|
+
template <typename type4x4>
|
1747
|
+
void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
|
1748
|
+
const float d_all = (float)(xb->d);
|
1749
|
+
device const uint8_t * ql = (device const uint8_t *)xb->ql;
|
1750
|
+
device const uint8_t * qh = (device const uint8_t *)xb->qh;
|
1751
|
+
device const int8_t * scales = (device const int8_t *)xb->scales;
|
1752
|
+
|
1753
|
+
#if QK_K == 256
|
1754
|
+
ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
|
1755
|
+
qh = qh + 32*(il/8) + 16*(il&1);
|
1756
|
+
float sc = scales[(il%2) + 2 * ((il/2))];
|
1757
|
+
il = (il/2)%4;
|
1758
|
+
#else
|
1759
|
+
ql = ql + 16 * (il&1);
|
1760
|
+
float sc = scales[il];
|
1761
|
+
#endif
|
1762
|
+
for (int i = 0; i < 16; ++i) {
|
1763
|
+
uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
1764
|
+
uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
|
1765
|
+
const float coef = il>1 ? 1.f/16.f : 1.f;
|
1766
|
+
float q = il&1 ? ((ql[i]&kmask2)|((qh[i]&kmask1)<<2)) - 32.f/coef : \
|
1767
|
+
((ql[i]&kmask2)|((qh[i]&kmask1)<<4)) - 32.f/coef;
|
1768
|
+
reg[i/4][i%4] = d_all * sc * q * coef;
|
1769
|
+
}
|
1770
|
+
}
|
1771
|
+
|
1772
|
+
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
1773
|
+
kernel void kernel_get_rows(
|
1774
|
+
device const void * src0,
|
1775
|
+
device const int * src1,
|
1776
|
+
device float * dst,
|
1777
|
+
constant int64_t & ne00,
|
1778
|
+
constant uint64_t & nb01,
|
1779
|
+
constant uint64_t & nb1,
|
1780
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
1781
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
1782
|
+
uint tptg[[threads_per_threadgroup]]) {
|
1783
|
+
const int i = tgpig;
|
1784
|
+
const int r = ((device int32_t *) src1)[i];
|
1785
|
+
|
1786
|
+
for (int ind = tiitg; ind < ne00/16; ind += tptg) {
|
1787
|
+
float4x4 temp;
|
1788
|
+
dequantize_func(
|
1789
|
+
((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
|
1790
|
+
*(((device float4x4 *) ((device char *) dst + i*nb1)) + ind) = temp;
|
1791
|
+
}
|
1792
|
+
}
|
1793
|
+
|
1794
|
+
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
|
1795
|
+
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix A
|
1796
|
+
#define BLOCK_SIZE_K 32
|
1797
|
+
#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
|
1798
|
+
#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
|
1799
|
+
#define THREAD_PER_BLOCK 128
|
1800
|
+
#define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers
|
1801
|
+
#define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers
|
1802
|
+
#define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8
|
1803
|
+
#define SG_MAT_ROW 8
|
1804
|
+
|
1805
|
+
// each block_q contains 16*nl weights
|
1806
|
+
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
1807
|
+
kernel void kernel_mul_mm(device const uchar * src0,
|
1808
|
+
device const float * src1,
|
1809
|
+
device float * dst,
|
1810
|
+
constant int64_t & ne00,
|
1811
|
+
constant int64_t & ne02,
|
1812
|
+
constant int64_t & nb01,
|
1813
|
+
constant int64_t & nb02,
|
1814
|
+
constant int64_t & ne12,
|
1815
|
+
constant int64_t & ne0,
|
1816
|
+
constant int64_t & ne1,
|
1817
|
+
constant uint & gqa,
|
1818
|
+
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
1819
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1820
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
1821
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1822
|
+
|
1823
|
+
threadgroup half * sa = ((threadgroup half *)shared_memory);
|
1824
|
+
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
|
1825
|
+
|
1826
|
+
const uint r0 = tgpig.y;
|
1827
|
+
const uint r1 = tgpig.x;
|
1828
|
+
const uint im = tgpig.z;
|
1829
|
+
// if this block is of 64x32 shape or smaller
|
1830
|
+
short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
|
1831
|
+
short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
|
1832
|
+
// a thread shouldn't load data outside of the matrix
|
1833
|
+
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
1834
|
+
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
1835
|
+
|
1836
|
+
simdgroup_half8x8 ma[4];
|
1837
|
+
simdgroup_float8x8 mb[2];
|
1838
|
+
simdgroup_float8x8 c_res[8];
|
1839
|
+
for (int i = 0; i < 8; i++){
|
1840
|
+
c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
|
1841
|
+
}
|
1842
|
+
|
1843
|
+
short il = (tiitg % THREAD_PER_ROW);
|
1844
|
+
uint offset0 = im/gqa*nb02; ushort offset1 = il/nl;
|
1845
|
+
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
|
1846
|
+
device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \
|
1847
|
+
+ BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne00 * ne1;
|
1848
|
+
|
1849
|
+
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
1850
|
+
//load data and store to threadgroup memory
|
1851
|
+
half4x4 temp_a;
|
1852
|
+
dequantize_func(x, il, temp_a);
|
1853
|
+
#pragma unroll(16)
|
1854
|
+
for (int i = 0; i < 16; i++) {
|
1855
|
+
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
|
1856
|
+
+ 16 * (tiitg % THREAD_PER_ROW) + 8 * (i / 8)) \
|
1857
|
+
+ (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
|
1858
|
+
}
|
1859
|
+
*(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) \
|
1860
|
+
= *((device float2x4 *)y);
|
1861
|
+
il = (il + 2 < nl) ? il + 2 : il % 2;
|
1862
|
+
x = (il < 2) ? x + (2+nl-1)/nl : x;
|
1863
|
+
y += BLOCK_SIZE_K;
|
1864
|
+
|
1865
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1866
|
+
//load matrices from threadgroup memory and conduct outer products
|
1867
|
+
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
|
1868
|
+
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
|
1869
|
+
#pragma unroll(4)
|
1870
|
+
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
|
1871
|
+
#pragma unroll(4)
|
1872
|
+
for (int i = 0; i < 4; i++) {
|
1873
|
+
simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
|
1874
|
+
}
|
1875
|
+
simdgroup_barrier(mem_flags::mem_none);
|
1876
|
+
#pragma unroll(2)
|
1877
|
+
for (int i = 0; i < 2; i++) {
|
1878
|
+
simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
|
1879
|
+
}
|
1880
|
+
|
1881
|
+
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
|
1882
|
+
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
|
1883
|
+
#pragma unroll(8)
|
1884
|
+
for (int i = 0; i < 8; i++){
|
1885
|
+
simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
|
1886
|
+
}
|
1887
|
+
}
|
1888
|
+
}
|
1889
|
+
|
1890
|
+
if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
|
1891
|
+
device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1) \
|
1892
|
+
+ (BLOCK_SIZE_N * r1 + 16 * (sgitg>>1)) * ne0 + im*ne1*ne0;
|
1893
|
+
for (int i = 0; i < 8; i++) {
|
1894
|
+
simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
|
1895
|
+
}
|
1896
|
+
} else {
|
1897
|
+
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
1898
|
+
threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
|
1899
|
+
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
|
1900
|
+
for (int i = 0; i < 8; i++) {
|
1901
|
+
simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
|
1902
|
+
}
|
1903
|
+
|
1904
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1905
|
+
device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
|
1906
|
+
if (sgitg==0) {
|
1907
|
+
for (int i = 0; i < n_rows; i++) {
|
1908
|
+
for (int j = tiitg; j< n_cols; j += BLOCK_SIZE_N) {
|
1909
|
+
*(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
|
1910
|
+
}
|
1911
|
+
}
|
1912
|
+
}
|
1971
1913
|
}
|
1972
1914
|
}
|
1915
|
+
|
1916
|
+
#if QK_K == 256
|
1917
|
+
#define QK_NL 16
|
1918
|
+
#else
|
1919
|
+
#define QK_NL 4
|
1920
|
+
#endif
|
1921
|
+
|
1922
|
+
typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
|
1923
|
+
constant uint64_t &, constant uint64_t &, uint, uint, uint);
|
1924
|
+
|
1925
|
+
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
1926
|
+
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
1927
|
+
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
1928
|
+
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
|
1929
|
+
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
|
1930
|
+
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
|
1931
|
+
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
|
1932
|
+
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
|
1933
|
+
|
1934
|
+
typedef void (mat_mm_t)(device const uchar *, device const float *, device float *, constant int64_t &,\
|
1935
|
+
constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \
|
1936
|
+
constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint);
|
1937
|
+
|
1938
|
+
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
1939
|
+
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
|
1940
|
+
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
|
1941
|
+
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
|
1942
|
+
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
|
1943
|
+
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
|
1944
|
+
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
|
1945
|
+
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|