llama_cpp 0.3.3 → 0.3.5
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +31 -0
- data/ext/llama_cpp/extconf.rb +1 -0
- data/ext/llama_cpp/llama_cpp.cpp +439 -9
- data/ext/llama_cpp/src/ggml-cuda.cu +759 -136
- data/ext/llama_cpp/src/ggml-metal.h +7 -0
- data/ext/llama_cpp/src/ggml-metal.m +250 -111
- data/ext/llama_cpp/src/ggml-metal.metal +614 -483
- data/ext/llama_cpp/src/ggml.c +793 -1032
- data/ext/llama_cpp/src/ggml.h +95 -18
- data/ext/llama_cpp/src/k_quants.c +327 -3
- data/ext/llama_cpp/src/k_quants.h +8 -0
- data/ext/llama_cpp/src/llama.cpp +626 -166
- data/ext/llama_cpp/src/llama.h +94 -10
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +1 -0
- data/sig/llama_cpp.rbs +36 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 991df3df6b16ec98a203a6c6565794988eec04697ccb963faab976f436e1bfcc
|
4
|
+
data.tar.gz: dafd3b8274640eb79353e11056f497a02392fef332a46dcc1717878c836f62bd
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 0a54fdf18c5be5273f01d4d991b59975a8e8b6a0a8f54087fb90df3f1a8a7ebad557d01fb119f3d61cafa8f6c59fb81641624779fda27b732eea2868cb4642e8
|
7
|
+
data.tar.gz: 6574064078070502e36ad933bd5efb2f479c94f47a1260286fe485e48d35a7b3985274943c6163189326891349b47b1d9815d77c600b015fe111cc3842179392
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,34 @@
|
|
1
|
+
## [[0.3.5](https://github.com/yoshoku/llama_cpp.rb/compare/v0.3.4...v0.3.5)] - 2023-07-29
|
2
|
+
|
3
|
+
- Bump bundled llama.cpp from master-d924522 to master-1a94186.
|
4
|
+
- Add `GrammarElement` and `Grammar` classes.
|
5
|
+
- Add `sample_grammar` method to Context.
|
6
|
+
- Add `grammar_accept_token method` method to Context.
|
7
|
+
|
8
|
+
## [[0.3.4](https://github.com/yoshoku/llama_cpp.rb/compare/v0.3.3...v0.3.4)] - 2023-07-23
|
9
|
+
|
10
|
+
- Bump bundled llama.cpp from master-32c5411 to master-d924522.
|
11
|
+
- Add `rope_freq_base` and `rope_freq_scale` options to ContextParams.
|
12
|
+
- Add `max_devices` module function to LLaMACpp.
|
13
|
+
- Add `n_vocab`, `n_ctx`, and `n_embd` methods to Model.
|
14
|
+
- Add `vocab`, `tokenize`, and `token_to_str` methods to Model.
|
15
|
+
```ruby
|
16
|
+
require 'llama_cpp'
|
17
|
+
|
18
|
+
params = LLaMACpp::ContextParams.new
|
19
|
+
model = LLaMACpp::Model.new(model_path: '/path/to/model.bin', params: params)
|
20
|
+
|
21
|
+
p model.tokenize(text: 'hello, world')
|
22
|
+
# => [12199, 29892, 3186]
|
23
|
+
|
24
|
+
p model.token_to_str(12199)
|
25
|
+
# => "hello"
|
26
|
+
```
|
27
|
+
|
28
|
+
**Breaking Changes**
|
29
|
+
- Fix to automatically call `backend_free` method when Ruby script exits.
|
30
|
+
- Remove `smooth_factor` argument from `sample_classifier_free_guidance methos` on Context.
|
31
|
+
|
1
32
|
## [[0.3.3](https://github.com/yoshoku/llama_cpp.rb/compare/v0.3.2...v0.3.3)] - 2023-07-15
|
2
33
|
|
3
34
|
- Bump bundled llama.cpp from master-481f793 to master-32c5411.
|
data/ext/llama_cpp/extconf.rb
CHANGED
data/ext/llama_cpp/llama_cpp.cpp
CHANGED
@@ -8,6 +8,8 @@ VALUE rb_cLLaMAContextParams;
|
|
8
8
|
VALUE rb_cLLaMAModelQuantizeParams;
|
9
9
|
VALUE rb_cLLaMATokenData;
|
10
10
|
VALUE rb_cLLaMATokenDataArray;
|
11
|
+
VALUE rb_cLLaMAGrammarElement;
|
12
|
+
VALUE rb_cLLaMAGrammar;
|
11
13
|
|
12
14
|
class LLaMATokenDataWrapper {
|
13
15
|
public:
|
@@ -406,6 +408,10 @@ public:
|
|
406
408
|
rb_define_method(rb_cLLaMAContextParams, "main_gpu=", RUBY_METHOD_FUNC(_llama_context_params_set_main_gpu), 1);
|
407
409
|
rb_define_method(rb_cLLaMAContextParams, "main_gpu", RUBY_METHOD_FUNC(_llama_context_params_get_main_gpu), 0);
|
408
410
|
rb_define_method(rb_cLLaMAContextParams, "tensor_split", RUBY_METHOD_FUNC(_llama_context_params_get_tensor_split), 0);
|
411
|
+
rb_define_method(rb_cLLaMAContextParams, "rope_freq_base=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_freq_base), 1);
|
412
|
+
rb_define_method(rb_cLLaMAContextParams, "rope_freq_base", RUBY_METHOD_FUNC(_llama_context_params_get_rope_freq_base), 0);
|
413
|
+
rb_define_method(rb_cLLaMAContextParams, "rope_freq_scale=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_freq_scale), 1);
|
414
|
+
rb_define_method(rb_cLLaMAContextParams, "rope_freq_scale", RUBY_METHOD_FUNC(_llama_context_params_get_rope_freq_scale), 0);
|
409
415
|
rb_define_method(rb_cLLaMAContextParams, "low_vram=", RUBY_METHOD_FUNC(_llama_context_params_set_low_vram), 1);
|
410
416
|
rb_define_method(rb_cLLaMAContextParams, "low_vram", RUBY_METHOD_FUNC(_llama_context_params_get_low_vram), 0);
|
411
417
|
rb_define_method(rb_cLLaMAContextParams, "seed=", RUBY_METHOD_FUNC(_llama_context_params_set_seed), 1);
|
@@ -494,6 +500,30 @@ private:
|
|
494
500
|
return ret;
|
495
501
|
}
|
496
502
|
|
503
|
+
// rope_freq_base
|
504
|
+
static VALUE _llama_context_params_set_rope_freq_base(VALUE self, VALUE rope_freq_base) {
|
505
|
+
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
506
|
+
ptr->params.rope_freq_base = NUM2DBL(rope_freq_base);
|
507
|
+
return DBL2NUM(ptr->params.rope_freq_base);
|
508
|
+
}
|
509
|
+
|
510
|
+
static VALUE _llama_context_params_get_rope_freq_base(VALUE self) {
|
511
|
+
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
512
|
+
return DBL2NUM(ptr->params.rope_freq_base);
|
513
|
+
}
|
514
|
+
|
515
|
+
// rope_freq_scale
|
516
|
+
static VALUE _llama_context_params_set_rope_freq_scale(VALUE self, VALUE rope_freq_scale) {
|
517
|
+
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
518
|
+
ptr->params.rope_freq_scale = NUM2DBL(rope_freq_scale);
|
519
|
+
return DBL2NUM(ptr->params.rope_freq_scale);
|
520
|
+
}
|
521
|
+
|
522
|
+
static VALUE _llama_context_params_get_rope_freq_scale(VALUE self) {
|
523
|
+
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
524
|
+
return DBL2NUM(ptr->params.rope_freq_scale);
|
525
|
+
}
|
526
|
+
|
497
527
|
// low_vram
|
498
528
|
static VALUE _llama_context_params_set_low_vram(VALUE self, VALUE low_vram) {
|
499
529
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
@@ -764,6 +794,12 @@ public:
|
|
764
794
|
rb_define_method(rb_cLLaMAModel, "free", RUBY_METHOD_FUNC(_llama_model_free), 0);
|
765
795
|
rb_define_method(rb_cLLaMAModel, "load", RUBY_METHOD_FUNC(_llama_model_load), -1);
|
766
796
|
rb_define_method(rb_cLLaMAModel, "apply_lora_from_file", RUBY_METHOD_FUNC(_llama_model_apply_lora_from_file), -1);
|
797
|
+
rb_define_method(rb_cLLaMAModel, "n_vocab", RUBY_METHOD_FUNC(_llama_model_get_n_vocab_from_model), 0);
|
798
|
+
rb_define_method(rb_cLLaMAModel, "n_ctx", RUBY_METHOD_FUNC(_llama_model_get_n_ctx_from_model), 0);
|
799
|
+
rb_define_method(rb_cLLaMAModel, "n_embd", RUBY_METHOD_FUNC(_llama_model_get_n_embd_from_model), 0);
|
800
|
+
rb_define_method(rb_cLLaMAModel, "vocab", RUBY_METHOD_FUNC(_llama_model_get_vocab_from_model), -1);
|
801
|
+
rb_define_method(rb_cLLaMAModel, "token_to_str", RUBY_METHOD_FUNC(_llama_model_token_to_str_with_model), 1);
|
802
|
+
rb_define_method(rb_cLLaMAModel, "tokenize", RUBY_METHOD_FUNC(_llama_model_tokenize_with_model), -1);
|
767
803
|
}
|
768
804
|
|
769
805
|
private:
|
@@ -908,6 +944,109 @@ private:
|
|
908
944
|
}
|
909
945
|
return Qnil;
|
910
946
|
}
|
947
|
+
|
948
|
+
static VALUE _llama_model_get_n_vocab_from_model(VALUE self) {
|
949
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
950
|
+
return INT2NUM(llama_n_vocab_from_model(ptr->model));
|
951
|
+
}
|
952
|
+
|
953
|
+
static VALUE _llama_model_get_n_ctx_from_model(VALUE self) {
|
954
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
955
|
+
return INT2NUM(llama_n_ctx_from_model(ptr->model));
|
956
|
+
}
|
957
|
+
|
958
|
+
static VALUE _llama_model_get_n_embd_from_model(VALUE self) {
|
959
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
960
|
+
return INT2NUM(llama_n_embd_from_model(ptr->model));
|
961
|
+
}
|
962
|
+
|
963
|
+
static VALUE _llama_model_get_vocab_from_model(int argc, VALUE* argv, VALUE self) {
|
964
|
+
VALUE kw_args = Qnil;
|
965
|
+
ID kw_table[1] = { rb_intern("capacity") };
|
966
|
+
VALUE kw_values[1] = { Qundef };
|
967
|
+
rb_scan_args(argc, argv, ":", &kw_args);
|
968
|
+
rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
|
969
|
+
|
970
|
+
if (!RB_INTEGER_TYPE_P(kw_values[0])) {
|
971
|
+
rb_raise(rb_eArgError, "capacity must be an integer");
|
972
|
+
return Qnil;
|
973
|
+
}
|
974
|
+
|
975
|
+
const int capacity = NUM2INT(kw_values[0]);
|
976
|
+
|
977
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
978
|
+
const int n = std::min(capacity, llama_n_vocab_from_model(ptr->model));
|
979
|
+
const char** vocabs = ALLOCA_N(const char*, n);
|
980
|
+
float* scores = ALLOCA_N(float, n);
|
981
|
+
|
982
|
+
llama_get_vocab_from_model(ptr->model, vocabs, scores, capacity);
|
983
|
+
|
984
|
+
VALUE vocabs_ary = rb_ary_new();
|
985
|
+
VALUE scores_ary = rb_ary_new();
|
986
|
+
|
987
|
+
for (int i = 0; i < n; i++) {
|
988
|
+
rb_ary_push(vocabs_ary, rb_str_new_cstr(vocabs[i]));
|
989
|
+
rb_ary_push(scores_ary, DBL2NUM(scores[i]));
|
990
|
+
}
|
991
|
+
|
992
|
+
VALUE ret = rb_ary_new3(2, vocabs_ary, scores_ary);
|
993
|
+
|
994
|
+
return ret;
|
995
|
+
}
|
996
|
+
|
997
|
+
static VALUE _llama_model_token_to_str_with_model(VALUE self, VALUE token_) {
|
998
|
+
if (!RB_INTEGER_TYPE_P(token_)) {
|
999
|
+
rb_raise(rb_eArgError, "token must be an integer");
|
1000
|
+
return Qnil;
|
1001
|
+
}
|
1002
|
+
const llama_token token = NUM2INT(token_);
|
1003
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1004
|
+
const char* str = llama_token_to_str_with_model(ptr->model, token);
|
1005
|
+
return rb_str_new_cstr(str);
|
1006
|
+
}
|
1007
|
+
|
1008
|
+
static VALUE _llama_model_tokenize_with_model(int argc, VALUE* argv, VALUE self) {
|
1009
|
+
VALUE kw_args = Qnil;
|
1010
|
+
ID kw_table[3] = { rb_intern("text"), rb_intern("n_max_tokens"), rb_intern("add_bos") };
|
1011
|
+
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
1012
|
+
rb_scan_args(argc, argv, ":", &kw_args);
|
1013
|
+
rb_get_kwargs(kw_args, kw_table, 1, 2, kw_values);
|
1014
|
+
|
1015
|
+
if (!RB_TYPE_P(kw_values[0], T_STRING)) {
|
1016
|
+
rb_raise(rb_eArgError, "text must be a String");
|
1017
|
+
return Qnil;
|
1018
|
+
}
|
1019
|
+
if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
|
1020
|
+
rb_raise(rb_eArgError, "n_max_tokens must be an integer");
|
1021
|
+
return Qnil;
|
1022
|
+
}
|
1023
|
+
if (kw_values[2] != Qundef && (kw_values[2] != Qtrue && kw_values[2] != Qfalse)) {
|
1024
|
+
rb_raise(rb_eArgError, "add_bos must be a boolean");
|
1025
|
+
return Qnil;
|
1026
|
+
}
|
1027
|
+
|
1028
|
+
VALUE text_ = kw_values[0];
|
1029
|
+
std::string text = StringValueCStr(text_);
|
1030
|
+
const bool add_bos = kw_values[2] == Qtrue ? true : false;
|
1031
|
+
const int n_max_tokens = kw_values[1] != Qundef ? NUM2INT(kw_values[1]) : text.size() + (add_bos ? 1 : 0);
|
1032
|
+
|
1033
|
+
llama_token* tokens = ALLOCA_N(llama_token, n_max_tokens);
|
1034
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1035
|
+
const int n_tokens = llama_tokenize_with_model(ptr->model, text.c_str(), tokens, n_max_tokens, add_bos);
|
1036
|
+
|
1037
|
+
if (n_tokens < 0) {
|
1038
|
+
rb_raise(rb_eRuntimeError, "failed to tokenize. The numebr of tokens (%d) is greater than n_max_tokens.", -n_tokens);
|
1039
|
+
return Qnil;
|
1040
|
+
}
|
1041
|
+
|
1042
|
+
VALUE ret = rb_ary_new2(n_tokens);
|
1043
|
+
for (int i = 0; i < n_tokens; i++) {
|
1044
|
+
rb_ary_store(ret, i, INT2NUM(tokens[i]));
|
1045
|
+
}
|
1046
|
+
|
1047
|
+
RB_GC_GUARD(text_);
|
1048
|
+
return ret;
|
1049
|
+
}
|
911
1050
|
};
|
912
1051
|
|
913
1052
|
const rb_data_type_t RbLLaMAModel::llama_model_type = {
|
@@ -920,6 +1059,222 @@ const rb_data_type_t RbLLaMAModel::llama_model_type = {
|
|
920
1059
|
RUBY_TYPED_FREE_IMMEDIATELY
|
921
1060
|
};
|
922
1061
|
|
1062
|
+
class LLaMAGrammarElementWrapper {
|
1063
|
+
public:
|
1064
|
+
llama_grammar_element element;
|
1065
|
+
|
1066
|
+
LLaMAGrammarElementWrapper() {
|
1067
|
+
element.type = LLAMA_GRETYPE_END;
|
1068
|
+
element.value = 0;
|
1069
|
+
}
|
1070
|
+
|
1071
|
+
~LLaMAGrammarElementWrapper() {}
|
1072
|
+
};
|
1073
|
+
|
1074
|
+
class RbLLaMAGrammarElement {
|
1075
|
+
public:
|
1076
|
+
static VALUE llama_grammar_element_alloc(VALUE self) {
|
1077
|
+
LLaMAGrammarElementWrapper* ptr = (LLaMAGrammarElementWrapper*)ruby_xmalloc(sizeof(LLaMAGrammarElementWrapper));
|
1078
|
+
new (ptr) LLaMAGrammarElementWrapper();
|
1079
|
+
return TypedData_Wrap_Struct(self, &llama_grammar_element_type, ptr);
|
1080
|
+
}
|
1081
|
+
|
1082
|
+
static void llama_grammar_element_free(void* ptr) {
|
1083
|
+
((LLaMAGrammarElementWrapper*)ptr)->~LLaMAGrammarElementWrapper();
|
1084
|
+
ruby_xfree(ptr);
|
1085
|
+
}
|
1086
|
+
|
1087
|
+
static size_t llama_grammar_element_size(const void* ptr) {
|
1088
|
+
return sizeof(*((LLaMAGrammarElementWrapper*)ptr));
|
1089
|
+
}
|
1090
|
+
|
1091
|
+
static LLaMAGrammarElementWrapper* get_llama_grammar_element(VALUE self) {
|
1092
|
+
LLaMAGrammarElementWrapper* ptr;
|
1093
|
+
TypedData_Get_Struct(self, LLaMAGrammarElementWrapper, &llama_grammar_element_type, ptr);
|
1094
|
+
return ptr;
|
1095
|
+
}
|
1096
|
+
|
1097
|
+
static void define_class(VALUE outer) {
|
1098
|
+
rb_cLLaMAGrammarElement = rb_define_class_under(outer, "GrammarElement", rb_cObject);
|
1099
|
+
rb_define_alloc_func(rb_cLLaMAGrammarElement, llama_grammar_element_alloc);
|
1100
|
+
rb_define_method(rb_cLLaMAGrammarElement, "initialize", RUBY_METHOD_FUNC(_llama_grammar_element_init), -1);
|
1101
|
+
rb_define_method(rb_cLLaMAGrammarElement, "type=", RUBY_METHOD_FUNC(_llama_grammar_element_set_type), 1);
|
1102
|
+
rb_define_method(rb_cLLaMAGrammarElement, "type", RUBY_METHOD_FUNC(_llama_grammar_element_get_type), 0);
|
1103
|
+
rb_define_method(rb_cLLaMAGrammarElement, "value=", RUBY_METHOD_FUNC(_llama_grammar_element_set_value), 1);
|
1104
|
+
rb_define_method(rb_cLLaMAGrammarElement, "value", RUBY_METHOD_FUNC(_llama_grammar_element_get_value), 0);
|
1105
|
+
}
|
1106
|
+
|
1107
|
+
private:
|
1108
|
+
static const rb_data_type_t llama_grammar_element_type;
|
1109
|
+
|
1110
|
+
static VALUE _llama_grammar_element_init(int argc, VALUE* argv, VALUE self) {
|
1111
|
+
VALUE kw_args = Qnil;
|
1112
|
+
ID kw_table[2] = { rb_intern("type"), rb_intern("value") };
|
1113
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
1114
|
+
VALUE arr = Qnil;
|
1115
|
+
rb_scan_args(argc, argv, ":", &arr, &kw_args);
|
1116
|
+
rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
|
1117
|
+
|
1118
|
+
if (kw_values[0] != Qundef && !RB_INTEGER_TYPE_P(kw_values[0])) {
|
1119
|
+
rb_raise(rb_eArgError, "type must be an integer");
|
1120
|
+
return Qnil;
|
1121
|
+
}
|
1122
|
+
if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
|
1123
|
+
rb_raise(rb_eArgError, "value must be an integer");
|
1124
|
+
return Qnil;
|
1125
|
+
}
|
1126
|
+
|
1127
|
+
LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
|
1128
|
+
new (ptr) LLaMAGrammarElementWrapper();
|
1129
|
+
|
1130
|
+
if (kw_values[0] != Qundef) {
|
1131
|
+
ptr->element.type = (enum llama_gretype)NUM2INT(kw_values[0]);
|
1132
|
+
}
|
1133
|
+
if (kw_values[1] != Qundef) {
|
1134
|
+
ptr->element.value = NUM2INT(kw_values[1]);
|
1135
|
+
}
|
1136
|
+
|
1137
|
+
return self;
|
1138
|
+
}
|
1139
|
+
|
1140
|
+
// type
|
1141
|
+
static VALUE _llama_grammar_element_set_type(VALUE self, VALUE type) {
|
1142
|
+
LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
|
1143
|
+
ptr->element.type = (enum llama_gretype)NUM2INT(type);
|
1144
|
+
return INT2NUM(ptr->element.type);
|
1145
|
+
}
|
1146
|
+
|
1147
|
+
static VALUE _llama_grammar_element_get_type(VALUE self) {
|
1148
|
+
LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
|
1149
|
+
return INT2NUM(ptr->element.type);
|
1150
|
+
}
|
1151
|
+
|
1152
|
+
// value
|
1153
|
+
static VALUE _llama_grammar_element_set_value(VALUE self, VALUE type) {
|
1154
|
+
LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
|
1155
|
+
ptr->element.value = NUM2INT(type);
|
1156
|
+
return INT2NUM(ptr->element.value);
|
1157
|
+
}
|
1158
|
+
|
1159
|
+
static VALUE _llama_grammar_element_get_value(VALUE self) {
|
1160
|
+
LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
|
1161
|
+
return INT2NUM(ptr->element.value);
|
1162
|
+
}
|
1163
|
+
};
|
1164
|
+
|
1165
|
+
const rb_data_type_t RbLLaMAGrammarElement::llama_grammar_element_type = {
|
1166
|
+
"RbLLaMAGrammarElement",
|
1167
|
+
{ NULL,
|
1168
|
+
RbLLaMAGrammarElement::llama_grammar_element_free,
|
1169
|
+
RbLLaMAGrammarElement::llama_grammar_element_size },
|
1170
|
+
NULL,
|
1171
|
+
NULL,
|
1172
|
+
RUBY_TYPED_FREE_IMMEDIATELY
|
1173
|
+
};
|
1174
|
+
|
1175
|
+
class LLaMAGrammarWrapper {
|
1176
|
+
public:
|
1177
|
+
struct llama_grammar* grammar;
|
1178
|
+
|
1179
|
+
LLaMAGrammarWrapper() : grammar(nullptr) {}
|
1180
|
+
|
1181
|
+
~LLaMAGrammarWrapper() {
|
1182
|
+
if (grammar) {
|
1183
|
+
llama_grammar_free(grammar);
|
1184
|
+
}
|
1185
|
+
}
|
1186
|
+
};
|
1187
|
+
|
1188
|
+
class RbLLaMAGrammar {
|
1189
|
+
public:
|
1190
|
+
static VALUE llama_grammar_alloc(VALUE self) {
|
1191
|
+
LLaMAGrammarWrapper* ptr = (LLaMAGrammarWrapper*)ruby_xmalloc(sizeof(LLaMAGrammarWrapper));
|
1192
|
+
new (ptr) LLaMAGrammarWrapper();
|
1193
|
+
return TypedData_Wrap_Struct(self, &llama_grammar_type, ptr);
|
1194
|
+
}
|
1195
|
+
|
1196
|
+
static void llama_grammar_free(void* ptr) {
|
1197
|
+
((LLaMAGrammarWrapper*)ptr)->~LLaMAGrammarWrapper();
|
1198
|
+
ruby_xfree(ptr);
|
1199
|
+
}
|
1200
|
+
|
1201
|
+
static size_t llama_grammar_size(const void* ptr) {
|
1202
|
+
return sizeof(*((LLaMAGrammarWrapper*)ptr));
|
1203
|
+
}
|
1204
|
+
|
1205
|
+
static LLaMAGrammarWrapper* get_llama_grammar(VALUE self) {
|
1206
|
+
LLaMAGrammarWrapper* ptr;
|
1207
|
+
TypedData_Get_Struct(self, LLaMAGrammarWrapper, &llama_grammar_type, ptr);
|
1208
|
+
return ptr;
|
1209
|
+
}
|
1210
|
+
|
1211
|
+
static void define_class(VALUE outer) {
|
1212
|
+
rb_cLLaMAGrammar = rb_define_class_under(outer, "Grammar", rb_cObject);
|
1213
|
+
rb_define_alloc_func(rb_cLLaMAGrammar, llama_grammar_alloc);
|
1214
|
+
rb_define_method(rb_cLLaMAGrammar, "initialize", RUBY_METHOD_FUNC(_llama_grammar_init), -1);
|
1215
|
+
}
|
1216
|
+
|
1217
|
+
private:
|
1218
|
+
static const rb_data_type_t llama_grammar_type;
|
1219
|
+
|
1220
|
+
static VALUE _llama_grammar_init(int argc, VALUE* argv, VALUE self) {
|
1221
|
+
VALUE kw_args = Qnil;
|
1222
|
+
ID kw_table[2] = { rb_intern("rules"), rb_intern("start_rule_index") };
|
1223
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
1224
|
+
rb_scan_args(argc, argv, ":", &kw_args);
|
1225
|
+
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
1226
|
+
|
1227
|
+
if (!RB_TYPE_P(kw_values[0], T_ARRAY)) {
|
1228
|
+
rb_raise(rb_eArgError, "rules must be an array");
|
1229
|
+
return Qnil;
|
1230
|
+
}
|
1231
|
+
if (!RB_INTEGER_TYPE_P(kw_values[1])) {
|
1232
|
+
rb_raise(rb_eArgError, "start_rule_index must be an integer");
|
1233
|
+
return Qnil;
|
1234
|
+
}
|
1235
|
+
|
1236
|
+
const int n_rules = RARRAY_LEN(kw_values[0]);
|
1237
|
+
llama_grammar_element** rules = ALLOCA_N(llama_grammar_element*, n_rules);
|
1238
|
+
for (int i = 0; i < n_rules; ++i) {
|
1239
|
+
VALUE rule = rb_ary_entry(kw_values[0], i);
|
1240
|
+
if (!RB_TYPE_P(rule, T_ARRAY)) {
|
1241
|
+
rb_raise(rb_eArgError, "element of rules must be an array");
|
1242
|
+
return Qnil;
|
1243
|
+
}
|
1244
|
+
const int n_elements = RARRAY_LEN(rule);
|
1245
|
+
llama_grammar_element* elements = ALLOCA_N(llama_grammar_element, n_elements);
|
1246
|
+
for (int j = 0; j < n_elements; ++j) {
|
1247
|
+
VALUE element = rb_ary_entry(rule, j);
|
1248
|
+
if (!rb_obj_is_kind_of(element, rb_cLLaMAGrammarElement)) {
|
1249
|
+
rb_raise(rb_eArgError, "element of rule must be an instance of GrammarElement");
|
1250
|
+
return Qnil;
|
1251
|
+
}
|
1252
|
+
LLaMAGrammarElementWrapper* ptr = RbLLaMAGrammarElement::get_llama_grammar_element(element);
|
1253
|
+
elements[j] = ptr->element;
|
1254
|
+
}
|
1255
|
+
rules[i] = elements;
|
1256
|
+
}
|
1257
|
+
|
1258
|
+
const size_t start_rule_index = NUM2SIZET(kw_values[1]);
|
1259
|
+
|
1260
|
+
LLaMAGrammarWrapper* ptr = get_llama_grammar(self);
|
1261
|
+
new (ptr) LLaMAGrammarWrapper();
|
1262
|
+
ptr->grammar = llama_grammar_init((const llama_grammar_element**)rules, n_rules, start_rule_index);
|
1263
|
+
|
1264
|
+
return self;
|
1265
|
+
}
|
1266
|
+
};
|
1267
|
+
|
1268
|
+
const rb_data_type_t RbLLaMAGrammar::llama_grammar_type = {
|
1269
|
+
"RbLLaMAGrammar",
|
1270
|
+
{ NULL,
|
1271
|
+
RbLLaMAGrammar::llama_grammar_free,
|
1272
|
+
RbLLaMAGrammar::llama_grammar_size },
|
1273
|
+
NULL,
|
1274
|
+
NULL,
|
1275
|
+
RUBY_TYPED_FREE_IMMEDIATELY
|
1276
|
+
};
|
1277
|
+
|
923
1278
|
class LLaMAContextWrapper {
|
924
1279
|
public:
|
925
1280
|
struct llama_context* ctx;
|
@@ -991,6 +1346,8 @@ public:
|
|
991
1346
|
rb_define_method(rb_cLLaMAContext, "sample_token_mirostat_v2", RUBY_METHOD_FUNC(_llama_context_sample_token_mirostat_v2), -1);
|
992
1347
|
rb_define_method(rb_cLLaMAContext, "sample_token_greedy", RUBY_METHOD_FUNC(_llama_context_sample_token_greedy), 1);
|
993
1348
|
rb_define_method(rb_cLLaMAContext, "sample_token", RUBY_METHOD_FUNC(_llama_context_sample_token), 1);
|
1349
|
+
rb_define_method(rb_cLLaMAContext, "sample_grammar", RUBY_METHOD_FUNC(_llama_context_sample_grammar), -1);
|
1350
|
+
rb_define_method(rb_cLLaMAContext, "grammar_accept_token", RUBY_METHOD_FUNC(_llama_context_grammar_accept_token), -1);
|
994
1351
|
}
|
995
1352
|
|
996
1353
|
private:
|
@@ -1581,11 +1938,11 @@ private:
|
|
1581
1938
|
|
1582
1939
|
static VALUE _llama_context_sample_classifier_free_guidance(int argc, VALUE* argv, VALUE self) {
|
1583
1940
|
VALUE kw_args = Qnil;
|
1584
|
-
ID kw_table[
|
1585
|
-
VALUE kw_values[
|
1941
|
+
ID kw_table[2] = { rb_intern("guidance"), rb_intern("scale") };
|
1942
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
1586
1943
|
VALUE candidates = Qnil;
|
1587
1944
|
rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
|
1588
|
-
rb_get_kwargs(kw_args, kw_table,
|
1945
|
+
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
1589
1946
|
|
1590
1947
|
if (!rb_obj_is_kind_of(kw_values[0], rb_cLLaMAContext)) {
|
1591
1948
|
rb_raise(rb_eArgError, "guidance must be a Context");
|
@@ -1595,10 +1952,6 @@ private:
|
|
1595
1952
|
rb_raise(rb_eArgError, "scale must be a float");
|
1596
1953
|
return Qnil;
|
1597
1954
|
}
|
1598
|
-
if (!RB_FLOAT_TYPE_P(kw_values[2])) {
|
1599
|
-
rb_raise(rb_eArgError, "smooth_factor must be a float");
|
1600
|
-
return Qnil;
|
1601
|
-
}
|
1602
1955
|
|
1603
1956
|
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
1604
1957
|
if (ctx_ptr->ctx == NULL) {
|
@@ -1617,9 +1970,8 @@ private:
|
|
1617
1970
|
return Qnil;
|
1618
1971
|
}
|
1619
1972
|
const float scale = NUM2DBL(kw_values[1]);
|
1620
|
-
const float smooth_factor = NUM2DBL(kw_values[2]);
|
1621
1973
|
|
1622
|
-
llama_sample_classifier_free_guidance(ctx_ptr->ctx, &(cnd_ptr->array), guidance_ptr->ctx, scale
|
1974
|
+
llama_sample_classifier_free_guidance(ctx_ptr->ctx, &(cnd_ptr->array), guidance_ptr->ctx, scale);
|
1623
1975
|
|
1624
1976
|
return Qnil;
|
1625
1977
|
}
|
@@ -1972,6 +2324,69 @@ private:
|
|
1972
2324
|
llama_token id = llama_sample_token(ctx_ptr->ctx, &(cnd_ptr->array));
|
1973
2325
|
return INT2NUM(id);
|
1974
2326
|
}
|
2327
|
+
|
2328
|
+
static VALUE _llama_context_sample_grammar(int argc, VALUE* argv, VALUE self) {
|
2329
|
+
VALUE kw_args = Qnil;
|
2330
|
+
ID kw_table[1] = { rb_intern("grammar") };
|
2331
|
+
VALUE kw_values[1] = { Qundef };
|
2332
|
+
VALUE candidates = Qnil;
|
2333
|
+
rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
|
2334
|
+
rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
|
2335
|
+
|
2336
|
+
if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
|
2337
|
+
rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
|
2338
|
+
return Qnil;
|
2339
|
+
}
|
2340
|
+
if (!rb_obj_is_kind_of(kw_values[0], rb_cLLaMAGrammar)) {
|
2341
|
+
rb_raise(rb_eArgError, "grammar must be a Grammar");
|
2342
|
+
return Qnil;
|
2343
|
+
}
|
2344
|
+
|
2345
|
+
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
2346
|
+
if (ctx_ptr->ctx == NULL) {
|
2347
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
2348
|
+
return Qnil;
|
2349
|
+
}
|
2350
|
+
LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
|
2351
|
+
if (cnd_ptr->array.data == nullptr) {
|
2352
|
+
rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
|
2353
|
+
return Qnil;
|
2354
|
+
}
|
2355
|
+
LLaMAGrammarWrapper* grm_ptr = RbLLaMAGrammar::get_llama_grammar(kw_values[0]);
|
2356
|
+
|
2357
|
+
llama_sample_grammar(ctx_ptr->ctx, &(cnd_ptr->array), grm_ptr->grammar);
|
2358
|
+
|
2359
|
+
return Qnil;
|
2360
|
+
}
|
2361
|
+
|
2362
|
+
static VALUE _llama_context_grammar_accept_token(int argc, VALUE* argv, VALUE self) {
|
2363
|
+
VALUE kw_args = Qnil;
|
2364
|
+
ID kw_table[2] = { rb_intern("grammar"), rb_intern("token") };
|
2365
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
2366
|
+
rb_scan_args(argc, argv, ":", &kw_args);
|
2367
|
+
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
2368
|
+
|
2369
|
+
if (!rb_obj_is_kind_of(kw_values[0], rb_cLLaMAGrammar)) {
|
2370
|
+
rb_raise(rb_eArgError, "grammar must be a Grammar");
|
2371
|
+
return Qnil;
|
2372
|
+
}
|
2373
|
+
if (!RB_INTEGER_TYPE_P(kw_values[1])) {
|
2374
|
+
rb_raise(rb_eArgError, "token must be an Integer");
|
2375
|
+
return Qnil;
|
2376
|
+
}
|
2377
|
+
|
2378
|
+
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
2379
|
+
if (ctx_ptr->ctx == NULL) {
|
2380
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
2381
|
+
return Qnil;
|
2382
|
+
}
|
2383
|
+
LLaMAGrammarWrapper* grm_ptr = RbLLaMAGrammar::get_llama_grammar(kw_values[0]);
|
2384
|
+
llama_token token = NUM2INT(kw_values[1]);
|
2385
|
+
|
2386
|
+
llama_grammar_accept_token(ctx_ptr->ctx, grm_ptr->grammar, token);
|
2387
|
+
|
2388
|
+
return Qnil;
|
2389
|
+
}
|
1975
2390
|
};
|
1976
2391
|
|
1977
2392
|
const rb_data_type_t RbLLaMAContext::llama_context_type = {
|
@@ -2062,6 +2477,10 @@ static VALUE rb_llama_mlock_supported(VALUE self) {
|
|
2062
2477
|
return llama_mlock_supported() ? Qtrue : Qfalse;
|
2063
2478
|
}
|
2064
2479
|
|
2480
|
+
static VALUE rb_llama_max_devices(VALUE self) {
|
2481
|
+
return INT2NUM(llama_max_devices());
|
2482
|
+
}
|
2483
|
+
|
2065
2484
|
extern "C" void Init_llama_cpp(void) {
|
2066
2485
|
rb_mLLaMACpp = rb_define_module("LLaMACpp");
|
2067
2486
|
|
@@ -2072,6 +2491,8 @@ extern "C" void Init_llama_cpp(void) {
|
|
2072
2491
|
RbLLaMAContext::define_class(rb_mLLaMACpp);
|
2073
2492
|
RbLLaMAContextParams::define_class(rb_mLLaMACpp);
|
2074
2493
|
RbLLaMAModelQuantizeParams::define_class(rb_mLLaMACpp);
|
2494
|
+
RbLLaMAGrammarElement::define_class(rb_mLLaMACpp);
|
2495
|
+
RbLLaMAGrammar::define_class(rb_mLLaMACpp);
|
2075
2496
|
|
2076
2497
|
rb_define_module_function(rb_mLLaMACpp, "backend_init", rb_llama_llama_backend_init, -1);
|
2077
2498
|
rb_define_module_function(rb_mLLaMACpp, "backend_free", rb_llama_llama_backend_free, 0);
|
@@ -2082,6 +2503,7 @@ extern "C" void Init_llama_cpp(void) {
|
|
2082
2503
|
rb_define_module_function(rb_mLLaMACpp, "print_system_info", rb_llama_print_system_info, 0);
|
2083
2504
|
rb_define_module_function(rb_mLLaMACpp, "mmap_supported?", rb_llama_mmap_supported, 0);
|
2084
2505
|
rb_define_module_function(rb_mLLaMACpp, "mlock_supported?", rb_llama_mlock_supported, 0);
|
2506
|
+
rb_define_module_function(rb_mLLaMACpp, "max_devices", rb_llama_max_devices, 0);
|
2085
2507
|
|
2086
2508
|
rb_define_const(rb_mLLaMACpp, "LLAMA_MAX_DEVICES", INT2NUM(LLAMA_MAX_DEVICES));
|
2087
2509
|
|
@@ -2103,6 +2525,14 @@ extern "C" void Init_llama_cpp(void) {
|
|
2103
2525
|
rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q5_K_M", INT2NUM(LLAMA_FTYPE_MOSTLY_Q5_K_M));
|
2104
2526
|
rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q6_K", INT2NUM(LLAMA_FTYPE_MOSTLY_Q6_K));
|
2105
2527
|
|
2528
|
+
rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_END", INT2NUM(LLAMA_GRETYPE_END));
|
2529
|
+
rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_ALT", INT2NUM(LLAMA_GRETYPE_ALT));
|
2530
|
+
rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_RULE_REF", INT2NUM(LLAMA_GRETYPE_RULE_REF));
|
2531
|
+
rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR", INT2NUM(LLAMA_GRETYPE_CHAR));
|
2532
|
+
rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR_NOT", INT2NUM(LLAMA_GRETYPE_CHAR_NOT));
|
2533
|
+
rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR_RNG_UPPER", INT2NUM(LLAMA_GRETYPE_CHAR_RNG_UPPER));
|
2534
|
+
rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR_ALT", INT2NUM(LLAMA_GRETYPE_CHAR_ALT));
|
2535
|
+
|
2106
2536
|
std::stringstream ss_magic;
|
2107
2537
|
ss_magic << std::showbase << std::hex << LLAMA_FILE_MAGIC_GGJT;
|
2108
2538
|
rb_define_const(rb_mLLaMACpp, "LLAMA_FILE_MAGIC_GGJT", rb_str_new2(ss_magic.str().c_str()));
|