llama_cpp 0.13.0 → 0.14.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +13 -0
- data/ext/llama_cpp/llama_cpp.cpp +59 -26
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +6 -4
- data/vendor/tmp/llama.cpp/Makefile +2 -3
- data/vendor/tmp/llama.cpp/ggml-backend-impl.h +4 -3
- data/vendor/tmp/llama.cpp/ggml-backend.c +18 -21
- data/vendor/tmp/llama.cpp/ggml-backend.h +16 -15
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +949 -168
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +2 -2
- data/vendor/tmp/llama.cpp/ggml-metal.m +63 -7
- data/vendor/tmp/llama.cpp/ggml-metal.metal +120 -75
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +2 -2
- data/vendor/tmp/llama.cpp/ggml-quants.c +178 -133
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +3432 -1118
- data/vendor/tmp/llama.cpp/ggml-sycl.h +5 -0
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +39336 -43461
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +1327 -773
- data/vendor/tmp/llama.cpp/ggml-vulkan.h +1 -0
- data/vendor/tmp/llama.cpp/ggml.c +227 -15
- data/vendor/tmp/llama.cpp/ggml.h +30 -4
- data/vendor/tmp/llama.cpp/llama.cpp +631 -211
- data/vendor/tmp/llama.cpp/llama.h +28 -10
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: c7d855ccd32ae097f26a671751d6a2178361cf8d8a6c1b99af37859f2c47ca03
|
4
|
+
data.tar.gz: 3b17318424d08c65ad34da3fa14956c86db0a2ea05ac174323a9b8d2b9e69d59
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 2d90bf9fdd8dbaf5e67b7fb8797a9412168ae6ce5fcfc4c6aca34e194d5beb5204184b5bb36d65dc507a7a618ac9e938987e8d8bf5871e4eb6304b5e6de06020
|
7
|
+
data.tar.gz: eab524367ace146eb6e20786bd530cead145e1651bcdb726afbb5364609d04b22ca8a515016bb0c2d154ea97fb62f19222c122bc9bb5efe7fc389a6f259da6f0
|
data/CHANGELOG.md
CHANGED
@@ -1,5 +1,18 @@
|
|
1
|
+
## [[0.14.0](https://github.com/yoshoku/llama_cpp.rb/compare/v0.13.0...v0.14.0)] - 2024-03-09
|
2
|
+
|
3
|
+
**Breaking Changes**
|
4
|
+
|
5
|
+
- Bump bundled llama.cpp from b2303 to b2361.
|
6
|
+
- Rename embedding accessor to `embeddings` in `ContextParams`.
|
7
|
+
- Remove `do_pooling` accessor from `ContextParams`.
|
8
|
+
- Add `pooling_type` accessor to `ContextParams`.
|
9
|
+
- Fix the size of array returned by `embedding` method in `Context` from `n_embd` to `n_tokens * n_embd`.
|
10
|
+
- Add `embeddings_seq` method to `Context`.
|
11
|
+
|
1
12
|
## [[0.13.0](https://github.com/yoshoku/llama_cpp.rb/compare/v0.12.7...v0.13.0)] - 2024-03-02
|
2
13
|
|
14
|
+
**Breaking Changes**
|
15
|
+
|
3
16
|
- Bump bundled llama.cpp from b2143 to b2303.
|
4
17
|
- Remove deprecated methods:
|
5
18
|
- `map_supported?`, `mlock_supported?`, `apply_lora_from_file`, `eval`, `eval_embd`, `sample_classifier_free_guidance`, `sample_temperature`, and `mul_mat_q`.
|
data/ext/llama_cpp/llama_cpp.cpp
CHANGED
@@ -952,6 +952,8 @@ public:
|
|
952
952
|
rb_define_method(rb_cLLaMAContextParams, "n_threads_batch", RUBY_METHOD_FUNC(_llama_context_params_get_n_threads_batch), 0);
|
953
953
|
rb_define_method(rb_cLLaMAContextParams, "rope_scaling_type=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_scaling_type), 1);
|
954
954
|
rb_define_method(rb_cLLaMAContextParams, "rope_scaling_type", RUBY_METHOD_FUNC(_llama_context_params_get_rope_scaling_type), 0);
|
955
|
+
rb_define_method(rb_cLLaMAContextParams, "pooling_type=", RUBY_METHOD_FUNC(_llama_context_params_set_pooling_type), 1);
|
956
|
+
rb_define_method(rb_cLLaMAContextParams, "pooling_type", RUBY_METHOD_FUNC(_llama_context_params_get_pooling_type), 0);
|
955
957
|
rb_define_method(rb_cLLaMAContextParams, "rope_freq_base=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_freq_base), 1);
|
956
958
|
rb_define_method(rb_cLLaMAContextParams, "rope_freq_base", RUBY_METHOD_FUNC(_llama_context_params_get_rope_freq_base), 0);
|
957
959
|
rb_define_method(rb_cLLaMAContextParams, "rope_freq_scale=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_freq_scale), 1);
|
@@ -974,12 +976,10 @@ public:
|
|
974
976
|
rb_define_method(rb_cLLaMAContextParams, "type_v", RUBY_METHOD_FUNC(_llama_context_params_get_type_v), 0);
|
975
977
|
rb_define_method(rb_cLLaMAContextParams, "logits_all=", RUBY_METHOD_FUNC(_llama_context_params_set_logits_all), 1);
|
976
978
|
rb_define_method(rb_cLLaMAContextParams, "logits_all", RUBY_METHOD_FUNC(_llama_context_params_get_logits_all), 0);
|
977
|
-
rb_define_method(rb_cLLaMAContextParams, "
|
978
|
-
rb_define_method(rb_cLLaMAContextParams, "
|
979
|
+
rb_define_method(rb_cLLaMAContextParams, "embeddings=", RUBY_METHOD_FUNC(_llama_context_params_set_embeddings), 1);
|
980
|
+
rb_define_method(rb_cLLaMAContextParams, "embeddings", RUBY_METHOD_FUNC(_llama_context_params_get_embeddings), 0);
|
979
981
|
rb_define_method(rb_cLLaMAContextParams, "offload_kqv=", RUBY_METHOD_FUNC(_llama_context_params_set_offload_kqv), 1);
|
980
982
|
rb_define_method(rb_cLLaMAContextParams, "offload_kqv", RUBY_METHOD_FUNC(_llama_context_params_get_offload_kqv), 0);
|
981
|
-
rb_define_method(rb_cLLaMAContextParams, "do_pooling=", RUBY_METHOD_FUNC(_llama_context_params_set_do_pooling), 1);
|
982
|
-
rb_define_method(rb_cLLaMAContextParams, "do_pooling", RUBY_METHOD_FUNC(_llama_context_params_get_do_pooling), 0);
|
983
983
|
}
|
984
984
|
|
985
985
|
private:
|
@@ -1058,7 +1058,7 @@ private:
|
|
1058
1058
|
// rope_scaling_type
|
1059
1059
|
static VALUE _llama_context_params_set_rope_scaling_type(VALUE self, VALUE scaling_type) {
|
1060
1060
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
1061
|
-
ptr->params.rope_scaling_type = NUM2INT(scaling_type);
|
1061
|
+
ptr->params.rope_scaling_type = static_cast<enum llama_rope_scaling_type>(NUM2INT(scaling_type));
|
1062
1062
|
return INT2NUM(ptr->params.rope_scaling_type);
|
1063
1063
|
}
|
1064
1064
|
|
@@ -1067,6 +1067,18 @@ private:
|
|
1067
1067
|
return INT2NUM(ptr->params.rope_scaling_type);
|
1068
1068
|
}
|
1069
1069
|
|
1070
|
+
// pooling_type
|
1071
|
+
static VALUE _llama_context_params_set_pooling_type(VALUE self, VALUE scaling_type) {
|
1072
|
+
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
1073
|
+
ptr->params.pooling_type = static_cast<enum llama_pooling_type>(NUM2INT(scaling_type));
|
1074
|
+
return INT2NUM(ptr->params.pooling_type);
|
1075
|
+
}
|
1076
|
+
|
1077
|
+
static VALUE _llama_context_params_get_pooling_type(VALUE self) {
|
1078
|
+
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
1079
|
+
return INT2NUM(ptr->params.pooling_type);
|
1080
|
+
}
|
1081
|
+
|
1070
1082
|
// rope_freq_base
|
1071
1083
|
static VALUE _llama_context_params_set_rope_freq_base(VALUE self, VALUE rope_freq_base) {
|
1072
1084
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
@@ -1199,16 +1211,16 @@ private:
|
|
1199
1211
|
return ptr->params.logits_all ? Qtrue : Qfalse;
|
1200
1212
|
}
|
1201
1213
|
|
1202
|
-
//
|
1203
|
-
static VALUE
|
1214
|
+
// embeddings
|
1215
|
+
static VALUE _llama_context_params_set_embeddings(VALUE self, VALUE embeddings) {
|
1204
1216
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
1205
|
-
ptr->params.
|
1206
|
-
return ptr->params.
|
1217
|
+
ptr->params.embeddings = RTEST(embeddings) ? true : false;
|
1218
|
+
return ptr->params.embeddings ? Qtrue : Qfalse;
|
1207
1219
|
}
|
1208
1220
|
|
1209
|
-
static VALUE
|
1221
|
+
static VALUE _llama_context_params_get_embeddings(VALUE self) {
|
1210
1222
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
1211
|
-
return ptr->params.
|
1223
|
+
return ptr->params.embeddings ? Qtrue : Qfalse;
|
1212
1224
|
}
|
1213
1225
|
|
1214
1226
|
// offload_kqv
|
@@ -1222,18 +1234,6 @@ private:
|
|
1222
1234
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
1223
1235
|
return ptr->params.offload_kqv ? Qtrue : Qfalse;
|
1224
1236
|
}
|
1225
|
-
|
1226
|
-
// do_pooling
|
1227
|
-
static VALUE _llama_context_params_set_do_pooling(VALUE self, VALUE do_pooling) {
|
1228
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
1229
|
-
ptr->params.do_pooling = RTEST(do_pooling) ? true : false;
|
1230
|
-
return ptr->params.do_pooling ? Qtrue : Qfalse;
|
1231
|
-
}
|
1232
|
-
|
1233
|
-
static VALUE _llama_context_params_get_do_pooling(VALUE self) {
|
1234
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
1235
|
-
return ptr->params.do_pooling ? Qtrue : Qfalse;
|
1236
|
-
}
|
1237
1237
|
};
|
1238
1238
|
|
1239
1239
|
const rb_data_type_t RbLLaMAContextParams::llama_context_params_type = {
|
@@ -2016,6 +2016,7 @@ public:
|
|
2016
2016
|
rb_define_method(rb_cLLaMAContext, "logits", RUBY_METHOD_FUNC(_llama_context_logits), 0);
|
2017
2017
|
rb_define_method(rb_cLLaMAContext, "embeddings", RUBY_METHOD_FUNC(_llama_context_embeddings), 0);
|
2018
2018
|
rb_define_method(rb_cLLaMAContext, "embeddings_ith", RUBY_METHOD_FUNC(_llama_context_embeddings_ith), 1);
|
2019
|
+
rb_define_method(rb_cLLaMAContext, "embeddings_seq", RUBY_METHOD_FUNC(_llama_context_embeddings_seq), 1);
|
2019
2020
|
rb_define_method(rb_cLLaMAContext, "n_ctx", RUBY_METHOD_FUNC(_llama_context_n_ctx), 0);
|
2020
2021
|
rb_define_method(rb_cLLaMAContext, "n_batch", RUBY_METHOD_FUNC(_llama_context_n_batch), 0);
|
2021
2022
|
rb_define_method(rb_cLLaMAContext, "timings", RUBY_METHOD_FUNC(_llama_context_get_timings), 0);
|
@@ -2151,7 +2152,7 @@ private:
|
|
2151
2152
|
LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
|
2152
2153
|
VALUE params = rb_iv_get(self, "@params");
|
2153
2154
|
LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
|
2154
|
-
if (!prms_ptr->params.
|
2155
|
+
if (!prms_ptr->params.embeddings) {
|
2155
2156
|
rb_raise(rb_eRuntimeError, "embedding parameter is false");
|
2156
2157
|
return Qnil;
|
2157
2158
|
}
|
@@ -2160,10 +2161,11 @@ private:
|
|
2160
2161
|
return Qnil;
|
2161
2162
|
}
|
2162
2163
|
|
2164
|
+
const int n_tokens = NUM2INT(rb_iv_get(self, "@n_tokens"));
|
2163
2165
|
const int n_embd = llama_n_embd(model_ptr->model);
|
2164
2166
|
const float* embd = llama_get_embeddings(ptr->ctx);
|
2165
2167
|
VALUE output = rb_ary_new();
|
2166
|
-
for (int i = 0; i < n_embd; i++) {
|
2168
|
+
for (int i = 0; i < n_tokens * n_embd; i++) {
|
2167
2169
|
rb_ary_push(output, DBL2NUM((double)(embd[i])));
|
2168
2170
|
}
|
2169
2171
|
|
@@ -2182,7 +2184,7 @@ private:
|
|
2182
2184
|
}
|
2183
2185
|
VALUE params = rb_iv_get(self, "@params");
|
2184
2186
|
LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
|
2185
|
-
if (!prms_ptr->params.
|
2187
|
+
if (!prms_ptr->params.embeddings) {
|
2186
2188
|
rb_raise(rb_eRuntimeError, "embedding parameter is false");
|
2187
2189
|
return Qnil;
|
2188
2190
|
}
|
@@ -2200,6 +2202,36 @@ private:
|
|
2200
2202
|
return output;
|
2201
2203
|
}
|
2202
2204
|
|
2205
|
+
static VALUE _llama_context_embeddings_seq(VALUE self, VALUE seq_id) {
|
2206
|
+
if (!RB_INTEGER_TYPE_P(seq_id)) {
|
2207
|
+
rb_raise(rb_eArgError, "seq_id must be an integer");
|
2208
|
+
return Qnil;
|
2209
|
+
}
|
2210
|
+
LLaMAContextWrapper* ptr = get_llama_context(self);
|
2211
|
+
if (ptr->ctx == NULL) {
|
2212
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
2213
|
+
return Qnil;
|
2214
|
+
}
|
2215
|
+
VALUE params = rb_iv_get(self, "@params");
|
2216
|
+
LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
|
2217
|
+
if (!prms_ptr->params.embeddings) {
|
2218
|
+
rb_raise(rb_eRuntimeError, "embedding parameter is false");
|
2219
|
+
return Qnil;
|
2220
|
+
}
|
2221
|
+
|
2222
|
+
VALUE model = rb_iv_get(self, "@model");
|
2223
|
+
LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
|
2224
|
+
const int n_embd = llama_n_embd(model_ptr->model);
|
2225
|
+
|
2226
|
+
VALUE output = rb_ary_new();
|
2227
|
+
const float* embd = llama_get_embeddings_seq(ptr->ctx, NUM2INT(seq_id));
|
2228
|
+
for (int i = 0; i < n_embd; i++) {
|
2229
|
+
rb_ary_push(output, DBL2NUM((double)(embd[i])));
|
2230
|
+
}
|
2231
|
+
|
2232
|
+
return output;
|
2233
|
+
}
|
2234
|
+
|
2203
2235
|
static VALUE _llama_context_n_ctx(VALUE self) {
|
2204
2236
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
2205
2237
|
if (ptr->ctx == NULL) {
|
@@ -3229,6 +3261,7 @@ extern "C" void Init_llama_cpp(void) {
|
|
3229
3261
|
rb_define_const(rb_mLLaMACpp, "LLAMA_ROPE_SCALING_TYPE_YARN", INT2NUM(LLAMA_ROPE_SCALING_TYPE_YARN));
|
3230
3262
|
rb_define_const(rb_mLLaMACpp, "LLAMA_ROPE_SCALING_TYPE_MAX_VALUE", INT2NUM(LLAMA_ROPE_SCALING_TYPE_MAX_VALUE));
|
3231
3263
|
|
3264
|
+
rb_define_const(rb_mLLaMACpp, "LLAMA_POOLING_TYPE_UNSPECIFIED", INT2NUM(LLAMA_POOLING_TYPE_UNSPECIFIED));
|
3232
3265
|
rb_define_const(rb_mLLaMACpp, "LLAMA_POOLING_TYPE_NONE", INT2NUM(LLAMA_POOLING_TYPE_NONE));
|
3233
3266
|
rb_define_const(rb_mLLaMACpp, "LLAMA_POOLING_TYPE_MEAN", INT2NUM(LLAMA_POOLING_TYPE_MEAN));
|
3234
3267
|
rb_define_const(rb_mLLaMACpp, "LLAMA_POOLING_TYPE_CLS", INT2NUM(LLAMA_POOLING_TYPE_CLS));
|
data/lib/llama_cpp/version.rb
CHANGED
@@ -3,8 +3,8 @@
|
|
3
3
|
# llama_cpp.rb provides Ruby bindings for the llama.cpp.
|
4
4
|
module LLaMACpp
|
5
5
|
# The version of llama_cpp.rb you install.
|
6
|
-
VERSION = '0.
|
6
|
+
VERSION = '0.14.0'
|
7
7
|
|
8
8
|
# The version of llama.cpp bundled with llama_cpp.rb.
|
9
|
-
LLAMA_CPP_VERSION = '
|
9
|
+
LLAMA_CPP_VERSION = 'b2361'
|
10
10
|
end
|
data/sig/llama_cpp.rbs
CHANGED
@@ -50,6 +50,7 @@ module LLaMACpp
|
|
50
50
|
LLAMA_ROPE_SCALING_TYPE_YARN: Integer
|
51
51
|
LLAMA_ROPE_SCALING_TYPE_MAX_VALUE: Integer
|
52
52
|
|
53
|
+
LLAMA_POOLING_TYPE_UNSPECIFIED: Integer
|
53
54
|
LLAMA_POOLING_TYPE_NONE: Integer
|
54
55
|
LLAMA_POOLING_TYPE_MEAN: Integer
|
55
56
|
LLAMA_POOLING_TYPE_CLS: Integer
|
@@ -201,6 +202,7 @@ module LLaMACpp
|
|
201
202
|
def initialize: (model: ::LLaMACpp::Model, params: ::LLaMACpp::ContextParams) -> void
|
202
203
|
def embeddings: () -> Array[Float]
|
203
204
|
def embeddings_ith: (Integer) -> Array[Float]
|
205
|
+
def embeddings_seq: (Integer) -> Array[Float]
|
204
206
|
def decode: (::LLaMACpp::Batch) -> void
|
205
207
|
def logits: () -> Array[Float]
|
206
208
|
def n_ctx: () -> Integer
|
@@ -254,6 +256,8 @@ module LLaMACpp
|
|
254
256
|
def n_threads_batch=: (Integer) -> Integer
|
255
257
|
def rope_scaling_type=: (Integer) -> Integer
|
256
258
|
def rope_scaling_type: () -> Integer
|
259
|
+
def pooling_type=: (Integer) -> Integer
|
260
|
+
def pooling_type: () -> Integer
|
257
261
|
def rope_freq_base=: (Float) -> Float
|
258
262
|
def rope_freq_base: () -> Float
|
259
263
|
def rope_freq_scale=: (Float) -> Float
|
@@ -276,12 +280,10 @@ module LLaMACpp
|
|
276
280
|
def type_v: () -> Integer
|
277
281
|
def logits_all: () -> bool
|
278
282
|
def logits_all=: (bool) -> bool
|
279
|
-
def
|
280
|
-
def
|
283
|
+
def embeddings: () -> bool
|
284
|
+
def embeddings=: (bool) -> bool
|
281
285
|
def offload_kqv: () -> bool
|
282
286
|
def offload_kqv=: (bool) -> bool
|
283
|
-
def do_pooling: () -> bool
|
284
|
-
def do_pooling=: (bool) -> bool
|
285
287
|
end
|
286
288
|
|
287
289
|
class ModelQuantizeParams
|
@@ -729,10 +729,9 @@ save-load-state: examples/save-load-state/save-load-state.cpp ggml.o llama.o $(C
|
|
729
729
|
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
730
730
|
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
731
731
|
|
732
|
-
server: examples/server/server.cpp examples/server/
|
732
|
+
server: examples/server/server.cpp examples/server/utils.hpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
|
733
733
|
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
734
|
-
$(CXX) $(CXXFLAGS) -
|
735
|
-
$(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h %.hpp $< examples/llava/clip.cpp,$^) $(call GET_OBJ_FILE, $<) $(call GET_OBJ_FILE, examples/llava/clip.cpp) -o $@ $(LDFLAGS) $(LWINSOCK2)
|
734
|
+
$(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2)
|
736
735
|
|
737
736
|
gguf: examples/gguf/gguf.cpp ggml.o $(OBJS)
|
738
737
|
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
@@ -91,13 +91,14 @@ extern "C" {
|
|
91
91
|
// (optional) complete all pending operations
|
92
92
|
void (*GGML_CALL synchronize)(ggml_backend_t backend);
|
93
93
|
|
94
|
-
//
|
94
|
+
// create a plan for ggml_cgraph and free it
|
95
95
|
ggml_backend_graph_plan_t (*GGML_CALL graph_plan_create) (ggml_backend_t backend, const struct ggml_cgraph * cgraph);
|
96
96
|
void (*GGML_CALL graph_plan_free) (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
|
97
|
-
void (*GGML_CALL graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
|
98
97
|
|
98
|
+
// compute graph with a plan
|
99
|
+
enum ggml_status (*GGML_CALL graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
|
99
100
|
// compute graph without a plan (async)
|
100
|
-
|
101
|
+
enum ggml_status (*GGML_CALL graph_compute) (ggml_backend_t backend, struct ggml_cgraph * cgraph);
|
101
102
|
|
102
103
|
// check if the backend supports an operation
|
103
104
|
bool (*GGML_CALL supports_op)(ggml_backend_t backend, const struct ggml_tensor * op);
|
@@ -262,11 +262,11 @@ void ggml_backend_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_pla
|
|
262
262
|
backend->iface.graph_plan_free(backend, plan);
|
263
263
|
}
|
264
264
|
|
265
|
-
|
266
|
-
backend->iface.graph_plan_compute(backend, plan);
|
265
|
+
enum ggml_status ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
|
266
|
+
return backend->iface.graph_plan_compute(backend, plan);
|
267
267
|
}
|
268
268
|
|
269
|
-
|
269
|
+
enum ggml_status ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
270
270
|
return backend->iface.graph_compute(backend, cgraph);
|
271
271
|
}
|
272
272
|
|
@@ -732,15 +732,15 @@ GGML_CALL static void ggml_backend_cpu_graph_plan_free(ggml_backend_t backend, g
|
|
732
732
|
GGML_UNUSED(backend);
|
733
733
|
}
|
734
734
|
|
735
|
-
GGML_CALL static
|
735
|
+
GGML_CALL static enum ggml_status ggml_backend_cpu_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
|
736
736
|
struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan;
|
737
737
|
|
738
|
-
ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan);
|
738
|
+
return ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan);
|
739
739
|
|
740
740
|
GGML_UNUSED(backend);
|
741
741
|
}
|
742
742
|
|
743
|
-
GGML_CALL static
|
743
|
+
GGML_CALL static enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
744
744
|
struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
|
745
745
|
|
746
746
|
struct ggml_cplan cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads);
|
@@ -755,8 +755,7 @@ GGML_CALL static bool ggml_backend_cpu_graph_compute(ggml_backend_t backend, str
|
|
755
755
|
cplan.abort_callback = cpu_ctx->abort_callback;
|
756
756
|
cplan.abort_callback_data = cpu_ctx->abort_callback_data;
|
757
757
|
|
758
|
-
ggml_graph_compute(cgraph, &cplan);
|
759
|
-
return true;
|
758
|
+
return ggml_graph_compute(cgraph, &cplan);
|
760
759
|
}
|
761
760
|
|
762
761
|
GGML_CALL static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
|
@@ -1437,7 +1436,7 @@ static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {
|
|
1437
1436
|
return true;
|
1438
1437
|
}
|
1439
1438
|
|
1440
|
-
static
|
1439
|
+
static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) {
|
1441
1440
|
uint64_t copy_us[GGML_MAX_BACKENDS] = {0};
|
1442
1441
|
uint64_t compute_us[GGML_MAX_BACKENDS] = {0};
|
1443
1442
|
|
@@ -1472,8 +1471,9 @@ static bool ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) {
|
|
1472
1471
|
|
1473
1472
|
uint64_t compute_start_us = ggml_time_us();
|
1474
1473
|
if (!sched->callback_eval) {
|
1475
|
-
|
1476
|
-
|
1474
|
+
enum ggml_status ec = ggml_backend_graph_compute(split_backend, &split->graph);
|
1475
|
+
if (ec != GGML_STATUS_SUCCESS) {
|
1476
|
+
return ec;
|
1477
1477
|
}
|
1478
1478
|
//ggml_backend_synchronize(split_backend); // necessary to measure compute time
|
1479
1479
|
} else {
|
@@ -1494,8 +1494,9 @@ static bool ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) {
|
|
1494
1494
|
|
1495
1495
|
struct ggml_cgraph gv = ggml_graph_view(&split->graph, j0, j1 + 1);
|
1496
1496
|
|
1497
|
-
|
1498
|
-
|
1497
|
+
enum ggml_status ec = ggml_backend_graph_compute(split_backend, &gv);
|
1498
|
+
if (ec != GGML_STATUS_SUCCESS) {
|
1499
|
+
return ec;
|
1499
1500
|
}
|
1500
1501
|
|
1501
1502
|
if (need && !sched->callback_eval(t, false, sched->callback_eval_user_data)) {
|
@@ -1519,7 +1520,7 @@ static bool ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) {
|
|
1519
1520
|
}
|
1520
1521
|
#endif
|
1521
1522
|
|
1522
|
-
return
|
1523
|
+
return GGML_STATUS_SUCCESS;
|
1523
1524
|
}
|
1524
1525
|
|
1525
1526
|
ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size) {
|
@@ -1581,7 +1582,7 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph *
|
|
1581
1582
|
return true;
|
1582
1583
|
}
|
1583
1584
|
|
1584
|
-
|
1585
|
+
enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
|
1585
1586
|
GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS);
|
1586
1587
|
|
1587
1588
|
if (!sched->is_reset) {
|
@@ -1590,14 +1591,10 @@ bool ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cg
|
|
1590
1591
|
|
1591
1592
|
ggml_backend_sched_split_graph(sched, graph);
|
1592
1593
|
if (!ggml_backend_sched_alloc_splits(sched)) {
|
1593
|
-
return
|
1594
|
+
return GGML_STATUS_ALLOC_FAILED;
|
1594
1595
|
}
|
1595
1596
|
|
1596
|
-
|
1597
|
-
return false;
|
1598
|
-
}
|
1599
|
-
|
1600
|
-
return true;
|
1597
|
+
return ggml_backend_sched_compute_splits(sched);
|
1601
1598
|
}
|
1602
1599
|
|
1603
1600
|
void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data) {
|
@@ -66,12 +66,13 @@ extern "C" {
|
|
66
66
|
|
67
67
|
GGML_API void ggml_backend_synchronize(ggml_backend_t backend);
|
68
68
|
|
69
|
-
GGML_API ggml_backend_graph_plan_t ggml_backend_graph_plan_create
|
69
|
+
GGML_API ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph);
|
70
|
+
GGML_API void ggml_backend_graph_plan_free (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
|
70
71
|
|
71
|
-
GGML_API
|
72
|
-
GGML_API
|
73
|
-
|
74
|
-
GGML_API bool ggml_backend_supports_op
|
72
|
+
GGML_API enum ggml_status ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
|
73
|
+
GGML_API enum ggml_status ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph);
|
74
|
+
|
75
|
+
GGML_API bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op);
|
75
76
|
|
76
77
|
// tensor copy between different backends
|
77
78
|
GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst);
|
@@ -157,26 +158,26 @@ extern "C" {
|
|
157
158
|
typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);
|
158
159
|
|
159
160
|
// Initialize a backend scheduler
|
160
|
-
GGML_API ggml_backend_sched_t
|
161
|
-
GGML_API void
|
161
|
+
GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size);
|
162
|
+
GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched);
|
162
163
|
// Initialize backend buffers from a measure graph
|
163
|
-
GGML_API bool
|
164
|
+
GGML_API bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph);
|
164
165
|
// Get the number of splits of the last graph
|
165
|
-
GGML_API int
|
166
|
+
GGML_API int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched);
|
166
167
|
|
167
|
-
GGML_API size_t
|
168
|
+
GGML_API size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend);
|
168
169
|
|
169
|
-
GGML_API void
|
170
|
-
GGML_API ggml_backend_t
|
170
|
+
GGML_API void ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
|
171
|
+
GGML_API ggml_backend_t ggml_backend_sched_get_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
|
171
172
|
|
172
173
|
// Allocate and compute graph on the backend scheduler
|
173
|
-
GGML_API
|
174
|
+
GGML_API enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
|
174
175
|
|
175
176
|
// Reset all assignments and allocators - must be called before changing the node backends
|
176
|
-
GGML_API void
|
177
|
+
GGML_API void ggml_backend_sched_reset(ggml_backend_sched_t sched);
|
177
178
|
|
178
179
|
// Set a callback to be called for each resulting node during graph compute
|
179
|
-
GGML_API void
|
180
|
+
GGML_API void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data);
|
180
181
|
|
181
182
|
//
|
182
183
|
// Utils
|