llama_cpp 0.13.0 → 0.14.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 8e8d23f3abceeea388895f198a3906b7a24d692cba97e46934a14567450fc3a2
4
- data.tar.gz: 9d1385671b76ea826fbc000910e102fbbb951970f77b7511fdf2653adbc97334
3
+ metadata.gz: c7d855ccd32ae097f26a671751d6a2178361cf8d8a6c1b99af37859f2c47ca03
4
+ data.tar.gz: 3b17318424d08c65ad34da3fa14956c86db0a2ea05ac174323a9b8d2b9e69d59
5
5
  SHA512:
6
- metadata.gz: 24746b8aaaa749b4058ddb64f6b07952356a6947ef1f40bc8bf7010a37b8b476e71632452ce28b6e61b11c66249a9d4fb6573de31e66e750bdb4391ce8f3286c
7
- data.tar.gz: 56f79812ecdeecfc2dce6f68a73fc72d4495c6a51cc1d2ea7ccfeeb3e1ac9b6e72e78cbed019108e05987e431c4634bbfa1029f380f813a7fb6e009b5f6ec4e3
6
+ metadata.gz: 2d90bf9fdd8dbaf5e67b7fb8797a9412168ae6ce5fcfc4c6aca34e194d5beb5204184b5bb36d65dc507a7a618ac9e938987e8d8bf5871e4eb6304b5e6de06020
7
+ data.tar.gz: eab524367ace146eb6e20786bd530cead145e1651bcdb726afbb5364609d04b22ca8a515016bb0c2d154ea97fb62f19222c122bc9bb5efe7fc389a6f259da6f0
data/CHANGELOG.md CHANGED
@@ -1,5 +1,18 @@
1
+ ## [[0.14.0](https://github.com/yoshoku/llama_cpp.rb/compare/v0.13.0...v0.14.0)] - 2024-03-09
2
+
3
+ **Breaking Changes**
4
+
5
+ - Bump bundled llama.cpp from b2303 to b2361.
6
+ - Rename embedding accessor to `embeddings` in `ContextParams`.
7
+ - Remove `do_pooling` accessor from `ContextParams`.
8
+ - Add `pooling_type` accessor to `ContextParams`.
9
+ - Fix the size of array returned by `embedding` method in `Context` from `n_embd` to `n_tokens * n_embd`.
10
+ - Add `embeddings_seq` method to `Context`.
11
+
1
12
  ## [[0.13.0](https://github.com/yoshoku/llama_cpp.rb/compare/v0.12.7...v0.13.0)] - 2024-03-02
2
13
 
14
+ **Breaking Changes**
15
+
3
16
  - Bump bundled llama.cpp from b2143 to b2303.
4
17
  - Remove deprecated methods:
5
18
  - `map_supported?`, `mlock_supported?`, `apply_lora_from_file`, `eval`, `eval_embd`, `sample_classifier_free_guidance`, `sample_temperature`, and `mul_mat_q`.
@@ -952,6 +952,8 @@ public:
952
952
  rb_define_method(rb_cLLaMAContextParams, "n_threads_batch", RUBY_METHOD_FUNC(_llama_context_params_get_n_threads_batch), 0);
