llama_cpp 0.3.4 → 0.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 35afb5cc65c290036ae7e45459eadc9b509f34f33a3f7708244cf47f1a38829f
4
- data.tar.gz: 3301158526c63d9d2004e22bda0d1cc8025b4343d8d737df96260786531b074d
3
+ metadata.gz: 545786d4c9308ffe0f7e214a12427beaea0b26bec915ff84b16eed25ef1932a4
4
+ data.tar.gz: aaa0d4fc1710b13a26163306c8b51e423233c2f7e4b3d6127f94c9b6c4846f9c
5
5
  SHA512:
6
- metadata.gz: b0a50f9f012f44f119a70790d3de07c7fcc64151246791e270e4ff9fc479a85a01c53cf2775945eba3145a3ba89da55a8d14891c6236cfeae16aed5ae455cf0d
7
- data.tar.gz: ede388584e115ae93d509b6c15b288303c348f3cfe8ea46879a1b69e6c96be31a321edbb52cfbeb309a8fb456738f3f6b7cc1d3f71ce7addbd05b3a1e73d4755
6
+ metadata.gz: 12b3ac122fd7ea59b51e2d6ff905ed1a71cf8a8b3650a269d4a3793ae32a0149f6836a792c8f216d0fdb0c39aeb3b47914e73ffc74b574bbe686660e6be84ea1
7
+ data.tar.gz: 5056b95552f3434692a6c19653810d77bb28ddf9b28abd78712ccfb4ee4f7d836a5d54e283513fcfc617cc79ffa7bb9257d4ac2b6d96ec89158bf94acd4cec86
data/CHANGELOG.md CHANGED
@@ -1,3 +1,15 @@
1
+ ## [[0.3.6](https://github.com/yoshoku/llama_cpp.rb/compare/v0.3.5...v0.3.6)] - 2023-08-04
2
+
3
+ - Bump bundled llama.cpp from master-1a94186 to master-468ea24.
4
+ - Add `mul_mat_q` option to ContextParams.
5
+
6
+ ## [[0.3.5](https://github.com/yoshoku/llama_cpp.rb/compare/v0.3.4...v0.3.5)] - 2023-07-29
7
+
8
+ - Bump bundled llama.cpp from master-d924522 to master-1a94186.
9
+ - Add `GrammarElement` and `Grammar` classes.
10
+ - Add `sample_grammar` method to Context.
11
+ - Add `grammar_accept_token method` method to Context.
12
+
1
13
  ## [[0.3.4](https://github.com/yoshoku/llama_cpp.rb/compare/v0.3.3...v0.3.4)] - 2023-07-23
2
14
 
3
15
  - Bump bundled llama.cpp from master-32c5411 to master-d924522.
data/README.md CHANGED
@@ -12,11 +12,27 @@ This gem is still under development and may undergo many changes in the future.
12
12
 
13
13
  Install the gem and add to the application's Gemfile by executing:
14
14
 
15
- $ bundle add llama_cpp
15
+ ```sh
16
+ $ bundle add llama_cpp
17
+ ```
16
18
 
17
19
  If bundler is not being used to manage dependencies, install the gem by executing:
18
20
 
