llama_cpp 0.3.3 → 0.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: cf337091019bb773e47cf206ff2ff30ed0bef963094494e6493455cad7c59840
4
- data.tar.gz: fdbae8e08a6b87d49c5658d5c1857f20bf8efdf5a5371906630dccf4eb0f1159
3
+ metadata.gz: 991df3df6b16ec98a203a6c6565794988eec04697ccb963faab976f436e1bfcc
4
+ data.tar.gz: dafd3b8274640eb79353e11056f497a02392fef332a46dcc1717878c836f62bd
5
5
  SHA512:
6
- metadata.gz: f0fee68294960c5ab9f56ebfe7256a00f9330e55f4954f2b016e07cbc023570298fa8f8b578f3e187fe9183b869769085311931122f93a033c6c21158b4e9485
7
- data.tar.gz: 7eec8c98ae9ec1a56fa4bdb4e83a2dc2bdea407fc037af8d1b8f09a30c0d1246333d410707f4d66f3f473bf73574757cf12e56a86a0cb47074501f63f65f0c02
6
+ metadata.gz: 0a54fdf18c5be5273f01d4d991b59975a8e8b6a0a8f54087fb90df3f1a8a7ebad557d01fb119f3d61cafa8f6c59fb81641624779fda27b732eea2868cb4642e8
7
+ data.tar.gz: 6574064078070502e36ad933bd5efb2f479c94f47a1260286fe485e48d35a7b3985274943c6163189326891349b47b1d9815d77c600b015fe111cc3842179392
data/CHANGELOG.md CHANGED
@@ -1,3 +1,34 @@
1
+ ## [[0.3.5](https://github.com/yoshoku/llama_cpp.rb/compare/v0.3.4...v0.3.5)] - 2023-07-29
2
+
3
+ - Bump bundled llama.cpp from master-d924522 to master-1a94186.
4
+ - Add `GrammarElement` and `Grammar` classes.
5
+ - Add `sample_grammar` method to Context.
6
+ - Add `grammar_accept_token method` method to Context.
7
+
8
+ ## [[0.3.4](https://github.com/yoshoku/llama_cpp.rb/compare/v0.3.3...v0.3.4)] - 2023-07-23
9
+
10
+ - Bump bundled llama.cpp from master-32c5411 to master-d924522.
11
+ - Add `rope_freq_base` and `rope_freq_scale` options to ContextParams.
12
+ - Add `max_devices` module function to LLaMACpp.
13
+ - Add `n_vocab`, `n_ctx`, and `n_embd` methods to Model.
14
+ - Add `vocab`, `tokenize`, and `token_to_str` methods to Model.
15
+ ```ruby
16
+ require 'llama_cpp'
17
+
18
+ params = LLaMACpp::ContextParams.new
19
+ model = LLaMACpp::Model.new(model_path: '/path/to/model.bin', params: params)
20
+
21
+ p model.tokenize(text: 'hello, world')
22
+ # => [12199, 29892, 3186]
23
+
24
+ p model.token_to_str(12199)
25
+ # => "hello"
26
+ ```
27
+
28
+ **Breaking Changes**
29
+ - Fix to automatically call `backend_free` method when Ruby script exits.
30
+ - Remove `smooth_factor` argument from `sample_classifier_free_guidance methos` on Context.
31
+
1
32
  ## [[0.3.3](https://github.com/yoshoku/llama_cpp.rb/compare/v0.3.2...v0.3.3)] - 2023-07-15
2
33
 
3
34
  - Bump bundled llama.cpp from master-481f793 to master-32c5411.
@@ -85,6 +85,7 @@ if with_config('mpi')
85
85
  $CXXFLAGS << ' -DGGML_USE_MPI -Wno-cast-qual'
86
86
  end
87
87
 
88
+ # @!visibility private
88
89
  UNAME_M = RbConfig::CONFIG['build_cpu'] || RbConfig::CONFIG['host_cpu'] || RbConfig::CONFIG['target_cpu']
89
90
 
90
91
  # rubocop:disable Layout/LineLength
@@ -8,6 +8,8 @@ VALUE rb_cLLaMAContextParams;
8
8
  VALUE rb_cLLaMAModelQuantizeParams;
9
9
  VALUE rb_cLLaMATokenData;
10
10
  VALUE rb_cLLaMATokenDataArray;
11
+ VALUE rb_cLLaMAGrammarElement;
12
+ VALUE rb_cLLaMAGrammar;
11
13
 
