llama_cpp 0.9.5 → 0.10.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 4fd4e1a5e4d7e2442ab43255996da3ce92f898f9876f1bda343e2433c5050dd7
4
- data.tar.gz: dece2da6c9befa15e6990d18fb58e2bf13d8da6c62033969b6b5104f82df736d
3
+ metadata.gz: 7f406c15621a7c247adaacf1d588ddf278225e6846466afd1184c00f1ee61768
4
+ data.tar.gz: df73657c75a80cb44f41d34a3c1054676cf59a5d7d56cb1c2ce8a94264002293
5
5
  SHA512:
6
- metadata.gz: 51a383690b6e90e9493e1f318e916dfd94a909f4e554afd8ea822d047f05e96be3e2f371e83f0da5a37a9837d9ae5ecc6992bb9d9c0fd60a9de521bcd148e8f7
7
- data.tar.gz: 15bbe94edb232d1979f2907c6c3ab7325a1089f9dcdd5d4262d7f0955fd6183e6b01cfee16593165f6e9901991e765ea30740bc1a83cca8fad60df4417551e3b
6
+ metadata.gz: acd08d5099f14bf2bd4c8f9bf016253f0e316179b79d72fbe7066b0d645ca31e9bab427fcc53d93874f8df74cb1746731e2cd21864bfecdecff91f9778919b42
7
+ data.tar.gz: 5014a1bd545be90c56bebd48119a198cf7276513cb6c5f00d8322aa6eaa9a27442bc51bf06953a11c2fc04145f797c630cefee17b36589fe38f9226003416a09
data/CHANGELOG.md CHANGED
@@ -1,3 +1,13 @@
1
+ ## [[0.10.0](https://github.com/yoshoku/llama_cpp.rb/compare/v0.9.5...v0.10.0)] - 2023-12-09
2
+
3
+ - Bump bundled llama.cpp from b1593 to b1620.
4
+ - Add `ModelKVOverride` class.
5
+ - Add `offload_kqv`, `type_k`, and `type_v` to ContextParams.
6
+ - Add kv overwrite type constants.
7
+
8
+ **Breaking Changes**
9
+ - Remove `f16_kv` from ContextParams.
10
+
1
11
  ## [[0.9.5](https://github.com/yoshoku/llama_cpp.rb/compare/v0.9.4...v0.9.5)] - 2023-12-02
2
12
 
3
13
  - Bump bundled llama.cpp from b1555 to b1593.
@@ -3,6 +3,7 @@
3
3
  VALUE rb_mLLaMACpp;
4
4
  VALUE rb_cLLaMABatch;
5
5
  VALUE rb_cLLaMAModel;
6
+ VALUE rb_cLLaMAModelKVOverride;
6
7
  VALUE rb_cLLaMAModelParams;
7
8
  VALUE rb_cLLaMATimings;
8
9
  VALUE rb_cLLaMAContext;
@@ -612,6 +613,78 @@ const rb_data_type_t RbLLaMATimings::llama_timings_type = {
612
613
  RUBY_TYPED_FREE_IMMEDIATELY
613
614
  };
614
615
 
