llama_cpp 0.13.0 → 0.14.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 8e8d23f3abceeea388895f198a3906b7a24d692cba97e46934a14567450fc3a2
4
- data.tar.gz: 9d1385671b76ea826fbc000910e102fbbb951970f77b7511fdf2653adbc97334
3
+ metadata.gz: c7d855ccd32ae097f26a671751d6a2178361cf8d8a6c1b99af37859f2c47ca03
4
+ data.tar.gz: 3b17318424d08c65ad34da3fa14956c86db0a2ea05ac174323a9b8d2b9e69d59
5
5
  SHA512:
6
- metadata.gz: 24746b8aaaa749b4058ddb64f6b07952356a6947ef1f40bc8bf7010a37b8b476e71632452ce28b6e61b11c66249a9d4fb6573de31e66e750bdb4391ce8f3286c
7
- data.tar.gz: 56f79812ecdeecfc2dce6f68a73fc72d4495c6a51cc1d2ea7ccfeeb3e1ac9b6e72e78cbed019108e05987e431c4634bbfa1029f380f813a7fb6e009b5f6ec4e3
6
+ metadata.gz: 2d90bf9fdd8dbaf5e67b7fb8797a9412168ae6ce5fcfc4c6aca34e194d5beb5204184b5bb36d65dc507a7a618ac9e938987e8d8bf5871e4eb6304b5e6de06020
7
+ data.tar.gz: eab524367ace146eb6e20786bd530cead145e1651bcdb726afbb5364609d04b22ca8a515016bb0c2d154ea97fb62f19222c122bc9bb5efe7fc389a6f259da6f0
data/CHANGELOG.md CHANGED
@@ -1,5 +1,18 @@
1
+ ## [[0.14.0](https://github.com/yoshoku/llama_cpp.rb/compare/v0.13.0...v0.14.0)] - 2024-03-09
2
+
3
+ **Breaking Changes**
4
+
5
+ - Bump bundled llama.cpp from b2303 to b2361.
6
+ - Rename embedding accessor to `embeddings` in `ContextParams`.
7
+ - Remove `do_pooling` accessor from `ContextParams`.
8
+ - Add `pooling_type` accessor to `ContextParams`.
9
+ - Fix the size of array returned by `embedding` method in `Context` from `n_embd` to `n_tokens * n_embd`.
10
+ - Add `embeddings_seq` method to `Context`.
11
+
1
12
  ## [[0.13.0](https://github.com/yoshoku/llama_cpp.rb/compare/v0.12.7...v0.13.0)] - 2024-03-02
2
13
 
14
+ **Breaking Changes**
15
+
3
16
  - Bump bundled llama.cpp from b2143 to b2303.
4
17
  - Remove deprecated methods:
5
18
  - `map_supported?`, `mlock_supported?`, `apply_lora_from_file`, `eval`, `eval_embd`, `sample_classifier_free_guidance`, `sample_temperature`, and `mul_mat_q`.
@@ -952,6 +952,8 @@ public:
952
952
  rb_define_method(rb_cLLaMAContextParams, "n_threads_batch", RUBY_METHOD_FUNC(_llama_context_params_get_n_threads_batch), 0);