12
14
  class LLaMATokenDataWrapper {
13
15
  public:
@@ -406,6 +408,10 @@ public:
406
408
  rb_define_method(rb_cLLaMAContextParams, "main_gpu=", RUBY_METHOD_FUNC(_llama_context_params_set_main_gpu), 1);
407
409
  rb_define_method(rb_cLLaMAContextParams, "main_gpu", RUBY_METHOD_FUNC(_llama_context_params_get_main_gpu), 0);
408
410
  rb_define_method(rb_cLLaMAContextParams, "tensor_split", RUBY_METHOD_FUNC(_llama_context_params_get_tensor_split), 0);
411
+ rb_define_method(rb_cLLaMAContextParams, "rope_freq_base=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_freq_base), 1);
412
+ rb_define_method(rb_cLLaMAContextParams, "rope_freq_base", RUBY_METHOD_FUNC(_llama_context_params_get_rope_freq_base), 0);
413
+ rb_define_method(rb_cLLaMAContextParams, "rope_freq_scale=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_freq_scale), 1);
414
+ rb_define_method(rb_cLLaMAContextParams, "rope_freq_scale", RUBY_METHOD_FUNC(_llama_context_params_get_rope_freq_scale), 0);
409
415
  rb_define_method(rb_cLLaMAContextParams, "low_vram=", RUBY_METHOD_FUNC(_llama_context_params_set_low_vram), 1);
410
416
  rb_define_method(rb_cLLaMAContextParams, "low_vram", RUBY_METHOD_FUNC(_llama_context_params_get_low_vram), 0);
411
417
  rb_define_method(rb_cLLaMAContextParams, "seed=", RUBY_METHOD_FUNC(_llama_context_params_set_seed), 1);
@@ -494,6 +500,30 @@ private:
494
500
  return ret;
495
501
  }
496
502
 
503
+ // rope_freq_base
504
+ static VALUE _llama_context_params_set_rope_freq_base(VALUE self, VALUE rope_freq_base) {
505
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
506
+ ptr->params.rope_freq_base = NUM2DBL(rope_freq_base);
507
+ return DBL2NUM(ptr->params.rope_freq_base);
508
+ }
509
+
510
+ static VALUE _llama_context_params_get_rope_freq_base(VALUE self) {
511
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
512
+ return DBL2NUM(ptr->params.rope_freq_base);
513
+ }
514
+
515
+ // rope_freq_scale
516
+ static VALUE _llama_context_params_set_rope_freq_scale(VALUE self, VALUE rope_freq_scale) {
517
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
518
+ ptr->params.rope_freq_scale = NUM2DBL(rope_freq_scale);
519
+ return DBL2NUM(ptr->params.rope_freq_scale);
520
+ }
521
+
522
+ static VALUE _llama_context_params_get_rope_freq_scale(VALUE self) {
523
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
524
+ return DBL2NUM(ptr->params.rope_freq_scale);
525
+ }
526
+
497
527
  // low_vram
498
528
  static VALUE _llama_context_params_set_low_vram(VALUE self, VALUE low_vram) {
499
529
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
@@ -764,6 +794,12 @@ public:
764
794
  rb_define_method(rb_cLLaMAModel, "free", RUBY_METHOD_FUNC(_llama_model_free), 0);
765
795
  rb_define_method(rb_cLLaMAModel, "load", RUBY_METHOD_FUNC(_llama_model_load), -1);
766
796
  rb_define_method(rb_cLLaMAModel, "apply_lora_from_file", RUBY_METHOD_FUNC(_llama_model_apply_lora_from_file), -1);
797
+ rb_define_method(rb_cLLaMAModel, "n_vocab", RUBY_METHOD_FUNC(_llama_model_get_n_vocab_from_model), 0);
798
+ rb_define_method(rb_cLLaMAModel, "n_ctx", RUBY_METHOD_FUNC(_llama_model_get_n_ctx_from_model), 0);
799
+ rb_define_method(rb_cLLaMAModel, "n_embd", RUBY_METHOD_FUNC(_llama_model_get_n_embd_from_model), 0);
800
+ rb_define_method(rb_cLLaMAModel, "vocab", RUBY_METHOD_FUNC(_llama_model_get_vocab_from_model), -1);
801
+ rb_define_method(rb_cLLaMAModel, "token_to_str", RUBY_METHOD_FUNC(_llama_model_token_to_str_with_model), 1);
802
+ rb_define_method(rb_cLLaMAModel, "tokenize", RUBY_METHOD_FUNC(_llama_model_tokenize_with_model), -1);
767
803
  }
768
804
 
769
805
  private:
@@ -908,6 +944,109 @@ private:
908
944
  }
909
945
  return Qnil;
910
946
  }