953
953
  rb_define_method(rb_cLLaMAContextParams, "rope_scaling_type=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_scaling_type), 1);
954
954
  rb_define_method(rb_cLLaMAContextParams, "rope_scaling_type", RUBY_METHOD_FUNC(_llama_context_params_get_rope_scaling_type), 0);
955
+ rb_define_method(rb_cLLaMAContextParams, "pooling_type=", RUBY_METHOD_FUNC(_llama_context_params_set_pooling_type), 1);
956
+ rb_define_method(rb_cLLaMAContextParams, "pooling_type", RUBY_METHOD_FUNC(_llama_context_params_get_pooling_type), 0);
955
957
  rb_define_method(rb_cLLaMAContextParams, "rope_freq_base=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_freq_base), 1);
956
958
  rb_define_method(rb_cLLaMAContextParams, "rope_freq_base", RUBY_METHOD_FUNC(_llama_context_params_get_rope_freq_base), 0);
957
959
  rb_define_method(rb_cLLaMAContextParams, "rope_freq_scale=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_freq_scale), 1);
@@ -974,12 +976,10 @@ public:
974
976
  rb_define_method(rb_cLLaMAContextParams, "type_v", RUBY_METHOD_FUNC(_llama_context_params_get_type_v), 0);
975
977
  rb_define_method(rb_cLLaMAContextParams, "logits_all=", RUBY_METHOD_FUNC(_llama_context_params_set_logits_all), 1);
976
978
  rb_define_method(rb_cLLaMAContextParams, "logits_all", RUBY_METHOD_FUNC(_llama_context_params_get_logits_all), 0);
977
- rb_define_method(rb_cLLaMAContextParams, "embedding=", RUBY_METHOD_FUNC(_llama_context_params_set_embedding), 1);
978
- rb_define_method(rb_cLLaMAContextParams, "embedding", RUBY_METHOD_FUNC(_llama_context_params_get_embedding), 0);
979
+ rb_define_method(rb_cLLaMAContextParams, "embeddings=", RUBY_METHOD_FUNC(_llama_context_params_set_embeddings), 1);
980
+ rb_define_method(rb_cLLaMAContextParams, "embeddings", RUBY_METHOD_FUNC(_llama_context_params_get_embeddings), 0);
979
981
  rb_define_method(rb_cLLaMAContextParams, "offload_kqv=", RUBY_METHOD_FUNC(_llama_context_params_set_offload_kqv), 1);
980
982
  rb_define_method(rb_cLLaMAContextParams, "offload_kqv", RUBY_METHOD_FUNC(_llama_context_params_get_offload_kqv), 0);
981
- rb_define_method(rb_cLLaMAContextParams, "do_pooling=", RUBY_METHOD_FUNC(_llama_context_params_set_do_pooling), 1);
982
- rb_define_method(rb_cLLaMAContextParams, "do_pooling", RUBY_METHOD_FUNC(_llama_context_params_get_do_pooling), 0);
983
983
  }
984
984
 
985
985
  private:
@@ -1058,7 +1058,7 @@ private:
1058
1058
  // rope_scaling_type
1059
1059
  static VALUE _llama_context_params_set_rope_scaling_type(VALUE self, VALUE scaling_type) {
1060
1060
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1061
- ptr->params.rope_scaling_type = NUM2INT(scaling_type);
1061
+ ptr->params.rope_scaling_type = static_cast<enum llama_rope_scaling_type>(NUM2INT(scaling_type));
1062
1062
  return INT2NUM(ptr->params.rope_scaling_type);
1063
1063
  }
1064
1064
 
@@ -1067,6 +1067,18 @@ private:
1067
1067
  return INT2NUM(ptr->params.rope_scaling_type);
1068
1068
  }
1069
1069
 
1070
+ // pooling_type
1071
+ static VALUE _llama_context_params_set_pooling_type(VALUE self, VALUE scaling_type) {
1072
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1073
+ ptr->params.pooling_type = static_cast<enum llama_pooling_type>(NUM2INT(scaling_type));
1074
+ return INT2NUM(ptr->params.pooling_type);
1075
+ }
1076
+
1077
+ static VALUE _llama_context_params_get_pooling_type(VALUE self) {
1078
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1079
+ return INT2NUM(ptr->params.pooling_type);
1080
+ }
1081
+
1070
1082
  // rope_freq_base
1071
1083
  static VALUE _llama_context_params_set_rope_freq_base(VALUE self, VALUE rope_freq_base) {
1072
1084
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
@@ -1199,16 +1211,16 @@ private:
1199
1211
  return ptr->params.logits_all ? Qtrue : Qfalse;
1200
1212
  }
1201
1213
 
1202
- // embedding
1203
- static VALUE _llama_context_params_set_embedding(VALUE self, VALUE embedding) {
1214
+ // embeddings
1215
+ static VALUE _llama_context_params_set_embeddings(VALUE self, VALUE embeddings) {
1204
1216
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1205
- ptr->params.embedding = RTEST(embedding) ? true : false;
1206
- return ptr->params.embedding ? Qtrue : Qfalse;
1217
+ ptr->params.embeddings = RTEST(embeddings) ? true : false;
1218
+ return ptr->params.embeddings ? Qtrue : Qfalse;
1207
1219
  }
1208
1220
 
1209
- static VALUE _llama_context_params_get_embedding(VALUE self) {
1221
+ static VALUE _llama_context_params_get_embeddings(VALUE self) {
1210
1222
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1211
- return ptr->params.embedding ? Qtrue : Qfalse;
1223
+ return ptr->params.embeddings ? Qtrue : Qfalse;
1212
1224
  }
1213
1225
 
1214
1226
  // offload_kqv
@@ -1222,18 +1234,6 @@ private:
1222
1234
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1223
1235
  return ptr->params.offload_kqv ? Qtrue : Qfalse;
1224
1236
  }
1225
-
1226
- // do_pooling
1227
- static VALUE _llama_context_params_set_do_pooling(VALUE self, VALUE do_pooling) {
1228
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1229
- ptr->params.do_pooling = RTEST(do_pooling) ? true : false;
1230
- return ptr->params.do_pooling ? Qtrue : Qfalse;
1231
- }
1232
-
1233
- static VALUE _llama_context_params_get_do_pooling(VALUE self) {
1234
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1235
- return ptr->params.do_pooling ? Qtrue : Qfalse;
1236
- }
1237
1237
  };
1238
1238
 
1239
1239
  const rb_data_type_t RbLLaMAContextParams::llama_context_params_type = {
@@ -2016,6 +2016,7 @@ public:
2016
2016
  rb_define_method(rb_cLLaMAContext, "logits", RUBY_METHOD_FUNC(_llama_context_logits), 0);
2017
2017
  rb_define_method(rb_cLLaMAContext, "embeddings", RUBY_METHOD_FUNC(_llama_context_embeddings), 0);
2018
2018
  rb_define_method(rb_cLLaMAContext, "embeddings_ith", RUBY_METHOD_FUNC(_llama_context_embeddings_ith), 1);
2019
+ rb_define_method(rb_cLLaMAContext, "embeddings_seq", RUBY_METHOD_FUNC(_llama_context_embeddings_seq), 1);
2019
2020
  rb_define_method(rb_cLLaMAContext, "n_ctx", RUBY_METHOD_FUNC(_llama_context_n_ctx), 0);
2020
2021
  rb_define_method(rb_cLLaMAContext, "n_batch", RUBY_METHOD_FUNC(_llama_context_n_batch), 0);
2021
2022
  rb_define_method(rb_cLLaMAContext, "timings", RUBY_METHOD_FUNC(_llama_context_get_timings), 0);
@@ -2151,7 +2152,7 @@ private:
2151
2152
  LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
2152
2153
  VALUE params = rb_iv_get(self, "@params");
2153
2154
  LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
2154
- if (!prms_ptr->params.embedding) {
2155
+ if (!prms_ptr->params.embeddings) {
2155
2156
  rb_raise(rb_eRuntimeError, "embedding parameter is false");
2156
2157
  return Qnil;
2157
2158
  }
@@ -2160,10 +2161,11 @@ private:
2160
2161
  return Qnil;
2161
2162
  }
2162
2163
 
2164
+ const int n_tokens = NUM2INT(rb_iv_get(self, "@n_tokens"));
2163
2165
  const int n_embd = llama_n_embd(model_ptr->model);
2164
2166
  const float* embd = llama_get_embeddings(ptr->ctx);
2165
2167
  VALUE output = rb_ary_new();
2166
- for (int i = 0; i < n_embd; i++) {
2168
+ for (int i = 0; i < n_tokens * n_embd; i++) {
2167
2169
  rb_ary_push(output, DBL2NUM((double)(embd[i])));
2168
2170
  }
2169
2171
 
@@ -2182,7 +2184,7 @@ private:
2182
2184
  }
2183
2185
  VALUE params = rb_iv_get(self, "@params");
2184
2186
  LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
2185
- if (!prms_ptr->params.embedding) {
2187
+ if (!prms_ptr->params.embeddings) {
2186
2188
  rb_raise(rb_eRuntimeError, "embedding parameter is false");
2187
2189
  return Qnil;
2188
2190
  }
@@ -2200,6 +2202,36 @@ private:
2200
2202
  return output;
2201
2203
  }
2202
2204
 
2205
+ static VALUE _llama_context_embeddings_seq(VALUE self, VALUE seq_id) {
2206
+ if (!RB_INTEGER_TYPE_P(seq_id)) {
2207
+ rb_raise(rb_eArgError, "seq_id must be an integer");
2208
+ return Qnil;
2209
+ }
2210
+ LLaMAContextWrapper* ptr = get_llama_context(self);
2211
+ if (ptr->ctx == NULL) {
2212
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2213
+ return Qnil;
2214
+ }
2215
+ VALUE params = rb_iv_get(self, "@params");
2216
+ LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
2217
+ if (!prms_ptr->params.embeddings) {
2218
+ rb_raise(rb_eRuntimeError, "embedding parameter is false");
2219
+ return Qnil;
2220
+ }
2221
+
2222
+ VALUE model = rb_iv_get(self, "@model");
2223
+ LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
2224
+ const int n_embd = llama_n_embd(model_ptr->model);
2225
+
2226
+ VALUE output = rb_ary_new();
2227
+ const float* embd = llama_get_embeddings_seq(ptr->ctx, NUM2INT(seq_id));
2228
+ for (int i = 0; i < n_embd; i++) {
2229
+ rb_ary_push(output, DBL2NUM((double)(embd[i])));
2230
+ }
2231
+
2232
+ return output;
2233
+ }
2234
+
2203
2235
  static VALUE _llama_context_n_ctx(VALUE self) {
2204
2236
  LLaMAContextWrapper* ptr = get_llama_context(self);
2205
2237
  if (ptr->ctx == NULL) {
@@ -3229,6 +3261,7 @@ extern "C" void Init_llama_cpp(void) {
3229
3261
  rb_define_const(rb_mLLaMACpp, "LLAMA_ROPE_SCALING_TYPE_YARN", INT2NUM(LLAMA_ROPE_SCALING_TYPE_YARN));
3230
3262
  rb_define_const(rb_mLLaMACpp, "LLAMA_ROPE_SCALING_TYPE_MAX_VALUE", INT2NUM(LLAMA_ROPE_SCALING_TYPE_MAX_VALUE));
3231
3263
 
3264
+ rb_define_const(rb_mLLaMACpp, "LLAMA_POOLING_TYPE_UNSPECIFIED", INT2NUM(LLAMA_POOLING_TYPE_UNSPECIFIED));
3232
3265
  rb_define_const(rb_mLLaMACpp, "LLAMA_POOLING_TYPE_NONE", INT2NUM(LLAMA_POOLING_TYPE_NONE));
3233
3266
  rb_define_const(rb_mLLaMACpp, "LLAMA_POOLING_TYPE_MEAN", INT2NUM(LLAMA_POOLING_TYPE_MEAN));
3234
3267
  rb_define_const(rb_mLLaMACpp, "LLAMA_POOLING_TYPE_CLS", INT2NUM(LLAMA_POOLING_TYPE_CLS));
@@ -3,8 +3,8 @@
3
3
  # llama_cpp.rb provides Ruby bindings for the llama.cpp.
4
4
  module LLaMACpp
5
5
  # The version of llama_cpp.rb you install.
6
- VERSION = '0.13.0'
6
+ VERSION = '0.14.0'
7
7
 
8
8
  # The version of llama.cpp bundled with llama_cpp.rb.
9
- LLAMA_CPP_VERSION = 'b2303'
9
+ LLAMA_CPP_VERSION = 'b2361'
10
10
  end
data/sig/llama_cpp.rbs CHANGED
@@ -50,6 +50,7 @@ module LLaMACpp
50
50
  LLAMA_ROPE_SCALING_TYPE_YARN: Integer
51
51
  LLAMA_ROPE_SCALING_TYPE_MAX_VALUE: Integer
52
52
 
53
+ LLAMA_POOLING_TYPE_UNSPECIFIED: Integer
53
54
  LLAMA_POOLING_TYPE_NONE: Integer
54
55
  LLAMA_POOLING_TYPE_MEAN: Integer
55
56
  LLAMA_POOLING_TYPE_CLS: Integer
@@ -201,6 +202,7 @@ module LLaMACpp
201
202
  def initialize: (model: ::LLaMACpp::Model, params: ::LLaMACpp::ContextParams) -> void
202
203
  def embeddings: () -> Array[Float]
203
204
  def embeddings_ith: (Integer) -> Array[Float]
205
+ def embeddings_seq: (Integer) -> Array[Float]
204
206
  def decode: (::LLaMACpp::Batch) -> void
205
207
  def logits: () -> Array[Float]
206
208
  def n_ctx: () -> Integer
@@ -254,6 +256,8 @@ module LLaMACpp
254
256
  def n_threads_batch=: (Integer) -> Integer
255
257
  def rope_scaling_type=: (Integer) -> Integer
256
258
  def rope_scaling_type: () -> Integer
259
+ def pooling_type=: (Integer) -> Integer
260
+ def pooling_type: () -> Integer
257
261
  def rope_freq_base=: (Float) -> Float
258
262
  def rope_freq_base: () -> Float
259
263
  def rope_freq_scale=: (Float) -> Float
@@ -276,12 +280,10 @@ module LLaMACpp
276
280
  def type_v: () -> Integer
277
281
  def logits_all: () -> bool
278
282
  def logits_all=: (bool) -> bool
279
- def embedding: () -> bool
280
- def embedding=: (bool) -> bool
283
+ def embeddings: () -> bool
284
+ def embeddings=: (bool) -> bool
281
285
  def offload_kqv: () -> bool
282
286
  def offload_kqv=: (bool) -> bool
283
- def do_pooling: () -> bool
284
- def do_pooling=: (bool) -> bool
285
287
  end
286
288
 
287
289
  class ModelQuantizeParams
@@ -729,10 +729,9 @@ save-load-state: examples/save-load-state/save-load-state.cpp ggml.o llama.o $(C
729
729
  $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
730
730
  $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
731
731
 
732
- server: examples/server/server.cpp examples/server/oai.hpp examples/server/utils.hpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp examples/llava/clip.cpp examples/llava/clip.h examples/llava/llava.h examples/llava/llava.cpp common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
732
+ server: examples/server/server.cpp examples/server/utils.hpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
733
733
  $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
734
- $(CXX) $(CXXFLAGS) -c examples/llava/clip.cpp -o $(call GET_OBJ_FILE, examples/llava/clip.cpp) -Wno-cast-qual
735
- $(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h %.hpp $< examples/llava/clip.cpp,$^) $(call GET_OBJ_FILE, $<) $(call GET_OBJ_FILE, examples/llava/clip.cpp) -o $@ $(LDFLAGS) $(LWINSOCK2)
734
+ $(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2)
736
735
 
737
736
  gguf: examples/gguf/gguf.cpp ggml.o $(OBJS)
738
737
  $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
@@ -91,13 +91,14 @@ extern "C" {
91
91
  // (optional) complete all pending operations
92
92
  void (*GGML_CALL synchronize)(ggml_backend_t backend);
93
93
 
94
- // compute graph with a plan
94
+ // create a plan for ggml_cgraph and free it
95
95
  ggml_backend_graph_plan_t (*GGML_CALL graph_plan_create) (ggml_backend_t backend, const struct ggml_cgraph * cgraph);
96
96
  void (*GGML_CALL graph_plan_free) (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
97
- void (*GGML_CALL graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
98
97
 
98
+ // compute graph with a plan
99
+ enum ggml_status (*GGML_CALL graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
99
100
  // compute graph without a plan (async)
100
- bool (*GGML_CALL graph_compute)(ggml_backend_t backend, struct ggml_cgraph * cgraph);
101
+ enum ggml_status (*GGML_CALL graph_compute) (ggml_backend_t backend, struct ggml_cgraph * cgraph);
101
102
 
102
103
  // check if the backend supports an operation
103
104
  bool (*GGML_CALL supports_op)(ggml_backend_t backend, const struct ggml_tensor * op);
@@ -262,11 +262,11 @@ void ggml_backend_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_pla
262
262
  backend->iface.graph_plan_free(backend, plan);
263
263
  }
264
264
 
265
- void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
266
- backend->iface.graph_plan_compute(backend, plan);
265
+ enum ggml_status ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
266
+ return backend->iface.graph_plan_compute(backend, plan);
267
267
  }
268
268
 
269
- bool ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
269
+ enum ggml_status ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
270
270
  return backend->iface.graph_compute(backend, cgraph);
271
271
  }
272
272
 
@@ -732,15 +732,15 @@ GGML_CALL static void ggml_backend_cpu_graph_plan_free(ggml_backend_t backend, g
732
732
  GGML_UNUSED(backend);
733
733
  }
734
734
 
735
- GGML_CALL static void ggml_backend_cpu_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
735
+ GGML_CALL static enum ggml_status ggml_backend_cpu_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
736
736
  struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan;
737
737
 
738
- ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan);
738
+ return ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan);
739
739
 
740
740
  GGML_UNUSED(backend);
741
741
  }
742
742
 
743
- GGML_CALL static bool ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
743
+ GGML_CALL static enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
744
744
  struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
745
745
 
746
746
  struct ggml_cplan cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads);
@@ -755,8 +755,7 @@ GGML_CALL static bool ggml_backend_cpu_graph_compute(ggml_backend_t backend, str
755
755
  cplan.abort_callback = cpu_ctx->abort_callback;
756
756
  cplan.abort_callback_data = cpu_ctx->abort_callback_data;
757
757
 
758
- ggml_graph_compute(cgraph, &cplan);
759
- return true;
758
+ return ggml_graph_compute(cgraph, &cplan);
760
759
  }
761
760
 
762
761
  GGML_CALL static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
@@ -1437,7 +1436,7 @@ static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {
1437
1436
  return true;
1438
1437
  }
1439
1438
 
1440
- static bool ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) {
1439
+ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) {
1441
1440
  uint64_t copy_us[GGML_MAX_BACKENDS] = {0};
1442
1441
  uint64_t compute_us[GGML_MAX_BACKENDS] = {0};
1443
1442
 
@@ -1472,8 +1471,9 @@ static bool ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) {
1472
1471
 
1473
1472
  uint64_t compute_start_us = ggml_time_us();
1474
1473
  if (!sched->callback_eval) {
1475
- if (!ggml_backend_graph_compute(split_backend, &split->graph)) {
1476
- return false;
1474
+ enum ggml_status ec = ggml_backend_graph_compute(split_backend, &split->graph);
1475
+ if (ec != GGML_STATUS_SUCCESS) {
1476
+ return ec;
1477
1477
  }
1478
1478
  //ggml_backend_synchronize(split_backend); // necessary to measure compute time
1479
1479
  } else {
@@ -1494,8 +1494,9 @@ static bool ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) {
1494
1494
 
1495
1495
  struct ggml_cgraph gv = ggml_graph_view(&split->graph, j0, j1 + 1);
1496
1496
 
1497
- if (!ggml_backend_graph_compute(split_backend, &gv)) {
1498
- return false;
1497
+ enum ggml_status ec = ggml_backend_graph_compute(split_backend, &gv);
1498
+ if (ec != GGML_STATUS_SUCCESS) {
1499
+ return ec;
1499
1500
  }
1500
1501
 
1501
1502
  if (need && !sched->callback_eval(t, false, sched->callback_eval_user_data)) {
@@ -1519,7 +1520,7 @@ static bool ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) {
1519
1520
  }
1520
1521
  #endif
1521
1522
 
1522
- return true;
1523
+ return GGML_STATUS_SUCCESS;
1523
1524
  }
1524
1525
 
1525
1526
  ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size) {
@@ -1581,7 +1582,7 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph *
1581
1582
  return true;
1582
1583
  }
1583
1584
 
1584
- bool ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
1585
+ enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
1585
1586
  GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS);
1586
1587
 
1587
1588
  if (!sched->is_reset) {
@@ -1590,14 +1591,10 @@ bool ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cg
1590
1591
 
1591
1592
  ggml_backend_sched_split_graph(sched, graph);
1592
1593
  if (!ggml_backend_sched_alloc_splits(sched)) {
1593
- return false;
1594
+ return GGML_STATUS_ALLOC_FAILED;
1594
1595
  }
1595
1596
 
1596
- if (!ggml_backend_sched_compute_splits(sched)) {
1597
- return false;
1598
- }
1599
-
1600
- return true;
1597
+ return ggml_backend_sched_compute_splits(sched);
1601
1598
  }
1602
1599
 
1603
1600
  void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data) {
@@ -66,12 +66,13 @@ extern "C" {
66
66
 
67
67
  GGML_API void ggml_backend_synchronize(ggml_backend_t backend);
68
68
 
69
- GGML_API ggml_backend_graph_plan_t ggml_backend_graph_plan_create (ggml_backend_t backend, struct ggml_cgraph * cgraph);
69
+ GGML_API ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph);
70
+ GGML_API void ggml_backend_graph_plan_free (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
70
71
 
71
- GGML_API void ggml_backend_graph_plan_free (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
72
- GGML_API void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
73
- GGML_API bool ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph);
74
- GGML_API bool ggml_backend_supports_op (ggml_backend_t backend, const struct ggml_tensor * op);
72
+ GGML_API enum ggml_status ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
73
+ GGML_API enum ggml_status ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph);
74
+
75
+ GGML_API bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op);
75
76
 
76
77
  // tensor copy between different backends
77
78
  GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst);
@@ -157,26 +158,26 @@ extern "C" {
157
158
  typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);
158
159
 
159
160
  // Initialize a backend scheduler
160
- GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size);
161
- GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched);
161
+ GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size);
162
+ GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched);
162
163
  // Initialize backend buffers from a measure graph
163
- GGML_API bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph);
164
+ GGML_API bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph);
164
165
  // Get the number of splits of the last graph
165
- GGML_API int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched);
166
+ GGML_API int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched);
166
167
 
167
- GGML_API size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend);
168
+ GGML_API size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend);
168
169
 
169
- GGML_API void ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
170
- GGML_API ggml_backend_t ggml_backend_sched_get_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
170
+ GGML_API void ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
171
+ GGML_API ggml_backend_t ggml_backend_sched_get_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
171
172
 
172
173
  // Allocate and compute graph on the backend scheduler
173
- GGML_API bool ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
174
+ GGML_API enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
174
175
 
175
176
  // Reset all assignments and allocators - must be called before changing the node backends
176
- GGML_API void ggml_backend_sched_reset(ggml_backend_sched_t sched);
177
+ GGML_API void ggml_backend_sched_reset(ggml_backend_sched_t sched);
177
178
 
178
179
  // Set a callback to be called for each resulting node during graph compute
179
- GGML_API void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data);
180
+ GGML_API void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data);
180
181
 
181
182
  //
182
183
  // Utils