llama_cpp 0.3.4 → 0.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 35afb5cc65c290036ae7e45459eadc9b509f34f33a3f7708244cf47f1a38829f
4
- data.tar.gz: 3301158526c63d9d2004e22bda0d1cc8025b4343d8d737df96260786531b074d
3
+ metadata.gz: 991df3df6b16ec98a203a6c6565794988eec04697ccb963faab976f436e1bfcc
4
+ data.tar.gz: dafd3b8274640eb79353e11056f497a02392fef332a46dcc1717878c836f62bd
5
5
  SHA512:
6
- metadata.gz: b0a50f9f012f44f119a70790d3de07c7fcc64151246791e270e4ff9fc479a85a01c53cf2775945eba3145a3ba89da55a8d14891c6236cfeae16aed5ae455cf0d
7
- data.tar.gz: ede388584e115ae93d509b6c15b288303c348f3cfe8ea46879a1b69e6c96be31a321edbb52cfbeb309a8fb456738f3f6b7cc1d3f71ce7addbd05b3a1e73d4755
6
+ metadata.gz: 0a54fdf18c5be5273f01d4d991b59975a8e8b6a0a8f54087fb90df3f1a8a7ebad557d01fb119f3d61cafa8f6c59fb81641624779fda27b732eea2868cb4642e8
7
+ data.tar.gz: 6574064078070502e36ad933bd5efb2f479c94f47a1260286fe485e48d35a7b3985274943c6163189326891349b47b1d9815d77c600b015fe111cc3842179392
data/CHANGELOG.md CHANGED
@@ -1,3 +1,10 @@
1
+ ## [[0.3.5](https://github.com/yoshoku/llama_cpp.rb/compare/v0.3.4...v0.3.5)] - 2023-07-29
2
+
3
+ - Bump bundled llama.cpp from master-d924522 to master-1a94186.
4
+ - Add `GrammarElement` and `Grammar` classes.
5
+ - Add `sample_grammar` method to Context.
6
+ - Add `grammar_accept_token method` method to Context.
7
+
1
8
  ## [[0.3.4](https://github.com/yoshoku/llama_cpp.rb/compare/v0.3.3...v0.3.4)] - 2023-07-23
2
9
 
3
10
  - Bump bundled llama.cpp from master-32c5411 to master-d924522.
@@ -85,6 +85,7 @@ if with_config('mpi')
85
85
  $CXXFLAGS << ' -DGGML_USE_MPI -Wno-cast-qual'
86
86
  end
87
87
 
88
+ # @!visibility private
88
89
  UNAME_M = RbConfig::CONFIG['build_cpu'] || RbConfig::CONFIG['host_cpu'] || RbConfig::CONFIG['target_cpu']
89
90
 
90
91
  # rubocop:disable Layout/LineLength
@@ -8,6 +8,8 @@ VALUE rb_cLLaMAContextParams;
8
8
  VALUE rb_cLLaMAModelQuantizeParams;
9
9
  VALUE rb_cLLaMATokenData;
10
10
  VALUE rb_cLLaMATokenDataArray;
11
+ VALUE rb_cLLaMAGrammarElement;
12
+ VALUE rb_cLLaMAGrammar;
11
13
 