947
+
948
+ static VALUE _llama_model_get_n_vocab_from_model(VALUE self) {
949
+ LLaMAModelWrapper* ptr = get_llama_model(self);
950
+ return INT2NUM(llama_n_vocab_from_model(ptr->model));
951
+ }
952
+
953
+ static VALUE _llama_model_get_n_ctx_from_model(VALUE self) {
954
+ LLaMAModelWrapper* ptr = get_llama_model(self);
955
+ return INT2NUM(llama_n_ctx_from_model(ptr->model));
956
+ }
957
+
958
+ static VALUE _llama_model_get_n_embd_from_model(VALUE self) {
959
+ LLaMAModelWrapper* ptr = get_llama_model(self);
960
+ return INT2NUM(llama_n_embd_from_model(ptr->model));
961
+ }
962
+
963
+ static VALUE _llama_model_get_vocab_from_model(int argc, VALUE* argv, VALUE self) {
964
+ VALUE kw_args = Qnil;
965
+ ID kw_table[1] = { rb_intern("capacity") };
966
+ VALUE kw_values[1] = { Qundef };
967
+ rb_scan_args(argc, argv, ":", &kw_args);
968
+ rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
969
+
970
+ if (!RB_INTEGER_TYPE_P(kw_values[0])) {
971
+ rb_raise(rb_eArgError, "capacity must be an integer");
972
+ return Qnil;
973
+ }
974
+
975
+ const int capacity = NUM2INT(kw_values[0]);
976
+
977
+ LLaMAModelWrapper* ptr = get_llama_model(self);
978
+ const int n = std::min(capacity, llama_n_vocab_from_model(ptr->model));
979
+ const char** vocabs = ALLOCA_N(const char*, n);
980
+ float* scores = ALLOCA_N(float, n);
981
+
982
+ llama_get_vocab_from_model(ptr->model, vocabs, scores, capacity);
983
+
984
+ VALUE vocabs_ary = rb_ary_new();
985
+ VALUE scores_ary = rb_ary_new();
986
+
987
+ for (int i = 0; i < n; i++) {
988
+ rb_ary_push(vocabs_ary, rb_str_new_cstr(vocabs[i]));
989
+ rb_ary_push(scores_ary, DBL2NUM(scores[i]));
990
+ }
991
+
992
+ VALUE ret = rb_ary_new3(2, vocabs_ary, scores_ary);
993
+
994
+ return ret;
995
+ }
996
+
997
+ static VALUE _llama_model_token_to_str_with_model(VALUE self, VALUE token_) {
998
+ if (!RB_INTEGER_TYPE_P(token_)) {
999
+ rb_raise(rb_eArgError, "token must be an integer");
1000
+ return Qnil;
1001
+ }
1002
+ const llama_token token = NUM2INT(token_);
1003
+ LLaMAModelWrapper* ptr = get_llama_model(self);
1004
+ const char* str = llama_token_to_str_with_model(ptr->model, token);
1005
+ return rb_str_new_cstr(str);
1006
+ }
1007
+
1008
+ static VALUE _llama_model_tokenize_with_model(int argc, VALUE* argv, VALUE self) {
1009
+ VALUE kw_args = Qnil;
1010
+ ID kw_table[3] = { rb_intern("text"), rb_intern("n_max_tokens"), rb_intern("add_bos") };
1011
+ VALUE kw_values[3] = { Qundef, Qundef, Qundef };
1012
+ rb_scan_args(argc, argv, ":", &kw_args);
1013
+ rb_get_kwargs(kw_args, kw_table, 1, 2, kw_values);
1014
+
1015
+ if (!RB_TYPE_P(kw_values[0], T_STRING)) {
1016
+ rb_raise(rb_eArgError, "text must be a String");
1017
+ return Qnil;
1018
+ }
1019
+ if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
1020
+ rb_raise(rb_eArgError, "n_max_tokens must be an integer");
1021
+ return Qnil;
1022
+ }
1023
+ if (kw_values[2] != Qundef && (kw_values[2] != Qtrue && kw_values[2] != Qfalse)) {
1024
+ rb_raise(rb_eArgError, "add_bos must be a boolean");
1025
+ return Qnil;
1026
+ }
1027
+
1028
+ VALUE text_ = kw_values[0];
1029
+ std::string text = StringValueCStr(text_);
1030
+ const bool add_bos = kw_values[2] == Qtrue ? true : false;
1031
+ const int n_max_tokens = kw_values[1] != Qundef ? NUM2INT(kw_values[1]) : text.size() + (add_bos ? 1 : 0);
1032
+
1033
+ llama_token* tokens = ALLOCA_N(llama_token, n_max_tokens);
1034
+ LLaMAModelWrapper* ptr = get_llama_model(self);
1035
+ const int n_tokens = llama_tokenize_with_model(ptr->model, text.c_str(), tokens, n_max_tokens, add_bos);
1036
+
1037
+ if (n_tokens < 0) {
1038
+ rb_raise(rb_eRuntimeError, "failed to tokenize. The numebr of tokens (%d) is greater than n_max_tokens.", -n_tokens);
1039
+ return Qnil;
1040
+ }
1041
+
1042
+ VALUE ret = rb_ary_new2(n_tokens);
1043
+ for (int i = 0; i < n_tokens; i++) {
1044
+ rb_ary_store(ret, i, INT2NUM(tokens[i]));
1045
+ }
1046
+
1047
+ RB_GC_GUARD(text_);
1048
+ return ret;
1049
+ }
911
1050
  };
