llama_cpp 0.0.7 → 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +7 -0
- data/ext/llama_cpp/llama_cpp.cpp +736 -36
- data/ext/llama_cpp/src/ggml-cuda.h +8 -33
- data/ext/llama_cpp/src/ggml-opencl.c +202 -20
- data/ext/llama_cpp/src/ggml.c +732 -496
- data/ext/llama_cpp/src/ggml.h +47 -5
- data/ext/llama_cpp/src/{llama_util.h → llama-util.h} +76 -10
- data/ext/llama_cpp/src/llama.cpp +560 -147
- data/ext/llama_cpp/src/llama.h +71 -24
- data/lib/llama_cpp/client.rb +29 -6
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +27 -3
- data/sig/llama_cpp.rbs +38 -3
- metadata +3 -3
data/ext/llama_cpp/llama_cpp.cpp
CHANGED
@@ -4,6 +4,255 @@
|
|
4
4
|
VALUE rb_mLLaMACpp;
|
5
5
|
VALUE rb_cLLaMAContext;
|
6
6
|
VALUE rb_cLLaMAContextParams;
|
7
|
+
VALUE rb_cLLaMATokenData;
|
8
|
+
VALUE rb_cLLaMATokenDataArray;
|
9
|
+
|
10
|
+
class LLaMATokenDataWrapper {
|
11
|
+
public:
|
12
|
+
llama_token_data data;
|
13
|
+
|
14
|
+
LLaMATokenDataWrapper() {
|
15
|
+
data.id = 0;
|
16
|
+
data.logit = 0.0;
|
17
|
+
data.p = 0.0;
|
18
|
+
};
|
19
|
+
|
20
|
+
~LLaMATokenDataWrapper(){};
|
21
|
+
};
|
22
|
+
|
23
|
+
class RbLLaMATokenData {
|
24
|
+
public:
|
25
|
+
static VALUE llama_token_data_alloc(VALUE self) {
|
26
|
+
LLaMATokenDataWrapper* ptr = (LLaMATokenDataWrapper*)ruby_xmalloc(sizeof(LLaMATokenDataWrapper));
|
27
|
+
new (ptr) LLaMATokenDataWrapper();
|
28
|
+
return TypedData_Wrap_Struct(self, &llama_token_data_type, ptr);
|
29
|
+
};
|
30
|
+
|
31
|
+
static void llama_token_data_free(void* ptr) {
|
32
|
+
((LLaMATokenDataWrapper*)ptr)->~LLaMATokenDataWrapper();
|
33
|
+
ruby_xfree(ptr);
|
34
|
+
};
|
35
|
+
|
36
|
+
static size_t llama_token_data_size(const void* ptr) {
|
37
|
+
return sizeof(*((LLaMATokenDataWrapper*)ptr));
|
38
|
+
};
|
39
|
+
|
40
|
+
static LLaMATokenDataWrapper* get_llama_token_data(VALUE self) {
|
41
|
+
LLaMATokenDataWrapper* ptr;
|
42
|
+
TypedData_Get_Struct(self, LLaMATokenDataWrapper, &llama_token_data_type, ptr);
|
43
|
+
return ptr;
|
44
|
+
};
|
45
|
+
|
46
|
+
static void define_class(VALUE outer) {
|
47
|
+
rb_cLLaMATokenData = rb_define_class_under(outer, "TokenData", rb_cObject);
|
48
|
+
rb_define_alloc_func(rb_cLLaMATokenData, llama_token_data_alloc);
|
49
|
+
rb_define_method(rb_cLLaMATokenData, "initialize", RUBY_METHOD_FUNC(_llama_token_data_init), -1);
|
50
|
+
rb_define_method(rb_cLLaMATokenData, "id=", RUBY_METHOD_FUNC(_llama_token_data_set_id), 1);
|
51
|
+
rb_define_method(rb_cLLaMATokenData, "id", RUBY_METHOD_FUNC(_llama_token_data_get_id), 0);
|
52
|
+
rb_define_method(rb_cLLaMATokenData, "logit=", RUBY_METHOD_FUNC(_llama_token_data_set_logit), 1);
|
53
|
+
rb_define_method(rb_cLLaMATokenData, "logit", RUBY_METHOD_FUNC(_llama_token_data_get_logit), 0);
|
54
|
+
rb_define_method(rb_cLLaMATokenData, "p=", RUBY_METHOD_FUNC(_llama_token_data_set_p), 1);
|
55
|
+
rb_define_method(rb_cLLaMATokenData, "p", RUBY_METHOD_FUNC(_llama_token_data_get_p), 0);
|
56
|
+
}
|
57
|
+
|
58
|
+
private:
|
59
|
+
static const rb_data_type_t llama_token_data_type;
|
60
|
+
|
61
|
+
static VALUE _llama_token_data_init(int argc, VALUE* argv, VALUE self) {
|
62
|
+
VALUE kw_args = Qnil;
|
63
|
+
ID kw_table[3] = { rb_intern("id"), rb_intern("logit"), rb_intern("p") };
|
64
|
+
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
65
|
+
rb_scan_args(argc, argv, ":", &kw_args);
|
66
|
+
rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
|
67
|
+
|
68
|
+
if (!RB_INTEGER_TYPE_P(kw_values[0])) {
|
69
|
+
rb_raise(rb_eArgError, "id must be an integer");
|
70
|
+
return Qnil;
|
71
|
+
}
|
72
|
+
if (!RB_FLOAT_TYPE_P(kw_values[1])) {
|
73
|
+
rb_raise(rb_eArgError, "logit must be a float");
|
74
|
+
return Qnil;
|
75
|
+
}
|
76
|
+
if (!RB_FLOAT_TYPE_P(kw_values[2])) {
|
77
|
+
rb_raise(rb_eArgError, "p must be a float");
|
78
|
+
return Qnil;
|
79
|
+
}
|
80
|
+
|
81
|
+
LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
|
82
|
+
new (ptr) LLaMATokenDataWrapper();
|
83
|
+
|
84
|
+
ptr->data.id = NUM2INT(kw_values[0]);
|
85
|
+
ptr->data.logit = NUM2DBL(kw_values[1]);
|
86
|
+
ptr->data.p = NUM2DBL(kw_values[2]);
|
87
|
+
|
88
|
+
return self;
|
89
|
+
}
|
90
|
+
|
91
|
+
// id
|
92
|
+
static VALUE _llama_token_data_set_id(VALUE self, VALUE id) {
|
93
|
+
LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
|
94
|
+
ptr->data.id = NUM2INT(id);
|
95
|
+
return INT2NUM(ptr->data.id);
|
96
|
+
};
|
97
|
+
|
98
|
+
static VALUE _llama_token_data_get_id(VALUE self) {
|
99
|
+
LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
|
100
|
+
return INT2NUM(ptr->data.id);
|
101
|
+
};
|
102
|
+
|
103
|
+
// logit
|
104
|
+
static VALUE _llama_token_data_set_logit(VALUE self, VALUE logit) {
|
105
|
+
LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
|
106
|
+
ptr->data.logit = NUM2DBL(logit);
|
107
|
+
return DBL2NUM(ptr->data.logit);
|
108
|
+
};
|
109
|
+
|
110
|
+
static VALUE _llama_token_data_get_logit(VALUE self) {
|
111
|
+
LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
|
112
|
+
return DBL2NUM(ptr->data.logit);
|
113
|
+
};
|
114
|
+
|
115
|
+
// p
|
116
|
+
static VALUE _llama_token_data_set_p(VALUE self, VALUE p) {
|
117
|
+
LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
|
118
|
+
ptr->data.p = NUM2DBL(p);
|
119
|
+
return DBL2NUM(ptr->data.p);
|
120
|
+
};
|
121
|
+
|
122
|
+
static VALUE _llama_token_data_get_p(VALUE self) {
|
123
|
+
LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
|
124
|
+
return DBL2NUM(ptr->data.p);
|
125
|
+
};
|
126
|
+
};
|
127
|
+
|
128
|
+
const rb_data_type_t RbLLaMATokenData::llama_token_data_type = {
|
129
|
+
"RbLLaMATokenData",
|
130
|
+
{ NULL,
|
131
|
+
RbLLaMATokenData::llama_token_data_free,
|
132
|
+
RbLLaMATokenData::llama_token_data_size },
|
133
|
+
NULL,
|
134
|
+
NULL,
|
135
|
+
RUBY_TYPED_FREE_IMMEDIATELY
|
136
|
+
};
|
137
|
+
|
138
|
+
class LLaMATokenDataArrayWrapper {
|
139
|
+
public:
|
140
|
+
llama_token_data_array array;
|
141
|
+
|
142
|
+
LLaMATokenDataArrayWrapper() {
|
143
|
+
array.data = nullptr;
|
144
|
+
array.size = 0;
|
145
|
+
array.sorted = false;
|
146
|
+
};
|
147
|
+
|
148
|
+
~LLaMATokenDataArrayWrapper() {
|
149
|
+
if (array.data) {
|
150
|
+
ruby_xfree(array.data);
|
151
|
+
array.data = nullptr;
|
152
|
+
}
|
153
|
+
};
|
154
|
+
};
|
155
|
+
|
156
|
+
class RbLLaMATokenDataArray {
|
157
|
+
public:
|
158
|
+
static VALUE llama_token_data_array_alloc(VALUE self) {
|
159
|
+
LLaMATokenDataArrayWrapper* ptr = (LLaMATokenDataArrayWrapper*)ruby_xmalloc(sizeof(LLaMATokenDataArrayWrapper));
|
160
|
+
new (ptr) LLaMATokenDataArrayWrapper();
|
161
|
+
return TypedData_Wrap_Struct(self, &llama_token_data_array_type, ptr);
|
162
|
+
};
|
163
|
+
|
164
|
+
static void llama_token_data_array_free(void* ptr) {
|
165
|
+
((LLaMATokenDataArrayWrapper*)ptr)->~LLaMATokenDataArrayWrapper();
|
166
|
+
ruby_xfree(ptr);
|
167
|
+
};
|
168
|
+
|
169
|
+
static size_t llama_token_data_array_size(const void* ptr) {
|
170
|
+
return sizeof(*((LLaMATokenDataArrayWrapper*)ptr));
|
171
|
+
};
|
172
|
+
|
173
|
+
static LLaMATokenDataArrayWrapper* get_llama_token_data_array(VALUE self) {
|
174
|
+
LLaMATokenDataArrayWrapper* ptr;
|
175
|
+
TypedData_Get_Struct(self, LLaMATokenDataArrayWrapper, &llama_token_data_array_type, ptr);
|
176
|
+
return ptr;
|
177
|
+
};
|
178
|
+
|
179
|
+
static void define_class(VALUE outer) {
|
180
|
+
rb_cLLaMATokenDataArray = rb_define_class_under(outer, "TokenDataArray", rb_cObject);
|
181
|
+
rb_define_alloc_func(rb_cLLaMATokenDataArray, llama_token_data_array_alloc);
|
182
|
+
rb_define_method(rb_cLLaMATokenDataArray, "initialize", RUBY_METHOD_FUNC(_llama_token_data_array_init), -1);
|
183
|
+
rb_define_method(rb_cLLaMATokenDataArray, "size", RUBY_METHOD_FUNC(_llama_token_data_array_get_size), 0);
|
184
|
+
rb_define_method(rb_cLLaMATokenDataArray, "sorted", RUBY_METHOD_FUNC(_llama_token_data_array_get_sorted), 0);
|
185
|
+
};
|
186
|
+
|
187
|
+
private:
|
188
|
+
static const rb_data_type_t llama_token_data_array_type;
|
189
|
+
|
190
|
+
static VALUE _llama_token_data_array_init(int argc, VALUE* argv, VALUE self) {
|
191
|
+
VALUE kw_args = Qnil;
|
192
|
+
ID kw_table[1] = { rb_intern("sorted") };
|
193
|
+
VALUE kw_values[1] = { Qundef };
|
194
|
+
VALUE arr = Qnil;
|
195
|
+
rb_scan_args(argc, argv, "1:", &arr, &kw_args);
|
196
|
+
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
197
|
+
|
198
|
+
if (!RB_TYPE_P(arr, T_ARRAY)) {
|
199
|
+
rb_raise(rb_eArgError, "1st argument must be an array");
|
200
|
+
return Qnil;
|
201
|
+
}
|
202
|
+
size_t sz_array = RARRAY_LEN(arr);
|
203
|
+
if (sz_array == 0) {
|
204
|
+
rb_raise(rb_eArgError, "array must not be empty");
|
205
|
+
return Qnil;
|
206
|
+
}
|
207
|
+
if (kw_values[0] != Qundef && !RB_TYPE_P(kw_values[0], T_TRUE) && !RB_TYPE_P(kw_values[0], T_FALSE)) {
|
208
|
+
rb_raise(rb_eArgError, "sorted must be a boolean");
|
209
|
+
return Qnil;
|
210
|
+
}
|
211
|
+
|
212
|
+
LLaMATokenDataArrayWrapper* ptr = get_llama_token_data_array(self);
|
213
|
+
new (ptr) LLaMATokenDataArrayWrapper();
|
214
|
+
|
215
|
+
ptr->array.data = (llama_token_data*)ruby_xmalloc(sizeof(llama_token_data) * sz_array);
|
216
|
+
for (size_t i = 0; i < sz_array; ++i) {
|
217
|
+
VALUE el = rb_ary_entry(arr, i);
|
218
|
+
if (!rb_obj_is_kind_of(el, rb_cLLaMATokenData)) {
|
219
|
+
rb_raise(rb_eArgError, "array element must be a TokenData");
|
220
|
+
xfree(ptr->array.data);
|
221
|
+
ptr->array.data = nullptr;
|
222
|
+
return Qnil;
|
223
|
+
}
|
224
|
+
llama_token_data token_data = RbLLaMATokenData::get_llama_token_data(el)->data;
|
225
|
+
ptr->array.data[i].id = token_data.id;
|
226
|
+
ptr->array.data[i].logit = token_data.logit;
|
227
|
+
ptr->array.data[i].p = token_data.p;
|
228
|
+
}
|
229
|
+
|
230
|
+
ptr->array.size = sz_array;
|
231
|
+
ptr->array.sorted = kw_values[0] == Qtrue;
|
232
|
+
|
233
|
+
return self;
|
234
|
+
};
|
235
|
+
|
236
|
+
static VALUE _llama_token_data_array_get_size(VALUE self) {
|
237
|
+
LLaMATokenDataArrayWrapper* ptr = get_llama_token_data_array(self);
|
238
|
+
return SIZET2NUM(ptr->array.size);
|
239
|
+
};
|
240
|
+
|
241
|
+
static VALUE _llama_token_data_array_get_sorted(VALUE self) {
|
242
|
+
LLaMATokenDataArrayWrapper* ptr = get_llama_token_data_array(self);
|
243
|
+
return ptr->array.sorted ? Qtrue : Qfalse;
|
244
|
+
};
|
245
|
+
};
|
246
|
+
|
247
|
+
const rb_data_type_t RbLLaMATokenDataArray::llama_token_data_array_type = {
|
248
|
+
"RbLLaMATokenDataArray",
|
249
|
+
{ NULL,
|
250
|
+
RbLLaMATokenDataArray::llama_token_data_array_free,
|
251
|
+
RbLLaMATokenDataArray::llama_token_data_array_size },
|
252
|
+
NULL,
|
253
|
+
NULL,
|
254
|
+
RUBY_TYPED_FREE_IMMEDIATELY
|
255
|
+
};
|
7
256
|
|
8
257
|
class LLaMAContextParamsWrapper {
|
9
258
|
public:
|
@@ -234,7 +483,6 @@ public:
|
|
234
483
|
rb_define_method(rb_cLLaMAContext, "logits", RUBY_METHOD_FUNC(_llama_context_logits), 0);
|
235
484
|
rb_define_method(rb_cLLaMAContext, "embeddings", RUBY_METHOD_FUNC(_llama_context_embeddings), 0);
|
236
485
|
rb_define_method(rb_cLLaMAContext, "token_to_str", RUBY_METHOD_FUNC(_llama_context_token_to_str), 1);
|
237
|
-
rb_define_method(rb_cLLaMAContext, "sample_top_p_top_k", RUBY_METHOD_FUNC(_llama_context_sample_top_p_top_k), -1);
|
238
486
|
rb_define_method(rb_cLLaMAContext, "n_vocab", RUBY_METHOD_FUNC(_llama_context_n_vocab), 0);
|
239
487
|
rb_define_method(rb_cLLaMAContext, "n_ctx", RUBY_METHOD_FUNC(_llama_context_n_ctx), 0);
|
240
488
|
rb_define_method(rb_cLLaMAContext, "n_embd", RUBY_METHOD_FUNC(_llama_context_n_embd), 0);
|
@@ -244,6 +492,20 @@ public:
|
|
244
492
|
rb_define_method(rb_cLLaMAContext, "free", RUBY_METHOD_FUNC(_llama_context_free), 0);
|
245
493
|
rb_define_method(rb_cLLaMAContext, "load", RUBY_METHOD_FUNC(_llama_context_load), -1);
|
246
494
|
rb_define_method(rb_cLLaMAContext, "apply_lora_from_file", RUBY_METHOD_FUNC(_llama_context_apply_lora_from_file), -1);
|
495
|
+
rb_define_method(rb_cLLaMAContext, "kv_cache_token_count", RUBY_METHOD_FUNC(_llama_context_kv_cache_token_count), 0);
|
496
|
+
rb_define_method(rb_cLLaMAContext, "set_rng_seed", RUBY_METHOD_FUNC(_llama_context_set_rng_seed), 1);
|
497
|
+
rb_define_method(rb_cLLaMAContext, "sample_repetition_penalty", RUBY_METHOD_FUNC(_llama_context_sample_repetition_penalty), -1);
|
498
|
+
rb_define_method(rb_cLLaMAContext, "sample_frequency_and_presence_penalties", RUBY_METHOD_FUNC(_llama_context_sample_frequency_and_presence_penalties), -1);
|
499
|
+
rb_define_method(rb_cLLaMAContext, "sample_softmax", RUBY_METHOD_FUNC(_llama_context_sample_softmax), 1);
|
500
|
+
rb_define_method(rb_cLLaMAContext, "sample_top_k", RUBY_METHOD_FUNC(_llama_context_sample_top_k), -1);
|
501
|
+
rb_define_method(rb_cLLaMAContext, "sample_top_p", RUBY_METHOD_FUNC(_llama_context_sample_top_p), -1);
|
502
|
+
rb_define_method(rb_cLLaMAContext, "sample_tail_free", RUBY_METHOD_FUNC(_llama_context_sample_tail_free), -1);
|
503
|
+
rb_define_method(rb_cLLaMAContext, "sample_typical", RUBY_METHOD_FUNC(_llama_context_sample_typical), -1);
|
504
|
+
rb_define_method(rb_cLLaMAContext, "sample_temperature", RUBY_METHOD_FUNC(_llama_context_sample_temperature), -1);
|
505
|
+
rb_define_method(rb_cLLaMAContext, "sample_token_mirostat", RUBY_METHOD_FUNC(_llama_context_sample_token_mirostat), -1);
|
506
|
+
rb_define_method(rb_cLLaMAContext, "sample_token_mirostat_v2", RUBY_METHOD_FUNC(_llama_context_sample_token_mirostat_v2), -1);
|
507
|
+
rb_define_method(rb_cLLaMAContext, "sample_token_greedy", RUBY_METHOD_FUNC(_llama_context_sample_token_greedy), 1);
|
508
|
+
rb_define_method(rb_cLLaMAContext, "sample_token", RUBY_METHOD_FUNC(_llama_context_sample_token), 1);
|
247
509
|
};
|
248
510
|
|
249
511
|
private:
|
@@ -448,40 +710,6 @@ private:
|
|
448
710
|
return output;
|
449
711
|
};
|
450
712
|
|
451
|
-
static VALUE _llama_context_sample_top_p_top_k(int argc, VALUE* argv, VALUE self) {
|
452
|
-
VALUE last_n_tokens = Qnil;
|
453
|
-
VALUE kw_args = Qnil;
|
454
|
-
ID kw_table[4] = { rb_intern("top_k"), rb_intern("top_p"), rb_intern("temp"), rb_intern("penalty") };
|
455
|
-
VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
|
456
|
-
rb_scan_args(argc, argv, "1:", &last_n_tokens, &kw_args);
|
457
|
-
rb_get_kwargs(kw_args, kw_table, 4, 0, kw_values);
|
458
|
-
|
459
|
-
if (!RB_TYPE_P(last_n_tokens, T_ARRAY)) {
|
460
|
-
rb_raise(rb_eArgError, "last_n_tokens must be an Array");
|
461
|
-
return Qnil;
|
462
|
-
}
|
463
|
-
|
464
|
-
const int last_n_tokens_size = RARRAY_LEN(last_n_tokens);
|
465
|
-
const int top_k = NUM2INT(kw_values[0]);
|
466
|
-
const double top_p = NUM2DBL(kw_values[1]);
|
467
|
-
const double temp = NUM2DBL(kw_values[2]);
|
468
|
-
const double penalty = NUM2DBL(kw_values[3]);
|
469
|
-
|
470
|
-
std::vector<llama_token> last_n_tokens_data(last_n_tokens_size);
|
471
|
-
for (int i = 0; i < last_n_tokens_size; i++) {
|
472
|
-
last_n_tokens_data[i] = NUM2INT(rb_ary_entry(last_n_tokens, i));
|
473
|
-
}
|
474
|
-
|
475
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
476
|
-
if (ptr->ctx == NULL) {
|
477
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
478
|
-
return Qnil;
|
479
|
-
}
|
480
|
-
llama_token token = llama_sample_top_p_top_k(ptr->ctx, last_n_tokens_data.data(), last_n_tokens_size, top_k, top_p, temp, penalty);
|
481
|
-
|
482
|
-
return INT2NUM(token);
|
483
|
-
}
|
484
|
-
|
485
713
|
static VALUE _llama_context_n_vocab(VALUE self) {
|
486
714
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
487
715
|
if (ptr->ctx == NULL) {
|
@@ -621,6 +849,471 @@ private:
|
|
621
849
|
}
|
622
850
|
return Qnil;
|
623
851
|
};
|
852
|
+
|
853
|
+
static VALUE _llama_context_kv_cache_token_count(VALUE self) {
|
854
|
+
LLaMAContextWrapper* ptr = get_llama_context(self);
|
855
|
+
if (ptr->ctx == NULL) {
|
856
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
857
|
+
return Qnil;
|
858
|
+
}
|
859
|
+
return INT2NUM(llama_get_kv_cache_token_count(ptr->ctx));
|
860
|
+
};
|
861
|
+
|
862
|
+
static VALUE _llama_context_set_rng_seed(VALUE self, VALUE seed_) {
|
863
|
+
LLaMAContextWrapper* ptr = get_llama_context(self);
|
864
|
+
if (ptr->ctx == NULL) {
|
865
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
866
|
+
return Qnil;
|
867
|
+
}
|
868
|
+
const int seed = NUM2INT(seed_);
|
869
|
+
llama_set_rng_seed(ptr->ctx, seed);
|
870
|
+
return Qnil;
|
871
|
+
};
|
872
|
+
|
873
|
+
static VALUE _llama_context_sample_repetition_penalty(int argc, VALUE* argv, VALUE self) {
|
874
|
+
VALUE kw_args = Qnil;
|
875
|
+
ID kw_table[1] = { rb_intern("penalty") };
|
876
|
+
VALUE kw_values[1] = { Qundef };
|
877
|
+
VALUE candidates = Qnil;
|
878
|
+
VALUE last_n_tokens = Qnil;
|
879
|
+
rb_scan_args(argc, argv, "2:", &candidates, &last_n_tokens, &kw_args);
|
880
|
+
rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
|
881
|
+
|
882
|
+
if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
|
883
|
+
rb_raise(rb_eArgError, "candidates must be a TokenDataArray");
|
884
|
+
return Qnil;
|
885
|
+
}
|
886
|
+
if (!RB_TYPE_P(last_n_tokens, T_ARRAY)) {
|
887
|
+
rb_raise(rb_eArgError, "last_n_tokens must be an Array");
|
888
|
+
return Qnil;
|
889
|
+
}
|
890
|
+
if (!RB_FLOAT_TYPE_P(kw_values[0])) {
|
891
|
+
rb_raise(rb_eArgError, "penalty must be a float");
|
892
|
+
return Qnil;
|
893
|
+
}
|
894
|
+
|
895
|
+
const size_t last_tokens_size = RARRAY_LEN(last_n_tokens);
|
896
|
+
std::vector<llama_token> last_n_tokens_data(last_tokens_size);
|
897
|
+
for (size_t i = 0; i < last_tokens_size; i++) {
|
898
|
+
last_n_tokens_data[i] = NUM2INT(rb_ary_entry(last_n_tokens, i));
|
899
|
+
}
|
900
|
+
|
901
|
+
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
902
|
+
if (ctx_ptr->ctx == NULL) {
|
903
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
904
|
+
return Qnil;
|
905
|
+
}
|
906
|
+
LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
|
907
|
+
if (cnd_ptr->array.data == nullptr) {
|
908
|
+
rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
|
909
|
+
return Qnil;
|
910
|
+
}
|
911
|
+
const float penalty = NUM2DBL(kw_values[0]);
|
912
|
+
|
913
|
+
llama_sample_repetition_penalty(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size, penalty);
|
914
|
+
|
915
|
+
return Qnil;
|
916
|
+
};
|
917
|
+
|
918
|
+
static VALUE _llama_context_sample_frequency_and_presence_penalties(int argc, VALUE* argv, VALUE self) {
|
919
|
+
VALUE kw_args = Qnil;
|
920
|
+
ID kw_table[2] = { rb_intern("frequency"), rb_intern("presence") };
|
921
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
922
|
+
VALUE candidates = Qnil;
|
923
|
+
VALUE last_n_tokens = Qnil;
|
924
|
+
rb_scan_args(argc, argv, "2:", &candidates, &last_n_tokens, &kw_args);
|
925
|
+
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
926
|
+
|
927
|
+
if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
|
928
|
+
rb_raise(rb_eArgError, "candidates must be a TokenDataArray");
|
929
|
+
return Qnil;
|
930
|
+
}
|
931
|
+
if (!RB_TYPE_P(last_n_tokens, T_ARRAY)) {
|
932
|
+
rb_raise(rb_eArgError, "last_n_tokens must be an Array");
|
933
|
+
return Qnil;
|
934
|
+
}
|
935
|
+
if (!RB_FLOAT_TYPE_P(kw_values[0])) {
|
936
|
+
rb_raise(rb_eArgError, "frequency must be a float");
|
937
|
+
return Qnil;
|
938
|
+
}
|
939
|
+
if (!RB_FLOAT_TYPE_P(kw_values[1])) {
|
940
|
+
rb_raise(rb_eArgError, "presence must be a float");
|
941
|
+
return Qnil;
|
942
|
+
}
|
943
|
+
|
944
|
+
const size_t last_tokens_size = RARRAY_LEN(last_n_tokens);
|
945
|
+
std::vector<llama_token> last_n_tokens_data(last_tokens_size);
|
946
|
+
for (size_t i = 0; i < last_tokens_size; i++) {
|
947
|
+
last_n_tokens_data[i] = NUM2INT(rb_ary_entry(last_n_tokens, i));
|
948
|
+
}
|
949
|
+
|
950
|
+
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
951
|
+
if (ctx_ptr->ctx == NULL) {
|
952
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
953
|
+
return Qnil;
|
954
|
+
}
|
955
|
+
LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
|
956
|
+
if (cnd_ptr->array.data == nullptr) {
|
957
|
+
rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
|
958
|
+
return Qnil;
|
959
|
+
}
|
960
|
+
|
961
|
+
const float alpha_frequency = NUM2DBL(kw_values[0]);
|
962
|
+
const float alpha_presence = NUM2DBL(kw_values[1]);
|
963
|
+
|
964
|
+
llama_sample_frequency_and_presence_penalties(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size, alpha_frequency, alpha_presence);
|
965
|
+
|
966
|
+
return Qnil;
|
967
|
+
};
|
968
|
+
|
969
|
+
static VALUE _llama_context_sample_softmax(VALUE self, VALUE candidates) {
|
970
|
+
if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
|
971
|
+
rb_raise(rb_eArgError, "argument must be a TokenDataArray");
|
972
|
+
return Qnil;
|
973
|
+
}
|
974
|
+
|
975
|
+
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
976
|
+
if (ctx_ptr->ctx == NULL) {
|
977
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
978
|
+
return Qnil;
|
979
|
+
}
|
980
|
+
LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
|
981
|
+
if (cnd_ptr->array.data == nullptr) {
|
982
|
+
rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
|
983
|
+
return Qnil;
|
984
|
+
}
|
985
|
+
|
986
|
+
llama_sample_softmax(ctx_ptr->ctx, &(cnd_ptr->array));
|
987
|
+
|
988
|
+
return Qnil;
|
989
|
+
};
|
990
|
+
|
991
|
+
static VALUE _llama_context_sample_top_k(int argc, VALUE* argv, VALUE self) {
|
992
|
+
VALUE kw_args = Qnil;
|
993
|
+
ID kw_table[2] = { rb_intern("k"), rb_intern("min_keep") };
|
994
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
995
|
+
VALUE candidates = Qnil;
|
996
|
+
rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
|
997
|
+
rb_get_kwargs(kw_args, kw_table, 1, 1, kw_values);
|
998
|
+
|
999
|
+
if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
|
1000
|
+
rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
|
1001
|
+
return Qnil;
|
1002
|
+
}
|
1003
|
+
if (!RB_INTEGER_TYPE_P(kw_values[0])) {
|
1004
|
+
rb_raise(rb_eArgError, "k must be an integer");
|
1005
|
+
return Qnil;
|
1006
|
+
}
|
1007
|
+
if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
|
1008
|
+
rb_raise(rb_eArgError, "min_keep must be an integer");
|
1009
|
+
return Qnil;
|
1010
|
+
}
|
1011
|
+
|
1012
|
+
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
1013
|
+
if (ctx_ptr->ctx == NULL) {
|
1014
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1015
|
+
return Qnil;
|
1016
|
+
}
|
1017
|
+
LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
|
1018
|
+
if (cnd_ptr->array.data == nullptr) {
|
1019
|
+
rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
|
1020
|
+
return Qnil;
|
1021
|
+
}
|
1022
|
+
const int k = NUM2DBL(kw_values[0]);
|
1023
|
+
const size_t min_keep = kw_values[1] != Qundef ? NUM2SIZET(kw_values[1]) : 1;
|
1024
|
+
|
1025
|
+
llama_sample_top_k(ctx_ptr->ctx, &(cnd_ptr->array), k, min_keep);
|
1026
|
+
|
1027
|
+
return Qnil;
|
1028
|
+
};
|
1029
|
+
|
1030
|
+
static VALUE _llama_context_sample_top_p(int argc, VALUE* argv, VALUE self) {
|
1031
|
+
VALUE kw_args = Qnil;
|
1032
|
+
ID kw_table[2] = { rb_intern("prob"), rb_intern("min_keep") };
|
1033
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
1034
|
+
VALUE candidates = Qnil;
|
1035
|
+
rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
|
1036
|
+
rb_get_kwargs(kw_args, kw_table, 1, 1, kw_values);
|
1037
|
+
|
1038
|
+
if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
|
1039
|
+
rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
|
1040
|
+
return Qnil;
|
1041
|
+
}
|
1042
|
+
if (!RB_FLOAT_TYPE_P(kw_values[0])) {
|
1043
|
+
rb_raise(rb_eArgError, "prob must be a float");
|
1044
|
+
return Qnil;
|
1045
|
+
}
|
1046
|
+
if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
|
1047
|
+
rb_raise(rb_eArgError, "min_keep must be an integer");
|
1048
|
+
return Qnil;
|
1049
|
+
}
|
1050
|
+
|
1051
|
+
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
1052
|
+
if (ctx_ptr->ctx == NULL) {
|
1053
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1054
|
+
return Qnil;
|
1055
|
+
}
|
1056
|
+
LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
|
1057
|
+
if (cnd_ptr->array.data == nullptr) {
|
1058
|
+
rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
|
1059
|
+
return Qnil;
|
1060
|
+
}
|
1061
|
+
const float prob = NUM2DBL(kw_values[0]);
|
1062
|
+
const size_t min_keep = kw_values[1] != Qundef ? NUM2SIZET(kw_values[1]) : 1;
|
1063
|
+
|
1064
|
+
llama_sample_top_p(ctx_ptr->ctx, &(cnd_ptr->array), prob, min_keep);
|
1065
|
+
|
1066
|
+
return Qnil;
|
1067
|
+
};
|
1068
|
+
|
1069
|
+
static VALUE _llama_context_sample_tail_free(int argc, VALUE* argv, VALUE self) {
|
1070
|
+
VALUE kw_args = Qnil;
|
1071
|
+
ID kw_table[2] = { rb_intern("z"), rb_intern("min_keep") };
|
1072
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
1073
|
+
VALUE candidates = Qnil;
|
1074
|
+
rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
|
1075
|
+
rb_get_kwargs(kw_args, kw_table, 1, 1, kw_values);
|
1076
|
+
|
1077
|
+
if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
|
1078
|
+
rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
|
1079
|
+
return Qnil;
|
1080
|
+
}
|
1081
|
+
if (!RB_FLOAT_TYPE_P(kw_values[0])) {
|
1082
|
+
rb_raise(rb_eArgError, "prob must be a float");
|
1083
|
+
return Qnil;
|
1084
|
+
}
|
1085
|
+
if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
|
1086
|
+
rb_raise(rb_eArgError, "min_keep must be an integer");
|
1087
|
+
return Qnil;
|
1088
|
+
}
|
1089
|
+
|
1090
|
+
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
1091
|
+
if (ctx_ptr->ctx == NULL) {
|
1092
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1093
|
+
return Qnil;
|
1094
|
+
}
|
1095
|
+
LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
|
1096
|
+
if (cnd_ptr->array.data == nullptr) {
|
1097
|
+
rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
|
1098
|
+
return Qnil;
|
1099
|
+
}
|
1100
|
+
const float z = NUM2DBL(kw_values[0]);
|
1101
|
+
const size_t min_keep = kw_values[1] != Qundef ? NUM2SIZET(kw_values[1]) : 1;
|
1102
|
+
|
1103
|
+
llama_sample_tail_free(ctx_ptr->ctx, &(cnd_ptr->array), z, min_keep);
|
1104
|
+
|
1105
|
+
return Qnil;
|
1106
|
+
};
|
1107
|
+
|
1108
|
+
static VALUE _llama_context_sample_typical(int argc, VALUE* argv, VALUE self) {
|
1109
|
+
VALUE kw_args = Qnil;
|
1110
|
+
ID kw_table[2] = { rb_intern("prob"), rb_intern("min_keep") };
|
1111
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
1112
|
+
VALUE candidates = Qnil;
|
1113
|
+
rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
|
1114
|
+
rb_get_kwargs(kw_args, kw_table, 1, 1, kw_values);
|
1115
|
+
|
1116
|
+
if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
|
1117
|
+
rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
|
1118
|
+
return Qnil;
|
1119
|
+
}
|
1120
|
+
if (!RB_FLOAT_TYPE_P(kw_values[0])) {
|
1121
|
+
rb_raise(rb_eArgError, "prob must be a float");
|
1122
|
+
return Qnil;
|
1123
|
+
}
|
1124
|
+
if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
|
1125
|
+
rb_raise(rb_eArgError, "min_keep must be an integer");
|
1126
|
+
return Qnil;
|
1127
|
+
}
|
1128
|
+
|
1129
|
+
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
1130
|
+
if (ctx_ptr->ctx == NULL) {
|
1131
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1132
|
+
return Qnil;
|
1133
|
+
}
|
1134
|
+
LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
|
1135
|
+
if (cnd_ptr->array.data == nullptr) {
|
1136
|
+
rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
|
1137
|
+
return Qnil;
|
1138
|
+
}
|
1139
|
+
const float prob = NUM2DBL(kw_values[0]);
|
1140
|
+
const size_t min_keep = kw_values[1] != Qundef ? NUM2SIZET(kw_values[1]) : 1;
|
1141
|
+
|
1142
|
+
llama_sample_typical(ctx_ptr->ctx, &(cnd_ptr->array), prob, min_keep);
|
1143
|
+
|
1144
|
+
return Qnil;
|
1145
|
+
};
|
1146
|
+
|
1147
|
+
static VALUE _llama_context_sample_temperature(int argc, VALUE* argv, VALUE self) {
|
1148
|
+
VALUE kw_args = Qnil;
|
1149
|
+
ID kw_table[1] = { rb_intern("temperature") };
|
1150
|
+
VALUE kw_values[1] = { Qundef };
|
1151
|
+
VALUE candidates = Qnil;
|
1152
|
+
rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
|
1153
|
+
rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
|
1154
|
+
|
1155
|
+
if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
|
1156
|
+
rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
|
1157
|
+
return Qnil;
|
1158
|
+
}
|
1159
|
+
if (!RB_FLOAT_TYPE_P(kw_values[0])) {
|
1160
|
+
rb_raise(rb_eArgError, "temperature must be a float");
|
1161
|
+
return Qnil;
|
1162
|
+
}
|
1163
|
+
|
1164
|
+
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
1165
|
+
if (ctx_ptr->ctx == NULL) {
|
1166
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1167
|
+
return Qnil;
|
1168
|
+
}
|
1169
|
+
LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
|
1170
|
+
if (cnd_ptr->array.data == nullptr) {
|
1171
|
+
rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
|
1172
|
+
return Qnil;
|
1173
|
+
}
|
1174
|
+
const float temperature = NUM2DBL(kw_values[0]);
|
1175
|
+
|
1176
|
+
llama_sample_temperature(ctx_ptr->ctx, &(cnd_ptr->array), temperature);
|
1177
|
+
|
1178
|
+
return Qnil;
|
1179
|
+
};
|
1180
|
+
|
1181
|
+
static VALUE _llama_context_sample_token_mirostat(int argc, VALUE* argv, VALUE self) {
|
1182
|
+
VALUE kw_args = Qnil;
|
1183
|
+
ID kw_table[4] = { rb_intern("tau"), rb_intern("eta"), rb_intern("m"), rb_intern("mu") };
|
1184
|
+
VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
|
1185
|
+
VALUE candidates = Qnil;
|
1186
|
+
rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
|
1187
|
+
rb_get_kwargs(kw_args, kw_table, 4, 0, kw_values);
|
1188
|
+
|
1189
|
+
if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
|
1190
|
+
rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
|
1191
|
+
return Qnil;
|
1192
|
+
}
|
1193
|
+
if (!RB_FLOAT_TYPE_P(kw_values[0])) {
|
1194
|
+
rb_raise(rb_eArgError, "tau must be a float");
|
1195
|
+
return Qnil;
|
1196
|
+
}
|
1197
|
+
if (!RB_FLOAT_TYPE_P(kw_values[1])) {
|
1198
|
+
rb_raise(rb_eArgError, "eta must be a float");
|
1199
|
+
return Qnil;
|
1200
|
+
}
|
1201
|
+
if (!RB_INTEGER_TYPE_P(kw_values[2])) {
|
1202
|
+
rb_raise(rb_eArgError, "m must be an integer");
|
1203
|
+
return Qnil;
|
1204
|
+
}
|
1205
|
+
if (!RB_FLOAT_TYPE_P(kw_values[3])) {
|
1206
|
+
rb_raise(rb_eArgError, "mu must be a float");
|
1207
|
+
return Qnil;
|
1208
|
+
}
|
1209
|
+
|
1210
|
+
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
1211
|
+
if (ctx_ptr->ctx == NULL) {
|
1212
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1213
|
+
return Qnil;
|
1214
|
+
}
|
1215
|
+
LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
|
1216
|
+
if (cnd_ptr->array.data == nullptr) {
|
1217
|
+
rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
|
1218
|
+
return Qnil;
|
1219
|
+
}
|
1220
|
+
const float tau = NUM2DBL(kw_values[0]);
|
1221
|
+
const float eta = NUM2DBL(kw_values[1]);
|
1222
|
+
const int m = NUM2INT(kw_values[2]);
|
1223
|
+
float mu = NUM2DBL(kw_values[3]);
|
1224
|
+
|
1225
|
+
llama_token id = llama_sample_token_mirostat(ctx_ptr->ctx, &(cnd_ptr->array), tau, eta, m, &mu);
|
1226
|
+
|
1227
|
+
VALUE ret = rb_ary_new2(2);
|
1228
|
+
rb_ary_store(ret, 0, INT2NUM(id));
|
1229
|
+
rb_ary_store(ret, 1, DBL2NUM(mu));
|
1230
|
+
return ret;
|
1231
|
+
};
|
1232
|
+
|
1233
|
+
static VALUE _llama_context_sample_token_mirostat_v2(int argc, VALUE* argv, VALUE self) {
|
1234
|
+
VALUE kw_args = Qnil;
|
1235
|
+
ID kw_table[3] = { rb_intern("tau"), rb_intern("eta"), rb_intern("mu") };
|
1236
|
+
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
1237
|
+
VALUE candidates = Qnil;
|
1238
|
+
rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
|
1239
|
+
rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
|
1240
|
+
|
1241
|
+
if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
|
1242
|
+
rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
|
1243
|
+
return Qnil;
|
1244
|
+
}
|
1245
|
+
if (!RB_FLOAT_TYPE_P(kw_values[0])) {
|
1246
|
+
rb_raise(rb_eArgError, "tau must be a float");
|
1247
|
+
return Qnil;
|
1248
|
+
}
|
1249
|
+
if (!RB_FLOAT_TYPE_P(kw_values[1])) {
|
1250
|
+
rb_raise(rb_eArgError, "eta must be a float");
|
1251
|
+
return Qnil;
|
1252
|
+
}
|
1253
|
+
if (!RB_FLOAT_TYPE_P(kw_values[2])) {
|
1254
|
+
rb_raise(rb_eArgError, "mu must be a float");
|
1255
|
+
return Qnil;
|
1256
|
+
}
|
1257
|
+
|
1258
|
+
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
1259
|
+
if (ctx_ptr->ctx == NULL) {
|
1260
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1261
|
+
return Qnil;
|
1262
|
+
}
|
1263
|
+
LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
|
1264
|
+
if (cnd_ptr->array.data == nullptr) {
|
1265
|
+
rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
|
1266
|
+
return Qnil;
|
1267
|
+
}
|
1268
|
+
const float tau = NUM2DBL(kw_values[0]);
|
1269
|
+
const float eta = NUM2DBL(kw_values[1]);
|
1270
|
+
float mu = NUM2DBL(kw_values[2]);
|
1271
|
+
|
1272
|
+
llama_token id = llama_sample_token_mirostat_v2(ctx_ptr->ctx, &(cnd_ptr->array), tau, eta, &mu);
|
1273
|
+
|
1274
|
+
VALUE ret = rb_ary_new2(2);
|
1275
|
+
rb_ary_store(ret, 0, INT2NUM(id));
|
1276
|
+
rb_ary_store(ret, 1, DBL2NUM(mu));
|
1277
|
+
return ret;
|
1278
|
+
};
|
1279
|
+
|
1280
|
+
static VALUE _llama_context_sample_token_greedy(VALUE self, VALUE candidates) {
|
1281
|
+
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
1282
|
+
if (ctx_ptr->ctx == NULL) {
|
1283
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1284
|
+
return Qnil;
|
1285
|
+
}
|
1286
|
+
if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
|
1287
|
+
rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
|
1288
|
+
return Qnil;
|
1289
|
+
}
|
1290
|
+
LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
|
1291
|
+
if (cnd_ptr->array.data == nullptr) {
|
1292
|
+
rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
|
1293
|
+
return Qnil;
|
1294
|
+
}
|
1295
|
+
llama_token id = llama_sample_token_greedy(ctx_ptr->ctx, &(cnd_ptr->array));
|
1296
|
+
return INT2NUM(id);
|
1297
|
+
};
|
1298
|
+
|
1299
|
+
static VALUE _llama_context_sample_token(VALUE self, VALUE candidates) {
|
1300
|
+
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
1301
|
+
if (ctx_ptr->ctx == NULL) {
|
1302
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1303
|
+
return Qnil;
|
1304
|
+
}
|
1305
|
+
if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
|
1306
|
+
rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
|
1307
|
+
return Qnil;
|
1308
|
+
}
|
1309
|
+
LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
|
1310
|
+
if (cnd_ptr->array.data == nullptr) {
|
1311
|
+
rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
|
1312
|
+
return Qnil;
|
1313
|
+
}
|
1314
|
+
llama_token id = llama_sample_token(ctx_ptr->ctx, &(cnd_ptr->array));
|
1315
|
+
return INT2NUM(id);
|
1316
|
+
};
|
624
1317
|
};
|
625
1318
|
|
626
1319
|
const rb_data_type_t RbLLaMAContext::llama_context_type = {
|
@@ -680,6 +1373,10 @@ static VALUE rb_llama_token_eos(VALUE self) {
|
|
680
1373
|
return INT2NUM(llama_token_eos());
|
681
1374
|
}
|
682
1375
|
|
1376
|
+
static VALUE rb_llama_token_nl(VALUE self) {
|
1377
|
+
return INT2NUM(llama_token_nl());
|
1378
|
+
}
|
1379
|
+
|
683
1380
|
static VALUE rb_llama_print_system_info(VALUE self) {
|
684
1381
|
const char* result = llama_print_system_info();
|
685
1382
|
return rb_utf8_str_new_cstr(result);
|
@@ -695,12 +1392,16 @@ static VALUE rb_llama_mlock_supported(VALUE self) {
|
|
695
1392
|
|
696
1393
|
extern "C" void Init_llama_cpp(void) {
|
697
1394
|
rb_mLLaMACpp = rb_define_module("LLaMACpp");
|
1395
|
+
|
1396
|
+
RbLLaMATokenData::define_class(rb_mLLaMACpp);
|
1397
|
+
RbLLaMATokenDataArray::define_class(rb_mLLaMACpp);
|
698
1398
|
RbLLaMAContext::define_class(rb_mLLaMACpp);
|
699
1399
|
RbLLaMAContextParams::define_class(rb_mLLaMACpp);
|
700
1400
|
|
701
1401
|
rb_define_module_function(rb_mLLaMACpp, "model_quantize", rb_llama_model_quantize, -1);
|
702
1402
|
rb_define_module_function(rb_mLLaMACpp, "token_bos", rb_llama_token_bos, 0);
|
703
1403
|
rb_define_module_function(rb_mLLaMACpp, "token_eos", rb_llama_token_eos, 0);
|
1404
|
+
rb_define_module_function(rb_mLLaMACpp, "token_nl", rb_llama_token_nl, 0);
|
704
1405
|
rb_define_module_function(rb_mLLaMACpp, "print_system_info", rb_llama_print_system_info, 0);
|
705
1406
|
rb_define_module_function(rb_mLLaMACpp, "mmap_supported?", rb_llama_mmap_supported, 0);
|
706
1407
|
rb_define_module_function(rb_mLLaMACpp, "mlock_supported?", rb_llama_mlock_supported, 0);
|
@@ -711,7 +1412,6 @@ extern "C" void Init_llama_cpp(void) {
|
|
711
1412
|
rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q4_1", INT2NUM(LLAMA_FTYPE_MOSTLY_Q4_1));
|
712
1413
|
rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16", INT2NUM(LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16));
|
713
1414
|
rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q4_2", INT2NUM(LLAMA_FTYPE_MOSTLY_Q4_2));
|
714
|
-
rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q4_3", INT2NUM(LLAMA_FTYPE_MOSTLY_Q4_3));
|
715
1415
|
rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q8_0", INT2NUM(LLAMA_FTYPE_MOSTLY_Q8_0));
|
716
1416
|
rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q5_0", INT2NUM(LLAMA_FTYPE_MOSTLY_Q5_0));
|
717
1417
|
rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q5_1", INT2NUM(LLAMA_FTYPE_MOSTLY_Q5_1));
|