19
- $ gem install llama_cpp
21
+ ```sh
22
+ $ gem install llama_cpp
23
+ ```
24
+
25
+ There are several installation options for improving execution performance:
26
+
27
+ ```sh
28
+ # use OpenBLAS
29
+ $ gem install llama_cpp -- --with-openblas
30
+
31
+ # use Metal on macOS
32
+ $ gem install llama_cpp -- --with-metal
33
+ ```
34
+
35
+ Those options are defined in [extconf.rb](https://github.com/yoshoku/llama_cpp.rb/blob/main/ext/llama_cpp/extconf.rb) by with_config method.
20
36
 
21
37
  ## Usage
22
38
 
@@ -5,7 +5,7 @@ require 'fileutils'
5
5
 
6
6
  abort 'libstdc++ is not found.' unless have_library('stdc++')
7
7
 
8
- $srcs = %w[ggml.c llama.cpp llama_cpp.cpp]
8
+ $srcs = %w[ggml.c ggml-alloc.c llama.cpp llama_cpp.cpp]
9
9
  $srcs << 'ggml-opencl.cpp' if with_config('clblast')
10
10
  $srcs << 'ggml-mpi.c' if with_config('mpi')
11
11
  $CFLAGS << ' -w -DNDEBUG'
@@ -85,6 +85,7 @@ if with_config('mpi')
85
85
  $CXXFLAGS << ' -DGGML_USE_MPI -Wno-cast-qual'
86
86
  end
87
87
 
88
+ # @!visibility private
88
89
  UNAME_M = RbConfig::CONFIG['build_cpu'] || RbConfig::CONFIG['host_cpu'] || RbConfig::CONFIG['target_cpu']
89
90
 
90
91
  # rubocop:disable Layout/LineLength
@@ -8,6 +8,8 @@ VALUE rb_cLLaMAContextParams;
8
8
  VALUE rb_cLLaMAModelQuantizeParams;
9
9
  VALUE rb_cLLaMATokenData;
10
10
  VALUE rb_cLLaMATokenDataArray;
11
+ VALUE rb_cLLaMAGrammarElement;
12
+ VALUE rb_cLLaMAGrammar;
11
13
 
12
14
  class LLaMATokenDataWrapper {
13
15
  public:
@@ -412,6 +414,8 @@ public:
412
414
  rb_define_method(rb_cLLaMAContextParams, "rope_freq_scale", RUBY_METHOD_FUNC(_llama_context_params_get_rope_freq_scale), 0);
413
415
  rb_define_method(rb_cLLaMAContextParams, "low_vram=", RUBY_METHOD_FUNC(_llama_context_params_set_low_vram), 1);
414
416
  rb_define_method(rb_cLLaMAContextParams, "low_vram", RUBY_METHOD_FUNC(_llama_context_params_get_low_vram), 0);
417
+ rb_define_method(rb_cLLaMAContextParams, "mul_mat_q=", RUBY_METHOD_FUNC(_llama_context_params_set_mul_mat_q), 1);
418
+ rb_define_method(rb_cLLaMAContextParams, "mul_mat_q", RUBY_METHOD_FUNC(_llama_context_params_get_mul_mat_q), 0);
415
419
  rb_define_method(rb_cLLaMAContextParams, "seed=", RUBY_METHOD_FUNC(_llama_context_params_set_seed), 1);
416
420
  rb_define_method(rb_cLLaMAContextParams, "seed", RUBY_METHOD_FUNC(_llama_context_params_get_seed), 0);
417
421
  rb_define_method(rb_cLLaMAContextParams, "f16_kv=", RUBY_METHOD_FUNC(_llama_context_params_set_f16_kv), 1);
@@ -525,7 +529,7 @@ private:
525
529
  // low_vram
526
530
  static VALUE _llama_context_params_set_low_vram(VALUE self, VALUE low_vram) {
527
531
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
528
- ptr->params.low_vram = low_vram == Qtrue ? true : false;
532
+ ptr->params.low_vram = RTEST(low_vram) ? true : false;
529
533
  return ptr->params.low_vram ? Qtrue : Qfalse;
530
534
  }
531
535
 
@@ -534,6 +538,18 @@ private:
534
538
  return ptr->params.low_vram ? Qtrue : Qfalse;
535
539
  }
536
540
 
541
+ // mul_mat_q
542
+ static VALUE _llama_context_params_set_mul_mat_q(VALUE self, VALUE mul_mat_q) {
543
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
544
+ ptr->params.mul_mat_q = RTEST(mul_mat_q) ? true : false;
545
+ return ptr->params.mul_mat_q ? Qtrue : Qfalse;
546
+ }
547
+
548
+ static VALUE _llama_context_params_get_mul_mat_q(VALUE self) {
549
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
550
+ return ptr->params.mul_mat_q ? Qtrue : Qfalse;
551
+ }
552
+
537
553
  // seed
538
554
  static VALUE _llama_context_params_set_seed(VALUE self, VALUE seed) {
539
555
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
@@ -553,7 +569,7 @@ private:
553
569
  // f16_kv
554
570
  static VALUE _llama_context_params_set_f16_kv(VALUE self, VALUE f16_kv) {
555
571
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
556
- ptr->params.f16_kv = f16_kv == Qtrue ? true : false;
572
+ ptr->params.f16_kv = RTEST(f16_kv) ? true : false;
557
573
  return ptr->params.f16_kv ? Qtrue : Qfalse;
558
574
  }
559
575
 
@@ -565,7 +581,7 @@ private:
565
581
  // logits_all
566
582
  static VALUE _llama_context_params_set_logits_all(VALUE self, VALUE logits_all) {
567
583
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
568
- ptr->params.logits_all = logits_all == Qtrue ? true : false;
584
+ ptr->params.logits_all = RTEST(logits_all) ? true : false;
569
585
  return ptr->params.logits_all ? Qtrue : Qfalse;
570
586
  }
571
587
 
@@ -577,7 +593,7 @@ private:
577
593
  // vocab_only
578
594
  static VALUE _llama_context_params_set_vocab_only(VALUE self, VALUE vocab_only) {
579
595
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
580
- ptr->params.vocab_only = vocab_only == Qtrue ? true : false;
596
+ ptr->params.vocab_only = RTEST(vocab_only) ? true : false;
581
597
  return ptr->params.vocab_only ? Qtrue : Qfalse;
582
598
  }
583
599
 
@@ -589,7 +605,7 @@ private:
589
605
  // use_mmap
590
606
  static VALUE _llama_context_params_set_use_mmap(VALUE self, VALUE use_mmap) {
591
607
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
592
- ptr->params.use_mmap = use_mmap == Qtrue ? true : false;
608
+ ptr->params.use_mmap = RTEST(use_mmap) ? true : false;
593
609
  return ptr->params.use_mmap ? Qtrue : Qfalse;
594
610
  }
595
611
 
@@ -601,7 +617,7 @@ private:
601
617
  // use_mlock
602
618
  static VALUE _llama_context_params_set_use_mlock(VALUE self, VALUE use_mlock) {
603
619
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
604
- ptr->params.use_mlock = use_mlock == Qtrue ? true : false;
620
+ ptr->params.use_mlock = RTEST(use_mlock) ? true : false;
605
621
  return ptr->params.use_mlock ? Qtrue : Qfalse;
606
622
  }
607
623
 
@@ -613,7 +629,7 @@ private:
613
629
  // embedding
614
630
  static VALUE _llama_context_params_set_embedding(VALUE self, VALUE embedding) {
615
631
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
616
- ptr->params.embedding = embedding == Qtrue ? true : false;
632
+ ptr->params.embedding = RTEST(embedding) ? true : false;
617
633
  return ptr->params.embedding ? Qtrue : Qfalse;
618
634
  }
619
635
 
@@ -1057,6 +1073,222 @@ const rb_data_type_t RbLLaMAModel::llama_model_type = {
1057
1073
  RUBY_TYPED_FREE_IMMEDIATELY
1058
1074
  };
1059
1075
 
1076
+ class LLaMAGrammarElementWrapper {
1077
+ public:
1078
+ llama_grammar_element element;
1079
+
1080
+ LLaMAGrammarElementWrapper() {
1081
+ element.type = LLAMA_GRETYPE_END;
1082
+ element.value = 0;
1083
+ }
1084
+
1085
+ ~LLaMAGrammarElementWrapper() {}
1086
+ };
1087
+
1088
+ class RbLLaMAGrammarElement {
1089
+ public:
1090
+ static VALUE llama_grammar_element_alloc(VALUE self) {
1091
+ LLaMAGrammarElementWrapper* ptr = (LLaMAGrammarElementWrapper*)ruby_xmalloc(sizeof(LLaMAGrammarElementWrapper));
1092
+ new (ptr) LLaMAGrammarElementWrapper();
1093
+ return TypedData_Wrap_Struct(self, &llama_grammar_element_type, ptr);
1094
+ }
1095
+
1096
+ static void llama_grammar_element_free(void* ptr) {
1097
+ ((LLaMAGrammarElementWrapper*)ptr)->~LLaMAGrammarElementWrapper();
1098
+ ruby_xfree(ptr);
1099
+ }
1100
+
1101
+ static size_t llama_grammar_element_size(const void* ptr) {
1102
+ return sizeof(*((LLaMAGrammarElementWrapper*)ptr));
1103
+ }
1104
+
1105
+ static LLaMAGrammarElementWrapper* get_llama_grammar_element(VALUE self) {
1106
+ LLaMAGrammarElementWrapper* ptr;
1107
+ TypedData_Get_Struct(self, LLaMAGrammarElementWrapper, &llama_grammar_element_type, ptr);
1108
+ return ptr;
1109
+ }
1110
+
1111
+ static void define_class(VALUE outer) {
1112
+ rb_cLLaMAGrammarElement = rb_define_class_under(outer, "GrammarElement", rb_cObject);
1113
+ rb_define_alloc_func(rb_cLLaMAGrammarElement, llama_grammar_element_alloc);
1114
+ rb_define_method(rb_cLLaMAGrammarElement, "initialize", RUBY_METHOD_FUNC(_llama_grammar_element_init), -1);
1115
+ rb_define_method(rb_cLLaMAGrammarElement, "type=", RUBY_METHOD_FUNC(_llama_grammar_element_set_type), 1);
1116
+ rb_define_method(rb_cLLaMAGrammarElement, "type", RUBY_METHOD_FUNC(_llama_grammar_element_get_type), 0);
1117
+ rb_define_method(rb_cLLaMAGrammarElement, "value=", RUBY_METHOD_FUNC(_llama_grammar_element_set_value), 1);
1118
+ rb_define_method(rb_cLLaMAGrammarElement, "value", RUBY_METHOD_FUNC(_llama_grammar_element_get_value), 0);
1119
+ }
1120
+
1121
+ private:
1122
+ static const rb_data_type_t llama_grammar_element_type;
1123
+
1124
+ static VALUE _llama_grammar_element_init(int argc, VALUE* argv, VALUE self) {
1125
+ VALUE kw_args = Qnil;
1126
+ ID kw_table[2] = { rb_intern("type"), rb_intern("value") };
1127
+ VALUE kw_values[2] = { Qundef, Qundef };
1128
+ VALUE arr = Qnil;
1129
+ rb_scan_args(argc, argv, ":", &arr, &kw_args);
1130
+ rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
1131
+
1132
+ if (kw_values[0] != Qundef && !RB_INTEGER_TYPE_P(kw_values[0])) {
1133
+ rb_raise(rb_eArgError, "type must be an integer");
1134
+ return Qnil;
1135
+ }
1136
+ if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
1137
+ rb_raise(rb_eArgError, "value must be an integer");
1138
+ return Qnil;
1139
+ }
1140
+
1141
+ LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
1142
+ new (ptr) LLaMAGrammarElementWrapper();
1143
+
1144
+ if (kw_values[0] != Qundef) {
1145
+ ptr->element.type = (enum llama_gretype)NUM2INT(kw_values[0]);
1146
+ }
1147
+ if (kw_values[1] != Qundef) {
1148
+ ptr->element.value = NUM2INT(kw_values[1]);
1149
+ }
1150
+
1151
+ return self;
1152
+ }
1153
+
1154
+ // type
1155
+ static VALUE _llama_grammar_element_set_type(VALUE self, VALUE type) {
1156
+ LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
1157
+ ptr->element.type = (enum llama_gretype)NUM2INT(type);
1158
+ return INT2NUM(ptr->element.type);
1159
+ }
1160
+
1161
+ static VALUE _llama_grammar_element_get_type(VALUE self) {
1162
+ LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
1163
+ return INT2NUM(ptr->element.type);
1164
+ }
1165
+
1166
+ // value
1167
+ static VALUE _llama_grammar_element_set_value(VALUE self, VALUE type) {
1168
+ LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
1169
+ ptr->element.value = NUM2INT(type);
1170
+ return INT2NUM(ptr->element.value);
1171
+ }
1172
+
1173
+ static VALUE _llama_grammar_element_get_value(VALUE self) {
1174
+ LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
1175
+ return INT2NUM(ptr->element.value);
1176
+ }
1177
+ };
1178
+
1179
+ const rb_data_type_t RbLLaMAGrammarElement::llama_grammar_element_type = {
1180
+ "RbLLaMAGrammarElement",
1181
+ { NULL,
1182
+ RbLLaMAGrammarElement::llama_grammar_element_free,
1183
+ RbLLaMAGrammarElement::llama_grammar_element_size },
1184
+ NULL,
1185
+ NULL,
1186
+ RUBY_TYPED_FREE_IMMEDIATELY
1187
+ };
1188
+
1189
+ class LLaMAGrammarWrapper {
1190
+ public:
1191
+ struct llama_grammar* grammar;
1192
+
1193
+ LLaMAGrammarWrapper() : grammar(nullptr) {}
1194
+
1195
+ ~LLaMAGrammarWrapper() {
1196
+ if (grammar) {
1197
+ llama_grammar_free(grammar);
1198
+ }
1199
+ }
1200
+ };
1201
+
1202
+ class RbLLaMAGrammar {
1203
+ public:
1204
+ static VALUE llama_grammar_alloc(VALUE self) {
1205
+ LLaMAGrammarWrapper* ptr = (LLaMAGrammarWrapper*)ruby_xmalloc(sizeof(LLaMAGrammarWrapper));
1206
+ new (ptr) LLaMAGrammarWrapper();
1207
+ return TypedData_Wrap_Struct(self, &llama_grammar_type, ptr);
1208
+ }
1209
+
1210
+ static void llama_grammar_free(void* ptr) {
1211
+ ((LLaMAGrammarWrapper*)ptr)->~LLaMAGrammarWrapper();
1212
+ ruby_xfree(ptr);
1213
+ }
1214
+
1215
+ static size_t llama_grammar_size(const void* ptr) {
1216
+ return sizeof(*((LLaMAGrammarWrapper*)ptr));
1217
+ }
1218
+
1219
+ static LLaMAGrammarWrapper* get_llama_grammar(VALUE self) {
1220
+ LLaMAGrammarWrapper* ptr;
1221
+ TypedData_Get_Struct(self, LLaMAGrammarWrapper, &llama_grammar_type, ptr);
1222
+ return ptr;
1223
+ }
1224
+
1225
+ static void define_class(VALUE outer) {
1226
+ rb_cLLaMAGrammar = rb_define_class_under(outer, "Grammar", rb_cObject);
1227
+ rb_define_alloc_func(rb_cLLaMAGrammar, llama_grammar_alloc);
1228
+ rb_define_method(rb_cLLaMAGrammar, "initialize", RUBY_METHOD_FUNC(_llama_grammar_init), -1);
1229
+ }
1230
+
1231
+ private:
1232
+ static const rb_data_type_t llama_grammar_type;
1233
+
1234
+ static VALUE _llama_grammar_init(int argc, VALUE* argv, VALUE self) {
1235
+ VALUE kw_args = Qnil;
1236
+ ID kw_table[2] = { rb_intern("rules"), rb_intern("start_rule_index") };
1237
+ VALUE kw_values[2] = { Qundef, Qundef };
1238
+ rb_scan_args(argc, argv, ":", &kw_args);
1239
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
1240
+
1241
+ if (!RB_TYPE_P(kw_values[0], T_ARRAY)) {
1242
+ rb_raise(rb_eArgError, "rules must be an array");
1243
+ return Qnil;
1244
+ }
1245
+ if (!RB_INTEGER_TYPE_P(kw_values[1])) {
1246
+ rb_raise(rb_eArgError, "start_rule_index must be an integer");
1247
+ return Qnil;
1248
+ }
1249
+
1250
+ const int n_rules = RARRAY_LEN(kw_values[0]);
1251
+ llama_grammar_element** rules = ALLOCA_N(llama_grammar_element*, n_rules);
1252
+ for (int i = 0; i < n_rules; ++i) {
1253
+ VALUE rule = rb_ary_entry(kw_values[0], i);
1254
+ if (!RB_TYPE_P(rule, T_ARRAY)) {
1255
+ rb_raise(rb_eArgError, "element of rules must be an array");
1256
+ return Qnil;
1257
+ }
1258
+ const int n_elements = RARRAY_LEN(rule);
1259
+ llama_grammar_element* elements = ALLOCA_N(llama_grammar_element, n_elements);
1260
+ for (int j = 0; j < n_elements; ++j) {
1261
+ VALUE element = rb_ary_entry(rule, j);
1262
+ if (!rb_obj_is_kind_of(element, rb_cLLaMAGrammarElement)) {
1263
+ rb_raise(rb_eArgError, "element of rule must be an instance of GrammarElement");
1264
+ return Qnil;
1265
+ }
1266
+ LLaMAGrammarElementWrapper* ptr = RbLLaMAGrammarElement::get_llama_grammar_element(element);
1267
+ elements[j] = ptr->element;
1268
+ }
1269
+ rules[i] = elements;
1270
+ }
1271
+
1272
+ const size_t start_rule_index = NUM2SIZET(kw_values[1]);
1273
+
1274
+ LLaMAGrammarWrapper* ptr = get_llama_grammar(self);
1275
+ new (ptr) LLaMAGrammarWrapper();
1276
+ ptr->grammar = llama_grammar_init((const llama_grammar_element**)rules, n_rules, start_rule_index);
1277
+
1278
+ return self;
1279
+ }
1280
+ };
1281
+
1282
+ const rb_data_type_t RbLLaMAGrammar::llama_grammar_type = {
1283
+ "RbLLaMAGrammar",
1284
+ { NULL,
1285
+ RbLLaMAGrammar::llama_grammar_free,
1286
+ RbLLaMAGrammar::llama_grammar_size },
1287
+ NULL,
1288
+ NULL,
1289
+ RUBY_TYPED_FREE_IMMEDIATELY
1290
+ };
1291
+
1060
1292
  class LLaMAContextWrapper {
1061
1293
  public:
1062
1294
  struct llama_context* ctx;
@@ -1128,6 +1360,8 @@ public:
1128
1360
  rb_define_method(rb_cLLaMAContext, "sample_token_mirostat_v2", RUBY_METHOD_FUNC(_llama_context_sample_token_mirostat_v2), -1);
1129
1361
  rb_define_method(rb_cLLaMAContext, "sample_token_greedy", RUBY_METHOD_FUNC(_llama_context_sample_token_greedy), 1);
1130
1362
  rb_define_method(rb_cLLaMAContext, "sample_token", RUBY_METHOD_FUNC(_llama_context_sample_token), 1);
1363
+ rb_define_method(rb_cLLaMAContext, "sample_grammar", RUBY_METHOD_FUNC(_llama_context_sample_grammar), -1);
1364
+ rb_define_method(rb_cLLaMAContext, "grammar_accept_token", RUBY_METHOD_FUNC(_llama_context_grammar_accept_token), -1);
1131
1365
  }
1132
1366
 
1133
1367
  private:
@@ -2104,6 +2338,69 @@ private:
2104
2338
  llama_token id = llama_sample_token(ctx_ptr->ctx, &(cnd_ptr->array));
2105
2339
  return INT2NUM(id);
2106
2340
  }
