llama_cpp 0.5.3 → 0.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/README.md +6 -5
- data/examples/chat.rb +13 -13
- data/examples/embedding.rb +9 -9
- data/ext/llama_cpp/llama_cpp.cpp +547 -272
- data/ext/llama_cpp/src/ggml-alloc.c +8 -2
- data/ext/llama_cpp/src/ggml-alloc.h +1 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +209 -82
- data/ext/llama_cpp/src/ggml-cuda.h +1 -0
- data/ext/llama_cpp/src/ggml-metal.h +4 -0
- data/ext/llama_cpp/src/ggml-metal.m +163 -84
- data/ext/llama_cpp/src/ggml-metal.metal +121 -38
- data/ext/llama_cpp/src/ggml.c +1596 -842
- data/ext/llama_cpp/src/ggml.h +116 -35
- data/ext/llama_cpp/src/llama.cpp +1015 -586
- data/ext/llama_cpp/src/llama.h +304 -119
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +5 -9
- data/sig/llama_cpp.rbs +65 -34
- metadata +3 -3
@@ -24,12 +24,59 @@ typedef struct {
|
|
24
24
|
int8_t qs[QK8_0]; // quants
|
25
25
|
} block_q8_0;
|
26
26
|
|
27
|
+
// general-purpose kernel for addition of two tensors
|
28
|
+
// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
|
29
|
+
// cons: not very efficient
|
27
30
|
kernel void kernel_add(
|
28
|
-
device const
|
29
|
-
device const
|
30
|
-
device
|
31
|
-
|
32
|
-
|
31
|
+
device const char * src0,
|
32
|
+
device const char * src1,
|
33
|
+
device char * dst,
|
34
|
+
constant int64_t & ne00,
|
35
|
+
constant int64_t & ne01,
|
36
|
+
constant int64_t & ne02,
|
37
|
+
constant int64_t & ne03,
|
38
|
+
constant int64_t & nb00,
|
39
|
+
constant int64_t & nb01,
|
40
|
+
constant int64_t & nb02,
|
41
|
+
constant int64_t & nb03,
|
42
|
+
constant int64_t & ne10,
|
43
|
+
constant int64_t & ne11,
|
44
|
+
constant int64_t & ne12,
|
45
|
+
constant int64_t & ne13,
|
46
|
+
constant int64_t & nb10,
|
47
|
+
constant int64_t & nb11,
|
48
|
+
constant int64_t & nb12,
|
49
|
+
constant int64_t & nb13,
|
50
|
+
constant int64_t & ne0,
|
51
|
+
constant int64_t & ne1,
|
52
|
+
constant int64_t & ne2,
|
53
|
+
constant int64_t & ne3,
|
54
|
+
constant int64_t & nb0,
|
55
|
+
constant int64_t & nb1,
|
56
|
+
constant int64_t & nb2,
|
57
|
+
constant int64_t & nb3,
|
58
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
59
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
60
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
61
|
+
const int64_t i03 = tgpig.z;
|
62
|
+
const int64_t i02 = tgpig.y;
|
63
|
+
const int64_t i01 = tgpig.x;
|
64
|
+
|
65
|
+
const int64_t i13 = i03 % ne13;
|
66
|
+
const int64_t i12 = i02 % ne12;
|
67
|
+
const int64_t i11 = i01 % ne11;
|
68
|
+
|
69
|
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
|
70
|
+
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
|
71
|
+
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
|
72
|
+
|
73
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
74
|
+
((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0] + ((device float *)src1_ptr)[0];
|
75
|
+
|
76
|
+
src0_ptr += ntg.x*nb00;
|
77
|
+
src1_ptr += ntg.x*nb10;
|
78
|
+
dst_ptr += ntg.x*nb0;
|
79
|
+
}
|
33
80
|
}
|
34
81
|
|
35
82
|
// assumption: src1 is a row
|
@@ -38,7 +85,7 @@ kernel void kernel_add_row(
|
|
38
85
|
device const float4 * src0,
|
39
86
|
device const float4 * src1,
|
40
87
|
device float4 * dst,
|
41
|
-
constant int64_t & nb,
|
88
|
+
constant int64_t & nb [[buffer(27)]],
|
42
89
|
uint tpig[[thread_position_in_grid]]) {
|
43
90
|
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
44
91
|
}
|
@@ -806,30 +853,61 @@ kernel void kernel_alibi_f32(
|
|
806
853
|
}
|
807
854
|
}
|
808
855
|
|
856
|
+
typedef void (rope_t)(
|
857
|
+
device const void * src0,
|
858
|
+
device const int32_t * src1,
|
859
|
+
device float * dst,
|
860
|
+
constant int64_t & ne00,
|
861
|
+
constant int64_t & ne01,
|
862
|
+
constant int64_t & ne02,
|
863
|
+
constant int64_t & ne03,
|
864
|
+
constant uint64_t & nb00,
|
865
|
+
constant uint64_t & nb01,
|
866
|
+
constant uint64_t & nb02,
|
867
|
+
constant uint64_t & nb03,
|
868
|
+
constant int64_t & ne0,
|
869
|
+
constant int64_t & ne1,
|
870
|
+
constant int64_t & ne2,
|
871
|
+
constant int64_t & ne3,
|
872
|
+
constant uint64_t & nb0,
|
873
|
+
constant uint64_t & nb1,
|
874
|
+
constant uint64_t & nb2,
|
875
|
+
constant uint64_t & nb3,
|
876
|
+
constant int & n_past,
|
877
|
+
constant int & n_dims,
|
878
|
+
constant int & mode,
|
879
|
+
constant float & freq_base,
|
880
|
+
constant float & freq_scale,
|
881
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
882
|
+
uint3 tptg[[threads_per_threadgroup]],
|
883
|
+
uint3 tgpig[[threadgroup_position_in_grid]]);
|
884
|
+
|
885
|
+
template<typename T>
|
809
886
|
kernel void kernel_rope(
|
810
|
-
device const
|
811
|
-
device
|
812
|
-
|
813
|
-
constant
|
814
|
-
constant
|
815
|
-
constant
|
816
|
-
constant
|
817
|
-
constant
|
818
|
-
constant
|
819
|
-
constant
|
820
|
-
constant
|
821
|
-
constant
|
822
|
-
constant
|
823
|
-
constant
|
824
|
-
constant
|
825
|
-
constant
|
826
|
-
constant
|
827
|
-
constant
|
828
|
-
constant
|
829
|
-
constant
|
830
|
-
constant
|
831
|
-
constant
|
832
|
-
constant
|
887
|
+
device const void * src0,
|
888
|
+
device const int32_t * src1,
|
889
|
+
device float * dst,
|
890
|
+
constant int64_t & ne00,
|
891
|
+
constant int64_t & ne01,
|
892
|
+
constant int64_t & ne02,
|
893
|
+
constant int64_t & ne03,
|
894
|
+
constant uint64_t & nb00,
|
895
|
+
constant uint64_t & nb01,
|
896
|
+
constant uint64_t & nb02,
|
897
|
+
constant uint64_t & nb03,
|
898
|
+
constant int64_t & ne0,
|
899
|
+
constant int64_t & ne1,
|
900
|
+
constant int64_t & ne2,
|
901
|
+
constant int64_t & ne3,
|
902
|
+
constant uint64_t & nb0,
|
903
|
+
constant uint64_t & nb1,
|
904
|
+
constant uint64_t & nb2,
|
905
|
+
constant uint64_t & nb3,
|
906
|
+
constant int & n_past,
|
907
|
+
constant int & n_dims,
|
908
|
+
constant int & mode,
|
909
|
+
constant float & freq_base,
|
910
|
+
constant float & freq_scale,
|
833
911
|
uint tiitg[[thread_index_in_threadgroup]],
|
834
912
|
uint3 tptg[[threads_per_threadgroup]],
|
835
913
|
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
@@ -839,7 +917,9 @@ kernel void kernel_rope(
|
|
839
917
|
|
840
918
|
const bool is_neox = mode & 2;
|
841
919
|
|
842
|
-
const
|
920
|
+
device const int32_t * pos = src1;
|
921
|
+
|
922
|
+
const int64_t p = pos[i2];
|
843
923
|
|
844
924
|
const float theta_0 = freq_scale * (float)p;
|
845
925
|
const float inv_ndims = -1.f/n_dims;
|
@@ -851,11 +931,11 @@ kernel void kernel_rope(
|
|
851
931
|
const float cos_theta = cos(theta);
|
852
932
|
const float sin_theta = sin(theta);
|
853
933
|
|
854
|
-
device const
|
855
|
-
device
|
934
|
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
935
|
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
856
936
|
|
857
|
-
const
|
858
|
-
const
|
937
|
+
const T x0 = src[0];
|
938
|
+
const T x1 = src[1];
|
859
939
|
|
860
940
|
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
861
941
|
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
@@ -870,8 +950,8 @@ kernel void kernel_rope(
|
|
870
950
|
|
871
951
|
const int64_t i0 = ib*n_dims + ic/2;
|
872
952
|
|
873
|
-
device const
|
874
|
-
device
|
953
|
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
954
|
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
875
955
|
|
876
956
|
const float x0 = src[0];
|
877
957
|
const float x1 = src[n_dims/2];
|
@@ -883,6 +963,9 @@ kernel void kernel_rope(
|
|
883
963
|
}
|
884
964
|
}
|
885
965
|
|
966
|
+
template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
|
967
|
+
template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
|
968
|
+
|
886
969
|
kernel void kernel_cpy_f16_f16(
|
887
970
|
device const half * src0,
|
888
971
|
device half * dst,
|
@@ -1273,8 +1356,8 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1273
1356
|
|
1274
1357
|
float yl[32];
|
1275
1358
|
|
1276
|
-
const uint16_t kmask1 = 0x3030;
|
1277
|
-
const uint16_t kmask2 = 0x0f0f;
|
1359
|
+
//const uint16_t kmask1 = 0x3030;
|
1360
|
+
//const uint16_t kmask2 = 0x0f0f;
|
1278
1361
|
|
1279
1362
|
const int tid = tiisg/4;
|
1280
1363
|
const int ix = tiisg%4;
|