llama_cpp 0.3.8 → 0.5.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +19 -0
- data/README.md +1 -1
- data/examples/chat.rb +4 -6
- data/ext/llama_cpp/extconf.rb +3 -3
- data/ext/llama_cpp/llama_cpp.cpp +129 -124
- data/ext/llama_cpp/src/ggml-alloc.c +90 -113
- data/ext/llama_cpp/src/ggml-alloc.h +1 -1
- data/ext/llama_cpp/src/ggml-cuda.cu +350 -77
- data/ext/llama_cpp/src/ggml-cuda.h +13 -0
- data/ext/llama_cpp/src/ggml-metal.h +4 -0
- data/ext/llama_cpp/src/ggml-metal.m +226 -121
- data/ext/llama_cpp/src/ggml-metal.metal +157 -35
- data/ext/llama_cpp/src/ggml.c +2724 -584
- data/ext/llama_cpp/src/ggml.h +282 -31
- data/ext/llama_cpp/src/k_quants.c +112 -56
- data/ext/llama_cpp/src/llama.cpp +4857 -2986
- data/ext/llama_cpp/src/llama.h +180 -126
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +2 -2
- data/sig/llama_cpp.rbs +12 -11
- metadata +2 -2
@@ -18,10 +18,16 @@ typedef struct {
|
|
18
18
|
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
19
19
|
} block_q4_1;
|
20
20
|
|
21
|
+
#define QK8_0 32
|
22
|
+
typedef struct {
|
23
|
+
half d; // delta
|
24
|
+
int8_t qs[QK8_0]; // quants
|
25
|
+
} block_q8_0;
|
26
|
+
|
21
27
|
kernel void kernel_add(
|
22
|
-
device const
|
23
|
-
device const
|
24
|
-
device
|
28
|
+
device const float4 * src0,
|
29
|
+
device const float4 * src1,
|
30
|
+
device float4 * dst,
|
25
31
|
uint tpig[[thread_position_in_grid]]) {
|
26
32
|
dst[tpig] = src0[tpig] + src1[tpig];
|
27
33
|
}
|
@@ -29,18 +35,18 @@ kernel void kernel_add(
|
|
29
35
|
// assumption: src1 is a row
|
30
36
|
// broadcast src1 into src0
|
31
37
|
kernel void kernel_add_row(
|
32
|
-
device const
|
33
|
-
device const
|
34
|
-
device
|
35
|
-
constant int64_t &
|
38
|
+
device const float4 * src0,
|
39
|
+
device const float4 * src1,
|
40
|
+
device float4 * dst,
|
41
|
+
constant int64_t & nb,
|
36
42
|
uint tpig[[thread_position_in_grid]]) {
|
37
|
-
dst[tpig] = src0[tpig] + src1[tpig %
|
43
|
+
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
38
44
|
}
|
39
45
|
|
40
46
|
kernel void kernel_mul(
|
41
|
-
device const
|
42
|
-
device const
|
43
|
-
device
|
47
|
+
device const float4 * src0,
|
48
|
+
device const float4 * src1,
|
49
|
+
device float4 * dst,
|
44
50
|
uint tpig[[thread_position_in_grid]]) {
|
45
51
|
dst[tpig] = src0[tpig] * src1[tpig];
|
46
52
|
}
|
@@ -48,12 +54,12 @@ kernel void kernel_mul(
|
|
48
54
|
// assumption: src1 is a row
|
49
55
|
// broadcast src1 into src0
|
50
56
|
kernel void kernel_mul_row(
|
51
|
-
device const
|
52
|
-
device const
|
53
|
-
device
|
54
|
-
constant
|
57
|
+
device const float4 * src0,
|
58
|
+
device const float4 * src1,
|
59
|
+
device float4 * dst,
|
60
|
+
constant int64_t & nb,
|
55
61
|
uint tpig[[thread_position_in_grid]]) {
|
56
|
-
dst[tpig] = src0[tpig] * src1[tpig %
|
62
|
+
dst[tpig] = src0[tpig] * src1[tpig % nb];
|
57
63
|
}
|
58
64
|
|
59
65
|
kernel void kernel_scale(
|
@@ -87,7 +93,12 @@ kernel void kernel_gelu(
|
|
87
93
|
device float * dst,
|
88
94
|
uint tpig[[thread_position_in_grid]]) {
|
89
95
|
float x = src0[tpig];
|
90
|
-
|
96
|
+
|
97
|
+
// BEWARE !!!
|
98
|
+
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
|
99
|
+
// This was observed with Falcon 7B and 40B models
|
100
|
+
//
|
101
|
+
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
91
102
|
}
|
92
103
|
|
93
104
|
kernel void kernel_soft_max(
|
@@ -352,7 +363,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
|
|
352
363
|
const int first_row = (r0 * nsg + sgitg) * nr;
|
353
364
|
const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
|
354
365
|
device const block_q_type * x = (device const block_q_type *) src0 + offset0;
|
355
|
-
device const float
|
366
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
356
367
|
float yl[16]; // src1 vector cache
|
357
368
|
float sumf[nr]={0.f};
|
358
369
|
|
@@ -424,6 +435,68 @@ kernel void kernel_mul_mat_q4_1_f32(
|
|
424
435
|
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
425
436
|
}
|
426
437
|
|
438
|
+
kernel void kernel_mul_mat_q8_0_f32(
|
439
|
+
device const void * src0,
|
440
|
+
device const float * src1,
|
441
|
+
device float * dst,
|
442
|
+
constant int64_t & ne00,
|
443
|
+
constant int64_t & ne01[[buffer(4)]],
|
444
|
+
constant int64_t & ne02[[buffer(5)]],
|
445
|
+
constant int64_t & ne10[[buffer(9)]],
|
446
|
+
constant int64_t & ne12[[buffer(11)]],
|
447
|
+
constant int64_t & ne0[[buffer(15)]],
|
448
|
+
constant int64_t & ne1[[buffer(16)]],
|
449
|
+
constant uint & gqa[[buffer(17)]],
|
450
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
451
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
452
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
453
|
+
const int nr = N_DST;
|
454
|
+
const int nsg = N_SIMDGROUP;
|
455
|
+
const int nw = N_SIMDWIDTH;
|
456
|
+
|
457
|
+
const int nb = ne00/QK8_0;
|
458
|
+
const int r0 = tgpig.x;
|
459
|
+
const int r1 = tgpig.y;
|
460
|
+
const int im = tgpig.z;
|
461
|
+
const int first_row = (r0 * nsg + sgitg) * nr;
|
462
|
+
const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
|
463
|
+
device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
|
464
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
465
|
+
|
466
|
+
float yl[16];
|
467
|
+
float sumf[nr]={0.f};
|
468
|
+
|
469
|
+
const int ix = tiisg/2;
|
470
|
+
const int il = tiisg%2;
|
471
|
+
|
472
|
+
device const float * yb = y + ix * QK8_0 + 16*il;
|
473
|
+
|
474
|
+
// each thread in a SIMD group deals with half a block.
|
475
|
+
for (int ib = ix; ib < nb; ib += nw/2) {
|
476
|
+
for (int i = 0; i < 16; ++i) {
|
477
|
+
yl[i] = yb[i];
|
478
|
+
}
|
479
|
+
|
480
|
+
for (int row = 0; row < nr; row++) {
|
481
|
+
device const int8_t * qs = x[ib+row*nb].qs + 16*il;
|
482
|
+
float sumq = 0.f;
|
483
|
+
for (int iq = 0; iq < 16; ++iq) {
|
484
|
+
sumq += qs[iq] * yl[iq];
|
485
|
+
}
|
486
|
+
sumf[row] += sumq*x[ib+row*nb].d;
|
487
|
+
}
|
488
|
+
|
489
|
+
yb += QK8_0 * 16;
|
490
|
+
}
|
491
|
+
|
492
|
+
for (int row = 0; row < nr; ++row) {
|
493
|
+
const float tot = simd_sum(sumf[row]);
|
494
|
+
if (tiisg == 0 && first_row + row < ne01) {
|
495
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
|
496
|
+
}
|
497
|
+
}
|
498
|
+
}
|
499
|
+
|
427
500
|
kernel void kernel_mul_mat_f16_f32(
|
428
501
|
device const char * src0,
|
429
502
|
device const char * src1,
|
@@ -455,26 +528,43 @@ kernel void kernel_mul_mat_f16_f32(
|
|
455
528
|
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
|
456
529
|
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
457
530
|
|
458
|
-
|
531
|
+
uint ith = tpitg.x;
|
532
|
+
uint nth = tptg.x;
|
459
533
|
|
460
|
-
|
461
|
-
|
534
|
+
sum[ith] = 0.0f;
|
535
|
+
|
536
|
+
for (int i = ith; i < ne00; i += nth) {
|
537
|
+
sum[ith] += (float) x[i] * (float) y[i];
|
462
538
|
}
|
463
539
|
|
464
540
|
// accumulate the sum from all threads in the threadgroup
|
465
541
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
466
|
-
|
467
|
-
|
468
|
-
sum[tpitg.x] += sum[tpitg.x + i];
|
469
|
-
}
|
470
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
542
|
+
if (ith%4 == 0) {
|
543
|
+
for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
|
471
544
|
}
|
472
|
-
|
473
|
-
if (
|
545
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
546
|
+
if (ith%16 == 0) {
|
547
|
+
for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
|
548
|
+
}
|
549
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
550
|
+
if (ith == 0) {
|
551
|
+
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
474
552
|
dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
|
475
553
|
}
|
476
|
-
}
|
477
554
|
|
555
|
+
// Original implementation. Left behind commented out for now
|
556
|
+
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
557
|
+
//for (uint i = tptg.x/2; i > 0; i /= 2) {
|
558
|
+
// if (tpitg.x < i) {
|
559
|
+
// sum[tpitg.x] += sum[tpitg.x + i];
|
560
|
+
// }
|
561
|
+
// threadgroup_barrier(mem_flags::mem_threadgroup);
|
562
|
+
//}
|
563
|
+
//
|
564
|
+
//if (tpitg.x == 0) {
|
565
|
+
// dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
|
566
|
+
//}
|
567
|
+
}
|
478
568
|
|
479
569
|
kernel void kernel_alibi_f32(
|
480
570
|
device const float * src0,
|
@@ -571,7 +661,25 @@ kernel void kernel_rope(
|
|
571
661
|
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
572
662
|
}
|
573
663
|
} else {
|
574
|
-
|
664
|
+
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
|
665
|
+
for (int64_t ic = 0; ic < n_dims; ic += 2) {
|
666
|
+
const float cos_theta = cos(theta);
|
667
|
+
const float sin_theta = sin(theta);
|
668
|
+
|
669
|
+
theta *= theta_scale;
|
670
|
+
|
671
|
+
const int64_t i0 = ib*n_dims + ic/2;
|
672
|
+
|
673
|
+
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
674
|
+
device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
675
|
+
|
676
|
+
const float x0 = src[0];
|
677
|
+
const float x1 = src[n_dims/2];
|
678
|
+
|
679
|
+
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
680
|
+
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
681
|
+
}
|
682
|
+
}
|
575
683
|
}
|
576
684
|
}
|
577
685
|
|
@@ -1598,12 +1706,12 @@ template <typename type4x4>
|
|
1598
1706
|
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
|
1599
1707
|
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
1600
1708
|
const half d = il ? (xb->d / 16.h) : xb->d;
|
1601
|
-
const half m = il ? (-8.h * 16.h) : -8.h;
|
1709
|
+
const half m = il ? ( -8.h * 16.h) : -8.h;
|
1602
1710
|
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
1603
1711
|
const ushort mask1 = il ? 0xF000 : 0x0F00;
|
1604
1712
|
|
1605
1713
|
for (int i=0;i<8;i++) {
|
1606
|
-
reg[i/2][2*(i%2)]
|
1714
|
+
reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) + m) * d;
|
1607
1715
|
reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d;
|
1608
1716
|
}
|
1609
1717
|
}
|
@@ -1617,11 +1725,21 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
|
|
1617
1725
|
const ushort mask1 = il ? 0xF000 : 0x0F00;
|
1618
1726
|
|
1619
1727
|
for (int i=0;i<8;i++) {
|
1620
|
-
reg[i/2][2*(i%2)]
|
1728
|
+
reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) * d) + m;
|
1621
1729
|
reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m;
|
1622
1730
|
}
|
1623
1731
|
}
|
1624
1732
|
|
1733
|
+
template <typename type4x4>
|
1734
|
+
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
|
1735
|
+
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
1736
|
+
const half d = xb->d;
|
1737
|
+
|
1738
|
+
for (int i=0;i<16;i++) {
|
1739
|
+
reg[i/4][i%4] = (qs[i + 16*il] * d);
|
1740
|
+
}
|
1741
|
+
}
|
1742
|
+
|
1625
1743
|
template <typename type4x4>
|
1626
1744
|
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
|
1627
1745
|
const half d = xb->d;
|
@@ -1850,6 +1968,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
1850
1968
|
//load data and store to threadgroup memory
|
1851
1969
|
half4x4 temp_a;
|
1852
1970
|
dequantize_func(x, il, temp_a);
|
1971
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1853
1972
|
#pragma unroll(16)
|
1854
1973
|
for (int i = 0; i < 16; i++) {
|
1855
1974
|
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
|
@@ -1895,6 +2014,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
1895
2014
|
}
|
1896
2015
|
} else {
|
1897
2016
|
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
2017
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1898
2018
|
threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
|
1899
2019
|
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
|
1900
2020
|
for (int i = 0; i < 8; i++) {
|
@@ -1922,9 +2042,10 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
1922
2042
|
typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
|
1923
2043
|
constant uint64_t &, constant uint64_t &, uint, uint, uint);
|
1924
2044
|
|
1925
|
-
template [[host_name("kernel_get_rows_f16")]]
|
2045
|
+
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
1926
2046
|
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
1927
2047
|
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
2048
|
+
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
|
1928
2049
|
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
|
1929
2050
|
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
|
1930
2051
|
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
|
@@ -1935,9 +2056,10 @@ typedef void (mat_mm_t)(device const uchar *, device const float *, device float
|
|
1935
2056
|
constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \
|
1936
2057
|
constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint);
|
1937
2058
|
|
1938
|
-
template [[host_name("kernel_mul_mm_f16_f32")]]
|
2059
|
+
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
1939
2060
|
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
|
1940
2061
|
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
|
2062
|
+
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
|
1941
2063
|
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
|
1942
2064
|
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
|
1943
2065
|
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
|