2341
+
2342
+ static VALUE _llama_context_sample_grammar(int argc, VALUE* argv, VALUE self) {
2343
+ VALUE kw_args = Qnil;
2344
+ ID kw_table[1] = { rb_intern("grammar") };
2345
+ VALUE kw_values[1] = { Qundef };
2346
+ VALUE candidates = Qnil;
2347
+ rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
2348
+ rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
2349
+
2350
+ if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
2351
+ rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
2352
+ return Qnil;
2353
+ }
2354
+ if (!rb_obj_is_kind_of(kw_values[0], rb_cLLaMAGrammar)) {
2355
+ rb_raise(rb_eArgError, "grammar must be a Grammar");
2356
+ return Qnil;
2357
+ }
2358
+
2359
+ LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
2360
+ if (ctx_ptr->ctx == NULL) {
2361
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2362
+ return Qnil;
2363
+ }
2364
+ LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
2365
+ if (cnd_ptr->array.data == nullptr) {
2366
+ rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
2367
+ return Qnil;
2368
+ }
2369
+ LLaMAGrammarWrapper* grm_ptr = RbLLaMAGrammar::get_llama_grammar(kw_values[0]);
2370
+
2371
+ llama_sample_grammar(ctx_ptr->ctx, &(cnd_ptr->array), grm_ptr->grammar);
2372
+
2373
+ return Qnil;
2374
+ }
2375
+
2376
+ static VALUE _llama_context_grammar_accept_token(int argc, VALUE* argv, VALUE self) {
2377
+ VALUE kw_args = Qnil;
2378
+ ID kw_table[2] = { rb_intern("grammar"), rb_intern("token") };
2379
+ VALUE kw_values[2] = { Qundef, Qundef };
2380
+ rb_scan_args(argc, argv, ":", &kw_args);
2381
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
2382
+
2383
+ if (!rb_obj_is_kind_of(kw_values[0], rb_cLLaMAGrammar)) {
2384
+ rb_raise(rb_eArgError, "grammar must be a Grammar");
2385
+ return Qnil;
2386
+ }
2387
+ if (!RB_INTEGER_TYPE_P(kw_values[1])) {
2388
+ rb_raise(rb_eArgError, "token must be an Integer");
2389
+ return Qnil;
2390
+ }
2391
+
2392
+ LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
2393
+ if (ctx_ptr->ctx == NULL) {
2394
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2395
+ return Qnil;
2396
+ }
2397
+ LLaMAGrammarWrapper* grm_ptr = RbLLaMAGrammar::get_llama_grammar(kw_values[0]);
2398
+ llama_token token = NUM2INT(kw_values[1]);
2399
+
2400
+ llama_grammar_accept_token(ctx_ptr->ctx, grm_ptr->grammar, token);
2401
+
2402
+ return Qnil;
2403
+ }
2107
2404
  };