12
14
  class LLaMATokenDataWrapper {
13
15
  public:
@@ -1057,6 +1059,222 @@ const rb_data_type_t RbLLaMAModel::llama_model_type = {
1057
1059
  RUBY_TYPED_FREE_IMMEDIATELY
1058
1060
  };
1059
1061
 
1062
+ class LLaMAGrammarElementWrapper {
1063
+ public:
1064
+ llama_grammar_element element;
1065
+
1066
+ LLaMAGrammarElementWrapper() {
1067
+ element.type = LLAMA_GRETYPE_END;
1068
+ element.value = 0;
1069
+ }
1070
+
1071
+ ~LLaMAGrammarElementWrapper() {}
1072
+ };
1073
+
1074
+ class RbLLaMAGrammarElement {
1075
+ public:
1076
+ static VALUE llama_grammar_element_alloc(VALUE self) {
1077
+ LLaMAGrammarElementWrapper* ptr = (LLaMAGrammarElementWrapper*)ruby_xmalloc(sizeof(LLaMAGrammarElementWrapper));
1078
+ new (ptr) LLaMAGrammarElementWrapper();
1079
+ return TypedData_Wrap_Struct(self, &llama_grammar_element_type, ptr);
1080
+ }
1081
+
1082
+ static void llama_grammar_element_free(void* ptr) {
1083
+ ((LLaMAGrammarElementWrapper*)ptr)->~LLaMAGrammarElementWrapper();
1084
+ ruby_xfree(ptr);
1085
+ }
1086
+
1087
+ static size_t llama_grammar_element_size(const void* ptr) {
1088
+ return sizeof(*((LLaMAGrammarElementWrapper*)ptr));
1089
+ }
1090
+
1091
+ static LLaMAGrammarElementWrapper* get_llama_grammar_element(VALUE self) {
1092
+ LLaMAGrammarElementWrapper* ptr;
1093
+ TypedData_Get_Struct(self, LLaMAGrammarElementWrapper, &llama_grammar_element_type, ptr);
1094
+ return ptr;
1095
+ }
1096
+
1097
+ static void define_class(VALUE outer) {
1098
+ rb_cLLaMAGrammarElement = rb_define_class_under(outer, "GrammarElement", rb_cObject);
1099
+ rb_define_alloc_func(rb_cLLaMAGrammarElement, llama_grammar_element_alloc);
1100
+ rb_define_method(rb_cLLaMAGrammarElement, "initialize", RUBY_METHOD_FUNC(_llama_grammar_element_init), -1);
1101
+ rb_define_method(rb_cLLaMAGrammarElement, "type=", RUBY_METHOD_FUNC(_llama_grammar_element_set_type), 1);
1102
+ rb_define_method(rb_cLLaMAGrammarElement, "type", RUBY_METHOD_FUNC(_llama_grammar_element_get_type), 0);
1103
+ rb_define_method(rb_cLLaMAGrammarElement, "value=", RUBY_METHOD_FUNC(_llama_grammar_element_set_value), 1);
1104
+ rb_define_method(rb_cLLaMAGrammarElement, "value", RUBY_METHOD_FUNC(_llama_grammar_element_get_value), 0);
1105
+ }
1106
+
1107
+ private:
1108
+ static const rb_data_type_t llama_grammar_element_type;
1109
+
1110
+ static VALUE _llama_grammar_element_init(int argc, VALUE* argv, VALUE self) {
1111
+ VALUE kw_args = Qnil;
1112
+ ID kw_table[2] = { rb_intern("type"), rb_intern("value") };
1113
+ VALUE kw_values[2] = { Qundef, Qundef };
1114
+ VALUE arr = Qnil;
1115
+ rb_scan_args(argc, argv, ":", &arr, &kw_args);
1116
+ rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
1117
+
1118
+ if (kw_values[0] != Qundef && !RB_INTEGER_TYPE_P(kw_values[0])) {
1119
+ rb_raise(rb_eArgError, "type must be an integer");
1120
+ return Qnil;
1121
+ }
1122
+ if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
1123
+ rb_raise(rb_eArgError, "value must be an integer");
1124
+ return Qnil;
1125
+ }
1126
+
1127
+ LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
1128
+ new (ptr) LLaMAGrammarElementWrapper();
1129
+
1130
+ if (kw_values[0] != Qundef) {
1131
+ ptr->element.type = (enum llama_gretype)NUM2INT(kw_values[0]);
1132
+ }
1133
+ if (kw_values[1] != Qundef) {
1134
+ ptr->element.value = NUM2INT(kw_values[1]);
1135
+ }
1136
+
1137
+ return self;
1138
+ }
1139
+
1140
+ // type
1141
+ static VALUE _llama_grammar_element_set_type(VALUE self, VALUE type) {
1142
+ LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
1143
+ ptr->element.type = (enum llama_gretype)NUM2INT(type);
1144
+ return INT2NUM(ptr->element.type);
1145
+ }
1146
+
1147
+ static VALUE _llama_grammar_element_get_type(VALUE self) {
1148
+ LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
1149
+ return INT2NUM(ptr->element.type);
1150
+ }
1151
+
1152
+ // value
1153
+ static VALUE _llama_grammar_element_set_value(VALUE self, VALUE type) {
1154
+ LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
1155
+ ptr->element.value = NUM2INT(type);
1156
+ return INT2NUM(ptr->element.value);
1157
+ }
1158
+
1159
+ static VALUE _llama_grammar_element_get_value(VALUE self) {
1160
+ LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
1161
+ return INT2NUM(ptr->element.value);
1162
+ }
1163
+ };
1164
+
1165
+ const rb_data_type_t RbLLaMAGrammarElement::llama_grammar_element_type = {
1166
+ "RbLLaMAGrammarElement",
1167
+ { NULL,
1168
+ RbLLaMAGrammarElement::llama_grammar_element_free,
1169
+ RbLLaMAGrammarElement::llama_grammar_element_size },
1170
+ NULL,
1171
+ NULL,
1172
+ RUBY_TYPED_FREE_IMMEDIATELY
1173
+ };
1174
+
1175
+ class LLaMAGrammarWrapper {
1176
+ public:
1177
+ struct llama_grammar* grammar;
1178
+
1179
+ LLaMAGrammarWrapper() : grammar(nullptr) {}
1180
+
1181
+ ~LLaMAGrammarWrapper() {
1182
+ if (grammar) {
1183
+ llama_grammar_free(grammar);
1184
+ }
1185
+ }
1186
+ };
1187
+
1188
+ class RbLLaMAGrammar {
1189
+ public:
1190
+ static VALUE llama_grammar_alloc(VALUE self) {
1191
+ LLaMAGrammarWrapper* ptr = (LLaMAGrammarWrapper*)ruby_xmalloc(sizeof(LLaMAGrammarWrapper));
1192
+ new (ptr) LLaMAGrammarWrapper();
1193
+ return TypedData_Wrap_Struct(self, &llama_grammar_type, ptr);
1194
+ }
1195
+
1196
+ static void llama_grammar_free(void* ptr) {
1197
+ ((LLaMAGrammarWrapper*)ptr)->~LLaMAGrammarWrapper();
1198
+ ruby_xfree(ptr);
1199
+ }
1200
+
1201
+ static size_t llama_grammar_size(const void* ptr) {
1202
+ return sizeof(*((LLaMAGrammarWrapper*)ptr));
1203
+ }
1204
+
1205
+ static LLaMAGrammarWrapper* get_llama_grammar(VALUE self) {
1206
+ LLaMAGrammarWrapper* ptr;
1207
+ TypedData_Get_Struct(self, LLaMAGrammarWrapper, &llama_grammar_type, ptr);
1208
+ return ptr;
1209
+ }
1210
+
1211
+ static void define_class(VALUE outer) {
1212
+ rb_cLLaMAGrammar = rb_define_class_under(outer, "Grammar", rb_cObject);
1213
+ rb_define_alloc_func(rb_cLLaMAGrammar, llama_grammar_alloc);
1214
+ rb_define_method(rb_cLLaMAGrammar, "initialize", RUBY_METHOD_FUNC(_llama_grammar_init), -1);
1215
+ }
1216
+
1217
+ private:
1218
+ static const rb_data_type_t llama_grammar_type;
1219
+
1220
+ static VALUE _llama_grammar_init(int argc, VALUE* argv, VALUE self) {
1221
+ VALUE kw_args = Qnil;
1222
+ ID kw_table[2] = { rb_intern("rules"), rb_intern("start_rule_index") };
1223
+ VALUE kw_values[2] = { Qundef, Qundef };
1224
+ rb_scan_args(argc, argv, ":", &kw_args);
1225
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
1226
+
1227
+ if (!RB_TYPE_P(kw_values[0], T_ARRAY)) {
1228
+ rb_raise(rb_eArgError, "rules must be an array");
1229
+ return Qnil;
1230
+ }
1231
+ if (!RB_INTEGER_TYPE_P(kw_values[1])) {
1232
+ rb_raise(rb_eArgError, "start_rule_index must be an integer");
1233
+ return Qnil;
1234
+ }
1235
+
1236
+ const int n_rules = RARRAY_LEN(kw_values[0]);
1237
+ llama_grammar_element** rules = ALLOCA_N(llama_grammar_element*, n_rules);
1238
+ for (int i = 0; i < n_rules; ++i) {
1239
+ VALUE rule = rb_ary_entry(kw_values[0], i);
1240
+ if (!RB_TYPE_P(rule, T_ARRAY)) {
1241
+ rb_raise(rb_eArgError, "element of rules must be an array");
1242
+ return Qnil;
1243
+ }
1244
+ const int n_elements = RARRAY_LEN(rule);
1245
+ llama_grammar_element* elements = ALLOCA_N(llama_grammar_element, n_elements);
1246
+ for (int j = 0; j < n_elements; ++j) {
1247
+ VALUE element = rb_ary_entry(rule, j);
1248
+ if (!rb_obj_is_kind_of(element, rb_cLLaMAGrammarElement)) {
1249
+ rb_raise(rb_eArgError, "element of rule must be an instance of GrammarElement");
1250
+ return Qnil;
1251
+ }
1252
+ LLaMAGrammarElementWrapper* ptr = RbLLaMAGrammarElement::get_llama_grammar_element(element);
1253
+ elements[j] = ptr->element;
1254
+ }
1255
+ rules[i] = elements;
1256
+ }
1257
+
1258
+ const size_t start_rule_index = NUM2SIZET(kw_values[1]);
1259
+
1260
+ LLaMAGrammarWrapper* ptr = get_llama_grammar(self);
1261
+ new (ptr) LLaMAGrammarWrapper();
1262
+ ptr->grammar = llama_grammar_init((const llama_grammar_element**)rules, n_rules, start_rule_index);
1263
+
1264
+ return self;
1265
+ }
1266
+ };
1267
+
1268
+ const rb_data_type_t RbLLaMAGrammar::llama_grammar_type = {
1269
+ "RbLLaMAGrammar",
1270
+ { NULL,
1271
+ RbLLaMAGrammar::llama_grammar_free,
1272
+ RbLLaMAGrammar::llama_grammar_size },
1273
+ NULL,
1274
+ NULL,
1275
+ RUBY_TYPED_FREE_IMMEDIATELY
1276
+ };
1277
+
1060
1278
  class LLaMAContextWrapper {
1061
1279
  public:
1062
1280
  struct llama_context* ctx;
@@ -1128,6 +1346,8 @@ public:
1128
1346
  rb_define_method(rb_cLLaMAContext, "sample_token_mirostat_v2", RUBY_METHOD_FUNC(_llama_context_sample_token_mirostat_v2), -1);
1129
1347
  rb_define_method(rb_cLLaMAContext, "sample_token_greedy", RUBY_METHOD_FUNC(_llama_context_sample_token_greedy), 1);
1130
1348
  rb_define_method(rb_cLLaMAContext, "sample_token", RUBY_METHOD_FUNC(_llama_context_sample_token), 1);
1349
+ rb_define_method(rb_cLLaMAContext, "sample_grammar", RUBY_METHOD_FUNC(_llama_context_sample_grammar), -1);
1350
+ rb_define_method(rb_cLLaMAContext, "grammar_accept_token", RUBY_METHOD_FUNC(_llama_context_grammar_accept_token), -1);
1131
1351
  }
