llama_cpp 0.9.5 → 0.10.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 4fd4e1a5e4d7e2442ab43255996da3ce92f898f9876f1bda343e2433c5050dd7
4
- data.tar.gz: dece2da6c9befa15e6990d18fb58e2bf13d8da6c62033969b6b5104f82df736d
3
+ metadata.gz: 7f406c15621a7c247adaacf1d588ddf278225e6846466afd1184c00f1ee61768
4
+ data.tar.gz: df73657c75a80cb44f41d34a3c1054676cf59a5d7d56cb1c2ce8a94264002293
5
5
  SHA512:
6
- metadata.gz: 51a383690b6e90e9493e1f318e916dfd94a909f4e554afd8ea822d047f05e96be3e2f371e83f0da5a37a9837d9ae5ecc6992bb9d9c0fd60a9de521bcd148e8f7
7
- data.tar.gz: 15bbe94edb232d1979f2907c6c3ab7325a1089f9dcdd5d4262d7f0955fd6183e6b01cfee16593165f6e9901991e765ea30740bc1a83cca8fad60df4417551e3b
6
+ metadata.gz: acd08d5099f14bf2bd4c8f9bf016253f0e316179b79d72fbe7066b0d645ca31e9bab427fcc53d93874f8df74cb1746731e2cd21864bfecdecff91f9778919b42
7
+ data.tar.gz: 5014a1bd545be90c56bebd48119a198cf7276513cb6c5f00d8322aa6eaa9a27442bc51bf06953a11c2fc04145f797c630cefee17b36589fe38f9226003416a09
data/CHANGELOG.md CHANGED
@@ -1,3 +1,13 @@
1
+ ## [[0.10.0](https://github.com/yoshoku/llama_cpp.rb/compare/v0.9.5...v0.10.0)] - 2023-12-09
2
+
3
+ - Bump bundled llama.cpp from b1593 to b1620.
4
+ - Add `ModelKVOverride` class.
5
+ - Add `offload_kqv`, `type_k`, and `type_v` to ContextParams.
6
+ - Add kv overwrite type constants.
7
+
8
+ **Breaking Changes**
9
+ - Remove `f16_kv` from ContextParams.
10
+
1
11
  ## [[0.9.5](https://github.com/yoshoku/llama_cpp.rb/compare/v0.9.4...v0.9.5)] - 2023-12-02
2
12
 
3
13
  - Bump bundled llama.cpp from b1555 to b1593.
@@ -3,6 +3,7 @@
3
3
  VALUE rb_mLLaMACpp;
4
4
  VALUE rb_cLLaMABatch;
5
5
  VALUE rb_cLLaMAModel;
6
+ VALUE rb_cLLaMAModelKVOverride;
6
7
  VALUE rb_cLLaMAModelParams;
7
8
  VALUE rb_cLLaMATimings;
8
9
  VALUE rb_cLLaMAContext;
@@ -612,6 +613,78 @@ const rb_data_type_t RbLLaMATimings::llama_timings_type = {
612
613
  RUBY_TYPED_FREE_IMMEDIATELY
613
614
  };
614
615
 
