llama_cpp 0.7.0 → 0.8.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +12 -0
- data/ext/llama_cpp/extconf.rb +1 -1
- data/ext/llama_cpp/llama_cpp.cpp +41 -21
- data/ext/llama_cpp/src/ggml-alloc.c +62 -107
- data/ext/llama_cpp/src/ggml-alloc.h +11 -5
- data/ext/llama_cpp/src/ggml-backend.c +385 -0
- data/ext/llama_cpp/src/ggml-backend.h +143 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +500 -78
- data/ext/llama_cpp/src/ggml-cuda.h +4 -0
- data/ext/llama_cpp/src/ggml-metal.h +18 -1
- data/ext/llama_cpp/src/ggml-metal.m +396 -127
- data/ext/llama_cpp/src/ggml-metal.metal +290 -46
- data/ext/llama_cpp/src/ggml-opencl.cpp +47 -71
- data/ext/llama_cpp/src/ggml.c +71 -55
- data/ext/llama_cpp/src/ggml.h +15 -9
- data/ext/llama_cpp/src/k_quants.c +12 -20
- data/ext/llama_cpp/src/k_quants.h +5 -5
- data/ext/llama_cpp/src/llama.cpp +1851 -250
- data/ext/llama_cpp/src/llama.h +18 -12
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +4 -4
- metadata +5 -3
@@ -13,11 +13,26 @@ typedef struct {
|
|
13
13
|
|
14
14
|
#define QK4_1 32
|
15
15
|
typedef struct {
|
16
|
-
half d;
|
17
|
-
half m;
|
16
|
+
half d; // delta
|
17
|
+
half m; // min
|
18
18
|
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
19
19
|
} block_q4_1;
|
20
20
|
|
21
|
+
#define QK5_0 32
|
22
|
+
typedef struct {
|
23
|
+
half d; // delta
|
24
|
+
uint8_t qh[4]; // 5-th bit of quants
|
25
|
+
uint8_t qs[QK5_0 / 2]; // nibbles / quants
|
26
|
+
} block_q5_0;
|
27
|
+
|
28
|
+
#define QK5_1 32
|
29
|
+
typedef struct {
|
30
|
+
half d; // delta
|
31
|
+
half m; // min
|
32
|
+
uint8_t qh[4]; // 5-th bit of quants
|
33
|
+
uint8_t qs[QK5_1 / 2]; // nibbles / quants
|
34
|
+
} block_q5_1;
|
35
|
+
|
21
36
|
#define QK8_0 32
|
22
37
|
typedef struct {
|
23
38
|
half d; // delta
|
@@ -132,6 +147,13 @@ kernel void kernel_relu(
|
|
132
147
|
dst[tpig] = max(0.0f, src0[tpig]);
|
133
148
|
}
|
134
149
|
|
150
|
+
kernel void kernel_sqr(
|
151
|
+
device const float * src0,
|
152
|
+
device float * dst,
|
153
|
+
uint tpig[[thread_position_in_grid]]) {
|
154
|
+
dst[tpig] = src0[tpig] * src0[tpig];
|
155
|
+
}
|
156
|
+
|
135
157
|
constant float GELU_COEF_A = 0.044715f;
|
136
158
|
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
137
159
|
|
@@ -338,10 +360,11 @@ kernel void kernel_rms_norm(
|
|
338
360
|
uint sgitg[[simdgroup_index_in_threadgroup]],
|
339
361
|
uint tiisg[[thread_index_in_simdgroup]],
|
340
362
|
uint ntg[[threads_per_threadgroup]]) {
|
341
|
-
device const float4 * x
|
342
|
-
device const float
|
343
|
-
|
344
|
-
|
363
|
+
device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
|
364
|
+
device const float * x_scalar = (device const float *) x;
|
365
|
+
|
366
|
+
float4 sumf = 0;
|
367
|
+
float all_sum = 0;
|
345
368
|
|
346
369
|
// parallel sum
|
347
370
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
@@ -354,6 +377,7 @@ kernel void kernel_rms_norm(
|
|
354
377
|
}
|
355
378
|
|
356
379
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
380
|
+
|
357
381
|
// broadcast, simd group number is ntg / 32
|
358
382
|
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
359
383
|
if (tpitg < i) {
|
@@ -361,7 +385,9 @@ kernel void kernel_rms_norm(
|
|
361
385
|
}
|
362
386
|
}
|
363
387
|
if (tpitg == 0) {
|
364
|
-
for (int i = 4 * (ne00 / 4); i < ne00; i++) {
|
388
|
+
for (int i = 4 * (ne00 / 4); i < ne00; i++) {
|
389
|
+
sum[0] += x_scalar[i];
|
390
|
+
}
|
365
391
|
sum[0] /= ne00;
|
366
392
|
}
|
367
393
|
|
@@ -376,7 +402,9 @@ kernel void kernel_rms_norm(
|
|
376
402
|
y[i00] = x[i00] * scale;
|
377
403
|
}
|
378
404
|
if (tpitg == 0) {
|
379
|
-
for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
|
405
|
+
for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
|
406
|
+
y_scalar[i00] = x_scalar[i00] * scale;
|
407
|
+
}
|
380
408
|
}
|
381
409
|
}
|
382
410
|
|
@@ -386,8 +414,11 @@ kernel void kernel_rms_norm(
|
|
386
414
|
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
387
415
|
inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
|
388
416
|
float d = qb_curr->d;
|
417
|
+
|
389
418
|
float2 acc = 0.f;
|
419
|
+
|
390
420
|
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
|
421
|
+
|
391
422
|
for (int i = 0; i < 8; i+=2) {
|
392
423
|
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
|
393
424
|
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
|
@@ -404,8 +435,11 @@ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thre
|
|
404
435
|
inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
|
405
436
|
float d = qb_curr->d;
|
406
437
|
float m = qb_curr->m;
|
407
|
-
|
438
|
+
|
408
439
|
float2 acc = 0.f;
|
440
|
+
|
441
|
+
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
|
442
|
+
|
409
443
|
for (int i = 0; i < 8; i+=2) {
|
410
444
|
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
|
411
445
|
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
|
@@ -415,9 +449,52 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
|
|
415
449
|
return d * (acc[0] + acc[1]) + sumy * m;
|
416
450
|
}
|
417
451
|
|
452
|
+
// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
|
453
|
+
// il indicates where the q5 quants begin (0 or QK5_0/4)
|
454
|
+
// we assume that the yl's have been multiplied with the appropriate scale factor
|
455
|
+
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
456
|
+
inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
|
457
|
+
float d = qb_curr->d;
|
458
|
+
|
459
|
+
float2 acc = 0.f;
|
460
|
+
|
461
|
+
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2);
|
462
|
+
const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
|
463
|
+
|
464
|
+
for (int i = 0; i < 8; i+=2) {
|
465
|
+
acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
|
466
|
+
+ yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
|
467
|
+
acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
|
468
|
+
+ yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
|
469
|
+
}
|
470
|
+
return d * (sumy * -16.f + acc[0] + acc[1]);
|
471
|
+
}
|
472
|
+
|
473
|
+
// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
|
474
|
+
// il indicates where the q5 quants begin (0 or QK5_1/4)
|
475
|
+
// we assume that the yl's have been multiplied with the appropriate scale factor
|
476
|
+
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
477
|
+
inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) {
|
478
|
+
float d = qb_curr->d;
|
479
|
+
float m = qb_curr->m;
|
480
|
+
|
481
|
+
float2 acc = 0.f;
|
482
|
+
|
483
|
+
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2);
|
484
|
+
const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
|
485
|
+
|
486
|
+
for (int i = 0; i < 8; i+=2) {
|
487
|
+
acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
|
488
|
+
+ yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
|
489
|
+
acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
|
490
|
+
+ yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
|
491
|
+
}
|
492
|
+
return d * (acc[0] + acc[1]) + sumy * m;
|
493
|
+
}
|
494
|
+
|
418
495
|
// putting them in the kernel cause a significant performance penalty
|
419
|
-
#define N_DST 4
|
420
|
-
#define N_SIMDGROUP 2
|
496
|
+
#define N_DST 4 // each SIMD group works on 4 rows
|
497
|
+
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
421
498
|
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
422
499
|
//Note: This is a template, but strictly speaking it only applies to
|
423
500
|
// quantizations where the block size is 32. It also does not
|
@@ -428,18 +505,23 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
|
|
428
505
|
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
|
429
506
|
uint3 tgpig, uint tiisg, uint sgitg) {
|
430
507
|
const int nb = ne00/QK4_0;
|
508
|
+
|
431
509
|
const int r0 = tgpig.x;
|
432
510
|
const int r1 = tgpig.y;
|
433
511
|
const int im = tgpig.z;
|
512
|
+
|
434
513
|
const int first_row = (r0 * nsg + sgitg) * nr;
|
514
|
+
|
435
515
|
const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
|
516
|
+
|
436
517
|
device const block_q_type * x = (device const block_q_type *) src0 + offset0;
|
437
518
|
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
438
|
-
float yl[16]; // src1 vector cache
|
439
|
-
float sumf[nr]={0.f};
|
440
519
|
|
441
|
-
|
442
|
-
|
520
|
+
float yl[16]; // src1 vector cache
|
521
|
+
float sumf[nr] = {0.f};
|
522
|
+
|
523
|
+
const int ix = (tiisg/2);
|
524
|
+
const int il = (tiisg%2)*8;
|
443
525
|
|
444
526
|
device const float * yb = y + ix * QK4_0 + il;
|
445
527
|
|
@@ -450,6 +532,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
|
|
450
532
|
sumy += yb[i] + yb[i+1];
|
451
533
|
yl[i+0] = yb[i+ 0];
|
452
534
|
yl[i+1] = yb[i+ 1]/256.f;
|
535
|
+
|
453
536
|
sumy += yb[i+16] + yb[i+17];
|
454
537
|
yl[i+8] = yb[i+16]/16.f;
|
455
538
|
yl[i+9] = yb[i+17]/4096.f;
|
@@ -465,12 +548,12 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
|
|
465
548
|
for (int row = 0; row < nr; ++row) {
|
466
549
|
const float tot = simd_sum(sumf[row]);
|
467
550
|
if (tiisg == 0 && first_row + row < ne01) {
|
468
|
-
dst[
|
551
|
+
dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot;
|
469
552
|
}
|
470
553
|
}
|
471
554
|
}
|
472
555
|
|
473
|
-
kernel void
|
556
|
+
kernel void kernel_mul_mv_q4_0_f32(
|
474
557
|
device const void * src0,
|
475
558
|
device const float * src1,
|
476
559
|
device float * dst,
|
@@ -483,12 +566,12 @@ kernel void kernel_mul_mat_q4_0_f32(
|
|
483
566
|
constant int64_t & ne1[[buffer(16)]],
|
484
567
|
constant uint & gqa[[buffer(17)]],
|
485
568
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
486
|
-
uint
|
487
|
-
uint
|
569
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
570
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
488
571
|
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
489
572
|
}
|
490
573
|
|
491
|
-
kernel void
|
574
|
+
kernel void kernel_mul_mv_q4_1_f32(
|
492
575
|
device const void * src0,
|
493
576
|
device const float * src1,
|
494
577
|
device float * dst,
|
@@ -506,9 +589,46 @@ kernel void kernel_mul_mat_q4_1_f32(
|
|
506
589
|
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
507
590
|
}
|
508
591
|
|
592
|
+
kernel void kernel_mul_mv_q5_0_f32(
|
593
|
+
device const void * src0,
|
594
|
+
device const float * src1,
|
595
|
+
device float * dst,
|
596
|
+
constant int64_t & ne00,
|
597
|
+
constant int64_t & ne01[[buffer(4)]],
|
598
|
+
constant int64_t & ne02[[buffer(5)]],
|
599
|
+
constant int64_t & ne10[[buffer(9)]],
|
600
|
+
constant int64_t & ne12[[buffer(11)]],
|
601
|
+
constant int64_t & ne0[[buffer(15)]],
|
602
|
+
constant int64_t & ne1[[buffer(16)]],
|
603
|
+
constant uint & gqa[[buffer(17)]],
|
604
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
605
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
606
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
607
|
+
mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
608
|
+
}
|
609
|
+
|
610
|
+
kernel void kernel_mul_mv_q5_1_f32(
|
611
|
+
device const void * src0,
|
612
|
+
device const float * src1,
|
613
|
+
device float * dst,
|
614
|
+
constant int64_t & ne00,
|
615
|
+
constant int64_t & ne01[[buffer(4)]],
|
616
|
+
constant int64_t & ne02[[buffer(5)]],
|
617
|
+
constant int64_t & ne10[[buffer(9)]],
|
618
|
+
constant int64_t & ne12[[buffer(11)]],
|
619
|
+
constant int64_t & ne0[[buffer(15)]],
|
620
|
+
constant int64_t & ne1[[buffer(16)]],
|
621
|
+
constant uint & gqa[[buffer(17)]],
|
622
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
623
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
624
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
625
|
+
mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
626
|
+
}
|
627
|
+
|
628
|
+
|
509
629
|
#define NB_Q8_0 8
|
510
630
|
|
511
|
-
kernel void
|
631
|
+
kernel void kernel_mul_mv_q8_0_f32(
|
512
632
|
device const void * src0,
|
513
633
|
device const float * src1,
|
514
634
|
device float * dst,
|
@@ -572,7 +692,7 @@ kernel void kernel_mul_mat_q8_0_f32(
|
|
572
692
|
|
573
693
|
#define N_F32_F32 4
|
574
694
|
|
575
|
-
kernel void
|
695
|
+
kernel void kernel_mul_mv_f32_f32(
|
576
696
|
device const char * src0,
|
577
697
|
device const char * src1,
|
578
698
|
device float * dst,
|
@@ -643,7 +763,7 @@ kernel void kernel_mul_mat_f32_f32(
|
|
643
763
|
}
|
644
764
|
}
|
645
765
|
|
646
|
-
kernel void
|
766
|
+
kernel void kernel_mul_mv_f16_f32_1row(
|
647
767
|
device const char * src0,
|
648
768
|
device const char * src1,
|
649
769
|
device float * dst,
|
@@ -662,7 +782,7 @@ kernel void kernel_mul_mat_f16_f32_1row(
|
|
662
782
|
constant int64_t & ne0,
|
663
783
|
constant int64_t & ne1,
|
664
784
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
665
|
-
uint
|
785
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
666
786
|
|
667
787
|
const int64_t r0 = tgpig.x;
|
668
788
|
const int64_t r1 = tgpig.y;
|
@@ -697,7 +817,7 @@ kernel void kernel_mul_mat_f16_f32_1row(
|
|
697
817
|
|
698
818
|
#define N_F16_F32 4
|
699
819
|
|
700
|
-
kernel void
|
820
|
+
kernel void kernel_mul_mv_f16_f32(
|
701
821
|
device const char * src0,
|
702
822
|
device const char * src1,
|
703
823
|
device float * dst,
|
@@ -769,7 +889,7 @@ kernel void kernel_mul_mat_f16_f32(
|
|
769
889
|
}
|
770
890
|
|
771
891
|
// Assumes row size (ne00) is a multiple of 4
|
772
|
-
kernel void
|
892
|
+
kernel void kernel_mul_mv_f16_f32_l4(
|
773
893
|
device const char * src0,
|
774
894
|
device const char * src1,
|
775
895
|
device float * dst,
|
@@ -1098,6 +1218,62 @@ kernel void kernel_cpy_f32_f32(
|
|
1098
1218
|
}
|
1099
1219
|
}
|
1100
1220
|
|
1221
|
+
kernel void kernel_concat(
|
1222
|
+
device const char * src0,
|
1223
|
+
device const char * src1,
|
1224
|
+
device char * dst,
|
1225
|
+
constant int64_t & ne00,
|
1226
|
+
constant int64_t & ne01,
|
1227
|
+
constant int64_t & ne02,
|
1228
|
+
constant int64_t & ne03,
|
1229
|
+
constant uint64_t & nb00,
|
1230
|
+
constant uint64_t & nb01,
|
1231
|
+
constant uint64_t & nb02,
|
1232
|
+
constant uint64_t & nb03,
|
1233
|
+
constant int64_t & ne10,
|
1234
|
+
constant int64_t & ne11,
|
1235
|
+
constant int64_t & ne12,
|
1236
|
+
constant int64_t & ne13,
|
1237
|
+
constant uint64_t & nb10,
|
1238
|
+
constant uint64_t & nb11,
|
1239
|
+
constant uint64_t & nb12,
|
1240
|
+
constant uint64_t & nb13,
|
1241
|
+
constant int64_t & ne0,
|
1242
|
+
constant int64_t & ne1,
|
1243
|
+
constant int64_t & ne2,
|
1244
|
+
constant int64_t & ne3,
|
1245
|
+
constant uint64_t & nb0,
|
1246
|
+
constant uint64_t & nb1,
|
1247
|
+
constant uint64_t & nb2,
|
1248
|
+
constant uint64_t & nb3,
|
1249
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1250
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
1251
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
1252
|
+
|
1253
|
+
const int64_t i03 = tgpig.z;
|
1254
|
+
const int64_t i02 = tgpig.y;
|
1255
|
+
const int64_t i01 = tgpig.x;
|
1256
|
+
|
1257
|
+
const int64_t i13 = i03 % ne13;
|
1258
|
+
const int64_t i12 = i02 % ne12;
|
1259
|
+
const int64_t i11 = i01 % ne11;
|
1260
|
+
|
1261
|
+
device const char * src0_ptr = src0 + i03 * nb03 + i02 * nb02 + i01 * nb01 + tpitg.x*nb00;
|
1262
|
+
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
|
1263
|
+
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
|
1264
|
+
|
1265
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
1266
|
+
if (i02 < ne02) {
|
1267
|
+
((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0];
|
1268
|
+
src0_ptr += ntg.x*nb00;
|
1269
|
+
} else {
|
1270
|
+
((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0];
|
1271
|
+
src1_ptr += ntg.x*nb10;
|
1272
|
+
}
|
1273
|
+
dst_ptr += ntg.x*nb0;
|
1274
|
+
}
|
1275
|
+
}
|
1276
|
+
|
1101
1277
|
//============================================ k-quants ======================================================
|
1102
1278
|
|
1103
1279
|
#ifndef QK_K
|
@@ -1190,7 +1366,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
|
|
1190
1366
|
|
1191
1367
|
//====================================== dot products =========================
|
1192
1368
|
|
1193
|
-
kernel void
|
1369
|
+
kernel void kernel_mul_mv_q2_K_f32(
|
1194
1370
|
device const void * src0,
|
1195
1371
|
device const float * src1,
|
1196
1372
|
device float * dst,
|
@@ -1334,7 +1510,7 @@ kernel void kernel_mul_mat_q2_K_f32(
|
|
1334
1510
|
}
|
1335
1511
|
|
1336
1512
|
#if QK_K == 256
|
1337
|
-
kernel void
|
1513
|
+
kernel void kernel_mul_mv_q3_K_f32(
|
1338
1514
|
device const void * src0,
|
1339
1515
|
device const float * src1,
|
1340
1516
|
device float * dst,
|
@@ -1486,7 +1662,7 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1486
1662
|
}
|
1487
1663
|
}
|
1488
1664
|
#else
|
1489
|
-
kernel void
|
1665
|
+
kernel void kernel_mul_mv_q3_K_f32(
|
1490
1666
|
device const void * src0,
|
1491
1667
|
device const float * src1,
|
1492
1668
|
device float * dst,
|
@@ -1557,7 +1733,7 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1557
1733
|
#endif
|
1558
1734
|
|
1559
1735
|
#if QK_K == 256
|
1560
|
-
kernel void
|
1736
|
+
kernel void kernel_mul_mv_q4_K_f32(
|
1561
1737
|
device const void * src0,
|
1562
1738
|
device const float * src1,
|
1563
1739
|
device float * dst,
|
@@ -1663,7 +1839,7 @@ kernel void kernel_mul_mat_q4_K_f32(
|
|
1663
1839
|
}
|
1664
1840
|
}
|
1665
1841
|
#else
|
1666
|
-
kernel void
|
1842
|
+
kernel void kernel_mul_mv_q4_K_f32(
|
1667
1843
|
device const void * src0,
|
1668
1844
|
device const float * src1,
|
1669
1845
|
device float * dst,
|
@@ -1752,7 +1928,7 @@ kernel void kernel_mul_mat_q4_K_f32(
|
|
1752
1928
|
}
|
1753
1929
|
#endif
|
1754
1930
|
|
1755
|
-
kernel void
|
1931
|
+
kernel void kernel_mul_mv_q5_K_f32(
|
1756
1932
|
device const void * src0,
|
1757
1933
|
device const float * src1,
|
1758
1934
|
device float * dst,
|
@@ -1925,7 +2101,7 @@ kernel void kernel_mul_mat_q5_K_f32(
|
|
1925
2101
|
|
1926
2102
|
}
|
1927
2103
|
|
1928
|
-
kernel void
|
2104
|
+
kernel void kernel_mul_mv_q6_K_f32(
|
1929
2105
|
device const void * src0,
|
1930
2106
|
device const float * src1,
|
1931
2107
|
device float * dst,
|
@@ -2074,6 +2250,62 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
|
|
2074
2250
|
}
|
2075
2251
|
}
|
2076
2252
|
|
2253
|
+
template <typename type4x4>
|
2254
|
+
void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
|
2255
|
+
device const uint16_t * qs = ((device const uint16_t *)xb + 3);
|
2256
|
+
const float d = xb->d;
|
2257
|
+
const float md = -16.h * xb->d;
|
2258
|
+
const ushort mask = il ? 0x00F0 : 0x000F;
|
2259
|
+
|
2260
|
+
const uint32_t qh = *((device const uint32_t *)xb->qh);
|
2261
|
+
|
2262
|
+
const int x_mv = il ? 4 : 0;
|
2263
|
+
|
2264
|
+
const int gh_mv = il ? 12 : 0;
|
2265
|
+
const int gh_bk = il ? 0 : 4;
|
2266
|
+
|
2267
|
+
for (int i = 0; i < 8; i++) {
|
2268
|
+
// extract the 5-th bits for x0 and x1
|
2269
|
+
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
|
2270
|
+
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
|
2271
|
+
|
2272
|
+
// combine the 4-bits from qs with the 5th bit
|
2273
|
+
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
|
2274
|
+
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
|
2275
|
+
|
2276
|
+
reg[i/2][2*(i%2)+0] = d * x0 + md;
|
2277
|
+
reg[i/2][2*(i%2)+1] = d * x1 + md;
|
2278
|
+
}
|
2279
|
+
}
|
2280
|
+
|
2281
|
+
template <typename type4x4>
|
2282
|
+
void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
|
2283
|
+
device const uint16_t * qs = ((device const uint16_t *)xb + 4);
|
2284
|
+
const float d = xb->d;
|
2285
|
+
const float m = xb->m;
|
2286
|
+
const ushort mask = il ? 0x00F0 : 0x000F;
|
2287
|
+
|
2288
|
+
const uint32_t qh = *((device const uint32_t *)xb->qh);
|
2289
|
+
|
2290
|
+
const int x_mv = il ? 4 : 0;
|
2291
|
+
|
2292
|
+
const int gh_mv = il ? 12 : 0;
|
2293
|
+
const int gh_bk = il ? 0 : 4;
|
2294
|
+
|
2295
|
+
for (int i = 0; i < 8; i++) {
|
2296
|
+
// extract the 5-th bits for x0 and x1
|
2297
|
+
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
|
2298
|
+
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
|
2299
|
+
|
2300
|
+
// combine the 4-bits from qs with the 5th bit
|
2301
|
+
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
|
2302
|
+
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
|
2303
|
+
|
2304
|
+
reg[i/2][2*(i%2)+0] = d * x0 + m;
|
2305
|
+
reg[i/2][2*(i%2)+1] = d * x1 + m;
|
2306
|
+
}
|
2307
|
+
}
|
2308
|
+
|
2077
2309
|
template <typename type4x4>
|
2078
2310
|
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
|
2079
2311
|
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
@@ -2263,7 +2495,7 @@ kernel void kernel_get_rows(
|
|
2263
2495
|
}
|
2264
2496
|
|
2265
2497
|
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
|
2266
|
-
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix
|
2498
|
+
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
|
2267
2499
|
#define BLOCK_SIZE_K 32
|
2268
2500
|
#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
|
2269
2501
|
#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
|
@@ -2300,9 +2532,11 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
2300
2532
|
const uint r0 = tgpig.y;
|
2301
2533
|
const uint r1 = tgpig.x;
|
2302
2534
|
const uint im = tgpig.z;
|
2535
|
+
|
2303
2536
|
// if this block is of 64x32 shape or smaller
|
2304
2537
|
short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
|
2305
2538
|
short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
|
2539
|
+
|
2306
2540
|
// a thread shouldn't load data outside of the matrix
|
2307
2541
|
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
2308
2542
|
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
@@ -2326,26 +2560,30 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
2326
2560
|
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
2327
2561
|
|
2328
2562
|
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
2329
|
-
//load data and store to threadgroup memory
|
2563
|
+
// load data and store to threadgroup memory
|
2330
2564
|
half4x4 temp_a;
|
2331
2565
|
dequantize_func(x, il, temp_a);
|
2332
2566
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
2567
|
+
|
2333
2568
|
#pragma unroll(16)
|
2334
2569
|
for (int i = 0; i < 16; i++) {
|
2335
2570
|
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
|
2336
|
-
+
|
2337
|
-
+
|
2571
|
+
+ (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
|
2572
|
+
+ (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
|
2338
2573
|
}
|
2339
|
-
|
2340
|
-
|
2574
|
+
|
2575
|
+
*(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
|
2576
|
+
|
2341
2577
|
il = (il + 2 < nl) ? il + 2 : il % 2;
|
2342
2578
|
x = (il < 2) ? x + (2+nl-1)/nl : x;
|
2343
2579
|
y += BLOCK_SIZE_K;
|
2344
2580
|
|
2345
2581
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
2346
|
-
|
2582
|
+
|
2583
|
+
// load matrices from threadgroup memory and conduct outer products
|
2347
2584
|
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
|
2348
2585
|
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
|
2586
|
+
|
2349
2587
|
#pragma unroll(4)
|
2350
2588
|
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
|
2351
2589
|
#pragma unroll(4)
|
@@ -2360,6 +2598,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
2360
2598
|
|
2361
2599
|
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
|
2362
2600
|
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
|
2601
|
+
|
2363
2602
|
#pragma unroll(8)
|
2364
2603
|
for (int i = 0; i < 8; i++){
|
2365
2604
|
simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
|
@@ -2368,25 +2607,26 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
2368
2607
|
}
|
2369
2608
|
|
2370
2609
|
if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
|
2371
|
-
device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1) \
|
2372
|
-
|
2610
|
+
device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
|
2611
|
+
+ (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
|
2373
2612
|
for (int i = 0; i < 8; i++) {
|
2374
2613
|
simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
|
2375
2614
|
}
|
2376
2615
|
} else {
|
2377
2616
|
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
2378
2617
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
2379
|
-
threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
|
2618
|
+
threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
|
2380
2619
|
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
|
2381
2620
|
for (int i = 0; i < 8; i++) {
|
2382
2621
|
simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
|
2383
2622
|
}
|
2384
2623
|
|
2385
2624
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
2386
|
-
|
2387
|
-
|
2625
|
+
|
2626
|
+
device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
|
2627
|
+
if (sgitg == 0) {
|
2388
2628
|
for (int i = 0; i < n_rows; i++) {
|
2389
|
-
for (int j = tiitg; j< n_cols; j += BLOCK_SIZE_N) {
|
2629
|
+
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
|
2390
2630
|
*(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
|
2391
2631
|
}
|
2392
2632
|
}
|
@@ -2407,6 +2647,8 @@ template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows
|
|
2407
2647
|
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
2408
2648
|
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
2409
2649
|
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
2650
|
+
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
|
2651
|
+
template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows<block_q5_1, 2, dequantize_q5_1>;
|
2410
2652
|
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
|
2411
2653
|
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
|
2412
2654
|
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
|
@@ -2435,6 +2677,8 @@ template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<f
|
|
2435
2677
|
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
2436
2678
|
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
|
2437
2679
|
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
|
2680
|
+
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
|
2681
|
+
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
|
2438
2682
|
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
|
2439
2683
|
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
|
2440
2684
|
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
|