llama_cpp 0.3.4 → 0.3.6

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 35afb5cc65c290036ae7e45459eadc9b509f34f33a3f7708244cf47f1a38829f
4
- data.tar.gz: 3301158526c63d9d2004e22bda0d1cc8025b4343d8d737df96260786531b074d
3
+ metadata.gz: 545786d4c9308ffe0f7e214a12427beaea0b26bec915ff84b16eed25ef1932a4
4
+ data.tar.gz: aaa0d4fc1710b13a26163306c8b51e423233c2f7e4b3d6127f94c9b6c4846f9c
5
5
  SHA512:
6
- metadata.gz: b0a50f9f012f44f119a70790d3de07c7fcc64151246791e270e4ff9fc479a85a01c53cf2775945eba3145a3ba89da55a8d14891c6236cfeae16aed5ae455cf0d
7
- data.tar.gz: ede388584e115ae93d509b6c15b288303c348f3cfe8ea46879a1b69e6c96be31a321edbb52cfbeb309a8fb456738f3f6b7cc1d3f71ce7addbd05b3a1e73d4755
6
+ metadata.gz: 12b3ac122fd7ea59b51e2d6ff905ed1a71cf8a8b3650a269d4a3793ae32a0149f6836a792c8f216d0fdb0c39aeb3b47914e73ffc74b574bbe686660e6be84ea1
7
+ data.tar.gz: 5056b95552f3434692a6c19653810d77bb28ddf9b28abd78712ccfb4ee4f7d836a5d54e283513fcfc617cc79ffa7bb9257d4ac2b6d96ec89158bf94acd4cec86
data/CHANGELOG.md CHANGED
@@ -1,3 +1,15 @@
1
+ ## [[0.3.6](https://github.com/yoshoku/llama_cpp.rb/compare/v0.3.5...v0.3.6)] - 2023-08-04
2
+
3
+ - Bump bundled llama.cpp from master-1a94186 to master-468ea24.
4
+ - Add `mul_mat_q` option to ContextParams.
5
+
6
+ ## [[0.3.5](https://github.com/yoshoku/llama_cpp.rb/compare/v0.3.4...v0.3.5)] - 2023-07-29
7
+
8
+ - Bump bundled llama.cpp from master-d924522 to master-1a94186.
9
+ - Add `GrammarElement` and `Grammar` classes.
10
+ - Add `sample_grammar` method to Context.
11
+ - Add `grammar_accept_token method` method to Context.
12
+
1
13
  ## [[0.3.4](https://github.com/yoshoku/llama_cpp.rb/compare/v0.3.3...v0.3.4)] - 2023-07-23
2
14
 
3
15
  - Bump bundled llama.cpp from master-32c5411 to master-d924522.
data/README.md CHANGED
@@ -12,11 +12,27 @@ This gem is still under development and may undergo many changes in the future.
12
12
 
13
13
  Install the gem and add to the application's Gemfile by executing:
14
14
 
15
- $ bundle add llama_cpp
15
+ ```sh
16
+ $ bundle add llama_cpp
17
+ ```
16
18
 
17
19
  If bundler is not being used to manage dependencies, install the gem by executing:
18
20
 
