llama_cpp 0.3.3 → 0.3.5

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: cf337091019bb773e47cf206ff2ff30ed0bef963094494e6493455cad7c59840
4
- data.tar.gz: fdbae8e08a6b87d49c5658d5c1857f20bf8efdf5a5371906630dccf4eb0f1159
3
+ metadata.gz: 991df3df6b16ec98a203a6c6565794988eec04697ccb963faab976f436e1bfcc
4
+ data.tar.gz: dafd3b8274640eb79353e11056f497a02392fef332a46dcc1717878c836f62bd
5
5
  SHA512:
6
- metadata.gz: f0fee68294960c5ab9f56ebfe7256a00f9330e55f4954f2b016e07cbc023570298fa8f8b578f3e187fe9183b869769085311931122f93a033c6c21158b4e9485
7
- data.tar.gz: 7eec8c98ae9ec1a56fa4bdb4e83a2dc2bdea407fc037af8d1b8f09a30c0d1246333d410707f4d66f3f473bf73574757cf12e56a86a0cb47074501f63f65f0c02
6
+ metadata.gz: 0a54fdf18c5be5273f01d4d991b59975a8e8b6a0a8f54087fb90df3f1a8a7ebad557d01fb119f3d61cafa8f6c59fb81641624779fda27b732eea2868cb4642e8
7
+ data.tar.gz: 6574064078070502e36ad933bd5efb2f479c94f47a1260286fe485e48d35a7b3985274943c6163189326891349b47b1d9815d77c600b015fe111cc3842179392
data/CHANGELOG.md CHANGED
@@ -1,3 +1,34 @@
1
+ ## [[0.3.5](https://github.com/yoshoku/llama_cpp.rb/compare/v0.3.4...v0.3.5)] - 2023-07-29
2
+
3
+ - Bump bundled llama.cpp from master-d924522 to master-1a94186.
4
+ - Add `GrammarElement` and `Grammar` classes.
5
+ - Add `sample_grammar` method to Context.
6
+ - Add `grammar_accept_token method` method to Context.
7
+
8
+ ## [[0.3.4](https://github.com/yoshoku/llama_cpp.rb/compare/v0.3.3...v0.3.4)] - 2023-07-23
9
+
10
+ - Bump bundled llama.cpp from master-32c5411 to master-d924522.
11
+ - Add `rope_freq_base` and `rope_freq_scale` options to ContextParams.
12
+ - Add `max_devices` module function to LLaMACpp.
13
+ - Add `n_vocab`, `n_ctx`, and `n_embd` methods to Model.
14
+ - Add `vocab`, `tokenize`, and `token_to_str` methods to Model.
15
+ ```ruby
16
+ require 'llama_cpp'
17
+
18
+ params = LLaMACpp::ContextParams.new
19
+ model = LLaMACpp::Model.new(model_path: '/path/to/model.bin', params: params)
20
+
21
+ p model.tokenize(text: 'hello, world')
22
+ # => [12199, 29892, 3186]
23
+
24
+ p model.token_to_str(12199)
25
+ # => "hello"
26
+ ```
27
+
28
+ **Breaking Changes**
29
+ - Fix to automatically call `backend_free` method when Ruby script exits.
30
+ - Remove `smooth_factor` argument from `sample_classifier_free_guidance methos` on Context.
31
+
1
32
  ## [[0.3.3](https://github.com/yoshoku/llama_cpp.rb/compare/v0.3.2...v0.3.3)] - 2023-07-15
2
33
 
3
34
  - Bump bundled llama.cpp from master-481f793 to master-32c5411.
@@ -85,6 +85,7 @@ if with_config('mpi')
85
85
  $CXXFLAGS << ' -DGGML_USE_MPI -Wno-cast-qual'
86
86
  end
87
87
 
88
+ # @!visibility private
88
89
  UNAME_M = RbConfig::CONFIG['build_cpu'] || RbConfig::CONFIG['host_cpu'] || RbConfig::CONFIG['target_cpu']
89
90
 
90
91
  # rubocop:disable Layout/LineLength
@@ -8,6 +8,8 @@ VALUE rb_cLLaMAContextParams;
8
8
  VALUE rb_cLLaMAModelQuantizeParams;
9
9
  VALUE rb_cLLaMATokenData;
10
10
  VALUE rb_cLLaMATokenDataArray;
11
+ VALUE rb_cLLaMAGrammarElement;
12
+ VALUE rb_cLLaMAGrammar;
11
13
 
