llama_cpp 0.5.3 → 0.7.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +17 -0
- data/README.md +6 -5
- data/examples/chat.rb +13 -13
- data/examples/embedding.rb +9 -9
- data/ext/llama_cpp/llama_cpp.cpp +583 -262
- data/ext/llama_cpp/src/ggml-alloc.c +8 -2
- data/ext/llama_cpp/src/ggml-alloc.h +1 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +326 -149
- data/ext/llama_cpp/src/ggml-cuda.h +1 -0
- data/ext/llama_cpp/src/ggml-metal.h +4 -0
- data/ext/llama_cpp/src/ggml-metal.m +167 -89
- data/ext/llama_cpp/src/ggml-metal.metal +130 -40
- data/ext/llama_cpp/src/ggml-opencl.cpp +119 -53
- data/ext/llama_cpp/src/ggml.c +2355 -1166
- data/ext/llama_cpp/src/ggml.h +129 -35
- data/ext/llama_cpp/src/k_quants.c +744 -2
- data/ext/llama_cpp/src/llama.cpp +1766 -671
- data/ext/llama_cpp/src/llama.h +321 -120
- data/ext/llama_cpp/src/unicode.h +462 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +6 -10
- data/sig/llama_cpp.rbs +70 -34
- metadata +4 -3
data/ext/llama_cpp/llama_cpp.cpp
CHANGED
@@ -1,7 +1,9 @@
|
|
1
1
|
#include "llama_cpp.h"
|
2
2
|
|
3
3
|
VALUE rb_mLLaMACpp;
|
4
|
+
VALUE rb_cLLaMABatch;
|
4
5
|
VALUE rb_cLLaMAModel;
|
6
|
+
VALUE rb_cLLaMAModelParams;
|
5
7
|
VALUE rb_cLLaMATimings;
|
6
8
|
VALUE rb_cLLaMAContext;
|
7
9
|
VALUE rb_cLLaMAContextParams;
|
@@ -11,6 +13,238 @@ VALUE rb_cLLaMATokenDataArray;
|
|
11
13
|
VALUE rb_cLLaMAGrammarElement;
|
12
14
|
VALUE rb_cLLaMAGrammar;
|
13
15
|
|
16
|
+
class LLaMABatchWrapper {
|
17
|
+
public:
|
18
|
+
llama_batch batch;
|
19
|
+
|
20
|
+
LLaMABatchWrapper() {}
|
21
|
+
|
22
|
+
~LLaMABatchWrapper() {
|
23
|
+
llama_batch_free(batch);
|
24
|
+
}
|
25
|
+
};
|
26
|
+
|
27
|
+
class RbLLaMABatch {
|
28
|
+
public:
|
29
|
+
static VALUE llama_batch_alloc(VALUE self) {
|
30
|
+
LLaMABatchWrapper* ptr = (LLaMABatchWrapper*)ruby_xmalloc(sizeof(LLaMABatchWrapper));
|
31
|
+
new (ptr) LLaMABatchWrapper();
|
32
|
+
return TypedData_Wrap_Struct(self, &llama_batch_type, ptr);
|
33
|
+
}
|
34
|
+
|
35
|
+
static void llama_batch_free(void* ptr) {
|
36
|
+
((LLaMABatchWrapper*)ptr)->~LLaMABatchWrapper();
|
37
|
+
ruby_xfree(ptr);
|
38
|
+
}
|
39
|
+
|
40
|
+
static size_t llama_batch_size(const void* ptr) {
|
41
|
+
return sizeof(*((LLaMABatchWrapper*)ptr));
|
42
|
+
}
|
43
|
+
|
44
|
+
static LLaMABatchWrapper* get_llama_batch(VALUE self) {
|
45
|
+
LLaMABatchWrapper* ptr;
|
46
|
+
TypedData_Get_Struct(self, LLaMABatchWrapper, &llama_batch_type, ptr);
|
47
|
+
return ptr;
|
48
|
+
}
|
49
|
+
|
50
|
+
static void define_class(VALUE outer) {
|
51
|
+
rb_cLLaMABatch = rb_define_class_under(outer, "Batch", rb_cObject);
|
52
|
+
rb_define_alloc_func(rb_cLLaMABatch, llama_batch_alloc);
|
53
|
+
rb_define_method(rb_cLLaMABatch, "initialize", RUBY_METHOD_FUNC(_llama_batch_initialize), -1);
|
54
|
+
rb_define_method(rb_cLLaMABatch, "n_tokens=", RUBY_METHOD_FUNC(_llama_batch_set_n_tokens), 1);
|
55
|
+
rb_define_method(rb_cLLaMABatch, "n_tokens", RUBY_METHOD_FUNC(_llama_batch_get_n_tokens), 0);
|
56
|
+
rb_define_method(rb_cLLaMABatch, "all_pos_zero=", RUBY_METHOD_FUNC(_llama_batch_set_all_pos_zero), 1);
|
57
|
+
rb_define_method(rb_cLLaMABatch, "all_pos_zero", RUBY_METHOD_FUNC(_llama_batch_get_all_pos_zero), 0);
|
58
|
+
rb_define_method(rb_cLLaMABatch, "all_pos_one=", RUBY_METHOD_FUNC(_llama_batch_set_all_pos_one), 1);
|
59
|
+
rb_define_method(rb_cLLaMABatch, "all_pos_one", RUBY_METHOD_FUNC(_llama_batch_get_all_pos_one), 0);
|
60
|
+
rb_define_method(rb_cLLaMABatch, "all_seq_id=", RUBY_METHOD_FUNC(_llama_batch_set_all_seq_id), 1);
|
61
|
+
rb_define_method(rb_cLLaMABatch, "all_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_all_seq_id), 0);
|
62
|
+
rb_define_method(rb_cLLaMABatch, "set_token", RUBY_METHOD_FUNC(_llama_batch_set_token), 2);
|
63
|
+
rb_define_method(rb_cLLaMABatch, "get_token", RUBY_METHOD_FUNC(_llama_batch_get_token), 1);
|
64
|
+
rb_define_method(rb_cLLaMABatch, "set_pos", RUBY_METHOD_FUNC(_llama_batch_set_pos), 2);
|
65
|
+
rb_define_method(rb_cLLaMABatch, "get_pos", RUBY_METHOD_FUNC(_llama_batch_get_pos), 1);
|
66
|
+
rb_define_method(rb_cLLaMABatch, "set_seq_id", RUBY_METHOD_FUNC(_llama_batch_set_seq_id), 2);
|
67
|
+
rb_define_method(rb_cLLaMABatch, "get_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_seq_id), 1);
|
68
|
+
rb_define_method(rb_cLLaMABatch, "set_logits", RUBY_METHOD_FUNC(_llama_batch_set_logits), 2);
|
69
|
+
rb_define_method(rb_cLLaMABatch, "get_logits", RUBY_METHOD_FUNC(_llama_batch_get_logits), 1);
|
70
|
+
}
|
71
|
+
|
72
|
+
private:
|
73
|
+
static const rb_data_type_t llama_batch_type;
|
74
|
+
|
75
|
+
static VALUE _llama_batch_initialize(int argc, VALUE* argv, VALUE self) {
|
76
|
+
VALUE kw_args = Qnil;
|
77
|
+
ID kw_table[2] = { rb_intern("n_tokens"), rb_intern("embd") };
|
78
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
79
|
+
rb_scan_args(argc, argv, ":", &kw_args);
|
80
|
+
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
81
|
+
|
82
|
+
if (!RB_INTEGER_TYPE_P(kw_values[0])) {
|
83
|
+
rb_raise(rb_eArgError, "n_tokens must be an integer");
|
84
|
+
return Qnil;
|
85
|
+
}
|
86
|
+
if (!RB_INTEGER_TYPE_P(kw_values[1])) {
|
87
|
+
rb_raise(rb_eArgError, "embd must be an integer");
|
88
|
+
return Qnil;
|
89
|
+
}
|
90
|
+
|
91
|
+
const int32_t n_tokens = NUM2INT(kw_values[0]);
|
92
|
+
const int32_t embd = NUM2INT(kw_values[1]);
|
93
|
+
|
94
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
95
|
+
ptr->batch = llama_batch_init(n_tokens, embd);
|
96
|
+
|
97
|
+
return Qnil;
|
98
|
+
}
|
99
|
+
|
100
|
+
// n_tokens
|
101
|
+
static VALUE _llama_batch_set_n_tokens(VALUE self, VALUE n_tokens) {
|
102
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
103
|
+
ptr->batch.n_tokens = NUM2INT(n_tokens);
|
104
|
+
return INT2NUM(ptr->batch.n_tokens);
|
105
|
+
}
|
106
|
+
|
107
|
+
static VALUE _llama_batch_get_n_tokens(VALUE self) {
|
108
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
109
|
+
return INT2NUM(ptr->batch.n_tokens);
|
110
|
+
}
|
111
|
+
|
112
|
+
// all_pos_0
|
113
|
+
static VALUE _llama_batch_set_all_pos_zero(VALUE self, VALUE all_pos_0) {
|
114
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
115
|
+
ptr->batch.all_pos_0 = NUM2INT(all_pos_0);
|
116
|
+
return INT2NUM(ptr->batch.all_pos_0);
|
117
|
+
}
|
118
|
+
|
119
|
+
static VALUE _llama_batch_get_all_pos_zero(VALUE self) {
|
120
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
121
|
+
return INT2NUM(ptr->batch.all_pos_0);
|
122
|
+
}
|
123
|
+
|
124
|
+
// all_pos_1
|
125
|
+
static VALUE _llama_batch_set_all_pos_one(VALUE self, VALUE all_pos_1) {
|
126
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
127
|
+
ptr->batch.all_pos_1 = NUM2INT(all_pos_1);
|
128
|
+
return INT2NUM(ptr->batch.all_pos_1);
|
129
|
+
}
|
130
|
+
|
131
|
+
static VALUE _llama_batch_get_all_pos_one(VALUE self) {
|
132
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
133
|
+
return INT2NUM(ptr->batch.all_pos_1);
|
134
|
+
}
|
135
|
+
|
136
|
+
// all_seq_id
|
137
|
+
static VALUE _llama_batch_set_all_seq_id(VALUE self, VALUE all_seq_id) {
|
138
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
139
|
+
ptr->batch.all_seq_id = NUM2INT(all_seq_id);
|
140
|
+
return INT2NUM(ptr->batch.all_seq_id);
|
141
|
+
}
|
142
|
+
|
143
|
+
static VALUE _llama_batch_get_all_seq_id(VALUE self) {
|
144
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
145
|
+
return INT2NUM(ptr->batch.all_seq_id);
|
146
|
+
}
|
147
|
+
|
148
|
+
// token
|
149
|
+
static VALUE _llama_batch_set_token(VALUE self, VALUE idx, VALUE value) {
|
150
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
151
|
+
const int32_t id = NUM2INT(idx);
|
152
|
+
if (id < 0 || id >= ptr->batch.n_tokens) {
|
153
|
+
rb_raise(rb_eArgError, "idx must be in [0, n_tokens)");
|
154
|
+
return Qnil;
|
155
|
+
}
|
156
|
+
ptr->batch.token[id] = NUM2INT(value);
|
157
|
+
return INT2NUM(ptr->batch.token[id]);
|
158
|
+
}
|
159
|
+
|
160
|
+
static VALUE _llama_batch_get_token(VALUE self, VALUE idx) {
|
161
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
162
|
+
const int32_t id = NUM2INT(idx);
|
163
|
+
if (id < 0 || id >= ptr->batch.n_tokens) {
|
164
|
+
rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
|
165
|
+
return Qnil;
|
166
|
+
}
|
167
|
+
return INT2NUM(ptr->batch.token[id]);
|
168
|
+
}
|
169
|
+
|
170
|
+
// pos
|
171
|
+
static VALUE _llama_batch_set_pos(VALUE self, VALUE idx, VALUE value) {
|
172
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
173
|
+
const int32_t id = NUM2INT(idx);
|
174
|
+
if (id < 0 || id >= ptr->batch.n_tokens) {
|
175
|
+
rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
|
176
|
+
return Qnil;
|
177
|
+
}
|
178
|
+
ptr->batch.pos[id] = NUM2INT(value);
|
179
|
+
return INT2NUM(ptr->batch.pos[id]);
|
180
|
+
}
|
181
|
+
|
182
|
+
static VALUE _llama_batch_get_pos(VALUE self, VALUE idx) {
|
183
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
184
|
+
const int32_t id = NUM2INT(idx);
|
185
|
+
if (id < 0 || id >= ptr->batch.n_tokens) {
|
186
|
+
rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
|
187
|
+
return Qnil;
|
188
|
+
}
|
189
|
+
return INT2NUM(ptr->batch.pos[id]);
|
190
|
+
}
|
191
|
+
|
192
|
+
// seq_id
|
193
|
+
static VALUE _llama_batch_set_seq_id(VALUE self, VALUE idx, VALUE value) {
|
194
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
195
|
+
const int32_t id = NUM2INT(idx);
|
196
|
+
if (id < 0 || id >= ptr->batch.n_tokens) {
|
197
|
+
rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
|
198
|
+
return Qnil;
|
199
|
+
}
|
200
|
+
ptr->batch.seq_id[id] = NUM2INT(value);
|
201
|
+
return INT2NUM(ptr->batch.seq_id[id]);
|
202
|
+
}
|
203
|
+
|
204
|
+
static VALUE _llama_batch_get_seq_id(VALUE self, VALUE idx) {
|
205
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
206
|
+
const int32_t id = NUM2INT(idx);
|
207
|
+
if (id < 0 || id >= ptr->batch.n_tokens) {
|
208
|
+
rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
|
209
|
+
return Qnil;
|
210
|
+
}
|
211
|
+
return INT2NUM(ptr->batch.seq_id[id]);
|
212
|
+
}
|
213
|
+
|
214
|
+
// logits
|
215
|
+
static VALUE _llama_batch_set_logits(VALUE self, VALUE idx, VALUE value) {
|
216
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
217
|
+
const int32_t id = NUM2INT(idx);
|
218
|
+
if (id < 0 || id >= ptr->batch.n_tokens) {
|
219
|
+
rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
|
220
|
+
return Qnil;
|
221
|
+
}
|
222
|
+
ptr->batch.logits[id] = RTEST(value) ? true : false;
|
223
|
+
return ptr->batch.logits[id] ? Qtrue : Qfalse;
|
224
|
+
}
|
225
|
+
|
226
|
+
static VALUE _llama_batch_get_logits(VALUE self, VALUE idx) {
|
227
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
228
|
+
const int32_t id = NUM2INT(idx);
|
229
|
+
if (id < 0 || id >= ptr->batch.n_tokens) {
|
230
|
+
rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
|
231
|
+
return Qnil;
|
232
|
+
}
|
233
|
+
return ptr->batch.logits[id] ? Qtrue : Qfalse;
|
234
|
+
}
|
235
|
+
};
|
236
|
+
|
237
|
+
const rb_data_type_t RbLLaMABatch::llama_batch_type = {
|
238
|
+
"RbLLaMABatch",
|
239
|
+
{ NULL,
|
240
|
+
RbLLaMABatch::llama_batch_free,
|
241
|
+
RbLLaMABatch::llama_batch_size },
|
242
|
+
NULL,
|
243
|
+
NULL,
|
244
|
+
RUBY_TYPED_FREE_IMMEDIATELY
|
245
|
+
|
246
|
+
};
|
247
|
+
|
14
248
|
class LLaMATokenDataWrapper {
|
15
249
|
public:
|
16
250
|
llama_token_data data;
|
@@ -363,6 +597,144 @@ const rb_data_type_t RbLLaMATimings::llama_timings_type = {
|
|
363
597
|
RUBY_TYPED_FREE_IMMEDIATELY
|
364
598
|
};
|
365
599
|
|
600
|
+
class LLaMAModelParamsWrapper {
|
601
|
+
public:
|
602
|
+
struct llama_model_params params;
|
603
|
+
|
604
|
+
LLaMAModelParamsWrapper() : params(llama_model_default_params()) {}
|
605
|
+
|
606
|
+
~LLaMAModelParamsWrapper() {}
|
607
|
+
};
|
608
|
+
|
609
|
+
class RbLLaMAModelParams {
|
610
|
+
public:
|
611
|
+
static VALUE llama_model_params_alloc(VALUE self) {
|
612
|
+
LLaMAModelParamsWrapper* ptr = (LLaMAModelParamsWrapper*)ruby_xmalloc(sizeof(LLaMAModelParamsWrapper));
|
613
|
+
new (ptr) LLaMAModelParamsWrapper();
|
614
|
+
return TypedData_Wrap_Struct(self, &llama_model_params_type, ptr);
|
615
|
+
}
|
616
|
+
|
617
|
+
static void llama_model_params_free(void* ptr) {
|
618
|
+
((LLaMAModelParamsWrapper*)ptr)->~LLaMAModelParamsWrapper();
|
619
|
+
ruby_xfree(ptr);
|
620
|
+
}
|
621
|
+
|
622
|
+
static size_t llama_model_params_size(const void* ptr) {
|
623
|
+
return sizeof(*((LLaMAModelParamsWrapper*)ptr));
|
624
|
+
}
|
625
|
+
|
626
|
+
static LLaMAModelParamsWrapper* get_llama_model_params(VALUE self) {
|
627
|
+
LLaMAModelParamsWrapper* ptr;
|
628
|
+
TypedData_Get_Struct(self, LLaMAModelParamsWrapper, &llama_model_params_type, ptr);
|
629
|
+
return ptr;
|
630
|
+
}
|
631
|
+
|
632
|
+
static void define_class(VALUE outer) {
|
633
|
+
rb_cLLaMAModelParams = rb_define_class_under(outer, "ModelParams", rb_cObject);
|
634
|
+
rb_define_alloc_func(rb_cLLaMAModelParams, llama_model_params_alloc);
|
635
|
+
rb_define_method(rb_cLLaMAModelParams, "n_gpu_layers=", RUBY_METHOD_FUNC(_llama_model_params_set_n_gpu_layers), 1);
|
636
|
+
rb_define_method(rb_cLLaMAModelParams, "n_gpu_layers", RUBY_METHOD_FUNC(_llama_model_params_get_n_gpu_layers), 0);
|
637
|
+
rb_define_method(rb_cLLaMAModelParams, "main_gpu=", RUBY_METHOD_FUNC(_llama_model_params_set_main_gpu), 1);
|
638
|
+
rb_define_method(rb_cLLaMAModelParams, "main_gpu", RUBY_METHOD_FUNC(_llama_model_params_get_main_gpu), 0);
|
639
|
+
rb_define_method(rb_cLLaMAModelParams, "tensor_split", RUBY_METHOD_FUNC(_llama_model_params_get_tensor_split), 0);
|
640
|
+
rb_define_method(rb_cLLaMAModelParams, "vocab_only=", RUBY_METHOD_FUNC(_llama_model_params_set_vocab_only), 1);
|
641
|
+
rb_define_method(rb_cLLaMAModelParams, "vocab_only", RUBY_METHOD_FUNC(_llama_model_params_get_vocab_only), 0);
|
642
|
+
rb_define_method(rb_cLLaMAModelParams, "use_mmap=", RUBY_METHOD_FUNC(_llama_model_params_set_use_mmap), 1);
|
643
|
+
rb_define_method(rb_cLLaMAModelParams, "use_mmap", RUBY_METHOD_FUNC(_llama_model_params_get_use_mmap), 0);
|
644
|
+
rb_define_method(rb_cLLaMAModelParams, "use_mlock=", RUBY_METHOD_FUNC(_llama_model_params_set_use_mlock), 1);
|
645
|
+
rb_define_method(rb_cLLaMAModelParams, "use_mlock", RUBY_METHOD_FUNC(_llama_model_params_get_use_mlock), 0);
|
646
|
+
}
|
647
|
+
|
648
|
+
private:
|
649
|
+
static const rb_data_type_t llama_model_params_type;
|
650
|
+
|
651
|
+
// n_gpu_layers
|
652
|
+
static VALUE _llama_model_params_set_n_gpu_layers(VALUE self, VALUE n_gpu_layers) {
|
653
|
+
LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
|
654
|
+
ptr->params.n_gpu_layers = NUM2INT(n_gpu_layers);
|
655
|
+
return INT2NUM(ptr->params.n_gpu_layers);
|
656
|
+
}
|
657
|
+
|
658
|
+
static VALUE _llama_model_params_get_n_gpu_layers(VALUE self) {
|
659
|
+
LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
|
660
|
+
return INT2NUM(ptr->params.n_gpu_layers);
|
661
|
+
}
|
662
|
+
|
663
|
+
// main_gpu
|
664
|
+
static VALUE _llama_model_params_set_main_gpu(VALUE self, VALUE main_gpu) {
|
665
|
+
LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
|
666
|
+
ptr->params.main_gpu = NUM2INT(main_gpu);
|
667
|
+
return INT2NUM(ptr->params.main_gpu);
|
668
|
+
}
|
669
|
+
|
670
|
+
static VALUE _llama_model_params_get_main_gpu(VALUE self) {
|
671
|
+
LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
|
672
|
+
return INT2NUM(ptr->params.main_gpu);
|
673
|
+
}
|
674
|
+
|
675
|
+
// tensor_split
|
676
|
+
static VALUE _llama_model_params_get_tensor_split(VALUE self) {
|
677
|
+
if (LLAMA_MAX_DEVICES < 1) {
|
678
|
+
return rb_ary_new();
|
679
|
+
}
|
680
|
+
VALUE ret = rb_ary_new2(LLAMA_MAX_DEVICES);
|
681
|
+
LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
|
682
|
+
if (ptr->params.tensor_split == nullptr) {
|
683
|
+
return rb_ary_new();
|
684
|
+
}
|
685
|
+
for (size_t i = 0; i < LLAMA_MAX_DEVICES; i++) {
|
686
|
+
rb_ary_store(ret, i, DBL2NUM(ptr->params.tensor_split[i]));
|
687
|
+
}
|
688
|
+
return ret;
|
689
|
+
}
|
690
|
+
|
691
|
+
// vocab_only
|
692
|
+
static VALUE _llama_model_params_set_vocab_only(VALUE self, VALUE vocab_only) {
|
693
|
+
LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
|
694
|
+
ptr->params.vocab_only = RTEST(vocab_only) ? true : false;
|
695
|
+
return ptr->params.vocab_only ? Qtrue : Qfalse;
|
696
|
+
}
|
697
|
+
|
698
|
+
static VALUE _llama_model_params_get_vocab_only(VALUE self) {
|
699
|
+
LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
|
700
|
+
return ptr->params.vocab_only ? Qtrue : Qfalse;
|
701
|
+
}
|
702
|
+
|
703
|
+
// use_mmap
|
704
|
+
static VALUE _llama_model_params_set_use_mmap(VALUE self, VALUE use_mmap) {
|
705
|
+
LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
|
706
|
+
ptr->params.use_mmap = RTEST(use_mmap) ? true : false;
|
707
|
+
return ptr->params.use_mmap ? Qtrue : Qfalse;
|
708
|
+
}
|
709
|
+
|
710
|
+
static VALUE _llama_model_params_get_use_mmap(VALUE self) {
|
711
|
+
LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
|
712
|
+
return ptr->params.use_mmap ? Qtrue : Qfalse;
|
713
|
+
}
|
714
|
+
|
715
|
+
// use_mlock
|
716
|
+
static VALUE _llama_model_params_set_use_mlock(VALUE self, VALUE use_mlock) {
|
717
|
+
LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
|
718
|
+
ptr->params.use_mlock = RTEST(use_mlock) ? true : false;
|
719
|
+
return ptr->params.use_mlock ? Qtrue : Qfalse;
|
720
|
+
}
|
721
|
+
|
722
|
+
static VALUE _llama_model_params_get_use_mlock(VALUE self) {
|
723
|
+
LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
|
724
|
+
return ptr->params.use_mlock ? Qtrue : Qfalse;
|
725
|
+
}
|
726
|
+
};
|
727
|
+
|
728
|
+
const rb_data_type_t RbLLaMAModelParams::llama_model_params_type = {
|
729
|
+
"RbLLaMAModelParams",
|
730
|
+
{ NULL,
|
731
|
+
RbLLaMAModelParams::llama_model_params_free,
|
732
|
+
RbLLaMAModelParams::llama_model_params_size },
|
733
|
+
NULL,
|
734
|
+
NULL,
|
735
|
+
RUBY_TYPED_FREE_IMMEDIATELY
|
736
|
+
};
|
737
|
+
|
366
738
|
class LLaMAContextParamsWrapper {
|
367
739
|
public:
|
368
740
|
struct llama_context_params params;
|
@@ -399,35 +771,26 @@ public:
|
|
399
771
|
rb_cLLaMAContextParams = rb_define_class_under(outer, "ContextParams", rb_cObject);
|
400
772
|
rb_define_alloc_func(rb_cLLaMAContextParams, llama_context_params_alloc);
|
401
773
|
// rb_define_method(rb_cLLaMAContextParams, "initialize", RUBY_METHOD_FUNC(_llama_context_params_init), 0);
|
774
|
+
rb_define_method(rb_cLLaMAContextParams, "seed=", RUBY_METHOD_FUNC(_llama_context_params_set_seed), 1);
|
775
|
+
rb_define_method(rb_cLLaMAContextParams, "seed", RUBY_METHOD_FUNC(_llama_context_params_get_seed), 0);
|
402
776
|
rb_define_method(rb_cLLaMAContextParams, "n_ctx=", RUBY_METHOD_FUNC(_llama_context_params_set_n_ctx), 1);
|
403
777
|
rb_define_method(rb_cLLaMAContextParams, "n_ctx", RUBY_METHOD_FUNC(_llama_context_params_get_n_ctx), 0);
|
404
778
|
rb_define_method(rb_cLLaMAContextParams, "n_batch=", RUBY_METHOD_FUNC(_llama_context_params_set_n_batch), 1);
|
405
779
|
rb_define_method(rb_cLLaMAContextParams, "n_batch", RUBY_METHOD_FUNC(_llama_context_params_get_n_batch), 0);
|
406
|
-
rb_define_method(rb_cLLaMAContextParams, "
|
407
|
-
rb_define_method(rb_cLLaMAContextParams, "
|
408
|
-
rb_define_method(rb_cLLaMAContextParams, "
|
409
|
-
rb_define_method(rb_cLLaMAContextParams, "
|
410
|
-
rb_define_method(rb_cLLaMAContextParams, "tensor_split", RUBY_METHOD_FUNC(_llama_context_params_get_tensor_split), 0);
|
780
|
+
rb_define_method(rb_cLLaMAContextParams, "n_threads=", RUBY_METHOD_FUNC(_llama_context_params_set_n_threads), 1);
|
781
|
+
rb_define_method(rb_cLLaMAContextParams, "n_threads", RUBY_METHOD_FUNC(_llama_context_params_get_n_threads), 0);
|
782
|
+
rb_define_method(rb_cLLaMAContextParams, "n_threads_batch=", RUBY_METHOD_FUNC(_llama_context_params_set_n_threads_batch), 1);
|
783
|
+
rb_define_method(rb_cLLaMAContextParams, "n_threads_batch", RUBY_METHOD_FUNC(_llama_context_params_get_n_threads_batch), 0);
|
411
784
|
rb_define_method(rb_cLLaMAContextParams, "rope_freq_base=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_freq_base), 1);
|
412
785
|
rb_define_method(rb_cLLaMAContextParams, "rope_freq_base", RUBY_METHOD_FUNC(_llama_context_params_get_rope_freq_base), 0);
|
413
786
|
rb_define_method(rb_cLLaMAContextParams, "rope_freq_scale=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_freq_scale), 1);
|
414
787
|
rb_define_method(rb_cLLaMAContextParams, "rope_freq_scale", RUBY_METHOD_FUNC(_llama_context_params_get_rope_freq_scale), 0);
|
415
|
-
rb_define_method(rb_cLLaMAContextParams, "low_vram=", RUBY_METHOD_FUNC(_llama_context_params_set_low_vram), 1);
|
416
|
-
rb_define_method(rb_cLLaMAContextParams, "low_vram", RUBY_METHOD_FUNC(_llama_context_params_get_low_vram), 0);
|
417
788
|
rb_define_method(rb_cLLaMAContextParams, "mul_mat_q=", RUBY_METHOD_FUNC(_llama_context_params_set_mul_mat_q), 1);
|
418
789
|
rb_define_method(rb_cLLaMAContextParams, "mul_mat_q", RUBY_METHOD_FUNC(_llama_context_params_get_mul_mat_q), 0);
|
419
|
-
rb_define_method(rb_cLLaMAContextParams, "seed=", RUBY_METHOD_FUNC(_llama_context_params_set_seed), 1);
|
420
|
-
rb_define_method(rb_cLLaMAContextParams, "seed", RUBY_METHOD_FUNC(_llama_context_params_get_seed), 0);
|
421
790
|
rb_define_method(rb_cLLaMAContextParams, "f16_kv=", RUBY_METHOD_FUNC(_llama_context_params_set_f16_kv), 1);
|
422
791
|
rb_define_method(rb_cLLaMAContextParams, "f16_kv", RUBY_METHOD_FUNC(_llama_context_params_get_f16_kv), 0);
|
423
792
|
rb_define_method(rb_cLLaMAContextParams, "logits_all=", RUBY_METHOD_FUNC(_llama_context_params_set_logits_all), 1);
|
424
793
|
rb_define_method(rb_cLLaMAContextParams, "logits_all", RUBY_METHOD_FUNC(_llama_context_params_get_logits_all), 0);
|
425
|
-
rb_define_method(rb_cLLaMAContextParams, "vocab_only=", RUBY_METHOD_FUNC(_llama_context_params_set_vocab_only), 1);
|
426
|
-
rb_define_method(rb_cLLaMAContextParams, "vocab_only", RUBY_METHOD_FUNC(_llama_context_params_get_vocab_only), 0);
|
427
|
-
rb_define_method(rb_cLLaMAContextParams, "use_mmap=", RUBY_METHOD_FUNC(_llama_context_params_set_use_mmap), 1);
|
428
|
-
rb_define_method(rb_cLLaMAContextParams, "use_mmap", RUBY_METHOD_FUNC(_llama_context_params_get_use_mmap), 0);
|
429
|
-
rb_define_method(rb_cLLaMAContextParams, "use_mlock=", RUBY_METHOD_FUNC(_llama_context_params_set_use_mlock), 1);
|
430
|
-
rb_define_method(rb_cLLaMAContextParams, "use_mlock", RUBY_METHOD_FUNC(_llama_context_params_get_use_mlock), 0);
|
431
794
|
rb_define_method(rb_cLLaMAContextParams, "embedding=", RUBY_METHOD_FUNC(_llama_context_params_set_embedding), 1);
|
432
795
|
rb_define_method(rb_cLLaMAContextParams, "embedding", RUBY_METHOD_FUNC(_llama_context_params_get_embedding), 0);
|
433
796
|
}
|
@@ -441,6 +804,22 @@ private:
|
|
441
804
|
// return self;
|
442
805
|
// }
|
443
806
|
|
807
|
+
// seed
|
808
|
+
static VALUE _llama_context_params_set_seed(VALUE self, VALUE seed) {
|
809
|
+
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
810
|
+
if (NUM2INT(seed) < 0) {
|
811
|
+
rb_raise(rb_eArgError, "seed must be positive");
|
812
|
+
return Qnil;
|
813
|
+
}
|
814
|
+
ptr->params.seed = NUM2INT(seed);
|
815
|
+
return INT2NUM(ptr->params.seed);
|
816
|
+
}
|
817
|
+
|
818
|
+
static VALUE _llama_context_params_get_seed(VALUE self) {
|
819
|
+
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
820
|
+
return INT2NUM(ptr->params.seed);
|
821
|
+
}
|
822
|
+
|
444
823
|
// n_ctx
|
445
824
|
static VALUE _llama_context_params_set_n_ctx(VALUE self, VALUE n_ctx) {
|
446
825
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
@@ -465,41 +844,28 @@ private:
|
|
465
844
|
return INT2NUM(ptr->params.n_batch);
|
466
845
|
}
|
467
846
|
|
468
|
-
//
|
469
|
-
static VALUE
|
847
|
+
// n_threads
|
848
|
+
static VALUE _llama_context_params_set_n_threads(VALUE self, VALUE n_threads) {
|
470
849
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
471
|
-
ptr->params.
|
472
|
-
return INT2NUM(ptr->params.
|
850
|
+
ptr->params.n_threads = NUM2INT(n_threads);
|
851
|
+
return INT2NUM(ptr->params.n_threads);
|
473
852
|
}
|
474
853
|
|
475
|
-
static VALUE
|
854
|
+
static VALUE _llama_context_params_get_n_threads(VALUE self) {
|
476
855
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
477
|
-
return INT2NUM(ptr->params.
|
856
|
+
return INT2NUM(ptr->params.n_threads);
|
478
857
|
}
|
479
858
|
|
480
|
-
//
|
481
|
-
static VALUE
|
859
|
+
// n_threads_batch
|
860
|
+
static VALUE _llama_context_params_set_n_threads_batch(VALUE self, VALUE n_threads_batch) {
|
482
861
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
483
|
-
ptr->params.
|
484
|
-
return INT2NUM(ptr->params.
|
862
|
+
ptr->params.n_threads_batch = NUM2INT(n_threads_batch);
|
863
|
+
return INT2NUM(ptr->params.n_threads_batch);
|
485
864
|
}
|
486
865
|
|
487
|
-
static VALUE
|
866
|
+
static VALUE _llama_context_params_get_n_threads_batch(VALUE self) {
|
488
867
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
489
|
-
return INT2NUM(ptr->params.
|
490
|
-
}
|
491
|
-
|
492
|
-
// tensor_split
|
493
|
-
static VALUE _llama_context_params_get_tensor_split(VALUE self) {
|
494
|
-
if (LLAMA_MAX_DEVICES < 1) {
|
495
|
-
return rb_ary_new();
|
496
|
-
}
|
497
|
-
VALUE ret = rb_ary_new2(LLAMA_MAX_DEVICES);
|
498
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
499
|
-
for (size_t i = 0; i < LLAMA_MAX_DEVICES; i++) {
|
500
|
-
rb_ary_store(ret, i, DBL2NUM(ptr->params.tensor_split[i]));
|
501
|
-
}
|
502
|
-
return ret;
|
868
|
+
return INT2NUM(ptr->params.n_threads_batch);
|
503
869
|
}
|
504
870
|
|
505
871
|
// rope_freq_base
|
@@ -526,18 +892,6 @@ private:
|
|
526
892
|
return DBL2NUM(ptr->params.rope_freq_scale);
|
527
893
|
}
|
528
894
|
|
529
|
-
// low_vram
|
530
|
-
static VALUE _llama_context_params_set_low_vram(VALUE self, VALUE low_vram) {
|
531
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
532
|
-
ptr->params.low_vram = RTEST(low_vram) ? true : false;
|
533
|
-
return ptr->params.low_vram ? Qtrue : Qfalse;
|
534
|
-
}
|
535
|
-
|
536
|
-
static VALUE _llama_context_params_get_low_vram(VALUE self) {
|
537
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
538
|
-
return ptr->params.low_vram ? Qtrue : Qfalse;
|
539
|
-
}
|
540
|
-
|
541
895
|
// mul_mat_q
|
542
896
|
static VALUE _llama_context_params_set_mul_mat_q(VALUE self, VALUE mul_mat_q) {
|
543
897
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
@@ -550,22 +904,6 @@ private:
|
|
550
904
|
return ptr->params.mul_mat_q ? Qtrue : Qfalse;
|
551
905
|
}
|
552
906
|
|
553
|
-
// seed
|
554
|
-
static VALUE _llama_context_params_set_seed(VALUE self, VALUE seed) {
|
555
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
556
|
-
if (NUM2INT(seed) < 0) {
|
557
|
-
rb_raise(rb_eArgError, "seed must be positive");
|
558
|
-
return Qnil;
|
559
|
-
}
|
560
|
-
ptr->params.seed = NUM2INT(seed);
|
561
|
-
return INT2NUM(ptr->params.seed);
|
562
|
-
}
|
563
|
-
|
564
|
-
static VALUE _llama_context_params_get_seed(VALUE self) {
|
565
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
566
|
-
return INT2NUM(ptr->params.seed);
|
567
|
-
}
|
568
|
-
|
569
907
|
// f16_kv
|
570
908
|
static VALUE _llama_context_params_set_f16_kv(VALUE self, VALUE f16_kv) {
|
571
909
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
@@ -590,42 +928,6 @@ private:
|
|
590
928
|
return ptr->params.logits_all ? Qtrue : Qfalse;
|
591
929
|
}
|
592
930
|
|
593
|
-
// vocab_only
|
594
|
-
static VALUE _llama_context_params_set_vocab_only(VALUE self, VALUE vocab_only) {
|
595
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
596
|
-
ptr->params.vocab_only = RTEST(vocab_only) ? true : false;
|
597
|
-
return ptr->params.vocab_only ? Qtrue : Qfalse;
|
598
|
-
}
|
599
|
-
|
600
|
-
static VALUE _llama_context_params_get_vocab_only(VALUE self) {
|
601
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
602
|
-
return ptr->params.vocab_only ? Qtrue : Qfalse;
|
603
|
-
}
|
604
|
-
|
605
|
-
// use_mmap
|
606
|
-
static VALUE _llama_context_params_set_use_mmap(VALUE self, VALUE use_mmap) {
|
607
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
608
|
-
ptr->params.use_mmap = RTEST(use_mmap) ? true : false;
|
609
|
-
return ptr->params.use_mmap ? Qtrue : Qfalse;
|
610
|
-
}
|
611
|
-
|
612
|
-
static VALUE _llama_context_params_get_use_mmap(VALUE self) {
|
613
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
614
|
-
return ptr->params.use_mmap ? Qtrue : Qfalse;
|
615
|
-
}
|
616
|
-
|
617
|
-
// use_mlock
|
618
|
-
static VALUE _llama_context_params_set_use_mlock(VALUE self, VALUE use_mlock) {
|
619
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
620
|
-
ptr->params.use_mlock = RTEST(use_mlock) ? true : false;
|
621
|
-
return ptr->params.use_mlock ? Qtrue : Qfalse;
|
622
|
-
}
|
623
|
-
|
624
|
-
static VALUE _llama_context_params_get_use_mlock(VALUE self) {
|
625
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
626
|
-
return ptr->params.use_mlock ? Qtrue : Qfalse;
|
627
|
-
}
|
628
|
-
|
629
931
|
// embedding
|
630
932
|
static VALUE _llama_context_params_set_embedding(VALUE self, VALUE embedding) {
|
631
933
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
@@ -823,11 +1125,11 @@ public:
|
|
823
1125
|
rb_define_method(rb_cLLaMAModel, "load", RUBY_METHOD_FUNC(_llama_model_load), -1);
|
824
1126
|
rb_define_method(rb_cLLaMAModel, "apply_lora_from_file", RUBY_METHOD_FUNC(_llama_model_apply_lora_from_file), -1);
|
825
1127
|
rb_define_method(rb_cLLaMAModel, "n_vocab", RUBY_METHOD_FUNC(_llama_model_get_model_n_vocab), 0);
|
826
|
-
rb_define_method(rb_cLLaMAModel, "n_ctx", RUBY_METHOD_FUNC(_llama_model_get_model_n_ctx), 0);
|
827
1128
|
rb_define_method(rb_cLLaMAModel, "n_ctx_train", RUBY_METHOD_FUNC(_llama_model_get_model_n_ctx_train), 0);
|
828
1129
|
rb_define_method(rb_cLLaMAModel, "n_embd", RUBY_METHOD_FUNC(_llama_model_get_model_n_embd), 0);
|
829
|
-
rb_define_method(rb_cLLaMAModel, "
|
830
|
-
rb_define_method(rb_cLLaMAModel, "
|
1130
|
+
rb_define_method(rb_cLLaMAModel, "rope_freq_scale_train", RUBY_METHOD_FUNC(_llama_model_rope_freq_scale_train), 0);
|
1131
|
+
rb_define_method(rb_cLLaMAModel, "token_to_piece", RUBY_METHOD_FUNC(_llama_model_token_to_piece), 1);
|
1132
|
+
rb_define_method(rb_cLLaMAModel, "tokenize", RUBY_METHOD_FUNC(_llama_model_tokenize), -1);
|
831
1133
|
rb_define_method(rb_cLLaMAModel, "desc", RUBY_METHOD_FUNC(_llama_model_get_model_desc), 0);
|
832
1134
|
rb_define_method(rb_cLLaMAModel, "size", RUBY_METHOD_FUNC(_llama_model_get_model_size), 0);
|
833
1135
|
rb_define_method(rb_cLLaMAModel, "n_params", RUBY_METHOD_FUNC(_llama_model_get_model_n_params), 0);
|
@@ -841,30 +1143,21 @@ private:
|
|
841
1143
|
ID kw_table[2] = { rb_intern("model_path"), rb_intern("params") };
|
842
1144
|
VALUE kw_values[2] = { Qundef, Qundef };
|
843
1145
|
rb_scan_args(argc, argv, ":", &kw_args);
|
844
|
-
rb_get_kwargs(kw_args, kw_table,
|
845
|
-
|
846
|
-
if (kw_values[0] == Qundef && kw_values[1] == Qundef) {
|
847
|
-
rb_iv_set(self, "@params", Qnil);
|
848
|
-
return Qnil;
|
849
|
-
}
|
1146
|
+
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
850
1147
|
|
851
1148
|
if (!RB_TYPE_P(kw_values[0], T_STRING)) {
|
852
1149
|
rb_raise(rb_eArgError, "model_path must be a string");
|
853
1150
|
return Qnil;
|
854
1151
|
}
|
855
|
-
if (!rb_obj_is_kind_of(kw_values[1],
|
856
|
-
rb_raise(rb_eArgError, "params must be a
|
1152
|
+
if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAModelParams)) {
|
1153
|
+
rb_raise(rb_eArgError, "params must be a ModelParams");
|
857
1154
|
return Qnil;
|
858
1155
|
}
|
859
1156
|
|
860
1157
|
VALUE filename = kw_values[0];
|
861
|
-
|
1158
|
+
LLaMAModelParamsWrapper* prms_ptr = RbLLaMAModelParams::get_llama_model_params(kw_values[1]);
|
862
1159
|
LLaMAModelWrapper* model_ptr = get_llama_model(self);
|
863
1160
|
|
864
|
-
if (prms_ptr->params.seed == LLAMA_DEFAULT_SEED) {
|
865
|
-
prms_ptr->params.seed = time(NULL);
|
866
|
-
}
|
867
|
-
|
868
1161
|
try {
|
869
1162
|
model_ptr->model = llama_load_model_from_file(StringValueCStr(filename), prms_ptr->params);
|
870
1163
|
} catch (const std::runtime_error& e) {
|
@@ -912,8 +1205,8 @@ private:
|
|
912
1205
|
rb_raise(rb_eArgError, "model_path must be a string");
|
913
1206
|
return Qnil;
|
914
1207
|
}
|
915
|
-
if (!rb_obj_is_kind_of(kw_values[1],
|
916
|
-
rb_raise(rb_eArgError, "params must be a
|
1208
|
+
if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAModelParams)) {
|
1209
|
+
rb_raise(rb_eArgError, "params must be a LLaMAModelParams");
|
917
1210
|
return Qnil;
|
918
1211
|
}
|
919
1212
|
|
@@ -924,7 +1217,7 @@ private:
|
|
924
1217
|
}
|
925
1218
|
|
926
1219
|
VALUE filename = kw_values[0];
|
927
|
-
|
1220
|
+
LLaMAModelParamsWrapper* prms_ptr = RbLLaMAModelParams::get_llama_model_params(kw_values[1]);
|
928
1221
|
|
929
1222
|
try {
|
930
1223
|
model_ptr->model = llama_load_model_from_file(StringValueCStr(filename), prms_ptr->params);
|
@@ -946,10 +1239,10 @@ private:
|
|
946
1239
|
|
947
1240
|
static VALUE _llama_model_apply_lora_from_file(int argc, VALUE* argv, VALUE self) {
|
948
1241
|
VALUE kw_args = Qnil;
|
949
|
-
ID kw_table[
|
950
|
-
VALUE kw_values[
|
1242
|
+
ID kw_table[4] = { rb_intern("lora_path"), rb_intern("base_model_path"), rb_intern("n_threads"), rb_intern("scale") };
|
1243
|
+
VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
|
951
1244
|
rb_scan_args(argc, argv, ":", &kw_args);
|
952
|
-
rb_get_kwargs(kw_args, kw_table, 1,
|
1245
|
+
rb_get_kwargs(kw_args, kw_table, 1, 3, kw_values);
|
953
1246
|
|
954
1247
|
if (!RB_TYPE_P(kw_values[0], T_STRING)) {
|
955
1248
|
rb_raise(rb_eArgError, "lora_path must be a string");
|
@@ -963,13 +1256,18 @@ private:
|
|
963
1256
|
rb_raise(rb_eArgError, "n_threads must be an integer");
|
964
1257
|
return Qnil;
|
965
1258
|
}
|
1259
|
+
if (kw_values[3] != Qundef && !RB_FLOAT_TYPE_P(kw_values[3])) {
|
1260
|
+
rb_raise(rb_eArgError, "scale must be a float");
|
1261
|
+
return Qnil;
|
1262
|
+
}
|
966
1263
|
|
967
1264
|
const char* lora_path = StringValueCStr(kw_values[0]);
|
968
1265
|
const char* base_model_path = kw_values[1] == Qundef ? NULL : StringValueCStr(kw_values[1]);
|
969
1266
|
const int n_threads = kw_values[2] == Qundef ? 1 : NUM2INT(kw_values[2]);
|
1267
|
+
const float scale = kw_values[3] == Qundef ? 1.0 : NUM2DBL(kw_values[3]);
|
970
1268
|
|
971
1269
|
LLaMAModelWrapper* ptr = get_llama_model(self);
|
972
|
-
if (llama_model_apply_lora_from_file(ptr->model, lora_path, base_model_path, n_threads) != 0) {
|
1270
|
+
if (llama_model_apply_lora_from_file(ptr->model, lora_path, scale, base_model_path, n_threads) != 0) {
|
973
1271
|
rb_raise(rb_eRuntimeError, "Failed to apply LoRA");
|
974
1272
|
return Qnil;
|
975
1273
|
}
|
@@ -978,25 +1276,25 @@ private:
|
|
978
1276
|
|
979
1277
|
static VALUE _llama_model_get_model_n_vocab(VALUE self) {
|
980
1278
|
LLaMAModelWrapper* ptr = get_llama_model(self);
|
981
|
-
return INT2NUM(
|
1279
|
+
return INT2NUM(llama_n_vocab(ptr->model));
|
982
1280
|
}
|
983
1281
|
|
984
|
-
static VALUE
|
1282
|
+
static VALUE _llama_model_get_model_n_ctx_train(VALUE self) {
|
985
1283
|
LLaMAModelWrapper* ptr = get_llama_model(self);
|
986
|
-
return INT2NUM(
|
1284
|
+
return INT2NUM(llama_n_ctx_train(ptr->model));
|
987
1285
|
}
|
988
1286
|
|
989
|
-
static VALUE
|
1287
|
+
static VALUE _llama_model_get_model_n_embd(VALUE self) {
|
990
1288
|
LLaMAModelWrapper* ptr = get_llama_model(self);
|
991
|
-
return INT2NUM(
|
1289
|
+
return INT2NUM(llama_n_embd(ptr->model));
|
992
1290
|
}
|
993
1291
|
|
994
|
-
static VALUE
|
1292
|
+
static VALUE _llama_model_rope_freq_scale_train(VALUE self) {
|
995
1293
|
LLaMAModelWrapper* ptr = get_llama_model(self);
|
996
|
-
return
|
1294
|
+
return DBL2NUM(llama_rope_freq_scale_train(ptr->model));
|
997
1295
|
}
|
998
1296
|
|
999
|
-
static VALUE
|
1297
|
+
static VALUE _llama_model_token_to_piece(VALUE self, VALUE token_) {
|
1000
1298
|
if (!RB_INTEGER_TYPE_P(token_)) {
|
1001
1299
|
rb_raise(rb_eArgError, "token must be an integer");
|
1002
1300
|
return Qnil;
|
@@ -1004,10 +1302,10 @@ private:
|
|
1004
1302
|
const llama_token token = NUM2INT(token_);
|
1005
1303
|
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1006
1304
|
std::vector<char> result(8, 0);
|
1007
|
-
const int n_tokens =
|
1305
|
+
const int n_tokens = llama_token_to_piece(ptr->model, token, result.data(), result.size());
|
1008
1306
|
if (n_tokens < 0) {
|
1009
1307
|
result.resize(-n_tokens);
|
1010
|
-
const int check =
|
1308
|
+
const int check = llama_token_to_piece(ptr->model, token, result.data(), result.size());
|
1011
1309
|
if (check != -n_tokens) {
|
1012
1310
|
rb_raise(rb_eRuntimeError, "failed to convert");
|
1013
1311
|
return Qnil;
|
@@ -1016,10 +1314,10 @@ private:
|
|
1016
1314
|
result.resize(n_tokens);
|
1017
1315
|
}
|
1018
1316
|
std::string ret(result.data(), result.size());
|
1019
|
-
return
|
1317
|
+
return rb_utf8_str_new_cstr(ret.c_str());
|
1020
1318
|
}
|
1021
1319
|
|
1022
|
-
static VALUE
|
1320
|
+
static VALUE _llama_model_tokenize(int argc, VALUE* argv, VALUE self) {
|
1023
1321
|
VALUE kw_args = Qnil;
|
1024
1322
|
ID kw_table[3] = { rb_intern("text"), rb_intern("n_max_tokens"), rb_intern("add_bos") };
|
1025
1323
|
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
@@ -1046,7 +1344,7 @@ private:
|
|
1046
1344
|
|
1047
1345
|
llama_token* tokens = ALLOCA_N(llama_token, n_max_tokens);
|
1048
1346
|
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1049
|
-
const int n_tokens =
|
1347
|
+
const int n_tokens = llama_tokenize(ptr->model, text.c_str(), text.size(), tokens, n_max_tokens, add_bos);
|
1050
1348
|
|
1051
1349
|
if (n_tokens < 0) {
|
1052
1350
|
rb_raise(rb_eRuntimeError, "failed to tokenize. The numebr of tokens (%d) is greater than n_max_tokens.", -n_tokens);
|
@@ -1066,7 +1364,7 @@ private:
|
|
1066
1364
|
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1067
1365
|
char buf[128];
|
1068
1366
|
llama_model_desc(ptr->model, buf, sizeof(buf));
|
1069
|
-
return
|
1367
|
+
return rb_utf8_str_new_cstr(buf);
|
1070
1368
|
}
|
1071
1369
|
|
1072
1370
|
static VALUE _llama_model_get_model_size(VALUE self) {
|
@@ -1345,11 +1643,11 @@ public:
|
|
1345
1643
|
static void define_class(VALUE outer) {
|
1346
1644
|
rb_cLLaMAContext = rb_define_class_under(outer, "Context", rb_cObject);
|
1347
1645
|
rb_define_alloc_func(rb_cLLaMAContext, llama_context_alloc);
|
1646
|
+
rb_define_attr(rb_cLLaMAContext, "model", 1, 0);
|
1348
1647
|
rb_define_method(rb_cLLaMAContext, "initialize", RUBY_METHOD_FUNC(_llama_context_initialize), -1);
|
1349
1648
|
rb_define_method(rb_cLLaMAContext, "eval", RUBY_METHOD_FUNC(_llama_context_eval), -1);
|
1350
1649
|
rb_define_method(rb_cLLaMAContext, "eval_embd", RUBY_METHOD_FUNC(_llama_context_eval_embd), -1);
|
1351
|
-
rb_define_method(rb_cLLaMAContext, "
|
1352
|
-
rb_define_method(rb_cLLaMAContext, "tokenize", RUBY_METHOD_FUNC(_llama_context_tokenize), -1);
|
1650
|
+
rb_define_method(rb_cLLaMAContext, "decode", RUBY_METHOD_FUNC(_llama_context_decode), 1);
|
1353
1651
|
rb_define_method(rb_cLLaMAContext, "logits", RUBY_METHOD_FUNC(_llama_context_logits), 0);
|
1354
1652
|
rb_define_method(rb_cLLaMAContext, "embeddings", RUBY_METHOD_FUNC(_llama_context_embeddings), 0);
|
1355
1653
|
rb_define_method(rb_cLLaMAContext, "text", RUBY_METHOD_FUNC(_llama_context_text), 1);
|
@@ -1358,15 +1656,20 @@ public:
|
|
1358
1656
|
rb_define_method(rb_cLLaMAContext, "token_bos", RUBY_METHOD_FUNC(_llama_context_token_bos), 0);
|
1359
1657
|
rb_define_method(rb_cLLaMAContext, "token_eos", RUBY_METHOD_FUNC(_llama_context_token_eos), 0);
|
1360
1658
|
rb_define_method(rb_cLLaMAContext, "token_nl", RUBY_METHOD_FUNC(_llama_context_token_nl), 0);
|
1361
|
-
rb_define_method(rb_cLLaMAContext, "
|
1362
|
-
rb_define_method(rb_cLLaMAContext, "
|
1659
|
+
rb_define_method(rb_cLLaMAContext, "token_prefix", RUBY_METHOD_FUNC(_llama_context_token_prefix), 0);
|
1660
|
+
rb_define_method(rb_cLLaMAContext, "token_middle", RUBY_METHOD_FUNC(_llama_context_token_middle), 0);
|
1661
|
+
rb_define_method(rb_cLLaMAContext, "token_suffix", RUBY_METHOD_FUNC(_llama_context_token_suffix), 0);
|
1662
|
+
rb_define_method(rb_cLLaMAContext, "token_eot", RUBY_METHOD_FUNC(_llama_context_token_eot), 0);
|
1363
1663
|
rb_define_method(rb_cLLaMAContext, "n_ctx", RUBY_METHOD_FUNC(_llama_context_n_ctx), 0);
|
1364
|
-
rb_define_method(rb_cLLaMAContext, "n_ctx_train", RUBY_METHOD_FUNC(_llama_context_n_ctx_train), 0);
|
1365
|
-
rb_define_method(rb_cLLaMAContext, "n_embd", RUBY_METHOD_FUNC(_llama_context_n_embd), 0);
|
1366
1664
|
rb_define_method(rb_cLLaMAContext, "timings", RUBY_METHOD_FUNC(_llama_context_get_timings), 0);
|
1367
1665
|
rb_define_method(rb_cLLaMAContext, "print_timings", RUBY_METHOD_FUNC(_llama_context_print_timings), 0);
|
1368
1666
|
rb_define_method(rb_cLLaMAContext, "reset_timings", RUBY_METHOD_FUNC(_llama_context_reset_timings), 0);
|
1369
1667
|
rb_define_method(rb_cLLaMAContext, "kv_cache_token_count", RUBY_METHOD_FUNC(_llama_context_kv_cache_token_count), 0);
|
1668
|
+
rb_define_method(rb_cLLaMAContext, "kv_cache_tokens_rm", RUBY_METHOD_FUNC(_llama_context_kv_cache_tokens_rm), 2);
|
1669
|
+
rb_define_method(rb_cLLaMAContext, "kv_cache_seq_rm", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_rm), 3);
|
1670
|
+
rb_define_method(rb_cLLaMAContext, "kv_cache_seq_cp", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_cp), 4);
|
1671
|
+
rb_define_method(rb_cLLaMAContext, "kv_cache_seq_keep", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_keep), 1);
|
1672
|
+
rb_define_method(rb_cLLaMAContext, "kv_cache_seq_shift", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_shift), 4);
|
1370
1673
|
rb_define_method(rb_cLLaMAContext, "set_rng_seed", RUBY_METHOD_FUNC(_llama_context_set_rng_seed), 1);
|
1371
1674
|
rb_define_method(rb_cLLaMAContext, "load_session_file", RUBY_METHOD_FUNC(_llama_context_load_session_file), -1);
|
1372
1675
|
rb_define_method(rb_cLLaMAContext, "save_session_file", RUBY_METHOD_FUNC(_llama_context_save_session_file), -1);
|
@@ -1378,6 +1681,7 @@ public:
|
|
1378
1681
|
rb_define_method(rb_cLLaMAContext, "sample_top_p", RUBY_METHOD_FUNC(_llama_context_sample_top_p), -1);
|
1379
1682
|
rb_define_method(rb_cLLaMAContext, "sample_tail_free", RUBY_METHOD_FUNC(_llama_context_sample_tail_free), -1);
|
1380
1683
|
rb_define_method(rb_cLLaMAContext, "sample_typical", RUBY_METHOD_FUNC(_llama_context_sample_typical), -1);
|
1684
|
+
rb_define_method(rb_cLLaMAContext, "sample_temp", RUBY_METHOD_FUNC(_llama_context_sample_temp), -1);
|
1381
1685
|
rb_define_method(rb_cLLaMAContext, "sample_temperature", RUBY_METHOD_FUNC(_llama_context_sample_temperature), -1);
|
1382
1686
|
rb_define_method(rb_cLLaMAContext, "sample_token_mirostat", RUBY_METHOD_FUNC(_llama_context_sample_token_mirostat), -1);
|
1383
1687
|
rb_define_method(rb_cLLaMAContext, "sample_token_mirostat_v2", RUBY_METHOD_FUNC(_llama_context_sample_token_mirostat_v2), -1);
|
@@ -1392,24 +1696,27 @@ private:
|
|
1392
1696
|
|
1393
1697
|
static VALUE _llama_context_initialize(int argc, VALUE* argv, VALUE self) {
|
1394
1698
|
VALUE kw_args = Qnil;
|
1395
|
-
ID kw_table[
|
1396
|
-
VALUE kw_values[
|
1699
|
+
ID kw_table[2] = { rb_intern("model"), rb_intern("params") };
|
1700
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
1397
1701
|
rb_scan_args(argc, argv, ":", &kw_args);
|
1398
|
-
rb_get_kwargs(kw_args, kw_table,
|
1702
|
+
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
1399
1703
|
|
1400
1704
|
VALUE model = kw_values[0];
|
1401
1705
|
if (!rb_obj_is_kind_of(model, rb_cLLaMAModel)) {
|
1402
1706
|
rb_raise(rb_eArgError, "model must be a Model");
|
1403
1707
|
return Qnil;
|
1404
1708
|
}
|
1709
|
+
VALUE params = kw_values[1];
|
1710
|
+
if (!rb_obj_is_kind_of(params, rb_cLLaMAContextParams)) {
|
1711
|
+
rb_raise(rb_eArgError, "params must be a ContextParams");
|
1712
|
+
return Qnil;
|
1713
|
+
}
|
1405
1714
|
|
1406
1715
|
LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
|
1407
1716
|
if (model_ptr->model == NULL) {
|
1408
1717
|
rb_raise(rb_eRuntimeError, "Model is empty");
|
1409
1718
|
return Qnil;
|
1410
1719
|
}
|
1411
|
-
|
1412
|
-
VALUE params = rb_iv_get(model, "@params");
|
1413
1720
|
LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
|
1414
1721
|
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
1415
1722
|
|
@@ -1421,6 +1728,7 @@ private:
|
|
1421
1728
|
}
|
1422
1729
|
|
1423
1730
|
rb_iv_set(self, "@model", model);
|
1731
|
+
rb_iv_set(self, "@params", params);
|
1424
1732
|
rb_iv_set(self, "@has_evaluated", Qfalse);
|
1425
1733
|
|
1426
1734
|
return Qnil;
|
@@ -1428,8 +1736,8 @@ private:
|
|
1428
1736
|
|
1429
1737
|
static VALUE _llama_context_eval(int argc, VALUE* argv, VALUE self) {
|
1430
1738
|
VALUE kw_args = Qnil;
|
1431
|
-
ID kw_table[
|
1432
|
-
VALUE kw_values[
|
1739
|
+
ID kw_table[3] = { rb_intern("tokens"), rb_intern("n_past"), rb_intern("n_tokens") };
|
1740
|
+
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
1433
1741
|
rb_scan_args(argc, argv, ":", &kw_args);
|
1434
1742
|
rb_get_kwargs(kw_args, kw_table, 2, 2, kw_values);
|
1435
1743
|
|
@@ -1445,10 +1753,6 @@ private:
|
|
1445
1753
|
rb_raise(rb_eArgError, "n_tokens must be an integer");
|
1446
1754
|
return Qnil;
|
1447
1755
|
}
|
1448
|
-
if (kw_values[3] != Qundef && !RB_INTEGER_TYPE_P(kw_values[3])) {
|
1449
|
-
rb_raise(rb_eArgError, "n_threads must be an integer");
|
1450
|
-
return Qnil;
|
1451
|
-
}
|
1452
1756
|
|
1453
1757
|
const size_t tokens_len = RARRAY_LEN(kw_values[0]);
|
1454
1758
|
std::vector<llama_token> embd(tokens_len);
|
@@ -1463,14 +1767,13 @@ private:
|
|
1463
1767
|
|
1464
1768
|
const int n_tokens = kw_values[2] == Qundef ? (int)tokens_len : NUM2INT(kw_values[2]);
|
1465
1769
|
const int n_past = NUM2INT(kw_values[1]);
|
1466
|
-
const int n_threads = kw_values[3] == Qundef ? 1 : NUM2INT(kw_values[3]);
|
1467
1770
|
|
1468
1771
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1469
1772
|
if (ptr->ctx == NULL) {
|
1470
1773
|
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1471
1774
|
return Qnil;
|
1472
1775
|
}
|
1473
|
-
if (llama_eval(ptr->ctx, embd.data(), n_tokens, n_past
|
1776
|
+
if (llama_eval(ptr->ctx, embd.data(), n_tokens, n_past) != 0) {
|
1474
1777
|
rb_raise(rb_eRuntimeError, "Failed to evaluate");
|
1475
1778
|
return Qnil;
|
1476
1779
|
}
|
@@ -1483,8 +1786,8 @@ private:
|
|
1483
1786
|
|
1484
1787
|
static VALUE _llama_context_eval_embd(int argc, VALUE* argv, VALUE self) {
|
1485
1788
|
VALUE kw_args = Qnil;
|
1486
|
-
ID kw_table[
|
1487
|
-
VALUE kw_values[
|
1789
|
+
ID kw_table[3] = { rb_intern("embd"), rb_intern("n_past"), rb_intern("n_tokens") };
|
1790
|
+
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
1488
1791
|
rb_scan_args(argc, argv, ":", &kw_args);
|
1489
1792
|
rb_get_kwargs(kw_args, kw_table, 2, 2, kw_values);
|
1490
1793
|
|
@@ -1500,10 +1803,6 @@ private:
|
|
1500
1803
|
rb_raise(rb_eArgError, "n_tokens must be an integer");
|
1501
1804
|
return Qnil;
|
1502
1805
|
}
|
1503
|
-
if (kw_values[3] != Qundef && !RB_INTEGER_TYPE_P(kw_values[3])) {
|
1504
|
-
rb_raise(rb_eArgError, "n_threads must be an integer");
|
1505
|
-
return Qnil;
|
1506
|
-
}
|
1507
1806
|
|
1508
1807
|
const size_t tokens_len = RARRAY_LEN(kw_values[0]);
|
1509
1808
|
std::vector<float> embd(tokens_len);
|
@@ -1518,14 +1817,13 @@ private:
|
|
1518
1817
|
|
1519
1818
|
const int n_tokens = kw_values[2] == Qundef ? (int)tokens_len : NUM2INT(kw_values[2]);
|
1520
1819
|
const int n_past = NUM2INT(kw_values[1]);
|
1521
|
-
const int n_threads = kw_values[3] == Qundef ? 1 : NUM2INT(kw_values[3]);
|
1522
1820
|
|
1523
1821
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1524
1822
|
if (ptr->ctx == NULL) {
|
1525
1823
|
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1526
1824
|
return Qnil;
|
1527
1825
|
}
|
1528
|
-
if (llama_eval_embd(ptr->ctx, embd.data(), n_tokens, n_past
|
1826
|
+
if (llama_eval_embd(ptr->ctx, embd.data(), n_tokens, n_past) != 0) {
|
1529
1827
|
rb_raise(rb_eRuntimeError, "Failed to evaluate");
|
1530
1828
|
return Qnil;
|
1531
1829
|
}
|
@@ -1536,91 +1834,22 @@ private:
|
|
1536
1834
|
return Qnil;
|
1537
1835
|
}
|
1538
1836
|
|
1539
|
-
static VALUE
|
1837
|
+
static VALUE _llama_context_decode(VALUE self, VALUE batch) {
|
1540
1838
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1541
1839
|
if (ptr->ctx == NULL) {
|
1542
1840
|
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1543
1841
|
return Qnil;
|
1544
1842
|
}
|
1545
|
-
if (!
|
1546
|
-
rb_raise(rb_eArgError, "
|
1547
|
-
return Qnil;
|
1548
|
-
}
|
1549
|
-
const char* fname = StringValueCStr(fname_);
|
1550
|
-
if (llama_eval_export(ptr->ctx, fname) != 0) {
|
1551
|
-
return Qfalse;
|
1552
|
-
}
|
1553
|
-
RB_GC_GUARD(fname_);
|
1554
|
-
return Qtrue;
|
1555
|
-
}
|
1556
|
-
|
1557
|
-
static VALUE _llama_context_tokenize(int argc, VALUE* argv, VALUE self) {
|
1558
|
-
VALUE kw_args = Qnil;
|
1559
|
-
ID kw_table[3] = { rb_intern("text"), rb_intern("n_max_tokens"), rb_intern("add_bos") };
|
1560
|
-
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
1561
|
-
rb_scan_args(argc, argv, ":", &kw_args);
|
1562
|
-
rb_get_kwargs(kw_args, kw_table, 1, 2, kw_values);
|
1563
|
-
|
1564
|
-
if (!RB_TYPE_P(kw_values[0], T_STRING)) {
|
1565
|
-
rb_raise(rb_eArgError, "text must be a String");
|
1566
|
-
return Qnil;
|
1567
|
-
}
|
1568
|
-
if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
|
1569
|
-
rb_raise(rb_eArgError, "n_max_tokens must be an integer");
|
1843
|
+
if (!rb_obj_is_kind_of(batch, rb_cLLaMABatch)) {
|
1844
|
+
rb_raise(rb_eArgError, "batch must be a Batch");
|
1570
1845
|
return Qnil;
|
1571
1846
|
}
|
1572
|
-
|
1573
|
-
|
1847
|
+
LLaMABatchWrapper* batch_ptr = RbLLaMABatch::get_llama_batch(batch);
|
1848
|
+
if (llama_decode(ptr->ctx, batch_ptr->batch) < 0) {
|
1849
|
+
rb_raise(rb_eRuntimeError, "Failed to decode");
|
1574
1850
|
return Qnil;
|
1575
1851
|
}
|
1576
|
-
|
1577
|
-
VALUE text_ = kw_values[0];
|
1578
|
-
std::string text = StringValueCStr(text_);
|
1579
|
-
const bool add_bos = kw_values[2] == Qtrue ? true : false;
|
1580
|
-
const int n_max_tokens = kw_values[1] != Qundef ? NUM2INT(kw_values[1]) : text.size() + (add_bos ? 1 : 0);
|
1581
|
-
|
1582
|
-
std::vector<llama_token> tokens(n_max_tokens);
|
1583
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1584
|
-
if (ptr->ctx == NULL) {
|
1585
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1586
|
-
return Qnil;
|
1587
|
-
}
|
1588
|
-
const int n = llama_tokenize(ptr->ctx, text.c_str(), text.size(), tokens.data(), n_max_tokens, add_bos);
|
1589
|
-
if (n < 0) {
|
1590
|
-
rb_raise(rb_eRuntimeError, "Failed to tokenize");
|
1591
|
-
return Qnil;
|
1592
|
-
}
|
1593
|
-
|
1594
|
-
VALUE output = rb_ary_new();
|
1595
|
-
for (int i = 0; i < n; i++) {
|
1596
|
-
rb_ary_push(output, INT2NUM(tokens[i]));
|
1597
|
-
}
|
1598
|
-
|
1599
|
-
RB_GC_GUARD(text_);
|
1600
|
-
return output;
|
1601
|
-
}
|
1602
|
-
|
1603
|
-
static VALUE _llama_context_token_to_piece(VALUE self, VALUE token_) {
|
1604
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1605
|
-
if (ptr->ctx == NULL) {
|
1606
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1607
|
-
return Qnil;
|
1608
|
-
}
|
1609
|
-
const llama_token token = NUM2INT(token_);
|
1610
|
-
std::vector<char> result(8, 0);
|
1611
|
-
const int n_tokens = llama_token_to_piece(ptr->ctx, token, result.data(), result.size());
|
1612
|
-
if (n_tokens < 0) {
|
1613
|
-
result.resize(-n_tokens);
|
1614
|
-
const int check = llama_token_to_piece(ptr->ctx, token, result.data(), result.size());
|
1615
|
-
if (check != -n_tokens) {
|
1616
|
-
rb_raise(rb_eRuntimeError, "failed to convert");
|
1617
|
-
return Qnil;
|
1618
|
-
}
|
1619
|
-
} else {
|
1620
|
-
result.resize(n_tokens);
|
1621
|
-
}
|
1622
|
-
std::string ret(result.data(), result.size());
|
1623
|
-
return rb_str_new_cstr(ret.c_str());
|
1852
|
+
return Qnil;
|
1624
1853
|
}
|
1625
1854
|
|
1626
1855
|
static VALUE _llama_context_logits(VALUE self) {
|
@@ -1635,10 +1864,11 @@ private:
|
|
1635
1864
|
}
|
1636
1865
|
|
1637
1866
|
VALUE model = rb_iv_get(self, "@model");
|
1638
|
-
|
1867
|
+
LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
|
1868
|
+
VALUE params = rb_iv_get(self, "@params");
|
1639
1869
|
LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
|
1640
1870
|
const int n_tokens = prms_ptr->params.logits_all ? NUM2INT(rb_iv_get(self, "@n_tokens")) : 1;
|
1641
|
-
const int n_vocab = llama_n_vocab(
|
1871
|
+
const int n_vocab = llama_n_vocab(model_ptr->model);
|
1642
1872
|
const float* logits = llama_get_logits(ptr->ctx);
|
1643
1873
|
VALUE output = rb_ary_new();
|
1644
1874
|
for (int i = 0; i < n_tokens * n_vocab; i++) {
|
@@ -1655,7 +1885,8 @@ private:
|
|
1655
1885
|
return Qnil;
|
1656
1886
|
}
|
1657
1887
|
VALUE model = rb_iv_get(self, "@model");
|
1658
|
-
|
1888
|
+
LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
|
1889
|
+
VALUE params = rb_iv_get(self, "@params");
|
1659
1890
|
LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
|
1660
1891
|
if (!prms_ptr->params.embedding) {
|
1661
1892
|
rb_raise(rb_eRuntimeError, "embedding parameter is false");
|
@@ -1666,7 +1897,7 @@ private:
|
|
1666
1897
|
return Qnil;
|
1667
1898
|
}
|
1668
1899
|
|
1669
|
-
const int n_embd = llama_n_embd(
|
1900
|
+
const int n_embd = llama_n_embd(model_ptr->model);
|
1670
1901
|
const float* embd = llama_get_embeddings(ptr->ctx);
|
1671
1902
|
VALUE output = rb_ary_new();
|
1672
1903
|
for (int i = 0; i < n_embd; i++) {
|
@@ -1684,7 +1915,7 @@ private:
|
|
1684
1915
|
}
|
1685
1916
|
const llama_token token = NUM2INT(token_);
|
1686
1917
|
const char* text = llama_token_get_text(ptr->ctx, token);
|
1687
|
-
return
|
1918
|
+
return rb_utf8_str_new_cstr(text);
|
1688
1919
|
}
|
1689
1920
|
|
1690
1921
|
static VALUE _llama_context_score(VALUE self, VALUE token_) {
|
@@ -1736,40 +1967,49 @@ private:
|
|
1736
1967
|
return INT2NUM(llama_token_nl(ptr->ctx));
|
1737
1968
|
}
|
1738
1969
|
|
1739
|
-
static VALUE
|
1970
|
+
static VALUE _llama_context_token_prefix(VALUE self) {
|
1740
1971
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1741
1972
|
if (ptr->ctx == NULL) {
|
1742
1973
|
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1743
1974
|
return Qnil;
|
1744
1975
|
}
|
1745
|
-
return INT2NUM(
|
1976
|
+
return INT2NUM(llama_token_prefix(ptr->ctx));
|
1746
1977
|
}
|
1747
1978
|
|
1748
|
-
static VALUE
|
1979
|
+
static VALUE _llama_context_token_middle(VALUE self) {
|
1749
1980
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1750
1981
|
if (ptr->ctx == NULL) {
|
1751
1982
|
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1752
1983
|
return Qnil;
|
1753
1984
|
}
|
1754
|
-
return INT2NUM(
|
1985
|
+
return INT2NUM(llama_token_middle(ptr->ctx));
|
1986
|
+
}
|
1987
|
+
|
1988
|
+
static VALUE _llama_context_token_suffix(VALUE self) {
|
1989
|
+
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1990
|
+
if (ptr->ctx == NULL) {
|
1991
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1992
|
+
return Qnil;
|
1993
|
+
}
|
1994
|
+
return INT2NUM(llama_token_suffix(ptr->ctx));
|
1755
1995
|
}
|
1756
1996
|
|
1757
|
-
static VALUE
|
1997
|
+
static VALUE _llama_context_token_eot(VALUE self) {
|
1758
1998
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1759
1999
|
if (ptr->ctx == NULL) {
|
1760
2000
|
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1761
2001
|
return Qnil;
|
1762
2002
|
}
|
1763
|
-
return INT2NUM(
|
2003
|
+
return INT2NUM(llama_token_eot(ptr->ctx));
|
1764
2004
|
}
|
1765
2005
|
|
1766
|
-
static VALUE
|
2006
|
+
static VALUE _llama_context_n_ctx(VALUE self) {
|
1767
2007
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1768
2008
|
if (ptr->ctx == NULL) {
|
1769
2009
|
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1770
2010
|
return Qnil;
|
1771
2011
|
}
|
1772
|
-
return INT2NUM(
|
2012
|
+
return INT2NUM(llama_n_ctx(ptr->ctx));
|
1773
2013
|
}
|
1774
2014
|
|
1775
2015
|
static VALUE _llama_context_get_timings(VALUE self) {
|
@@ -1813,6 +2053,56 @@ private:
|
|
1813
2053
|
return INT2NUM(llama_get_kv_cache_token_count(ptr->ctx));
|
1814
2054
|
}
|
1815
2055
|
|
2056
|
+
static VALUE _llama_context_kv_cache_tokens_rm(VALUE self, VALUE c0, VALUE c1) {
|
2057
|
+
LLaMAContextWrapper* ptr = get_llama_context(self);
|
2058
|
+
if (ptr->ctx == NULL) {
|
2059
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
2060
|
+
return Qnil;
|
2061
|
+
}
|
2062
|
+
llama_kv_cache_tokens_rm(ptr->ctx, NUM2INT(c0), NUM2INT(c1));
|
2063
|
+
return Qnil;
|
2064
|
+
}
|
2065
|
+
|
2066
|
+
static VALUE _llama_context_kv_cache_seq_rm(VALUE self, VALUE seq_id, VALUE p0, VALUE p1) {
|
2067
|
+
LLaMAContextWrapper* ptr = get_llama_context(self);
|
2068
|
+
if (ptr->ctx == NULL) {
|
2069
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
2070
|
+
return Qnil;
|
2071
|
+
}
|
2072
|
+
llama_kv_cache_seq_rm(ptr->ctx, NUM2INT(seq_id), NUM2INT(p0), NUM2INT(p1));
|
2073
|
+
return Qnil;
|
2074
|
+
}
|
2075
|
+
|
2076
|
+
static VALUE _llama_context_kv_cache_seq_cp(VALUE self, VALUE seq_id_src, VALUE seq_id_dst, VALUE p0, VALUE p1) {
|
2077
|
+
LLaMAContextWrapper* ptr = get_llama_context(self);
|
2078
|
+
if (ptr->ctx == NULL) {
|
2079
|
+
rb_raise(rb_eArgError, "LLaMA context is not initialized");
|
2080
|
+
return Qnil;
|
2081
|
+
}
|
2082
|
+
llama_kv_cache_seq_cp(ptr->ctx, NUM2INT(seq_id_src), NUM2INT(seq_id_dst), NUM2INT(p0), NUM2INT(p1));
|
2083
|
+
return Qnil;
|
2084
|
+
}
|
2085
|
+
|
2086
|
+
static VALUE _llama_context_kv_cache_seq_keep(VALUE self, VALUE seq_id) {
|
2087
|
+
LLaMAContextWrapper* ptr = get_llama_context(self);
|
2088
|
+
if (ptr->ctx == NULL) {
|
2089
|
+
rb_raise(rb_eArgError, "LLaMA context is not initialized");
|
2090
|
+
return Qnil;
|
2091
|
+
}
|
2092
|
+
llama_kv_cache_seq_keep(ptr->ctx, NUM2INT(seq_id));
|
2093
|
+
return Qnil;
|
2094
|
+
}
|
2095
|
+
|
2096
|
+
static VALUE _llama_context_kv_cache_seq_shift(VALUE self, VALUE seq_id, VALUE p0, VALUE p1, VALUE delta) {
|
2097
|
+
LLaMAContextWrapper* ptr = get_llama_context(self);
|
2098
|
+
if (ptr->ctx == NULL) {
|
2099
|
+
rb_raise(rb_eArgError, "LLaMA context is not initialized");
|
2100
|
+
return Qnil;
|
2101
|
+
}
|
2102
|
+
llama_kv_cache_seq_shift(ptr->ctx, NUM2INT(seq_id), NUM2INT(p0), NUM2INT(p1), NUM2INT(delta));
|
2103
|
+
return Qnil;
|
2104
|
+
}
|
2105
|
+
|
1816
2106
|
static VALUE _llama_context_set_rng_seed(VALUE self, VALUE seed_) {
|
1817
2107
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1818
2108
|
if (ptr->ctx == NULL) {
|
@@ -1851,7 +2141,7 @@ private:
|
|
1851
2141
|
}
|
1852
2142
|
|
1853
2143
|
VALUE model = rb_iv_get(self, "@model");
|
1854
|
-
VALUE params = rb_iv_get(
|
2144
|
+
VALUE params = rb_iv_get(self, "@params");
|
1855
2145
|
LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
|
1856
2146
|
const int n_ctx = prms_ptr->params.n_ctx;
|
1857
2147
|
|
@@ -2235,6 +2525,40 @@ private:
|
|
2235
2525
|
return Qnil;
|
2236
2526
|
}
|
2237
2527
|
|
2528
|
+
static VALUE _llama_context_sample_temp(int argc, VALUE* argv, VALUE self) {
|
2529
|
+
VALUE kw_args = Qnil;
|
2530
|
+
ID kw_table[1] = { rb_intern("temp") };
|
2531
|
+
VALUE kw_values[1] = { Qundef };
|
2532
|
+
VALUE candidates = Qnil;
|
2533
|
+
rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
|
2534
|
+
rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
|
2535
|
+
|
2536
|
+
if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
|
2537
|
+
rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
|
2538
|
+
return Qnil;
|
2539
|
+
}
|
2540
|
+
if (!RB_FLOAT_TYPE_P(kw_values[0])) {
|
2541
|
+
rb_raise(rb_eArgError, "temp must be a float");
|
2542
|
+
return Qnil;
|
2543
|
+
}
|
2544
|
+
|
2545
|
+
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
2546
|
+
if (ctx_ptr->ctx == NULL) {
|
2547
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
2548
|
+
return Qnil;
|
2549
|
+
}
|
2550
|
+
LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
|
2551
|
+
if (cnd_ptr->array.data == nullptr) {
|
2552
|
+
rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
|
2553
|
+
return Qnil;
|
2554
|
+
}
|
2555
|
+
const float temp = NUM2DBL(kw_values[0]);
|
2556
|
+
|
2557
|
+
llama_sample_temp(ctx_ptr->ctx, &(cnd_ptr->array), temp);
|
2558
|
+
|
2559
|
+
return Qnil;
|
2560
|
+
}
|
2561
|
+
|
2238
2562
|
static VALUE _llama_context_sample_temperature(int argc, VALUE* argv, VALUE self) {
|
2239
2563
|
VALUE kw_args = Qnil;
|
2240
2564
|
ID kw_table[1] = { rb_intern("temperature") };
|
@@ -2560,6 +2884,7 @@ extern "C" void Init_llama_cpp(void) {
|
|
2560
2884
|
RbLLaMATokenData::define_class(rb_mLLaMACpp);
|
2561
2885
|
RbLLaMATokenDataArray::define_class(rb_mLLaMACpp);
|
2562
2886
|
RbLLaMAModel::define_class(rb_mLLaMACpp);
|
2887
|
+
RbLLaMAModelParams::define_class(rb_mLLaMACpp);
|
2563
2888
|
RbLLaMATimings::define_class(rb_mLLaMACpp);
|
2564
2889
|
RbLLaMAContext::define_class(rb_mLLaMACpp);
|
2565
2890
|
RbLLaMAContextParams::define_class(rb_mLLaMACpp);
|
@@ -2578,10 +2903,6 @@ extern "C" void Init_llama_cpp(void) {
|
|
2578
2903
|
|
2579
2904
|
rb_define_const(rb_mLLaMACpp, "LLAMA_MAX_DEVICES", INT2NUM(LLAMA_MAX_DEVICES));
|
2580
2905
|
|
2581
|
-
rb_define_const(rb_mLLaMACpp, "LLAMA_LOG_LEVEL_ERROR", INT2NUM(LLAMA_LOG_LEVEL_ERROR));
|
2582
|
-
rb_define_const(rb_mLLaMACpp, "LLAMA_LOG_LEVEL_WARN", INT2NUM(LLAMA_LOG_LEVEL_WARN));
|
2583
|
-
rb_define_const(rb_mLLaMACpp, "LLAMA_LOG_LEVEL_INFO", INT2NUM(LLAMA_LOG_LEVEL_INFO));
|
2584
|
-
|
2585
2906
|
rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_TYPE_SPM", INT2NUM(LLAMA_VOCAB_TYPE_SPM));
|
2586
2907
|
rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_TYPE_BPE", INT2NUM(LLAMA_VOCAB_TYPE_BPE));
|
2587
2908
|
|