19
- $ gem install llama_cpp
21
+ ```sh
22
+ $ gem install llama_cpp
23
+ ```
24
+
25
+ There are several installation options for improving execution performance:
26
+
27
+ ```sh
28
+ # use OpenBLAS
29
+ $ gem install llama_cpp -- --with-openblas
30
+
31
+ # use Metal on macOS
32
+ $ gem install llama_cpp -- --with-metal
33
+ ```
34
+
35
+ Those options are defined in [extconf.rb](https://github.com/yoshoku/llama_cpp.rb/blob/main/ext/llama_cpp/extconf.rb) by with_config method.
20
36
 
21
37
  ## Usage
22
38
 
@@ -5,7 +5,7 @@ require 'fileutils'
5
5
 
6
6
  abort 'libstdc++ is not found.' unless have_library('stdc++')
7
7
 
8
- $srcs = %w[ggml.c llama.cpp llama_cpp.cpp]
8
+ $srcs = %w[ggml.c ggml-alloc.c llama.cpp llama_cpp.cpp]
9
9
  $srcs << 'ggml-opencl.cpp' if with_config('clblast')
10
10
  $srcs << 'ggml-mpi.c' if with_config('mpi')
11
11
  $CFLAGS << ' -w -DNDEBUG'
@@ -85,6 +85,7 @@ if with_config('mpi')
85
85
  $CXXFLAGS << ' -DGGML_USE_MPI -Wno-cast-qual'
86
86
  end
87
87
 
88
+ # @!visibility private
88
89
  UNAME_M = RbConfig::CONFIG['build_cpu'] || RbConfig::CONFIG['host_cpu'] || RbConfig::CONFIG['target_cpu']
89
90
 
90
91
  # rubocop:disable Layout/LineLength
@@ -8,6 +8,8 @@ VALUE rb_cLLaMAContextParams;
8
8
  VALUE rb_cLLaMAModelQuantizeParams;
9
9
  VALUE rb_cLLaMATokenData;
10
10
  VALUE rb_cLLaMATokenDataArray;
11
+ VALUE rb_cLLaMAGrammarElement;
12
+ VALUE rb_cLLaMAGrammar;
11
13
 
12
14
  class LLaMATokenDataWrapper {
13
15
  public:
@@ -412,6 +414,8 @@ public:
412
414
  rb_define_method(rb_cLLaMAContextParams, "rope_freq_scale", RUBY_METHOD_FUNC(_llama_context_params_get_rope_freq_scale), 0);
413
415
  rb_define_method(rb_cLLaMAContextParams, "low_vram=", RUBY_METHOD_FUNC(_llama_context_params_set_low_vram), 1);
414
416
  rb_define_method(rb_cLLaMAContextParams, "low_vram", RUBY_METHOD_FUNC(_llama_context_params_get_low_vram), 0);
417
+ rb_define_method(rb_cLLaMAContextParams, "mul_mat_q=", RUBY_METHOD_FUNC(_llama_context_params_set_mul_mat_q), 1);
418
+ rb_define_method(rb_cLLaMAContextParams, "mul_mat_q", RUBY_METHOD_FUNC(_llama_context_params_get_mul_mat_q), 0);
415
419
  rb_define_method(rb_cLLaMAContextParams, "seed=", RUBY_METHOD_FUNC(_llama_context_params_set_seed), 1);
416
420
  rb_define_method(rb_cLLaMAContextParams, "seed", RUBY_METHOD_FUNC(_llama_context_params_get_seed), 0);
417
421
  rb_define_method(rb_cLLaMAContextParams, "f16_kv=", RUBY_METHOD_FUNC(_llama_context_params_set_f16_kv), 1);
@@ -525,7 +529,7 @@ private:
525
529
  // low_vram
526
530
  static VALUE _llama_context_params_set_low_vram(VALUE self, VALUE low_vram) {
527
531
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
528
- ptr->params.low_vram = low_vram == Qtrue ? true : false;
532
+ ptr->params.low_vram = RTEST(low_vram) ? true : false;
529
533
  return ptr->params.low_vram ? Qtrue : Qfalse;
530
534
  }
531
535
 
@@ -534,6 +538,18 @@ private:
534
538
  return ptr->params.low_vram ? Qtrue : Qfalse;
535
539
  }
536
540
 
541
+ // mul_mat_q
542
+ static VALUE _llama_context_params_set_mul_mat_q(VALUE self, VALUE mul_mat_q) {
543
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
544
+ ptr->params.mul_mat_q = RTEST(mul_mat_q) ? true : false;
545
+ return ptr->params.mul_mat_q ? Qtrue : Qfalse;
546
+ }
547
+
548
+ static VALUE _llama_context_params_get_mul_mat_q(VALUE self) {
549
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
550
+ return ptr->params.mul_mat_q ? Qtrue : Qfalse;
551
+ }
552
+
537
553
  // seed
538
554
  static VALUE _llama_context_params_set_seed(VALUE self, VALUE seed) {
539
555
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
@@ -553,7 +569,7 @@ private:
553
569
  // f16_kv
554
570
  static VALUE _llama_context_params_set_f16_kv(VALUE self, VALUE f16_kv) {
555
571
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
556
- ptr->params.f16_kv = f16_kv == Qtrue ? true : false;
572
+ ptr->params.f16_kv = RTEST(f16_kv) ? true : false;
557
573
  return ptr->params.f16_kv ? Qtrue : Qfalse;
558
574
  }
559
575
 
@@ -565,7 +581,7 @@ private:
565
581
  // logits_all
566
582
  static VALUE _llama_context_params_set_logits_all(VALUE self, VALUE logits_all) {
567
583
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
568
- ptr->params.logits_all = logits_all == Qtrue ? true : false;
584
+ ptr->params.logits_all = RTEST(logits_all) ? true : false;
569
585
  return ptr->params.logits_all ? Qtrue : Qfalse;
570
586
  }
571
587
 
@@ -577,7 +593,7 @@ private:
577
593
  // vocab_only
578
594
  static VALUE _llama_context_params_set_vocab_only(VALUE self, VALUE vocab_only) {
579
595
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
580
- ptr->params.vocab_only = vocab_only == Qtrue ? true : false;
596
+ ptr->params.vocab_only = RTEST(vocab_only) ? true : false;
581
597
  return ptr->params.vocab_only ? Qtrue : Qfalse;
582
598
  }
583
599
 
@@ -589,7 +605,7 @@ private:
589
605
  // use_mmap
590
606
  static VALUE _llama_context_params_set_use_mmap(VALUE self, VALUE use_mmap) {
591
607
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
592
- ptr->params.use_mmap = use_mmap == Qtrue ? true : false;
608
+ ptr->params.use_mmap = RTEST(use_mmap) ? true : false;
593
609
  return ptr->params.use_mmap ? Qtrue : Qfalse;
594
610
  }
595
611
 
@@ -601,7 +617,7 @@ private:
601
617
  // use_mlock
602
618
  static VALUE _llama_context_params_set_use_mlock(VALUE self, VALUE use_mlock) {
603
619
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
604
- ptr->params.use_mlock = use_mlock == Qtrue ? true : false;
620
+ ptr->params.use_mlock = RTEST(use_mlock) ? true : false;
605
621
  return ptr->params.use_mlock ? Qtrue : Qfalse;
606
622
  }
607
623
 
@@ -613,7 +629,7 @@ private:
613
629
  // embedding
614
630
  static VALUE _llama_context_params_set_embedding(VALUE self, VALUE embedding) {
615
631
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
616
- ptr->params.embedding = embedding == Qtrue ? true : false;
632
+ ptr->params.embedding = RTEST(embedding) ? true : false;
617
633
  return ptr->params.embedding ? Qtrue : Qfalse;
618
634
  }
619
635
 
@@ -1057,6 +1073,222 @@ const rb_data_type_t RbLLaMAModel::llama_model_type = {
1057
1073
  RUBY_TYPED_FREE_IMMEDIATELY
1058
1074
  };
1059
1075
 
1076
+ class LLaMAGrammarElementWrapper {
1077
+ public:
1078
+ llama_grammar_element element;
1079
+
1080
+ LLaMAGrammarElementWrapper() {
1081
+ element.type = LLAMA_GRETYPE_END;
1082
+ element.value = 0;
1083
+ }
1084
+
1085
+ ~LLaMAGrammarElementWrapper() {}
1086
+ };
1087
+
1088
+ class RbLLaMAGrammarElement {
1089
+ public:
1090
+ static VALUE llama_grammar_element_alloc(VALUE self) {
1091
+ LLaMAGrammarElementWrapper* ptr = (LLaMAGrammarElementWrapper*)ruby_xmalloc(sizeof(LLaMAGrammarElementWrapper));
1092
+ new (ptr) LLaMAGrammarElementWrapper();
1093
+ return TypedData_Wrap_Struct(self, &llama_grammar_element_type, ptr);
1094
+ }
1095
+
1096
+ static void llama_grammar_element_free(void* ptr) {
1097
+ ((LLaMAGrammarElementWrapper*)ptr)->~LLaMAGrammarElementWrapper();
1098
+ ruby_xfree(ptr);
1099
+ }
1100
+
1101
+ static size_t llama_grammar_element_size(const void* ptr) {
1102
+ return sizeof(*((LLaMAGrammarElementWrapper*)ptr));
1103
+ }
1104
+
1105
+ static LLaMAGrammarElementWrapper* get_llama_grammar_element(VALUE self) {
1106
+ LLaMAGrammarElementWrapper* ptr;
1107
+ TypedData_Get_Struct(self, LLaMAGrammarElementWrapper, &llama_grammar_element_type, ptr);
1108
+ return ptr;
1109
+ }
1110
+
1111
+ static void define_class(VALUE outer) {
1112
+ rb_cLLaMAGrammarElement = rb_define_class_under(outer, "GrammarElement", rb_cObject);
1113
+ rb_define_alloc_func(rb_cLLaMAGrammarElement, llama_grammar_element_alloc);
1114
+ rb_define_method(rb_cLLaMAGrammarElement, "initialize", RUBY_METHOD_FUNC(_llama_grammar_element_init), -1);
1115
+ rb_define_method(rb_cLLaMAGrammarElement, "type=", RUBY_METHOD_FUNC(_llama_grammar_element_set_type), 1);
1116
+ rb_define_method(rb_cLLaMAGrammarElement, "type", RUBY_METHOD_FUNC(_llama_grammar_element_get_type), 0);
1117
+ rb_define_method(rb_cLLaMAGrammarElement, "value=", RUBY_METHOD_FUNC(_llama_grammar_element_set_value), 1);
1118
+ rb_define_method(rb_cLLaMAGrammarElement, "value", RUBY_METHOD_FUNC(_llama_grammar_element_get_value), 0);
1119
+ }
1120
+
1121
+ private:
1122
+ static const rb_data_type_t llama_grammar_element_type;
1123
+
1124
+ static VALUE _llama_grammar_element_init(int argc, VALUE* argv, VALUE self) {
1125
+ VALUE kw_args = Qnil;
1126
+ ID kw_table[2] = { rb_intern("type"), rb_intern("value") };
1127
+ VALUE kw_values[2] = { Qundef, Qundef };
1128
+ VALUE arr = Qnil;
1129
+ rb_scan_args(argc, argv, ":", &arr, &kw_args);
1130
+ rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
1131
+
1132
+ if (kw_values[0] != Qundef && !RB_INTEGER_TYPE_P(kw_values[0])) {
1133
+ rb_raise(rb_eArgError, "type must be an integer");
1134
+ return Qnil;
1135
+ }
1136
+ if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
1137
+ rb_raise(rb_eArgError, "value must be an integer");
1138
+ return Qnil;
1139
+ }
1140
+
1141
+ LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
1142
+ new (ptr) LLaMAGrammarElementWrapper();
1143
+
1144
+ if (kw_values[0] != Qundef) {
1145
+ ptr->element.type = (enum llama_gretype)NUM2INT(kw_values[0]);
1146
+ }
1147
+ if (kw_values[1] != Qundef) {
1148
+ ptr->element.value = NUM2INT(kw_values[1]);
1149
+ }
1150
+
1151
+ return self;
1152
+ }
1153
+
1154
+ // type
1155
+ static VALUE _llama_grammar_element_set_type(VALUE self, VALUE type) {
1156
+ LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
1157
+ ptr->element.type = (enum llama_gretype)NUM2INT(type);
1158
+ return INT2NUM(ptr->element.type);
1159
+ }
1160
+
1161
+ static VALUE _llama_grammar_element_get_type(VALUE self) {
1162
+ LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
1163
+ return INT2NUM(ptr->element.type);
1164
+ }
1165
+
1166
+ // value
1167
+ static VALUE _llama_grammar_element_set_value(VALUE self, VALUE type) {
1168
+ LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
1169
+ ptr->element.value = NUM2INT(type);
1170
+ return INT2NUM(ptr->element.value);
1171
+ }
1172
+
1173
+ static VALUE _llama_grammar_element_get_value(VALUE self) {
1174
+ LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
1175
+ return INT2NUM(ptr->element.value);
1176
+ }
1177
+ };
1178
+
1179
+ const rb_data_type_t RbLLaMAGrammarElement::llama_grammar_element_type = {
1180
+ "RbLLaMAGrammarElement",
1181
+ { NULL,
1182
+ RbLLaMAGrammarElement::llama_grammar_element_free,
1183
+ RbLLaMAGrammarElement::llama_grammar_element_size },
1184
+ NULL,
1185
+ NULL,
1186
+ RUBY_TYPED_FREE_IMMEDIATELY
1187
+ };
1188
+
1189
+ class LLaMAGrammarWrapper {
1190
+ public:
1191
+ struct llama_grammar* grammar;
1192
+
1193
+ LLaMAGrammarWrapper() : grammar(nullptr) {}
1194
+
1195
+ ~LLaMAGrammarWrapper() {
1196
+ if (grammar) {
1197
+ llama_grammar_free(grammar);
1198
+ }
1199
+ }
1200
+ };
1201
+
1202
+ class RbLLaMAGrammar {
1203
+ public:
1204
+ static VALUE llama_grammar_alloc(VALUE self) {
1205
+ LLaMAGrammarWrapper* ptr = (LLaMAGrammarWrapper*)ruby_xmalloc(sizeof(LLaMAGrammarWrapper));
1206
+ new (ptr) LLaMAGrammarWrapper();
1207
+ return TypedData_Wrap_Struct(self, &llama_grammar_type, ptr);
1208
+ }
1209
+
1210
+ static void llama_grammar_free(void* ptr) {
1211
+ ((LLaMAGrammarWrapper*)ptr)->~LLaMAGrammarWrapper();
1212
+ ruby_xfree(ptr);
1213
+ }
1214
+
1215
+ static size_t llama_grammar_size(const void* ptr) {
1216
+ return sizeof(*((LLaMAGrammarWrapper*)ptr));
1217
+ }
1218
+
1219
+ static LLaMAGrammarWrapper* get_llama_grammar(VALUE self) {
1220
+ LLaMAGrammarWrapper* ptr;
1221
+ TypedData_Get_Struct(self, LLaMAGrammarWrapper, &llama_grammar_type, ptr);
1222
+ return ptr;
1223
+ }
1224
+
1225
+ static void define_class(VALUE outer) {
1226
+ rb_cLLaMAGrammar = rb_define_class_under(outer, "Grammar", rb_cObject);
1227
+ rb_define_alloc_func(rb_cLLaMAGrammar, llama_grammar_alloc);
1228
+ rb_define_method(rb_cLLaMAGrammar, "initialize", RUBY_METHOD_FUNC(_llama_grammar_init), -1);
1229
+ }
1230
+
1231
+ private:
1232
+ static const rb_data_type_t llama_grammar_type;
1233
+
1234
+ static VALUE _llama_grammar_init(int argc, VALUE* argv, VALUE self) {
1235
+ VALUE kw_args = Qnil;
1236
+ ID kw_table[2] = { rb_intern("rules"), rb_intern("start_rule_index") };
1237
+ VALUE kw_values[2] = { Qundef, Qundef };
1238
+ rb_scan_args(argc, argv, ":", &kw_args);
1239
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
1240
+
1241
+ if (!RB_TYPE_P(kw_values[0], T_ARRAY)) {
1242
+ rb_raise(rb_eArgError, "rules must be an array");
1243
+ return Qnil;
1244
+ }
1245
+ if (!RB_INTEGER_TYPE_P(kw_values[1])) {
1246
+ rb_raise(rb_eArgError, "start_rule_index must be an integer");
1247
+ return Qnil;
1248
+ }
1249
+
1250
+ const int n_rules = RARRAY_LEN(kw_values[0]);
1251
+ llama_grammar_element** rules = ALLOCA_N(llama_grammar_element*, n_rules);
1252
+ for (int i = 0; i < n_rules; ++i) {
1253
+ VALUE rule = rb_ary_entry(kw_values[0], i);
1254
+ if (!RB_TYPE_P(rule, T_ARRAY)) {
1255
+ rb_raise(rb_eArgError, "element of rules must be an array");
1256
+ return Qnil;
1257
+ }
1258
+ const int n_elements = RARRAY_LEN(rule);
1259
+ llama_grammar_element* elements = ALLOCA_N(llama_grammar_element, n_elements);
1260
+ for (int j = 0; j < n_elements; ++j) {
1261
+ VALUE element = rb_ary_entry(rule, j);
1262
+ if (!rb_obj_is_kind_of(element, rb_cLLaMAGrammarElement)) {
1263
+ rb_raise(rb_eArgError, "element of rule must be an instance of GrammarElement");
1264
+ return Qnil;
1265
+ }
1266
+ LLaMAGrammarElementWrapper* ptr = RbLLaMAGrammarElement::get_llama_grammar_element(element);
1267
+ elements[j] = ptr->element;
1268
+ }
1269
+ rules[i] = elements;
1270
+ }
1271
+
1272
+ const size_t start_rule_index = NUM2SIZET(kw_values[1]);
1273
+
1274
+ LLaMAGrammarWrapper* ptr = get_llama_grammar(self);
1275
+ new (ptr) LLaMAGrammarWrapper();
1276
+ ptr->grammar = llama_grammar_init((const llama_grammar_element**)rules, n_rules, start_rule_index);
1277
+
1278
+ return self;
1279
+ }
1280
+ };
1281
+
1282
+ const rb_data_type_t RbLLaMAGrammar::llama_grammar_type = {
1283
+ "RbLLaMAGrammar",
1284
+ { NULL,
1285
+ RbLLaMAGrammar::llama_grammar_free,
1286
+ RbLLaMAGrammar::llama_grammar_size },
1287
+ NULL,
1288
+ NULL,
1289
+ RUBY_TYPED_FREE_IMMEDIATELY
1290
+ };
1291
+
1060
1292
  class LLaMAContextWrapper {
1061
1293
  public:
1062
1294
  struct llama_context* ctx;
@@ -1128,6 +1360,8 @@ public:
1128
1360
  rb_define_method(rb_cLLaMAContext, "sample_token_mirostat_v2", RUBY_METHOD_FUNC(_llama_context_sample_token_mirostat_v2), -1);
1129
1361
  rb_define_method(rb_cLLaMAContext, "sample_token_greedy", RUBY_METHOD_FUNC(_llama_context_sample_token_greedy), 1);
1130
1362
  rb_define_method(rb_cLLaMAContext, "sample_token", RUBY_METHOD_FUNC(_llama_context_sample_token), 1);
1363
+ rb_define_method(rb_cLLaMAContext, "sample_grammar", RUBY_METHOD_FUNC(_llama_context_sample_grammar), -1);
1364
+ rb_define_method(rb_cLLaMAContext, "grammar_accept_token", RUBY_METHOD_FUNC(_llama_context_grammar_accept_token), -1);
1131
1365
  }
1132
1366
 
1133
1367
  private:
@@ -2104,6 +2338,69 @@ private:
2104
2338
  llama_token id = llama_sample_token(ctx_ptr->ctx, &(cnd_ptr->array));
2105
2339
  return INT2NUM(id);
2106
2340
  }