12
14
  class LLaMATokenDataWrapper {
13
15
  public:
@@ -406,6 +408,10 @@ public:
406
408
  rb_define_method(rb_cLLaMAContextParams, "main_gpu=", RUBY_METHOD_FUNC(_llama_context_params_set_main_gpu), 1);
407
409
  rb_define_method(rb_cLLaMAContextParams, "main_gpu", RUBY_METHOD_FUNC(_llama_context_params_get_main_gpu), 0);
408
410
  rb_define_method(rb_cLLaMAContextParams, "tensor_split", RUBY_METHOD_FUNC(_llama_context_params_get_tensor_split), 0);
411
+ rb_define_method(rb_cLLaMAContextParams, "rope_freq_base=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_freq_base), 1);
412
+ rb_define_method(rb_cLLaMAContextParams, "rope_freq_base", RUBY_METHOD_FUNC(_llama_context_params_get_rope_freq_base), 0);
413
+ rb_define_method(rb_cLLaMAContextParams, "rope_freq_scale=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_freq_scale), 1);
414
+ rb_define_method(rb_cLLaMAContextParams, "rope_freq_scale", RUBY_METHOD_FUNC(_llama_context_params_get_rope_freq_scale), 0);
409
415
  rb_define_method(rb_cLLaMAContextParams, "low_vram=", RUBY_METHOD_FUNC(_llama_context_params_set_low_vram), 1);
410
416
  rb_define_method(rb_cLLaMAContextParams, "low_vram", RUBY_METHOD_FUNC(_llama_context_params_get_low_vram), 0);
411
417
  rb_define_method(rb_cLLaMAContextParams, "seed=", RUBY_METHOD_FUNC(_llama_context_params_set_seed), 1);
@@ -494,6 +500,30 @@ private:
494
500
  return ret;
495
501
  }
496
502
 
503
+ // rope_freq_base
504
+ static VALUE _llama_context_params_set_rope_freq_base(VALUE self, VALUE rope_freq_base) {
505
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
506
+ ptr->params.rope_freq_base = NUM2DBL(rope_freq_base);
507
+ return DBL2NUM(ptr->params.rope_freq_base);
508
+ }
509
+
510
+ static VALUE _llama_context_params_get_rope_freq_base(VALUE self) {
511
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
512
+ return DBL2NUM(ptr->params.rope_freq_base);
513
+ }
514
+
515
+ // rope_freq_scale
516
+ static VALUE _llama_context_params_set_rope_freq_scale(VALUE self, VALUE rope_freq_scale) {
517
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
518
+ ptr->params.rope_freq_scale = NUM2DBL(rope_freq_scale);
519
+ return DBL2NUM(ptr->params.rope_freq_scale);
520
+ }
521
+
522
+ static VALUE _llama_context_params_get_rope_freq_scale(VALUE self) {
523
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
524
+ return DBL2NUM(ptr->params.rope_freq_scale);
525
+ }
526
+
497
527
  // low_vram
498
528
  static VALUE _llama_context_params_set_low_vram(VALUE self, VALUE low_vram) {
499
529
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
@@ -764,6 +794,12 @@ public:
764
794
  rb_define_method(rb_cLLaMAModel, "free", RUBY_METHOD_FUNC(_llama_model_free), 0);
765
795
  rb_define_method(rb_cLLaMAModel, "load", RUBY_METHOD_FUNC(_llama_model_load), -1);
766
796
  rb_define_method(rb_cLLaMAModel, "apply_lora_from_file", RUBY_METHOD_FUNC(_llama_model_apply_lora_from_file), -1);
797
+ rb_define_method(rb_cLLaMAModel, "n_vocab", RUBY_METHOD_FUNC(_llama_model_get_n_vocab_from_model), 0);
798
+ rb_define_method(rb_cLLaMAModel, "n_ctx", RUBY_METHOD_FUNC(_llama_model_get_n_ctx_from_model), 0);
799
+ rb_define_method(rb_cLLaMAModel, "n_embd", RUBY_METHOD_FUNC(_llama_model_get_n_embd_from_model), 0);
800
+ rb_define_method(rb_cLLaMAModel, "vocab", RUBY_METHOD_FUNC(_llama_model_get_vocab_from_model), -1);
801
+ rb_define_method(rb_cLLaMAModel, "token_to_str", RUBY_METHOD_FUNC(_llama_model_token_to_str_with_model), 1);
802
+ rb_define_method(rb_cLLaMAModel, "tokenize", RUBY_METHOD_FUNC(_llama_model_tokenize_with_model), -1);
767
803
  }
768
804
 
769
805
  private:
@@ -908,6 +944,109 @@ private:
908
944
  }
909
945
  return Qnil;
910
946
  }