616
+ class RbLLaMAModelKVOverride {
617
+ public:
618
+ static VALUE llama_model_kv_override_alloc(VALUE self) {
619
+ llama_model_kv_override* ptr = (llama_model_kv_override*)ruby_xmalloc(sizeof(llama_model_kv_override));
620
+ new (ptr) llama_model_kv_override();
621
+ return TypedData_Wrap_Struct(self, &llama_model_kv_override_type, ptr);
622
+ }
623
+
624
+ static void llama_model_kv_override_free(void* ptr) {
625
+ ((llama_model_kv_override*)ptr)->~llama_model_kv_override();
626
+ ruby_xfree(ptr);
627
+ }
628
+
629
+ static size_t llama_model_kv_override_size(const void* ptr) {
630
+ return sizeof(*((llama_model_kv_override*)ptr));
631
+ }
632
+
633
+ static llama_model_kv_override* get_llama_model_kv_override(VALUE self) {
634
+ llama_model_kv_override* ptr;
635
+ TypedData_Get_Struct(self, llama_model_kv_override, &llama_model_kv_override_type, ptr);
636
+ return ptr;
637
+ }
638
+
639
+ static void define_class(VALUE outer) {
640
+ rb_cLLaMAModelKVOverride = rb_define_class_under(outer, "ModelKVOverride", rb_cObject);
641
+ rb_define_alloc_func(rb_cLLaMAModelKVOverride, llama_model_kv_override_alloc);
642
+ rb_define_method(rb_cLLaMAModelKVOverride, "key", RUBY_METHOD_FUNC(_llama_model_kv_override_get_key), 0);
643
+ rb_define_method(rb_cLLaMAModelKVOverride, "tag", RUBY_METHOD_FUNC(_llama_model_kv_override_get_tag), 0);
644
+ rb_define_method(rb_cLLaMAModelKVOverride, "int_value", RUBY_METHOD_FUNC(_llama_model_kv_override_get_int_value), 0);
645
+ rb_define_method(rb_cLLaMAModelKVOverride, "float_value", RUBY_METHOD_FUNC(_llama_model_kv_override_get_float_value), 0);
646
+ rb_define_method(rb_cLLaMAModelKVOverride, "bool_value", RUBY_METHOD_FUNC(_llama_model_kv_override_get_bool_value), 0);
647
+ }
648
+
649
+ static const rb_data_type_t llama_model_kv_override_type;
650
+
651
+ private:
652
+ static VALUE _llama_model_kv_override_get_key(VALUE self) {
653
+ llama_model_kv_override* ptr = get_llama_model_kv_override(self);
654
+ return rb_utf8_str_new_cstr(ptr->key);
655
+ }
656
+
657
+ static VALUE _llama_model_kv_override_get_tag(VALUE self) {
658
+ llama_model_kv_override* ptr = get_llama_model_kv_override(self);
659
+ return INT2NUM(ptr->tag);
660
+ }
661
+
662
+ static VALUE _llama_model_kv_override_get_int_value(VALUE self) {
663
+ llama_model_kv_override* ptr = get_llama_model_kv_override(self);
664
+ return INT2NUM(ptr->int_value);
665
+ }
666
+
667
+ static VALUE _llama_model_kv_override_get_float_value(VALUE self) {
668
+ llama_model_kv_override* ptr = get_llama_model_kv_override(self);
669
+ return DBL2NUM(ptr->float_value);
670
+ }
671
+
672
+ static VALUE _llama_model_kv_override_get_bool_value(VALUE self) {
673
+ llama_model_kv_override* ptr = get_llama_model_kv_override(self);
674
+ return ptr->bool_value ? Qtrue : Qfalse;
675
+ }
676
+ };
677
+
678
+ const rb_data_type_t RbLLaMAModelKVOverride::llama_model_kv_override_type = {
679
+ "RbLLaMAModelKVOverride",
680
+ { NULL,
681
+ RbLLaMAModelKVOverride::llama_model_kv_override_free,
682
+ RbLLaMAModelKVOverride::llama_model_kv_override_size },
683
+ NULL,
684
+ NULL,
685
+ RUBY_TYPED_FREE_IMMEDIATELY
686
+ };
687
+
615
688
  class LLaMAModelParamsWrapper {
616
689
  public:
617
690
  struct llama_model_params params;
@@ -812,14 +885,18 @@ public:
812
885
  rb_define_method(rb_cLLaMAContextParams, "yarn_beta_slow", RUBY_METHOD_FUNC(_llama_context_params_get_yarn_beta_slow), 0);
813
886
  rb_define_method(rb_cLLaMAContextParams, "yarn_orig_ctx=", RUBY_METHOD_FUNC(_llama_context_params_set_yarn_orig_ctx), 1);
814
887
  rb_define_method(rb_cLLaMAContextParams, "yarn_orig_ctx", RUBY_METHOD_FUNC(_llama_context_params_get_yarn_orig_ctx), 0);
888
+ rb_define_method(rb_cLLaMAContextParams, "type_k=", RUBY_METHOD_FUNC(_llama_context_params_set_type_k), 1);
889
+ rb_define_method(rb_cLLaMAContextParams, "type_k", RUBY_METHOD_FUNC(_llama_context_params_get_type_k), 0);
890
+ rb_define_method(rb_cLLaMAContextParams, "type_v=", RUBY_METHOD_FUNC(_llama_context_params_set_type_v), 1);
891
+ rb_define_method(rb_cLLaMAContextParams, "type_v", RUBY_METHOD_FUNC(_llama_context_params_get_type_v), 0);
815
892
  rb_define_method(rb_cLLaMAContextParams, "mul_mat_q=", RUBY_METHOD_FUNC(_llama_context_params_set_mul_mat_q), 1);
816
893
  rb_define_method(rb_cLLaMAContextParams, "mul_mat_q", RUBY_METHOD_FUNC(_llama_context_params_get_mul_mat_q), 0);
817
- rb_define_method(rb_cLLaMAContextParams, "f16_kv=", RUBY_METHOD_FUNC(_llama_context_params_set_f16_kv), 1);
818
- rb_define_method(rb_cLLaMAContextParams, "f16_kv", RUBY_METHOD_FUNC(_llama_context_params_get_f16_kv), 0);
819
894
  rb_define_method(rb_cLLaMAContextParams, "logits_all=", RUBY_METHOD_FUNC(_llama_context_params_set_logits_all), 1);
820
895
  rb_define_method(rb_cLLaMAContextParams, "logits_all", RUBY_METHOD_FUNC(_llama_context_params_get_logits_all), 0);
821
896
  rb_define_method(rb_cLLaMAContextParams, "embedding=", RUBY_METHOD_FUNC(_llama_context_params_set_embedding), 1);
822
897
  rb_define_method(rb_cLLaMAContextParams, "embedding", RUBY_METHOD_FUNC(_llama_context_params_get_embedding), 0);
898
+ rb_define_method(rb_cLLaMAContextParams, "offload_kqv=", RUBY_METHOD_FUNC(_llama_context_params_set_offload_kqv), 1);
899
+ rb_define_method(rb_cLLaMAContextParams, "offload_kqv", RUBY_METHOD_FUNC(_llama_context_params_get_offload_kqv), 0);
823
900
  }