2341
+
2342
+ static VALUE _llama_context_sample_grammar(int argc, VALUE* argv, VALUE self) {
2343
+ VALUE kw_args = Qnil;
2344
+ ID kw_table[1] = { rb_intern("grammar") };
2345
+ VALUE kw_values[1] = { Qundef };
2346
+ VALUE candidates = Qnil;
2347
+ rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
2348
+ rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
2349
+
2350
+ if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
2351
+ rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
2352
+ return Qnil;
2353
+ }
2354
+ if (!rb_obj_is_kind_of(kw_values[0], rb_cLLaMAGrammar)) {
2355
+ rb_raise(rb_eArgError, "grammar must be a Grammar");
2356
+ return Qnil;
2357
+ }
2358
+
2359
+ LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
2360
+ if (ctx_ptr->ctx == NULL) {
2361
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2362
+ return Qnil;
2363
+ }
2364
+ LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
2365
+ if (cnd_ptr->array.data == nullptr) {
2366
+ rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
2367
+ return Qnil;
2368
+ }
2369
+ LLaMAGrammarWrapper* grm_ptr = RbLLaMAGrammar::get_llama_grammar(kw_values[0]);
2370
+
2371
+ llama_sample_grammar(ctx_ptr->ctx, &(cnd_ptr->array), grm_ptr->grammar);
2372
+
2373
+ return Qnil;
2374
+ }
2375
+
2376
+ static VALUE _llama_context_grammar_accept_token(int argc, VALUE* argv, VALUE self) {
2377
+ VALUE kw_args = Qnil;
2378
+ ID kw_table[2] = { rb_intern("grammar"), rb_intern("token") };
2379
+ VALUE kw_values[2] = { Qundef, Qundef };
2380
+ rb_scan_args(argc, argv, ":", &kw_args);
2381
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
2382
+
2383
+ if (!rb_obj_is_kind_of(kw_values[0], rb_cLLaMAGrammar)) {
2384
+ rb_raise(rb_eArgError, "grammar must be a Grammar");
2385
+ return Qnil;
2386
+ }
2387
+ if (!RB_INTEGER_TYPE_P(kw_values[1])) {
2388
+ rb_raise(rb_eArgError, "token must be an Integer");
2389
+ return Qnil;
2390
+ }
2391
+
2392
+ LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
2393
+ if (ctx_ptr->ctx == NULL) {
2394
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2395
+ return Qnil;
2396
+ }
2397
+ LLaMAGrammarWrapper* grm_ptr = RbLLaMAGrammar::get_llama_grammar(kw_values[0]);
2398
+ llama_token token = NUM2INT(kw_values[1]);
2399
+
2400
+ llama_grammar_accept_token(ctx_ptr->ctx, grm_ptr->grammar, token);
2401
+
2402
+ return Qnil;
2403
+ }
2107
2404
  };