947
+
948
+ static VALUE _llama_model_get_n_vocab_from_model(VALUE self) {
949
+ LLaMAModelWrapper* ptr = get_llama_model(self);
950
+ return INT2NUM(llama_n_vocab_from_model(ptr->model));
951
+ }
952
+
953
+ static VALUE _llama_model_get_n_ctx_from_model(VALUE self) {
954
+ LLaMAModelWrapper* ptr = get_llama_model(self);
955
+ return INT2NUM(llama_n_ctx_from_model(ptr->model));
956
+ }
957
+
958
+ static VALUE _llama_model_get_n_embd_from_model(VALUE self) {
959
+ LLaMAModelWrapper* ptr = get_llama_model(self);
960
+ return INT2NUM(llama_n_embd_from_model(ptr->model));
961
+ }
962
+
963
+ static VALUE _llama_model_get_vocab_from_model(int argc, VALUE* argv, VALUE self) {
964
+ VALUE kw_args = Qnil;
965
+ ID kw_table[1] = { rb_intern("capacity") };
966
+ VALUE kw_values[1] = { Qundef };
967
+ rb_scan_args(argc, argv, ":", &kw_args);
968
+ rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
969
+
970
+ if (!RB_INTEGER_TYPE_P(kw_values[0])) {
971
+ rb_raise(rb_eArgError, "capacity must be an integer");
972
+ return Qnil;
973
+ }
974
+
975
+ const int capacity = NUM2INT(kw_values[0]);
976
+
977
+ LLaMAModelWrapper* ptr = get_llama_model(self);
978
+ const int n = std::min(capacity, llama_n_vocab_from_model(ptr->model));
979
+ const char** vocabs = ALLOCA_N(const char*, n);
980
+ float* scores = ALLOCA_N(float, n);
981
+
982
+ llama_get_vocab_from_model(ptr->model, vocabs, scores, capacity);
983
+
984
+ VALUE vocabs_ary = rb_ary_new();
985
+ VALUE scores_ary = rb_ary_new();
986
+
987
+ for (int i = 0; i < n; i++) {
988
+ rb_ary_push(vocabs_ary, rb_str_new_cstr(vocabs[i]));
989
+ rb_ary_push(scores_ary, DBL2NUM(scores[i]));
990
+ }
991
+
992
+ VALUE ret = rb_ary_new3(2, vocabs_ary, scores_ary);
993
+
994
+ return ret;
995
+ }
996
+
997
+ static VALUE _llama_model_token_to_str_with_model(VALUE self, VALUE token_) {
998
+ if (!RB_INTEGER_TYPE_P(token_)) {
999
+ rb_raise(rb_eArgError, "token must be an integer");
1000
+ return Qnil;
1001
+ }
1002
+ const llama_token token = NUM2INT(token_);
1003
+ LLaMAModelWrapper* ptr = get_llama_model(self);
1004
+ const char* str = llama_token_to_str_with_model(ptr->model, token);
1005
+ return rb_str_new_cstr(str);
1006
+ }
1007
+
1008
+ static VALUE _llama_model_tokenize_with_model(int argc, VALUE* argv, VALUE self) {
1009
+ VALUE kw_args = Qnil;
1010
+ ID kw_table[3] = { rb_intern("text"), rb_intern("n_max_tokens"), rb_intern("add_bos") };
1011
+ VALUE kw_values[3] = { Qundef, Qundef, Qundef };
1012
+ rb_scan_args(argc, argv, ":", &kw_args);
1013
+ rb_get_kwargs(kw_args, kw_table, 1, 2, kw_values);
1014
+
1015
+ if (!RB_TYPE_P(kw_values[0], T_STRING)) {
1016
+ rb_raise(rb_eArgError, "text must be a String");
1017
+ return Qnil;
1018
+ }
1019
+ if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
1020
+ rb_raise(rb_eArgError, "n_max_tokens must be an integer");
1021
+ return Qnil;
1022
+ }
1023
+ if (kw_values[2] != Qundef && (kw_values[2] != Qtrue && kw_values[2] != Qfalse)) {
1024
+ rb_raise(rb_eArgError, "add_bos must be a boolean");
1025
+ return Qnil;
1026
+ }
1027
+
1028
+ VALUE text_ = kw_values[0];
1029
+ std::string text = StringValueCStr(text_);
1030
+ const bool add_bos = kw_values[2] == Qtrue ? true : false;
1031
+ const int n_max_tokens = kw_values[1] != Qundef ? NUM2INT(kw_values[1]) : text.size() + (add_bos ? 1 : 0);
1032
+
1033
+ llama_token* tokens = ALLOCA_N(llama_token, n_max_tokens);
1034
+ LLaMAModelWrapper* ptr = get_llama_model(self);
1035
+ const int n_tokens = llama_tokenize_with_model(ptr->model, text.c_str(), tokens, n_max_tokens, add_bos);
1036
+
1037
+ if (n_tokens < 0) {
1038
+ rb_raise(rb_eRuntimeError, "failed to tokenize. The numebr of tokens (%d) is greater than n_max_tokens.", -n_tokens);
1039
+ return Qnil;
1040
+ }
1041
+
1042
+ VALUE ret = rb_ary_new2(n_tokens);
1043
+ for (int i = 0; i < n_tokens; i++) {
1044
+ rb_ary_store(ret, i, INT2NUM(tokens[i]));
1045
+ }
1046
+
1047
+ RB_GC_GUARD(text_);
1048
+ return ret;
1049
+ }
911
1050
  };
