llama_cpp 0.3.3 → 0.3.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +31 -0
- data/ext/llama_cpp/extconf.rb +1 -0
- data/ext/llama_cpp/llama_cpp.cpp +439 -9
- data/ext/llama_cpp/src/ggml-cuda.cu +759 -136
- data/ext/llama_cpp/src/ggml-metal.h +7 -0
- data/ext/llama_cpp/src/ggml-metal.m +250 -111
- data/ext/llama_cpp/src/ggml-metal.metal +614 -483
- data/ext/llama_cpp/src/ggml.c +793 -1032
- data/ext/llama_cpp/src/ggml.h +95 -18
- data/ext/llama_cpp/src/k_quants.c +327 -3
- data/ext/llama_cpp/src/k_quants.h +8 -0
- data/ext/llama_cpp/src/llama.cpp +626 -166
- data/ext/llama_cpp/src/llama.h +94 -10
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +1 -0
- data/sig/llama_cpp.rbs +36 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 991df3df6b16ec98a203a6c6565794988eec04697ccb963faab976f436e1bfcc
|
4
|
+
data.tar.gz: dafd3b8274640eb79353e11056f497a02392fef332a46dcc1717878c836f62bd
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 0a54fdf18c5be5273f01d4d991b59975a8e8b6a0a8f54087fb90df3f1a8a7ebad557d01fb119f3d61cafa8f6c59fb81641624779fda27b732eea2868cb4642e8
|
7
|
+
data.tar.gz: 6574064078070502e36ad933bd5efb2f479c94f47a1260286fe485e48d35a7b3985274943c6163189326891349b47b1d9815d77c600b015fe111cc3842179392
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,34 @@
|
|
1
|
+
## [[0.3.5](https://github.com/yoshoku/llama_cpp.rb/compare/v0.3.4...v0.3.5)] - 2023-07-29
|
2
|
+
|
3
|
+
- Bump bundled llama.cpp from master-d924522 to master-1a94186.
|
4
|
+
- Add `GrammarElement` and `Grammar` classes.
|
5
|
+
- Add `sample_grammar` method to Context.
|
6
|
+
- Add `grammar_accept_token method` method to Context.
|
7
|
+
|
8
|
+
## [[0.3.4](https://github.com/yoshoku/llama_cpp.rb/compare/v0.3.3...v0.3.4)] - 2023-07-23
|
9
|
+
|
10
|
+
- Bump bundled llama.cpp from master-32c5411 to master-d924522.
|
11
|
+
- Add `rope_freq_base` and `rope_freq_scale` options to ContextParams.
|
12
|
+
- Add `max_devices` module function to LLaMACpp.
|
13
|
+
- Add `n_vocab`, `n_ctx`, and `n_embd` methods to Model.
|
14
|
+
- Add `vocab`, `tokenize`, and `token_to_str` methods to Model.
|
15
|
+
```ruby
|
16
|
+
require 'llama_cpp'
|
17
|
+
|
18
|
+
params = LLaMACpp::ContextParams.new
|
19
|
+
model = LLaMACpp::Model.new(model_path: '/path/to/model.bin', params: params)
|
20
|
+
|
21
|
+
p model.tokenize(text: 'hello, world')
|
22
|
+
# => [12199, 29892, 3186]
|
23
|
+
|
24
|
+
p model.token_to_str(12199)
|
25
|
+
# => "hello"
|
26
|
+
```
|
27
|
+
|
28
|
+
**Breaking Changes**
|
29
|
+
- Fix to automatically call `backend_free` method when Ruby script exits.
|
30
|
+
- Remove `smooth_factor` argument from `sample_classifier_free_guidance methos` on Context.
|
31
|
+
|
1
32
|
## [[0.3.3](https://github.com/yoshoku/llama_cpp.rb/compare/v0.3.2...v0.3.3)] - 2023-07-15
|
2
33
|
|
3
34
|
- Bump bundled llama.cpp from master-481f793 to master-32c5411.
|
data/ext/llama_cpp/extconf.rb
CHANGED
data/ext/llama_cpp/llama_cpp.cpp
CHANGED
@@ -8,6 +8,8 @@ VALUE rb_cLLaMAContextParams;
|
|
8
8
|
VALUE rb_cLLaMAModelQuantizeParams;
|
9
9
|
VALUE rb_cLLaMATokenData;
|
10
10
|
VALUE rb_cLLaMATokenDataArray;
|
11
|
+
VALUE rb_cLLaMAGrammarElement;
|
12
|
+
VALUE rb_cLLaMAGrammar;
|
11
13
|
|
12
14
|
class LLaMATokenDataWrapper {
|
13
15
|
public:
|
@@ -406,6 +408,10 @@ public:
|
|
406
408
|
rb_define_method(rb_cLLaMAContextParams, "main_gpu=", RUBY_METHOD_FUNC(_llama_context_params_set_main_gpu), 1);
|
407
409
|
rb_define_method(rb_cLLaMAContextParams, "main_gpu", RUBY_METHOD_FUNC(_llama_context_params_get_main_gpu), 0);
|
408
410
|
rb_define_method(rb_cLLaMAContextParams, "tensor_split", RUBY_METHOD_FUNC(_llama_context_params_get_tensor_split), 0);
|
411
|
+
rb_define_method(rb_cLLaMAContextParams, "rope_freq_base=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_freq_base), 1);
|
412
|
+
rb_define_method(rb_cLLaMAContextParams, "rope_freq_base", RUBY_METHOD_FUNC(_llama_context_params_get_rope_freq_base), 0);
|
413
|
+
rb_define_method(rb_cLLaMAContextParams, "rope_freq_scale=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_freq_scale), 1);
|
414
|
+
rb_define_method(rb_cLLaMAContextParams, "rope_freq_scale", RUBY_METHOD_FUNC(_llama_context_params_get_rope_freq_scale), 0);
|
409
415
|
rb_define_method(rb_cLLaMAContextParams, "low_vram=", RUBY_METHOD_FUNC(_llama_context_params_set_low_vram), 1);
|
410
416
|
rb_define_method(rb_cLLaMAContextParams, "low_vram", RUBY_METHOD_FUNC(_llama_context_params_get_low_vram), 0);
|
411
417
|
rb_define_method(rb_cLLaMAContextParams, "seed=", RUBY_METHOD_FUNC(_llama_context_params_set_seed), 1);
|
@@ -494,6 +500,30 @@ private:
|
|
494
500
|
return ret;
|
495
501
|
}
|
496
502
|
|
503
|
+
// rope_freq_base
|
504
|
+
static VALUE _llama_context_params_set_rope_freq_base(VALUE self, VALUE rope_freq_base) {
|
505
|
+
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
506
|
+
ptr->params.rope_freq_base = NUM2DBL(rope_freq_base);
|
507
|
+
return DBL2NUM(ptr->params.rope_freq_base);
|
508
|
+
}
|
509
|
+
|
510
|
+
static VALUE _llama_context_params_get_rope_freq_base(VALUE self) {
|
511
|
+
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
512
|
+
return DBL2NUM(ptr->params.rope_freq_base);
|
513
|
+
}
|
514
|
+
|
515
|
+
// rope_freq_scale
|
516
|
+
static VALUE _llama_context_params_set_rope_freq_scale(VALUE self, VALUE rope_freq_scale) {
|
517
|
+
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
518
|
+
ptr->params.rope_freq_scale = NUM2DBL(rope_freq_scale);
|
519
|
+
return DBL2NUM(ptr->params.rope_freq_scale);
|
520
|
+
}
|
521
|
+
|
522
|
+
static VALUE _llama_context_params_get_rope_freq_scale(VALUE self) {
|
523
|
+
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
524
|
+
return DBL2NUM(ptr->params.rope_freq_scale);
|
525
|
+
}
|
526
|
+
|
497
527
|
// low_vram
|
498
528
|
static VALUE _llama_context_params_set_low_vram(VALUE self, VALUE low_vram) {
|
499
529
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
@@ -764,6 +794,12 @@ public:
|
|
764
794
|
rb_define_method(rb_cLLaMAModel, "free", RUBY_METHOD_FUNC(_llama_model_free), 0);
|
765
795
|
rb_define_method(rb_cLLaMAModel, "load", RUBY_METHOD_FUNC(_llama_model_load), -1);
|
766
796
|
rb_define_method(rb_cLLaMAModel, "apply_lora_from_file", RUBY_METHOD_FUNC(_llama_model_apply_lora_from_file), -1);
|
797
|
+
rb_define_method(rb_cLLaMAModel, "n_vocab", RUBY_METHOD_FUNC(_llama_model_get_n_vocab_from_model), 0);
|
798
|
+
rb_define_method(rb_cLLaMAModel, "n_ctx", RUBY_METHOD_FUNC(_llama_model_get_n_ctx_from_model), 0);
|
799
|
+
rb_define_method(rb_cLLaMAModel, "n_embd", RUBY_METHOD_FUNC(_llama_model_get_n_embd_from_model), 0);
|
800
|
+
rb_define_method(rb_cLLaMAModel, "vocab", RUBY_METHOD_FUNC(_llama_model_get_vocab_from_model), -1);
|
801
|
+
rb_define_method(rb_cLLaMAModel, "token_to_str", RUBY_METHOD_FUNC(_llama_model_token_to_str_with_model), 1);
|
802
|
+
rb_define_method(rb_cLLaMAModel, "tokenize", RUBY_METHOD_FUNC(_llama_model_tokenize_with_model), -1);
|
767
803
|
}
|
768
804
|
|
769
805
|
private:
|
@@ -908,6 +944,109 @@ private:
|
|
908
944
|
}
|
909
945
|
return Qnil;
|
910
946
|
}
|
947
|
+
|
948
|
+
static VALUE _llama_model_get_n_vocab_from_model(VALUE self) {
|
949
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
950
|
+
return INT2NUM(llama_n_vocab_from_model(ptr->model));
|
951
|
+
}
|
952
|
+
|
953
|
+
static VALUE _llama_model_get_n_ctx_from_model(VALUE self) {
|
954
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
955
|
+
return INT2NUM(llama_n_ctx_from_model(ptr->model));
|
956
|
+
}
|
957
|
+
|
958
|
+
static VALUE _llama_model_get_n_embd_from_model(VALUE self) {
|
959
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
960
|
+
return INT2NUM(llama_n_embd_from_model(ptr->model));
|
961
|
+
}
|
962
|
+
|
963
|
+
static VALUE _llama_model_get_vocab_from_model(int argc, VALUE* argv, VALUE self) {
|
964
|
+
VALUE kw_args = Qnil;
|
965
|
+
ID kw_table[1] = { rb_intern("capacity") };
|
966
|
+
VALUE kw_values[1] = { Qundef };
|
967
|
+
rb_scan_args(argc, argv, ":", &kw_args);
|
968
|
+
rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
|
969
|
+
|
970
|
+
if (!RB_INTEGER_TYPE_P(kw_values[0])) {
|
971
|
+
rb_raise(rb_eArgError, "capacity must be an integer");
|
972
|
+
return Qnil;
|
973
|
+
}
|
974
|
+
|
975
|
+
const int capacity = NUM2INT(kw_values[0]);
|
976
|
+
|
977
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
978
|
+
const int n = std::min(capacity, llama_n_vocab_from_model(ptr->model));
|
979
|
+
const char** vocabs = ALLOCA_N(const char*, n);
|
980
|
+
float* scores = ALLOCA_N(float, n);
|
981
|
+
|
982
|
+
llama_get_vocab_from_model(ptr->model, vocabs, scores, capacity);
|
983
|
+
|
984
|
+
VALUE vocabs_ary = rb_ary_new();
|
985
|
+
VALUE scores_ary = rb_ary_new();
|
986
|
+
|
987
|
+
for (int i = 0; i < n; i++) {
|
988
|
+
rb_ary_push(vocabs_ary, rb_str_new_cstr(vocabs[i]));
|
989
|
+
rb_ary_push(scores_ary, DBL2NUM(scores[i]));
|
990
|
+
}
|
991
|
+
|
992
|
+
VALUE ret = rb_ary_new3(2, vocabs_ary, scores_ary);
|
993
|
+
|
994
|
+
return ret;
|
995
|
+
}
|
996
|
+
|
997
|
+
static VALUE _llama_model_token_to_str_with_model(VALUE self, VALUE token_) {
|
998
|
+
if (!RB_INTEGER_TYPE_P(token_)) {
|
999
|
+
rb_raise(rb_eArgError, "token must be an integer");
|
1000
|
+
return Qnil;
|
1001
|
+
}
|
1002
|
+
const llama_token token = NUM2INT(token_);
|
1003
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1004
|
+
const char* str = llama_token_to_str_with_model(ptr->model, token);
|
1005
|
+
return rb_str_new_cstr(str);
|
1006
|
+
}
|
1007
|
+
|
1008
|
+
static VALUE _llama_model_tokenize_with_model(int argc, VALUE* argv, VALUE self) {
|
1009
|
+
VALUE kw_args = Qnil;
|
1010
|
+
ID kw_table[3] = { rb_intern("text"), rb_intern("n_max_tokens"), rb_intern("add_bos") };
|
1011
|
+
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
1012
|
+
rb_scan_args(argc, argv, ":", &kw_args);
|
1013
|
+
rb_get_kwargs(kw_args, kw_table, 1, 2, kw_values);
|
1014
|
+
|
1015
|
+
if (!RB_TYPE_P(kw_values[0], T_STRING)) {
|
1016
|
+
rb_raise(rb_eArgError, "text must be a String");
|
1017
|
+
return Qnil;
|
1018
|
+
}
|
1019
|
+
if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
|
1020
|
+
rb_raise(rb_eArgError, "n_max_tokens must be an integer");
|
1021
|
+
return Qnil;
|
1022
|
+
}
|
1023
|
+
if (kw_values[2] != Qundef && (kw_values[2] != Qtrue && kw_values[2] != Qfalse)) {
|
1024
|
+
rb_raise(rb_eArgError, "add_bos must be a boolean");
|
1025
|
+
return Qnil;
|
1026
|
+
}
|
1027
|
+
|
1028
|
+
VALUE text_ = kw_values[0];
|
1029
|
+
std::string text = StringValueCStr(text_);
|
1030
|
+
const bool add_bos = kw_values[2] == Qtrue ? true : false;
|
1031
|
+
const int n_max_tokens = kw_values[1] != Qundef ? NUM2INT(kw_values[1]) : text.size() + (add_bos ? 1 : 0);
|
1032
|
+
|
1033
|
+
llama_token* tokens = ALLOCA_N(llama_token, n_max_tokens);
|
1034
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1035
|
+
const int n_tokens = llama_tokenize_with_model(ptr->model, text.c_str(), tokens, n_max_tokens, add_bos);
|
1036
|
+
|
1037
|
+
if (n_tokens < 0) {
|
1038
|
+
rb_raise(rb_eRuntimeError, "failed to tokenize. The numebr of tokens (%d) is greater than n_max_tokens.", -n_tokens);
|
1039
|
+
return Qnil;
|
1040
|
+
}
|
1041
|
+
|
1042
|
+
VALUE ret = rb_ary_new2(n_tokens);
|
1043
|
+
for (int i = 0; i < n_tokens; i++) {
|
1044
|
+
rb_ary_store(ret, i, INT2NUM(tokens[i]));
|
1045
|
+
}
|
1046
|
+
|
1047
|
+
RB_GC_GUARD(text_);
|
1048
|
+
return ret;
|
1049
|
+
}
|
911
1050
|
};
|
912
1051
|
|
913
1052
|
const rb_data_type_t RbLLaMAModel::llama_model_type = {
|
@@ -920,6 +1059,222 @@ const rb_data_type_t RbLLaMAModel::llama_model_type = {
|
|
920
1059
|
RUBY_TYPED_FREE_IMMEDIATELY
|
921
1060
|
};
|
922
1061
|
|
1062
|
+
class LLaMAGrammarElementWrapper {
|
1063
|
+
public:
|
1064
|
+
llama_grammar_element element;
|
1065
|
+
|
1066
|
+
LLaMAGrammarElementWrapper() {
|
1067
|
+
element.type = LLAMA_GRETYPE_END;
|
1068
|
+
element.value = 0;
|
1069
|
+
}
|
1070
|
+
|
1071
|
+
~LLaMAGrammarElementWrapper() {}
|
1072
|
+
};
|
1073
|
+
|
1074
|
+
class RbLLaMAGrammarElement {
|
1075
|
+
public:
|
1076
|
+
static VALUE llama_grammar_element_alloc(VALUE self) {
|
1077
|
+
LLaMAGrammarElementWrapper* ptr = (LLaMAGrammarElementWrapper*)ruby_xmalloc(sizeof(LLaMAGrammarElementWrapper));
|
1078
|
+
new (ptr) LLaMAGrammarElementWrapper();
|
1079
|
+
return TypedData_Wrap_Struct(self, &llama_grammar_element_type, ptr);
|
1080
|
+
}
|
1081
|
+
|
1082
|
+
static void llama_grammar_element_free(void* ptr) {
|
1083
|
+
((LLaMAGrammarElementWrapper*)ptr)->~LLaMAGrammarElementWrapper();
|
1084
|
+
ruby_xfree(ptr);
|
1085
|
+
}
|
1086
|
+
|
1087
|
+
static size_t llama_grammar_element_size(const void* ptr) {
|
1088
|
+
return sizeof(*((LLaMAGrammarElementWrapper*)ptr));
|
1089
|
+
}
|
1090
|
+
|
1091
|
+
static LLaMAGrammarElementWrapper* get_llama_grammar_element(VALUE self) {
|
1092
|
+
LLaMAGrammarElementWrapper* ptr;
|
1093
|
+
TypedData_Get_Struct(self, LLaMAGrammarElementWrapper, &llama_grammar_element_type, ptr);
|
1094
|
+
return ptr;
|
1095
|
+
}
|
1096
|
+
|
1097
|
+
static void define_class(VALUE outer) {
|
1098
|
+
rb_cLLaMAGrammarElement = rb_define_class_under(outer, "GrammarElement", rb_cObject);
|
1099
|
+
rb_define_alloc_func(rb_cLLaMAGrammarElement, llama_grammar_element_alloc);
|
1100
|
+
rb_define_method(rb_cLLaMAGrammarElement, "initialize", RUBY_METHOD_FUNC(_llama_grammar_element_init), -1);
|
1101
|
+
rb_define_method(rb_cLLaMAGrammarElement, "type=", RUBY_METHOD_FUNC(_llama_grammar_element_set_type), 1);
|
1102
|
+
rb_define_method(rb_cLLaMAGrammarElement, "type", RUBY_METHOD_FUNC(_llama_grammar_element_get_type), 0);
|
1103
|
+
rb_define_method(rb_cLLaMAGrammarElement, "value=", RUBY_METHOD_FUNC(_llama_grammar_element_set_value), 1);
|
1104
|
+
rb_define_method(rb_cLLaMAGrammarElement, "value", RUBY_METHOD_FUNC(_llama_grammar_element_get_value), 0);
|
1105
|
+
}
|
1106
|
+
|
1107
|
+
private:
|
1108
|
+
static const rb_data_type_t llama_grammar_element_type;
|
1109
|
+
|
1110
|
+
static VALUE _llama_grammar_element_init(int argc, VALUE* argv, VALUE self) {
|
1111
|
+
VALUE kw_args = Qnil;
|
1112
|
+
ID kw_table[2] = { rb_intern("type"), rb_intern("value") };
|
1113
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
1114
|
+
VALUE arr = Qnil;
|
1115
|
+
rb_scan_args(argc, argv, ":", &arr, &kw_args);
|
1116
|
+
rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
|
1117
|
+
|
1118
|
+
if (kw_values[0] != Qundef && !RB_INTEGER_TYPE_P(kw_values[0])) {
|
1119
|
+
rb_raise(rb_eArgError, "type must be an integer");
|
1120
|
+
return Qnil;
|
1121
|
+
}
|
1122
|
+
if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
|
1123
|
+
rb_raise(rb_eArgError, "value must be an integer");
|
1124
|
+
return Qnil;
|
1125
|
+
}
|
1126
|
+
|
1127
|
+
LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
|
1128
|
+
new (ptr) LLaMAGrammarElementWrapper();
|
1129
|
+
|
1130
|
+
if (kw_values[0] != Qundef) {
|
1131
|
+
ptr->element.type = (enum llama_gretype)NUM2INT(kw_values[0]);
|
1132
|
+
}
|
1133
|
+
if (kw_values[1] != Qundef) {
|
1134
|
+
ptr->element.value = NUM2INT(kw_values[1]);
|
1135
|
+
}
|
1136
|
+
|
1137
|
+
return self;
|
1138
|
+
}
|
1139
|
+
|
1140
|
+
// type
|
1141
|
+
static VALUE _llama_grammar_element_set_type(VALUE self, VALUE type) {
|
1142
|
+
LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
|
1143
|
+
ptr->element.type = (enum llama_gretype)NUM2INT(type);
|
1144
|
+
return INT2NUM(ptr->element.type);
|
1145
|
+
}
|
1146
|
+
|
1147
|
+
static VALUE _llama_grammar_element_get_type(VALUE self) {
|
1148
|
+
LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
|
1149
|
+
return INT2NUM(ptr->element.type);
|
1150
|
+
}
|
1151
|
+
|
1152
|
+
// value
|
1153
|
+
static VALUE _llama_grammar_element_set_value(VALUE self, VALUE type) {
|
1154
|
+
LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
|
1155
|
+
ptr->element.value = NUM2INT(type);
|
1156
|
+
return INT2NUM(ptr->element.value);
|
1157
|
+
}
|
1158
|
+
|
1159
|
+
static VALUE _llama_grammar_element_get_value(VALUE self) {
|
1160
|
+
LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
|
1161
|
+
return INT2NUM(ptr->element.value);
|
1162
|
+
}
|
1163
|
+
};
|
1164
|
+
|
1165
|
+
const rb_data_type_t RbLLaMAGrammarElement::llama_grammar_element_type = {
|
1166
|
+
"RbLLaMAGrammarElement",
|
1167
|
+
{ NULL,
|
1168
|
+
RbLLaMAGrammarElement::llama_grammar_element_free,
|
1169
|
+
RbLLaMAGrammarElement::llama_grammar_element_size },
|
1170
|
+
NULL,
|
1171
|
+
NULL,
|
1172
|
+
RUBY_TYPED_FREE_IMMEDIATELY
|
1173
|
+
};
|
1174
|
+
|
1175
|
+
class LLaMAGrammarWrapper {
|
1176
|
+
public:
|
1177
|
+
struct llama_grammar* grammar;
|
1178
|
+
|
1179
|
+
LLaMAGrammarWrapper() : grammar(nullptr) {}
|
1180
|
+
|
1181
|
+
~LLaMAGrammarWrapper() {
|
1182
|
+
if (grammar) {
|
1183
|
+
llama_grammar_free(grammar);
|
1184
|
+
}
|
1185
|
+
}
|
1186
|
+
};
|
1187
|
+
|
1188
|
+
class RbLLaMAGrammar {
|
1189
|
+
public:
|
1190
|
+
static VALUE llama_grammar_alloc(VALUE self) {
|
1191
|
+
LLaMAGrammarWrapper* ptr = (LLaMAGrammarWrapper*)ruby_xmalloc(sizeof(LLaMAGrammarWrapper));
|
1192
|
+
new (ptr) LLaMAGrammarWrapper();
|
1193
|
+
return TypedData_Wrap_Struct(self, &llama_grammar_type, ptr);
|
1194
|
+
}
|
1195
|
+
|
1196
|
+
static void llama_grammar_free(void* ptr) {
|
1197
|
+
((LLaMAGrammarWrapper*)ptr)->~LLaMAGrammarWrapper();
|
1198
|
+
ruby_xfree(ptr);
|
1199
|
+
}
|
1200
|
+
|
1201
|
+
static size_t llama_grammar_size(const void* ptr) {
|
1202
|
+
return sizeof(*((LLaMAGrammarWrapper*)ptr));
|
1203
|
+
}
|
1204
|
+
|
1205
|
+
static LLaMAGrammarWrapper* get_llama_grammar(VALUE self) {
|
1206
|
+
LLaMAGrammarWrapper* ptr;
|
1207
|
+
TypedData_Get_Struct(self, LLaMAGrammarWrapper, &llama_grammar_type, ptr);
|
1208
|
+
return ptr;
|
1209
|
+
}
|
1210
|
+
|
1211
|
+
static void define_class(VALUE outer) {
|
1212
|
+
rb_cLLaMAGrammar = rb_define_class_under(outer, "Grammar", rb_cObject);
|
1213
|
+
rb_define_alloc_func(rb_cLLaMAGrammar, llama_grammar_alloc);
|
1214
|
+
rb_define_method(rb_cLLaMAGrammar, "initialize", RUBY_METHOD_FUNC(_llama_grammar_init), -1);
|
1215
|
+
}
|
1216
|
+
|
1217
|
+
private:
|
1218
|
+
static const rb_data_type_t llama_grammar_type;
|
1219
|
+
|
1220
|
+
static VALUE _llama_grammar_init(int argc, VALUE* argv, VALUE self) {
|
1221
|
+
VALUE kw_args = Qnil;
|
1222
|
+
ID kw_table[2] = { rb_intern("rules"), rb_intern("start_rule_index") };
|
1223
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
1224
|
+
rb_scan_args(argc, argv, ":", &kw_args);
|
1225
|
+
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
1226
|
+
|
1227
|
+
if (!RB_TYPE_P(kw_values[0], T_ARRAY)) {
|
1228
|
+
rb_raise(rb_eArgError, "rules must be an array");
|
1229
|
+
return Qnil;
|
1230
|
+
}
|
1231
|
+
if (!RB_INTEGER_TYPE_P(kw_values[1])) {
|
1232
|
+
rb_raise(rb_eArgError, "start_rule_index must be an integer");
|
1233
|
+
return Qnil;
|
1234
|
+
}
|
1235
|
+
|
1236
|
+
const int n_rules = RARRAY_LEN(kw_values[0]);
|
1237
|
+
llama_grammar_element** rules = ALLOCA_N(llama_grammar_element*, n_rules);
|
1238
|
+
for (int i = 0; i < n_rules; ++i) {
|
1239
|
+
VALUE rule = rb_ary_entry(kw_values[0], i);
|
1240
|
+
if (!RB_TYPE_P(rule, T_ARRAY)) {
|
1241
|
+
rb_raise(rb_eArgError, "element of rules must be an array");
|
1242
|
+
return Qnil;
|
1243
|
+
}
|
1244
|
+
const int n_elements = RARRAY_LEN(rule);
|
1245
|
+
llama_grammar_element* elements = ALLOCA_N(llama_grammar_element, n_elements);
|
1246
|
+
for (int j = 0; j < n_elements; ++j) {
|
1247
|
+
VALUE element = rb_ary_entry(rule, j);
|
1248
|
+
if (!rb_obj_is_kind_of(element, rb_cLLaMAGrammarElement)) {
|
1249
|
+
rb_raise(rb_eArgError, "element of rule must be an instance of GrammarElement");
|
1250
|
+
return Qnil;
|
1251
|
+
}
|
1252
|
+
LLaMAGrammarElementWrapper* ptr = RbLLaMAGrammarElement::get_llama_grammar_element(element);
|
1253
|
+
elements[j] = ptr->element;
|
1254
|
+
}
|
1255
|
+
rules[i] = elements;
|
1256
|
+
}
|
1257
|
+
|
1258
|
+
const size_t start_rule_index = NUM2SIZET(kw_values[1]);
|
1259
|
+
|
1260
|
+
LLaMAGrammarWrapper* ptr = get_llama_grammar(self);
|
1261
|
+
new (ptr) LLaMAGrammarWrapper();
|
1262
|
+
ptr->grammar = llama_grammar_init((const llama_grammar_element**)rules, n_rules, start_rule_index);
|
1263
|
+
|
1264
|
+
return self;
|
1265
|
+
}
|
1266
|
+
};
|
1267
|
+
|
1268
|
+
const rb_data_type_t RbLLaMAGrammar::llama_grammar_type = {
|
1269
|
+
"RbLLaMAGrammar",
|
1270
|
+
{ NULL,
|
1271
|
+
RbLLaMAGrammar::llama_grammar_free,
|
1272
|
+
RbLLaMAGrammar::llama_grammar_size },
|
1273
|
+
NULL,
|
1274
|
+
NULL,
|
1275
|
+
RUBY_TYPED_FREE_IMMEDIATELY
|
1276
|
+
};
|
1277
|
+
|
923
1278
|
class LLaMAContextWrapper {
|
924
1279
|
public:
|
925
1280
|
struct llama_context* ctx;
|
@@ -991,6 +1346,8 @@ public:
|
|
991
1346
|
rb_define_method(rb_cLLaMAContext, "sample_token_mirostat_v2", RUBY_METHOD_FUNC(_llama_context_sample_token_mirostat_v2), -1);
|
992
1347
|
rb_define_method(rb_cLLaMAContext, "sample_token_greedy", RUBY_METHOD_FUNC(_llama_context_sample_token_greedy), 1);
|
993
1348
|
rb_define_method(rb_cLLaMAContext, "sample_token", RUBY_METHOD_FUNC(_llama_context_sample_token), 1);
|
1349
|
+
rb_define_method(rb_cLLaMAContext, "sample_grammar", RUBY_METHOD_FUNC(_llama_context_sample_grammar), -1);
|
1350
|
+
rb_define_method(rb_cLLaMAContext, "grammar_accept_token", RUBY_METHOD_FUNC(_llama_context_grammar_accept_token), -1);
|
994
1351
|
}
|
995
1352
|
|
996
1353
|
private:
|
@@ -1581,11 +1938,11 @@ private:
|
|
1581
1938
|
|
1582
1939
|
static VALUE _llama_context_sample_classifier_free_guidance(int argc, VALUE* argv, VALUE self) {
|
1583
1940
|
VALUE kw_args = Qnil;
|
1584
|
-
ID kw_table[
|
1585
|
-
VALUE kw_values[
|
1941
|
+
ID kw_table[2] = { rb_intern("guidance"), rb_intern("scale") };
|
1942
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
1586
1943
|
VALUE candidates = Qnil;
|
1587
1944
|
rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
|
1588
|
-
rb_get_kwargs(kw_args, kw_table,
|
1945
|
+
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
1589
1946
|
|
1590
1947
|
if (!rb_obj_is_kind_of(kw_values[0], rb_cLLaMAContext)) {
|
1591
1948
|
rb_raise(rb_eArgError, "guidance must be a Context");
|
@@ -1595,10 +1952,6 @@ private:
|
|
1595
1952
|
rb_raise(rb_eArgError, "scale must be a float");
|
1596
1953
|
return Qnil;
|
1597
1954
|
}
|
1598
|
-
if (!RB_FLOAT_TYPE_P(kw_values[2])) {
|
1599
|
-
rb_raise(rb_eArgError, "smooth_factor must be a float");
|
1600
|
-
return Qnil;
|
1601
|
-
}
|
1602
1955
|
|
1603
1956
|
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
1604
1957
|
if (ctx_ptr->ctx == NULL) {
|
@@ -1617,9 +1970,8 @@ private:
|
|
1617
1970
|
return Qnil;
|
1618
1971
|
}
|
1619
1972
|
const float scale = NUM2DBL(kw_values[1]);
|
1620
|
-
const float smooth_factor = NUM2DBL(kw_values[2]);
|
1621
1973
|
|
1622
|
-
llama_sample_classifier_free_guidance(ctx_ptr->ctx, &(cnd_ptr->array), guidance_ptr->ctx, scale
|
1974
|
+
llama_sample_classifier_free_guidance(ctx_ptr->ctx, &(cnd_ptr->array), guidance_ptr->ctx, scale);
|
1623
1975
|
|
1624
1976
|
return Qnil;
|
1625
1977
|
}
|
@@ -1972,6 +2324,69 @@ private:
|
|
1972
2324
|
llama_token id = llama_sample_token(ctx_ptr->ctx, &(cnd_ptr->array));
|
1973
2325
|
return INT2NUM(id);
|
1974
2326
|
}
|
2327
|
+
|
2328
|
+
static VALUE _llama_context_sample_grammar(int argc, VALUE* argv, VALUE self) {
|
2329
|
+
VALUE kw_args = Qnil;
|
2330
|
+
ID kw_table[1] = { rb_intern("grammar") };
|
2331
|
+
VALUE kw_values[1] = { Qundef };
|
2332
|
+
VALUE candidates = Qnil;
|
2333
|
+
rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
|
2334
|
+
rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
|
2335
|
+
|
2336
|
+
if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
|
2337
|
+
rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
|
2338
|
+
return Qnil;
|
2339
|
+
}
|
2340
|
+
if (!rb_obj_is_kind_of(kw_values[0], rb_cLLaMAGrammar)) {
|
2341
|
+
rb_raise(rb_eArgError, "grammar must be a Grammar");
|
2342
|
+
return Qnil;
|
2343
|
+
}
|
2344
|
+
|
2345
|
+
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
2346
|
+
if (ctx_ptr->ctx == NULL) {
|
2347
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
2348
|
+
return Qnil;
|
2349
|
+
}
|
2350
|
+
LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
|
2351
|
+
if (cnd_ptr->array.data == nullptr) {
|
2352
|
+
rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
|
2353
|
+
return Qnil;
|
2354
|
+
}
|
2355
|
+
LLaMAGrammarWrapper* grm_ptr = RbLLaMAGrammar::get_llama_grammar(kw_values[0]);
|
2356
|
+
|
2357
|
+
llama_sample_grammar(ctx_ptr->ctx, &(cnd_ptr->array), grm_ptr->grammar);
|
2358
|
+
|
2359
|
+
return Qnil;
|
2360
|
+
}
|
2361
|
+
|
2362
|
+
static VALUE _llama_context_grammar_accept_token(int argc, VALUE* argv, VALUE self) {
|
2363
|
+
VALUE kw_args = Qnil;
|
2364
|
+
ID kw_table[2] = { rb_intern("grammar"), rb_intern("token") };
|
2365
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
2366
|
+
rb_scan_args(argc, argv, ":", &kw_args);
|
2367
|
+
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
2368
|
+
|
2369
|
+
if (!rb_obj_is_kind_of(kw_values[0], rb_cLLaMAGrammar)) {
|
2370
|
+
rb_raise(rb_eArgError, "grammar must be a Grammar");
|
2371
|
+
return Qnil;
|
2372
|
+
}
|
2373
|
+
if (!RB_INTEGER_TYPE_P(kw_values[1])) {
|
2374
|
+
rb_raise(rb_eArgError, "token must be an Integer");
|
2375
|
+
return Qnil;
|
2376
|
+
}
|
2377
|
+
|
2378
|
+
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
2379
|
+
if (ctx_ptr->ctx == NULL) {
|
2380
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
2381
|
+
return Qnil;
|
2382
|
+
}
|
2383
|
+
LLaMAGrammarWrapper* grm_ptr = RbLLaMAGrammar::get_llama_grammar(kw_values[0]);
|
2384
|
+
llama_token token = NUM2INT(kw_values[1]);
|
2385
|
+
|
2386
|
+
llama_grammar_accept_token(ctx_ptr->ctx, grm_ptr->grammar, token);
|
2387
|
+
|
2388
|
+
return Qnil;
|
2389
|
+
}
|
1975
2390
|
};
|
1976
2391
|
|
1977
2392
|
const rb_data_type_t RbLLaMAContext::llama_context_type = {
|
@@ -2062,6 +2477,10 @@ static VALUE rb_llama_mlock_supported(VALUE self) {
|
|
2062
2477
|
return llama_mlock_supported() ? Qtrue : Qfalse;
|
2063
2478
|
}
|
2064
2479
|
|
2480
|
+
static VALUE rb_llama_max_devices(VALUE self) {
|
2481
|
+
return INT2NUM(llama_max_devices());
|
2482
|
+
}
|
2483
|
+
|
2065
2484
|
extern "C" void Init_llama_cpp(void) {
|
2066
2485
|
rb_mLLaMACpp = rb_define_module("LLaMACpp");
|
2067
2486
|
|
@@ -2072,6 +2491,8 @@ extern "C" void Init_llama_cpp(void) {
|
|
2072
2491
|
RbLLaMAContext::define_class(rb_mLLaMACpp);
|
2073
2492
|
RbLLaMAContextParams::define_class(rb_mLLaMACpp);
|
2074
2493
|
RbLLaMAModelQuantizeParams::define_class(rb_mLLaMACpp);
|
2494
|
+
RbLLaMAGrammarElement::define_class(rb_mLLaMACpp);
|
2495
|
+
RbLLaMAGrammar::define_class(rb_mLLaMACpp);
|
2075
2496
|
|
2076
2497
|
rb_define_module_function(rb_mLLaMACpp, "backend_init", rb_llama_llama_backend_init, -1);
|
2077
2498
|
rb_define_module_function(rb_mLLaMACpp, "backend_free", rb_llama_llama_backend_free, 0);
|
@@ -2082,6 +2503,7 @@ extern "C" void Init_llama_cpp(void) {
|
|
2082
2503
|
rb_define_module_function(rb_mLLaMACpp, "print_system_info", rb_llama_print_system_info, 0);
|
2083
2504
|
rb_define_module_function(rb_mLLaMACpp, "mmap_supported?", rb_llama_mmap_supported, 0);
|
2084
2505
|
rb_define_module_function(rb_mLLaMACpp, "mlock_supported?", rb_llama_mlock_supported, 0);
|
2506
|
+
rb_define_module_function(rb_mLLaMACpp, "max_devices", rb_llama_max_devices, 0);
|
2085
2507
|
|
2086
2508
|
rb_define_const(rb_mLLaMACpp, "LLAMA_MAX_DEVICES", INT2NUM(LLAMA_MAX_DEVICES));
|
2087
2509
|
|
@@ -2103,6 +2525,14 @@ extern "C" void Init_llama_cpp(void) {
|
|
2103
2525
|
rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q5_K_M", INT2NUM(LLAMA_FTYPE_MOSTLY_Q5_K_M));
|
2104
2526
|
rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q6_K", INT2NUM(LLAMA_FTYPE_MOSTLY_Q6_K));
|
2105
2527
|
|
2528
|
+
rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_END", INT2NUM(LLAMA_GRETYPE_END));
|
2529
|
+
rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_ALT", INT2NUM(LLAMA_GRETYPE_ALT));
|
2530
|
+
rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_RULE_REF", INT2NUM(LLAMA_GRETYPE_RULE_REF));
|
2531
|
+
rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR", INT2NUM(LLAMA_GRETYPE_CHAR));
|
2532
|
+
rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR_NOT", INT2NUM(LLAMA_GRETYPE_CHAR_NOT));
|
2533
|
+
rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR_RNG_UPPER", INT2NUM(LLAMA_GRETYPE_CHAR_RNG_UPPER));
|
2534
|
+
rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR_ALT", INT2NUM(LLAMA_GRETYPE_CHAR_ALT));
|
2535
|
+
|
2106
2536
|
std::stringstream ss_magic;
|
2107
2537
|
ss_magic << std::showbase << std::hex << LLAMA_FILE_MAGIC_GGJT;
|
2108
2538
|
rb_define_const(rb_mLLaMACpp, "LLAMA_FILE_MAGIC_GGJT", rb_str_new2(ss_magic.str().c_str()));
|