912
1051
 
913
1052
  const rb_data_type_t RbLLaMAModel::llama_model_type = {
@@ -920,6 +1059,222 @@ const rb_data_type_t RbLLaMAModel::llama_model_type = {
920
1059
  RUBY_TYPED_FREE_IMMEDIATELY
921
1060
  };
922
1061
 
1062
+ class LLaMAGrammarElementWrapper {
1063
+ public:
1064
+ llama_grammar_element element;
1065
+
1066
+ LLaMAGrammarElementWrapper() {
1067
+ element.type = LLAMA_GRETYPE_END;
1068
+ element.value = 0;
1069
+ }
1070
+
1071
+ ~LLaMAGrammarElementWrapper() {}
1072
+ };
1073
+
1074
+ class RbLLaMAGrammarElement {
1075
+ public:
1076
+ static VALUE llama_grammar_element_alloc(VALUE self) {
1077
+ LLaMAGrammarElementWrapper* ptr = (LLaMAGrammarElementWrapper*)ruby_xmalloc(sizeof(LLaMAGrammarElementWrapper));
1078
+ new (ptr) LLaMAGrammarElementWrapper();
1079
+ return TypedData_Wrap_Struct(self, &llama_grammar_element_type, ptr);
1080
+ }
1081
+
1082
+ static void llama_grammar_element_free(void* ptr) {
1083
+ ((LLaMAGrammarElementWrapper*)ptr)->~LLaMAGrammarElementWrapper();
1084
+ ruby_xfree(ptr);
1085
+ }
1086
+
1087
+ static size_t llama_grammar_element_size(const void* ptr) {
1088
+ return sizeof(*((LLaMAGrammarElementWrapper*)ptr));
1089
+ }
1090
+
1091
+ static LLaMAGrammarElementWrapper* get_llama_grammar_element(VALUE self) {
1092
+ LLaMAGrammarElementWrapper* ptr;
1093
+ TypedData_Get_Struct(self, LLaMAGrammarElementWrapper, &llama_grammar_element_type, ptr);
1094
+ return ptr;
1095
+ }
1096
+
1097
+ static void define_class(VALUE outer) {
1098
+ rb_cLLaMAGrammarElement = rb_define_class_under(outer, "GrammarElement", rb_cObject);
1099
+ rb_define_alloc_func(rb_cLLaMAGrammarElement, llama_grammar_element_alloc);
1100
+ rb_define_method(rb_cLLaMAGrammarElement, "initialize", RUBY_METHOD_FUNC(_llama_grammar_element_init), -1);
1101
+ rb_define_method(rb_cLLaMAGrammarElement, "type=", RUBY_METHOD_FUNC(_llama_grammar_element_set_type), 1);
1102
+ rb_define_method(rb_cLLaMAGrammarElement, "type", RUBY_METHOD_FUNC(_llama_grammar_element_get_type), 0);
1103
+ rb_define_method(rb_cLLaMAGrammarElement, "value=", RUBY_METHOD_FUNC(_llama_grammar_element_set_value), 1);
1104
+ rb_define_method(rb_cLLaMAGrammarElement, "value", RUBY_METHOD_FUNC(_llama_grammar_element_get_value), 0);
1105
+ }
1106
+
1107
+ private:
1108
+ static const rb_data_type_t llama_grammar_element_type;
1109
+
1110
+ static VALUE _llama_grammar_element_init(int argc, VALUE* argv, VALUE self) {
1111
+ VALUE kw_args = Qnil;
1112
+ ID kw_table[2] = { rb_intern("type"), rb_intern("value") };
1113
+ VALUE kw_values[2] = { Qundef, Qundef };
1114
+ VALUE arr = Qnil;
1115
+ rb_scan_args(argc, argv, ":", &arr, &kw_args);
1116
+ rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
1117
+
1118
+ if (kw_values[0] != Qundef && !RB_INTEGER_TYPE_P(kw_values[0])) {
1119
+ rb_raise(rb_eArgError, "type must be an integer");
1120
+ return Qnil;
1121
+ }
1122
+ if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
1123
+ rb_raise(rb_eArgError, "value must be an integer");
1124
+ return Qnil;
1125
+ }
1126
+
1127
+ LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
1128
+ new (ptr) LLaMAGrammarElementWrapper();
1129
+
1130
+ if (kw_values[0] != Qundef) {
1131
+ ptr->element.type = (enum llama_gretype)NUM2INT(kw_values[0]);
1132
+ }
1133
+ if (kw_values[1] != Qundef) {
1134
+ ptr->element.value = NUM2INT(kw_values[1]);
1135
+ }
1136
+
1137
+ return self;
1138
+ }
1139
+
1140
+ // type
1141
+ static VALUE _llama_grammar_element_set_type(VALUE self, VALUE type) {
1142
+ LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
1143
+ ptr->element.type = (enum llama_gretype)NUM2INT(type);
1144
+ return INT2NUM(ptr->element.type);
1145
+ }
1146
+
1147
+ static VALUE _llama_grammar_element_get_type(VALUE self) {
1148
+ LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
1149
+ return INT2NUM(ptr->element.type);
1150
+ }
1151
+
1152
+ // value
1153
+ static VALUE _llama_grammar_element_set_value(VALUE self, VALUE type) {
1154
+ LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
1155
+ ptr->element.value = NUM2INT(type);
1156
+ return INT2NUM(ptr->element.value);
1157
+ }
1158
+
1159
+ static VALUE _llama_grammar_element_get_value(VALUE self) {
1160
+ LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
1161
+ return INT2NUM(ptr->element.value);
1162
+ }
1163
+ };
1164
+
1165
+ const rb_data_type_t RbLLaMAGrammarElement::llama_grammar_element_type = {
1166
+ "RbLLaMAGrammarElement",
1167
+ { NULL,
1168
+ RbLLaMAGrammarElement::llama_grammar_element_free,
1169
+ RbLLaMAGrammarElement::llama_grammar_element_size },
1170
+ NULL,
1171
+ NULL,
1172
+ RUBY_TYPED_FREE_IMMEDIATELY
1173
+ };
1174
+
1175
+ class LLaMAGrammarWrapper {
1176
+ public:
1177
+ struct llama_grammar* grammar;
1178
+
1179
+ LLaMAGrammarWrapper() : grammar(nullptr) {}
1180
+
1181
+ ~LLaMAGrammarWrapper() {
1182
+ if (grammar) {
1183
+ llama_grammar_free(grammar);
1184
+ }
1185
+ }
1186
+ };
1187
+
1188
+ class RbLLaMAGrammar {
1189
+ public:
1190
+ static VALUE llama_grammar_alloc(VALUE self) {
1191
+ LLaMAGrammarWrapper* ptr = (LLaMAGrammarWrapper*)ruby_xmalloc(sizeof(LLaMAGrammarWrapper));
1192
+ new (ptr) LLaMAGrammarWrapper();
1193
+ return TypedData_Wrap_Struct(self, &llama_grammar_type, ptr);
1194
+ }
1195
+
1196
+ static void llama_grammar_free(void* ptr) {
1197
+ ((LLaMAGrammarWrapper*)ptr)->~LLaMAGrammarWrapper();
1198
+ ruby_xfree(ptr);
1199
+ }
1200
+
1201
+ static size_t llama_grammar_size(const void* ptr) {
1202
+ return sizeof(*((LLaMAGrammarWrapper*)ptr));
1203
+ }
1204
+
1205
+ static LLaMAGrammarWrapper* get_llama_grammar(VALUE self) {
1206
+ LLaMAGrammarWrapper* ptr;
1207
+ TypedData_Get_Struct(self, LLaMAGrammarWrapper, &llama_grammar_type, ptr);
1208
+ return ptr;
1209
+ }
1210
+
1211
+ static void define_class(VALUE outer) {
1212
+ rb_cLLaMAGrammar = rb_define_class_under(outer, "Grammar", rb_cObject);
1213
+ rb_define_alloc_func(rb_cLLaMAGrammar, llama_grammar_alloc);
1214
+ rb_define_method(rb_cLLaMAGrammar, "initialize", RUBY_METHOD_FUNC(_llama_grammar_init), -1);
1215
+ }
1216
+
1217
+ private:
1218
+ static const rb_data_type_t llama_grammar_type;
1219
+
1220
+ static VALUE _llama_grammar_init(int argc, VALUE* argv, VALUE self) {
1221
+ VALUE kw_args = Qnil;
1222
+ ID kw_table[2] = { rb_intern("rules"), rb_intern("start_rule_index") };
1223
+ VALUE kw_values[2] = { Qundef, Qundef };
1224
+ rb_scan_args(argc, argv, ":", &kw_args);
1225
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
1226
+
1227
+ if (!RB_TYPE_P(kw_values[0], T_ARRAY)) {
1228
+ rb_raise(rb_eArgError, "rules must be an array");
1229
+ return Qnil;
1230
+ }
1231
+ if (!RB_INTEGER_TYPE_P(kw_values[1])) {
1232
+ rb_raise(rb_eArgError, "start_rule_index must be an integer");
1233
+ return Qnil;
1234
+ }
1235
+
1236
+ const int n_rules = RARRAY_LEN(kw_values[0]);
1237
+ llama_grammar_element** rules = ALLOCA_N(llama_grammar_element*, n_rules);
1238
+ for (int i = 0; i < n_rules; ++i) {
1239
+ VALUE rule = rb_ary_entry(kw_values[0], i);
1240
+ if (!RB_TYPE_P(rule, T_ARRAY)) {
1241
+ rb_raise(rb_eArgError, "element of rules must be an array");
1242
+ return Qnil;
1243
+ }
1244
+ const int n_elements = RARRAY_LEN(rule);
1245
+ llama_grammar_element* elements = ALLOCA_N(llama_grammar_element, n_elements);
1246
+ for (int j = 0; j < n_elements; ++j) {
1247
+ VALUE element = rb_ary_entry(rule, j);
1248
+ if (!rb_obj_is_kind_of(element, rb_cLLaMAGrammarElement)) {
1249
+ rb_raise(rb_eArgError, "element of rule must be an instance of GrammarElement");
1250
+ return Qnil;
1251
+ }
1252
+ LLaMAGrammarElementWrapper* ptr = RbLLaMAGrammarElement::get_llama_grammar_element(element);
1253
+ elements[j] = ptr->element;
1254
+ }
1255
+ rules[i] = elements;
1256
+ }
1257
+
1258
+ const size_t start_rule_index = NUM2SIZET(kw_values[1]);
1259
+
1260
+ LLaMAGrammarWrapper* ptr = get_llama_grammar(self);
1261
+ new (ptr) LLaMAGrammarWrapper();
1262
+ ptr->grammar = llama_grammar_init((const llama_grammar_element**)rules, n_rules, start_rule_index);
1263
+
1264
+ return self;
1265
+ }
1266
+ };
1267
+
1268
+ const rb_data_type_t RbLLaMAGrammar::llama_grammar_type = {
1269
+ "RbLLaMAGrammar",
1270
+ { NULL,
1271
+ RbLLaMAGrammar::llama_grammar_free,
1272
+ RbLLaMAGrammar::llama_grammar_size },
1273
+ NULL,
1274
+ NULL,
1275
+ RUBY_TYPED_FREE_IMMEDIATELY
1276
+ };
1277
+
923
1278
  class LLaMAContextWrapper {
924
1279
  public:
925
1280
  struct llama_context* ctx;
@@ -991,6 +1346,8 @@ public:
991
1346
  rb_define_method(rb_cLLaMAContext, "sample_token_mirostat_v2", RUBY_METHOD_FUNC(_llama_context_sample_token_mirostat_v2), -1);
992
1347
  rb_define_method(rb_cLLaMAContext, "sample_token_greedy", RUBY_METHOD_FUNC(_llama_context_sample_token_greedy), 1);
993
1348
  rb_define_method(rb_cLLaMAContext, "sample_token", RUBY_METHOD_FUNC(_llama_context_sample_token), 1);
1349
+ rb_define_method(rb_cLLaMAContext, "sample_grammar", RUBY_METHOD_FUNC(_llama_context_sample_grammar), -1);
1350
+ rb_define_method(rb_cLLaMAContext, "grammar_accept_token", RUBY_METHOD_FUNC(_llama_context_grammar_accept_token), -1);
994
1351
  }
