llama_cpp 0.7.0 → 0.8.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +12 -0
- data/ext/llama_cpp/extconf.rb +1 -1
- data/ext/llama_cpp/llama_cpp.cpp +41 -21
- data/ext/llama_cpp/src/ggml-alloc.c +62 -107
- data/ext/llama_cpp/src/ggml-alloc.h +11 -5
- data/ext/llama_cpp/src/ggml-backend.c +385 -0
- data/ext/llama_cpp/src/ggml-backend.h +143 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +500 -78
- data/ext/llama_cpp/src/ggml-cuda.h +4 -0
- data/ext/llama_cpp/src/ggml-metal.h +18 -1
- data/ext/llama_cpp/src/ggml-metal.m +396 -127
- data/ext/llama_cpp/src/ggml-metal.metal +290 -46
- data/ext/llama_cpp/src/ggml-opencl.cpp +47 -71
- data/ext/llama_cpp/src/ggml.c +71 -55
- data/ext/llama_cpp/src/ggml.h +15 -9
- data/ext/llama_cpp/src/k_quants.c +12 -20
- data/ext/llama_cpp/src/k_quants.h +5 -5
- data/ext/llama_cpp/src/llama.cpp +1851 -250
- data/ext/llama_cpp/src/llama.h +18 -12
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +4 -4
- metadata +5 -3
@@ -13,11 +13,26 @@ typedef struct {
|
|
13
13
|
|
14
14
|
#define QK4_1 32
|
15
15
|
typedef struct {
|
16
|
-
half d;
|
17
|
-
half m;
|
16
|
+
half d; // delta
|
17
|
+
half m; // min
|
18
18
|
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
19
19
|
} block_q4_1;
|
20
20
|
|
21
|
+
#define QK5_0 32
|
22
|
+
typedef struct {
|
23
|
+
half d; // delta
|
24
|
+
uint8_t qh[4]; // 5-th bit of quants
|
25
|
+
uint8_t qs[QK5_0 / 2]; // nibbles / quants
|
26
|
+
} block_q5_0;
|
27
|
+
|
28
|
+
#define QK5_1 32
|
29
|
+
typedef struct {
|
30
|
+
half d; // delta
|
31
|
+
half m; // min
|
32
|
+
uint8_t qh[4]; // 5-th bit of quants
|
33
|
+
uint8_t qs[QK5_1 / 2]; // nibbles / quants
|
34
|
+
} block_q5_1;
|
35
|
+
|
21
36
|
#define QK8_0 32
|
22
37
|
typedef struct {
|
23
38
|
half d; // delta
|
@@ -132,6 +147,13 @@ kernel void kernel_relu(
|
|
132
147
|
dst[tpig] = max(0.0f, src0[tpig]);
|
133
148
|
}
|
134
149
|
|
150
|
+
kernel void kernel_sqr(
|
151
|
+
device const float * src0,
|
152
|
+
device float * dst,
|
153
|
+
uint tpig[[thread_position_in_grid]]) {
|
154
|
+
dst[tpig] = src0[tpig] * src0[tpig];
|
155
|
+
}
|
156
|
+
|
135
157
|
constant float GELU_COEF_A = 0.044715f;
|
136
158
|
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
137
159
|
|
@@ -338,10 +360,11 @@ kernel void kernel_rms_norm(
|
|
338
360
|
uint sgitg[[simdgroup_index_in_threadgroup]],
|
339
361
|
uint tiisg[[thread_index_in_simdgroup]],
|
340
362
|
uint ntg[[threads_per_threadgroup]]) {
|
341
|
-
device const float4 * x
|
342
|
-
device const float
|
343
|
-
|
344
|
-
|
363
|
+
device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
|
364
|
+
device const float * x_scalar = (device const float *) x;
|
365
|
+
|
366
|
+
float4 sumf = 0;
|
367
|
+
float all_sum = 0;
|
345
368
|
|
346
369
|
// parallel sum
|
347
370
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
@@ -354,6 +377,7 @@ kernel void kernel_rms_norm(
|
|
354
377
|
}
|
355
378
|
|
356
379
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
380
|
+
|
357
381
|
// broadcast, simd group number is ntg / 32
|
358
382
|
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
359
383
|
if (tpitg < i) {
|
@@ -361,7 +385,9 @@ kernel void kernel_rms_norm(
|
|
361
385
|
}
|
362
386
|
}
|
363
387
|
if (tpitg == 0) {
|
364
|
-
for (int i = 4 * (ne00 / 4); i < ne00; i++) {
|
388
|
+
for (int i = 4 * (ne00 / 4); i < ne00; i++) {
|
389
|
+
sum[0] += x_scalar[i];
|
390
|
+
}
|
365
391
|
sum[0] /= ne00;
|
366
392
|
}
|
367
393
|
|
@@ -376,7 +402,9 @@ kernel void kernel_rms_norm(
|
|
376
402
|
y[i00] = x[i00] * scale;
|
377
403
|
}
|
378
404
|
if (tpitg == 0) {
|
379
|
-
for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
|
405
|
+
for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
|
406
|
+
y_scalar[i00] = x_scalar[i00] * scale;
|
407
|
+
}
|
380
408
|
}
|
381
409
|
}
|
382
410
|
|
@@ -386,8 +414,11 @@ kernel void kernel_rms_norm(
|
|
386
414
|
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
387
415
|
inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
|
388
416
|
float d = qb_curr->d;
|
417
|
+
|
389
418
|
float2 acc = 0.f;
|
419
|
+
|
390
420
|
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
|
421
|
+
|
391
422
|
for (int i = 0; i < 8; i+=2) {
|
392
423
|
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
|
393
424
|
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
|
@@ -404,8 +435,11 @@ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thre
|
|
404
435
|
inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
|
405
436
|
float d = qb_curr->d;
|
406
437
|
float m = qb_curr->m;
|
407
|
-
|
438
|
+
|
408
439
|
float2 acc = 0.f;
|
440
|
+
|
441
|
+
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
|
442
|
+
|
409
443
|
for (int i = 0; i < 8; i+=2) {
|
410
444
|
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
|
411
445
|
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
|
@@ -415,9 +449,52 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
|
|
415
449
|
return d * (acc[0] + acc[1]) + sumy * m;
|
416
450
|
}
|
417
451
|
|
452
|
+
// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
|
453
|
+
// il indicates where the q5 quants begin (0 or QK5_0/4)
|
454
|
+
// we assume that the yl's have been multiplied with the appropriate scale factor
|
455
|
+
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
456
|
+
inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
|
457
|
+
float d = qb_curr->d;
|
458
|
+
|
459
|
+
float2 acc = 0.f;
|
460
|
+
|
461
|
+
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2);
|
462
|
+
const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
|
463
|
+
|
464
|
+
for (int i = 0; i < 8; i+=2) {
|
465
|
+
acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
|
466
|
+
+ yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
|
467
|
+
acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
|
468
|
+
+ yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
|
469
|
+
}
|
470
|
+
return d * (sumy * -16.f + acc[0] + acc[1]);
|
471
|
+
}
|
472
|
+
|
473
|
+
// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
|
474
|
+
// il indicates where the q5 quants begin (0 or QK5_1/4)
|
475
|
+
// we assume that the yl's have been multiplied with the appropriate scale factor
|
476
|
+
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
477
|
+
inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) {
|
478
|
+
float d = qb_curr->d;
|
479
|
+
float m = qb_curr->m;
|
480
|
+
|
481
|
+
float2 acc = 0.f;
|
482
|
+
|
483
|
+
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2);
|
484
|
+
const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
|
485
|
+
|
486
|
+
for (int i = 0; i < 8; i+=2) {
|
487
|
+
acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
|
488
|
+
+ yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
|
489
|
+
acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
|
490
|
+
+ yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
|
491
|
+
}
|
492
|
+
return d * (acc[0] + acc[1]) + sumy * m;
|
493
|
+
}
|
494
|
+
|
418
495
|
// putting them in the kernel cause a significant performance penalty
|
419
|
-
#define N_DST 4
|
420
|
-
#define N_SIMDGROUP 2
|
496
|
+
#define N_DST 4 // each SIMD group works on 4 rows
|
497
|
+
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
421
498
|
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
422
499
|
//Note: This is a template, but strictly speaking it only applies to
|
423
500
|
// quantizations where the block size is 32. It also does not
|
@@ -428,18 +505,23 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
|
|
428
505
|
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
|
429
506
|
uint3 tgpig, uint tiisg, uint sgitg) {
|
430
507
|
const int nb = ne00/QK4_0;
|
508
|
+
|
431
509
|
const int r0 = tgpig.x;
|
432
510
|
const int r1 = tgpig.y;
|
433
511
|
const int im = tgpig.z;
|
512
|
+
|
434
513
|
const int first_row = (r0 * nsg + sgitg) * nr;
|
514
|
+
|
435
515
|
const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
|
516
|
+
|
436
517
|
device const block_q_type * x = (device const block_q_type *) src0 + offset0;
|
437
518
|
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
438
|
-
float yl[16]; // src1 vector cache
|
439
|
-
float sumf[nr]={0.f};
|
440
519
|
|
441
|
-
|
442
|
-
|
520
|
+
float yl[16]; // src1 vector cache
|
521
|
+
float sumf[nr] = {0.f};
|
522
|
+
|
523
|
+
const int ix = (tiisg/2);
|
524
|
+
const int il = (tiisg%2)*8;
|
443
525
|
|
444
526
|
device const float * yb = y + ix * QK4_0 + il;
|
445
527
|
|
@@ -450,6 +532,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
|
|
450
532
|
sumy += yb[i] + yb[i+1];
|
451
533
|
yl[i+0] = yb[i+ 0];
|
452
534
|
yl[i+1] = yb[i+ 1]/256.f;
|
535
|
+
|
453
536
|
sumy += yb[i+16] + yb[i+17];
|
454
537
|
yl[i+8] = yb[i+16]/16.f;
|
455
538
|
yl[i+9] = yb[i+17]/4096.f;
|
@@ -465,12 +548,12 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
|
|
465
548
|
for (int row = 0; row < nr; ++row) {
|
466
549
|
const float tot = simd_sum(sumf[row]);
|
467
550
|
if (tiisg == 0 && first_row + row < ne01) {
|
468
|
-
dst[
|
551
|
+
dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot;
|
469
552
|
}
|
470
553
|
}
|
471
554
|
}
|
472
555
|
|
473
|
-
kernel void
|
556
|
+
kernel void kernel_mul_mv_q4_0_f32(
|
474
557
|
device const void * src0,
|
475
558
|
device const float * src1,
|
476
559
|
device float * dst,
|
@@ -483,12 +566,12 @@ kernel void kernel_mul_mat_q4_0_f32(
|
|
483
566
|
constant int64_t & ne1[[buffer(16)]],
|
484
567
|
constant uint & gqa[[buffer(17)]],
|
485
568
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
486
|
-
uint
|
487
|
-
uint
|
569
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
570
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
488
571
|
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
489
572
|
}
|
490
573
|
|
491
|
-
kernel void
|
574
|
+
kernel void kernel_mul_mv_q4_1_f32(
|
492
575
|
device const void * src0,
|
493
576
|
device const float * src1,
|
494
577
|
device float * dst,
|
@@ -506,9 +589,46 @@ kernel void kernel_mul_mat_q4_1_f32(
|
|
506
589
|
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
507
590
|
}
|
508
591
|
|
592
|
+
kernel void kernel_mul_mv_q5_0_f32(
|
593
|
+
device const void * src0,
|
594
|
+
device const float * src1,
|
595
|
+
device float * dst,
|
596
|
+
constant int64_t & ne00,
|
597
|
+
constant int64_t & ne01[[buffer(4)]],
|
598
|
+
constant int64_t & ne02[[buffer(5)]],
|
599
|
+
constant int64_t & ne10[[buffer(9)]],
|
600
|
+
constant int64_t & ne12[[buffer(11)]],
|
601
|
+
constant int64_t & ne0[[buffer(15)]],
|
602
|
+
constant int64_t & ne1[[buffer(16)]],
|
603
|
+
constant uint & gqa[[buffer(17)]],
|
604
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
605
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
606
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
607
|
+
mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
608
|
+
}
|
609
|
+
|
610
|
+
kernel void kernel_mul_mv_q5_1_f32(
|
611
|
+
device const void * src0,
|
612
|
+
device const float * src1,
|
613
|
+
device float * dst,
|
614
|
+
constant int64_t & ne00,
|
615
|
+
constant int64_t & ne01[[buffer(4)]],
|
616
|
+
constant int64_t & ne02[[buffer(5)]],
|
617
|
+
constant int64_t & ne10[[buffer(9)]],
|
618
|
+
constant int64_t & ne12[[buffer(11)]],
|
619
|
+
constant int64_t & ne0[[buffer(15)]],
|
620
|
+
constant int64_t & ne1[[buffer(16)]],
|
621
|
+
constant uint & gqa[[buffer(17)]],
|
622
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
623
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
624
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
625
|
+
mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
626
|
+
}
|
627
|
+
|
628
|
+
|
509
629
|
#define NB_Q8_0 8
|
510
630
|
|
511
|
-
kernel void
|
631
|
+
kernel void kernel_mul_mv_q8_0_f32(
|
512
632
|
device const void * src0,
|
513
633
|
device const float * src1,
|
514
634
|
device float * dst,
|
@@ -572,7 +692,7 @@ kernel void kernel_mul_mat_q8_0_f32(
|
|
572
692
|
|
573
693
|
#define N_F32_F32 4
|
574
694
|
|
575
|
-
kernel void
|
695
|
+
kernel void kernel_mul_mv_f32_f32(
|
576
696
|
device const char * src0,
|
577
697
|
device const char * src1,
|
578
698
|
device float * dst,
|
@@ -643,7 +763,7 @@ kernel void kernel_mul_mat_f32_f32(
|
|
643
763
|
}
|
644
764
|
}
|
645
765
|
|
646
|
-
kernel void
|
766
|
+
kernel void kernel_mul_mv_f16_f32_1row(
|
647
767
|
device const char * src0,
|
648
768
|
device const char * src1,
|
649
769
|
device float * dst,
|
@@ -662,7 +782,7 @@ kernel void kernel_mul_mat_f16_f32_1row(
|
|
662
782
|
constant int64_t & ne0,
|
663
783
|
constant int64_t & ne1,
|
664
784
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
665
|
-
uint
|
785
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
666
786
|
|
667
787
|
const int64_t r0 = tgpig.x;
|
668
788
|
const int64_t r1 = tgpig.y;
|
@@ -697,7 +817,7 @@ kernel void kernel_mul_mat_f16_f32_1row(
|
|
697
817
|
|
698
818
|
#define N_F16_F32 4
|
699
819
|
|
700
|
-
kernel void
|
820
|
+
kernel void kernel_mul_mv_f16_f32(
|
701
821
|
device const char * src0,
|
702
822
|
device const char * src1,
|
703
823
|
device float * dst,
|
@@ -769,7 +889,7 @@ kernel void kernel_mul_mat_f16_f32(
|
|
769
889
|
}
|
770
890
|
|
771
891
|
// Assumes row size (ne00) is a multiple of 4
|
772
|
-
kernel void
|
892
|
+
kernel void kernel_mul_mv_f16_f32_l4(
|
773
893
|
device const char * src0,
|
774
894
|
device const char * src1,
|
775
895
|
device float * dst,
|
@@ -1098,6 +1218,62 @@ kernel void kernel_cpy_f32_f32(
|
|
1098
1218
|
}
|
1099
1219
|
}
|
1100
1220
|
|
1221
|
+
kernel void kernel_concat(
|
1222
|
+
device const char * src0,
|
1223
|
+
device const char * src1,
|
1224
|
+
device char * dst,
|
1225
|
+
constant int64_t & ne00,
|
1226
|
+
constant int64_t & ne01,
|
1227
|
+
constant int64_t & ne02,
|
1228
|
+
constant int64_t & ne03,
|
1229
|
+
constant uint64_t & nb00,
|
1230
|
+
constant uint64_t & nb01,
|
1231
|
+
constant uint64_t & nb02,
|
1232
|
+
constant uint64_t & nb03,
|
1233
|
+
constant int64_t & ne10,
|
1234
|
+
constant int64_t & ne11,
|
1235
|
+
constant int64_t & ne12,
|
1236
|
+
constant int64_t & ne13,
|
1237
|
+
constant uint64_t & nb10,
|
1238
|
+
constant uint64_t & nb11,
|
1239
|
+
constant uint64_t & nb12,
|
1240
|
+
constant uint64_t & nb13,
|
1241
|
+
constant int64_t & ne0,
|
1242
|
+
constant int64_t & ne1,
|
1243
|
+
constant int64_t & ne2,
|
1244
|
+
constant int64_t & ne3,
|
1245
|
+
constant uint64_t & nb0,
|
1246
|
+
constant uint64_t & nb1,
|
1247
|
+
constant uint64_t & nb2,
|
1248
|
+
constant uint64_t & nb3,
|
1249
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1250
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
1251
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
1252
|
+
|
1253
|
+
const int64_t i03 = tgpig.z;
|
1254
|
+
const int64_t i02 = tgpig.y;
|
1255
|
+
const int64_t i01 = tgpig.x;
|
1256
|
+
|
1257
|
+
const int64_t i13 = i03 % ne13;
|
1258
|
+
const int64_t i12 = i02 % ne12;
|
1259
|
+
const int64_t i11 = i01 % ne11;
|
1260
|
+
|
1261
|
+
device const char * src0_ptr = src0 + i03 * nb03 + i02 * nb02 + i01 * nb01 + tpitg.x*nb00;
|
1262
|
+
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
|
1263
|
+
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
|
1264
|
+
|
1265
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
1266
|
+
if (i02 < ne02) {
|
1267
|
+
((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0];
|
1268
|
+
src0_ptr += ntg.x*nb00;
|
1269
|
+
} else {
|
1270
|
+
((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0];
|
1271
|
+
src1_ptr += ntg.x*nb10;
|
1272
|
+
}
|
1273
|
+
dst_ptr += ntg.x*nb0;
|
1274
|
+
}
|
1275
|
+
}
|
1276
|
+
|
1101
1277
|
//============================================ k-quants ======================================================
|
1102
1278
|
|
1103
1279
|
#ifndef QK_K
|
@@ -1190,7 +1366,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
|
|
1190
1366
|
|
1191
1367
|
//====================================== dot products =========================
|
1192
1368
|
|
1193
|
-
kernel void
|
1369
|
+
kernel void kernel_mul_mv_q2_K_f32(
|
1194
1370
|
device const void * src0,
|
1195
1371
|
device const float * src1,
|
1196
1372
|
device float * dst,
|
@@ -1334,7 +1510,7 @@ kernel void kernel_mul_mat_q2_K_f32(
|
|
1334
1510
|
}
|
1335
1511
|
|
1336
1512
|
#if QK_K == 256
|
1337
|
-
kernel void
|
1513
|
+
kernel void kernel_mul_mv_q3_K_f32(
|
1338
1514
|
device const void * src0,
|
1339
1515
|
device const float * src1,
|
1340
1516
|
device float * dst,
|
@@ -1486,7 +1662,7 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1486
1662
|
}
|
1487
1663
|
}
|
1488
1664
|
#else
|
1489
|
-
kernel void
|
1665
|
+
kernel void kernel_mul_mv_q3_K_f32(
|
1490
1666
|
device const void * src0,
|
1491
1667
|
device const float * src1,
|
1492
1668
|
device float * dst,
|
@@ -1557,7 +1733,7 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1557
1733
|
#endif
|
1558
1734
|
|
1559
1735
|
#if QK_K == 256
|
1560
|
-
kernel void
|
1736
|
+
kernel void kernel_mul_mv_q4_K_f32(
|
1561
1737
|
device const void * src0,
|
1562
1738
|
device const float * src1,
|
1563
1739
|
device float * dst,
|
@@ -1663,7 +1839,7 @@ kernel void kernel_mul_mat_q4_K_f32(
|
|
1663
1839
|
}
|
1664
1840
|
}
|
1665
1841
|
#else
|
1666
|
-
kernel void
|
1842
|
+
kernel void kernel_mul_mv_q4_K_f32(
|
1667
1843
|
device const void * src0,
|
1668
1844
|
device const float * src1,
|
1669
1845
|
device float * dst,
|
@@ -1752,7 +1928,7 @@ kernel void kernel_mul_mat_q4_K_f32(
|
|
1752
1928
|
}
|
1753
1929
|
#endif
|
1754
1930
|
|
1755
|
-
kernel void
|
1931
|
+
kernel void kernel_mul_mv_q5_K_f32(
|
1756
1932
|
device const void * src0,
|
1757
1933
|
device const float * src1,
|
1758
1934
|
device float * dst,
|
@@ -1925,7 +2101,7 @@ kernel void kernel_mul_mat_q5_K_f32(
|
|
1925
2101
|
|
1926
2102
|
}
|
1927
2103
|
|
1928
|
-
kernel void
|
2104
|
+
kernel void kernel_mul_mv_q6_K_f32(
|
1929
2105
|
device const void * src0,
|
1930
2106
|
device const float * src1,
|
1931
2107
|
device float * dst,
|
@@ -2074,6 +2250,62 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
|
|
2074
2250
|
}
|
2075
2251
|
}
|
2076
2252
|
|
2253
|
+
template <typename type4x4>
|
2254
|
+
void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
|
2255
|
+
device const uint16_t * qs = ((device const uint16_t *)xb + 3);
|
2256
|
+
const float d = xb->d;
|
2257
|
+
const float md = -16.h * xb->d;
|
2258
|
+
const ushort mask = il ? 0x00F0 : 0x000F;
|
2259
|
+
|
2260
|
+
const uint32_t qh = *((device const uint32_t *)xb->qh);
|
2261
|
+
|
2262
|
+
const int x_mv = il ? 4 : 0;
|
2263
|
+
|
2264
|
+
const int gh_mv = il ? 12 : 0;
|
2265
|
+
const int gh_bk = il ? 0 : 4;
|
2266
|
+
|
2267
|
+
for (int i = 0; i < 8; i++) {
|
2268
|
+
// extract the 5-th bits for x0 and x1
|
2269
|
+
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
|
2270
|
+
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
|
2271
|
+
|
2272
|
+
// combine the 4-bits from qs with the 5th bit
|
2273
|
+
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
|
2274
|
+
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
|
2275
|
+
|
2276
|
+
reg[i/2][2*(i%2)+0] = d * x0 + md;
|
2277
|
+
reg[i/2][2*(i%2)+1] = d * x1 + md;
|
2278
|
+
}
|
2279
|
+
}
|
2280
|
+
|
2281
|
+
template <typename type4x4>
|
2282
|
+
void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
|
2283
|
+
device const uint16_t * qs = ((device const uint16_t *)xb + 4);
|
2284
|
+
const float d = xb->d;
|
2285
|
+
const float m = xb->m;
|
2286
|
+
const ushort mask = il ? 0x00F0 : 0x000F;
|
2287
|
+
|
2288
|
+
const uint32_t qh = *((device const uint32_t *)xb->qh);
|
2289
|
+
|
2290
|
+
const int x_mv = il ? 4 : 0;
|
2291
|
+
|
2292
|
+
const int gh_mv = il ? 12 : 0;
|
2293
|
+
const int gh_bk = il ? 0 : 4;
|
2294
|
+
|
2295
|
+
for (int i = 0; i < 8; i++) {
|
2296
|
+
// extract the 5-th bits for x0 and x1
|
2297
|
+
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
|
2298
|
+
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
|
2299
|
+
|
2300
|
+
// combine the 4-bits from qs with the 5th bit
|
2301
|
+
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
|
2302
|
+
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
|
2303
|
+
|
2304
|
+
reg[i/2][2*(i%2)+0] = d * x0 + m;
|
2305
|
+
reg[i/2][2*(i%2)+1] = d * x1 + m;
|
2306
|
+
}
|
2307
|
+
}
|
2308
|
+
|
2077
2309
|
template <typename type4x4>
|
2078
2310
|
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
|
2079
2311
|
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
@@ -2263,7 +2495,7 @@ kernel void kernel_get_rows(
|
|
2263
2495
|
}
|
2264
2496
|
|
2265
2497
|
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
|
2266
|
-
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix
|
2498
|
+
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
|
2267
2499
|
#define BLOCK_SIZE_K 32
|
2268
2500
|
#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
|
2269
2501
|
#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
|
@@ -2300,9 +2532,11 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
2300
2532
|
const uint r0 = tgpig.y;
|
2301
2533
|
const uint r1 = tgpig.x;
|
2302
2534
|
const uint im = tgpig.z;
|
2535
|
+
|
2303
2536
|
// if this block is of 64x32 shape or smaller
|
2304
2537
|
short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
|
2305
2538
|
short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
|
2539
|
+
|
2306
2540
|
// a thread shouldn't load data outside of the matrix
|
2307
2541
|
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
2308
2542
|
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
@@ -2326,26 +2560,30 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
2326
2560
|
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
2327
2561
|
|
2328
2562
|
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
2329
|
-
//load data and store to threadgroup memory
|
2563
|
+
// load data and store to threadgroup memory
|
2330
2564
|
half4x4 temp_a;
|
2331
2565
|
dequantize_func(x, il, temp_a);
|
2332
2566
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
2567
|
+
|
2333
2568
|
#pragma unroll(16)
|
2334
2569
|
for (int i = 0; i < 16; i++) {
|
2335
2570
|
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
|
2336
|
-
+
|
2337
|
-
+
|
2571
|
+
+ (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
|
2572
|
+
+ (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
|
2338
2573
|
}
|
2339
|
-
|
2340
|
-
|
2574
|
+
|
2575
|
+
*(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
|
2576
|
+
|
2341
2577
|
il = (il + 2 < nl) ? il + 2 : il % 2;
|
2342
2578
|
x = (il < 2) ? x + (2+nl-1)/nl : x;
|
2343
2579
|
y += BLOCK_SIZE_K;
|
2344
2580
|
|
2345
2581
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
2346
|
-
|
2582
|
+
|
2583
|
+
// load matrices from threadgroup memory and conduct outer products
|
2347
2584
|
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
|
2348
2585
|
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
|
2586
|
+
|
2349
2587
|
#pragma unroll(4)
|
2350
2588
|
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
|
2351
2589
|
#pragma unroll(4)
|
@@ -2360,6 +2598,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
2360
2598
|
|
2361
2599
|
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
|
2362
2600
|
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
|
2601
|
+
|
2363
2602
|
#pragma unroll(8)
|
2364
2603
|
for (int i = 0; i < 8; i++){
|
2365
2604
|
simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
|
@@ -2368,25 +2607,26 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
2368
2607
|
}
|
2369
2608
|
|
2370
2609
|
if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
|
2371
|
-
device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1) \
|
2372
|
-
|
2610
|
+
device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
|
2611
|
+
+ (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
|
2373
2612
|
for (int i = 0; i < 8; i++) {
|
2374
2613
|
simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
|
2375
2614
|
}
|
2376
2615
|
} else {
|
2377
2616
|
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
2378
2617
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
2379
|
-
threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
|
2618
|
+
threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
|
2380
2619
|
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
|
2381
2620
|
for (int i = 0; i < 8; i++) {
|
2382
2621
|
simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
|
2383
2622
|
}
|
2384
2623
|
|
2385
2624
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
2386
|
-
|
2387
|
-
|
2625
|
+
|
2626
|
+
device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
|
2627
|
+
if (sgitg == 0) {
|
2388
2628
|
for (int i = 0; i < n_rows; i++) {
|
2389
|
-
for (int j = tiitg; j< n_cols; j += BLOCK_SIZE_N) {
|
2629
|
+
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
|
2390
2630
|
*(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
|
2391
2631
|
}
|
2392
2632
|
}
|
@@ -2407,6 +2647,8 @@ template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows
|
|
2407
2647
|
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
2408
2648
|
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
2409
2649
|
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
2650
|
+
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
|
2651
|
+
template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows<block_q5_1, 2, dequantize_q5_1>;
|
2410
2652
|
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
|
2411
2653
|
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
|
2412
2654
|
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
|
@@ -2435,6 +2677,8 @@ template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<f
|
|
2435
2677
|
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
2436
2678
|
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
|
2437
2679
|
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
|
2680
|
+
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
|
2681
|
+
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
|
2438
2682
|
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
|
2439
2683
|
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
|
2440
2684
|
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
|