llama_cpp 0.3.4 → 0.3.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +12 -0
- data/README.md +18 -2
- data/ext/llama_cpp/extconf.rb +2 -1
- data/ext/llama_cpp/llama_cpp.cpp +315 -8
- data/ext/llama_cpp/src/ggml-alloc.c +541 -0
- data/ext/llama_cpp/src/ggml-alloc.h +22 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +2271 -414
- data/ext/llama_cpp/src/ggml-cuda.h +1 -0
- data/ext/llama_cpp/src/ggml-metal.h +7 -0
- data/ext/llama_cpp/src/ggml-metal.m +218 -87
- data/ext/llama_cpp/src/ggml-metal.metal +72 -55
- data/ext/llama_cpp/src/ggml.c +754 -996
- data/ext/llama_cpp/src/ggml.h +94 -18
- data/ext/llama_cpp/src/k_quants.c +350 -24
- data/ext/llama_cpp/src/llama.cpp +713 -179
- data/ext/llama_cpp/src/llama.h +61 -5
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +26 -0
- metadata +4 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 545786d4c9308ffe0f7e214a12427beaea0b26bec915ff84b16eed25ef1932a4
|
4
|
+
data.tar.gz: aaa0d4fc1710b13a26163306c8b51e423233c2f7e4b3d6127f94c9b6c4846f9c
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 12b3ac122fd7ea59b51e2d6ff905ed1a71cf8a8b3650a269d4a3793ae32a0149f6836a792c8f216d0fdb0c39aeb3b47914e73ffc74b574bbe686660e6be84ea1
|
7
|
+
data.tar.gz: 5056b95552f3434692a6c19653810d77bb28ddf9b28abd78712ccfb4ee4f7d836a5d54e283513fcfc617cc79ffa7bb9257d4ac2b6d96ec89158bf94acd4cec86
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,15 @@
|
|
1
|
+
## [[0.3.6](https://github.com/yoshoku/llama_cpp.rb/compare/v0.3.5...v0.3.6)] - 2023-08-04
|
2
|
+
|
3
|
+
- Bump bundled llama.cpp from master-1a94186 to master-468ea24.
|
4
|
+
- Add `mul_mat_q` option to ContextParams.
|
5
|
+
|
6
|
+
## [[0.3.5](https://github.com/yoshoku/llama_cpp.rb/compare/v0.3.4...v0.3.5)] - 2023-07-29
|
7
|
+
|
8
|
+
- Bump bundled llama.cpp from master-d924522 to master-1a94186.
|
9
|
+
- Add `GrammarElement` and `Grammar` classes.
|
10
|
+
- Add `sample_grammar` method to Context.
|
11
|
+
- Add `grammar_accept_token method` method to Context.
|
12
|
+
|
1
13
|
## [[0.3.4](https://github.com/yoshoku/llama_cpp.rb/compare/v0.3.3...v0.3.4)] - 2023-07-23
|
2
14
|
|
3
15
|
- Bump bundled llama.cpp from master-32c5411 to master-d924522.
|
data/README.md
CHANGED
@@ -12,11 +12,27 @@ This gem is still under development and may undergo many changes in the future.
|
|
12
12
|
|
13
13
|
Install the gem and add to the application's Gemfile by executing:
|
14
14
|
|
15
|
-
|
15
|
+
```sh
|
16
|
+
$ bundle add llama_cpp
|
17
|
+
```
|
16
18
|
|
17
19
|
If bundler is not being used to manage dependencies, install the gem by executing:
|
18
20
|
|
19
|
-
|
21
|
+
```sh
|
22
|
+
$ gem install llama_cpp
|
23
|
+
```
|
24
|
+
|
25
|
+
There are several installation options for improving execution performance:
|
26
|
+
|
27
|
+
```sh
|
28
|
+
# use OpenBLAS
|
29
|
+
$ gem install llama_cpp -- --with-openblas
|
30
|
+
|
31
|
+
# use Metal on macOS
|
32
|
+
$ gem install llama_cpp -- --with-metal
|
33
|
+
```
|
34
|
+
|
35
|
+
Those options are defined in [extconf.rb](https://github.com/yoshoku/llama_cpp.rb/blob/main/ext/llama_cpp/extconf.rb) by with_config method.
|
20
36
|
|
21
37
|
## Usage
|
22
38
|
|
data/ext/llama_cpp/extconf.rb
CHANGED
@@ -5,7 +5,7 @@ require 'fileutils'
|
|
5
5
|
|
6
6
|
abort 'libstdc++ is not found.' unless have_library('stdc++')
|
7
7
|
|
8
|
-
$srcs = %w[ggml.c llama.cpp llama_cpp.cpp]
|
8
|
+
$srcs = %w[ggml.c ggml-alloc.c llama.cpp llama_cpp.cpp]
|
9
9
|
$srcs << 'ggml-opencl.cpp' if with_config('clblast')
|
10
10
|
$srcs << 'ggml-mpi.c' if with_config('mpi')
|
11
11
|
$CFLAGS << ' -w -DNDEBUG'
|
@@ -85,6 +85,7 @@ if with_config('mpi')
|
|
85
85
|
$CXXFLAGS << ' -DGGML_USE_MPI -Wno-cast-qual'
|
86
86
|
end
|
87
87
|
|
88
|
+
# @!visibility private
|
88
89
|
UNAME_M = RbConfig::CONFIG['build_cpu'] || RbConfig::CONFIG['host_cpu'] || RbConfig::CONFIG['target_cpu']
|
89
90
|
|
90
91
|
# rubocop:disable Layout/LineLength
|
data/ext/llama_cpp/llama_cpp.cpp
CHANGED
@@ -8,6 +8,8 @@ VALUE rb_cLLaMAContextParams;
|
|
8
8
|
VALUE rb_cLLaMAModelQuantizeParams;
|
9
9
|
VALUE rb_cLLaMATokenData;
|
10
10
|
VALUE rb_cLLaMATokenDataArray;
|
11
|
+
VALUE rb_cLLaMAGrammarElement;
|
12
|
+
VALUE rb_cLLaMAGrammar;
|
11
13
|
|
12
14
|
class LLaMATokenDataWrapper {
|
13
15
|
public:
|
@@ -412,6 +414,8 @@ public:
|
|
412
414
|
rb_define_method(rb_cLLaMAContextParams, "rope_freq_scale", RUBY_METHOD_FUNC(_llama_context_params_get_rope_freq_scale), 0);
|
413
415
|
rb_define_method(rb_cLLaMAContextParams, "low_vram=", RUBY_METHOD_FUNC(_llama_context_params_set_low_vram), 1);
|
414
416
|
rb_define_method(rb_cLLaMAContextParams, "low_vram", RUBY_METHOD_FUNC(_llama_context_params_get_low_vram), 0);
|
417
|
+
rb_define_method(rb_cLLaMAContextParams, "mul_mat_q=", RUBY_METHOD_FUNC(_llama_context_params_set_mul_mat_q), 1);
|
418
|
+
rb_define_method(rb_cLLaMAContextParams, "mul_mat_q", RUBY_METHOD_FUNC(_llama_context_params_get_mul_mat_q), 0);
|
415
419
|
rb_define_method(rb_cLLaMAContextParams, "seed=", RUBY_METHOD_FUNC(_llama_context_params_set_seed), 1);
|
416
420
|
rb_define_method(rb_cLLaMAContextParams, "seed", RUBY_METHOD_FUNC(_llama_context_params_get_seed), 0);
|
417
421
|
rb_define_method(rb_cLLaMAContextParams, "f16_kv=", RUBY_METHOD_FUNC(_llama_context_params_set_f16_kv), 1);
|
@@ -525,7 +529,7 @@ private:
|
|
525
529
|
// low_vram
|
526
530
|
static VALUE _llama_context_params_set_low_vram(VALUE self, VALUE low_vram) {
|
527
531
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
528
|
-
ptr->params.low_vram = low_vram
|
532
|
+
ptr->params.low_vram = RTEST(low_vram) ? true : false;
|
529
533
|
return ptr->params.low_vram ? Qtrue : Qfalse;
|
530
534
|
}
|
531
535
|
|
@@ -534,6 +538,18 @@ private:
|
|
534
538
|
return ptr->params.low_vram ? Qtrue : Qfalse;
|
535
539
|
}
|
536
540
|
|
541
|
+
// mul_mat_q
|
542
|
+
static VALUE _llama_context_params_set_mul_mat_q(VALUE self, VALUE mul_mat_q) {
|
543
|
+
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
544
|
+
ptr->params.mul_mat_q = RTEST(mul_mat_q) ? true : false;
|
545
|
+
return ptr->params.mul_mat_q ? Qtrue : Qfalse;
|
546
|
+
}
|
547
|
+
|
548
|
+
static VALUE _llama_context_params_get_mul_mat_q(VALUE self) {
|
549
|
+
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
550
|
+
return ptr->params.mul_mat_q ? Qtrue : Qfalse;
|
551
|
+
}
|
552
|
+
|
537
553
|
// seed
|
538
554
|
static VALUE _llama_context_params_set_seed(VALUE self, VALUE seed) {
|
539
555
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
@@ -553,7 +569,7 @@ private:
|
|
553
569
|
// f16_kv
|
554
570
|
static VALUE _llama_context_params_set_f16_kv(VALUE self, VALUE f16_kv) {
|
555
571
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
556
|
-
ptr->params.f16_kv = f16_kv
|
572
|
+
ptr->params.f16_kv = RTEST(f16_kv) ? true : false;
|
557
573
|
return ptr->params.f16_kv ? Qtrue : Qfalse;
|
558
574
|
}
|
559
575
|
|
@@ -565,7 +581,7 @@ private:
|
|
565
581
|
// logits_all
|
566
582
|
static VALUE _llama_context_params_set_logits_all(VALUE self, VALUE logits_all) {
|
567
583
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
568
|
-
ptr->params.logits_all = logits_all
|
584
|
+
ptr->params.logits_all = RTEST(logits_all) ? true : false;
|
569
585
|
return ptr->params.logits_all ? Qtrue : Qfalse;
|
570
586
|
}
|
571
587
|
|
@@ -577,7 +593,7 @@ private:
|
|
577
593
|
// vocab_only
|
578
594
|
static VALUE _llama_context_params_set_vocab_only(VALUE self, VALUE vocab_only) {
|
579
595
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
580
|
-
ptr->params.vocab_only = vocab_only
|
596
|
+
ptr->params.vocab_only = RTEST(vocab_only) ? true : false;
|
581
597
|
return ptr->params.vocab_only ? Qtrue : Qfalse;
|
582
598
|
}
|
583
599
|
|
@@ -589,7 +605,7 @@ private:
|
|
589
605
|
// use_mmap
|
590
606
|
static VALUE _llama_context_params_set_use_mmap(VALUE self, VALUE use_mmap) {
|
591
607
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
592
|
-
ptr->params.use_mmap = use_mmap
|
608
|
+
ptr->params.use_mmap = RTEST(use_mmap) ? true : false;
|
593
609
|
return ptr->params.use_mmap ? Qtrue : Qfalse;
|
594
610
|
}
|
595
611
|
|
@@ -601,7 +617,7 @@ private:
|
|
601
617
|
// use_mlock
|
602
618
|
static VALUE _llama_context_params_set_use_mlock(VALUE self, VALUE use_mlock) {
|
603
619
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
604
|
-
ptr->params.use_mlock = use_mlock
|
620
|
+
ptr->params.use_mlock = RTEST(use_mlock) ? true : false;
|
605
621
|
return ptr->params.use_mlock ? Qtrue : Qfalse;
|
606
622
|
}
|
607
623
|
|
@@ -613,7 +629,7 @@ private:
|
|
613
629
|
// embedding
|
614
630
|
static VALUE _llama_context_params_set_embedding(VALUE self, VALUE embedding) {
|
615
631
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
616
|
-
ptr->params.embedding = embedding
|
632
|
+
ptr->params.embedding = RTEST(embedding) ? true : false;
|
617
633
|
return ptr->params.embedding ? Qtrue : Qfalse;
|
618
634
|
}
|
619
635
|
|
@@ -1057,6 +1073,222 @@ const rb_data_type_t RbLLaMAModel::llama_model_type = {
|
|
1057
1073
|
RUBY_TYPED_FREE_IMMEDIATELY
|
1058
1074
|
};
|
1059
1075
|
|
1076
|
+
class LLaMAGrammarElementWrapper {
|
1077
|
+
public:
|
1078
|
+
llama_grammar_element element;
|
1079
|
+
|
1080
|
+
LLaMAGrammarElementWrapper() {
|
1081
|
+
element.type = LLAMA_GRETYPE_END;
|
1082
|
+
element.value = 0;
|
1083
|
+
}
|
1084
|
+
|
1085
|
+
~LLaMAGrammarElementWrapper() {}
|
1086
|
+
};
|
1087
|
+
|
1088
|
+
class RbLLaMAGrammarElement {
|
1089
|
+
public:
|
1090
|
+
static VALUE llama_grammar_element_alloc(VALUE self) {
|
1091
|
+
LLaMAGrammarElementWrapper* ptr = (LLaMAGrammarElementWrapper*)ruby_xmalloc(sizeof(LLaMAGrammarElementWrapper));
|
1092
|
+
new (ptr) LLaMAGrammarElementWrapper();
|
1093
|
+
return TypedData_Wrap_Struct(self, &llama_grammar_element_type, ptr);
|
1094
|
+
}
|
1095
|
+
|
1096
|
+
static void llama_grammar_element_free(void* ptr) {
|
1097
|
+
((LLaMAGrammarElementWrapper*)ptr)->~LLaMAGrammarElementWrapper();
|
1098
|
+
ruby_xfree(ptr);
|
1099
|
+
}
|
1100
|
+
|
1101
|
+
static size_t llama_grammar_element_size(const void* ptr) {
|
1102
|
+
return sizeof(*((LLaMAGrammarElementWrapper*)ptr));
|
1103
|
+
}
|
1104
|
+
|
1105
|
+
static LLaMAGrammarElementWrapper* get_llama_grammar_element(VALUE self) {
|
1106
|
+
LLaMAGrammarElementWrapper* ptr;
|
1107
|
+
TypedData_Get_Struct(self, LLaMAGrammarElementWrapper, &llama_grammar_element_type, ptr);
|
1108
|
+
return ptr;
|
1109
|
+
}
|
1110
|
+
|
1111
|
+
static void define_class(VALUE outer) {
|
1112
|
+
rb_cLLaMAGrammarElement = rb_define_class_under(outer, "GrammarElement", rb_cObject);
|
1113
|
+
rb_define_alloc_func(rb_cLLaMAGrammarElement, llama_grammar_element_alloc);
|
1114
|
+
rb_define_method(rb_cLLaMAGrammarElement, "initialize", RUBY_METHOD_FUNC(_llama_grammar_element_init), -1);
|
1115
|
+
rb_define_method(rb_cLLaMAGrammarElement, "type=", RUBY_METHOD_FUNC(_llama_grammar_element_set_type), 1);
|
1116
|
+
rb_define_method(rb_cLLaMAGrammarElement, "type", RUBY_METHOD_FUNC(_llama_grammar_element_get_type), 0);
|
1117
|
+
rb_define_method(rb_cLLaMAGrammarElement, "value=", RUBY_METHOD_FUNC(_llama_grammar_element_set_value), 1);
|
1118
|
+
rb_define_method(rb_cLLaMAGrammarElement, "value", RUBY_METHOD_FUNC(_llama_grammar_element_get_value), 0);
|
1119
|
+
}
|
1120
|
+
|
1121
|
+
private:
|
1122
|
+
static const rb_data_type_t llama_grammar_element_type;
|
1123
|
+
|
1124
|
+
static VALUE _llama_grammar_element_init(int argc, VALUE* argv, VALUE self) {
|
1125
|
+
VALUE kw_args = Qnil;
|
1126
|
+
ID kw_table[2] = { rb_intern("type"), rb_intern("value") };
|
1127
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
1128
|
+
VALUE arr = Qnil;
|
1129
|
+
rb_scan_args(argc, argv, ":", &arr, &kw_args);
|
1130
|
+
rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
|
1131
|
+
|
1132
|
+
if (kw_values[0] != Qundef && !RB_INTEGER_TYPE_P(kw_values[0])) {
|
1133
|
+
rb_raise(rb_eArgError, "type must be an integer");
|
1134
|
+
return Qnil;
|
1135
|
+
}
|
1136
|
+
if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
|
1137
|
+
rb_raise(rb_eArgError, "value must be an integer");
|
1138
|
+
return Qnil;
|
1139
|
+
}
|
1140
|
+
|
1141
|
+
LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
|
1142
|
+
new (ptr) LLaMAGrammarElementWrapper();
|
1143
|
+
|
1144
|
+
if (kw_values[0] != Qundef) {
|
1145
|
+
ptr->element.type = (enum llama_gretype)NUM2INT(kw_values[0]);
|
1146
|
+
}
|
1147
|
+
if (kw_values[1] != Qundef) {
|
1148
|
+
ptr->element.value = NUM2INT(kw_values[1]);
|
1149
|
+
}
|
1150
|
+
|
1151
|
+
return self;
|
1152
|
+
}
|
1153
|
+
|
1154
|
+
// type
|
1155
|
+
static VALUE _llama_grammar_element_set_type(VALUE self, VALUE type) {
|
1156
|
+
LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
|
1157
|
+
ptr->element.type = (enum llama_gretype)NUM2INT(type);
|
1158
|
+
return INT2NUM(ptr->element.type);
|
1159
|
+
}
|
1160
|
+
|
1161
|
+
static VALUE _llama_grammar_element_get_type(VALUE self) {
|
1162
|
+
LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
|
1163
|
+
return INT2NUM(ptr->element.type);
|
1164
|
+
}
|
1165
|
+
|
1166
|
+
// value
|
1167
|
+
static VALUE _llama_grammar_element_set_value(VALUE self, VALUE type) {
|
1168
|
+
LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
|
1169
|
+
ptr->element.value = NUM2INT(type);
|
1170
|
+
return INT2NUM(ptr->element.value);
|
1171
|
+
}
|
1172
|
+
|
1173
|
+
static VALUE _llama_grammar_element_get_value(VALUE self) {
|
1174
|
+
LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
|
1175
|
+
return INT2NUM(ptr->element.value);
|
1176
|
+
}
|
1177
|
+
};
|
1178
|
+
|
1179
|
+
const rb_data_type_t RbLLaMAGrammarElement::llama_grammar_element_type = {
|
1180
|
+
"RbLLaMAGrammarElement",
|
1181
|
+
{ NULL,
|
1182
|
+
RbLLaMAGrammarElement::llama_grammar_element_free,
|
1183
|
+
RbLLaMAGrammarElement::llama_grammar_element_size },
|
1184
|
+
NULL,
|
1185
|
+
NULL,
|
1186
|
+
RUBY_TYPED_FREE_IMMEDIATELY
|
1187
|
+
};
|
1188
|
+
|
1189
|
+
class LLaMAGrammarWrapper {
|
1190
|
+
public:
|
1191
|
+
struct llama_grammar* grammar;
|
1192
|
+
|
1193
|
+
LLaMAGrammarWrapper() : grammar(nullptr) {}
|
1194
|
+
|
1195
|
+
~LLaMAGrammarWrapper() {
|
1196
|
+
if (grammar) {
|
1197
|
+
llama_grammar_free(grammar);
|
1198
|
+
}
|
1199
|
+
}
|
1200
|
+
};
|
1201
|
+
|
1202
|
+
class RbLLaMAGrammar {
|
1203
|
+
public:
|
1204
|
+
static VALUE llama_grammar_alloc(VALUE self) {
|
1205
|
+
LLaMAGrammarWrapper* ptr = (LLaMAGrammarWrapper*)ruby_xmalloc(sizeof(LLaMAGrammarWrapper));
|
1206
|
+
new (ptr) LLaMAGrammarWrapper();
|
1207
|
+
return TypedData_Wrap_Struct(self, &llama_grammar_type, ptr);
|
1208
|
+
}
|
1209
|
+
|
1210
|
+
static void llama_grammar_free(void* ptr) {
|
1211
|
+
((LLaMAGrammarWrapper*)ptr)->~LLaMAGrammarWrapper();
|
1212
|
+
ruby_xfree(ptr);
|
1213
|
+
}
|
1214
|
+
|
1215
|
+
static size_t llama_grammar_size(const void* ptr) {
|
1216
|
+
return sizeof(*((LLaMAGrammarWrapper*)ptr));
|
1217
|
+
}
|
1218
|
+
|
1219
|
+
static LLaMAGrammarWrapper* get_llama_grammar(VALUE self) {
|
1220
|
+
LLaMAGrammarWrapper* ptr;
|
1221
|
+
TypedData_Get_Struct(self, LLaMAGrammarWrapper, &llama_grammar_type, ptr);
|
1222
|
+
return ptr;
|
1223
|
+
}
|
1224
|
+
|
1225
|
+
static void define_class(VALUE outer) {
|
1226
|
+
rb_cLLaMAGrammar = rb_define_class_under(outer, "Grammar", rb_cObject);
|
1227
|
+
rb_define_alloc_func(rb_cLLaMAGrammar, llama_grammar_alloc);
|
1228
|
+
rb_define_method(rb_cLLaMAGrammar, "initialize", RUBY_METHOD_FUNC(_llama_grammar_init), -1);
|
1229
|
+
}
|
1230
|
+
|
1231
|
+
private:
|
1232
|
+
static const rb_data_type_t llama_grammar_type;
|
1233
|
+
|
1234
|
+
static VALUE _llama_grammar_init(int argc, VALUE* argv, VALUE self) {
|
1235
|
+
VALUE kw_args = Qnil;
|
1236
|
+
ID kw_table[2] = { rb_intern("rules"), rb_intern("start_rule_index") };
|
1237
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
1238
|
+
rb_scan_args(argc, argv, ":", &kw_args);
|
1239
|
+
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
1240
|
+
|
1241
|
+
if (!RB_TYPE_P(kw_values[0], T_ARRAY)) {
|
1242
|
+
rb_raise(rb_eArgError, "rules must be an array");
|
1243
|
+
return Qnil;
|
1244
|
+
}
|
1245
|
+
if (!RB_INTEGER_TYPE_P(kw_values[1])) {
|
1246
|
+
rb_raise(rb_eArgError, "start_rule_index must be an integer");
|
1247
|
+
return Qnil;
|
1248
|
+
}
|
1249
|
+
|
1250
|
+
const int n_rules = RARRAY_LEN(kw_values[0]);
|
1251
|
+
llama_grammar_element** rules = ALLOCA_N(llama_grammar_element*, n_rules);
|
1252
|
+
for (int i = 0; i < n_rules; ++i) {
|
1253
|
+
VALUE rule = rb_ary_entry(kw_values[0], i);
|
1254
|
+
if (!RB_TYPE_P(rule, T_ARRAY)) {
|
1255
|
+
rb_raise(rb_eArgError, "element of rules must be an array");
|
1256
|
+
return Qnil;
|
1257
|
+
}
|
1258
|
+
const int n_elements = RARRAY_LEN(rule);
|
1259
|
+
llama_grammar_element* elements = ALLOCA_N(llama_grammar_element, n_elements);
|
1260
|
+
for (int j = 0; j < n_elements; ++j) {
|
1261
|
+
VALUE element = rb_ary_entry(rule, j);
|
1262
|
+
if (!rb_obj_is_kind_of(element, rb_cLLaMAGrammarElement)) {
|
1263
|
+
rb_raise(rb_eArgError, "element of rule must be an instance of GrammarElement");
|
1264
|
+
return Qnil;
|
1265
|
+
}
|
1266
|
+
LLaMAGrammarElementWrapper* ptr = RbLLaMAGrammarElement::get_llama_grammar_element(element);
|
1267
|
+
elements[j] = ptr->element;
|
1268
|
+
}
|
1269
|
+
rules[i] = elements;
|
1270
|
+
}
|
1271
|
+
|
1272
|
+
const size_t start_rule_index = NUM2SIZET(kw_values[1]);
|
1273
|
+
|
1274
|
+
LLaMAGrammarWrapper* ptr = get_llama_grammar(self);
|
1275
|
+
new (ptr) LLaMAGrammarWrapper();
|
1276
|
+
ptr->grammar = llama_grammar_init((const llama_grammar_element**)rules, n_rules, start_rule_index);
|
1277
|
+
|
1278
|
+
return self;
|
1279
|
+
}
|
1280
|
+
};
|
1281
|
+
|
1282
|
+
const rb_data_type_t RbLLaMAGrammar::llama_grammar_type = {
|
1283
|
+
"RbLLaMAGrammar",
|
1284
|
+
{ NULL,
|
1285
|
+
RbLLaMAGrammar::llama_grammar_free,
|
1286
|
+
RbLLaMAGrammar::llama_grammar_size },
|
1287
|
+
NULL,
|
1288
|
+
NULL,
|
1289
|
+
RUBY_TYPED_FREE_IMMEDIATELY
|
1290
|
+
};
|
1291
|
+
|
1060
1292
|
class LLaMAContextWrapper {
|
1061
1293
|
public:
|
1062
1294
|
struct llama_context* ctx;
|
@@ -1128,6 +1360,8 @@ public:
|
|
1128
1360
|
rb_define_method(rb_cLLaMAContext, "sample_token_mirostat_v2", RUBY_METHOD_FUNC(_llama_context_sample_token_mirostat_v2), -1);
|
1129
1361
|
rb_define_method(rb_cLLaMAContext, "sample_token_greedy", RUBY_METHOD_FUNC(_llama_context_sample_token_greedy), 1);
|
1130
1362
|
rb_define_method(rb_cLLaMAContext, "sample_token", RUBY_METHOD_FUNC(_llama_context_sample_token), 1);
|
1363
|
+
rb_define_method(rb_cLLaMAContext, "sample_grammar", RUBY_METHOD_FUNC(_llama_context_sample_grammar), -1);
|
1364
|
+
rb_define_method(rb_cLLaMAContext, "grammar_accept_token", RUBY_METHOD_FUNC(_llama_context_grammar_accept_token), -1);
|
1131
1365
|
}
|
1132
1366
|
|
1133
1367
|
private:
|
@@ -2104,6 +2338,69 @@ private:
|
|
2104
2338
|
llama_token id = llama_sample_token(ctx_ptr->ctx, &(cnd_ptr->array));
|
2105
2339
|
return INT2NUM(id);
|
2106
2340
|
}
|
2341
|
+
|
2342
|
+
static VALUE _llama_context_sample_grammar(int argc, VALUE* argv, VALUE self) {
|
2343
|
+
VALUE kw_args = Qnil;
|
2344
|
+
ID kw_table[1] = { rb_intern("grammar") };
|
2345
|
+
VALUE kw_values[1] = { Qundef };
|
2346
|
+
VALUE candidates = Qnil;
|
2347
|
+
rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
|
2348
|
+
rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
|
2349
|
+
|
2350
|
+
if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
|
2351
|
+
rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
|
2352
|
+
return Qnil;
|
2353
|
+
}
|
2354
|
+
if (!rb_obj_is_kind_of(kw_values[0], rb_cLLaMAGrammar)) {
|
2355
|
+
rb_raise(rb_eArgError, "grammar must be a Grammar");
|
2356
|
+
return Qnil;
|
2357
|
+
}
|
2358
|
+
|
2359
|
+
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
2360
|
+
if (ctx_ptr->ctx == NULL) {
|
2361
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
2362
|
+
return Qnil;
|
2363
|
+
}
|
2364
|
+
LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
|
2365
|
+
if (cnd_ptr->array.data == nullptr) {
|
2366
|
+
rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
|
2367
|
+
return Qnil;
|
2368
|
+
}
|
2369
|
+
LLaMAGrammarWrapper* grm_ptr = RbLLaMAGrammar::get_llama_grammar(kw_values[0]);
|
2370
|
+
|
2371
|
+
llama_sample_grammar(ctx_ptr->ctx, &(cnd_ptr->array), grm_ptr->grammar);
|
2372
|
+
|
2373
|
+
return Qnil;
|
2374
|
+
}
|
2375
|
+
|
2376
|
+
static VALUE _llama_context_grammar_accept_token(int argc, VALUE* argv, VALUE self) {
|
2377
|
+
VALUE kw_args = Qnil;
|
2378
|
+
ID kw_table[2] = { rb_intern("grammar"), rb_intern("token") };
|
2379
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
2380
|
+
rb_scan_args(argc, argv, ":", &kw_args);
|
2381
|
+
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
2382
|
+
|
2383
|
+
if (!rb_obj_is_kind_of(kw_values[0], rb_cLLaMAGrammar)) {
|
2384
|
+
rb_raise(rb_eArgError, "grammar must be a Grammar");
|
2385
|
+
return Qnil;
|
2386
|
+
}
|
2387
|
+
if (!RB_INTEGER_TYPE_P(kw_values[1])) {
|
2388
|
+
rb_raise(rb_eArgError, "token must be an Integer");
|
2389
|
+
return Qnil;
|
2390
|
+
}
|
2391
|
+
|
2392
|
+
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
2393
|
+
if (ctx_ptr->ctx == NULL) {
|
2394
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
2395
|
+
return Qnil;
|
2396
|
+
}
|
2397
|
+
LLaMAGrammarWrapper* grm_ptr = RbLLaMAGrammar::get_llama_grammar(kw_values[0]);
|
2398
|
+
llama_token token = NUM2INT(kw_values[1]);
|
2399
|
+
|
2400
|
+
llama_grammar_accept_token(ctx_ptr->ctx, grm_ptr->grammar, token);
|
2401
|
+
|
2402
|
+
return Qnil;
|
2403
|
+
}
|
2107
2404
|
};
|
2108
2405
|
|
2109
2406
|
const rb_data_type_t RbLLaMAContext::llama_context_type = {
|
@@ -2125,7 +2422,7 @@ static VALUE rb_llama_llama_backend_init(int argc, VALUE* argv, VALUE self) {
|
|
2125
2422
|
rb_scan_args(argc, argv, ":", &kw_args);
|
2126
2423
|
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
2127
2424
|
|
2128
|
-
const bool numa = kw_values[0] == Qundef ? false : (RTEST ? true : false);
|
2425
|
+
const bool numa = kw_values[0] == Qundef ? false : (RTEST(kw_values[0]) ? true : false);
|
2129
2426
|
llama_backend_init(numa);
|
2130
2427
|
|
2131
2428
|
return Qnil;
|
@@ -2208,6 +2505,8 @@ extern "C" void Init_llama_cpp(void) {
|
|
2208
2505
|
RbLLaMAContext::define_class(rb_mLLaMACpp);
|
2209
2506
|
RbLLaMAContextParams::define_class(rb_mLLaMACpp);
|
2210
2507
|
RbLLaMAModelQuantizeParams::define_class(rb_mLLaMACpp);
|
2508
|
+
RbLLaMAGrammarElement::define_class(rb_mLLaMACpp);
|
2509
|
+
RbLLaMAGrammar::define_class(rb_mLLaMACpp);
|
2211
2510
|
|
2212
2511
|
rb_define_module_function(rb_mLLaMACpp, "backend_init", rb_llama_llama_backend_init, -1);
|
2213
2512
|
rb_define_module_function(rb_mLLaMACpp, "backend_free", rb_llama_llama_backend_free, 0);
|
@@ -2240,6 +2539,14 @@ extern "C" void Init_llama_cpp(void) {
|
|
2240
2539
|
rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q5_K_M", INT2NUM(LLAMA_FTYPE_MOSTLY_Q5_K_M));
|
2241
2540
|
rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q6_K", INT2NUM(LLAMA_FTYPE_MOSTLY_Q6_K));
|
2242
2541
|
|
2542
|
+
rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_END", INT2NUM(LLAMA_GRETYPE_END));
|
2543
|
+
rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_ALT", INT2NUM(LLAMA_GRETYPE_ALT));
|
2544
|
+
rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_RULE_REF", INT2NUM(LLAMA_GRETYPE_RULE_REF));
|
2545
|
+
rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR", INT2NUM(LLAMA_GRETYPE_CHAR));
|
2546
|
+
rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR_NOT", INT2NUM(LLAMA_GRETYPE_CHAR_NOT));
|
2547
|
+
rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR_RNG_UPPER", INT2NUM(LLAMA_GRETYPE_CHAR_RNG_UPPER));
|
2548
|
+
rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR_ALT", INT2NUM(LLAMA_GRETYPE_CHAR_ALT));
|
2549
|
+
|
2243
2550
|
std::stringstream ss_magic;
|
2244
2551
|
ss_magic << std::showbase << std::hex << LLAMA_FILE_MAGIC_GGJT;
|
2245
2552
|
rb_define_const(rb_mLLaMACpp, "LLAMA_FILE_MAGIC_GGJT", rb_str_new2(ss_magic.str().c_str()));
|