953
953
  rb_define_method(rb_cLLaMAContextParams, "rope_scaling_type=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_scaling_type), 1);
954
954
  rb_define_method(rb_cLLaMAContextParams, "rope_scaling_type", RUBY_METHOD_FUNC(_llama_context_params_get_rope_scaling_type), 0);
955
+ rb_define_method(rb_cLLaMAContextParams, "pooling_type=", RUBY_METHOD_FUNC(_llama_context_params_set_pooling_type), 1);
956
+ rb_define_method(rb_cLLaMAContextParams, "pooling_type", RUBY_METHOD_FUNC(_llama_context_params_get_pooling_type), 0);
955
957
  rb_define_method(rb_cLLaMAContextParams, "rope_freq_base=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_freq_base), 1);
956
958
  rb_define_method(rb_cLLaMAContextParams, "rope_freq_base", RUBY_METHOD_FUNC(_llama_context_params_get_rope_freq_base), 0);
957
959
  rb_define_method(rb_cLLaMAContextParams, "rope_freq_scale=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_freq_scale), 1);
@@ -974,12 +976,10 @@ public:
974
976
  rb_define_method(rb_cLLaMAContextParams, "type_v", RUBY_METHOD_FUNC(_llama_context_params_get_type_v), 0);
975
977
  rb_define_method(rb_cLLaMAContextParams, "logits_all=", RUBY_METHOD_FUNC(_llama_context_params_set_logits_all), 1);
976
978
  rb_define_method(rb_cLLaMAContextParams, "logits_all", RUBY_METHOD_FUNC(_llama_context_params_get_logits_all), 0);
977
- rb_define_method(rb_cLLaMAContextParams, "embedding=", RUBY_METHOD_FUNC(_llama_context_params_set_embedding), 1);
978
- rb_define_method(rb_cLLaMAContextParams, "embedding", RUBY_METHOD_FUNC(_llama_context_params_get_embedding), 0);
979
+ rb_define_method(rb_cLLaMAContextParams, "embeddings=", RUBY_METHOD_FUNC(_llama_context_params_set_embeddings), 1);
980
+ rb_define_method(rb_cLLaMAContextParams, "embeddings", RUBY_METHOD_FUNC(_llama_context_params_get_embeddings), 0);
979
981
  rb_define_method(rb_cLLaMAContextParams, "offload_kqv=", RUBY_METHOD_FUNC(_llama_context_params_set_offload_kqv), 1);
980
982
  rb_define_method(rb_cLLaMAContextParams, "offload_kqv", RUBY_METHOD_FUNC(_llama_context_params_get_offload_kqv), 0);
981
- rb_define_method(rb_cLLaMAContextParams, "do_pooling=", RUBY_METHOD_FUNC(_llama_context_params_set_do_pooling), 1);
982
- rb_define_method(rb_cLLaMAContextParams, "do_pooling", RUBY_METHOD_FUNC(_llama_context_params_get_do_pooling), 0);
983
983
  }
984
984
 
985
985
  private:
@@ -1058,7 +1058,7 @@ private:
1058
1058
  // rope_scaling_type
1059
1059
  static VALUE _llama_context_params_set_rope_scaling_type(VALUE self, VALUE scaling_type) {
1060
1060
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1061
- ptr->params.rope_scaling_type = NUM2INT(scaling_type);
1061
+ ptr->params.rope_scaling_type = static_cast<enum llama_rope_scaling_type>(NUM2INT(scaling_type));
1062
1062
  return INT2NUM(ptr->params.rope_scaling_type);
1063
1063
  }
1064
1064
 
@@ -1067,6 +1067,18 @@ private:
1067
1067
  return INT2NUM(ptr->params.rope_scaling_type);
1068
1068
  }
1069
1069
 
1070
+ // pooling_type
1071
+ static VALUE _llama_context_params_set_pooling_type(VALUE self, VALUE scaling_type) {
1072
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1073
+ ptr->params.pooling_type = static_cast<enum llama_pooling_type>(NUM2INT(scaling_type));
1074
+ return INT2NUM(ptr->params.pooling_type);
1075
+ }
1076
+
1077
+ static VALUE _llama_context_params_get_pooling_type(VALUE self) {
1078
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1079
+ return INT2NUM(ptr->params.pooling_type);
1080
+ }
1081
+
1070
1082
  // rope_freq_base
1071
1083
  static VALUE _llama_context_params_set_rope_freq_base(VALUE self, VALUE rope_freq_base) {
1072
1084
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
@@ -1199,16 +1211,16 @@ private:
1199
1211
  return ptr->params.logits_all ? Qtrue : Qfalse;
1200
1212
  }
1201
1213
 
1202
- // embedding
1203
- static VALUE _llama_context_params_set_embedding(VALUE self, VALUE embedding) {
1214
+ // embeddings
1215
+ static VALUE _llama_context_params_set_embeddings(VALUE self, VALUE embeddings) {
1204
1216
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1205
- ptr->params.embedding = RTEST(embedding) ? true : false;
1206
- return ptr->params.embedding ? Qtrue : Qfalse;
1217
+ ptr->params.embeddings = RTEST(embeddings) ? true : false;
1218
+ return ptr->params.embeddings ? Qtrue : Qfalse;
1207
1219
  }
1208
1220
 
1209
- static VALUE _llama_context_params_get_embedding(VALUE self) {
1221
+ static VALUE _llama_context_params_get_embeddings(VALUE self) {
1210
1222
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1211
- return ptr->params.embedding ? Qtrue : Qfalse;
1223
+ return ptr->params.embeddings ? Qtrue : Qfalse;
1212
1224
  }
1213
1225
 
1214
1226
  // offload_kqv
@@ -1222,18 +1234,6 @@ private:
1222
1234
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1223
1235
  return ptr->params.offload_kqv ? Qtrue : Qfalse;
1224
1236
  }
1225
-
1226
- // do_pooling
1227
- static VALUE _llama_context_params_set_do_pooling(VALUE self, VALUE do_pooling) {
1228
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1229
- ptr->params.do_pooling = RTEST(do_pooling) ? true : false;
1230
- return ptr->params.do_pooling ? Qtrue : Qfalse;
1231
- }
1232
-
1233
- static VALUE _llama_context_params_get_do_pooling(VALUE self) {
1234
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1235
- return ptr->params.do_pooling ? Qtrue : Qfalse;
1236
- }
1237
1237
  };
1238
1238
 
1239
1239
  const rb_data_type_t RbLLaMAContextParams::llama_context_params_type = {
@@ -2016,6 +2016,7 @@ public:
2016
2016
  rb_define_method(rb_cLLaMAContext, "logits", RUBY_METHOD_FUNC(_llama_context_logits), 0);
2017
2017
  rb_define_method(rb_cLLaMAContext, "embeddings", RUBY_METHOD_FUNC(_llama_context_embeddings), 0);
2018
2018
  rb_define_method(rb_cLLaMAContext, "embeddings_ith", RUBY_METHOD_FUNC(_llama_context_embeddings_ith), 1);
2019
+ rb_define_method(rb_cLLaMAContext, "embeddings_seq", RUBY_METHOD_FUNC(_llama_context_embeddings_seq), 1);
2019
2020
  rb_define_method(rb_cLLaMAContext, "n_ctx", RUBY_METHOD_FUNC(_llama_context_n_ctx), 0);
2020
2021
  rb_define_method(rb_cLLaMAContext, "n_batch", RUBY_METHOD_FUNC(_llama_context_n_batch), 0);
2021
2022
  rb_define_method(rb_cLLaMAContext, "timings", RUBY_METHOD_FUNC(_llama_context_get_timings), 0);
@@ -2151,7 +2152,7 @@ private:
2151
2152
  LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
2152
2153
  VALUE params = rb_iv_get(self, "@params");
2153
2154
  LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
2154
- if (!prms_ptr->params.embedding) {
2155
+ if (!prms_ptr->params.embeddings) {
2155
2156
  rb_raise(rb_eRuntimeError, "embedding parameter is false");
2156
2157
  return Qnil;
2157
2158
  }
@@ -2160,10 +2161,11 @@ private:
2160
2161
  return Qnil;
2161
2162
  }
2162
2163
 
2164
+ const int n_tokens = NUM2INT(rb_iv_get(self, "@n_tokens"));
2163
2165
  const int n_embd = llama_n_embd(model_ptr->model);
2164
2166
  const float* embd = llama_get_embeddings(ptr->ctx);
2165
2167
  VALUE output = rb_ary_new();
2166
- for (int i = 0; i < n_embd; i++) {
2168
+ for (int i = 0; i < n_tokens * n_embd; i++) {
2167
2169
  rb_ary_push(output, DBL2NUM((double)(embd[i])));
2168
2170
  }
2169
2171
 
@@ -2182,7 +2184,7 @@ private:
2182
2184
  }
2183
2185
  VALUE params = rb_iv_get(self, "@params");
2184
2186
  LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
2185
- if (!prms_ptr->params.embedding) {
2187
+ if (!prms_ptr->params.embeddings) {
2186
2188
  rb_raise(rb_eRuntimeError, "embedding parameter is false");
2187
2189
  return Qnil;
2188
2190
  }
@@ -2200,6 +2202,36 @@ private:
2200
2202
  return output;
2201
2203
  }
2202
2204
 
2205
+ static VALUE _llama_context_embeddings_seq(VALUE self, VALUE seq_id) {
2206
+ if (!RB_INTEGER_TYPE_P(seq_id)) {
2207
+ rb_raise(rb_eArgError, "seq_id must be an integer");
2208
+ return Qnil;
2209
+ }
2210
+ LLaMAContextWrapper* ptr = get_llama_context(self);
2211
+ if (ptr->ctx == NULL) {
2212
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2213
+ return Qnil;
2214
+ }
2215
+ VALUE params = rb_iv_get(self, "@params");
2216
+ LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
2217
+ if (!prms_ptr->params.embeddings) {
2218
+ rb_raise(rb_eRuntimeError, "embedding parameter is false");
2219
+ return Qnil;
2220
+ }
2221
+
2222
+ VALUE model = rb_iv_get(self, "@model");
2223
+ LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
2224
+ const int n_embd = llama_n_embd(model_ptr->model);
2225
+
2226
+ VALUE output = rb_ary_new();
2227
+ const float* embd = llama_get_embeddings_seq(ptr->ctx, NUM2INT(seq_id));
2228
+ for (int i = 0; i < n_embd; i++) {
2229
+ rb_ary_push(output, DBL2NUM((double)(embd[i])));
2230
+ }
2231
+
2232
+ return output;
2233
+ }
2234
+
2203
2235
  static VALUE _llama_context_n_ctx(VALUE self) {
2204
2236
  LLaMAContextWrapper* ptr = get_llama_context(self);
2205
2237
  if (ptr->ctx == NULL) {
@@ -3229,6 +3261,7 @@ extern "C" void Init_llama_cpp(void) {
3229
3261
  rb_define_const(rb_mLLaMACpp, "LLAMA_ROPE_SCALING_TYPE_YARN", INT2NUM(LLAMA_ROPE_SCALING_TYPE_YARN));
3230
3262
  rb_define_const(rb_mLLaMACpp, "LLAMA_ROPE_SCALING_TYPE_MAX_VALUE", INT2NUM(LLAMA_ROPE_SCALING_TYPE_MAX_VALUE));
3231
3263
 
3264
+ rb_define_const(rb_mLLaMACpp, "LLAMA_POOLING_TYPE_UNSPECIFIED", INT2NUM(LLAMA_POOLING_TYPE_UNSPECIFIED));
3232
3265
  rb_define_const(rb_mLLaMACpp, "LLAMA_POOLING_TYPE_NONE", INT2NUM(LLAMA_POOLING_TYPE_NONE));
3233
3266
  rb_define_const(rb_mLLaMACpp, "LLAMA_POOLING_TYPE_MEAN", INT2NUM(LLAMA_POOLING_TYPE_MEAN));
3234
3267
  rb_define_const(rb_mLLaMACpp, "LLAMA_POOLING_TYPE_CLS", INT2NUM(LLAMA_POOLING_TYPE_CLS));
@@ -3,8 +3,8 @@
3
3
  # llama_cpp.rb provides Ruby bindings for the llama.cpp.
4
4
  module LLaMACpp
5
5
  # The version of llama_cpp.rb you install.
6
- VERSION = '0.13.0'
6
+ VERSION = '0.14.0'
7
7
 
8
8
  # The version of llama.cpp bundled with llama_cpp.rb.
9
- LLAMA_CPP_VERSION = 'b2303'
9
+ LLAMA_CPP_VERSION = 'b2361'
10
10
  end
data/sig/llama_cpp.rbs CHANGED
@@ -50,6 +50,7 @@ module LLaMACpp
50
50
  LLAMA_ROPE_SCALING_TYPE_YARN: Integer
51
51
  LLAMA_ROPE_SCALING_TYPE_MAX_VALUE: Integer
52
52
 
53
+ LLAMA_POOLING_TYPE_UNSPECIFIED: Integer
53
54
  LLAMA_POOLING_TYPE_NONE: Integer
54
55
  LLAMA_POOLING_TYPE_MEAN: Integer
55
56
  LLAMA_POOLING_TYPE_CLS: Integer
@@ -201,6 +202,7 @@ module LLaMACpp
201
202
  def initialize: (model: ::LLaMACpp::Model, params: ::LLaMACpp::ContextParams) -> void
202
203
  def embeddings: () -> Array[Float]
203
204
  def embeddings_ith: (Integer) -> Array[Float]
205
+ def embeddings_seq: (Integer) -> Array[Float]
204
206
  def decode: (::LLaMACpp::Batch) -> void
205
207
  def logits: () -> Array[Float]
206
208
  def n_ctx: () -> Integer
@@ -254,6 +256,8 @@ module LLaMACpp
254
256
  def n_threads_batch=: (Integer) -> Integer
255
257
  def rope_scaling_type=: (Integer) -> Integer
256
258
  def rope_scaling_type: () -> Integer
259
+ def pooling_type=: (Integer) -> Integer
260
+ def pooling_type: () -> Integer
257
261
  def rope_freq_base=: (Float) -> Float
258
262
  def rope_freq_base: () -> Float
259
263
  def rope_freq_scale=: (Float) -> Float
@@ -276,12 +280,10 @@ module LLaMACpp
276
280
  def type_v: () -> Integer
277
281
  def logits_all: () -> bool
278
282
  def logits_all=: (bool) -> bool
279
- def embedding: () -> bool
280
- def embedding=: (bool) -> bool
283
+ def embeddings: () -> bool
284
+ def embeddings=: (bool) -> bool
281
285
  def offload_kqv: () -> bool
282
286
  def offload_kqv=: (bool) -> bool
283
- def do_pooling: () -> bool
284
- def do_pooling=: (bool) -> bool
285
287
  end
286
288
 
287
289
  class ModelQuantizeParams
@@ -729,10 +729,9 @@ save-load-state: examples/save-load-state/save-load-state.cpp ggml.o llama.o $(C
729
729
  $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
730
730
  $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
731
731
 
732
- server: examples/server/server.cpp examples/server/oai.hpp examples/server/utils.hpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp examples/llava/clip.cpp examples/llava/clip.h examples/llava/llava.h examples/llava/llava.cpp common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
732
+ server: examples/server/server.cpp examples/server/utils.hpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
733
733
  $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
734
- $(CXX) $(CXXFLAGS) -c examples/llava/clip.cpp -o $(call GET_OBJ_FILE, examples/llava/clip.cpp) -Wno-cast-qual
735
- $(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h %.hpp $< examples/llava/clip.cpp,$^) $(call GET_OBJ_FILE, $<) $(call GET_OBJ_FILE, examples/llava/clip.cpp) -o $@ $(LDFLAGS) $(LWINSOCK2)
734
+ $(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2)
736
735
 
737
736
  gguf: examples/gguf/gguf.cpp ggml.o $(OBJS)
738
737
  $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
@@ -91,13 +91,14 @@ extern "C" {
91
91
  // (optional) complete all pending operations
92
92
  void (*GGML_CALL synchronize)(ggml_backend_t backend);
93
93
 
94
- // compute graph with a plan
94
+ // create a plan for ggml_cgraph and free it
95
95
  ggml_backend_graph_plan_t (*GGML_CALL graph_plan_create) (ggml_backend_t backend, const struct ggml_cgraph * cgraph);
96
96
  void (*GGML_CALL graph_plan_free) (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
97
- void (*GGML_CALL graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
98
97
 
98
+ // compute graph with a plan
99
+ enum ggml_status (*GGML_CALL graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
99
100
  // compute graph without a plan (async)
100
- bool (*GGML_CALL graph_compute)(ggml_backend_t backend, struct ggml_cgraph * cgraph);
101
+ enum ggml_status (*GGML_CALL graph_compute) (ggml_backend_t backend, struct ggml_cgraph * cgraph);
101
102
 
102
103
  // check if the backend supports an operation
103
104
  bool (*GGML_CALL supports_op)(ggml_backend_t backend, const struct ggml_tensor * op);
@@ -262,11 +262,11 @@ void ggml_backend_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_pla
262
262
  backend->iface.graph_plan_free(backend, plan);
263
263
  }
264
264
 
265
- void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
266
- backend->iface.graph_plan_compute(backend, plan);
265
+ enum ggml_status ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
266
+ return backend->iface.graph_plan_compute(backend, plan);
267
267
  }
268
268
 
269
- bool ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
269
+ enum ggml_status ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
270
270
  return backend->iface.graph_compute(backend, cgraph);
271
271
  }
272
272
 
@@ -732,15 +732,15 @@ GGML_CALL static void ggml_backend_cpu_graph_plan_free(ggml_backend_t backend, g
732
732
  GGML_UNUSED(backend);
733
733
  }
734
734
 
735
- GGML_CALL static void ggml_backend_cpu_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
735
+ GGML_CALL static enum ggml_status ggml_backend_cpu_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
736
736
  struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan;
737
737
 
738
- ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan);
738
+ return ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan);
739
739
 
740
740
  GGML_UNUSED(backend);
741
741
  }
742
742
 
743
- GGML_CALL static bool ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
743
+ GGML_CALL static enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
744
744
  struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
745
745
 
746
746
  struct ggml_cplan cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads);
@@ -755,8 +755,7 @@ GGML_CALL static bool ggml_backend_cpu_graph_compute(ggml_backend_t backend, str
755
755
  cplan.abort_callback = cpu_ctx->abort_callback;
756
756
  cplan.abort_callback_data = cpu_ctx->abort_callback_data;
757
757
 
758
- ggml_graph_compute(cgraph, &cplan);
759
- return true;
758
+ return ggml_graph_compute(cgraph, &cplan);
760
759
  }
761
760
 
762
761
  GGML_CALL static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
@@ -1437,7 +1436,7 @@ static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {
1437
1436
  return true;
1438
1437
  }
1439
1438
 
1440
- static bool ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) {
1439
+ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) {
1441
1440
  uint64_t copy_us[GGML_MAX_BACKENDS] = {0};
1442
1441
  uint64_t compute_us[GGML_MAX_BACKENDS] = {0};
1443
1442
 
@@ -1472,8 +1471,9 @@ static bool ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) {
1472
1471
 
1473
1472
  uint64_t compute_start_us = ggml_time_us();
1474
1473
  if (!sched->callback_eval) {
1475
- if (!ggml_backend_graph_compute(split_backend, &split->graph)) {
1476
- return false;
1474
+ enum ggml_status ec = ggml_backend_graph_compute(split_backend, &split->graph);
1475
+ if (ec != GGML_STATUS_SUCCESS) {
1476
+ return ec;
1477
1477
  }
1478
1478
  //ggml_backend_synchronize(split_backend); // necessary to measure compute time
1479
1479
  } else {
@@ -1494,8 +1494,9 @@ static bool ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) {
1494
1494
 
1495
1495
  struct ggml_cgraph gv = ggml_graph_view(&split->graph, j0, j1 + 1);
1496
1496
 
1497
- if (!ggml_backend_graph_compute(split_backend, &gv)) {
1498
- return false;
1497
+ enum ggml_status ec = ggml_backend_graph_compute(split_backend, &gv);
1498
+ if (ec != GGML_STATUS_SUCCESS) {
1499
+ return ec;
1499
1500
  }
1500
1501
 
1501
1502
  if (need && !sched->callback_eval(t, false, sched->callback_eval_user_data)) {
@@ -1519,7 +1520,7 @@ static bool ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) {
1519
1520
  }
1520
1521
  #endif
1521
1522
 
1522
- return true;
1523
+ return GGML_STATUS_SUCCESS;
1523
1524
  }
1524
1525
 
1525
1526
  ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size) {
@@ -1581,7 +1582,7 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph *
1581
1582
  return true;
1582
1583
  }
1583
1584
 
1584
- bool ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
1585
+ enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
1585
1586
  GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS);
1586
1587
 
1587
1588
  if (!sched->is_reset) {
@@ -1590,14 +1591,10 @@ bool ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cg
1590
1591
 
1591
1592
  ggml_backend_sched_split_graph(sched, graph);
1592
1593
  if (!ggml_backend_sched_alloc_splits(sched)) {
1593
- return false;
1594
+ return GGML_STATUS_ALLOC_FAILED;
1594
1595
  }
1595
1596
 
1596
- if (!ggml_backend_sched_compute_splits(sched)) {
1597
- return false;
1598
- }
1599
-
1600
- return true;
1597
+ return ggml_backend_sched_compute_splits(sched);
1601
1598
  }
1602
1599
 
1603
1600
  void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data) {
@@ -66,12 +66,13 @@ extern "C" {
66
66
 
67
67
  GGML_API void ggml_backend_synchronize(ggml_backend_t backend);
68
68
 
69
- GGML_API ggml_backend_graph_plan_t ggml_backend_graph_plan_create (ggml_backend_t backend, struct ggml_cgraph * cgraph);
69
+ GGML_API ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph);
70
+ GGML_API void ggml_backend_graph_plan_free (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
70
71
 
71
- GGML_API void ggml_backend_graph_plan_free (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
72
- GGML_API void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
73
- GGML_API bool ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph);
74
- GGML_API bool ggml_backend_supports_op (ggml_backend_t backend, const struct ggml_tensor * op);
72
+ GGML_API enum ggml_status ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
73
+ GGML_API enum ggml_status ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph);
74
+
75
+ GGML_API bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op);
75
76
 
76
77
  // tensor copy between different backends
77
78
  GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst);
@@ -157,26 +158,26 @@ extern "C" {
157
158
  typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);
158
159
 
159
160
  // Initialize a backend scheduler
160
- GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size);
161
- GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched);
161
+ GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size);
162
+ GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched);
162
163
  // Initialize backend buffers from a measure graph
163
- GGML_API bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph);
164
+ GGML_API bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph);
164
165
  // Get the number of splits of the last graph
165
- GGML_API int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched);
166
+ GGML_API int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched);
166
167
 
167
- GGML_API size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend);
168
+ GGML_API size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend);
168
169
 
169
- GGML_API void ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
170
- GGML_API ggml_backend_t ggml_backend_sched_get_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
170
+ GGML_API void ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
171
+ GGML_API ggml_backend_t ggml_backend_sched_get_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
171
172
 
172
173
  // Allocate and compute graph on the backend scheduler
173
- GGML_API bool ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
174
+ GGML_API enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
174
175
 
175
176
  // Reset all assignments and allocators - must be called before changing the node backends
176
- GGML_API void ggml_backend_sched_reset(ggml_backend_sched_t sched);
177
+ GGML_API void ggml_backend_sched_reset(ggml_backend_sched_t sched);
177
178
 
178
179
  // Set a callback to be called for each resulting node during graph compute
179
- GGML_API void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data);
180
+ GGML_API void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data);
180
181
 
181
182
  //
182
183
  // Utils