616
+ class RbLLaMAModelKVOverride {
617
+ public:
618
+ static VALUE llama_model_kv_override_alloc(VALUE self) {
619
+ llama_model_kv_override* ptr = (llama_model_kv_override*)ruby_xmalloc(sizeof(llama_model_kv_override));
620
+ new (ptr) llama_model_kv_override();
621
+ return TypedData_Wrap_Struct(self, &llama_model_kv_override_type, ptr);
622
+ }
623
+
624
+ static void llama_model_kv_override_free(void* ptr) {
625
+ ((llama_model_kv_override*)ptr)->~llama_model_kv_override();
626
+ ruby_xfree(ptr);
627
+ }
628
+
629
+ static size_t llama_model_kv_override_size(const void* ptr) {
630
+ return sizeof(*((llama_model_kv_override*)ptr));
631
+ }
632
+
633
+ static llama_model_kv_override* get_llama_model_kv_override(VALUE self) {
634
+ llama_model_kv_override* ptr;
635
+ TypedData_Get_Struct(self, llama_model_kv_override, &llama_model_kv_override_type, ptr);
636
+ return ptr;
637
+ }
638
+
639
+ static void define_class(VALUE outer) {
640
+ rb_cLLaMAModelKVOverride = rb_define_class_under(outer, "ModelKVOverride", rb_cObject);
641
+ rb_define_alloc_func(rb_cLLaMAModelKVOverride, llama_model_kv_override_alloc);
642
+ rb_define_method(rb_cLLaMAModelKVOverride, "key", RUBY_METHOD_FUNC(_llama_model_kv_override_get_key), 0);
643
+ rb_define_method(rb_cLLaMAModelKVOverride, "tag", RUBY_METHOD_FUNC(_llama_model_kv_override_get_tag), 0);
644
+ rb_define_method(rb_cLLaMAModelKVOverride, "int_value", RUBY_METHOD_FUNC(_llama_model_kv_override_get_int_value), 0);
645
+ rb_define_method(rb_cLLaMAModelKVOverride, "float_value", RUBY_METHOD_FUNC(_llama_model_kv_override_get_float_value), 0);
646
+ rb_define_method(rb_cLLaMAModelKVOverride, "bool_value", RUBY_METHOD_FUNC(_llama_model_kv_override_get_bool_value), 0);
647
+ }
648
+
649
+ static const rb_data_type_t llama_model_kv_override_type;
650
+
651
+ private:
652
+ static VALUE _llama_model_kv_override_get_key(VALUE self) {
653
+ llama_model_kv_override* ptr = get_llama_model_kv_override(self);
654
+ return rb_utf8_str_new_cstr(ptr->key);
655
+ }
656
+
657
+ static VALUE _llama_model_kv_override_get_tag(VALUE self) {
658
+ llama_model_kv_override* ptr = get_llama_model_kv_override(self);
659
+ return INT2NUM(ptr->tag);
660
+ }
661
+
662
+ static VALUE _llama_model_kv_override_get_int_value(VALUE self) {
663
+ llama_model_kv_override* ptr = get_llama_model_kv_override(self);
664
+ return INT2NUM(ptr->int_value);
665
+ }
666
+
667
+ static VALUE _llama_model_kv_override_get_float_value(VALUE self) {
668
+ llama_model_kv_override* ptr = get_llama_model_kv_override(self);
669
+ return DBL2NUM(ptr->float_value);
670
+ }
671
+
672
+ static VALUE _llama_model_kv_override_get_bool_value(VALUE self) {
673
+ llama_model_kv_override* ptr = get_llama_model_kv_override(self);
674
+ return ptr->bool_value ? Qtrue : Qfalse;
675
+ }
676
+ };
677
+
678
+ const rb_data_type_t RbLLaMAModelKVOverride::llama_model_kv_override_type = {
679
+ "RbLLaMAModelKVOverride",
680
+ { NULL,
681
+ RbLLaMAModelKVOverride::llama_model_kv_override_free,
682
+ RbLLaMAModelKVOverride::llama_model_kv_override_size },
683
+ NULL,
684
+ NULL,
685
+ RUBY_TYPED_FREE_IMMEDIATELY
686
+ };
687
+
615
688
  class LLaMAModelParamsWrapper {
616
689
  public:
617
690
  struct llama_model_params params;
@@ -812,14 +885,18 @@ public:
812
885
  rb_define_method(rb_cLLaMAContextParams, "yarn_beta_slow", RUBY_METHOD_FUNC(_llama_context_params_get_yarn_beta_slow), 0);
813
886
  rb_define_method(rb_cLLaMAContextParams, "yarn_orig_ctx=", RUBY_METHOD_FUNC(_llama_context_params_set_yarn_orig_ctx), 1);
814
887
  rb_define_method(rb_cLLaMAContextParams, "yarn_orig_ctx", RUBY_METHOD_FUNC(_llama_context_params_get_yarn_orig_ctx), 0);
888
+ rb_define_method(rb_cLLaMAContextParams, "type_k=", RUBY_METHOD_FUNC(_llama_context_params_set_type_k), 1);
889
+ rb_define_method(rb_cLLaMAContextParams, "type_k", RUBY_METHOD_FUNC(_llama_context_params_get_type_k), 0);
890
+ rb_define_method(rb_cLLaMAContextParams, "type_v=", RUBY_METHOD_FUNC(_llama_context_params_set_type_v), 1);
891
+ rb_define_method(rb_cLLaMAContextParams, "type_v", RUBY_METHOD_FUNC(_llama_context_params_get_type_v), 0);
815
892
  rb_define_method(rb_cLLaMAContextParams, "mul_mat_q=", RUBY_METHOD_FUNC(_llama_context_params_set_mul_mat_q), 1);
816
893
  rb_define_method(rb_cLLaMAContextParams, "mul_mat_q", RUBY_METHOD_FUNC(_llama_context_params_get_mul_mat_q), 0);
817
- rb_define_method(rb_cLLaMAContextParams, "f16_kv=", RUBY_METHOD_FUNC(_llama_context_params_set_f16_kv), 1);
818
- rb_define_method(rb_cLLaMAContextParams, "f16_kv", RUBY_METHOD_FUNC(_llama_context_params_get_f16_kv), 0);
819
894
  rb_define_method(rb_cLLaMAContextParams, "logits_all=", RUBY_METHOD_FUNC(_llama_context_params_set_logits_all), 1);
820
895
  rb_define_method(rb_cLLaMAContextParams, "logits_all", RUBY_METHOD_FUNC(_llama_context_params_get_logits_all), 0);
821
896
  rb_define_method(rb_cLLaMAContextParams, "embedding=", RUBY_METHOD_FUNC(_llama_context_params_set_embedding), 1);
822
897
  rb_define_method(rb_cLLaMAContextParams, "embedding", RUBY_METHOD_FUNC(_llama_context_params_get_embedding), 0);
898
+ rb_define_method(rb_cLLaMAContextParams, "offload_kqv=", RUBY_METHOD_FUNC(_llama_context_params_set_offload_kqv), 1);
899
+ rb_define_method(rb_cLLaMAContextParams, "offload_kqv", RUBY_METHOD_FUNC(_llama_context_params_get_offload_kqv), 0);
823
900
  }