1132
1352
 
1133
1353
  private:
@@ -2104,6 +2324,69 @@ private:
2104
2324
  llama_token id = llama_sample_token(ctx_ptr->ctx, &(cnd_ptr->array));
2105
2325
  return INT2NUM(id);
2106
2326
  }
2327
+
2328
+ static VALUE _llama_context_sample_grammar(int argc, VALUE* argv, VALUE self) {
2329
+ VALUE kw_args = Qnil;
2330
+ ID kw_table[1] = { rb_intern("grammar") };
2331
+ VALUE kw_values[1] = { Qundef };
2332
+ VALUE candidates = Qnil;
2333
+ rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
2334
+ rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
2335
+
2336
+ if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
2337
+ rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
2338
+ return Qnil;
2339
+ }
2340
+ if (!rb_obj_is_kind_of(kw_values[0], rb_cLLaMAGrammar)) {
2341
+ rb_raise(rb_eArgError, "grammar must be a Grammar");
2342
+ return Qnil;
2343
+ }
2344
+
2345
+ LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
2346
+ if (ctx_ptr->ctx == NULL) {
2347
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2348
+ return Qnil;
2349
+ }
2350
+ LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
2351
+ if (cnd_ptr->array.data == nullptr) {
2352
+ rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
2353
+ return Qnil;
2354
+ }
2355
+ LLaMAGrammarWrapper* grm_ptr = RbLLaMAGrammar::get_llama_grammar(kw_values[0]);
2356
+
2357
+ llama_sample_grammar(ctx_ptr->ctx, &(cnd_ptr->array), grm_ptr->grammar);
2358
+
2359
+ return Qnil;
2360
+ }
2361
+
2362
+ static VALUE _llama_context_grammar_accept_token(int argc, VALUE* argv, VALUE self) {
2363
+ VALUE kw_args = Qnil;
2364
+ ID kw_table[2] = { rb_intern("grammar"), rb_intern("token") };
2365
+ VALUE kw_values[2] = { Qundef, Qundef };
2366
+ rb_scan_args(argc, argv, ":", &kw_args);
2367
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
2368
+
2369
+ if (!rb_obj_is_kind_of(kw_values[0], rb_cLLaMAGrammar)) {
2370
+ rb_raise(rb_eArgError, "grammar must be a Grammar");
2371
+ return Qnil;
2372
+ }
2373
+ if (!RB_INTEGER_TYPE_P(kw_values[1])) {
2374
+ rb_raise(rb_eArgError, "token must be an Integer");
2375
+ return Qnil;
2376
+ }
2377
+
2378
+ LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
2379
+ if (ctx_ptr->ctx == NULL) {
2380
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2381
+ return Qnil;
2382
+ }
2383
+ LLaMAGrammarWrapper* grm_ptr = RbLLaMAGrammar::get_llama_grammar(kw_values[0]);
2384
+ llama_token token = NUM2INT(kw_values[1]);
2385
+
2386
+ llama_grammar_accept_token(ctx_ptr->ctx, grm_ptr->grammar, token);
2387
+
2388
+ return Qnil;
2389
+ }
2107
2390
  };