995
1352
 
996
1353
  private:
@@ -1581,11 +1938,11 @@ private:
1581
1938
 
1582
1939
  static VALUE _llama_context_sample_classifier_free_guidance(int argc, VALUE* argv, VALUE self) {
1583
1940
  VALUE kw_args = Qnil;
1584
- ID kw_table[3] = { rb_intern("guidance"), rb_intern("scale"), rb_intern("smooth_factor") };
1585
- VALUE kw_values[3] = { Qundef, Qundef, Qundef };
1941
+ ID kw_table[2] = { rb_intern("guidance"), rb_intern("scale") };
1942
+ VALUE kw_values[2] = { Qundef, Qundef };
1586
1943
  VALUE candidates = Qnil;
1587
1944
  rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
1588
- rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
1945
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
1589
1946
 
1590
1947
  if (!rb_obj_is_kind_of(kw_values[0], rb_cLLaMAContext)) {
1591
1948
  rb_raise(rb_eArgError, "guidance must be a Context");
@@ -1595,10 +1952,6 @@ private:
1595
1952
  rb_raise(rb_eArgError, "scale must be a float");
1596
1953
  return Qnil;
1597
1954
  }
1598
- if (!RB_FLOAT_TYPE_P(kw_values[2])) {
1599
- rb_raise(rb_eArgError, "smooth_factor must be a float");
1600
- return Qnil;
1601
- }
1602
1955
 
1603
1956
  LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
1604
1957
  if (ctx_ptr->ctx == NULL) {
@@ -1617,9 +1970,8 @@ private:
1617
1970
  return Qnil;
1618
1971
  }
1619
1972
  const float scale = NUM2DBL(kw_values[1]);
1620
- const float smooth_factor = NUM2DBL(kw_values[2]);
1621
1973
 
1622
- llama_sample_classifier_free_guidance(ctx_ptr->ctx, &(cnd_ptr->array), guidance_ptr->ctx, scale, smooth_factor);
1974
+ llama_sample_classifier_free_guidance(ctx_ptr->ctx, &(cnd_ptr->array), guidance_ptr->ctx, scale);
1623
1975
 
1624
1976
  return Qnil;
1625
1977
  }