824
901
 
825
902
  private:
@@ -991,28 +1068,40 @@ private:
991
1068
  return UINT2NUM(ptr->params.yarn_orig_ctx);
992
1069
  }
993
1070
 
994
- // mul_mat_q
995
- static VALUE _llama_context_params_set_mul_mat_q(VALUE self, VALUE mul_mat_q) {
1071
+ // type_k
1072
+ static VALUE _llama_context_params_set_type_k(VALUE self, VALUE type_k) {
996
1073
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
997
- ptr->params.mul_mat_q = RTEST(mul_mat_q) ? true : false;
998
- return ptr->params.mul_mat_q ? Qtrue : Qfalse;
1074
+ ptr->params.type_k = static_cast<enum ggml_type>(NUM2INT(type_k));
1075
+ return INT2NUM(ptr->params.type_k);
999
1076
  }
1000
1077
 
1001
- static VALUE _llama_context_params_get_mul_mat_q(VALUE self) {
1078
+ static VALUE _llama_context_params_get_type_k(VALUE self) {
1002
1079
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1003
- return ptr->params.mul_mat_q ? Qtrue : Qfalse;
1080
+ return INT2NUM(ptr->params.type_k);
1004
1081
  }
1005
1082
 
1006
- // f16_kv
1007
- static VALUE _llama_context_params_set_f16_kv(VALUE self, VALUE f16_kv) {
1083
+ // type_v
1084
+ static VALUE _llama_context_params_set_type_v(VALUE self, VALUE type_v) {
1008
1085
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1009
- ptr->params.f16_kv = RTEST(f16_kv) ? true : false;
1010
- return ptr->params.f16_kv ? Qtrue : Qfalse;
1086
+ ptr->params.type_v = static_cast<enum ggml_type>(NUM2INT(type_v));
1087
+ return INT2NUM(ptr->params.type_v);
1011
1088
  }
1012
1089
 
1013
- static VALUE _llama_context_params_get_f16_kv(VALUE self) {
1090
+ static VALUE _llama_context_params_get_type_v(VALUE self) {
1014
1091
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1015
- return ptr->params.f16_kv ? Qtrue : Qfalse;
1092
+ return INT2NUM(ptr->params.type_v);
1093
+ }
1094
+
1095
+ // mul_mat_q
1096
+ static VALUE _llama_context_params_set_mul_mat_q(VALUE self, VALUE mul_mat_q) {
1097
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1098
+ ptr->params.mul_mat_q = RTEST(mul_mat_q) ? true : false;
1099
+ return ptr->params.mul_mat_q ? Qtrue : Qfalse;
1100
+ }
1101
+
1102
+ static VALUE _llama_context_params_get_mul_mat_q(VALUE self) {
1103
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1104
+ return ptr->params.mul_mat_q ? Qtrue : Qfalse;
1016
1105
  }
1017
1106
 
1018
1107
  // logits_all
@@ -1038,6 +1127,18 @@ private:
1038
1127
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1039
1128
  return ptr->params.embedding ? Qtrue : Qfalse;
1040
1129
  }
1130
+
1131
+ // offload_kqv
1132
+ static VALUE _llama_context_params_set_offload_kqv(VALUE self, VALUE offload_kqv) {
1133
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1134
+ ptr->params.offload_kqv = RTEST(offload_kqv) ? true : false;
1135
+ return ptr->params.offload_kqv ? Qtrue : Qfalse;
1136
+ }
1137
+
1138
+ static VALUE _llama_context_params_get_offload_kqv(VALUE self) {
1139
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1140
+ return ptr->params.offload_kqv ? Qtrue : Qfalse;
1141
+ }
1041
1142
  };
1042
1143
 
1043
1144
  const rb_data_type_t RbLLaMAContextParams::llama_context_params_type = {
@@ -2352,7 +2453,7 @@ private:
2352
2453
  const float penalty_present = NUM2DBL(kw_values[2]);
2353
2454
 
2354
2455
  llama_sample_repetition_penalties(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size,
2355
- penalty_repeat, penalty_freq, penalty_present);
2456
+ penalty_repeat, penalty_freq, penalty_present);
2356
2457
 
2357
2458
  return Qnil;
2358
2459
  }
@@ -2973,6 +3074,7 @@ extern "C" void Init_llama_cpp(void) {
2973
3074
  RbLLaMATokenData::define_class(rb_mLLaMACpp);
2974
3075
  RbLLaMATokenDataArray::define_class(rb_mLLaMACpp);
2975
3076
  RbLLaMAModel::define_class(rb_mLLaMACpp);
3077
+ RbLLaMAModelKVOverride::define_class(rb_mLLaMACpp);
2976
3078
  RbLLaMAModelParams::define_class(rb_mLLaMACpp);
2977
3079
  RbLLaMATimings::define_class(rb_mLLaMACpp);
2978
3080
  RbLLaMAContext::define_class(rb_mLLaMACpp);
@@ -3023,6 +3125,10 @@ extern "C" void Init_llama_cpp(void) {
3023
3125
 
3024
3126
  rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_GUESSED", INT2NUM(LLAMA_FTYPE_GUESSED));
3025
3127
 
3128
+ rb_define_const(rb_mLLaMACpp, "LLAMA_KV_OVERRIDE_INT", INT2NUM(LLAMA_KV_OVERRIDE_INT));
3129
+ rb_define_const(rb_mLLaMACpp, "LLAMA_KV_OVERRIDE_FLOAT", INT2NUM(LLAMA_KV_OVERRIDE_FLOAT));
3130
+ rb_define_const(rb_mLLaMACpp, "LLAMA_KV_OVERRIDE_BOOL", INT2NUM(LLAMA_KV_OVERRIDE_BOOL));
3131
+
3026
3132
  rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_END", INT2NUM(LLAMA_GRETYPE_END));
3027
3133
  rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_ALT", INT2NUM(LLAMA_GRETYPE_ALT));
3028
3134
  rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_RULE_REF", INT2NUM(LLAMA_GRETYPE_RULE_REF));
@@ -168,10 +168,6 @@ static void ggml_tallocr_free_tensor(ggml_tallocr_t alloc, struct ggml_tensor *
168
168
  size = aligned_offset(NULL, size, alloc->alignment);
169
169
  AT_PRINTF("%s: freeing %s at %p (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, ptr, size, alloc->n_free_blocks);
170
170
 
171
- if (!alloc->measure) {
172
- ggml_backend_buffer_free_tensor(alloc->buffer, tensor);
173
- }
174
-
175
171
  #ifdef GGML_ALLOCATOR_DEBUG
176
172
  remove_allocated_tensor(alloc, tensor);
177
173
  #endif
@@ -237,7 +233,7 @@ void ggml_tallocr_reset(ggml_tallocr_t alloc) {
237
233
  }
238
234
 
239
235
  ggml_tallocr_t ggml_tallocr_new(void * data, size_t size, size_t alignment) {
240
- struct ggml_backend_buffer * buffer = ggml_backend_cpu_buffer_from_ptr(NULL, data, size);
236
+ struct ggml_backend_buffer * buffer = ggml_backend_cpu_buffer_from_ptr(data, size);
241
237
 
242
238
  ggml_tallocr_t alloc = (ggml_tallocr_t)malloc(sizeof(struct ggml_tallocr));
243
239
 
@@ -449,7 +445,6 @@ static ggml_tallocr_t node_tallocr(ggml_gallocr_t galloc, struct ggml_tensor * n
449
445
  static void init_view(ggml_gallocr_t galloc, struct ggml_tensor * view, bool update_backend) {
450
446
  ggml_tallocr_t alloc = node_tallocr(galloc, view);
451
447
 
452
- //printf("init_view: %s from src %s\n", view->name, view->view_src->name);
453
448
  GGML_ASSERT(view->view_src != NULL && view->view_src->data != NULL);
454
449
  if (update_backend) {
455
450
  view->backend = view->view_src->backend;
@@ -459,7 +454,7 @@ static void init_view(ggml_gallocr_t galloc, struct ggml_tensor * view, bool upd
459
454
 
460
455
  // FIXME: the view should be initialized by the owning buffer, but currently this breaks the CUDA backend
461
456
  // due to the ggml_tensor_extra_gpu ring buffer overwriting the KV cache extras
462
- assert(ggml_tallocr_is_measure(alloc) || !view->buffer || view->buffer->backend == alloc->buffer->backend);
457
+ assert(ggml_tallocr_is_measure(alloc) || !view->buffer || view->buffer->buft == alloc->buffer->buft);
463
458
 
464
459
  if (!alloc->measure) {
465
460
  ggml_backend_buffer_init_tensor(alloc->buffer, view);
@@ -765,3 +760,43 @@ size_t ggml_allocr_max_size(ggml_allocr_t alloc) {
765
760
  size_t ggml_allocr_alloc_graph(ggml_allocr_t alloc, struct ggml_cgraph * graph) {
766
761
  return ggml_gallocr_alloc_graph(alloc->galloc, alloc->talloc, graph);
767
762
  }
763
+
764
+ // utils
765
+ ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) {
766
+ GGML_ASSERT(ggml_get_no_alloc(ctx) == true);
767
+
768
+ size_t alignment = ggml_backend_buft_get_alignment(buft);
769
+
770
+ size_t nbytes = 0;
771
+ for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
772
+ if (t->data == NULL && t->view_src == NULL) {
773
+ nbytes += GGML_PAD(ggml_backend_buft_get_alloc_size(buft, t), alignment);
774
+ }
775
+ }
776
+
777
+ if (nbytes == 0) {
778
+ fprintf(stderr, "%s: no tensors to allocate\n", __func__);
779
+ return NULL;
780
+ }
781
+
782
+ ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, nbytes);
783
+ ggml_tallocr_t tallocr = ggml_tallocr_new_from_buffer(buffer);
784
+
785
+ for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
786
+ if (t->data == NULL) {
787
+ if (t->view_src == NULL) {
788
+ ggml_tallocr_alloc(tallocr, t);
789
+ } else {
790
+ ggml_backend_view_init(buffer, t);
791
+ }
792
+ }
793
+ }
794
+
795
+ ggml_tallocr_free(tallocr);
796
+
797
+ return buffer;
798
+ }
799
+
800
+ ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend) {
801
+ return ggml_backend_alloc_ctx_tensors_from_buft(ctx, ggml_backend_get_default_buffer_type(backend));
802
+ }
@@ -8,6 +8,7 @@ extern "C" {
8
8
 
9
9
  struct ggml_backend;
10
10
  struct ggml_backend_buffer;
11
+ struct ggml_backend_buffer_type;
11
12
 
12
13
  //
13
14
  // Legacy API
@@ -80,6 +81,12 @@ GGML_API void ggml_gallocr_alloc_graph_n(
80
81
  struct ggml_hash_set hash_set,
81
82
  ggml_tallocr_t * hash_node_talloc);
82
83
 
84
+
85
+ // Utils
86
+ // Create a buffer and allocate all the tensors in a ggml_context
87
+ GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, struct ggml_backend_buffer_type * buft);
88
+ GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, struct ggml_backend * backend);
89
+
83
90
  #ifdef __cplusplus
84
91
  }
85
92
  #endif
@@ -12,31 +12,50 @@ extern "C" {
12
12
  // Backend buffer
13
13
  //
14
14
 
15
+ // buffer type
16
+ typedef void * ggml_backend_buffer_type_context_t;
17
+
18
+ struct ggml_backend_buffer_type_i {
19
+ ggml_backend_buffer_t (*alloc_buffer) (ggml_backend_buffer_type_t buft, size_t size);
20
+ size_t (*get_alignment) (ggml_backend_buffer_type_t buft); // tensor alignment
21
+ size_t (*get_alloc_size) (ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor); // data size needed to allocate the tensor, including padding
22
+ bool (*supports_backend)(ggml_backend_buffer_type_t buft, ggml_backend_t backend); // check if the buffer type is usable by the backend
23
+ };
24
+
25
+ struct ggml_backend_buffer_type {
26
+ struct ggml_backend_buffer_type_i iface;
27
+ ggml_backend_buffer_type_context_t context;
28
+ };
29
+
30
+ // buffer
15
31
  typedef void * ggml_backend_buffer_context_t;
16
32
 
17
33
  struct ggml_backend_buffer_i {
18
- void (*free_buffer) (ggml_backend_buffer_t buffer);
19
- void * (*get_base) (ggml_backend_buffer_t buffer); // get base pointer
20
- size_t (*get_alloc_size)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-allocation callback
21
- void (*init_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // post-allocation callback
22
- void (*free_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-free callback
34
+ void (*free_buffer)(ggml_backend_buffer_t buffer);
35
+ //void (*reset) (ggml_backend_buffer_t buffer); // reset any internal state due to tensor initialization, such as tensor extras
36
+ void * (*get_base) (ggml_backend_buffer_t buffer);
37
+ void (*init_tensor)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
38
+ void (*set_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
39
+ void (*get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
40
+ // (optional) copy tensor between different buffer-type, allow for single-copy tranfers
41
+ void (*cpy_tensor_from)(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst);
42
+ void (*cpy_tensor_to) (ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst);
23
43
  };
24
44
 
25
45
  struct ggml_backend_buffer {
26
- struct ggml_backend_buffer_i iface;
27
-
28
- ggml_backend_t backend;
46
+ struct ggml_backend_buffer_i iface;
47
+ ggml_backend_buffer_type_t buft;
29
48
  ggml_backend_buffer_context_t context;
30
-
31
49
  size_t size;
32
50
  };
33
51
 
34
- GGML_API ggml_backend_buffer_t ggml_backend_buffer_init(
35
- struct ggml_backend * backend,
52
+ ggml_backend_buffer_t ggml_backend_buffer_init(
53
+ ggml_backend_buffer_type_t buft,
36
54
  struct ggml_backend_buffer_i iface,
37
55
  ggml_backend_buffer_context_t context,
38
56
  size_t size);
39
57
 
58
+
40
59
  //
41
60
  // Backend
42
61
  //
@@ -49,20 +68,17 @@ extern "C" {
49
68
  void (*free)(ggml_backend_t backend);
50
69
 
51
70
  // buffer allocation
52
- ggml_backend_buffer_t (*alloc_buffer)(ggml_backend_t backend, size_t size);
71
+ ggml_backend_buffer_type_t (*get_default_buffer_type)(ggml_backend_t backend);
53
72
 
54
- // get buffer alignment
55
- size_t (*get_alignment)(ggml_backend_t backend);
56
-
57
- // tensor data access
58
- // these functions can be asynchronous, helper functions are provided for synchronous access that automatically call synchronize
73
+ // (optional) asynchroneous tensor data access
59
74
  void (*set_tensor_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
60
75
  void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
61
- void (*synchronize) (ggml_backend_t backend);
62
76
 
63
- // (optional) copy tensor between different backends, allow for single-copy tranfers
64
- void (*cpy_tensor_from)(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
65
- void (*cpy_tensor_to) (ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
77
+ // (optional) asynchroneous tensor copy
78
+ void (*cpy_tensor_from_async)(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
79
+ void (*cpy_tensor_to_async) (ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
80
+
81
+ void (*synchronize) (ggml_backend_t backend);
66
82
 
67
83
  // compute graph with a plan
68
84
  ggml_backend_graph_plan_t (*graph_plan_create) (ggml_backend_t backend, struct ggml_cgraph * cgraph);
@@ -82,6 +98,15 @@ extern "C" {
82
98
  ggml_backend_context_t context;
83
99
  };
84
100
 
101
+
102
+ //
103
+ // Backend registry
104
+ //
105
+
106
+ typedef ggml_backend_t (*ggml_backend_init_fn)(const char * params, void * user_data);
107
+
108
+ void ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data);
109
+
85
110
  #ifdef __cplusplus
86
111
  }
87
112
  #endif