912
1051
 
913
1052
  const rb_data_type_t RbLLaMAModel::llama_model_type = {
@@ -920,6 +1059,222 @@ const rb_data_type_t RbLLaMAModel::llama_model_type = {
920
1059
  RUBY_TYPED_FREE_IMMEDIATELY
921
1060
  };
922
1061
 
1062
+ class LLaMAGrammarElementWrapper {
1063
+ public:
1064
+ llama_grammar_element element;
1065
+
1066
+ LLaMAGrammarElementWrapper() {
1067
+ element.type = LLAMA_GRETYPE_END;
1068
+ element.value = 0;
1069
+ }
1070
+
1071
+ ~LLaMAGrammarElementWrapper() {}
1072
+ };
1073
+
1074
+ class RbLLaMAGrammarElement {
1075
+ public:
1076
+ static VALUE llama_grammar_element_alloc(VALUE self) {
1077
+ LLaMAGrammarElementWrapper* ptr = (LLaMAGrammarElementWrapper*)ruby_xmalloc(sizeof(LLaMAGrammarElementWrapper));
1078
+ new (ptr) LLaMAGrammarElementWrapper();
1079
+ return TypedData_Wrap_Struct(self, &llama_grammar_element_type, ptr);
1080
+ }
1081
+
1082
+ static void llama_grammar_element_free(void* ptr) {
1083
+ ((LLaMAGrammarElementWrapper*)ptr)->~LLaMAGrammarElementWrapper();
1084
+ ruby_xfree(ptr);
1085
+ }
1086
+
1087
+ static size_t llama_grammar_element_size(const void* ptr) {
1088
+ return sizeof(*((LLaMAGrammarElementWrapper*)ptr));
1089
+ }
1090
+
1091
+ static LLaMAGrammarElementWrapper* get_llama_grammar_element(VALUE self) {
1092
+ LLaMAGrammarElementWrapper* ptr;
1093
+ TypedData_Get_Struct(self, LLaMAGrammarElementWrapper, &llama_grammar_element_type, ptr);
1094
+ return ptr;
1095
+ }
1096
+
1097
+ static void define_class(VALUE outer) {
1098
+ rb_cLLaMAGrammarElement = rb_define_class_under(outer, "GrammarElement", rb_cObject);
1099
+ rb_define_alloc_func(rb_cLLaMAGrammarElement, llama_grammar_element_alloc);
1100
+ rb_define_method(rb_cLLaMAGrammarElement, "initialize", RUBY_METHOD_FUNC(_llama_grammar_element_init), -1);
1101
+ rb_define_method(rb_cLLaMAGrammarElement, "type=", RUBY_METHOD_FUNC(_llama_grammar_element_set_type), 1);
1102
+ rb_define_method(rb_cLLaMAGrammarElement, "type", RUBY_METHOD_FUNC(_llama_grammar_element_get_type), 0);
1103
+ rb_define_method(rb_cLLaMAGrammarElement, "value=", RUBY_METHOD_FUNC(_llama_grammar_element_set_value), 1);
1104
+ rb_define_method(rb_cLLaMAGrammarElement, "value", RUBY_METHOD_FUNC(_llama_grammar_element_get_value), 0);
1105
+ }
1106
+
1107
+ private:
1108
+ static const rb_data_type_t llama_grammar_element_type;
1109
+
1110
+ static VALUE _llama_grammar_element_init(int argc, VALUE* argv, VALUE self) {
1111
+ VALUE kw_args = Qnil;
1112
+ ID kw_table[2] = { rb_intern("type"), rb_intern("value") };
1113
+ VALUE kw_values[2] = { Qundef, Qundef };
1114
+ VALUE arr = Qnil;
1115
+ rb_scan_args(argc, argv, ":", &arr, &kw_args);
1116
+ rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
1117
+
1118
+ if (kw_values[0] != Qundef && !RB_INTEGER_TYPE_P(kw_values[0])) {
1119
+ rb_raise(rb_eArgError, "type must be an integer");
1120
+ return Qnil;
1121
+ }
1122
+ if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
1123
+ rb_raise(rb_eArgError, "value must be an integer");
1124
+ return Qnil;
1125
+ }
1126
+
1127
+ LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
1128
+ new (ptr) LLaMAGrammarElementWrapper();
1129
+
1130
+ if (kw_values[0] != Qundef) {
1131
+ ptr->element.type = (enum llama_gretype)NUM2INT(kw_values[0]);
1132
+ }
1133
+ if (kw_values[1] != Qundef) {
1134
+ ptr->element.value = NUM2INT(kw_values[1]);
1135
+ }
1136
+
1137
+ return self;
1138
+ }
1139
+
1140
+ // type
1141
+ static VALUE _llama_grammar_element_set_type(VALUE self, VALUE type) {
1142
+ LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
1143
+ ptr->element.type = (enum llama_gretype)NUM2INT(type);
1144
+ return INT2NUM(ptr->element.type);
1145
+ }
1146
+
1147
+ static VALUE _llama_grammar_element_get_type(VALUE self) {
1148
+ LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
1149
+ return INT2NUM(ptr->element.type);
1150
+ }
1151
+
1152
+ // value
1153
+ static VALUE _llama_grammar_element_set_value(VALUE self, VALUE type) {
1154
+ LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
1155
+ ptr->element.value = NUM2INT(type);
1156
+ return INT2NUM(ptr->element.value);
1157
+ }
1158
+
1159
+ static VALUE _llama_grammar_element_get_value(VALUE self) {
1160
+ LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
1161
+ return INT2NUM(ptr->element.value);
1162
+ }
1163
+ };
1164
+
1165
+ const rb_data_type_t RbLLaMAGrammarElement::llama_grammar_element_type = {
1166
+ "RbLLaMAGrammarElement",
1167
+ { NULL,
1168
+ RbLLaMAGrammarElement::llama_grammar_element_free,
1169
+ RbLLaMAGrammarElement::llama_grammar_element_size },
1170
+ NULL,
1171
+ NULL,
1172
+ RUBY_TYPED_FREE_IMMEDIATELY
1173
+ };
1174
+
1175
+ class LLaMAGrammarWrapper {
1176
+ public:
1177
+ struct llama_grammar* grammar;
1178
+
1179
+ LLaMAGrammarWrapper() : grammar(nullptr) {}
1180
+
1181
+ ~LLaMAGrammarWrapper() {
1182
+ if (grammar) {
1183
+ llama_grammar_free(grammar);
1184
+ }
1185
+ }
1186
+ };
1187
+
1188
+ class RbLLaMAGrammar {
1189
+ public:
1190
+ static VALUE llama_grammar_alloc(VALUE self) {
1191
+ LLaMAGrammarWrapper* ptr = (LLaMAGrammarWrapper*)ruby_xmalloc(sizeof(LLaMAGrammarWrapper));
1192
+ new (ptr) LLaMAGrammarWrapper();
1193
+ return TypedData_Wrap_Struct(self, &llama_grammar_type, ptr);
1194
+ }
1195
+
1196
+ static void llama_grammar_free(void* ptr) {
1197
+ ((LLaMAGrammarWrapper*)ptr)->~LLaMAGrammarWrapper();
1198
+ ruby_xfree(ptr);
1199
+ }
1200
+
1201
+ static size_t llama_grammar_size(const void* ptr) {
1202
+ return sizeof(*((LLaMAGrammarWrapper*)ptr));
1203
+ }
1204
+
1205
+ static LLaMAGrammarWrapper* get_llama_grammar(VALUE self) {
1206
+ LLaMAGrammarWrapper* ptr;
1207
+ TypedData_Get_Struct(self, LLaMAGrammarWrapper, &llama_grammar_type, ptr);
1208
+ return ptr;
1209
+ }
1210
+
1211
+ static void define_class(VALUE outer) {
1212
+ rb_cLLaMAGrammar = rb_define_class_under(outer, "Grammar", rb_cObject);
1213
+ rb_define_alloc_func(rb_cLLaMAGrammar, llama_grammar_alloc);
1214
+ rb_define_method(rb_cLLaMAGrammar, "initialize", RUBY_METHOD_FUNC(_llama_grammar_init), -1);
1215
+ }
1216
+
1217
+ private:
1218
+ static const rb_data_type_t llama_grammar_type;
1219
+
1220
+ static VALUE _llama_grammar_init(int argc, VALUE* argv, VALUE self) {
1221
+ VALUE kw_args = Qnil;
1222
+ ID kw_table[2] = { rb_intern("rules"), rb_intern("start_rule_index") };
1223
+ VALUE kw_values[2] = { Qundef, Qundef };
1224
+ rb_scan_args(argc, argv, ":", &kw_args);
1225
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
1226
+
1227
+ if (!RB_TYPE_P(kw_values[0], T_ARRAY)) {
1228
+ rb_raise(rb_eArgError, "rules must be an array");
1229
+ return Qnil;
1230
+ }
1231
+ if (!RB_INTEGER_TYPE_P(kw_values[1])) {
1232
+ rb_raise(rb_eArgError, "start_rule_index must be an integer");
1233
+ return Qnil;
1234
+ }
1235
+
1236
+ const int n_rules = RARRAY_LEN(kw_values[0]);
1237
+ llama_grammar_element** rules = ALLOCA_N(llama_grammar_element*, n_rules);
1238
+ for (int i = 0; i < n_rules; ++i) {
1239
+ VALUE rule = rb_ary_entry(kw_values[0], i);
1240
+ if (!RB_TYPE_P(rule, T_ARRAY)) {
1241
+ rb_raise(rb_eArgError, "element of rules must be an array");
1242
+ return Qnil;
1243
+ }
1244
+ const int n_elements = RARRAY_LEN(rule);
1245
+ llama_grammar_element* elements = ALLOCA_N(llama_grammar_element, n_elements);
1246
+ for (int j = 0; j < n_elements; ++j) {
1247
+ VALUE element = rb_ary_entry(rule, j);
1248
+ if (!rb_obj_is_kind_of(element, rb_cLLaMAGrammarElement)) {
1249
+ rb_raise(rb_eArgError, "element of rule must be an instance of GrammarElement");
1250
+ return Qnil;
1251
+ }
1252
+ LLaMAGrammarElementWrapper* ptr = RbLLaMAGrammarElement::get_llama_grammar_element(element);
1253
+ elements[j] = ptr->element;
1254
+ }
1255
+ rules[i] = elements;
1256
+ }
1257
+
1258
+ const size_t start_rule_index = NUM2SIZET(kw_values[1]);
1259
+
1260
+ LLaMAGrammarWrapper* ptr = get_llama_grammar(self);
1261
+ new (ptr) LLaMAGrammarWrapper();
1262
+ ptr->grammar = llama_grammar_init((const llama_grammar_element**)rules, n_rules, start_rule_index);
1263
+
1264
+ return self;
1265
+ }
1266
+ };
1267
+
1268
+ const rb_data_type_t RbLLaMAGrammar::llama_grammar_type = {
1269
+ "RbLLaMAGrammar",
1270
+ { NULL,
1271
+ RbLLaMAGrammar::llama_grammar_free,
1272
+ RbLLaMAGrammar::llama_grammar_size },
1273
+ NULL,
1274
+ NULL,
1275
+ RUBY_TYPED_FREE_IMMEDIATELY
1276
+ };
1277
+
923
1278
  class LLaMAContextWrapper {
924
1279
  public:
925
1280
  struct llama_context* ctx;
@@ -991,6 +1346,8 @@ public:
991
1346
  rb_define_method(rb_cLLaMAContext, "sample_token_mirostat_v2", RUBY_METHOD_FUNC(_llama_context_sample_token_mirostat_v2), -1);
992
1347
  rb_define_method(rb_cLLaMAContext, "sample_token_greedy", RUBY_METHOD_FUNC(_llama_context_sample_token_greedy), 1);
993
1348
  rb_define_method(rb_cLLaMAContext, "sample_token", RUBY_METHOD_FUNC(_llama_context_sample_token), 1);
1349
+ rb_define_method(rb_cLLaMAContext, "sample_grammar", RUBY_METHOD_FUNC(_llama_context_sample_grammar), -1);
1350
+ rb_define_method(rb_cLLaMAContext, "grammar_accept_token", RUBY_METHOD_FUNC(_llama_context_grammar_accept_token), -1);
994
1351
  }