2108
2405
 
2109
2406
  const rb_data_type_t RbLLaMAContext::llama_context_type = {
@@ -2125,7 +2422,7 @@ static VALUE rb_llama_llama_backend_init(int argc, VALUE* argv, VALUE self) {
2125
2422
  rb_scan_args(argc, argv, ":", &kw_args);
2126
2423
  rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
2127
2424
 
2128
- const bool numa = kw_values[0] == Qundef ? false : (RTEST ? true : false);
2425
+ const bool numa = kw_values[0] == Qundef ? false : (RTEST(kw_values[0]) ? true : false);
2129
2426
  llama_backend_init(numa);
2130
2427
 
2131
2428
  return Qnil;
@@ -2208,6 +2505,8 @@ extern "C" void Init_llama_cpp(void) {
2208
2505
  RbLLaMAContext::define_class(rb_mLLaMACpp);
2209
2506
  RbLLaMAContextParams::define_class(rb_mLLaMACpp);
2210
2507
  RbLLaMAModelQuantizeParams::define_class(rb_mLLaMACpp);
2508
+ RbLLaMAGrammarElement::define_class(rb_mLLaMACpp);
2509
+ RbLLaMAGrammar::define_class(rb_mLLaMACpp);
2211
2510
 
2212
2511
  rb_define_module_function(rb_mLLaMACpp, "backend_init", rb_llama_llama_backend_init, -1);
2213
2512
  rb_define_module_function(rb_mLLaMACpp, "backend_free", rb_llama_llama_backend_free, 0);
@@ -2240,6 +2539,14 @@ extern "C" void Init_llama_cpp(void) {
2240
2539
  rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q5_K_M", INT2NUM(LLAMA_FTYPE_MOSTLY_Q5_K_M));
2241
2540
  rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q6_K", INT2NUM(LLAMA_FTYPE_MOSTLY_Q6_K));
2242
2541
 
2542
+ rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_END", INT2NUM(LLAMA_GRETYPE_END));
2543
+ rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_ALT", INT2NUM(LLAMA_GRETYPE_ALT));
2544
+ rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_RULE_REF", INT2NUM(LLAMA_GRETYPE_RULE_REF));
2545
+ rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR", INT2NUM(LLAMA_GRETYPE_CHAR));
2546
+ rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR_NOT", INT2NUM(LLAMA_GRETYPE_CHAR_NOT));
2547
+ rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR_RNG_UPPER", INT2NUM(LLAMA_GRETYPE_CHAR_RNG_UPPER));
2548
+ rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR_ALT", INT2NUM(LLAMA_GRETYPE_CHAR_ALT));
2549
+
2243
2550
  std::stringstream ss_magic;
2244
2551
  ss_magic << std::showbase << std::hex << LLAMA_FILE_MAGIC_GGJT;
2245
2552
  rb_define_const(rb_mLLaMACpp, "LLAMA_FILE_MAGIC_GGJT", rb_str_new2(ss_magic.str().c_str()));