@@ -1972,6 +2324,69 @@ private:
1972
2324
  llama_token id = llama_sample_token(ctx_ptr->ctx, &(cnd_ptr->array));
1973
2325
  return INT2NUM(id);
1974
2326
  }
2327
+
2328
+ static VALUE _llama_context_sample_grammar(int argc, VALUE* argv, VALUE self) {
2329
+ VALUE kw_args = Qnil;
2330
+ ID kw_table[1] = { rb_intern("grammar") };
2331
+ VALUE kw_values[1] = { Qundef };
2332
+ VALUE candidates = Qnil;
2333
+ rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
2334
+ rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
2335
+
2336
+ if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
2337
+ rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
2338
+ return Qnil;
2339
+ }
2340
+ if (!rb_obj_is_kind_of(kw_values[0], rb_cLLaMAGrammar)) {
2341
+ rb_raise(rb_eArgError, "grammar must be a Grammar");
2342
+ return Qnil;
2343
+ }
2344
+
2345
+ LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
2346
+ if (ctx_ptr->ctx == NULL) {
2347
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2348
+ return Qnil;
2349
+ }
2350
+ LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
2351
+ if (cnd_ptr->array.data == nullptr) {
2352
+ rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
2353
+ return Qnil;
2354
+ }
2355
+ LLaMAGrammarWrapper* grm_ptr = RbLLaMAGrammar::get_llama_grammar(kw_values[0]);
2356
+
2357
+ llama_sample_grammar(ctx_ptr->ctx, &(cnd_ptr->array), grm_ptr->grammar);
2358
+
2359
+ return Qnil;
2360
+ }
2361
+
2362
+ static VALUE _llama_context_grammar_accept_token(int argc, VALUE* argv, VALUE self) {
2363
+ VALUE kw_args = Qnil;
2364
+ ID kw_table[2] = { rb_intern("grammar"), rb_intern("token") };
2365
+ VALUE kw_values[2] = { Qundef, Qundef };
2366
+ rb_scan_args(argc, argv, ":", &kw_args);
2367
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
2368
+
2369
+ if (!rb_obj_is_kind_of(kw_values[0], rb_cLLaMAGrammar)) {
2370
+ rb_raise(rb_eArgError, "grammar must be a Grammar");
2371
+ return Qnil;
2372
+ }
2373
+ if (!RB_INTEGER_TYPE_P(kw_values[1])) {
2374
+ rb_raise(rb_eArgError, "token must be an Integer");
2375
+ return Qnil;
2376
+ }
2377
+
2378
+ LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
2379
+ if (ctx_ptr->ctx == NULL) {
2380
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2381
+ return Qnil;
2382
+ }
2383
+ LLaMAGrammarWrapper* grm_ptr = RbLLaMAGrammar::get_llama_grammar(kw_values[0]);
2384
+ llama_token token = NUM2INT(kw_values[1]);
2385
+
2386
+ llama_grammar_accept_token(ctx_ptr->ctx, grm_ptr->grammar, token);
2387
+
2388
+ return Qnil;
2389
+ }
1975
2390
  };