995
1352
 
996
1353
  private:
@@ -1581,11 +1938,11 @@ private:
1581
1938
 
1582
1939
  static VALUE _llama_context_sample_classifier_free_guidance(int argc, VALUE* argv, VALUE self) {
1583
1940
  VALUE kw_args = Qnil;
1584
- ID kw_table[3] = { rb_intern("guidance"), rb_intern("scale"), rb_intern("smooth_factor") };
1585
- VALUE kw_values[3] = { Qundef, Qundef, Qundef };
1941
+ ID kw_table[2] = { rb_intern("guidance"), rb_intern("scale") };
1942
+ VALUE kw_values[2] = { Qundef, Qundef };
1586
1943
  VALUE candidates = Qnil;
1587
1944
  rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
1588
- rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
1945
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
1589
1946
 
1590
1947
  if (!rb_obj_is_kind_of(kw_values[0], rb_cLLaMAContext)) {
1591
1948
  rb_raise(rb_eArgError, "guidance must be a Context");
@@ -1595,10 +1952,6 @@ private:
1595
1952
  rb_raise(rb_eArgError, "scale must be a float");
1596
1953
  return Qnil;
1597
1954
  }
1598
- if (!RB_FLOAT_TYPE_P(kw_values[2])) {
1599
- rb_raise(rb_eArgError, "smooth_factor must be a float");
1600
- return Qnil;
1601
- }
1602
1955
 
1603
1956
  LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
1604
1957
  if (ctx_ptr->ctx == NULL) {
@@ -1617,9 +1970,8 @@ private:
1617
1970
  return Qnil;
1618
1971
  }
1619
1972
  const float scale = NUM2DBL(kw_values[1]);
1620
- const float smooth_factor = NUM2DBL(kw_values[2]);
1621
1973
 
1622
- llama_sample_classifier_free_guidance(ctx_ptr->ctx, &(cnd_ptr->array), guidance_ptr->ctx, scale, smooth_factor);
1974
+ llama_sample_classifier_free_guidance(ctx_ptr->ctx, &(cnd_ptr->array), guidance_ptr->ctx, scale);
1623
1975
 
1624
1976
  return Qnil;
1625
1977
  }
