llama_cpp 0.5.3 → 0.7.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +17 -0
- data/README.md +6 -5
- data/examples/chat.rb +13 -13
- data/examples/embedding.rb +9 -9
- data/ext/llama_cpp/llama_cpp.cpp +583 -262
- data/ext/llama_cpp/src/ggml-alloc.c +8 -2
- data/ext/llama_cpp/src/ggml-alloc.h +1 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +326 -149
- data/ext/llama_cpp/src/ggml-cuda.h +1 -0
- data/ext/llama_cpp/src/ggml-metal.h +4 -0
- data/ext/llama_cpp/src/ggml-metal.m +167 -89
- data/ext/llama_cpp/src/ggml-metal.metal +130 -40
- data/ext/llama_cpp/src/ggml-opencl.cpp +119 -53
- data/ext/llama_cpp/src/ggml.c +2355 -1166
- data/ext/llama_cpp/src/ggml.h +129 -35
- data/ext/llama_cpp/src/k_quants.c +744 -2
- data/ext/llama_cpp/src/llama.cpp +1766 -671
- data/ext/llama_cpp/src/llama.h +321 -120
- data/ext/llama_cpp/src/unicode.h +462 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +6 -10
- data/sig/llama_cpp.rbs +70 -34
- metadata +4 -3
data/ext/llama_cpp/llama_cpp.cpp
CHANGED
@@ -1,7 +1,9 @@
|
|
1
1
|
#include "llama_cpp.h"
|
2
2
|
|
3
3
|
VALUE rb_mLLaMACpp;
|
4
|
+
VALUE rb_cLLaMABatch;
|
4
5
|
VALUE rb_cLLaMAModel;
|
6
|
+
VALUE rb_cLLaMAModelParams;
|
5
7
|
VALUE rb_cLLaMATimings;
|
6
8
|
VALUE rb_cLLaMAContext;
|
7
9
|
VALUE rb_cLLaMAContextParams;
|
@@ -11,6 +13,238 @@ VALUE rb_cLLaMATokenDataArray;
|
|
11
13
|
VALUE rb_cLLaMAGrammarElement;
|
12
14
|
VALUE rb_cLLaMAGrammar;
|
13
15
|
|
16
|
+
class LLaMABatchWrapper {
|
17
|
+
public:
|
18
|
+
llama_batch batch;
|
19
|
+
|
20
|
+
LLaMABatchWrapper() {}
|
21
|
+
|
22
|
+
~LLaMABatchWrapper() {
|
23
|
+
llama_batch_free(batch);
|
24
|
+
}
|
25
|
+
};
|
26
|
+
|
27
|
+
class RbLLaMABatch {
|
28
|
+
public:
|
29
|
+
static VALUE llama_batch_alloc(VALUE self) {
|
30
|
+
LLaMABatchWrapper* ptr = (LLaMABatchWrapper*)ruby_xmalloc(sizeof(LLaMABatchWrapper));
|
31
|
+
new (ptr) LLaMABatchWrapper();
|
32
|
+
return TypedData_Wrap_Struct(self, &llama_batch_type, ptr);
|
33
|
+
}
|
34
|
+
|
35
|
+
static void llama_batch_free(void* ptr) {
|
36
|
+
((LLaMABatchWrapper*)ptr)->~LLaMABatchWrapper();
|
37
|
+
ruby_xfree(ptr);
|
38
|
+
}
|
39
|
+
|
40
|
+
static size_t llama_batch_size(const void* ptr) {
|
41
|
+
return sizeof(*((LLaMABatchWrapper*)ptr));
|
42
|
+
}
|
43
|
+
|
44
|
+
static LLaMABatchWrapper* get_llama_batch(VALUE self) {
|
45
|
+
LLaMABatchWrapper* ptr;
|
46
|
+
TypedData_Get_Struct(self, LLaMABatchWrapper, &llama_batch_type, ptr);
|
47
|
+
return ptr;
|
48
|
+
}
|
49
|
+
|
50
|
+
static void define_class(VALUE outer) {
|
51
|
+
rb_cLLaMABatch = rb_define_class_under(outer, "Batch", rb_cObject);
|
52
|
+
rb_define_alloc_func(rb_cLLaMABatch, llama_batch_alloc);
|
53
|
+
rb_define_method(rb_cLLaMABatch, "initialize", RUBY_METHOD_FUNC(_llama_batch_initialize), -1);
|
54
|
+
rb_define_method(rb_cLLaMABatch, "n_tokens=", RUBY_METHOD_FUNC(_llama_batch_set_n_tokens), 1);
|
55
|
+
rb_define_method(rb_cLLaMABatch, "n_tokens", RUBY_METHOD_FUNC(_llama_batch_get_n_tokens), 0);
|
56
|
+
rb_define_method(rb_cLLaMABatch, "all_pos_zero=", RUBY_METHOD_FUNC(_llama_batch_set_all_pos_zero), 1);
|
57
|
+
rb_define_method(rb_cLLaMABatch, "all_pos_zero", RUBY_METHOD_FUNC(_llama_batch_get_all_pos_zero), 0);
|
58
|
+
rb_define_method(rb_cLLaMABatch, "all_pos_one=", RUBY_METHOD_FUNC(_llama_batch_set_all_pos_one), 1);
|
59
|
+
rb_define_method(rb_cLLaMABatch, "all_pos_one", RUBY_METHOD_FUNC(_llama_batch_get_all_pos_one), 0);
|
60
|
+
rb_define_method(rb_cLLaMABatch, "all_seq_id=", RUBY_METHOD_FUNC(_llama_batch_set_all_seq_id), 1);
|
61
|
+
rb_define_method(rb_cLLaMABatch, "all_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_all_seq_id), 0);
|
62
|
+
rb_define_method(rb_cLLaMABatch, "set_token", RUBY_METHOD_FUNC(_llama_batch_set_token), 2);
|
63
|
+
rb_define_method(rb_cLLaMABatch, "get_token", RUBY_METHOD_FUNC(_llama_batch_get_token), 1);
|
64
|
+
rb_define_method(rb_cLLaMABatch, "set_pos", RUBY_METHOD_FUNC(_llama_batch_set_pos), 2);
|
65
|
+
rb_define_method(rb_cLLaMABatch, "get_pos", RUBY_METHOD_FUNC(_llama_batch_get_pos), 1);
|
66
|
+
rb_define_method(rb_cLLaMABatch, "set_seq_id", RUBY_METHOD_FUNC(_llama_batch_set_seq_id), 2);
|
67
|
+
rb_define_method(rb_cLLaMABatch, "get_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_seq_id), 1);
|
68
|
+
rb_define_method(rb_cLLaMABatch, "set_logits", RUBY_METHOD_FUNC(_llama_batch_set_logits), 2);
|
69
|
+
rb_define_method(rb_cLLaMABatch, "get_logits", RUBY_METHOD_FUNC(_llama_batch_get_logits), 1);
|
70
|
+
}
|
71
|
+
|
72
|
+
private:
|
73
|
+
static const rb_data_type_t llama_batch_type;
|
74
|
+
|
75
|
+
static VALUE _llama_batch_initialize(int argc, VALUE* argv, VALUE self) {
|
76
|
+
VALUE kw_args = Qnil;
|
77
|
+
ID kw_table[2] = { rb_intern("n_tokens"), rb_intern("embd") };
|
78
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
79
|
+
rb_scan_args(argc, argv, ":", &kw_args);
|
80
|
+
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
81
|
+
|
82
|
+
if (!RB_INTEGER_TYPE_P(kw_values[0])) {
|
83
|
+
rb_raise(rb_eArgError, "n_tokens must be an integer");
|
84
|
+
return Qnil;
|
85
|
+
}
|
86
|
+
if (!RB_INTEGER_TYPE_P(kw_values[1])) {
|
87
|
+
rb_raise(rb_eArgError, "embd must be an integer");
|
88
|
+
return Qnil;
|
89
|
+
}
|
90
|
+
|
91
|
+
const int32_t n_tokens = NUM2INT(kw_values[0]);
|
92
|
+
const int32_t embd = NUM2INT(kw_values[1]);
|
93
|
+
|
94
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
95
|
+
ptr->batch = llama_batch_init(n_tokens, embd);
|
96
|
+
|
97
|
+
return Qnil;
|
98
|
+
}
|
99
|
+
|
100
|
+
// n_tokens
|
101
|
+
static VALUE _llama_batch_set_n_tokens(VALUE self, VALUE n_tokens) {
|
102
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
103
|
+
ptr->batch.n_tokens = NUM2INT(n_tokens);
|
104
|
+
return INT2NUM(ptr->batch.n_tokens);
|
105
|
+
}
|
106
|
+
|
107
|
+
static VALUE _llama_batch_get_n_tokens(VALUE self) {
|
108
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
109
|
+
return INT2NUM(ptr->batch.n_tokens);
|
110
|
+
}
|
111
|
+
|
112
|
+
// all_pos_0
|
113
|
+
static VALUE _llama_batch_set_all_pos_zero(VALUE self, VALUE all_pos_0) {
|
114
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
115
|
+
ptr->batch.all_pos_0 = NUM2INT(all_pos_0);
|
116
|
+
return INT2NUM(ptr->batch.all_pos_0);
|
117
|
+
}
|
118
|
+
|
119
|
+
static VALUE _llama_batch_get_all_pos_zero(VALUE self) {
|
120
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
121
|
+
return INT2NUM(ptr->batch.all_pos_0);
|
122
|
+
}
|
123
|
+
|
124
|
+
// all_pos_1
|
125
|
+
static VALUE _llama_batch_set_all_pos_one(VALUE self, VALUE all_pos_1) {
|
126
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
127
|
+
ptr->batch.all_pos_1 = NUM2INT(all_pos_1);
|
128
|
+
return INT2NUM(ptr->batch.all_pos_1);
|
129
|
+
}
|
130
|
+
|
131
|
+
static VALUE _llama_batch_get_all_pos_one(VALUE self) {
|
132
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
133
|
+
return INT2NUM(ptr->batch.all_pos_1);
|
134
|
+
}
|
135
|
+
|
136
|
+
// all_seq_id
|
137
|
+
static VALUE _llama_batch_set_all_seq_id(VALUE self, VALUE all_seq_id) {
|
138
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
139
|
+
ptr->batch.all_seq_id = NUM2INT(all_seq_id);
|
140
|
+
return INT2NUM(ptr->batch.all_seq_id);
|
141
|
+
}
|
142
|
+
|
143
|
+
static VALUE _llama_batch_get_all_seq_id(VALUE self) {
|
144
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
145
|
+
return INT2NUM(ptr->batch.all_seq_id);
|
146
|
+
}
|
147
|
+
|
148
|
+
// token
|
149
|
+
static VALUE _llama_batch_set_token(VALUE self, VALUE idx, VALUE value) {
|
150
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
151
|
+
const int32_t id = NUM2INT(idx);
|
152
|
+
if (id < 0 || id >= ptr->batch.n_tokens) {
|
153
|
+
rb_raise(rb_eArgError, "idx must be in [0, n_tokens)");
|
154
|
+
return Qnil;
|
155
|
+
}
|
156
|
+
ptr->batch.token[id] = NUM2INT(value);
|
157
|
+
return INT2NUM(ptr->batch.token[id]);
|
158
|
+
}
|
159
|
+
|
160
|
+
static VALUE _llama_batch_get_token(VALUE self, VALUE idx) {
|
161
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
162
|
+
const int32_t id = NUM2INT(idx);
|
163
|
+
if (id < 0 || id >= ptr->batch.n_tokens) {
|
164
|
+
rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
|
165
|
+
return Qnil;
|
166
|
+
}
|
167
|
+
return INT2NUM(ptr->batch.token[id]);
|
168
|
+
}
|
169
|
+
|
170
|
+
// pos
|
171
|
+
static VALUE _llama_batch_set_pos(VALUE self, VALUE idx, VALUE value) {
|
172
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
173
|
+
const int32_t id = NUM2INT(idx);
|
174
|
+
if (id < 0 || id >= ptr->batch.n_tokens) {
|
175
|
+
rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
|
176
|
+
return Qnil;
|
177
|
+
}
|
178
|
+
ptr->batch.pos[id] = NUM2INT(value);
|
179
|
+
return INT2NUM(ptr->batch.pos[id]);
|
180
|
+
}
|
181
|
+
|
182
|
+
static VALUE _llama_batch_get_pos(VALUE self, VALUE idx) {
|
183
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
184
|
+
const int32_t id = NUM2INT(idx);
|
185
|
+
if (id < 0 || id >= ptr->batch.n_tokens) {
|
186
|
+
rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
|
187
|
+
return Qnil;
|
188
|
+
}
|
189
|
+
return INT2NUM(ptr->batch.pos[id]);
|
190
|
+
}
|
191
|
+
|
192
|
+
// seq_id
|
193
|
+
static VALUE _llama_batch_set_seq_id(VALUE self, VALUE idx, VALUE value) {
|
194
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
195
|
+
const int32_t id = NUM2INT(idx);
|
196
|
+
if (id < 0 || id >= ptr->batch.n_tokens) {
|
197
|
+
rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
|
198
|
+
return Qnil;
|
199
|
+
}
|
200
|
+
ptr->batch.seq_id[id] = NUM2INT(value);
|
201
|
+
return INT2NUM(ptr->batch.seq_id[id]);
|
202
|
+
}
|
203
|
+
|
204
|
+
static VALUE _llama_batch_get_seq_id(VALUE self, VALUE idx) {
|
205
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
206
|
+
const int32_t id = NUM2INT(idx);
|
207
|
+
if (id < 0 || id >= ptr->batch.n_tokens) {
|
208
|
+
rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
|
209
|
+
return Qnil;
|
210
|
+
}
|
211
|
+
return INT2NUM(ptr->batch.seq_id[id]);
|
212
|
+
}
|
213
|
+
|
214
|
+
// logits
|
215
|
+
static VALUE _llama_batch_set_logits(VALUE self, VALUE idx, VALUE value) {
|
216
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
217
|
+
const int32_t id = NUM2INT(idx);
|
218
|
+
if (id < 0 || id >= ptr->batch.n_tokens) {
|
219
|
+
rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
|
220
|
+
return Qnil;
|
221
|
+
}
|
222
|
+
ptr->batch.logits[id] = RTEST(value) ? true : false;
|
223
|
+
return ptr->batch.logits[id] ? Qtrue : Qfalse;
|
224
|
+
}
|
225
|
+
|
226
|
+
static VALUE _llama_batch_get_logits(VALUE self, VALUE idx) {
|
227
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
228
|
+
const int32_t id = NUM2INT(idx);
|
229
|
+
if (id < 0 || id >= ptr->batch.n_tokens) {
|
230
|
+
rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
|
231
|
+
return Qnil;
|
232
|
+
}
|
233
|
+
return ptr->batch.logits[id] ? Qtrue : Qfalse;
|
234
|
+
}
|
235
|
+
};
|
236
|
+
|
237
|
+
const rb_data_type_t RbLLaMABatch::llama_batch_type = {
|
238
|
+
"RbLLaMABatch",
|
239
|
+
{ NULL,
|
240
|
+
RbLLaMABatch::llama_batch_free,
|
241
|
+
RbLLaMABatch::llama_batch_size },
|
242
|
+
NULL,
|
243
|
+
NULL,
|
244
|
+
RUBY_TYPED_FREE_IMMEDIATELY
|
245
|
+
|
246
|
+
};
|
247
|
+
|
14
248
|
class LLaMATokenDataWrapper {
|
15
249
|
public:
|
16
250
|
llama_token_data data;
|
@@ -363,6 +597,144 @@ const rb_data_type_t RbLLaMATimings::llama_timings_type = {
|
|
363
597
|
RUBY_TYPED_FREE_IMMEDIATELY
|
364
598
|
};
|
365
599
|
|
600
|
+
class LLaMAModelParamsWrapper {
|
601
|
+
public:
|
602
|
+
struct llama_model_params params;
|
603
|
+
|
604
|
+
LLaMAModelParamsWrapper() : params(llama_model_default_params()) {}
|
605
|
+
|
606
|
+
~LLaMAModelParamsWrapper() {}
|
607
|
+
};
|
608
|
+
|
609
|
+
class RbLLaMAModelParams {
|
610
|
+
public:
|
611
|
+
static VALUE llama_model_params_alloc(VALUE self) {
|
612
|
+
LLaMAModelParamsWrapper* ptr = (LLaMAModelParamsWrapper*)ruby_xmalloc(sizeof(LLaMAModelParamsWrapper));
|
613
|
+
new (ptr) LLaMAModelParamsWrapper();
|
614
|
+
return TypedData_Wrap_Struct(self, &llama_model_params_type, ptr);
|
615
|
+
}
|
616
|
+
|
617
|
+
static void llama_model_params_free(void* ptr) {
|
618
|
+
((LLaMAModelParamsWrapper*)ptr)->~LLaMAModelParamsWrapper();
|
619
|
+
ruby_xfree(ptr);
|
620
|
+
}
|
621
|
+
|
622
|
+
static size_t llama_model_params_size(const void* ptr) {
|
623
|
+
return sizeof(*((LLaMAModelParamsWrapper*)ptr));
|
624
|
+
}
|
625
|
+
|
626
|
+
static LLaMAModelParamsWrapper* get_llama_model_params(VALUE self) {
|
627
|
+
LLaMAModelParamsWrapper* ptr;
|
628
|
+
TypedData_Get_Struct(self, LLaMAModelParamsWrapper, &llama_model_params_type, ptr);
|
629
|
+
return ptr;
|
630
|
+
}
|
631
|
+
|
632
|
+
static void define_class(VALUE outer) {
|
633
|
+
rb_cLLaMAModelParams = rb_define_class_under(outer, "ModelParams", rb_cObject);
|
634
|
+
rb_define_alloc_func(rb_cLLaMAModelParams, llama_model_params_alloc);
|
635
|
+
rb_define_method(rb_cLLaMAModelParams, "n_gpu_layers=", RUBY_METHOD_FUNC(_llama_model_params_set_n_gpu_layers), 1);
|
636
|
+
rb_define_method(rb_cLLaMAModelParams, "n_gpu_layers", RUBY_METHOD_FUNC(_llama_model_params_get_n_gpu_layers), 0);
|
637
|
+
rb_define_method(rb_cLLaMAModelParams, "main_gpu=", RUBY_METHOD_FUNC(_llama_model_params_set_main_gpu), 1);
|
638
|
+
rb_define_method(rb_cLLaMAModelParams, "main_gpu", RUBY_METHOD_FUNC(_llama_model_params_get_main_gpu), 0);
|
639
|
+
rb_define_method(rb_cLLaMAModelParams, "tensor_split", RUBY_METHOD_FUNC(_llama_model_params_get_tensor_split), 0);
|
640
|
+
rb_define_method(rb_cLLaMAModelParams, "vocab_only=", RUBY_METHOD_FUNC(_llama_model_params_set_vocab_only), 1);
|
641
|
+
rb_define_method(rb_cLLaMAModelParams, "vocab_only", RUBY_METHOD_FUNC(_llama_model_params_get_vocab_only), 0);
|
642
|
+
rb_define_method(rb_cLLaMAModelParams, "use_mmap=", RUBY_METHOD_FUNC(_llama_model_params_set_use_mmap), 1);
|
643
|
+
rb_define_method(rb_cLLaMAModelParams, "use_mmap", RUBY_METHOD_FUNC(_llama_model_params_get_use_mmap), 0);
|
644
|
+
rb_define_method(rb_cLLaMAModelParams, "use_mlock=", RUBY_METHOD_FUNC(_llama_model_params_set_use_mlock), 1);
|
645
|
+
rb_define_method(rb_cLLaMAModelParams, "use_mlock", RUBY_METHOD_FUNC(_llama_model_params_get_use_mlock), 0);
|
646
|
+
}
|
647
|
+
|
648
|
+
private:
|
649
|
+
static const rb_data_type_t llama_model_params_type;
|
650
|
+
|
651
|
+
// n_gpu_layers
|
652
|
+
static VALUE _llama_model_params_set_n_gpu_layers(VALUE self, VALUE n_gpu_layers) {
|
653
|
+
LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
|
654
|
+
ptr->params.n_gpu_layers = NUM2INT(n_gpu_layers);
|
655
|
+
return INT2NUM(ptr->params.n_gpu_layers);
|
656
|
+
}
|
657
|
+
|
658
|
+
static VALUE _llama_model_params_get_n_gpu_layers(VALUE self) {
|
659
|
+
LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
|
660
|
+
return INT2NUM(ptr->params.n_gpu_layers);
|
661
|
+
}
|
662
|
+
|
663
|
+
// main_gpu
|
664
|
+
static VALUE _llama_model_params_set_main_gpu(VALUE self, VALUE main_gpu) {
|
665
|
+
LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
|
666
|
+
ptr->params.main_gpu = NUM2INT(main_gpu);
|
667
|
+
return INT2NUM(ptr->params.main_gpu);
|
668
|
+
}
|
669
|
+
|
670
|
+
static VALUE _llama_model_params_get_main_gpu(VALUE self) {
|
671
|
+
LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
|
672
|
+
return INT2NUM(ptr->params.main_gpu);
|
673
|
+
}
|
674
|
+
|
675
|
+
// tensor_split
|
676
|
+
static VALUE _llama_model_params_get_tensor_split(VALUE self) {
|
677
|
+
if (LLAMA_MAX_DEVICES < 1) {
|
678
|
+
return rb_ary_new();
|
679
|
+
}
|
680
|
+
VALUE ret = rb_ary_new2(LLAMA_MAX_DEVICES);
|
681
|
+
LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
|
682
|
+
if (ptr->params.tensor_split == nullptr) {
|
683
|
+
return rb_ary_new();
|
684
|
+
}
|
685
|
+
for (size_t i = 0; i < LLAMA_MAX_DEVICES; i++) {
|
686
|
+
rb_ary_store(ret, i, DBL2NUM(ptr->params.tensor_split[i]));
|
687
|
+
}
|
688
|
+
return ret;
|
689
|
+
}
|
690
|
+
|
691
|
+
// vocab_only
|
692
|
+
static VALUE _llama_model_params_set_vocab_only(VALUE self, VALUE vocab_only) {
|
693
|
+
LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
|
694
|
+
ptr->params.vocab_only = RTEST(vocab_only) ? true : false;
|
695
|
+
return ptr->params.vocab_only ? Qtrue : Qfalse;
|
696
|
+
}
|
697
|
+
|
698
|
+
static VALUE _llama_model_params_get_vocab_only(VALUE self) {
|
699
|
+
LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
|
700
|
+
return ptr->params.vocab_only ? Qtrue : Qfalse;
|
701
|
+
}
|
702
|
+
|
703
|
+
// use_mmap
|
704
|
+
static VALUE _llama_model_params_set_use_mmap(VALUE self, VALUE use_mmap) {
|
705
|
+
LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
|
706
|
+
ptr->params.use_mmap = RTEST(use_mmap) ? true : false;
|
707
|
+
return ptr->params.use_mmap ? Qtrue : Qfalse;
|
708
|
+
}
|
709
|
+
|
710
|
+
static VALUE _llama_model_params_get_use_mmap(VALUE self) {
|
711
|
+
LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
|
712
|
+
return ptr->params.use_mmap ? Qtrue : Qfalse;
|
713
|
+
}
|
714
|
+
|
715
|
+
// use_mlock
|
716
|
+
static VALUE _llama_model_params_set_use_mlock(VALUE self, VALUE use_mlock) {
|
717
|
+
LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
|
718
|
+
ptr->params.use_mlock = RTEST(use_mlock) ? true : false;
|
719
|
+
return ptr->params.use_mlock ? Qtrue : Qfalse;
|
720
|
+
}
|
721
|
+
|
722
|
+
static VALUE _llama_model_params_get_use_mlock(VALUE self) {
|
723
|
+
LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
|
724
|
+
return ptr->params.use_mlock ? Qtrue : Qfalse;
|
725
|
+
}
|
726
|
+
};
|
727
|
+
|
728
|
+
const rb_data_type_t RbLLaMAModelParams::llama_model_params_type = {
|
729
|
+
"RbLLaMAModelParams",
|
730
|
+
{ NULL,
|
731
|
+
RbLLaMAModelParams::llama_model_params_free,
|
732
|
+
RbLLaMAModelParams::llama_model_params_size },
|
733
|
+
NULL,
|
734
|
+
NULL,
|
735
|
+
RUBY_TYPED_FREE_IMMEDIATELY
|
736
|
+
};
|
737
|
+
|
366
738
|
class LLaMAContextParamsWrapper {
|
367
739
|
public:
|
368
740
|
struct llama_context_params params;
|
@@ -399,35 +771,26 @@ public:
|
|
399
771
|
rb_cLLaMAContextParams = rb_define_class_under(outer, "ContextParams", rb_cObject);
|
400
772
|
rb_define_alloc_func(rb_cLLaMAContextParams, llama_context_params_alloc);
|
401
773
|
// rb_define_method(rb_cLLaMAContextParams, "initialize", RUBY_METHOD_FUNC(_llama_context_params_init), 0);
|
774
|
+
rb_define_method(rb_cLLaMAContextParams, "seed=", RUBY_METHOD_FUNC(_llama_context_params_set_seed), 1);
|
775
|
+
rb_define_method(rb_cLLaMAContextParams, "seed", RUBY_METHOD_FUNC(_llama_context_params_get_seed), 0);
|
402
776
|
rb_define_method(rb_cLLaMAContextParams, "n_ctx=", RUBY_METHOD_FUNC(_llama_context_params_set_n_ctx), 1);
|
403
777
|
rb_define_method(rb_cLLaMAContextParams, "n_ctx", RUBY_METHOD_FUNC(_llama_context_params_get_n_ctx), 0);
|
404
778
|
rb_define_method(rb_cLLaMAContextParams, "n_batch=", RUBY_METHOD_FUNC(_llama_context_params_set_n_batch), 1);
|
405
779
|
rb_define_method(rb_cLLaMAContextParams, "n_batch", RUBY_METHOD_FUNC(_llama_context_params_get_n_batch), 0);
|
406
|
-
rb_define_method(rb_cLLaMAContextParams, "
|
407
|
-
rb_define_method(rb_cLLaMAContextParams, "
|
408
|
-
rb_define_method(rb_cLLaMAContextParams, "
|
409
|
-
rb_define_method(rb_cLLaMAContextParams, "
|
410
|
-
rb_define_method(rb_cLLaMAContextParams, "tensor_split", RUBY_METHOD_FUNC(_llama_context_params_get_tensor_split), 0);
|
780
|
+
rb_define_method(rb_cLLaMAContextParams, "n_threads=", RUBY_METHOD_FUNC(_llama_context_params_set_n_threads), 1);
|
781
|
+
rb_define_method(rb_cLLaMAContextParams, "n_threads", RUBY_METHOD_FUNC(_llama_context_params_get_n_threads), 0);
|
782
|
+
rb_define_method(rb_cLLaMAContextParams, "n_threads_batch=", RUBY_METHOD_FUNC(_llama_context_params_set_n_threads_batch), 1);
|
783
|
+
rb_define_method(rb_cLLaMAContextParams, "n_threads_batch", RUBY_METHOD_FUNC(_llama_context_params_get_n_threads_batch), 0);
|
411
784
|
rb_define_method(rb_cLLaMAContextParams, "rope_freq_base=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_freq_base), 1);
|
412
785
|
rb_define_method(rb_cLLaMAContextParams, "rope_freq_base", RUBY_METHOD_FUNC(_llama_context_params_get_rope_freq_base), 0);
|
413
786
|
rb_define_method(rb_cLLaMAContextParams, "rope_freq_scale=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_freq_scale), 1);
|
414
787
|
rb_define_method(rb_cLLaMAContextParams, "rope_freq_scale", RUBY_METHOD_FUNC(_llama_context_params_get_rope_freq_scale), 0);
|
415
|
-
rb_define_method(rb_cLLaMAContextParams, "low_vram=", RUBY_METHOD_FUNC(_llama_context_params_set_low_vram), 1);
|
416
|
-
rb_define_method(rb_cLLaMAContextParams, "low_vram", RUBY_METHOD_FUNC(_llama_context_params_get_low_vram), 0);
|
417
788
|
rb_define_method(rb_cLLaMAContextParams, "mul_mat_q=", RUBY_METHOD_FUNC(_llama_context_params_set_mul_mat_q), 1);
|
418
789
|
rb_define_method(rb_cLLaMAContextParams, "mul_mat_q", RUBY_METHOD_FUNC(_llama_context_params_get_mul_mat_q), 0);
|
419
|
-
rb_define_method(rb_cLLaMAContextParams, "seed=", RUBY_METHOD_FUNC(_llama_context_params_set_seed), 1);
|
420
|
-
rb_define_method(rb_cLLaMAContextParams, "seed", RUBY_METHOD_FUNC(_llama_context_params_get_seed), 0);
|
421
790
|
rb_define_method(rb_cLLaMAContextParams, "f16_kv=", RUBY_METHOD_FUNC(_llama_context_params_set_f16_kv), 1);
|
422
791
|
rb_define_method(rb_cLLaMAContextParams, "f16_kv", RUBY_METHOD_FUNC(_llama_context_params_get_f16_kv), 0);
|
423
792
|
rb_define_method(rb_cLLaMAContextParams, "logits_all=", RUBY_METHOD_FUNC(_llama_context_params_set_logits_all), 1);
|
424
793
|
rb_define_method(rb_cLLaMAContextParams, "logits_all", RUBY_METHOD_FUNC(_llama_context_params_get_logits_all), 0);
|
425
|
-
rb_define_method(rb_cLLaMAContextParams, "vocab_only=", RUBY_METHOD_FUNC(_llama_context_params_set_vocab_only), 1);
|
426
|
-
rb_define_method(rb_cLLaMAContextParams, "vocab_only", RUBY_METHOD_FUNC(_llama_context_params_get_vocab_only), 0);
|
427
|
-
rb_define_method(rb_cLLaMAContextParams, "use_mmap=", RUBY_METHOD_FUNC(_llama_context_params_set_use_mmap), 1);
|
428
|
-
rb_define_method(rb_cLLaMAContextParams, "use_mmap", RUBY_METHOD_FUNC(_llama_context_params_get_use_mmap), 0);
|
429
|
-
rb_define_method(rb_cLLaMAContextParams, "use_mlock=", RUBY_METHOD_FUNC(_llama_context_params_set_use_mlock), 1);
|
430
|
-
rb_define_method(rb_cLLaMAContextParams, "use_mlock", RUBY_METHOD_FUNC(_llama_context_params_get_use_mlock), 0);
|
431
794
|
rb_define_method(rb_cLLaMAContextParams, "embedding=", RUBY_METHOD_FUNC(_llama_context_params_set_embedding), 1);
|
432
795
|
rb_define_method(rb_cLLaMAContextParams, "embedding", RUBY_METHOD_FUNC(_llama_context_params_get_embedding), 0);
|
433
796
|
}
|
@@ -441,6 +804,22 @@ private:
|
|
441
804
|
// return self;
|
442
805
|
// }
|
443
806
|
|
807
|
+
// seed
|
808
|
+
static VALUE _llama_context_params_set_seed(VALUE self, VALUE seed) {
|
809
|
+
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
810
|
+
if (NUM2INT(seed) < 0) {
|
811
|
+
rb_raise(rb_eArgError, "seed must be positive");
|
812
|
+
return Qnil;
|
813
|
+
}
|
814
|
+
ptr->params.seed = NUM2INT(seed);
|
815
|
+
return INT2NUM(ptr->params.seed);
|
816
|
+
}
|
817
|
+
|
818
|
+
static VALUE _llama_context_params_get_seed(VALUE self) {
|
819
|
+
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
820
|
+
return INT2NUM(ptr->params.seed);
|
821
|
+
}
|
822
|
+
|
444
823
|
// n_ctx
|
445
824
|
static VALUE _llama_context_params_set_n_ctx(VALUE self, VALUE n_ctx) {
|
446
825
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
@@ -465,41 +844,28 @@ private:
|
|
465
844
|
return INT2NUM(ptr->params.n_batch);
|
466
845
|
}
|
467
846
|
|
468
|
-
//
|
469
|
-
static VALUE
|
847
|
+
// n_threads
|
848
|
+
static VALUE _llama_context_params_set_n_threads(VALUE self, VALUE n_threads) {
|
470
849
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
471
|
-
ptr->params.
|
472
|
-
return INT2NUM(ptr->params.
|
850
|
+
ptr->params.n_threads = NUM2INT(n_threads);
|
851
|
+
return INT2NUM(ptr->params.n_threads);
|
473
852
|
}
|
474
853
|
|
475
|
-
static VALUE
|
854
|
+
static VALUE _llama_context_params_get_n_threads(VALUE self) {
|
476
855
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
477
|
-
return INT2NUM(ptr->params.
|
856
|
+
return INT2NUM(ptr->params.n_threads);
|
478
857
|
}
|
479
858
|
|
480
|
-
//
|
481
|
-
static VALUE
|
859
|
+
// n_threads_batch
|
860
|
+
static VALUE _llama_context_params_set_n_threads_batch(VALUE self, VALUE n_threads_batch) {
|
482
861
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
483
|
-
ptr->params.
|
484
|
-
return INT2NUM(ptr->params.
|
862
|
+
ptr->params.n_threads_batch = NUM2INT(n_threads_batch);
|
863
|
+
return INT2NUM(ptr->params.n_threads_batch);
|
485
864
|
}
|
486
865
|
|
487
|
-
static VALUE
|
866
|
+
static VALUE _llama_context_params_get_n_threads_batch(VALUE self) {
|
488
867
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
489
|
-
return INT2NUM(ptr->params.
|
490
|
-
}
|
491
|
-
|
492
|
-
// tensor_split
|
493
|
-
static VALUE _llama_context_params_get_tensor_split(VALUE self) {
|
494
|
-
if (LLAMA_MAX_DEVICES < 1) {
|
495
|
-
return rb_ary_new();
|
496
|
-
}
|
497
|
-
VALUE ret = rb_ary_new2(LLAMA_MAX_DEVICES);
|
498
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
499
|
-
for (size_t i = 0; i < LLAMA_MAX_DEVICES; i++) {
|
500
|
-
rb_ary_store(ret, i, DBL2NUM(ptr->params.tensor_split[i]));
|
501
|
-
}
|
502
|
-
return ret;
|
868
|
+
return INT2NUM(ptr->params.n_threads_batch);
|
503
869
|
}
|
504
870
|
|
505
871
|
// rope_freq_base
|
@@ -526,18 +892,6 @@ private:
|
|
526
892
|
return DBL2NUM(ptr->params.rope_freq_scale);
|
527
893
|
}
|
528
894
|
|
529
|
-
// low_vram
|
530
|
-
static VALUE _llama_context_params_set_low_vram(VALUE self, VALUE low_vram) {
|
531
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
532
|
-
ptr->params.low_vram = RTEST(low_vram) ? true : false;
|
533
|
-
return ptr->params.low_vram ? Qtrue : Qfalse;
|
534
|
-
}
|
535
|
-
|
536
|
-
static VALUE _llama_context_params_get_low_vram(VALUE self) {
|
537
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
538
|
-
return ptr->params.low_vram ? Qtrue : Qfalse;
|
539
|
-
}
|
540
|
-
|
541
895
|
// mul_mat_q
|
542
896
|
static VALUE _llama_context_params_set_mul_mat_q(VALUE self, VALUE mul_mat_q) {
|
543
897
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
@@ -550,22 +904,6 @@ private:
|
|
550
904
|
return ptr->params.mul_mat_q ? Qtrue : Qfalse;
|
551
905
|
}
|
552
906
|
|
553
|
-
// seed
|
554
|
-
static VALUE _llama_context_params_set_seed(VALUE self, VALUE seed) {
|
555
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
556
|
-
if (NUM2INT(seed) < 0) {
|
557
|
-
rb_raise(rb_eArgError, "seed must be positive");
|
558
|
-
return Qnil;
|
559
|
-
}
|
560
|
-
ptr->params.seed = NUM2INT(seed);
|
561
|
-
return INT2NUM(ptr->params.seed);
|
562
|
-
}
|
563
|
-
|
564
|
-
static VALUE _llama_context_params_get_seed(VALUE self) {
|
565
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
566
|
-
return INT2NUM(ptr->params.seed);
|
567
|
-
}
|
568
|
-
|
569
907
|
// f16_kv
|
570
908
|
static VALUE _llama_context_params_set_f16_kv(VALUE self, VALUE f16_kv) {
|
571
909
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
@@ -590,42 +928,6 @@ private:
|
|
590
928
|
return ptr->params.logits_all ? Qtrue : Qfalse;
|
591
929
|
}
|
592
930
|
|
593
|
-
// vocab_only
|
594
|
-
static VALUE _llama_context_params_set_vocab_only(VALUE self, VALUE vocab_only) {
|
595
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
596
|
-
ptr->params.vocab_only = RTEST(vocab_only) ? true : false;
|
597
|
-
return ptr->params.vocab_only ? Qtrue : Qfalse;
|
598
|
-
}
|
599
|
-
|
600
|
-
static VALUE _llama_context_params_get_vocab_only(VALUE self) {
|
601
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
602
|
-
return ptr->params.vocab_only ? Qtrue : Qfalse;
|
603
|
-
}
|
604
|
-
|
605
|
-
// use_mmap
|
606
|
-
static VALUE _llama_context_params_set_use_mmap(VALUE self, VALUE use_mmap) {
|
607
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
608
|
-
ptr->params.use_mmap = RTEST(use_mmap) ? true : false;
|
609
|
-
return ptr->params.use_mmap ? Qtrue : Qfalse;
|
610
|
-
}
|
611
|
-
|
612
|
-
static VALUE _llama_context_params_get_use_mmap(VALUE self) {
|
613
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
614
|
-
return ptr->params.use_mmap ? Qtrue : Qfalse;
|
615
|
-
}
|
616
|
-
|
617
|
-
// use_mlock
|
618
|
-
static VALUE _llama_context_params_set_use_mlock(VALUE self, VALUE use_mlock) {
|
619
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
620
|
-
ptr->params.use_mlock = RTEST(use_mlock) ? true : false;
|
621
|
-
return ptr->params.use_mlock ? Qtrue : Qfalse;
|
622
|
-
}
|
623
|
-
|
624
|
-
static VALUE _llama_context_params_get_use_mlock(VALUE self) {
|
625
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
626
|
-
return ptr->params.use_mlock ? Qtrue : Qfalse;
|
627
|
-
}
|
628
|
-
|
629
931
|
// embedding
|
630
932
|
static VALUE _llama_context_params_set_embedding(VALUE self, VALUE embedding) {
|
631
933
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
@@ -823,11 +1125,11 @@ public:
|
|
823
1125
|
rb_define_method(rb_cLLaMAModel, "load", RUBY_METHOD_FUNC(_llama_model_load), -1);
|
824
1126
|
rb_define_method(rb_cLLaMAModel, "apply_lora_from_file", RUBY_METHOD_FUNC(_llama_model_apply_lora_from_file), -1);
|
825
1127
|
rb_define_method(rb_cLLaMAModel, "n_vocab", RUBY_METHOD_FUNC(_llama_model_get_model_n_vocab), 0);
|
826
|
-
rb_define_method(rb_cLLaMAModel, "n_ctx", RUBY_METHOD_FUNC(_llama_model_get_model_n_ctx), 0);
|
827
1128
|
rb_define_method(rb_cLLaMAModel, "n_ctx_train", RUBY_METHOD_FUNC(_llama_model_get_model_n_ctx_train), 0);
|
828
1129
|
rb_define_method(rb_cLLaMAModel, "n_embd", RUBY_METHOD_FUNC(_llama_model_get_model_n_embd), 0);
|
829
|
-
rb_define_method(rb_cLLaMAModel, "
|
830
|
-
rb_define_method(rb_cLLaMAModel, "
|
1130
|
+
rb_define_method(rb_cLLaMAModel, "rope_freq_scale_train", RUBY_METHOD_FUNC(_llama_model_rope_freq_scale_train), 0);
|
1131
|
+
rb_define_method(rb_cLLaMAModel, "token_to_piece", RUBY_METHOD_FUNC(_llama_model_token_to_piece), 1);
|
1132
|
+
rb_define_method(rb_cLLaMAModel, "tokenize", RUBY_METHOD_FUNC(_llama_model_tokenize), -1);
|
831
1133
|
rb_define_method(rb_cLLaMAModel, "desc", RUBY_METHOD_FUNC(_llama_model_get_model_desc), 0);
|
832
1134
|
rb_define_method(rb_cLLaMAModel, "size", RUBY_METHOD_FUNC(_llama_model_get_model_size), 0);
|
833
1135
|
rb_define_method(rb_cLLaMAModel, "n_params", RUBY_METHOD_FUNC(_llama_model_get_model_n_params), 0);
|
@@ -841,30 +1143,21 @@ private:
|
|
841
1143
|
ID kw_table[2] = { rb_intern("model_path"), rb_intern("params") };
|
842
1144
|
VALUE kw_values[2] = { Qundef, Qundef };
|
843
1145
|
rb_scan_args(argc, argv, ":", &kw_args);
|
844
|
-
rb_get_kwargs(kw_args, kw_table,
|
845
|
-
|
846
|
-
if (kw_values[0] == Qundef && kw_values[1] == Qundef) {
|
847
|
-
rb_iv_set(self, "@params", Qnil);
|
848
|
-
return Qnil;
|
849
|
-
}
|
1146
|
+
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
850
1147
|
|
851
1148
|
if (!RB_TYPE_P(kw_values[0], T_STRING)) {
|
852
1149
|
rb_raise(rb_eArgError, "model_path must be a string");
|
853
1150
|
return Qnil;
|
854
1151
|
}
|
855
|
-
if (!rb_obj_is_kind_of(kw_values[1],
|
856
|
-
rb_raise(rb_eArgError, "params must be a
|
1152
|
+
if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAModelParams)) {
|
1153
|
+
rb_raise(rb_eArgError, "params must be a ModelParams");
|
857
1154
|
return Qnil;
|
858
1155
|
}
|
859
1156
|
|
860
1157
|
VALUE filename = kw_values[0];
|
861
|
-
|
1158
|
+
LLaMAModelParamsWrapper* prms_ptr = RbLLaMAModelParams::get_llama_model_params(kw_values[1]);
|
862
1159
|
LLaMAModelWrapper* model_ptr = get_llama_model(self);
|
863
1160
|
|
864
|
-
if (prms_ptr->params.seed == LLAMA_DEFAULT_SEED) {
|
865
|
-
prms_ptr->params.seed = time(NULL);
|
866
|
-
}
|
867
|
-
|
868
1161
|
try {
|
869
1162
|
model_ptr->model = llama_load_model_from_file(StringValueCStr(filename), prms_ptr->params);
|
870
1163
|
} catch (const std::runtime_error& e) {
|
@@ -912,8 +1205,8 @@ private:
|
|
912
1205
|
rb_raise(rb_eArgError, "model_path must be a string");
|
913
1206
|
return Qnil;
|
914
1207
|
}
|
915
|
-
if (!rb_obj_is_kind_of(kw_values[1],
|
916
|
-
rb_raise(rb_eArgError, "params must be a
|
1208
|
+
if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAModelParams)) {
|
1209
|
+
rb_raise(rb_eArgError, "params must be a LLaMAModelParams");
|
917
1210
|
return Qnil;
|
918
1211
|
}
|
919
1212
|
|
@@ -924,7 +1217,7 @@ private:
|
|
924
1217
|
}
|
925
1218
|
|
926
1219
|
VALUE filename = kw_values[0];
|
927
|
-
|
1220
|
+
LLaMAModelParamsWrapper* prms_ptr = RbLLaMAModelParams::get_llama_model_params(kw_values[1]);
|
928
1221
|
|
929
1222
|
try {
|
930
1223
|
model_ptr->model = llama_load_model_from_file(StringValueCStr(filename), prms_ptr->params);
|
@@ -946,10 +1239,10 @@ private:
|
|
946
1239
|
|
947
1240
|
static VALUE _llama_model_apply_lora_from_file(int argc, VALUE* argv, VALUE self) {
|
948
1241
|
VALUE kw_args = Qnil;
|
949
|
-
ID kw_table[
|
950
|
-
VALUE kw_values[
|
1242
|
+
ID kw_table[4] = { rb_intern("lora_path"), rb_intern("base_model_path"), rb_intern("n_threads"), rb_intern("scale") };
|
1243
|
+
VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
|
951
1244
|
rb_scan_args(argc, argv, ":", &kw_args);
|
952
|
-
rb_get_kwargs(kw_args, kw_table, 1,
|
1245
|
+
rb_get_kwargs(kw_args, kw_table, 1, 3, kw_values);
|
953
1246
|
|
954
1247
|
if (!RB_TYPE_P(kw_values[0], T_STRING)) {
|
955
1248
|
rb_raise(rb_eArgError, "lora_path must be a string");
|
@@ -963,13 +1256,18 @@ private:
|
|
963
1256
|
rb_raise(rb_eArgError, "n_threads must be an integer");
|
964
1257
|
return Qnil;
|
965
1258
|
}
|
1259
|
+
if (kw_values[3] != Qundef && !RB_FLOAT_TYPE_P(kw_values[3])) {
|
1260
|
+
rb_raise(rb_eArgError, "scale must be a float");
|
1261
|
+
return Qnil;
|
1262
|
+
}
|
966
1263
|
|
967
1264
|
const char* lora_path = StringValueCStr(kw_values[0]);
|
968
1265
|
const char* base_model_path = kw_values[1] == Qundef ? NULL : StringValueCStr(kw_values[1]);
|
969
1266
|
const int n_threads = kw_values[2] == Qundef ? 1 : NUM2INT(kw_values[2]);
|
1267
|
+
const float scale = kw_values[3] == Qundef ? 1.0 : NUM2DBL(kw_values[3]);
|
970
1268
|
|
971
1269
|
LLaMAModelWrapper* ptr = get_llama_model(self);
|
972
|
-
if (llama_model_apply_lora_from_file(ptr->model, lora_path, base_model_path, n_threads) != 0) {
|
1270
|
+
if (llama_model_apply_lora_from_file(ptr->model, lora_path, scale, base_model_path, n_threads) != 0) {
|
973
1271
|
rb_raise(rb_eRuntimeError, "Failed to apply LoRA");
|
974
1272
|
return Qnil;
|
975
1273
|
}
|
@@ -978,25 +1276,25 @@ private:
|
|
978
1276
|
|
979
1277
|
static VALUE _llama_model_get_model_n_vocab(VALUE self) {
|
980
1278
|
LLaMAModelWrapper* ptr = get_llama_model(self);
|
981
|
-
return INT2NUM(
|
1279
|
+
return INT2NUM(llama_n_vocab(ptr->model));
|
982
1280
|
}
|
983
1281
|
|
984
|
-
static VALUE
|
1282
|
+
static VALUE _llama_model_get_model_n_ctx_train(VALUE self) {
|
985
1283
|
LLaMAModelWrapper* ptr = get_llama_model(self);
|
986
|
-
return INT2NUM(
|
1284
|
+
return INT2NUM(llama_n_ctx_train(ptr->model));
|
987
1285
|
}
|
988
1286
|
|
989
|
-
static VALUE
|
1287
|
+
static VALUE _llama_model_get_model_n_embd(VALUE self) {
|
990
1288
|
LLaMAModelWrapper* ptr = get_llama_model(self);
|
991
|
-
return INT2NUM(
|
1289
|
+
return INT2NUM(llama_n_embd(ptr->model));
|
992
1290
|
}
|
993
1291
|
|
994
|
-
static VALUE
|
1292
|
+
static VALUE _llama_model_rope_freq_scale_train(VALUE self) {
|
995
1293
|
LLaMAModelWrapper* ptr = get_llama_model(self);
|
996
|
-
return
|
1294
|
+
return DBL2NUM(llama_rope_freq_scale_train(ptr->model));
|
997
1295
|
}
|
998
1296
|
|
999
|
-
static VALUE
|
1297
|
+
static VALUE _llama_model_token_to_piece(VALUE self, VALUE token_) {
|
1000
1298
|
if (!RB_INTEGER_TYPE_P(token_)) {
|
1001
1299
|
rb_raise(rb_eArgError, "token must be an integer");
|
1002
1300
|
return Qnil;
|
@@ -1004,10 +1302,10 @@ private:
|
|
1004
1302
|
const llama_token token = NUM2INT(token_);
|
1005
1303
|
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1006
1304
|
std::vector<char> result(8, 0);
|
1007
|
-
const int n_tokens =
|
1305
|
+
const int n_tokens = llama_token_to_piece(ptr->model, token, result.data(), result.size());
|
1008
1306
|
if (n_tokens < 0) {
|
1009
1307
|
result.resize(-n_tokens);
|
1010
|
-
const int check =
|
1308
|
+
const int check = llama_token_to_piece(ptr->model, token, result.data(), result.size());
|
1011
1309
|
if (check != -n_tokens) {
|
1012
1310
|
rb_raise(rb_eRuntimeError, "failed to convert");
|
1013
1311
|
return Qnil;
|
@@ -1016,10 +1314,10 @@ private:
|
|
1016
1314
|
result.resize(n_tokens);
|
1017
1315
|
}
|
1018
1316
|
std::string ret(result.data(), result.size());
|
1019
|
-
return
|
1317
|
+
return rb_utf8_str_new_cstr(ret.c_str());
|
1020
1318
|
}
|
1021
1319
|
|
1022
|
-
static VALUE
|
1320
|
+
static VALUE _llama_model_tokenize(int argc, VALUE* argv, VALUE self) {
|
1023
1321
|
VALUE kw_args = Qnil;
|
1024
1322
|
ID kw_table[3] = { rb_intern("text"), rb_intern("n_max_tokens"), rb_intern("add_bos") };
|
1025
1323
|
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
@@ -1046,7 +1344,7 @@ private:
|
|
1046
1344
|
|
1047
1345
|
llama_token* tokens = ALLOCA_N(llama_token, n_max_tokens);
|
1048
1346
|
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1049
|
-
const int n_tokens =
|
1347
|
+
const int n_tokens = llama_tokenize(ptr->model, text.c_str(), text.size(), tokens, n_max_tokens, add_bos);
|
1050
1348
|
|
1051
1349
|
if (n_tokens < 0) {
|
1052
1350
|
rb_raise(rb_eRuntimeError, "failed to tokenize. The numebr of tokens (%d) is greater than n_max_tokens.", -n_tokens);
|
@@ -1066,7 +1364,7 @@ private:
|
|
1066
1364
|
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1067
1365
|
char buf[128];
|
1068
1366
|
llama_model_desc(ptr->model, buf, sizeof(buf));
|
1069
|
-
return
|
1367
|
+
return rb_utf8_str_new_cstr(buf);
|
1070
1368
|
}
|
1071
1369
|
|
1072
1370
|
static VALUE _llama_model_get_model_size(VALUE self) {
|
@@ -1345,11 +1643,11 @@ public:
|
|
1345
1643
|
static void define_class(VALUE outer) {
|
1346
1644
|
rb_cLLaMAContext = rb_define_class_under(outer, "Context", rb_cObject);
|
1347
1645
|
rb_define_alloc_func(rb_cLLaMAContext, llama_context_alloc);
|
1646
|
+
rb_define_attr(rb_cLLaMAContext, "model", 1, 0);
|
1348
1647
|
rb_define_method(rb_cLLaMAContext, "initialize", RUBY_METHOD_FUNC(_llama_context_initialize), -1);
|
1349
1648
|
rb_define_method(rb_cLLaMAContext, "eval", RUBY_METHOD_FUNC(_llama_context_eval), -1);
|
1350
1649
|
rb_define_method(rb_cLLaMAContext, "eval_embd", RUBY_METHOD_FUNC(_llama_context_eval_embd), -1);
|
1351
|
-
rb_define_method(rb_cLLaMAContext, "
|
1352
|
-
rb_define_method(rb_cLLaMAContext, "tokenize", RUBY_METHOD_FUNC(_llama_context_tokenize), -1);
|
1650
|
+
rb_define_method(rb_cLLaMAContext, "decode", RUBY_METHOD_FUNC(_llama_context_decode), 1);
|
1353
1651
|
rb_define_method(rb_cLLaMAContext, "logits", RUBY_METHOD_FUNC(_llama_context_logits), 0);
|
1354
1652
|
rb_define_method(rb_cLLaMAContext, "embeddings", RUBY_METHOD_FUNC(_llama_context_embeddings), 0);
|
1355
1653
|
rb_define_method(rb_cLLaMAContext, "text", RUBY_METHOD_FUNC(_llama_context_text), 1);
|
@@ -1358,15 +1656,20 @@ public:
|
|
1358
1656
|
rb_define_method(rb_cLLaMAContext, "token_bos", RUBY_METHOD_FUNC(_llama_context_token_bos), 0);
|
1359
1657
|
rb_define_method(rb_cLLaMAContext, "token_eos", RUBY_METHOD_FUNC(_llama_context_token_eos), 0);
|
1360
1658
|
rb_define_method(rb_cLLaMAContext, "token_nl", RUBY_METHOD_FUNC(_llama_context_token_nl), 0);
|
1361
|
-
rb_define_method(rb_cLLaMAContext, "
|
1362
|
-
rb_define_method(rb_cLLaMAContext, "
|
1659
|
+
rb_define_method(rb_cLLaMAContext, "token_prefix", RUBY_METHOD_FUNC(_llama_context_token_prefix), 0);
|
1660
|
+
rb_define_method(rb_cLLaMAContext, "token_middle", RUBY_METHOD_FUNC(_llama_context_token_middle), 0);
|
1661
|
+
rb_define_method(rb_cLLaMAContext, "token_suffix", RUBY_METHOD_FUNC(_llama_context_token_suffix), 0);
|
1662
|
+
rb_define_method(rb_cLLaMAContext, "token_eot", RUBY_METHOD_FUNC(_llama_context_token_eot), 0);
|
1363
1663
|
rb_define_method(rb_cLLaMAContext, "n_ctx", RUBY_METHOD_FUNC(_llama_context_n_ctx), 0);
|
1364
|
-
rb_define_method(rb_cLLaMAContext, "n_ctx_train", RUBY_METHOD_FUNC(_llama_context_n_ctx_train), 0);
|
1365
|
-
rb_define_method(rb_cLLaMAContext, "n_embd", RUBY_METHOD_FUNC(_llama_context_n_embd), 0);
|
1366
1664
|
rb_define_method(rb_cLLaMAContext, "timings", RUBY_METHOD_FUNC(_llama_context_get_timings), 0);
|
1367
1665
|
rb_define_method(rb_cLLaMAContext, "print_timings", RUBY_METHOD_FUNC(_llama_context_print_timings), 0);
|
1368
1666
|
rb_define_method(rb_cLLaMAContext, "reset_timings", RUBY_METHOD_FUNC(_llama_context_reset_timings), 0);
|
1369
1667
|
rb_define_method(rb_cLLaMAContext, "kv_cache_token_count", RUBY_METHOD_FUNC(_llama_context_kv_cache_token_count), 0);
|
1668
|
+
rb_define_method(rb_cLLaMAContext, "kv_cache_tokens_rm", RUBY_METHOD_FUNC(_llama_context_kv_cache_tokens_rm), 2);
|
1669
|
+
rb_define_method(rb_cLLaMAContext, "kv_cache_seq_rm", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_rm), 3);
|
1670
|
+
rb_define_method(rb_cLLaMAContext, "kv_cache_seq_cp", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_cp), 4);
|
1671
|
+
rb_define_method(rb_cLLaMAContext, "kv_cache_seq_keep", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_keep), 1);
|
1672
|
+
rb_define_method(rb_cLLaMAContext, "kv_cache_seq_shift", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_shift), 4);
|
1370
1673
|
rb_define_method(rb_cLLaMAContext, "set_rng_seed", RUBY_METHOD_FUNC(_llama_context_set_rng_seed), 1);
|
1371
1674
|
rb_define_method(rb_cLLaMAContext, "load_session_file", RUBY_METHOD_FUNC(_llama_context_load_session_file), -1);
|
1372
1675
|
rb_define_method(rb_cLLaMAContext, "save_session_file", RUBY_METHOD_FUNC(_llama_context_save_session_file), -1);
|
@@ -1378,6 +1681,7 @@ public:
|
|
1378
1681
|
rb_define_method(rb_cLLaMAContext, "sample_top_p", RUBY_METHOD_FUNC(_llama_context_sample_top_p), -1);
|
1379
1682
|
rb_define_method(rb_cLLaMAContext, "sample_tail_free", RUBY_METHOD_FUNC(_llama_context_sample_tail_free), -1);
|
1380
1683
|
rb_define_method(rb_cLLaMAContext, "sample_typical", RUBY_METHOD_FUNC(_llama_context_sample_typical), -1);
|
1684
|
+
rb_define_method(rb_cLLaMAContext, "sample_temp", RUBY_METHOD_FUNC(_llama_context_sample_temp), -1);
|
1381
1685
|
rb_define_method(rb_cLLaMAContext, "sample_temperature", RUBY_METHOD_FUNC(_llama_context_sample_temperature), -1);
|
1382
1686
|
rb_define_method(rb_cLLaMAContext, "sample_token_mirostat", RUBY_METHOD_FUNC(_llama_context_sample_token_mirostat), -1);
|
1383
1687
|
rb_define_method(rb_cLLaMAContext, "sample_token_mirostat_v2", RUBY_METHOD_FUNC(_llama_context_sample_token_mirostat_v2), -1);
|
@@ -1392,24 +1696,27 @@ private:
|
|
1392
1696
|
|
1393
1697
|
static VALUE _llama_context_initialize(int argc, VALUE* argv, VALUE self) {
|
1394
1698
|
VALUE kw_args = Qnil;
|
1395
|
-
ID kw_table[
|
1396
|
-
VALUE kw_values[
|
1699
|
+
ID kw_table[2] = { rb_intern("model"), rb_intern("params") };
|
1700
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
1397
1701
|
rb_scan_args(argc, argv, ":", &kw_args);
|
1398
|
-
rb_get_kwargs(kw_args, kw_table,
|
1702
|
+
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
1399
1703
|
|
1400
1704
|
VALUE model = kw_values[0];
|
1401
1705
|
if (!rb_obj_is_kind_of(model, rb_cLLaMAModel)) {
|
1402
1706
|
rb_raise(rb_eArgError, "model must be a Model");
|
1403
1707
|
return Qnil;
|
1404
1708
|
}
|
1709
|
+
VALUE params = kw_values[1];
|
1710
|
+
if (!rb_obj_is_kind_of(params, rb_cLLaMAContextParams)) {
|
1711
|
+
rb_raise(rb_eArgError, "params must be a ContextParams");
|
1712
|
+
return Qnil;
|
1713
|
+
}
|
1405
1714
|
|
1406
1715
|
LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
|
1407
1716
|
if (model_ptr->model == NULL) {
|
1408
1717
|
rb_raise(rb_eRuntimeError, "Model is empty");
|
1409
1718
|
return Qnil;
|
1410
1719
|
}
|
1411
|
-
|
1412
|
-
VALUE params = rb_iv_get(model, "@params");
|
1413
1720
|
LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
|
1414
1721
|
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
1415
1722
|
|
@@ -1421,6 +1728,7 @@ private:
|
|
1421
1728
|
}
|
1422
1729
|
|
1423
1730
|
rb_iv_set(self, "@model", model);
|
1731
|
+
rb_iv_set(self, "@params", params);
|
1424
1732
|
rb_iv_set(self, "@has_evaluated", Qfalse);
|
1425
1733
|
|
1426
1734
|
return Qnil;
|
@@ -1428,8 +1736,8 @@ private:
|
|
1428
1736
|
|
1429
1737
|
static VALUE _llama_context_eval(int argc, VALUE* argv, VALUE self) {
|
1430
1738
|
VALUE kw_args = Qnil;
|
1431
|
-
ID kw_table[
|
1432
|
-
VALUE kw_values[
|
1739
|
+
ID kw_table[3] = { rb_intern("tokens"), rb_intern("n_past"), rb_intern("n_tokens") };
|
1740
|
+
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
1433
1741
|
rb_scan_args(argc, argv, ":", &kw_args);
|
1434
1742
|
rb_get_kwargs(kw_args, kw_table, 2, 2, kw_values);
|
1435
1743
|
|
@@ -1445,10 +1753,6 @@ private:
|
|
1445
1753
|
rb_raise(rb_eArgError, "n_tokens must be an integer");
|
1446
1754
|
return Qnil;
|
1447
1755
|
}
|
1448
|
-
if (kw_values[3] != Qundef && !RB_INTEGER_TYPE_P(kw_values[3])) {
|
1449
|
-
rb_raise(rb_eArgError, "n_threads must be an integer");
|
1450
|
-
return Qnil;
|
1451
|
-
}
|
1452
1756
|
|
1453
1757
|
const size_t tokens_len = RARRAY_LEN(kw_values[0]);
|
1454
1758
|
std::vector<llama_token> embd(tokens_len);
|
@@ -1463,14 +1767,13 @@ private:
|
|
1463
1767
|
|
1464
1768
|
const int n_tokens = kw_values[2] == Qundef ? (int)tokens_len : NUM2INT(kw_values[2]);
|
1465
1769
|
const int n_past = NUM2INT(kw_values[1]);
|
1466
|
-
const int n_threads = kw_values[3] == Qundef ? 1 : NUM2INT(kw_values[3]);
|
1467
1770
|
|
1468
1771
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1469
1772
|
if (ptr->ctx == NULL) {
|
1470
1773
|
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1471
1774
|
return Qnil;
|
1472
1775
|
}
|
1473
|
-
if (llama_eval(ptr->ctx, embd.data(), n_tokens, n_past
|
1776
|
+
if (llama_eval(ptr->ctx, embd.data(), n_tokens, n_past) != 0) {
|
1474
1777
|
rb_raise(rb_eRuntimeError, "Failed to evaluate");
|
1475
1778
|
return Qnil;
|
1476
1779
|
}
|
@@ -1483,8 +1786,8 @@ private:
|
|
1483
1786
|
|
1484
1787
|
static VALUE _llama_context_eval_embd(int argc, VALUE* argv, VALUE self) {
|
1485
1788
|
VALUE kw_args = Qnil;
|
1486
|
-
ID kw_table[
|
1487
|
-
VALUE kw_values[
|
1789
|
+
ID kw_table[3] = { rb_intern("embd"), rb_intern("n_past"), rb_intern("n_tokens") };
|
1790
|
+
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
1488
1791
|
rb_scan_args(argc, argv, ":", &kw_args);
|
1489
1792
|
rb_get_kwargs(kw_args, kw_table, 2, 2, kw_values);
|
1490
1793
|
|
@@ -1500,10 +1803,6 @@ private:
|
|
1500
1803
|
rb_raise(rb_eArgError, "n_tokens must be an integer");
|
1501
1804
|
return Qnil;
|
1502
1805
|
}
|
1503
|
-
if (kw_values[3] != Qundef && !RB_INTEGER_TYPE_P(kw_values[3])) {
|
1504
|
-
rb_raise(rb_eArgError, "n_threads must be an integer");
|
1505
|
-
return Qnil;
|
1506
|
-
}
|
1507
1806
|
|
1508
1807
|
const size_t tokens_len = RARRAY_LEN(kw_values[0]);
|
1509
1808
|
std::vector<float> embd(tokens_len);
|
@@ -1518,14 +1817,13 @@ private:
|
|
1518
1817
|
|
1519
1818
|
const int n_tokens = kw_values[2] == Qundef ? (int)tokens_len : NUM2INT(kw_values[2]);
|
1520
1819
|
const int n_past = NUM2INT(kw_values[1]);
|
1521
|
-
const int n_threads = kw_values[3] == Qundef ? 1 : NUM2INT(kw_values[3]);
|
1522
1820
|
|
1523
1821
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1524
1822
|
if (ptr->ctx == NULL) {
|
1525
1823
|
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1526
1824
|
return Qnil;
|
1527
1825
|
}
|
1528
|
-
if (llama_eval_embd(ptr->ctx, embd.data(), n_tokens, n_past
|
1826
|
+
if (llama_eval_embd(ptr->ctx, embd.data(), n_tokens, n_past) != 0) {
|
1529
1827
|
rb_raise(rb_eRuntimeError, "Failed to evaluate");
|
1530
1828
|
return Qnil;
|
1531
1829
|
}
|
@@ -1536,91 +1834,22 @@ private:
|
|
1536
1834
|
return Qnil;
|
1537
1835
|
}
|
1538
1836
|
|
1539
|
-
static VALUE
|
1837
|
+
static VALUE _llama_context_decode(VALUE self, VALUE batch) {
|
1540
1838
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1541
1839
|
if (ptr->ctx == NULL) {
|
1542
1840
|
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1543
1841
|
return Qnil;
|
1544
1842
|
}
|
1545
|
-
if (!
|
1546
|
-
rb_raise(rb_eArgError, "
|
1547
|
-
return Qnil;
|
1548
|
-
}
|
1549
|
-
const char* fname = StringValueCStr(fname_);
|
1550
|
-
if (llama_eval_export(ptr->ctx, fname) != 0) {
|
1551
|
-
return Qfalse;
|
1552
|
-
}
|
1553
|
-
RB_GC_GUARD(fname_);
|
1554
|
-
return Qtrue;
|
1555
|
-
}
|
1556
|
-
|
1557
|
-
static VALUE _llama_context_tokenize(int argc, VALUE* argv, VALUE self) {
|
1558
|
-
VALUE kw_args = Qnil;
|
1559
|
-
ID kw_table[3] = { rb_intern("text"), rb_intern("n_max_tokens"), rb_intern("add_bos") };
|
1560
|
-
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
1561
|
-
rb_scan_args(argc, argv, ":", &kw_args);
|
1562
|
-
rb_get_kwargs(kw_args, kw_table, 1, 2, kw_values);
|
1563
|
-
|
1564
|
-
if (!RB_TYPE_P(kw_values[0], T_STRING)) {
|
1565
|
-
rb_raise(rb_eArgError, "text must be a String");
|
1566
|
-
return Qnil;
|
1567
|
-
}
|
1568
|
-
if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
|
1569
|
-
rb_raise(rb_eArgError, "n_max_tokens must be an integer");
|
1843
|
+
if (!rb_obj_is_kind_of(batch, rb_cLLaMABatch)) {
|
1844
|
+
rb_raise(rb_eArgError, "batch must be a Batch");
|
1570
1845
|
return Qnil;
|
1571
1846
|
}
|
1572
|
-
|
1573
|
-
|
1847
|
+
LLaMABatchWrapper* batch_ptr = RbLLaMABatch::get_llama_batch(batch);
|
1848
|
+
if (llama_decode(ptr->ctx, batch_ptr->batch) < 0) {
|
1849
|
+
rb_raise(rb_eRuntimeError, "Failed to decode");
|
1574
1850
|
return Qnil;
|
1575
1851
|
}
|
1576
|
-
|
1577
|
-
VALUE text_ = kw_values[0];
|
1578
|
-
std::string text = StringValueCStr(text_);
|
1579
|
-
const bool add_bos = kw_values[2] == Qtrue ? true : false;
|
1580
|
-
const int n_max_tokens = kw_values[1] != Qundef ? NUM2INT(kw_values[1]) : text.size() + (add_bos ? 1 : 0);
|
1581
|
-
|
1582
|
-
std::vector<llama_token> tokens(n_max_tokens);
|
1583
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1584
|
-
if (ptr->ctx == NULL) {
|
1585
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1586
|
-
return Qnil;
|
1587
|
-
}
|
1588
|
-
const int n = llama_tokenize(ptr->ctx, text.c_str(), text.size(), tokens.data(), n_max_tokens, add_bos);
|
1589
|
-
if (n < 0) {
|
1590
|
-
rb_raise(rb_eRuntimeError, "Failed to tokenize");
|
1591
|
-
return Qnil;
|
1592
|
-
}
|
1593
|
-
|
1594
|
-
VALUE output = rb_ary_new();
|
1595
|
-
for (int i = 0; i < n; i++) {
|
1596
|
-
rb_ary_push(output, INT2NUM(tokens[i]));
|
1597
|
-
}
|
1598
|
-
|
1599
|
-
RB_GC_GUARD(text_);
|
1600
|
-
return output;
|
1601
|
-
}
|
1602
|
-
|
1603
|
-
static VALUE _llama_context_token_to_piece(VALUE self, VALUE token_) {
|
1604
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1605
|
-
if (ptr->ctx == NULL) {
|
1606
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1607
|
-
return Qnil;
|
1608
|
-
}
|
1609
|
-
const llama_token token = NUM2INT(token_);
|
1610
|
-
std::vector<char> result(8, 0);
|
1611
|
-
const int n_tokens = llama_token_to_piece(ptr->ctx, token, result.data(), result.size());
|
1612
|
-
if (n_tokens < 0) {
|
1613
|
-
result.resize(-n_tokens);
|
1614
|
-
const int check = llama_token_to_piece(ptr->ctx, token, result.data(), result.size());
|
1615
|
-
if (check != -n_tokens) {
|
1616
|
-
rb_raise(rb_eRuntimeError, "failed to convert");
|
1617
|
-
return Qnil;
|
1618
|
-
}
|
1619
|
-
} else {
|
1620
|
-
result.resize(n_tokens);
|
1621
|
-
}
|
1622
|
-
std::string ret(result.data(), result.size());
|
1623
|
-
return rb_str_new_cstr(ret.c_str());
|
1852
|
+
return Qnil;
|
1624
1853
|
}
|
1625
1854
|
|
1626
1855
|
static VALUE _llama_context_logits(VALUE self) {
|
@@ -1635,10 +1864,11 @@ private:
|
|
1635
1864
|
}
|
1636
1865
|
|
1637
1866
|
VALUE model = rb_iv_get(self, "@model");
|
1638
|
-
|
1867
|
+
LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
|
1868
|
+
VALUE params = rb_iv_get(self, "@params");
|
1639
1869
|
LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
|
1640
1870
|
const int n_tokens = prms_ptr->params.logits_all ? NUM2INT(rb_iv_get(self, "@n_tokens")) : 1;
|
1641
|
-
const int n_vocab = llama_n_vocab(
|
1871
|
+
const int n_vocab = llama_n_vocab(model_ptr->model);
|
1642
1872
|
const float* logits = llama_get_logits(ptr->ctx);
|
1643
1873
|
VALUE output = rb_ary_new();
|
1644
1874
|
for (int i = 0; i < n_tokens * n_vocab; i++) {
|
@@ -1655,7 +1885,8 @@ private:
|
|
1655
1885
|
return Qnil;
|
1656
1886
|
}
|
1657
1887
|
VALUE model = rb_iv_get(self, "@model");
|
1658
|
-
|
1888
|
+
LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
|
1889
|
+
VALUE params = rb_iv_get(self, "@params");
|
1659
1890
|
LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
|
1660
1891
|
if (!prms_ptr->params.embedding) {
|
1661
1892
|
rb_raise(rb_eRuntimeError, "embedding parameter is false");
|
@@ -1666,7 +1897,7 @@ private:
|
|
1666
1897
|
return Qnil;
|
1667
1898
|
}
|
1668
1899
|
|
1669
|
-
const int n_embd = llama_n_embd(
|
1900
|
+
const int n_embd = llama_n_embd(model_ptr->model);
|
1670
1901
|
const float* embd = llama_get_embeddings(ptr->ctx);
|
1671
1902
|
VALUE output = rb_ary_new();
|
1672
1903
|
for (int i = 0; i < n_embd; i++) {
|
@@ -1684,7 +1915,7 @@ private:
|
|
1684
1915
|
}
|
1685
1916
|
const llama_token token = NUM2INT(token_);
|
1686
1917
|
const char* text = llama_token_get_text(ptr->ctx, token);
|
1687
|
-
return
|
1918
|
+
return rb_utf8_str_new_cstr(text);
|
1688
1919
|
}
|
1689
1920
|
|
1690
1921
|
static VALUE _llama_context_score(VALUE self, VALUE token_) {
|
@@ -1736,40 +1967,49 @@ private:
|
|
1736
1967
|
return INT2NUM(llama_token_nl(ptr->ctx));
|
1737
1968
|
}
|
1738
1969
|
|
1739
|
-
static VALUE
|
1970
|
+
static VALUE _llama_context_token_prefix(VALUE self) {
|
1740
1971
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1741
1972
|
if (ptr->ctx == NULL) {
|
1742
1973
|
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1743
1974
|
return Qnil;
|
1744
1975
|
}
|
1745
|
-
return INT2NUM(
|
1976
|
+
return INT2NUM(llama_token_prefix(ptr->ctx));
|
1746
1977
|
}
|
1747
1978
|
|
1748
|
-
static VALUE
|
1979
|
+
static VALUE _llama_context_token_middle(VALUE self) {
|
1749
1980
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1750
1981
|
if (ptr->ctx == NULL) {
|
1751
1982
|
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1752
1983
|
return Qnil;
|
1753
1984
|
}
|
1754
|
-
return INT2NUM(
|
1985
|
+
return INT2NUM(llama_token_middle(ptr->ctx));
|
1986
|
+
}
|
1987
|
+
|
1988
|
+
static VALUE _llama_context_token_suffix(VALUE self) {
|
1989
|
+
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1990
|
+
if (ptr->ctx == NULL) {
|
1991
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1992
|
+
return Qnil;
|
1993
|
+
}
|
1994
|
+
return INT2NUM(llama_token_suffix(ptr->ctx));
|
1755
1995
|
}
|
1756
1996
|
|
1757
|
-
static VALUE
|
1997
|
+
static VALUE _llama_context_token_eot(VALUE self) {
|
1758
1998
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1759
1999
|
if (ptr->ctx == NULL) {
|
1760
2000
|
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1761
2001
|
return Qnil;
|
1762
2002
|
}
|
1763
|
-
return INT2NUM(
|
2003
|
+
return INT2NUM(llama_token_eot(ptr->ctx));
|
1764
2004
|
}
|
1765
2005
|
|
1766
|
-
static VALUE
|
2006
|
+
static VALUE _llama_context_n_ctx(VALUE self) {
|
1767
2007
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1768
2008
|
if (ptr->ctx == NULL) {
|
1769
2009
|
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1770
2010
|
return Qnil;
|
1771
2011
|
}
|
1772
|
-
return INT2NUM(
|
2012
|
+
return INT2NUM(llama_n_ctx(ptr->ctx));
|
1773
2013
|
}
|
1774
2014
|
|
1775
2015
|
static VALUE _llama_context_get_timings(VALUE self) {
|
@@ -1813,6 +2053,56 @@ private:
|
|
1813
2053
|
return INT2NUM(llama_get_kv_cache_token_count(ptr->ctx));
|
1814
2054
|
}
|
1815
2055
|
|
2056
|
+
static VALUE _llama_context_kv_cache_tokens_rm(VALUE self, VALUE c0, VALUE c1) {
|
2057
|
+
LLaMAContextWrapper* ptr = get_llama_context(self);
|
2058
|
+
if (ptr->ctx == NULL) {
|
2059
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
2060
|
+
return Qnil;
|
2061
|
+
}
|
2062
|
+
llama_kv_cache_tokens_rm(ptr->ctx, NUM2INT(c0), NUM2INT(c1));
|
2063
|
+
return Qnil;
|
2064
|
+
}
|
2065
|
+
|
2066
|
+
static VALUE _llama_context_kv_cache_seq_rm(VALUE self, VALUE seq_id, VALUE p0, VALUE p1) {
|
2067
|
+
LLaMAContextWrapper* ptr = get_llama_context(self);
|
2068
|
+
if (ptr->ctx == NULL) {
|
2069
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
2070
|
+
return Qnil;
|
2071
|
+
}
|
2072
|
+
llama_kv_cache_seq_rm(ptr->ctx, NUM2INT(seq_id), NUM2INT(p0), NUM2INT(p1));
|
2073
|
+
return Qnil;
|
2074
|
+
}
|
2075
|
+
|
2076
|
+
static VALUE _llama_context_kv_cache_seq_cp(VALUE self, VALUE seq_id_src, VALUE seq_id_dst, VALUE p0, VALUE p1) {
|
2077
|
+
LLaMAContextWrapper* ptr = get_llama_context(self);
|
2078
|
+
if (ptr->ctx == NULL) {
|
2079
|
+
rb_raise(rb_eArgError, "LLaMA context is not initialized");
|
2080
|
+
return Qnil;
|
2081
|
+
}
|
2082
|
+
llama_kv_cache_seq_cp(ptr->ctx, NUM2INT(seq_id_src), NUM2INT(seq_id_dst), NUM2INT(p0), NUM2INT(p1));
|
2083
|
+
return Qnil;
|
2084
|
+
}
|
2085
|
+
|
2086
|
+
static VALUE _llama_context_kv_cache_seq_keep(VALUE self, VALUE seq_id) {
|
2087
|
+
LLaMAContextWrapper* ptr = get_llama_context(self);
|
2088
|
+
if (ptr->ctx == NULL) {
|
2089
|
+
rb_raise(rb_eArgError, "LLaMA context is not initialized");
|
2090
|
+
return Qnil;
|
2091
|
+
}
|
2092
|
+
llama_kv_cache_seq_keep(ptr->ctx, NUM2INT(seq_id));
|
2093
|
+
return Qnil;
|
2094
|
+
}
|
2095
|
+
|
2096
|
+
static VALUE _llama_context_kv_cache_seq_shift(VALUE self, VALUE seq_id, VALUE p0, VALUE p1, VALUE delta) {
|
2097
|
+
LLaMAContextWrapper* ptr = get_llama_context(self);
|
2098
|
+
if (ptr->ctx == NULL) {
|
2099
|
+
rb_raise(rb_eArgError, "LLaMA context is not initialized");
|
2100
|
+
return Qnil;
|
2101
|
+
}
|
2102
|
+
llama_kv_cache_seq_shift(ptr->ctx, NUM2INT(seq_id), NUM2INT(p0), NUM2INT(p1), NUM2INT(delta));
|
2103
|
+
return Qnil;
|
2104
|
+
}
|
2105
|
+
|
1816
2106
|
static VALUE _llama_context_set_rng_seed(VALUE self, VALUE seed_) {
|
1817
2107
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1818
2108
|
if (ptr->ctx == NULL) {
|
@@ -1851,7 +2141,7 @@ private:
|
|
1851
2141
|
}
|
1852
2142
|
|
1853
2143
|
VALUE model = rb_iv_get(self, "@model");
|
1854
|
-
VALUE params = rb_iv_get(
|
2144
|
+
VALUE params = rb_iv_get(self, "@params");
|
1855
2145
|
LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
|
1856
2146
|
const int n_ctx = prms_ptr->params.n_ctx;
|
1857
2147
|
|
@@ -2235,6 +2525,40 @@ private:
|
|
2235
2525
|
return Qnil;
|
2236
2526
|
}
|
2237
2527
|
|
2528
|
+
static VALUE _llama_context_sample_temp(int argc, VALUE* argv, VALUE self) {
|
2529
|
+
VALUE kw_args = Qnil;
|
2530
|
+
ID kw_table[1] = { rb_intern("temp") };
|
2531
|
+
VALUE kw_values[1] = { Qundef };
|
2532
|
+
VALUE candidates = Qnil;
|
2533
|
+
rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
|
2534
|
+
rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
|
2535
|
+
|
2536
|
+
if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
|
2537
|
+
rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
|
2538
|
+
return Qnil;
|
2539
|
+
}
|
2540
|
+
if (!RB_FLOAT_TYPE_P(kw_values[0])) {
|
2541
|
+
rb_raise(rb_eArgError, "temp must be a float");
|
2542
|
+
return Qnil;
|
2543
|
+
}
|
2544
|
+
|
2545
|
+
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
2546
|
+
if (ctx_ptr->ctx == NULL) {
|
2547
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
2548
|
+
return Qnil;
|
2549
|
+
}
|
2550
|
+
LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
|
2551
|
+
if (cnd_ptr->array.data == nullptr) {
|
2552
|
+
rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
|
2553
|
+
return Qnil;
|
2554
|
+
}
|
2555
|
+
const float temp = NUM2DBL(kw_values[0]);
|
2556
|
+
|
2557
|
+
llama_sample_temp(ctx_ptr->ctx, &(cnd_ptr->array), temp);
|
2558
|
+
|
2559
|
+
return Qnil;
|
2560
|
+
}
|
2561
|
+
|
2238
2562
|
static VALUE _llama_context_sample_temperature(int argc, VALUE* argv, VALUE self) {
|
2239
2563
|
VALUE kw_args = Qnil;
|
2240
2564
|
ID kw_table[1] = { rb_intern("temperature") };
|
@@ -2560,6 +2884,7 @@ extern "C" void Init_llama_cpp(void) {
|
|
2560
2884
|
RbLLaMATokenData::define_class(rb_mLLaMACpp);
|
2561
2885
|
RbLLaMATokenDataArray::define_class(rb_mLLaMACpp);
|
2562
2886
|
RbLLaMAModel::define_class(rb_mLLaMACpp);
|
2887
|
+
RbLLaMAModelParams::define_class(rb_mLLaMACpp);
|
2563
2888
|
RbLLaMATimings::define_class(rb_mLLaMACpp);
|
2564
2889
|
RbLLaMAContext::define_class(rb_mLLaMACpp);
|
2565
2890
|
RbLLaMAContextParams::define_class(rb_mLLaMACpp);
|
@@ -2578,10 +2903,6 @@ extern "C" void Init_llama_cpp(void) {
|
|
2578
2903
|
|
2579
2904
|
rb_define_const(rb_mLLaMACpp, "LLAMA_MAX_DEVICES", INT2NUM(LLAMA_MAX_DEVICES));
|
2580
2905
|
|
2581
|
-
rb_define_const(rb_mLLaMACpp, "LLAMA_LOG_LEVEL_ERROR", INT2NUM(LLAMA_LOG_LEVEL_ERROR));
|
2582
|
-
rb_define_const(rb_mLLaMACpp, "LLAMA_LOG_LEVEL_WARN", INT2NUM(LLAMA_LOG_LEVEL_WARN));
|
2583
|
-
rb_define_const(rb_mLLaMACpp, "LLAMA_LOG_LEVEL_INFO", INT2NUM(LLAMA_LOG_LEVEL_INFO));
|
2584
|
-
|
2585
2906
|
rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_TYPE_SPM", INT2NUM(LLAMA_VOCAB_TYPE_SPM));
|
2586
2907
|
rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_TYPE_BPE", INT2NUM(LLAMA_VOCAB_TYPE_BPE));
|
2587
2908
|
|