2108
2391
 
2109
2392
  const rb_data_type_t RbLLaMAContext::llama_context_type = {
@@ -2208,6 +2491,8 @@ extern "C" void Init_llama_cpp(void) {
2208
2491
  RbLLaMAContext::define_class(rb_mLLaMACpp);
2209
2492
  RbLLaMAContextParams::define_class(rb_mLLaMACpp);
2210
2493
  RbLLaMAModelQuantizeParams::define_class(rb_mLLaMACpp);
2494
+ RbLLaMAGrammarElement::define_class(rb_mLLaMACpp);
2495
+ RbLLaMAGrammar::define_class(rb_mLLaMACpp);
2211
2496
 
2212
2497
  rb_define_module_function(rb_mLLaMACpp, "backend_init", rb_llama_llama_backend_init, -1);
2213
2498
  rb_define_module_function(rb_mLLaMACpp, "backend_free", rb_llama_llama_backend_free, 0);
@@ -2240,6 +2525,14 @@ extern "C" void Init_llama_cpp(void) {
2240
2525
  rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q5_K_M", INT2NUM(LLAMA_FTYPE_MOSTLY_Q5_K_M));
2241
2526
  rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q6_K", INT2NUM(LLAMA_FTYPE_MOSTLY_Q6_K));
2242
2527
 
2528
+ rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_END", INT2NUM(LLAMA_GRETYPE_END));
2529
+ rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_ALT", INT2NUM(LLAMA_GRETYPE_ALT));
2530
+ rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_RULE_REF", INT2NUM(LLAMA_GRETYPE_RULE_REF));
2531
+ rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR", INT2NUM(LLAMA_GRETYPE_CHAR));
2532
+ rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR_NOT", INT2NUM(LLAMA_GRETYPE_CHAR_NOT));
2533
+ rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR_RNG_UPPER", INT2NUM(LLAMA_GRETYPE_CHAR_RNG_UPPER));
2534
+ rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR_ALT", INT2NUM(LLAMA_GRETYPE_CHAR_ALT));
2535
+
2243
2536
  std::stringstream ss_magic;
2244
2537
  ss_magic << std::showbase << std::hex << LLAMA_FILE_MAGIC_GGJT;
2245
2538
  rb_define_const(rb_mLLaMACpp, "LLAMA_FILE_MAGIC_GGJT", rb_str_new2(ss_magic.str().c_str()));