824
901
 
825
902
  private:
@@ -991,28 +1068,40 @@ private:
991
1068
  return UINT2NUM(ptr->params.yarn_orig_ctx);
992
1069
  }
993
1070
 
994
- // mul_mat_q
995
- static VALUE _llama_context_params_set_mul_mat_q(VALUE self, VALUE mul_mat_q) {
1071
+ // type_k
1072
+ static VALUE _llama_context_params_set_type_k(VALUE self, VALUE type_k) {
996
1073
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
997
- ptr->params.mul_mat_q = RTEST(mul_mat_q) ? true : false;
998
- return ptr->params.mul_mat_q ? Qtrue : Qfalse;
1074
+ ptr->params.type_k = static_cast<enum ggml_type>(NUM2INT(type_k));
1075
+ return INT2NUM(ptr->params.type_k);
999
1076
  }
1000
1077
 
1001
- static VALUE _llama_context_params_get_mul_mat_q(VALUE self) {
1078
+ static VALUE _llama_context_params_get_type_k(VALUE self) {
1002
1079
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1003
- return ptr->params.mul_mat_q ? Qtrue : Qfalse;
1080
+ return INT2NUM(ptr->params.type_k);
1004
1081
  }
1005
1082
 
1006
- // f16_kv
1007
- static VALUE _llama_context_params_set_f16_kv(VALUE self, VALUE f16_kv) {
1083
+ // type_v
1084
+ static VALUE _llama_context_params_set_type_v(VALUE self, VALUE type_v) {
1008
1085
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1009
- ptr->params.f16_kv = RTEST(f16_kv) ? true : false;
1010
- return ptr->params.f16_kv ? Qtrue : Qfalse;
1086
+ ptr->params.type_v = static_cast<enum ggml_type>(NUM2INT(type_v));
1087
+ return INT2NUM(ptr->params.type_v);
1011
1088
  }
1012
1089
 
1013
- static VALUE _llama_context_params_get_f16_kv(VALUE self) {
1090
+ static VALUE _llama_context_params_get_type_v(VALUE self) {
1014
1091
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1015
- return ptr->params.f16_kv ? Qtrue : Qfalse;
1092
+ return INT2NUM(ptr->params.type_v);
1093
+ }
1094
+
1095
+ // mul_mat_q
1096
+ static VALUE _llama_context_params_set_mul_mat_q(VALUE self, VALUE mul_mat_q) {
1097
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1098
+ ptr->params.mul_mat_q = RTEST(mul_mat_q) ? true : false;
1099
+ return ptr->params.mul_mat_q ? Qtrue : Qfalse;
1100
+ }
1101
+
1102
+ static VALUE _llama_context_params_get_mul_mat_q(VALUE self) {
1103
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1104
+ return ptr->params.mul_mat_q ? Qtrue : Qfalse;
1016
1105
  }
1017
1106
 
1018
1107
  // logits_all
@@ -1038,6 +1127,18 @@ private:
1038
1127
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1039
1128
  return ptr->params.embedding ? Qtrue : Qfalse;
1040
1129
  }
1130
+
1131
+ // offload_kqv
1132
+ static VALUE _llama_context_params_set_offload_kqv(VALUE self, VALUE offload_kqv) {
1133
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1134
+ ptr->params.offload_kqv = RTEST(offload_kqv) ? true : false;
1135
+ return ptr->params.offload_kqv ? Qtrue : Qfalse;
1136
+ }
1137
+
1138
+ static VALUE _llama_context_params_get_offload_kqv(VALUE self) {
1139
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1140
+ return ptr->params.offload_kqv ? Qtrue : Qfalse;
1141
+ }
1041
1142
  };
1042
1143
 
1043
1144
  const rb_data_type_t RbLLaMAContextParams::llama_context_params_type = {
@@ -2352,7 +2453,7 @@ private:
2352
2453
  const float penalty_present = NUM2DBL(kw_values[2]);
2353
2454
 
2354
2455
  llama_sample_repetition_penalties(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size,
2355
- penalty_repeat, penalty_freq, penalty_present);
2456
+ penalty_repeat, penalty_freq, penalty_present);
2356
2457
 
2357
2458
  return Qnil;
2358
2459
  }
@@ -2973,6 +3074,7 @@ extern "C" void Init_llama_cpp(void) {
2973
3074
  RbLLaMATokenData::define_class(rb_mLLaMACpp);
2974
3075
  RbLLaMATokenDataArray::define_class(rb_mLLaMACpp);
2975
3076
  RbLLaMAModel::define_class(rb_mLLaMACpp);
3077
+ RbLLaMAModelKVOverride::define_class(rb_mLLaMACpp);
2976
3078
  RbLLaMAModelParams::define_class(rb_mLLaMACpp);
2977
3079
  RbLLaMATimings::define_class(rb_mLLaMACpp);
2978
3080
  RbLLaMAContext::define_class(rb_mLLaMACpp);
@@ -3023,6 +3125,10 @@ extern "C" void Init_llama_cpp(void) {
3023
3125
 
3024
3126
  rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_GUESSED", INT2NUM(LLAMA_FTYPE_GUESSED));
3025
3127
 
3128
+ rb_define_const(rb_mLLaMACpp, "LLAMA_KV_OVERRIDE_INT", INT2NUM(LLAMA_KV_OVERRIDE_INT));
3129
+ rb_define_const(rb_mLLaMACpp, "LLAMA_KV_OVERRIDE_FLOAT", INT2NUM(LLAMA_KV_OVERRIDE_FLOAT));
3130
+ rb_define_const(rb_mLLaMACpp, "LLAMA_KV_OVERRIDE_BOOL", INT2NUM(LLAMA_KV_OVERRIDE_BOOL));
3131
+
3026
3132
  rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_END", INT2NUM(LLAMA_GRETYPE_END));
3027
3133
  rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_ALT", INT2NUM(LLAMA_GRETYPE_ALT));
3028
3134
  rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_RULE_REF", INT2NUM(LLAMA_GRETYPE_RULE_REF));
@@ -168,10 +168,6 @@ static void ggml_tallocr_free_tensor(ggml_tallocr_t alloc, struct ggml_tensor *
168
168
  size = aligned_offset(NULL, size, alloc->alignment);
169
169
  AT_PRINTF("%s: freeing %s at %p (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, ptr, size, alloc->n_free_blocks);
170
170
 
171
- if (!alloc->measure) {
172
- ggml_backend_buffer_free_tensor(alloc->buffer, tensor);
173
- }
174
-
175
171
  #ifdef GGML_ALLOCATOR_DEBUG
176
172
  remove_allocated_tensor(alloc, tensor);
177
173
  #endif
@@ -237,7 +233,7 @@ void ggml_tallocr_reset(ggml_tallocr_t alloc) {
237
233
  }
238
234
 
239
235
  ggml_tallocr_t ggml_tallocr_new(void * data, size_t size, size_t alignment) {
240
- struct ggml_backend_buffer * buffer = ggml_backend_cpu_buffer_from_ptr(NULL, data, size);
236
+ struct ggml_backend_buffer * buffer = ggml_backend_cpu_buffer_from_ptr(data, size);
241
237
 
242
238
  ggml_tallocr_t alloc = (ggml_tallocr_t)malloc(sizeof(struct ggml_tallocr));
243
239
 
@@ -449,7 +445,6 @@ static ggml_tallocr_t node_tallocr(ggml_gallocr_t galloc, struct ggml_tensor * n
449
445
  static void init_view(ggml_gallocr_t galloc, struct ggml_tensor * view, bool update_backend) {
450
446
  ggml_tallocr_t alloc = node_tallocr(galloc, view);
451
447
 
452
- //printf("init_view: %s from src %s\n", view->name, view->view_src->name);
453
448
  GGML_ASSERT(view->view_src != NULL && view->view_src->data != NULL);
454
449
  if (update_backend) {
455
450
  view->backend = view->view_src->backend;
@@ -459,7 +454,7 @@ static void init_view(ggml_gallocr_t galloc, struct ggml_tensor * view, bool upd
459
454
 
460
455
  // FIXME: the view should be initialized by the owning buffer, but currently this breaks the CUDA backend
461
456
  // due to the ggml_tensor_extra_gpu ring buffer overwriting the KV cache extras
462
- assert(ggml_tallocr_is_measure(alloc) || !view->buffer || view->buffer->backend == alloc->buffer->backend);
457
+ assert(ggml_tallocr_is_measure(alloc) || !view->buffer || view->buffer->buft == alloc->buffer->buft);
463
458
 
464
459
  if (!alloc->measure) {
465
460
  ggml_backend_buffer_init_tensor(alloc->buffer, view);
@@ -765,3 +760,43 @@ size_t ggml_allocr_max_size(ggml_allocr_t alloc) {
765
760
  size_t ggml_allocr_alloc_graph(ggml_allocr_t alloc, struct ggml_cgraph * graph) {
766
761
  return ggml_gallocr_alloc_graph(alloc->galloc, alloc->talloc, graph);
767
762
  }
763
+
764
+ // utils
765
+ ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) {
766
+ GGML_ASSERT(ggml_get_no_alloc(ctx) == true);
767
+
768
+ size_t alignment = ggml_backend_buft_get_alignment(buft);
769
+
770
+ size_t nbytes = 0;
771
+ for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
772
+ if (t->data == NULL && t->view_src == NULL) {
773
+ nbytes += GGML_PAD(ggml_backend_buft_get_alloc_size(buft, t), alignment);
774
+ }
775
+ }
776
+
777
+ if (nbytes == 0) {
778
+ fprintf(stderr, "%s: no tensors to allocate\n", __func__);
779
+ return NULL;
780
+ }
781
+
782
+ ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, nbytes);
783
+ ggml_tallocr_t tallocr = ggml_tallocr_new_from_buffer(buffer);
784
+
785
+ for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
786
+ if (t->data == NULL) {
787
+ if (t->view_src == NULL) {
788
+ ggml_tallocr_alloc(tallocr, t);
789
+ } else {
790
+ ggml_backend_view_init(buffer, t);
791
+ }
792
+ }
793
+ }
794
+
795
+ ggml_tallocr_free(tallocr);
796
+
797
+ return buffer;
798
+ }
799
+
800
+ ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend) {
801
+ return ggml_backend_alloc_ctx_tensors_from_buft(ctx, ggml_backend_get_default_buffer_type(backend));
802
+ }
@@ -8,6 +8,7 @@ extern "C" {
8
8
 
9
9
  struct ggml_backend;
10
10
  struct ggml_backend_buffer;
11
+ struct ggml_backend_buffer_type;
11
12
 
12
13
  //
13
14
  // Legacy API
@@ -80,6 +81,12 @@ GGML_API void ggml_gallocr_alloc_graph_n(
80
81
  struct ggml_hash_set hash_set,
81
82
  ggml_tallocr_t * hash_node_talloc);
82
83
 
84
+
85
+ // Utils
86
+ // Create a buffer and allocate all the tensors in a ggml_context
87
+ GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, struct ggml_backend_buffer_type * buft);
88
+ GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, struct ggml_backend * backend);
89
+
83
90
  #ifdef __cplusplus
84
91
  }
85
92
  #endif
@@ -12,31 +12,50 @@ extern "C" {
12
12
  // Backend buffer
13
13
  //
14
14
 
15
+ // buffer type
16
+ typedef void * ggml_backend_buffer_type_context_t;
17
+
18
+ struct ggml_backend_buffer_type_i {
19
+ ggml_backend_buffer_t (*alloc_buffer) (ggml_backend_buffer_type_t buft, size_t size);
20
+ size_t (*get_alignment) (ggml_backend_buffer_type_t buft); // tensor alignment
21
+ size_t (*get_alloc_size) (ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor); // data size needed to allocate the tensor, including padding
22
+ bool (*supports_backend)(ggml_backend_buffer_type_t buft, ggml_backend_t backend); // check if the buffer type is usable by the backend
23
+ };
24
+
25
+ struct ggml_backend_buffer_type {
26
+ struct ggml_backend_buffer_type_i iface;
27
+ ggml_backend_buffer_type_context_t context;
28
+ };
29
+
30
+ // buffer
15
31
  typedef void * ggml_backend_buffer_context_t;
16
32
 
17
33
  struct ggml_backend_buffer_i {
18
- void (*free_buffer) (ggml_backend_buffer_t buffer);
19
- void * (*get_base) (ggml_backend_buffer_t buffer); // get base pointer
20
- size_t (*get_alloc_size)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-allocation callback
21
- void (*init_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // post-allocation callback
22
- void (*free_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-free callback
34
+ void (*free_buffer)(ggml_backend_buffer_t buffer);
35
+ //void (*reset) (ggml_backend_buffer_t buffer); // reset any internal state due to tensor initialization, such as tensor extras
36
+ void * (*get_base) (ggml_backend_buffer_t buffer);
37
+ void (*init_tensor)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
38
+ void (*set_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
39
+ void (*get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
40
+ // (optional) copy tensor between different buffer-type, allow for single-copy tranfers
41
+ void (*cpy_tensor_from)(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst);
42
+ void (*cpy_tensor_to) (ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst);
23
43
  };
24
44
 
25
45
  struct ggml_backend_buffer {
26
- struct ggml_backend_buffer_i iface;
27
-
28
- ggml_backend_t backend;
46
+ struct ggml_backend_buffer_i iface;
47
+ ggml_backend_buffer_type_t buft;
29
48
  ggml_backend_buffer_context_t context;
30
-
31
49
  size_t size;
32
50
  };
33
51
 
34
- GGML_API ggml_backend_buffer_t ggml_backend_buffer_init(
35
- struct ggml_backend * backend,
52
+ ggml_backend_buffer_t ggml_backend_buffer_init(
53
+ ggml_backend_buffer_type_t buft,
36
54
  struct ggml_backend_buffer_i iface,
37
55
  ggml_backend_buffer_context_t context,
38
56
  size_t size);
39
57
 
58
+
40
59
  //
41
60
  // Backend
42
61
  //
@@ -49,20 +68,17 @@ extern "C" {
49
68
  void (*free)(ggml_backend_t backend);
50
69
 
51
70
  // buffer allocation
52
- ggml_backend_buffer_t (*alloc_buffer)(ggml_backend_t backend, size_t size);
71
+ ggml_backend_buffer_type_t (*get_default_buffer_type)(ggml_backend_t backend);
53
72
 
54
- // get buffer alignment
55
- size_t (*get_alignment)(ggml_backend_t backend);
56
-
57
- // tensor data access
58
- // these functions can be asynchronous, helper functions are provided for synchronous access that automatically call synchronize
73
+ // (optional) asynchroneous tensor data access
59
74
  void (*set_tensor_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
60
75
  void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
61
- void (*synchronize) (ggml_backend_t backend);
62
76
 
63
- // (optional) copy tensor between different backends, allow for single-copy tranfers
64
- void (*cpy_tensor_from)(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
65
- void (*cpy_tensor_to) (ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
77
+ // (optional) asynchroneous tensor copy
78
+ void (*cpy_tensor_from_async)(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
79
+ void (*cpy_tensor_to_async) (ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
80
+
81
+ void (*synchronize) (ggml_backend_t backend);
66
82
 
67
83
  // compute graph with a plan
68
84
  ggml_backend_graph_plan_t (*graph_plan_create) (ggml_backend_t backend, struct ggml_cgraph * cgraph);
@@ -82,6 +98,15 @@ extern "C" {
82
98
  ggml_backend_context_t context;
83
99
  };
84
100
 
101
+
102
+ //
103
+ // Backend registry
104
+ //
105
+
106
+ typedef ggml_backend_t (*ggml_backend_init_fn)(const char * params, void * user_data);
107
+
108
+ void ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data);
109
+
85
110
  #ifdef __cplusplus
86
111
  }
87
112
  #endif