@@ -1972,6 +2324,69 @@ private:
1972
2324
  llama_token id = llama_sample_token(ctx_ptr->ctx, &(cnd_ptr->array));
1973
2325
  return INT2NUM(id);
1974
2326
  }
2327
+
2328
+ static VALUE _llama_context_sample_grammar(int argc, VALUE* argv, VALUE self) {
2329
+ VALUE kw_args = Qnil;
2330
+ ID kw_table[1] = { rb_intern("grammar") };
2331
+ VALUE kw_values[1] = { Qundef };
2332
+ VALUE candidates = Qnil;
2333
+ rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
2334
+ rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
2335
+
2336
+ if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
2337
+ rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
2338
+ return Qnil;
2339
+ }
2340
+ if (!rb_obj_is_kind_of(kw_values[0], rb_cLLaMAGrammar)) {
2341
+ rb_raise(rb_eArgError, "grammar must be a Grammar");
2342
+ return Qnil;
2343
+ }
2344
+
2345
+ LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
2346
+ if (ctx_ptr->ctx == NULL) {
2347
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2348
+ return Qnil;
2349
+ }
2350
+ LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
2351
+ if (cnd_ptr->array.data == nullptr) {
2352
+ rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
2353
+ return Qnil;
2354
+ }
2355
+ LLaMAGrammarWrapper* grm_ptr = RbLLaMAGrammar::get_llama_grammar(kw_values[0]);
2356
+
2357
+ llama_sample_grammar(ctx_ptr->ctx, &(cnd_ptr->array), grm_ptr->grammar);
2358
+
2359
+ return Qnil;
2360
+ }
2361
+
2362
+ static VALUE _llama_context_grammar_accept_token(int argc, VALUE* argv, VALUE self) {
2363
+ VALUE kw_args = Qnil;
2364
+ ID kw_table[2] = { rb_intern("grammar"), rb_intern("token") };
2365
+ VALUE kw_values[2] = { Qundef, Qundef };
2366
+ rb_scan_args(argc, argv, ":", &kw_args);
2367
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
2368
+
2369
+ if (!rb_obj_is_kind_of(kw_values[0], rb_cLLaMAGrammar)) {
2370
+ rb_raise(rb_eArgError, "grammar must be a Grammar");
2371
+ return Qnil;
2372
+ }
2373
+ if (!RB_INTEGER_TYPE_P(kw_values[1])) {
2374
+ rb_raise(rb_eArgError, "token must be an Integer");
2375
+ return Qnil;
2376
+ }
2377
+
2378
+ LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
2379
+ if (ctx_ptr->ctx == NULL) {
2380
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2381
+ return Qnil;
2382
+ }
2383
+ LLaMAGrammarWrapper* grm_ptr = RbLLaMAGrammar::get_llama_grammar(kw_values[0]);
2384
+ llama_token token = NUM2INT(kw_values[1]);
2385
+
2386
+ llama_grammar_accept_token(ctx_ptr->ctx, grm_ptr->grammar, token);
2387
+
2388
+ return Qnil;
2389
+ }
1975
2390
  };
