llama_cpp 0.0.7 → 0.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +17 -0
- data/ext/llama_cpp/llama_cpp.cpp +829 -51
- data/ext/llama_cpp/src/ggml-cuda.h +9 -32
- data/ext/llama_cpp/src/ggml-opencl.c +169 -24
- data/ext/llama_cpp/src/ggml.c +6672 -4376
- data/ext/llama_cpp/src/ggml.h +250 -15
- data/ext/llama_cpp/src/{llama_util.h → llama-util.h} +76 -10
- data/ext/llama_cpp/src/llama.cpp +710 -217
- data/ext/llama_cpp/src/llama.h +75 -28
- data/lib/llama_cpp/client.rb +30 -9
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +27 -3
- data/sig/llama_cpp.rbs +41 -7
- metadata +3 -3
data/ext/llama_cpp/llama_cpp.cpp
CHANGED
@@ -4,6 +4,255 @@
|
|
4
4
|
VALUE rb_mLLaMACpp;
|
5
5
|
VALUE rb_cLLaMAContext;
|
6
6
|
VALUE rb_cLLaMAContextParams;
|
7
|
+
VALUE rb_cLLaMATokenData;
|
8
|
+
VALUE rb_cLLaMATokenDataArray;
|
9
|
+
|
10
|
+
class LLaMATokenDataWrapper {
|
11
|
+
public:
|
12
|
+
llama_token_data data;
|
13
|
+
|
14
|
+
LLaMATokenDataWrapper() {
|
15
|
+
data.id = 0;
|
16
|
+
data.logit = 0.0;
|
17
|
+
data.p = 0.0;
|
18
|
+
};
|
19
|
+
|
20
|
+
~LLaMATokenDataWrapper(){};
|
21
|
+
};
|
22
|
+
|
23
|
+
class RbLLaMATokenData {
|
24
|
+
public:
|
25
|
+
static VALUE llama_token_data_alloc(VALUE self) {
|
26
|
+
LLaMATokenDataWrapper* ptr = (LLaMATokenDataWrapper*)ruby_xmalloc(sizeof(LLaMATokenDataWrapper));
|
27
|
+
new (ptr) LLaMATokenDataWrapper();
|
28
|
+
return TypedData_Wrap_Struct(self, &llama_token_data_type, ptr);
|
29
|
+
};
|
30
|
+
|
31
|
+
static void llama_token_data_free(void* ptr) {
|
32
|
+
((LLaMATokenDataWrapper*)ptr)->~LLaMATokenDataWrapper();
|
33
|
+
ruby_xfree(ptr);
|
34
|
+
};
|
35
|
+
|
36
|
+
static size_t llama_token_data_size(const void* ptr) {
|
37
|
+
return sizeof(*((LLaMATokenDataWrapper*)ptr));
|
38
|
+
};
|
39
|
+
|
40
|
+
static LLaMATokenDataWrapper* get_llama_token_data(VALUE self) {
|
41
|
+
LLaMATokenDataWrapper* ptr;
|
42
|
+
TypedData_Get_Struct(self, LLaMATokenDataWrapper, &llama_token_data_type, ptr);
|
43
|
+
return ptr;
|
44
|
+
};
|
45
|
+
|
46
|
+
static void define_class(VALUE outer) {
|
47
|
+
rb_cLLaMATokenData = rb_define_class_under(outer, "TokenData", rb_cObject);
|
48
|
+
rb_define_alloc_func(rb_cLLaMATokenData, llama_token_data_alloc);
|
49
|
+
rb_define_method(rb_cLLaMATokenData, "initialize", RUBY_METHOD_FUNC(_llama_token_data_init), -1);
|
50
|
+
rb_define_method(rb_cLLaMATokenData, "id=", RUBY_METHOD_FUNC(_llama_token_data_set_id), 1);
|
51
|
+
rb_define_method(rb_cLLaMATokenData, "id", RUBY_METHOD_FUNC(_llama_token_data_get_id), 0);
|
52
|
+
rb_define_method(rb_cLLaMATokenData, "logit=", RUBY_METHOD_FUNC(_llama_token_data_set_logit), 1);
|
53
|
+
rb_define_method(rb_cLLaMATokenData, "logit", RUBY_METHOD_FUNC(_llama_token_data_get_logit), 0);
|
54
|
+
rb_define_method(rb_cLLaMATokenData, "p=", RUBY_METHOD_FUNC(_llama_token_data_set_p), 1);
|
55
|
+
rb_define_method(rb_cLLaMATokenData, "p", RUBY_METHOD_FUNC(_llama_token_data_get_p), 0);
|
56
|
+
}
|
57
|
+
|
58
|
+
private:
|
59
|
+
static const rb_data_type_t llama_token_data_type;
|
60
|
+
|
61
|
+
static VALUE _llama_token_data_init(int argc, VALUE* argv, VALUE self) {
|
62
|
+
VALUE kw_args = Qnil;
|
63
|
+
ID kw_table[3] = { rb_intern("id"), rb_intern("logit"), rb_intern("p") };
|
64
|
+
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
65
|
+
rb_scan_args(argc, argv, ":", &kw_args);
|
66
|
+
rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
|
67
|
+
|
68
|
+
if (!RB_INTEGER_TYPE_P(kw_values[0])) {
|
69
|
+
rb_raise(rb_eArgError, "id must be an integer");
|
70
|
+
return Qnil;
|
71
|
+
}
|
72
|
+
if (!RB_FLOAT_TYPE_P(kw_values[1])) {
|
73
|
+
rb_raise(rb_eArgError, "logit must be a float");
|
74
|
+
return Qnil;
|
75
|
+
}
|
76
|
+
if (!RB_FLOAT_TYPE_P(kw_values[2])) {
|
77
|
+
rb_raise(rb_eArgError, "p must be a float");
|
78
|
+
return Qnil;
|
79
|
+
}
|
80
|
+
|
81
|
+
LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
|
82
|
+
new (ptr) LLaMATokenDataWrapper();
|
83
|
+
|
84
|
+
ptr->data.id = NUM2INT(kw_values[0]);
|
85
|
+
ptr->data.logit = NUM2DBL(kw_values[1]);
|
86
|
+
ptr->data.p = NUM2DBL(kw_values[2]);
|
87
|
+
|
88
|
+
return self;
|
89
|
+
}
|
90
|
+
|
91
|
+
// id
|
92
|
+
static VALUE _llama_token_data_set_id(VALUE self, VALUE id) {
|
93
|
+
LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
|
94
|
+
ptr->data.id = NUM2INT(id);
|
95
|
+
return INT2NUM(ptr->data.id);
|
96
|
+
};
|
97
|
+
|
98
|
+
static VALUE _llama_token_data_get_id(VALUE self) {
|
99
|
+
LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
|
100
|
+
return INT2NUM(ptr->data.id);
|
101
|
+
};
|
102
|
+
|
103
|
+
// logit
|
104
|
+
static VALUE _llama_token_data_set_logit(VALUE self, VALUE logit) {
|
105
|
+
LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
|
106
|
+
ptr->data.logit = NUM2DBL(logit);
|
107
|
+
return DBL2NUM(ptr->data.logit);
|
108
|
+
};
|
109
|
+
|
110
|
+
static VALUE _llama_token_data_get_logit(VALUE self) {
|
111
|
+
LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
|
112
|
+
return DBL2NUM(ptr->data.logit);
|
113
|
+
};
|
114
|
+
|
115
|
+
// p
|
116
|
+
static VALUE _llama_token_data_set_p(VALUE self, VALUE p) {
|
117
|
+
LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
|
118
|
+
ptr->data.p = NUM2DBL(p);
|
119
|
+
return DBL2NUM(ptr->data.p);
|
120
|
+
};
|
121
|
+
|
122
|
+
static VALUE _llama_token_data_get_p(VALUE self) {
|
123
|
+
LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
|
124
|
+
return DBL2NUM(ptr->data.p);
|
125
|
+
};
|
126
|
+
};
|
127
|
+
|
128
|
+
const rb_data_type_t RbLLaMATokenData::llama_token_data_type = {
|
129
|
+
"RbLLaMATokenData",
|
130
|
+
{ NULL,
|
131
|
+
RbLLaMATokenData::llama_token_data_free,
|
132
|
+
RbLLaMATokenData::llama_token_data_size },
|
133
|
+
NULL,
|
134
|
+
NULL,
|
135
|
+
RUBY_TYPED_FREE_IMMEDIATELY
|
136
|
+
};
|
137
|
+
|
138
|
+
class LLaMATokenDataArrayWrapper {
|
139
|
+
public:
|
140
|
+
llama_token_data_array array;
|
141
|
+
|
142
|
+
LLaMATokenDataArrayWrapper() {
|
143
|
+
array.data = nullptr;
|
144
|
+
array.size = 0;
|
145
|
+
array.sorted = false;
|
146
|
+
};
|
147
|
+
|
148
|
+
~LLaMATokenDataArrayWrapper() {
|
149
|
+
if (array.data) {
|
150
|
+
ruby_xfree(array.data);
|
151
|
+
array.data = nullptr;
|
152
|
+
}
|
153
|
+
};
|
154
|
+
};
|
155
|
+
|
156
|
+
class RbLLaMATokenDataArray {
|
157
|
+
public:
|
158
|
+
static VALUE llama_token_data_array_alloc(VALUE self) {
|
159
|
+
LLaMATokenDataArrayWrapper* ptr = (LLaMATokenDataArrayWrapper*)ruby_xmalloc(sizeof(LLaMATokenDataArrayWrapper));
|
160
|
+
new (ptr) LLaMATokenDataArrayWrapper();
|
161
|
+
return TypedData_Wrap_Struct(self, &llama_token_data_array_type, ptr);
|
162
|
+
};
|
163
|
+
|
164
|
+
static void llama_token_data_array_free(void* ptr) {
|
165
|
+
((LLaMATokenDataArrayWrapper*)ptr)->~LLaMATokenDataArrayWrapper();
|
166
|
+
ruby_xfree(ptr);
|
167
|
+
};
|
168
|
+
|
169
|
+
static size_t llama_token_data_array_size(const void* ptr) {
|
170
|
+
return sizeof(*((LLaMATokenDataArrayWrapper*)ptr));
|
171
|
+
};
|
172
|
+
|
173
|
+
static LLaMATokenDataArrayWrapper* get_llama_token_data_array(VALUE self) {
|
174
|
+
LLaMATokenDataArrayWrapper* ptr;
|
175
|
+
TypedData_Get_Struct(self, LLaMATokenDataArrayWrapper, &llama_token_data_array_type, ptr);
|
176
|
+
return ptr;
|
177
|
+
};
|
178
|
+
|
179
|
+
static void define_class(VALUE outer) {
|
180
|
+
rb_cLLaMATokenDataArray = rb_define_class_under(outer, "TokenDataArray", rb_cObject);
|
181
|
+
rb_define_alloc_func(rb_cLLaMATokenDataArray, llama_token_data_array_alloc);
|
182
|
+
rb_define_method(rb_cLLaMATokenDataArray, "initialize", RUBY_METHOD_FUNC(_llama_token_data_array_init), -1);
|
183
|
+
rb_define_method(rb_cLLaMATokenDataArray, "size", RUBY_METHOD_FUNC(_llama_token_data_array_get_size), 0);
|
184
|
+
rb_define_method(rb_cLLaMATokenDataArray, "sorted", RUBY_METHOD_FUNC(_llama_token_data_array_get_sorted), 0);
|
185
|
+
};
|
186
|
+
|
187
|
+
private:
|
188
|
+
static const rb_data_type_t llama_token_data_array_type;
|
189
|
+
|
190
|
+
static VALUE _llama_token_data_array_init(int argc, VALUE* argv, VALUE self) {
|
191
|
+
VALUE kw_args = Qnil;
|
192
|
+
ID kw_table[1] = { rb_intern("sorted") };
|
193
|
+
VALUE kw_values[1] = { Qundef };
|
194
|
+
VALUE arr = Qnil;
|
195
|
+
rb_scan_args(argc, argv, "1:", &arr, &kw_args);
|
196
|
+
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
197
|
+
|
198
|
+
if (!RB_TYPE_P(arr, T_ARRAY)) {
|
199
|
+
rb_raise(rb_eArgError, "1st argument must be an array");
|
200
|
+
return Qnil;
|
201
|
+
}
|
202
|
+
size_t sz_array = RARRAY_LEN(arr);
|
203
|
+
if (sz_array == 0) {
|
204
|
+
rb_raise(rb_eArgError, "array must not be empty");
|
205
|
+
return Qnil;
|
206
|
+
}
|
207
|
+
if (kw_values[0] != Qundef && !RB_TYPE_P(kw_values[0], T_TRUE) && !RB_TYPE_P(kw_values[0], T_FALSE)) {
|
208
|
+
rb_raise(rb_eArgError, "sorted must be a boolean");
|
209
|
+
return Qnil;
|
210
|
+
}
|
211
|
+
|
212
|
+
LLaMATokenDataArrayWrapper* ptr = get_llama_token_data_array(self);
|
213
|
+
new (ptr) LLaMATokenDataArrayWrapper();
|
214
|
+
|
215
|
+
ptr->array.data = (llama_token_data*)ruby_xmalloc(sizeof(llama_token_data) * sz_array);
|
216
|
+
for (size_t i = 0; i < sz_array; ++i) {
|
217
|
+
VALUE el = rb_ary_entry(arr, i);
|
218
|
+
if (!rb_obj_is_kind_of(el, rb_cLLaMATokenData)) {
|
219
|
+
rb_raise(rb_eArgError, "array element must be a TokenData");
|
220
|
+
xfree(ptr->array.data);
|
221
|
+
ptr->array.data = nullptr;
|
222
|
+
return Qnil;
|
223
|
+
}
|
224
|
+
llama_token_data token_data = RbLLaMATokenData::get_llama_token_data(el)->data;
|
225
|
+
ptr->array.data[i].id = token_data.id;
|
226
|
+
ptr->array.data[i].logit = token_data.logit;
|
227
|
+
ptr->array.data[i].p = token_data.p;
|
228
|
+
}
|
229
|
+
|
230
|
+
ptr->array.size = sz_array;
|
231
|
+
ptr->array.sorted = kw_values[0] == Qtrue;
|
232
|
+
|
233
|
+
return self;
|
234
|
+
};
|
235
|
+
|
236
|
+
static VALUE _llama_token_data_array_get_size(VALUE self) {
|
237
|
+
LLaMATokenDataArrayWrapper* ptr = get_llama_token_data_array(self);
|
238
|
+
return SIZET2NUM(ptr->array.size);
|
239
|
+
};
|
240
|
+
|
241
|
+
static VALUE _llama_token_data_array_get_sorted(VALUE self) {
|
242
|
+
LLaMATokenDataArrayWrapper* ptr = get_llama_token_data_array(self);
|
243
|
+
return ptr->array.sorted ? Qtrue : Qfalse;
|
244
|
+
};
|
245
|
+
};
|
246
|
+
|
247
|
+
const rb_data_type_t RbLLaMATokenDataArray::llama_token_data_array_type = {
|
248
|
+
"RbLLaMATokenDataArray",
|
249
|
+
{ NULL,
|
250
|
+
RbLLaMATokenDataArray::llama_token_data_array_free,
|
251
|
+
RbLLaMATokenDataArray::llama_token_data_array_size },
|
252
|
+
NULL,
|
253
|
+
NULL,
|
254
|
+
RUBY_TYPED_FREE_IMMEDIATELY
|
255
|
+
};
|
7
256
|
|
8
257
|
class LLaMAContextParamsWrapper {
|
9
258
|
public:
|
@@ -43,8 +292,6 @@ public:
|
|
43
292
|
// rb_define_method(rb_cLLaMAContextParams, "initialize", RUBY_METHOD_FUNC(_llama_context_params_init), 0);
|
44
293
|
rb_define_method(rb_cLLaMAContextParams, "n_ctx=", RUBY_METHOD_FUNC(_llama_context_params_set_n_ctx), 1);
|
45
294
|
rb_define_method(rb_cLLaMAContextParams, "n_ctx", RUBY_METHOD_FUNC(_llama_context_params_get_n_ctx), 0);
|
46
|
-
rb_define_method(rb_cLLaMAContextParams, "n_parts=", RUBY_METHOD_FUNC(_llama_context_params_set_n_parts), 1);
|
47
|
-
rb_define_method(rb_cLLaMAContextParams, "n_parts", RUBY_METHOD_FUNC(_llama_context_params_get_n_parts), 0);
|
48
295
|
rb_define_method(rb_cLLaMAContextParams, "seed=", RUBY_METHOD_FUNC(_llama_context_params_set_seed), 1);
|
49
296
|
rb_define_method(rb_cLLaMAContextParams, "seed", RUBY_METHOD_FUNC(_llama_context_params_get_seed), 0);
|
50
297
|
rb_define_method(rb_cLLaMAContextParams, "f16_kv=", RUBY_METHOD_FUNC(_llama_context_params_set_f16_kv), 1);
|
@@ -82,18 +329,6 @@ private:
|
|
82
329
|
return INT2NUM(ptr->params.n_ctx);
|
83
330
|
};
|
84
331
|
|
85
|
-
// n_parts
|
86
|
-
static VALUE _llama_context_params_set_n_parts(VALUE self, VALUE n_parts) {
|
87
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
88
|
-
ptr->params.n_parts = NUM2INT(n_parts);
|
89
|
-
return INT2NUM(ptr->params.n_parts);
|
90
|
-
};
|
91
|
-
|
92
|
-
static VALUE _llama_context_params_get_n_parts(VALUE self) {
|
93
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
94
|
-
return INT2NUM(ptr->params.n_parts);
|
95
|
-
};
|
96
|
-
|
97
332
|
// seed
|
98
333
|
static VALUE _llama_context_params_set_seed(VALUE self, VALUE seed) {
|
99
334
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
@@ -234,7 +469,6 @@ public:
|
|
234
469
|
rb_define_method(rb_cLLaMAContext, "logits", RUBY_METHOD_FUNC(_llama_context_logits), 0);
|
235
470
|
rb_define_method(rb_cLLaMAContext, "embeddings", RUBY_METHOD_FUNC(_llama_context_embeddings), 0);
|
236
471
|
rb_define_method(rb_cLLaMAContext, "token_to_str", RUBY_METHOD_FUNC(_llama_context_token_to_str), 1);
|
237
|
-
rb_define_method(rb_cLLaMAContext, "sample_top_p_top_k", RUBY_METHOD_FUNC(_llama_context_sample_top_p_top_k), -1);
|
238
472
|
rb_define_method(rb_cLLaMAContext, "n_vocab", RUBY_METHOD_FUNC(_llama_context_n_vocab), 0);
|
239
473
|
rb_define_method(rb_cLLaMAContext, "n_ctx", RUBY_METHOD_FUNC(_llama_context_n_ctx), 0);
|
240
474
|
rb_define_method(rb_cLLaMAContext, "n_embd", RUBY_METHOD_FUNC(_llama_context_n_embd), 0);
|
@@ -244,6 +478,22 @@ public:
|
|
244
478
|
rb_define_method(rb_cLLaMAContext, "free", RUBY_METHOD_FUNC(_llama_context_free), 0);
|
245
479
|
rb_define_method(rb_cLLaMAContext, "load", RUBY_METHOD_FUNC(_llama_context_load), -1);
|
246
480
|
rb_define_method(rb_cLLaMAContext, "apply_lora_from_file", RUBY_METHOD_FUNC(_llama_context_apply_lora_from_file), -1);
|
481
|
+
rb_define_method(rb_cLLaMAContext, "kv_cache_token_count", RUBY_METHOD_FUNC(_llama_context_kv_cache_token_count), 0);
|
482
|
+
rb_define_method(rb_cLLaMAContext, "set_rng_seed", RUBY_METHOD_FUNC(_llama_context_set_rng_seed), 1);
|
483
|
+
rb_define_method(rb_cLLaMAContext, "load_session_file", RUBY_METHOD_FUNC(_llama_context_load_session_file), -1);
|
484
|
+
rb_define_method(rb_cLLaMAContext, "save_session_file", RUBY_METHOD_FUNC(_llama_context_save_session_file), -1);
|
485
|
+
rb_define_method(rb_cLLaMAContext, "sample_repetition_penalty", RUBY_METHOD_FUNC(_llama_context_sample_repetition_penalty), -1);
|
486
|
+
rb_define_method(rb_cLLaMAContext, "sample_frequency_and_presence_penalties", RUBY_METHOD_FUNC(_llama_context_sample_frequency_and_presence_penalties), -1);
|
487
|
+
rb_define_method(rb_cLLaMAContext, "sample_softmax", RUBY_METHOD_FUNC(_llama_context_sample_softmax), 1);
|
488
|
+
rb_define_method(rb_cLLaMAContext, "sample_top_k", RUBY_METHOD_FUNC(_llama_context_sample_top_k), -1);
|
489
|
+
rb_define_method(rb_cLLaMAContext, "sample_top_p", RUBY_METHOD_FUNC(_llama_context_sample_top_p), -1);
|
490
|
+
rb_define_method(rb_cLLaMAContext, "sample_tail_free", RUBY_METHOD_FUNC(_llama_context_sample_tail_free), -1);
|
491
|
+
rb_define_method(rb_cLLaMAContext, "sample_typical", RUBY_METHOD_FUNC(_llama_context_sample_typical), -1);
|
492
|
+
rb_define_method(rb_cLLaMAContext, "sample_temperature", RUBY_METHOD_FUNC(_llama_context_sample_temperature), -1);
|
493
|
+
rb_define_method(rb_cLLaMAContext, "sample_token_mirostat", RUBY_METHOD_FUNC(_llama_context_sample_token_mirostat), -1);
|
494
|
+
rb_define_method(rb_cLLaMAContext, "sample_token_mirostat_v2", RUBY_METHOD_FUNC(_llama_context_sample_token_mirostat_v2), -1);
|
495
|
+
rb_define_method(rb_cLLaMAContext, "sample_token_greedy", RUBY_METHOD_FUNC(_llama_context_sample_token_greedy), 1);
|
496
|
+
rb_define_method(rb_cLLaMAContext, "sample_token", RUBY_METHOD_FUNC(_llama_context_sample_token), 1);
|
247
497
|
};
|
248
498
|
|
249
499
|
private:
|
@@ -448,40 +698,6 @@ private:
|
|
448
698
|
return output;
|
449
699
|
};
|
450
700
|
|
451
|
-
static VALUE _llama_context_sample_top_p_top_k(int argc, VALUE* argv, VALUE self) {
|
452
|
-
VALUE last_n_tokens = Qnil;
|
453
|
-
VALUE kw_args = Qnil;
|
454
|
-
ID kw_table[4] = { rb_intern("top_k"), rb_intern("top_p"), rb_intern("temp"), rb_intern("penalty") };
|
455
|
-
VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
|
456
|
-
rb_scan_args(argc, argv, "1:", &last_n_tokens, &kw_args);
|
457
|
-
rb_get_kwargs(kw_args, kw_table, 4, 0, kw_values);
|
458
|
-
|
459
|
-
if (!RB_TYPE_P(last_n_tokens, T_ARRAY)) {
|
460
|
-
rb_raise(rb_eArgError, "last_n_tokens must be an Array");
|
461
|
-
return Qnil;
|
462
|
-
}
|
463
|
-
|
464
|
-
const int last_n_tokens_size = RARRAY_LEN(last_n_tokens);
|
465
|
-
const int top_k = NUM2INT(kw_values[0]);
|
466
|
-
const double top_p = NUM2DBL(kw_values[1]);
|
467
|
-
const double temp = NUM2DBL(kw_values[2]);
|
468
|
-
const double penalty = NUM2DBL(kw_values[3]);
|
469
|
-
|
470
|
-
std::vector<llama_token> last_n_tokens_data(last_n_tokens_size);
|
471
|
-
for (int i = 0; i < last_n_tokens_size; i++) {
|
472
|
-
last_n_tokens_data[i] = NUM2INT(rb_ary_entry(last_n_tokens, i));
|
473
|
-
}
|
474
|
-
|
475
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
476
|
-
if (ptr->ctx == NULL) {
|
477
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
478
|
-
return Qnil;
|
479
|
-
}
|
480
|
-
llama_token token = llama_sample_top_p_top_k(ptr->ctx, last_n_tokens_data.data(), last_n_tokens_size, top_k, top_p, temp, penalty);
|
481
|
-
|
482
|
-
return INT2NUM(token);
|
483
|
-
}
|
484
|
-
|
485
701
|
static VALUE _llama_context_n_vocab(VALUE self) {
|
486
702
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
487
703
|
if (ptr->ctx == NULL) {
|
@@ -621,6 +837,562 @@ private:
|
|
621
837
|
}
|
622
838
|
return Qnil;
|
623
839
|
};
|
840
|
+
|
841
|
+
static VALUE _llama_context_kv_cache_token_count(VALUE self) {
|
842
|
+
LLaMAContextWrapper* ptr = get_llama_context(self);
|
843
|
+
if (ptr->ctx == NULL) {
|
844
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
845
|
+
return Qnil;
|
846
|
+
}
|
847
|
+
return INT2NUM(llama_get_kv_cache_token_count(ptr->ctx));
|
848
|
+
};
|
849
|
+
|
850
|
+
static VALUE _llama_context_set_rng_seed(VALUE self, VALUE seed_) {
|
851
|
+
LLaMAContextWrapper* ptr = get_llama_context(self);
|
852
|
+
if (ptr->ctx == NULL) {
|
853
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
854
|
+
return Qnil;
|
855
|
+
}
|
856
|
+
const int seed = NUM2INT(seed_);
|
857
|
+
llama_set_rng_seed(ptr->ctx, seed);
|
858
|
+
return Qnil;
|
859
|
+
};
|
860
|
+
|
861
|
+
static VALUE _llama_context_load_session_file(int argc, VALUE* argv, VALUE self) {
|
862
|
+
VALUE kw_args = Qnil;
|
863
|
+
ID kw_table[1] = { rb_intern("session_path") };
|
864
|
+
VALUE kw_values[1] = { Qundef };
|
865
|
+
VALUE candidates = Qnil;
|
866
|
+
VALUE last_n_tokens = Qnil;
|
867
|
+
rb_scan_args(argc, argv, ":", &kw_args);
|
868
|
+
rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
|
869
|
+
|
870
|
+
if (!RB_TYPE_P(kw_values[0], T_STRING)) {
|
871
|
+
rb_raise(rb_eArgError, "session_path must be a String");
|
872
|
+
return Qnil;
|
873
|
+
}
|
874
|
+
|
875
|
+
VALUE filename = kw_values[0];
|
876
|
+
|
877
|
+
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
878
|
+
if (ctx_ptr->ctx == NULL) {
|
879
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
880
|
+
return Qnil;
|
881
|
+
}
|
882
|
+
|
883
|
+
LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(rb_iv_get(self, "@params"));
|
884
|
+
const int n_ctx = prms_ptr->params.n_ctx;
|
885
|
+
|
886
|
+
std::vector<llama_token> session_tokens(n_ctx);
|
887
|
+
size_t n_token_count_out = 0;
|
888
|
+
|
889
|
+
try {
|
890
|
+
bool res = llama_load_session_file(ctx_ptr->ctx, StringValueCStr(filename), session_tokens.data(), session_tokens.capacity(), &n_token_count_out);
|
891
|
+
if (!res) {
|
892
|
+
rb_raise(rb_eRuntimeError, "Failed to load session file");
|
893
|
+
return Qnil;
|
894
|
+
}
|
895
|
+
session_tokens.resize(n_token_count_out);
|
896
|
+
} catch (const std::runtime_error& e) {
|
897
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
898
|
+
return Qnil;
|
899
|
+
}
|
900
|
+
|
901
|
+
VALUE ary_session_tokens = rb_ary_new2(n_token_count_out);
|
902
|
+
for (size_t i = 0; i < n_token_count_out; i++) {
|
903
|
+
rb_ary_store(ary_session_tokens, i, INT2NUM(session_tokens[i]));
|
904
|
+
}
|
905
|
+
|
906
|
+
RB_GC_GUARD(filename);
|
907
|
+
return ary_session_tokens;
|
908
|
+
}
|
909
|
+
|
910
|
+
static VALUE _llama_context_save_session_file(int argc, VALUE* argv, VALUE self) {
|
911
|
+
VALUE kw_args = Qnil;
|
912
|
+
ID kw_table[2] = { rb_intern("session_path"), rb_intern("session_tokens") };
|
913
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
914
|
+
VALUE candidates = Qnil;
|
915
|
+
VALUE last_n_tokens = Qnil;
|
916
|
+
rb_scan_args(argc, argv, ":", &kw_args);
|
917
|
+
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
918
|
+
|
919
|
+
if (!RB_TYPE_P(kw_values[0], T_STRING)) {
|
920
|
+
rb_raise(rb_eArgError, "session_path must be a String");
|
921
|
+
return Qnil;
|
922
|
+
}
|
923
|
+
if (!RB_TYPE_P(kw_values[1], T_ARRAY)) {
|
924
|
+
rb_raise(rb_eArgError, "session_tokens must be an Array");
|
925
|
+
return Qnil;
|
926
|
+
}
|
927
|
+
|
928
|
+
VALUE filename = kw_values[0];
|
929
|
+
const size_t sz_session_tokens = RARRAY_LEN(kw_values[1]);
|
930
|
+
std::vector<llama_token> session_tokens(sz_session_tokens);
|
931
|
+
for (size_t i = 0; i < sz_session_tokens; i++) {
|
932
|
+
session_tokens[i] = NUM2INT(rb_ary_entry(kw_values[1], i));
|
933
|
+
}
|
934
|
+
|
935
|
+
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
936
|
+
if (ctx_ptr->ctx == NULL) {
|
937
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
938
|
+
return Qnil;
|
939
|
+
}
|
940
|
+
|
941
|
+
bool res = llama_save_session_file(ctx_ptr->ctx, StringValueCStr(filename), session_tokens.data(), sz_session_tokens);
|
942
|
+
|
943
|
+
if (!res) {
|
944
|
+
rb_raise(rb_eRuntimeError, "Failed to save session file");
|
945
|
+
return Qnil;
|
946
|
+
}
|
947
|
+
|
948
|
+
RB_GC_GUARD(filename);
|
949
|
+
return Qnil;
|
950
|
+
}
|
951
|
+
|
952
|
+
static VALUE _llama_context_sample_repetition_penalty(int argc, VALUE* argv, VALUE self) {
|
953
|
+
VALUE kw_args = Qnil;
|
954
|
+
ID kw_table[1] = { rb_intern("penalty") };
|
955
|
+
VALUE kw_values[1] = { Qundef };
|
956
|
+
VALUE candidates = Qnil;
|
957
|
+
VALUE last_n_tokens = Qnil;
|
958
|
+
rb_scan_args(argc, argv, "2:", &candidates, &last_n_tokens, &kw_args);
|
959
|
+
rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
|
960
|
+
|
961
|
+
if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
|
962
|
+
rb_raise(rb_eArgError, "candidates must be a TokenDataArray");
|
963
|
+
return Qnil;
|
964
|
+
}
|
965
|
+
if (!RB_TYPE_P(last_n_tokens, T_ARRAY)) {
|
966
|
+
rb_raise(rb_eArgError, "last_n_tokens must be an Array");
|
967
|
+
return Qnil;
|
968
|
+
}
|
969
|
+
if (!RB_FLOAT_TYPE_P(kw_values[0])) {
|
970
|
+
rb_raise(rb_eArgError, "penalty must be a float");
|
971
|
+
return Qnil;
|
972
|
+
}
|
973
|
+
|
974
|
+
const size_t last_tokens_size = RARRAY_LEN(last_n_tokens);
|
975
|
+
std::vector<llama_token> last_n_tokens_data(last_tokens_size);
|
976
|
+
for (size_t i = 0; i < last_tokens_size; i++) {
|
977
|
+
last_n_tokens_data[i] = NUM2INT(rb_ary_entry(last_n_tokens, i));
|
978
|
+
}
|
979
|
+
|
980
|
+
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
981
|
+
if (ctx_ptr->ctx == NULL) {
|
982
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
983
|
+
return Qnil;
|
984
|
+
}
|
985
|
+
LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
|
986
|
+
if (cnd_ptr->array.data == nullptr) {
|
987
|
+
rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
|
988
|
+
return Qnil;
|
989
|
+
}
|
990
|
+
const float penalty = NUM2DBL(kw_values[0]);
|
991
|
+
|
992
|
+
llama_sample_repetition_penalty(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size, penalty);
|
993
|
+
|
994
|
+
return Qnil;
|
995
|
+
};
|
996
|
+
|
997
|
+
static VALUE _llama_context_sample_frequency_and_presence_penalties(int argc, VALUE* argv, VALUE self) {
|
998
|
+
VALUE kw_args = Qnil;
|
999
|
+
ID kw_table[2] = { rb_intern("frequency"), rb_intern("presence") };
|
1000
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
1001
|
+
VALUE candidates = Qnil;
|
1002
|
+
VALUE last_n_tokens = Qnil;
|
1003
|
+
rb_scan_args(argc, argv, "2:", &candidates, &last_n_tokens, &kw_args);
|
1004
|
+
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
1005
|
+
|
1006
|
+
if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
|
1007
|
+
rb_raise(rb_eArgError, "candidates must be a TokenDataArray");
|
1008
|
+
return Qnil;
|
1009
|
+
}
|
1010
|
+
if (!RB_TYPE_P(last_n_tokens, T_ARRAY)) {
|
1011
|
+
rb_raise(rb_eArgError, "last_n_tokens must be an Array");
|
1012
|
+
return Qnil;
|
1013
|
+
}
|
1014
|
+
if (!RB_FLOAT_TYPE_P(kw_values[0])) {
|
1015
|
+
rb_raise(rb_eArgError, "frequency must be a float");
|
1016
|
+
return Qnil;
|
1017
|
+
}
|
1018
|
+
if (!RB_FLOAT_TYPE_P(kw_values[1])) {
|
1019
|
+
rb_raise(rb_eArgError, "presence must be a float");
|
1020
|
+
return Qnil;
|
1021
|
+
}
|
1022
|
+
|
1023
|
+
const size_t last_tokens_size = RARRAY_LEN(last_n_tokens);
|
1024
|
+
std::vector<llama_token> last_n_tokens_data(last_tokens_size);
|
1025
|
+
for (size_t i = 0; i < last_tokens_size; i++) {
|
1026
|
+
last_n_tokens_data[i] = NUM2INT(rb_ary_entry(last_n_tokens, i));
|
1027
|
+
}
|
1028
|
+
|
1029
|
+
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
1030
|
+
if (ctx_ptr->ctx == NULL) {
|
1031
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1032
|
+
return Qnil;
|
1033
|
+
}
|
1034
|
+
LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
|
1035
|
+
if (cnd_ptr->array.data == nullptr) {
|
1036
|
+
rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
|
1037
|
+
return Qnil;
|
1038
|
+
}
|
1039
|
+
|
1040
|
+
const float alpha_frequency = NUM2DBL(kw_values[0]);
|
1041
|
+
const float alpha_presence = NUM2DBL(kw_values[1]);
|
1042
|
+
|
1043
|
+
llama_sample_frequency_and_presence_penalties(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size, alpha_frequency, alpha_presence);
|
1044
|
+
|
1045
|
+
return Qnil;
|
1046
|
+
};
|
1047
|
+
|
1048
|
+
static VALUE _llama_context_sample_softmax(VALUE self, VALUE candidates) {
|
1049
|
+
if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
|
1050
|
+
rb_raise(rb_eArgError, "argument must be a TokenDataArray");
|
1051
|
+
return Qnil;
|
1052
|
+
}
|
1053
|
+
|
1054
|
+
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
1055
|
+
if (ctx_ptr->ctx == NULL) {
|
1056
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1057
|
+
return Qnil;
|
1058
|
+
}
|
1059
|
+
LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
|
1060
|
+
if (cnd_ptr->array.data == nullptr) {
|
1061
|
+
rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
|
1062
|
+
return Qnil;
|
1063
|
+
}
|
1064
|
+
|
1065
|
+
llama_sample_softmax(ctx_ptr->ctx, &(cnd_ptr->array));
|
1066
|
+
|
1067
|
+
return Qnil;
|
1068
|
+
};
|
1069
|
+
|
1070
|
+
static VALUE _llama_context_sample_top_k(int argc, VALUE* argv, VALUE self) {
|
1071
|
+
VALUE kw_args = Qnil;
|
1072
|
+
ID kw_table[2] = { rb_intern("k"), rb_intern("min_keep") };
|
1073
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
1074
|
+
VALUE candidates = Qnil;
|
1075
|
+
rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
|
1076
|
+
rb_get_kwargs(kw_args, kw_table, 1, 1, kw_values);
|
1077
|
+
|
1078
|
+
if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
|
1079
|
+
rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
|
1080
|
+
return Qnil;
|
1081
|
+
}
|
1082
|
+
if (!RB_INTEGER_TYPE_P(kw_values[0])) {
|
1083
|
+
rb_raise(rb_eArgError, "k must be an integer");
|
1084
|
+
return Qnil;
|
1085
|
+
}
|
1086
|
+
if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
|
1087
|
+
rb_raise(rb_eArgError, "min_keep must be an integer");
|
1088
|
+
return Qnil;
|
1089
|
+
}
|
1090
|
+
|
1091
|
+
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
1092
|
+
if (ctx_ptr->ctx == NULL) {
|
1093
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1094
|
+
return Qnil;
|
1095
|
+
}
|
1096
|
+
LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
|
1097
|
+
if (cnd_ptr->array.data == nullptr) {
|
1098
|
+
rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
|
1099
|
+
return Qnil;
|
1100
|
+
}
|
1101
|
+
const int k = NUM2DBL(kw_values[0]);
|
1102
|
+
const size_t min_keep = kw_values[1] != Qundef ? NUM2SIZET(kw_values[1]) : 1;
|
1103
|
+
|
1104
|
+
llama_sample_top_k(ctx_ptr->ctx, &(cnd_ptr->array), k, min_keep);
|
1105
|
+
|
1106
|
+
return Qnil;
|
1107
|
+
};
|
1108
|
+
|
1109
|
+
static VALUE _llama_context_sample_top_p(int argc, VALUE* argv, VALUE self) {
|
1110
|
+
VALUE kw_args = Qnil;
|
1111
|
+
ID kw_table[2] = { rb_intern("prob"), rb_intern("min_keep") };
|
1112
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
1113
|
+
VALUE candidates = Qnil;
|
1114
|
+
rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
|
1115
|
+
rb_get_kwargs(kw_args, kw_table, 1, 1, kw_values);
|
1116
|
+
|
1117
|
+
if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
|
1118
|
+
rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
|
1119
|
+
return Qnil;
|
1120
|
+
}
|
1121
|
+
if (!RB_FLOAT_TYPE_P(kw_values[0])) {
|
1122
|
+
rb_raise(rb_eArgError, "prob must be a float");
|
1123
|
+
return Qnil;
|
1124
|
+
}
|
1125
|
+
if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
|
1126
|
+
rb_raise(rb_eArgError, "min_keep must be an integer");
|
1127
|
+
return Qnil;
|
1128
|
+
}
|
1129
|
+
|
1130
|
+
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
1131
|
+
if (ctx_ptr->ctx == NULL) {
|
1132
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1133
|
+
return Qnil;
|
1134
|
+
}
|
1135
|
+
LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
|
1136
|
+
if (cnd_ptr->array.data == nullptr) {
|
1137
|
+
rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
|
1138
|
+
return Qnil;
|
1139
|
+
}
|
1140
|
+
const float prob = NUM2DBL(kw_values[0]);
|
1141
|
+
const size_t min_keep = kw_values[1] != Qundef ? NUM2SIZET(kw_values[1]) : 1;
|
1142
|
+
|
1143
|
+
llama_sample_top_p(ctx_ptr->ctx, &(cnd_ptr->array), prob, min_keep);
|
1144
|
+
|
1145
|
+
return Qnil;
|
1146
|
+
};
|
1147
|
+
|
1148
|
+
static VALUE _llama_context_sample_tail_free(int argc, VALUE* argv, VALUE self) {
|
1149
|
+
VALUE kw_args = Qnil;
|
1150
|
+
ID kw_table[2] = { rb_intern("z"), rb_intern("min_keep") };
|
1151
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
1152
|
+
VALUE candidates = Qnil;
|
1153
|
+
rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
|
1154
|
+
rb_get_kwargs(kw_args, kw_table, 1, 1, kw_values);
|
1155
|
+
|
1156
|
+
if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
|
1157
|
+
rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
|
1158
|
+
return Qnil;
|
1159
|
+
}
|
1160
|
+
if (!RB_FLOAT_TYPE_P(kw_values[0])) {
|
1161
|
+
rb_raise(rb_eArgError, "prob must be a float");
|
1162
|
+
return Qnil;
|
1163
|
+
}
|
1164
|
+
if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
|
1165
|
+
rb_raise(rb_eArgError, "min_keep must be an integer");
|
1166
|
+
return Qnil;
|
1167
|
+
}
|
1168
|
+
|
1169
|
+
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
1170
|
+
if (ctx_ptr->ctx == NULL) {
|
1171
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1172
|
+
return Qnil;
|
1173
|
+
}
|
1174
|
+
LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
|
1175
|
+
if (cnd_ptr->array.data == nullptr) {
|
1176
|
+
rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
|
1177
|
+
return Qnil;
|
1178
|
+
}
|
1179
|
+
const float z = NUM2DBL(kw_values[0]);
|
1180
|
+
const size_t min_keep = kw_values[1] != Qundef ? NUM2SIZET(kw_values[1]) : 1;
|
1181
|
+
|
1182
|
+
llama_sample_tail_free(ctx_ptr->ctx, &(cnd_ptr->array), z, min_keep);
|
1183
|
+
|
1184
|
+
return Qnil;
|
1185
|
+
};
|
1186
|
+
|
1187
|
+
static VALUE _llama_context_sample_typical(int argc, VALUE* argv, VALUE self) {
|
1188
|
+
VALUE kw_args = Qnil;
|
1189
|
+
ID kw_table[2] = { rb_intern("prob"), rb_intern("min_keep") };
|
1190
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
1191
|
+
VALUE candidates = Qnil;
|
1192
|
+
rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
|
1193
|
+
rb_get_kwargs(kw_args, kw_table, 1, 1, kw_values);
|
1194
|
+
|
1195
|
+
if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
|
1196
|
+
rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
|
1197
|
+
return Qnil;
|
1198
|
+
}
|
1199
|
+
if (!RB_FLOAT_TYPE_P(kw_values[0])) {
|
1200
|
+
rb_raise(rb_eArgError, "prob must be a float");
|
1201
|
+
return Qnil;
|
1202
|
+
}
|
1203
|
+
if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
|
1204
|
+
rb_raise(rb_eArgError, "min_keep must be an integer");
|
1205
|
+
return Qnil;
|
1206
|
+
}
|
1207
|
+
|
1208
|
+
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
1209
|
+
if (ctx_ptr->ctx == NULL) {
|
1210
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1211
|
+
return Qnil;
|
1212
|
+
}
|
1213
|
+
LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
|
1214
|
+
if (cnd_ptr->array.data == nullptr) {
|
1215
|
+
rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
|
1216
|
+
return Qnil;
|
1217
|
+
}
|
1218
|
+
const float prob = NUM2DBL(kw_values[0]);
|
1219
|
+
const size_t min_keep = kw_values[1] != Qundef ? NUM2SIZET(kw_values[1]) : 1;
|
1220
|
+
|
1221
|
+
llama_sample_typical(ctx_ptr->ctx, &(cnd_ptr->array), prob, min_keep);
|
1222
|
+
|
1223
|
+
return Qnil;
|
1224
|
+
};
|
1225
|
+
|
1226
|
+
static VALUE _llama_context_sample_temperature(int argc, VALUE* argv, VALUE self) {
|
1227
|
+
VALUE kw_args = Qnil;
|
1228
|
+
ID kw_table[1] = { rb_intern("temperature") };
|
1229
|
+
VALUE kw_values[1] = { Qundef };
|
1230
|
+
VALUE candidates = Qnil;
|
1231
|
+
rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
|
1232
|
+
rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
|
1233
|
+
|
1234
|
+
if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
|
1235
|
+
rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
|
1236
|
+
return Qnil;
|
1237
|
+
}
|
1238
|
+
if (!RB_FLOAT_TYPE_P(kw_values[0])) {
|
1239
|
+
rb_raise(rb_eArgError, "temperature must be a float");
|
1240
|
+
return Qnil;
|
1241
|
+
}
|
1242
|
+
|
1243
|
+
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
1244
|
+
if (ctx_ptr->ctx == NULL) {
|
1245
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1246
|
+
return Qnil;
|
1247
|
+
}
|
1248
|
+
LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
|
1249
|
+
if (cnd_ptr->array.data == nullptr) {
|
1250
|
+
rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
|
1251
|
+
return Qnil;
|
1252
|
+
}
|
1253
|
+
const float temperature = NUM2DBL(kw_values[0]);
|
1254
|
+
|
1255
|
+
llama_sample_temperature(ctx_ptr->ctx, &(cnd_ptr->array), temperature);
|
1256
|
+
|
1257
|
+
return Qnil;
|
1258
|
+
};
|
1259
|
+
|
1260
|
+
static VALUE _llama_context_sample_token_mirostat(int argc, VALUE* argv, VALUE self) {
|
1261
|
+
VALUE kw_args = Qnil;
|
1262
|
+
ID kw_table[4] = { rb_intern("tau"), rb_intern("eta"), rb_intern("m"), rb_intern("mu") };
|
1263
|
+
VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
|
1264
|
+
VALUE candidates = Qnil;
|
1265
|
+
rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
|
1266
|
+
rb_get_kwargs(kw_args, kw_table, 4, 0, kw_values);
|
1267
|
+
|
1268
|
+
if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
|
1269
|
+
rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
|
1270
|
+
return Qnil;
|
1271
|
+
}
|
1272
|
+
if (!RB_FLOAT_TYPE_P(kw_values[0])) {
|
1273
|
+
rb_raise(rb_eArgError, "tau must be a float");
|
1274
|
+
return Qnil;
|
1275
|
+
}
|
1276
|
+
if (!RB_FLOAT_TYPE_P(kw_values[1])) {
|
1277
|
+
rb_raise(rb_eArgError, "eta must be a float");
|
1278
|
+
return Qnil;
|
1279
|
+
}
|
1280
|
+
if (!RB_INTEGER_TYPE_P(kw_values[2])) {
|
1281
|
+
rb_raise(rb_eArgError, "m must be an integer");
|
1282
|
+
return Qnil;
|
1283
|
+
}
|
1284
|
+
if (!RB_FLOAT_TYPE_P(kw_values[3])) {
|
1285
|
+
rb_raise(rb_eArgError, "mu must be a float");
|
1286
|
+
return Qnil;
|
1287
|
+
}
|
1288
|
+
|
1289
|
+
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
1290
|
+
if (ctx_ptr->ctx == NULL) {
|
1291
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1292
|
+
return Qnil;
|
1293
|
+
}
|
1294
|
+
LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
|
1295
|
+
if (cnd_ptr->array.data == nullptr) {
|
1296
|
+
rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
|
1297
|
+
return Qnil;
|
1298
|
+
}
|
1299
|
+
const float tau = NUM2DBL(kw_values[0]);
|
1300
|
+
const float eta = NUM2DBL(kw_values[1]);
|
1301
|
+
const int m = NUM2INT(kw_values[2]);
|
1302
|
+
float mu = NUM2DBL(kw_values[3]);
|
1303
|
+
|
1304
|
+
llama_token id = llama_sample_token_mirostat(ctx_ptr->ctx, &(cnd_ptr->array), tau, eta, m, &mu);
|
1305
|
+
|
1306
|
+
VALUE ret = rb_ary_new2(2);
|
1307
|
+
rb_ary_store(ret, 0, INT2NUM(id));
|
1308
|
+
rb_ary_store(ret, 1, DBL2NUM(mu));
|
1309
|
+
return ret;
|
1310
|
+
};
|
1311
|
+
|
1312
|
+
static VALUE _llama_context_sample_token_mirostat_v2(int argc, VALUE* argv, VALUE self) {
|
1313
|
+
VALUE kw_args = Qnil;
|
1314
|
+
ID kw_table[3] = { rb_intern("tau"), rb_intern("eta"), rb_intern("mu") };
|
1315
|
+
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
1316
|
+
VALUE candidates = Qnil;
|
1317
|
+
rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
|
1318
|
+
rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
|
1319
|
+
|
1320
|
+
if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
|
1321
|
+
rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
|
1322
|
+
return Qnil;
|
1323
|
+
}
|
1324
|
+
if (!RB_FLOAT_TYPE_P(kw_values[0])) {
|
1325
|
+
rb_raise(rb_eArgError, "tau must be a float");
|
1326
|
+
return Qnil;
|
1327
|
+
}
|
1328
|
+
if (!RB_FLOAT_TYPE_P(kw_values[1])) {
|
1329
|
+
rb_raise(rb_eArgError, "eta must be a float");
|
1330
|
+
return Qnil;
|
1331
|
+
}
|
1332
|
+
if (!RB_FLOAT_TYPE_P(kw_values[2])) {
|
1333
|
+
rb_raise(rb_eArgError, "mu must be a float");
|
1334
|
+
return Qnil;
|
1335
|
+
}
|
1336
|
+
|
1337
|
+
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
1338
|
+
if (ctx_ptr->ctx == NULL) {
|
1339
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1340
|
+
return Qnil;
|
1341
|
+
}
|
1342
|
+
LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
|
1343
|
+
if (cnd_ptr->array.data == nullptr) {
|
1344
|
+
rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
|
1345
|
+
return Qnil;
|
1346
|
+
}
|
1347
|
+
const float tau = NUM2DBL(kw_values[0]);
|
1348
|
+
const float eta = NUM2DBL(kw_values[1]);
|
1349
|
+
float mu = NUM2DBL(kw_values[2]);
|
1350
|
+
|
1351
|
+
llama_token id = llama_sample_token_mirostat_v2(ctx_ptr->ctx, &(cnd_ptr->array), tau, eta, &mu);
|
1352
|
+
|
1353
|
+
VALUE ret = rb_ary_new2(2);
|
1354
|
+
rb_ary_store(ret, 0, INT2NUM(id));
|
1355
|
+
rb_ary_store(ret, 1, DBL2NUM(mu));
|
1356
|
+
return ret;
|
1357
|
+
};
|
1358
|
+
|
1359
|
+
static VALUE _llama_context_sample_token_greedy(VALUE self, VALUE candidates) {
|
1360
|
+
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
1361
|
+
if (ctx_ptr->ctx == NULL) {
|
1362
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1363
|
+
return Qnil;
|
1364
|
+
}
|
1365
|
+
if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
|
1366
|
+
rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
|
1367
|
+
return Qnil;
|
1368
|
+
}
|
1369
|
+
LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
|
1370
|
+
if (cnd_ptr->array.data == nullptr) {
|
1371
|
+
rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
|
1372
|
+
return Qnil;
|
1373
|
+
}
|
1374
|
+
llama_token id = llama_sample_token_greedy(ctx_ptr->ctx, &(cnd_ptr->array));
|
1375
|
+
return INT2NUM(id);
|
1376
|
+
};
|
1377
|
+
|
1378
|
+
static VALUE _llama_context_sample_token(VALUE self, VALUE candidates) {
|
1379
|
+
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
1380
|
+
if (ctx_ptr->ctx == NULL) {
|
1381
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1382
|
+
return Qnil;
|
1383
|
+
}
|
1384
|
+
if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
|
1385
|
+
rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
|
1386
|
+
return Qnil;
|
1387
|
+
}
|
1388
|
+
LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
|
1389
|
+
if (cnd_ptr->array.data == nullptr) {
|
1390
|
+
rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
|
1391
|
+
return Qnil;
|
1392
|
+
}
|
1393
|
+
llama_token id = llama_sample_token(ctx_ptr->ctx, &(cnd_ptr->array));
|
1394
|
+
return INT2NUM(id);
|
1395
|
+
};
|
624
1396
|
};
|
625
1397
|
|
626
1398
|
const rb_data_type_t RbLLaMAContext::llama_context_type = {
|
@@ -680,6 +1452,10 @@ static VALUE rb_llama_token_eos(VALUE self) {
|
|
680
1452
|
return INT2NUM(llama_token_eos());
|
681
1453
|
}
|
682
1454
|
|
1455
|
+
static VALUE rb_llama_token_nl(VALUE self) {
|
1456
|
+
return INT2NUM(llama_token_nl());
|
1457
|
+
}
|
1458
|
+
|
683
1459
|
static VALUE rb_llama_print_system_info(VALUE self) {
|
684
1460
|
const char* result = llama_print_system_info();
|
685
1461
|
return rb_utf8_str_new_cstr(result);
|
@@ -695,12 +1471,16 @@ static VALUE rb_llama_mlock_supported(VALUE self) {
|
|
695
1471
|
|
696
1472
|
extern "C" void Init_llama_cpp(void) {
|
697
1473
|
rb_mLLaMACpp = rb_define_module("LLaMACpp");
|
1474
|
+
|
1475
|
+
RbLLaMATokenData::define_class(rb_mLLaMACpp);
|
1476
|
+
RbLLaMATokenDataArray::define_class(rb_mLLaMACpp);
|
698
1477
|
RbLLaMAContext::define_class(rb_mLLaMACpp);
|
699
1478
|
RbLLaMAContextParams::define_class(rb_mLLaMACpp);
|
700
1479
|
|
701
1480
|
rb_define_module_function(rb_mLLaMACpp, "model_quantize", rb_llama_model_quantize, -1);
|
702
1481
|
rb_define_module_function(rb_mLLaMACpp, "token_bos", rb_llama_token_bos, 0);
|
703
1482
|
rb_define_module_function(rb_mLLaMACpp, "token_eos", rb_llama_token_eos, 0);
|
1483
|
+
rb_define_module_function(rb_mLLaMACpp, "token_nl", rb_llama_token_nl, 0);
|
704
1484
|
rb_define_module_function(rb_mLLaMACpp, "print_system_info", rb_llama_print_system_info, 0);
|
705
1485
|
rb_define_module_function(rb_mLLaMACpp, "mmap_supported?", rb_llama_mmap_supported, 0);
|
706
1486
|
rb_define_module_function(rb_mLLaMACpp, "mlock_supported?", rb_llama_mlock_supported, 0);
|
@@ -710,8 +1490,6 @@ extern "C" void Init_llama_cpp(void) {
|
|
710
1490
|
rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q4_0", INT2NUM(LLAMA_FTYPE_MOSTLY_Q4_0));
|
711
1491
|
rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q4_1", INT2NUM(LLAMA_FTYPE_MOSTLY_Q4_1));
|
712
1492
|
rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16", INT2NUM(LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16));
|
713
|
-
rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q4_2", INT2NUM(LLAMA_FTYPE_MOSTLY_Q4_2));
|
714
|
-
rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q4_3", INT2NUM(LLAMA_FTYPE_MOSTLY_Q4_3));
|
715
1493
|
rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q8_0", INT2NUM(LLAMA_FTYPE_MOSTLY_Q8_0));
|
716
1494
|
rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q5_0", INT2NUM(LLAMA_FTYPE_MOSTLY_Q5_0));
|
717
1495
|
rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q5_1", INT2NUM(LLAMA_FTYPE_MOSTLY_Q5_1));
|