llama_cpp 0.5.3 → 0.6.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/README.md +6 -5
- data/examples/chat.rb +13 -13
- data/examples/embedding.rb +9 -9
- data/ext/llama_cpp/llama_cpp.cpp +547 -272
- data/ext/llama_cpp/src/ggml-alloc.c +8 -2
- data/ext/llama_cpp/src/ggml-alloc.h +1 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +209 -82
- data/ext/llama_cpp/src/ggml-cuda.h +1 -0
- data/ext/llama_cpp/src/ggml-metal.h +4 -0
- data/ext/llama_cpp/src/ggml-metal.m +163 -84
- data/ext/llama_cpp/src/ggml-metal.metal +121 -38
- data/ext/llama_cpp/src/ggml.c +1596 -842
- data/ext/llama_cpp/src/ggml.h +116 -35
- data/ext/llama_cpp/src/llama.cpp +1015 -586
- data/ext/llama_cpp/src/llama.h +304 -119
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +5 -9
- data/sig/llama_cpp.rbs +65 -34
- metadata +3 -3
@@ -24,12 +24,59 @@ typedef struct {
|
|
24
24
|
int8_t qs[QK8_0]; // quants
|
25
25
|
} block_q8_0;
|
26
26
|
|
27
|
+
// general-purpose kernel for addition of two tensors
|
28
|
+
// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
|
29
|
+
// cons: not very efficient
|
27
30
|
kernel void kernel_add(
|
28
|
-
device const
|
29
|
-
device const
|
30
|
-
device
|
31
|
-
|
32
|
-
|
31
|
+
device const char * src0,
|
32
|
+
device const char * src1,
|
33
|
+
device char * dst,
|
34
|
+
constant int64_t & ne00,
|
35
|
+
constant int64_t & ne01,
|
36
|
+
constant int64_t & ne02,
|
37
|
+
constant int64_t & ne03,
|
38
|
+
constant int64_t & nb00,
|
39
|
+
constant int64_t & nb01,
|
40
|
+
constant int64_t & nb02,
|
41
|
+
constant int64_t & nb03,
|
42
|
+
constant int64_t & ne10,
|
43
|
+
constant int64_t & ne11,
|
44
|
+
constant int64_t & ne12,
|
45
|
+
constant int64_t & ne13,
|
46
|
+
constant int64_t & nb10,
|
47
|
+
constant int64_t & nb11,
|
48
|
+
constant int64_t & nb12,
|
49
|
+
constant int64_t & nb13,
|
50
|
+
constant int64_t & ne0,
|
51
|
+
constant int64_t & ne1,
|
52
|
+
constant int64_t & ne2,
|
53
|
+
constant int64_t & ne3,
|
54
|
+
constant int64_t & nb0,
|
55
|
+
constant int64_t & nb1,
|
56
|
+
constant int64_t & nb2,
|
57
|
+
constant int64_t & nb3,
|
58
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
59
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
60
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
61
|
+
const int64_t i03 = tgpig.z;
|
62
|
+
const int64_t i02 = tgpig.y;
|
63
|
+
const int64_t i01 = tgpig.x;
|
64
|
+
|
65
|
+
const int64_t i13 = i03 % ne13;
|
66
|
+
const int64_t i12 = i02 % ne12;
|
67
|
+
const int64_t i11 = i01 % ne11;
|
68
|
+
|
69
|
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
|
70
|
+
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
|
71
|
+
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
|
72
|
+
|
73
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
74
|
+
((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0] + ((device float *)src1_ptr)[0];
|
75
|
+
|
76
|
+
src0_ptr += ntg.x*nb00;
|
77
|
+
src1_ptr += ntg.x*nb10;
|
78
|
+
dst_ptr += ntg.x*nb0;
|
79
|
+
}
|
33
80
|
}
|
34
81
|
|
35
82
|
// assumption: src1 is a row
|
@@ -38,7 +85,7 @@ kernel void kernel_add_row(
|
|
38
85
|
device const float4 * src0,
|
39
86
|
device const float4 * src1,
|
40
87
|
device float4 * dst,
|
41
|
-
constant int64_t & nb,
|
88
|
+
constant int64_t & nb [[buffer(27)]],
|
42
89
|
uint tpig[[thread_position_in_grid]]) {
|
43
90
|
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
44
91
|
}
|
@@ -806,30 +853,61 @@ kernel void kernel_alibi_f32(
|
|
806
853
|
}
|
807
854
|
}
|
808
855
|
|
856
|
+
typedef void (rope_t)(
|
857
|
+
device const void * src0,
|
858
|
+
device const int32_t * src1,
|
859
|
+
device float * dst,
|
860
|
+
constant int64_t & ne00,
|
861
|
+
constant int64_t & ne01,
|
862
|
+
constant int64_t & ne02,
|
863
|
+
constant int64_t & ne03,
|
864
|
+
constant uint64_t & nb00,
|
865
|
+
constant uint64_t & nb01,
|
866
|
+
constant uint64_t & nb02,
|
867
|
+
constant uint64_t & nb03,
|
868
|
+
constant int64_t & ne0,
|
869
|
+
constant int64_t & ne1,
|
870
|
+
constant int64_t & ne2,
|
871
|
+
constant int64_t & ne3,
|
872
|
+
constant uint64_t & nb0,
|
873
|
+
constant uint64_t & nb1,
|
874
|
+
constant uint64_t & nb2,
|
875
|
+
constant uint64_t & nb3,
|
876
|
+
constant int & n_past,
|
877
|
+
constant int & n_dims,
|
878
|
+
constant int & mode,
|
879
|
+
constant float & freq_base,
|
880
|
+
constant float & freq_scale,
|
881
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
882
|
+
uint3 tptg[[threads_per_threadgroup]],
|
883
|
+
uint3 tgpig[[threadgroup_position_in_grid]]);
|
884
|
+
|
885
|
+
template<typename T>
|
809
886
|
kernel void kernel_rope(
|
810
|
-
device const
|
811
|
-
device
|
812
|
-
|
813
|
-
constant
|
814
|
-
constant
|
815
|
-
constant
|
816
|
-
constant
|
817
|
-
constant
|
818
|
-
constant
|
819
|
-
constant
|
820
|
-
constant
|
821
|
-
constant
|
822
|
-
constant
|
823
|
-
constant
|
824
|
-
constant
|
825
|
-
constant
|
826
|
-
constant
|
827
|
-
constant
|
828
|
-
constant
|
829
|
-
constant
|
830
|
-
constant
|
831
|
-
constant
|
832
|
-
constant
|
887
|
+
device const void * src0,
|
888
|
+
device const int32_t * src1,
|
889
|
+
device float * dst,
|
890
|
+
constant int64_t & ne00,
|
891
|
+
constant int64_t & ne01,
|
892
|
+
constant int64_t & ne02,
|
893
|
+
constant int64_t & ne03,
|
894
|
+
constant uint64_t & nb00,
|
895
|
+
constant uint64_t & nb01,
|
896
|
+
constant uint64_t & nb02,
|
897
|
+
constant uint64_t & nb03,
|
898
|
+
constant int64_t & ne0,
|
899
|
+
constant int64_t & ne1,
|
900
|
+
constant int64_t & ne2,
|
901
|
+
constant int64_t & ne3,
|
902
|
+
constant uint64_t & nb0,
|
903
|
+
constant uint64_t & nb1,
|
904
|
+
constant uint64_t & nb2,
|
905
|
+
constant uint64_t & nb3,
|
906
|
+
constant int & n_past,
|
907
|
+
constant int & n_dims,
|
908
|
+
constant int & mode,
|
909
|
+
constant float & freq_base,
|
910
|
+
constant float & freq_scale,
|
833
911
|
uint tiitg[[thread_index_in_threadgroup]],
|
834
912
|
uint3 tptg[[threads_per_threadgroup]],
|
835
913
|
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
@@ -839,7 +917,9 @@ kernel void kernel_rope(
|
|
839
917
|
|
840
918
|
const bool is_neox = mode & 2;
|
841
919
|
|
842
|
-
const
|
920
|
+
device const int32_t * pos = src1;
|
921
|
+
|
922
|
+
const int64_t p = pos[i2];
|
843
923
|
|
844
924
|
const float theta_0 = freq_scale * (float)p;
|
845
925
|
const float inv_ndims = -1.f/n_dims;
|
@@ -851,11 +931,11 @@ kernel void kernel_rope(
|
|
851
931
|
const float cos_theta = cos(theta);
|
852
932
|
const float sin_theta = sin(theta);
|
853
933
|
|
854
|
-
device const
|
855
|
-
device
|
934
|
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
935
|
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
856
936
|
|
857
|
-
const
|
858
|
-
const
|
937
|
+
const T x0 = src[0];
|
938
|
+
const T x1 = src[1];
|
859
939
|
|
860
940
|
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
861
941
|
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
@@ -870,8 +950,8 @@ kernel void kernel_rope(
|
|
870
950
|
|
871
951
|
const int64_t i0 = ib*n_dims + ic/2;
|
872
952
|
|
873
|
-
device const
|
874
|
-
device
|
953
|
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
954
|
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
875
955
|
|
876
956
|
const float x0 = src[0];
|
877
957
|
const float x1 = src[n_dims/2];
|
@@ -883,6 +963,9 @@ kernel void kernel_rope(
|
|
883
963
|
}
|
884
964
|
}
|
885
965
|
|
966
|
+
template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
|
967
|
+
template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
|
968
|
+
|
886
969
|
kernel void kernel_cpy_f16_f16(
|
887
970
|
device const half * src0,
|
888
971
|
device half * dst,
|
@@ -1273,8 +1356,8 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1273
1356
|
|
1274
1357
|
float yl[32];
|
1275
1358
|
|
1276
|
-
const uint16_t kmask1 = 0x3030;
|
1277
|
-
const uint16_t kmask2 = 0x0f0f;
|
1359
|
+
//const uint16_t kmask1 = 0x3030;
|
1360
|
+
//const uint16_t kmask2 = 0x0f0f;
|
1278
1361
|
|
1279
1362
|
const int tid = tiisg/4;
|
1280
1363
|
const int ix = tiisg%4;
|