1976
2391
 
1977
2392
  const rb_data_type_t RbLLaMAContext::llama_context_type = {
@@ -2062,6 +2477,10 @@ static VALUE rb_llama_mlock_supported(VALUE self) {
2062
2477
  return llama_mlock_supported() ? Qtrue : Qfalse;
2063
2478
  }
2064
2479
 
2480
+ static VALUE rb_llama_max_devices(VALUE self) {
2481
+ return INT2NUM(llama_max_devices());
2482
+ }
2483
+
2065
2484
  extern "C" void Init_llama_cpp(void) {
2066
2485
  rb_mLLaMACpp = rb_define_module("LLaMACpp");
2067
2486
 
@@ -2072,6 +2491,8 @@ extern "C" void Init_llama_cpp(void) {
2072
2491
  RbLLaMAContext::define_class(rb_mLLaMACpp);
2073
2492
  RbLLaMAContextParams::define_class(rb_mLLaMACpp);
2074
2493
  RbLLaMAModelQuantizeParams::define_class(rb_mLLaMACpp);
2494
+ RbLLaMAGrammarElement::define_class(rb_mLLaMACpp);
2495
+ RbLLaMAGrammar::define_class(rb_mLLaMACpp);
2075
2496
 
2076
2497
  rb_define_module_function(rb_mLLaMACpp, "backend_init", rb_llama_llama_backend_init, -1);
2077
2498
  rb_define_module_function(rb_mLLaMACpp, "backend_free", rb_llama_llama_backend_free, 0);
@@ -2082,6 +2503,7 @@ extern "C" void Init_llama_cpp(void) {
2082
2503
  rb_define_module_function(rb_mLLaMACpp, "print_system_info", rb_llama_print_system_info, 0);
2083
2504
  rb_define_module_function(rb_mLLaMACpp, "mmap_supported?", rb_llama_mmap_supported, 0);
2084
2505
  rb_define_module_function(rb_mLLaMACpp, "mlock_supported?", rb_llama_mlock_supported, 0);
2506
+ rb_define_module_function(rb_mLLaMACpp, "max_devices", rb_llama_max_devices, 0);
2085
2507
 
2086
2508
  rb_define_const(rb_mLLaMACpp, "LLAMA_MAX_DEVICES", INT2NUM(LLAMA_MAX_DEVICES));
2087
2509
 
@@ -2103,6 +2525,14 @@ extern "C" void Init_llama_cpp(void) {
2103
2525
  rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q5_K_M", INT2NUM(LLAMA_FTYPE_MOSTLY_Q5_K_M));
2104
2526
  rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q6_K", INT2NUM(LLAMA_FTYPE_MOSTLY_Q6_K));
2105
2527
 
2528
+ rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_END", INT2NUM(LLAMA_GRETYPE_END));
2529
+ rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_ALT", INT2NUM(LLAMA_GRETYPE_ALT));
2530
+ rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_RULE_REF", INT2NUM(LLAMA_GRETYPE_RULE_REF));
2531
+ rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR", INT2NUM(LLAMA_GRETYPE_CHAR));
2532
+ rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR_NOT", INT2NUM(LLAMA_GRETYPE_CHAR_NOT));
2533
+ rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR_RNG_UPPER", INT2NUM(LLAMA_GRETYPE_CHAR_RNG_UPPER));
2534
+ rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR_ALT", INT2NUM(LLAMA_GRETYPE_CHAR_ALT));
2535
+
2106
2536
  std::stringstream ss_magic;
2107
2537
  ss_magic << std::showbase << std::hex << LLAMA_FILE_MAGIC_GGJT;
2108
2538
  rb_define_const(rb_mLLaMACpp, "LLAMA_FILE_MAGIC_GGJT", rb_str_new2(ss_magic.str().c_str()));