2108
2405
 
2109
2406
  const rb_data_type_t RbLLaMAContext::llama_context_type = {
@@ -2125,7 +2422,7 @@ static VALUE rb_llama_llama_backend_init(int argc, VALUE* argv, VALUE self) {
2125
2422
  rb_scan_args(argc, argv, ":", &kw_args);
2126
2423
  rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
2127
2424
 
2128
- const bool numa = kw_values[0] == Qundef ? false : (RTEST ? true : false);
2425
+ const bool numa = kw_values[0] == Qundef ? false : (RTEST(kw_values[0]) ? true : false);
2129
2426
  llama_backend_init(numa);
2130
2427
 
2131
2428
  return Qnil;
@@ -2208,6 +2505,8 @@ extern "C" void Init_llama_cpp(void) {
2208
2505
  RbLLaMAContext::define_class(rb_mLLaMACpp);
2209
2506
  RbLLaMAContextParams::define_class(rb_mLLaMACpp);
2210
2507
  RbLLaMAModelQuantizeParams::define_class(rb_mLLaMACpp);
2508
+ RbLLaMAGrammarElement::define_class(rb_mLLaMACpp);
2509
+ RbLLaMAGrammar::define_class(rb_mLLaMACpp);
2211
2510
 
2212
2511
  rb_define_module_function(rb_mLLaMACpp, "backend_init", rb_llama_llama_backend_init, -1);
2213
2512
  rb_define_module_function(rb_mLLaMACpp, "backend_free", rb_llama_llama_backend_free, 0);
@@ -2240,6 +2539,14 @@ extern "C" void Init_llama_cpp(void) {
2240
2539
  rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q5_K_M", INT2NUM(LLAMA_FTYPE_MOSTLY_Q5_K_M));
2241
2540
  rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q6_K", INT2NUM(LLAMA_FTYPE_MOSTLY_Q6_K));
2242
2541
 
2542
+ rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_END", INT2NUM(LLAMA_GRETYPE_END));
2543
+ rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_ALT", INT2NUM(LLAMA_GRETYPE_ALT));
2544
+ rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_RULE_REF", INT2NUM(LLAMA_GRETYPE_RULE_REF));
2545
+ rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR", INT2NUM(LLAMA_GRETYPE_CHAR));
2546
+ rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR_NOT", INT2NUM(LLAMA_GRETYPE_CHAR_NOT));
2547
+ rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR_RNG_UPPER", INT2NUM(LLAMA_GRETYPE_CHAR_RNG_UPPER));
2548
+ rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR_ALT", INT2NUM(LLAMA_GRETYPE_CHAR_ALT));
2549
+
2243
2550
  std::stringstream ss_magic;
2244
2551
  ss_magic << std::showbase << std::hex << LLAMA_FILE_MAGIC_GGJT;
2245
2552
  rb_define_const(rb_mLLaMACpp, "LLAMA_FILE_MAGIC_GGJT", rb_str_new2(ss_magic.str().c_str()));