llama_cpp 0.3.4 → 0.3.6
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +12 -0
- data/README.md +18 -2
- data/ext/llama_cpp/extconf.rb +2 -1
- data/ext/llama_cpp/llama_cpp.cpp +315 -8
- data/ext/llama_cpp/src/ggml-alloc.c +541 -0
- data/ext/llama_cpp/src/ggml-alloc.h +22 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +2271 -414
- data/ext/llama_cpp/src/ggml-cuda.h +1 -0
- data/ext/llama_cpp/src/ggml-metal.h +7 -0
- data/ext/llama_cpp/src/ggml-metal.m +218 -87
- data/ext/llama_cpp/src/ggml-metal.metal +72 -55
- data/ext/llama_cpp/src/ggml.c +754 -996
- data/ext/llama_cpp/src/ggml.h +94 -18
- data/ext/llama_cpp/src/k_quants.c +350 -24
- data/ext/llama_cpp/src/llama.cpp +713 -179
- data/ext/llama_cpp/src/llama.h +61 -5
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +26 -0
- metadata +4 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 545786d4c9308ffe0f7e214a12427beaea0b26bec915ff84b16eed25ef1932a4
|
4
|
+
data.tar.gz: aaa0d4fc1710b13a26163306c8b51e423233c2f7e4b3d6127f94c9b6c4846f9c
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 12b3ac122fd7ea59b51e2d6ff905ed1a71cf8a8b3650a269d4a3793ae32a0149f6836a792c8f216d0fdb0c39aeb3b47914e73ffc74b574bbe686660e6be84ea1
|
7
|
+
data.tar.gz: 5056b95552f3434692a6c19653810d77bb28ddf9b28abd78712ccfb4ee4f7d836a5d54e283513fcfc617cc79ffa7bb9257d4ac2b6d96ec89158bf94acd4cec86
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,15 @@
|
|
1
|
+
## [[0.3.6](https://github.com/yoshoku/llama_cpp.rb/compare/v0.3.5...v0.3.6)] - 2023-08-04
|
2
|
+
|
3
|
+
- Bump bundled llama.cpp from master-1a94186 to master-468ea24.
|
4
|
+
- Add `mul_mat_q` option to ContextParams.
|
5
|
+
|
6
|
+
## [[0.3.5](https://github.com/yoshoku/llama_cpp.rb/compare/v0.3.4...v0.3.5)] - 2023-07-29
|
7
|
+
|
8
|
+
- Bump bundled llama.cpp from master-d924522 to master-1a94186.
|
9
|
+
- Add `GrammarElement` and `Grammar` classes.
|
10
|
+
- Add `sample_grammar` method to Context.
|
11
|
+
- Add `grammar_accept_token method` method to Context.
|
12
|
+
|
1
13
|
## [[0.3.4](https://github.com/yoshoku/llama_cpp.rb/compare/v0.3.3...v0.3.4)] - 2023-07-23
|
2
14
|
|
3
15
|
- Bump bundled llama.cpp from master-32c5411 to master-d924522.
|
data/README.md
CHANGED
@@ -12,11 +12,27 @@ This gem is still under development and may undergo many changes in the future.
|
|
12
12
|
|
13
13
|
Install the gem and add to the application's Gemfile by executing:
|
14
14
|
|
15
|
-
|
15
|
+
```sh
|
16
|
+
$ bundle add llama_cpp
|
17
|
+
```
|
16
18
|
|
17
19
|
If bundler is not being used to manage dependencies, install the gem by executing:
|
18
20
|
|
19
|
-
|
21
|
+
```sh
|
22
|
+
$ gem install llama_cpp
|
23
|
+
```
|
24
|
+
|
25
|
+
There are several installation options for improving execution performance:
|
26
|
+
|
27
|
+
```sh
|
28
|
+
# use OpenBLAS
|
29
|
+
$ gem install llama_cpp -- --with-openblas
|
30
|
+
|
31
|
+
# use Metal on macOS
|
32
|
+
$ gem install llama_cpp -- --with-metal
|
33
|
+
```
|
34
|
+
|
35
|
+
Those options are defined in [extconf.rb](https://github.com/yoshoku/llama_cpp.rb/blob/main/ext/llama_cpp/extconf.rb) by with_config method.
|
20
36
|
|
21
37
|
## Usage
|
22
38
|
|
data/ext/llama_cpp/extconf.rb
CHANGED
@@ -5,7 +5,7 @@ require 'fileutils'
|
|
5
5
|
|
6
6
|
abort 'libstdc++ is not found.' unless have_library('stdc++')
|
7
7
|
|
8
|
-
$srcs = %w[ggml.c llama.cpp llama_cpp.cpp]
|
8
|
+
$srcs = %w[ggml.c ggml-alloc.c llama.cpp llama_cpp.cpp]
|
9
9
|
$srcs << 'ggml-opencl.cpp' if with_config('clblast')
|
10
10
|
$srcs << 'ggml-mpi.c' if with_config('mpi')
|
11
11
|
$CFLAGS << ' -w -DNDEBUG'
|
@@ -85,6 +85,7 @@ if with_config('mpi')
|
|
85
85
|
$CXXFLAGS << ' -DGGML_USE_MPI -Wno-cast-qual'
|
86
86
|
end
|
87
87
|
|
88
|
+
# @!visibility private
|
88
89
|
UNAME_M = RbConfig::CONFIG['build_cpu'] || RbConfig::CONFIG['host_cpu'] || RbConfig::CONFIG['target_cpu']
|
89
90
|
|
90
91
|
# rubocop:disable Layout/LineLength
|
data/ext/llama_cpp/llama_cpp.cpp
CHANGED
@@ -8,6 +8,8 @@ VALUE rb_cLLaMAContextParams;
|
|
8
8
|
VALUE rb_cLLaMAModelQuantizeParams;
|
9
9
|
VALUE rb_cLLaMATokenData;
|
10
10
|
VALUE rb_cLLaMATokenDataArray;
|
11
|
+
VALUE rb_cLLaMAGrammarElement;
|
12
|
+
VALUE rb_cLLaMAGrammar;
|
11
13
|
|
12
14
|
class LLaMATokenDataWrapper {
|
13
15
|
public:
|
@@ -412,6 +414,8 @@ public:
|
|
412
414
|
rb_define_method(rb_cLLaMAContextParams, "rope_freq_scale", RUBY_METHOD_FUNC(_llama_context_params_get_rope_freq_scale), 0);
|
413
415
|
rb_define_method(rb_cLLaMAContextParams, "low_vram=", RUBY_METHOD_FUNC(_llama_context_params_set_low_vram), 1);
|
414
416
|
rb_define_method(rb_cLLaMAContextParams, "low_vram", RUBY_METHOD_FUNC(_llama_context_params_get_low_vram), 0);
|
417
|
+
rb_define_method(rb_cLLaMAContextParams, "mul_mat_q=", RUBY_METHOD_FUNC(_llama_context_params_set_mul_mat_q), 1);
|
418
|
+
rb_define_method(rb_cLLaMAContextParams, "mul_mat_q", RUBY_METHOD_FUNC(_llama_context_params_get_mul_mat_q), 0);
|
415
419
|
rb_define_method(rb_cLLaMAContextParams, "seed=", RUBY_METHOD_FUNC(_llama_context_params_set_seed), 1);
|
416
420
|
rb_define_method(rb_cLLaMAContextParams, "seed", RUBY_METHOD_FUNC(_llama_context_params_get_seed), 0);
|
417
421
|
rb_define_method(rb_cLLaMAContextParams, "f16_kv=", RUBY_METHOD_FUNC(_llama_context_params_set_f16_kv), 1);
|
@@ -525,7 +529,7 @@ private:
|
|
525
529
|
// low_vram
|
526
530
|
static VALUE _llama_context_params_set_low_vram(VALUE self, VALUE low_vram) {
|
527
531
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
528
|
-
ptr->params.low_vram = low_vram
|
532
|
+
ptr->params.low_vram = RTEST(low_vram) ? true : false;
|
529
533
|
return ptr->params.low_vram ? Qtrue : Qfalse;
|
530
534
|
}
|
531
535
|
|
@@ -534,6 +538,18 @@ private:
|
|
534
538
|
return ptr->params.low_vram ? Qtrue : Qfalse;
|
535
539
|
}
|
536
540
|
|
541
|
+
// mul_mat_q
|
542
|
+
static VALUE _llama_context_params_set_mul_mat_q(VALUE self, VALUE mul_mat_q) {
|
543
|
+
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
544
|
+
ptr->params.mul_mat_q = RTEST(mul_mat_q) ? true : false;
|
545
|
+
return ptr->params.mul_mat_q ? Qtrue : Qfalse;
|
546
|
+
}
|
547
|
+
|
548
|
+
static VALUE _llama_context_params_get_mul_mat_q(VALUE self) {
|
549
|
+
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
550
|
+
return ptr->params.mul_mat_q ? Qtrue : Qfalse;
|
551
|
+
}
|
552
|
+
|
537
553
|
// seed
|
538
554
|
static VALUE _llama_context_params_set_seed(VALUE self, VALUE seed) {
|
539
555
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
@@ -553,7 +569,7 @@ private:
|
|
553
569
|
// f16_kv
|
554
570
|
static VALUE _llama_context_params_set_f16_kv(VALUE self, VALUE f16_kv) {
|
555
571
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
556
|
-
ptr->params.f16_kv = f16_kv
|
572
|
+
ptr->params.f16_kv = RTEST(f16_kv) ? true : false;
|
557
573
|
return ptr->params.f16_kv ? Qtrue : Qfalse;
|
558
574
|
}
|
559
575
|
|
@@ -565,7 +581,7 @@ private:
|
|
565
581
|
// logits_all
|
566
582
|
static VALUE _llama_context_params_set_logits_all(VALUE self, VALUE logits_all) {
|
567
583
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
568
|
-
ptr->params.logits_all = logits_all
|
584
|
+
ptr->params.logits_all = RTEST(logits_all) ? true : false;
|
569
585
|
return ptr->params.logits_all ? Qtrue : Qfalse;
|
570
586
|
}
|
571
587
|
|
@@ -577,7 +593,7 @@ private:
|
|
577
593
|
// vocab_only
|
578
594
|
static VALUE _llama_context_params_set_vocab_only(VALUE self, VALUE vocab_only) {
|
579
595
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
580
|
-
ptr->params.vocab_only = vocab_only
|
596
|
+
ptr->params.vocab_only = RTEST(vocab_only) ? true : false;
|
581
597
|
return ptr->params.vocab_only ? Qtrue : Qfalse;
|
582
598
|
}
|
583
599
|
|
@@ -589,7 +605,7 @@ private:
|
|
589
605
|
// use_mmap
|
590
606
|
static VALUE _llama_context_params_set_use_mmap(VALUE self, VALUE use_mmap) {
|
591
607
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
592
|
-
ptr->params.use_mmap = use_mmap
|
608
|
+
ptr->params.use_mmap = RTEST(use_mmap) ? true : false;
|
593
609
|
return ptr->params.use_mmap ? Qtrue : Qfalse;
|
594
610
|
}
|
595
611
|
|
@@ -601,7 +617,7 @@ private:
|
|
601
617
|
// use_mlock
|
602
618
|
static VALUE _llama_context_params_set_use_mlock(VALUE self, VALUE use_mlock) {
|
603
619
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
604
|
-
ptr->params.use_mlock = use_mlock
|
620
|
+
ptr->params.use_mlock = RTEST(use_mlock) ? true : false;
|
605
621
|
return ptr->params.use_mlock ? Qtrue : Qfalse;
|
606
622
|
}
|
607
623
|
|
@@ -613,7 +629,7 @@ private:
|
|
613
629
|
// embedding
|
614
630
|
static VALUE _llama_context_params_set_embedding(VALUE self, VALUE embedding) {
|
615
631
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
616
|
-
ptr->params.embedding = embedding
|
632
|
+
ptr->params.embedding = RTEST(embedding) ? true : false;
|
617
633
|
return ptr->params.embedding ? Qtrue : Qfalse;
|
618
634
|
}
|
619
635
|
|
@@ -1057,6 +1073,222 @@ const rb_data_type_t RbLLaMAModel::llama_model_type = {
|
|
1057
1073
|
RUBY_TYPED_FREE_IMMEDIATELY
|
1058
1074
|
};
|
1059
1075
|
|
1076
|
+
class LLaMAGrammarElementWrapper {
|
1077
|
+
public:
|
1078
|
+
llama_grammar_element element;
|
1079
|
+
|
1080
|
+
LLaMAGrammarElementWrapper() {
|
1081
|
+
element.type = LLAMA_GRETYPE_END;
|
1082
|
+
element.value = 0;
|
1083
|
+
}
|
1084
|
+
|
1085
|
+
~LLaMAGrammarElementWrapper() {}
|
1086
|
+
};
|
1087
|
+
|
1088
|
+
class RbLLaMAGrammarElement {
|
1089
|
+
public:
|
1090
|
+
static VALUE llama_grammar_element_alloc(VALUE self) {
|
1091
|
+
LLaMAGrammarElementWrapper* ptr = (LLaMAGrammarElementWrapper*)ruby_xmalloc(sizeof(LLaMAGrammarElementWrapper));
|
1092
|
+
new (ptr) LLaMAGrammarElementWrapper();
|
1093
|
+
return TypedData_Wrap_Struct(self, &llama_grammar_element_type, ptr);
|
1094
|
+
}
|
1095
|
+
|
1096
|
+
static void llama_grammar_element_free(void* ptr) {
|
1097
|
+
((LLaMAGrammarElementWrapper*)ptr)->~LLaMAGrammarElementWrapper();
|
1098
|
+
ruby_xfree(ptr);
|
1099
|
+
}
|
1100
|
+
|
1101
|
+
static size_t llama_grammar_element_size(const void* ptr) {
|
1102
|
+
return sizeof(*((LLaMAGrammarElementWrapper*)ptr));
|
1103
|
+
}
|
1104
|
+
|
1105
|
+
static LLaMAGrammarElementWrapper* get_llama_grammar_element(VALUE self) {
|
1106
|
+
LLaMAGrammarElementWrapper* ptr;
|
1107
|
+
TypedData_Get_Struct(self, LLaMAGrammarElementWrapper, &llama_grammar_element_type, ptr);
|
1108
|
+
return ptr;
|
1109
|
+
}
|
1110
|
+
|
1111
|
+
static void define_class(VALUE outer) {
|
1112
|
+
rb_cLLaMAGrammarElement = rb_define_class_under(outer, "GrammarElement", rb_cObject);
|
1113
|
+
rb_define_alloc_func(rb_cLLaMAGrammarElement, llama_grammar_element_alloc);
|
1114
|
+
rb_define_method(rb_cLLaMAGrammarElement, "initialize", RUBY_METHOD_FUNC(_llama_grammar_element_init), -1);
|
1115
|
+
rb_define_method(rb_cLLaMAGrammarElement, "type=", RUBY_METHOD_FUNC(_llama_grammar_element_set_type), 1);
|
1116
|
+
rb_define_method(rb_cLLaMAGrammarElement, "type", RUBY_METHOD_FUNC(_llama_grammar_element_get_type), 0);
|
1117
|
+
rb_define_method(rb_cLLaMAGrammarElement, "value=", RUBY_METHOD_FUNC(_llama_grammar_element_set_value), 1);
|
1118
|
+
rb_define_method(rb_cLLaMAGrammarElement, "value", RUBY_METHOD_FUNC(_llama_grammar_element_get_value), 0);
|
1119
|
+
}
|
1120
|
+
|
1121
|
+
private:
|
1122
|
+
static const rb_data_type_t llama_grammar_element_type;
|
1123
|
+
|
1124
|
+
static VALUE _llama_grammar_element_init(int argc, VALUE* argv, VALUE self) {
|
1125
|
+
VALUE kw_args = Qnil;
|
1126
|
+
ID kw_table[2] = { rb_intern("type"), rb_intern("value") };
|
1127
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
1128
|
+
VALUE arr = Qnil;
|
1129
|
+
rb_scan_args(argc, argv, ":", &arr, &kw_args);
|
1130
|
+
rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
|
1131
|
+
|
1132
|
+
if (kw_values[0] != Qundef && !RB_INTEGER_TYPE_P(kw_values[0])) {
|
1133
|
+
rb_raise(rb_eArgError, "type must be an integer");
|
1134
|
+
return Qnil;
|
1135
|
+
}
|
1136
|
+
if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
|
1137
|
+
rb_raise(rb_eArgError, "value must be an integer");
|
1138
|
+
return Qnil;
|
1139
|
+
}
|
1140
|
+
|
1141
|
+
LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
|
1142
|
+
new (ptr) LLaMAGrammarElementWrapper();
|
1143
|
+
|
1144
|
+
if (kw_values[0] != Qundef) {
|
1145
|
+
ptr->element.type = (enum llama_gretype)NUM2INT(kw_values[0]);
|
1146
|
+
}
|
1147
|
+
if (kw_values[1] != Qundef) {
|
1148
|
+
ptr->element.value = NUM2INT(kw_values[1]);
|
1149
|
+
}
|
1150
|
+
|
1151
|
+
return self;
|
1152
|
+
}
|
1153
|
+
|
1154
|
+
// type
|
1155
|
+
static VALUE _llama_grammar_element_set_type(VALUE self, VALUE type) {
|
1156
|
+
LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
|
1157
|
+
ptr->element.type = (enum llama_gretype)NUM2INT(type);
|
1158
|
+
return INT2NUM(ptr->element.type);
|
1159
|
+
}
|
1160
|
+
|
1161
|
+
static VALUE _llama_grammar_element_get_type(VALUE self) {
|
1162
|
+
LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
|
1163
|
+
return INT2NUM(ptr->element.type);
|
1164
|
+
}
|
1165
|
+
|
1166
|
+
// value
|
1167
|
+
static VALUE _llama_grammar_element_set_value(VALUE self, VALUE type) {
|
1168
|
+
LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
|
1169
|
+
ptr->element.value = NUM2INT(type);
|
1170
|
+
return INT2NUM(ptr->element.value);
|
1171
|
+
}
|
1172
|
+
|
1173
|
+
static VALUE _llama_grammar_element_get_value(VALUE self) {
|
1174
|
+
LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
|
1175
|
+
return INT2NUM(ptr->element.value);
|
1176
|
+
}
|
1177
|
+
};
|
1178
|
+
|
1179
|
+
const rb_data_type_t RbLLaMAGrammarElement::llama_grammar_element_type = {
|
1180
|
+
"RbLLaMAGrammarElement",
|
1181
|
+
{ NULL,
|
1182
|
+
RbLLaMAGrammarElement::llama_grammar_element_free,
|
1183
|
+
RbLLaMAGrammarElement::llama_grammar_element_size },
|
1184
|
+
NULL,
|
1185
|
+
NULL,
|
1186
|
+
RUBY_TYPED_FREE_IMMEDIATELY
|
1187
|
+
};
|
1188
|
+
|
1189
|
+
class LLaMAGrammarWrapper {
|
1190
|
+
public:
|
1191
|
+
struct llama_grammar* grammar;
|
1192
|
+
|
1193
|
+
LLaMAGrammarWrapper() : grammar(nullptr) {}
|
1194
|
+
|
1195
|
+
~LLaMAGrammarWrapper() {
|
1196
|
+
if (grammar) {
|
1197
|
+
llama_grammar_free(grammar);
|
1198
|
+
}
|
1199
|
+
}
|
1200
|
+
};
|
1201
|
+
|
1202
|
+
class RbLLaMAGrammar {
|
1203
|
+
public:
|
1204
|
+
static VALUE llama_grammar_alloc(VALUE self) {
|
1205
|
+
LLaMAGrammarWrapper* ptr = (LLaMAGrammarWrapper*)ruby_xmalloc(sizeof(LLaMAGrammarWrapper));
|
1206
|
+
new (ptr) LLaMAGrammarWrapper();
|
1207
|
+
return TypedData_Wrap_Struct(self, &llama_grammar_type, ptr);
|
1208
|
+
}
|
1209
|
+
|
1210
|
+
static void llama_grammar_free(void* ptr) {
|
1211
|
+
((LLaMAGrammarWrapper*)ptr)->~LLaMAGrammarWrapper();
|
1212
|
+
ruby_xfree(ptr);
|
1213
|
+
}
|
1214
|
+
|
1215
|
+
static size_t llama_grammar_size(const void* ptr) {
|
1216
|
+
return sizeof(*((LLaMAGrammarWrapper*)ptr));
|
1217
|
+
}
|
1218
|
+
|
1219
|
+
static LLaMAGrammarWrapper* get_llama_grammar(VALUE self) {
|
1220
|
+
LLaMAGrammarWrapper* ptr;
|
1221
|
+
TypedData_Get_Struct(self, LLaMAGrammarWrapper, &llama_grammar_type, ptr);
|
1222
|
+
return ptr;
|
1223
|
+
}
|
1224
|
+
|
1225
|
+
static void define_class(VALUE outer) {
|
1226
|
+
rb_cLLaMAGrammar = rb_define_class_under(outer, "Grammar", rb_cObject);
|
1227
|
+
rb_define_alloc_func(rb_cLLaMAGrammar, llama_grammar_alloc);
|
1228
|
+
rb_define_method(rb_cLLaMAGrammar, "initialize", RUBY_METHOD_FUNC(_llama_grammar_init), -1);
|
1229
|
+
}
|
1230
|
+
|
1231
|
+
private:
|
1232
|
+
static const rb_data_type_t llama_grammar_type;
|
1233
|
+
|
1234
|
+
static VALUE _llama_grammar_init(int argc, VALUE* argv, VALUE self) {
|
1235
|
+
VALUE kw_args = Qnil;
|
1236
|
+
ID kw_table[2] = { rb_intern("rules"), rb_intern("start_rule_index") };
|
1237
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
1238
|
+
rb_scan_args(argc, argv, ":", &kw_args);
|
1239
|
+
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
1240
|
+
|
1241
|
+
if (!RB_TYPE_P(kw_values[0], T_ARRAY)) {
|
1242
|
+
rb_raise(rb_eArgError, "rules must be an array");
|
1243
|
+
return Qnil;
|
1244
|
+
}
|
1245
|
+
if (!RB_INTEGER_TYPE_P(kw_values[1])) {
|
1246
|
+
rb_raise(rb_eArgError, "start_rule_index must be an integer");
|
1247
|
+
return Qnil;
|
1248
|
+
}
|
1249
|
+
|
1250
|
+
const int n_rules = RARRAY_LEN(kw_values[0]);
|
1251
|
+
llama_grammar_element** rules = ALLOCA_N(llama_grammar_element*, n_rules);
|
1252
|
+
for (int i = 0; i < n_rules; ++i) {
|
1253
|
+
VALUE rule = rb_ary_entry(kw_values[0], i);
|
1254
|
+
if (!RB_TYPE_P(rule, T_ARRAY)) {
|
1255
|
+
rb_raise(rb_eArgError, "element of rules must be an array");
|
1256
|
+
return Qnil;
|
1257
|
+
}
|
1258
|
+
const int n_elements = RARRAY_LEN(rule);
|
1259
|
+
llama_grammar_element* elements = ALLOCA_N(llama_grammar_element, n_elements);
|
1260
|
+
for (int j = 0; j < n_elements; ++j) {
|
1261
|
+
VALUE element = rb_ary_entry(rule, j);
|
1262
|
+
if (!rb_obj_is_kind_of(element, rb_cLLaMAGrammarElement)) {
|
1263
|
+
rb_raise(rb_eArgError, "element of rule must be an instance of GrammarElement");
|
1264
|
+
return Qnil;
|
1265
|
+
}
|
1266
|
+
LLaMAGrammarElementWrapper* ptr = RbLLaMAGrammarElement::get_llama_grammar_element(element);
|
1267
|
+
elements[j] = ptr->element;
|
1268
|
+
}
|
1269
|
+
rules[i] = elements;
|
1270
|
+
}
|
1271
|
+
|
1272
|
+
const size_t start_rule_index = NUM2SIZET(kw_values[1]);
|
1273
|
+
|
1274
|
+
LLaMAGrammarWrapper* ptr = get_llama_grammar(self);
|
1275
|
+
new (ptr) LLaMAGrammarWrapper();
|
1276
|
+
ptr->grammar = llama_grammar_init((const llama_grammar_element**)rules, n_rules, start_rule_index);
|
1277
|
+
|
1278
|
+
return self;
|
1279
|
+
}
|
1280
|
+
};
|
1281
|
+
|
1282
|
+
const rb_data_type_t RbLLaMAGrammar::llama_grammar_type = {
|
1283
|
+
"RbLLaMAGrammar",
|
1284
|
+
{ NULL,
|
1285
|
+
RbLLaMAGrammar::llama_grammar_free,
|
1286
|
+
RbLLaMAGrammar::llama_grammar_size },
|
1287
|
+
NULL,
|
1288
|
+
NULL,
|
1289
|
+
RUBY_TYPED_FREE_IMMEDIATELY
|
1290
|
+
};
|
1291
|
+
|
1060
1292
|
class LLaMAContextWrapper {
|
1061
1293
|
public:
|
1062
1294
|
struct llama_context* ctx;
|
@@ -1128,6 +1360,8 @@ public:
|
|
1128
1360
|
rb_define_method(rb_cLLaMAContext, "sample_token_mirostat_v2", RUBY_METHOD_FUNC(_llama_context_sample_token_mirostat_v2), -1);
|
1129
1361
|
rb_define_method(rb_cLLaMAContext, "sample_token_greedy", RUBY_METHOD_FUNC(_llama_context_sample_token_greedy), 1);
|
1130
1362
|
rb_define_method(rb_cLLaMAContext, "sample_token", RUBY_METHOD_FUNC(_llama_context_sample_token), 1);
|
1363
|
+
rb_define_method(rb_cLLaMAContext, "sample_grammar", RUBY_METHOD_FUNC(_llama_context_sample_grammar), -1);
|
1364
|
+
rb_define_method(rb_cLLaMAContext, "grammar_accept_token", RUBY_METHOD_FUNC(_llama_context_grammar_accept_token), -1);
|
1131
1365
|
}
|
1132
1366
|
|
1133
1367
|
private:
|
@@ -2104,6 +2338,69 @@ private:
|
|
2104
2338
|
llama_token id = llama_sample_token(ctx_ptr->ctx, &(cnd_ptr->array));
|
2105
2339
|
return INT2NUM(id);
|
2106
2340
|
}
|
2341
|
+
|
2342
|
+
static VALUE _llama_context_sample_grammar(int argc, VALUE* argv, VALUE self) {
|
2343
|
+
VALUE kw_args = Qnil;
|
2344
|
+
ID kw_table[1] = { rb_intern("grammar") };
|
2345
|
+
VALUE kw_values[1] = { Qundef };
|
2346
|
+
VALUE candidates = Qnil;
|
2347
|
+
rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
|
2348
|
+
rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
|
2349
|
+
|
2350
|
+
if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
|
2351
|
+
rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
|
2352
|
+
return Qnil;
|
2353
|
+
}
|
2354
|
+
if (!rb_obj_is_kind_of(kw_values[0], rb_cLLaMAGrammar)) {
|
2355
|
+
rb_raise(rb_eArgError, "grammar must be a Grammar");
|
2356
|
+
return Qnil;
|
2357
|
+
}
|
2358
|
+
|
2359
|
+
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
2360
|
+
if (ctx_ptr->ctx == NULL) {
|
2361
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
2362
|
+
return Qnil;
|
2363
|
+
}
|
2364
|
+
LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
|
2365
|
+
if (cnd_ptr->array.data == nullptr) {
|
2366
|
+
rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
|
2367
|
+
return Qnil;
|
2368
|
+
}
|
2369
|
+
LLaMAGrammarWrapper* grm_ptr = RbLLaMAGrammar::get_llama_grammar(kw_values[0]);
|
2370
|
+
|
2371
|
+
llama_sample_grammar(ctx_ptr->ctx, &(cnd_ptr->array), grm_ptr->grammar);
|
2372
|
+
|
2373
|
+
return Qnil;
|
2374
|
+
}
|
2375
|
+
|
2376
|
+
static VALUE _llama_context_grammar_accept_token(int argc, VALUE* argv, VALUE self) {
|
2377
|
+
VALUE kw_args = Qnil;
|
2378
|
+
ID kw_table[2] = { rb_intern("grammar"), rb_intern("token") };
|
2379
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
2380
|
+
rb_scan_args(argc, argv, ":", &kw_args);
|
2381
|
+
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
2382
|
+
|
2383
|
+
if (!rb_obj_is_kind_of(kw_values[0], rb_cLLaMAGrammar)) {
|
2384
|
+
rb_raise(rb_eArgError, "grammar must be a Grammar");
|
2385
|
+
return Qnil;
|
2386
|
+
}
|
2387
|
+
if (!RB_INTEGER_TYPE_P(kw_values[1])) {
|
2388
|
+
rb_raise(rb_eArgError, "token must be an Integer");
|
2389
|
+
return Qnil;
|
2390
|
+
}
|
2391
|
+
|
2392
|
+
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
2393
|
+
if (ctx_ptr->ctx == NULL) {
|
2394
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
2395
|
+
return Qnil;
|
2396
|
+
}
|
2397
|
+
LLaMAGrammarWrapper* grm_ptr = RbLLaMAGrammar::get_llama_grammar(kw_values[0]);
|
2398
|
+
llama_token token = NUM2INT(kw_values[1]);
|
2399
|
+
|
2400
|
+
llama_grammar_accept_token(ctx_ptr->ctx, grm_ptr->grammar, token);
|
2401
|
+
|
2402
|
+
return Qnil;
|
2403
|
+
}
|
2107
2404
|
};
|
2108
2405
|
|
2109
2406
|
const rb_data_type_t RbLLaMAContext::llama_context_type = {
|
@@ -2125,7 +2422,7 @@ static VALUE rb_llama_llama_backend_init(int argc, VALUE* argv, VALUE self) {
|
|
2125
2422
|
rb_scan_args(argc, argv, ":", &kw_args);
|
2126
2423
|
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
2127
2424
|
|
2128
|
-
const bool numa = kw_values[0] == Qundef ? false : (RTEST ? true : false);
|
2425
|
+
const bool numa = kw_values[0] == Qundef ? false : (RTEST(kw_values[0]) ? true : false);
|
2129
2426
|
llama_backend_init(numa);
|
2130
2427
|
|
2131
2428
|
return Qnil;
|
@@ -2208,6 +2505,8 @@ extern "C" void Init_llama_cpp(void) {
|
|
2208
2505
|
RbLLaMAContext::define_class(rb_mLLaMACpp);
|
2209
2506
|
RbLLaMAContextParams::define_class(rb_mLLaMACpp);
|
2210
2507
|
RbLLaMAModelQuantizeParams::define_class(rb_mLLaMACpp);
|
2508
|
+
RbLLaMAGrammarElement::define_class(rb_mLLaMACpp);
|
2509
|
+
RbLLaMAGrammar::define_class(rb_mLLaMACpp);
|
2211
2510
|
|
2212
2511
|
rb_define_module_function(rb_mLLaMACpp, "backend_init", rb_llama_llama_backend_init, -1);
|
2213
2512
|
rb_define_module_function(rb_mLLaMACpp, "backend_free", rb_llama_llama_backend_free, 0);
|
@@ -2240,6 +2539,14 @@ extern "C" void Init_llama_cpp(void) {
|
|
2240
2539
|
rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q5_K_M", INT2NUM(LLAMA_FTYPE_MOSTLY_Q5_K_M));
|
2241
2540
|
rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q6_K", INT2NUM(LLAMA_FTYPE_MOSTLY_Q6_K));
|
2242
2541
|
|
2542
|
+
rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_END", INT2NUM(LLAMA_GRETYPE_END));
|
2543
|
+
rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_ALT", INT2NUM(LLAMA_GRETYPE_ALT));
|
2544
|
+
rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_RULE_REF", INT2NUM(LLAMA_GRETYPE_RULE_REF));
|
2545
|
+
rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR", INT2NUM(LLAMA_GRETYPE_CHAR));
|
2546
|
+
rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR_NOT", INT2NUM(LLAMA_GRETYPE_CHAR_NOT));
|
2547
|
+
rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR_RNG_UPPER", INT2NUM(LLAMA_GRETYPE_CHAR_RNG_UPPER));
|
2548
|
+
rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR_ALT", INT2NUM(LLAMA_GRETYPE_CHAR_ALT));
|
2549
|
+
|
2243
2550
|
std::stringstream ss_magic;
|
2244
2551
|
ss_magic << std::showbase << std::hex << LLAMA_FILE_MAGIC_GGJT;
|
2245
2552
|
rb_define_const(rb_mLLaMACpp, "LLAMA_FILE_MAGIC_GGJT", rb_str_new2(ss_magic.str().c_str()));
|