1976
2391
 
1977
2392
  const rb_data_type_t RbLLaMAContext::llama_context_type = {
@@ -2062,6 +2477,10 @@ static VALUE rb_llama_mlock_supported(VALUE self) {
2062
2477
  return llama_mlock_supported() ? Qtrue : Qfalse;
2063
2478
  }
2064
2479
 
2480
+ static VALUE rb_llama_max_devices(VALUE self) {
2481
+ return INT2NUM(llama_max_devices());
2482
+ }
2483
+
2065
2484
  extern "C" void Init_llama_cpp(void) {
2066
2485
  rb_mLLaMACpp = rb_define_module("LLaMACpp");
2067
2486
 
@@ -2072,6 +2491,8 @@ extern "C" void Init_llama_cpp(void) {
2072
2491
  RbLLaMAContext::define_class(rb_mLLaMACpp);
2073
2492
  RbLLaMAContextParams::define_class(rb_mLLaMACpp);
2074
2493
  RbLLaMAModelQuantizeParams::define_class(rb_mLLaMACpp);
2494
+ RbLLaMAGrammarElement::define_class(rb_mLLaMACpp);
2495
+ RbLLaMAGrammar::define_class(rb_mLLaMACpp);
2075
2496
 
2076
2497
  rb_define_module_function(rb_mLLaMACpp, "backend_init", rb_llama_llama_backend_init, -1);
2077
2498
  rb_define_module_function(rb_mLLaMACpp, "backend_free", rb_llama_llama_backend_free, 0);
@@ -2082,6 +2503,7 @@ extern "C" void Init_llama_cpp(void) {
2082
2503
  rb_define_module_function(rb_mLLaMACpp, "print_system_info", rb_llama_print_system_info, 0);
2083
2504
  rb_define_module_function(rb_mLLaMACpp, "mmap_supported?", rb_llama_mmap_supported, 0);
2084
2505
  rb_define_module_function(rb_mLLaMACpp, "mlock_supported?", rb_llama_mlock_supported, 0);
2506
+ rb_define_module_function(rb_mLLaMACpp, "max_devices", rb_llama_max_devices, 0);
2085
2507
 
2086
2508
  rb_define_const(rb_mLLaMACpp, "LLAMA_MAX_DEVICES", INT2NUM(LLAMA_MAX_DEVICES));
2087
2509
 
@@ -2103,6 +2525,14 @@ extern "C" void Init_llama_cpp(void) {
2103
2525
  rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q5_K_M", INT2NUM(LLAMA_FTYPE_MOSTLY_Q5_K_M));
2104
2526
  rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q6_K", INT2NUM(LLAMA_FTYPE_MOSTLY_Q6_K));
2105
2527
 
2528
+ rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_END", INT2NUM(LLAMA_GRETYPE_END));
2529
+ rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_ALT", INT2NUM(LLAMA_GRETYPE_ALT));
2530
+ rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_RULE_REF", INT2NUM(LLAMA_GRETYPE_RULE_REF));
2531
+ rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR", INT2NUM(LLAMA_GRETYPE_CHAR));
2532
+ rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR_NOT", INT2NUM(LLAMA_GRETYPE_CHAR_NOT));
2533
+ rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR_RNG_UPPER", INT2NUM(LLAMA_GRETYPE_CHAR_RNG_UPPER));
2534
+ rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR_ALT", INT2NUM(LLAMA_GRETYPE_CHAR_ALT));
2535
+
2106
2536
  std::stringstream ss_magic;
2107
2537
  ss_magic << std::showbase << std::hex << LLAMA_FILE_MAGIC_GGJT;
2108
2538
  rb_define_const(rb_mLLaMACpp, "LLAMA_FILE_MAGIC_GGJT", rb_str_new2(ss_magic.str().c_str()));