llama_cpp 0.5.3 → 0.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/README.md +6 -5
- data/examples/chat.rb +13 -13
- data/examples/embedding.rb +9 -9
- data/ext/llama_cpp/llama_cpp.cpp +547 -272
- data/ext/llama_cpp/src/ggml-alloc.c +8 -2
- data/ext/llama_cpp/src/ggml-alloc.h +1 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +209 -82
- data/ext/llama_cpp/src/ggml-cuda.h +1 -0
- data/ext/llama_cpp/src/ggml-metal.h +4 -0
- data/ext/llama_cpp/src/ggml-metal.m +163 -84
- data/ext/llama_cpp/src/ggml-metal.metal +121 -38
- data/ext/llama_cpp/src/ggml.c +1596 -842
- data/ext/llama_cpp/src/ggml.h +116 -35
- data/ext/llama_cpp/src/llama.cpp +1015 -586
- data/ext/llama_cpp/src/llama.h +304 -119
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +5 -9
- data/sig/llama_cpp.rbs +65 -34
- metadata +3 -3
data/ext/llama_cpp/llama_cpp.cpp
CHANGED
@@ -1,7 +1,9 @@
|
|
1
1
|
#include "llama_cpp.h"
|
2
2
|
|
3
3
|
VALUE rb_mLLaMACpp;
|
4
|
+
VALUE rb_cLLaMABatch;
|
4
5
|
VALUE rb_cLLaMAModel;
|
6
|
+
VALUE rb_cLLaMAModelParams;
|
5
7
|
VALUE rb_cLLaMATimings;
|
6
8
|
VALUE rb_cLLaMAContext;
|
7
9
|
VALUE rb_cLLaMAContextParams;
|
@@ -11,6 +13,238 @@ VALUE rb_cLLaMATokenDataArray;
|
|
11
13
|
VALUE rb_cLLaMAGrammarElement;
|
12
14
|
VALUE rb_cLLaMAGrammar;
|
13
15
|
|
16
|
+
class LLaMABatchWrapper {
|
17
|
+
public:
|
18
|
+
llama_batch batch;
|
19
|
+
|
20
|
+
LLaMABatchWrapper() {}
|
21
|
+
|
22
|
+
~LLaMABatchWrapper() {
|
23
|
+
llama_batch_free(batch);
|
24
|
+
}
|
25
|
+
};
|
26
|
+
|
27
|
+
class RbLLaMABatch {
|
28
|
+
public:
|
29
|
+
static VALUE llama_batch_alloc(VALUE self) {
|
30
|
+
LLaMABatchWrapper* ptr = (LLaMABatchWrapper*)ruby_xmalloc(sizeof(LLaMABatchWrapper));
|
31
|
+
new (ptr) LLaMABatchWrapper();
|
32
|
+
return TypedData_Wrap_Struct(self, &llama_batch_type, ptr);
|
33
|
+
}
|
34
|
+
|
35
|
+
static void llama_batch_free(void* ptr) {
|
36
|
+
((LLaMABatchWrapper*)ptr)->~LLaMABatchWrapper();
|
37
|
+
ruby_xfree(ptr);
|
38
|
+
}
|
39
|
+
|
40
|
+
static size_t llama_batch_size(const void* ptr) {
|
41
|
+
return sizeof(*((LLaMABatchWrapper*)ptr));
|
42
|
+
}
|
43
|
+
|
44
|
+
static LLaMABatchWrapper* get_llama_batch(VALUE self) {
|
45
|
+
LLaMABatchWrapper* ptr;
|
46
|
+
TypedData_Get_Struct(self, LLaMABatchWrapper, &llama_batch_type, ptr);
|
47
|
+
return ptr;
|
48
|
+
}
|
49
|
+
|
50
|
+
static void define_class(VALUE outer) {
|
51
|
+
rb_cLLaMABatch = rb_define_class_under(outer, "Batch", rb_cObject);
|
52
|
+
rb_define_alloc_func(rb_cLLaMABatch, llama_batch_alloc);
|
53
|
+
rb_define_method(rb_cLLaMABatch, "initialize", RUBY_METHOD_FUNC(_llama_batch_initialize), -1);
|
54
|
+
rb_define_method(rb_cLLaMABatch, "n_tokens=", RUBY_METHOD_FUNC(_llama_batch_set_n_tokens), 1);
|
55
|
+
rb_define_method(rb_cLLaMABatch, "n_tokens", RUBY_METHOD_FUNC(_llama_batch_get_n_tokens), 0);
|
56
|
+
rb_define_method(rb_cLLaMABatch, "all_pos_zero=", RUBY_METHOD_FUNC(_llama_batch_set_all_pos_zero), 1);
|
57
|
+
rb_define_method(rb_cLLaMABatch, "all_pos_zero", RUBY_METHOD_FUNC(_llama_batch_get_all_pos_zero), 0);
|
58
|
+
rb_define_method(rb_cLLaMABatch, "all_pos_one=", RUBY_METHOD_FUNC(_llama_batch_set_all_pos_one), 1);
|
59
|
+
rb_define_method(rb_cLLaMABatch, "all_pos_one", RUBY_METHOD_FUNC(_llama_batch_get_all_pos_one), 0);
|
60
|
+
rb_define_method(rb_cLLaMABatch, "all_seq_id=", RUBY_METHOD_FUNC(_llama_batch_set_all_seq_id), 1);
|
61
|
+
rb_define_method(rb_cLLaMABatch, "all_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_all_seq_id), 0);
|
62
|
+
rb_define_method(rb_cLLaMABatch, "set_token", RUBY_METHOD_FUNC(_llama_batch_set_token), 2);
|
63
|
+
rb_define_method(rb_cLLaMABatch, "get_token", RUBY_METHOD_FUNC(_llama_batch_get_token), 1);
|
64
|
+
rb_define_method(rb_cLLaMABatch, "set_pos", RUBY_METHOD_FUNC(_llama_batch_set_pos), 2);
|
65
|
+
rb_define_method(rb_cLLaMABatch, "get_pos", RUBY_METHOD_FUNC(_llama_batch_get_pos), 1);
|
66
|
+
rb_define_method(rb_cLLaMABatch, "set_seq_id", RUBY_METHOD_FUNC(_llama_batch_set_seq_id), 2);
|
67
|
+
rb_define_method(rb_cLLaMABatch, "get_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_seq_id), 1);
|
68
|
+
rb_define_method(rb_cLLaMABatch, "set_logits", RUBY_METHOD_FUNC(_llama_batch_set_logits), 2);
|
69
|
+
rb_define_method(rb_cLLaMABatch, "get_logits", RUBY_METHOD_FUNC(_llama_batch_get_logits), 1);
|
70
|
+
}
|
71
|
+
|
72
|
+
private:
|
73
|
+
static const rb_data_type_t llama_batch_type;
|
74
|
+
|
75
|
+
static VALUE _llama_batch_initialize(int argc, VALUE* argv, VALUE self) {
|
76
|
+
VALUE kw_args = Qnil;
|
77
|
+
ID kw_table[2] = { rb_intern("n_tokens"), rb_intern("embd") };
|
78
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
79
|
+
rb_scan_args(argc, argv, ":", &kw_args);
|
80
|
+
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
81
|
+
|
82
|
+
if (!RB_INTEGER_TYPE_P(kw_values[0])) {
|
83
|
+
rb_raise(rb_eArgError, "n_tokens must be an integer");
|
84
|
+
return Qnil;
|
85
|
+
}
|
86
|
+
if (!RB_INTEGER_TYPE_P(kw_values[1])) {
|
87
|
+
rb_raise(rb_eArgError, "embd must be an integer");
|
88
|
+
return Qnil;
|
89
|
+
}
|
90
|
+
|
91
|
+
const int32_t n_tokens = NUM2INT(kw_values[0]);
|
92
|
+
const int32_t embd = NUM2INT(kw_values[1]);
|
93
|
+
|
94
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
95
|
+
ptr->batch = llama_batch_init(n_tokens, embd);
|
96
|
+
|
97
|
+
return Qnil;
|
98
|
+
}
|
99
|
+
|
100
|
+
// n_tokens
|
101
|
+
static VALUE _llama_batch_set_n_tokens(VALUE self, VALUE n_tokens) {
|
102
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
103
|
+
ptr->batch.n_tokens = NUM2INT(n_tokens);
|
104
|
+
return INT2NUM(ptr->batch.n_tokens);
|
105
|
+
}
|
106
|
+
|
107
|
+
static VALUE _llama_batch_get_n_tokens(VALUE self) {
|
108
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
109
|
+
return INT2NUM(ptr->batch.n_tokens);
|
110
|
+
}
|
111
|
+
|
112
|
+
// all_pos_0
|
113
|
+
static VALUE _llama_batch_set_all_pos_zero(VALUE self, VALUE all_pos_0) {
|
114
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
115
|
+
ptr->batch.all_pos_0 = NUM2INT(all_pos_0);
|
116
|
+
return INT2NUM(ptr->batch.all_pos_0);
|
117
|
+
}
|
118
|
+
|
119
|
+
static VALUE _llama_batch_get_all_pos_zero(VALUE self) {
|
120
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
121
|
+
return INT2NUM(ptr->batch.all_pos_0);
|
122
|
+
}
|
123
|
+
|
124
|
+
// all_pos_1
|
125
|
+
static VALUE _llama_batch_set_all_pos_one(VALUE self, VALUE all_pos_1) {
|
126
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
127
|
+
ptr->batch.all_pos_1 = NUM2INT(all_pos_1);
|
128
|
+
return INT2NUM(ptr->batch.all_pos_1);
|
129
|
+
}
|
130
|
+
|
131
|
+
static VALUE _llama_batch_get_all_pos_one(VALUE self) {
|
132
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
133
|
+
return INT2NUM(ptr->batch.all_pos_1);
|
134
|
+
}
|
135
|
+
|
136
|
+
// all_seq_id
|
137
|
+
static VALUE _llama_batch_set_all_seq_id(VALUE self, VALUE all_seq_id) {
|
138
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
139
|
+
ptr->batch.all_seq_id = NUM2INT(all_seq_id);
|
140
|
+
return INT2NUM(ptr->batch.all_seq_id);
|
141
|
+
}
|
142
|
+
|
143
|
+
static VALUE _llama_batch_get_all_seq_id(VALUE self) {
|
144
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
145
|
+
return INT2NUM(ptr->batch.all_seq_id);
|
146
|
+
}
|
147
|
+
|
148
|
+
// token
|
149
|
+
static VALUE _llama_batch_set_token(VALUE self, VALUE idx, VALUE value) {
|
150
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
151
|
+
const int32_t id = NUM2INT(idx);
|
152
|
+
if (id < 0 || id >= ptr->batch.n_tokens) {
|
153
|
+
rb_raise(rb_eArgError, "idx must be in [0, n_tokens)");
|
154
|
+
return Qnil;
|
155
|
+
}
|
156
|
+
ptr->batch.token[id] = NUM2INT(value);
|
157
|
+
return INT2NUM(ptr->batch.token[id]);
|
158
|
+
}
|
159
|
+
|
160
|
+
static VALUE _llama_batch_get_token(VALUE self, VALUE idx) {
|
161
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
162
|
+
const int32_t id = NUM2INT(idx);
|
163
|
+
if (id < 0 || id >= ptr->batch.n_tokens) {
|
164
|
+
rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
|
165
|
+
return Qnil;
|
166
|
+
}
|
167
|
+
return INT2NUM(ptr->batch.token[id]);
|
168
|
+
}
|
169
|
+
|
170
|
+
// pos
|
171
|
+
static VALUE _llama_batch_set_pos(VALUE self, VALUE idx, VALUE value) {
|
172
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
173
|
+
const int32_t id = NUM2INT(idx);
|
174
|
+
if (id < 0 || id >= ptr->batch.n_tokens) {
|
175
|
+
rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
|
176
|
+
return Qnil;
|
177
|
+
}
|
178
|
+
ptr->batch.pos[id] = NUM2INT(value);
|
179
|
+
return INT2NUM(ptr->batch.pos[id]);
|
180
|
+
}
|
181
|
+
|
182
|
+
static VALUE _llama_batch_get_pos(VALUE self, VALUE idx) {
|
183
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
184
|
+
const int32_t id = NUM2INT(idx);
|
185
|
+
if (id < 0 || id >= ptr->batch.n_tokens) {
|
186
|
+
rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
|
187
|
+
return Qnil;
|
188
|
+
}
|
189
|
+
return INT2NUM(ptr->batch.pos[id]);
|
190
|
+
}
|
191
|
+
|
192
|
+
// seq_id
|
193
|
+
static VALUE _llama_batch_set_seq_id(VALUE self, VALUE idx, VALUE value) {
|
194
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
195
|
+
const int32_t id = NUM2INT(idx);
|
196
|
+
if (id < 0 || id >= ptr->batch.n_tokens) {
|
197
|
+
rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
|
198
|
+
return Qnil;
|
199
|
+
}
|
200
|
+
ptr->batch.seq_id[id] = NUM2INT(value);
|
201
|
+
return INT2NUM(ptr->batch.seq_id[id]);
|
202
|
+
}
|
203
|
+
|
204
|
+
static VALUE _llama_batch_get_seq_id(VALUE self, VALUE idx) {
|
205
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
206
|
+
const int32_t id = NUM2INT(idx);
|
207
|
+
if (id < 0 || id >= ptr->batch.n_tokens) {
|
208
|
+
rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
|
209
|
+
return Qnil;
|
210
|
+
}
|
211
|
+
return INT2NUM(ptr->batch.seq_id[id]);
|
212
|
+
}
|
213
|
+
|
214
|
+
// logits
|
215
|
+
static VALUE _llama_batch_set_logits(VALUE self, VALUE idx, VALUE value) {
|
216
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
217
|
+
const int32_t id = NUM2INT(idx);
|
218
|
+
if (id < 0 || id >= ptr->batch.n_tokens) {
|
219
|
+
rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
|
220
|
+
return Qnil;
|
221
|
+
}
|
222
|
+
ptr->batch.logits[id] = RTEST(value) ? true : false;
|
223
|
+
return ptr->batch.logits[id] ? Qtrue : Qfalse;
|
224
|
+
}
|
225
|
+
|
226
|
+
static VALUE _llama_batch_get_logits(VALUE self, VALUE idx) {
|
227
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
228
|
+
const int32_t id = NUM2INT(idx);
|
229
|
+
if (id < 0 || id >= ptr->batch.n_tokens) {
|
230
|
+
rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
|
231
|
+
return Qnil;
|
232
|
+
}
|
233
|
+
return ptr->batch.logits[id] ? Qtrue : Qfalse;
|
234
|
+
}
|
235
|
+
};
|
236
|
+
|
237
|
+
const rb_data_type_t RbLLaMABatch::llama_batch_type = {
|
238
|
+
"RbLLaMABatch",
|
239
|
+
{ NULL,
|
240
|
+
RbLLaMABatch::llama_batch_free,
|
241
|
+
RbLLaMABatch::llama_batch_size },
|
242
|
+
NULL,
|
243
|
+
NULL,
|
244
|
+
RUBY_TYPED_FREE_IMMEDIATELY
|
245
|
+
|
246
|
+
};
|
247
|
+
|
14
248
|
class LLaMATokenDataWrapper {
|
15
249
|
public:
|
16
250
|
llama_token_data data;
|
@@ -363,6 +597,144 @@ const rb_data_type_t RbLLaMATimings::llama_timings_type = {
|
|
363
597
|
RUBY_TYPED_FREE_IMMEDIATELY
|
364
598
|
};
|
365
599
|
|
600
|
+
class LLaMAModelParamsWrapper {
|
601
|
+
public:
|
602
|
+
struct llama_model_params params;
|
603
|
+
|
604
|
+
LLaMAModelParamsWrapper() : params(llama_model_default_params()) {}
|
605
|
+
|
606
|
+
~LLaMAModelParamsWrapper() {}
|
607
|
+
};
|
608
|
+
|
609
|
+
class RbLLaMAModelParams {
|
610
|
+
public:
|
611
|
+
static VALUE llama_model_params_alloc(VALUE self) {
|
612
|
+
LLaMAModelParamsWrapper* ptr = (LLaMAModelParamsWrapper*)ruby_xmalloc(sizeof(LLaMAModelParamsWrapper));
|
613
|
+
new (ptr) LLaMAModelParamsWrapper();
|
614
|
+
return TypedData_Wrap_Struct(self, &llama_model_params_type, ptr);
|
615
|
+
}
|
616
|
+
|
617
|
+
static void llama_model_params_free(void* ptr) {
|
618
|
+
((LLaMAModelParamsWrapper*)ptr)->~LLaMAModelParamsWrapper();
|
619
|
+
ruby_xfree(ptr);
|
620
|
+
}
|
621
|
+
|
622
|
+
static size_t llama_model_params_size(const void* ptr) {
|
623
|
+
return sizeof(*((LLaMAModelParamsWrapper*)ptr));
|
624
|
+
}
|
625
|
+
|
626
|
+
static LLaMAModelParamsWrapper* get_llama_model_params(VALUE self) {
|
627
|
+
LLaMAModelParamsWrapper* ptr;
|
628
|
+
TypedData_Get_Struct(self, LLaMAModelParamsWrapper, &llama_model_params_type, ptr);
|
629
|
+
return ptr;
|
630
|
+
}
|
631
|
+
|
632
|
+
static void define_class(VALUE outer) {
|
633
|
+
rb_cLLaMAModelParams = rb_define_class_under(outer, "ModelParams", rb_cObject);
|
634
|
+
rb_define_alloc_func(rb_cLLaMAModelParams, llama_model_params_alloc);
|
635
|
+
rb_define_method(rb_cLLaMAModelParams, "n_gpu_layers=", RUBY_METHOD_FUNC(_llama_model_params_set_n_gpu_layers), 1);
|
636
|
+
rb_define_method(rb_cLLaMAModelParams, "n_gpu_layers", RUBY_METHOD_FUNC(_llama_model_params_get_n_gpu_layers), 0);
|
637
|
+
rb_define_method(rb_cLLaMAModelParams, "main_gpu=", RUBY_METHOD_FUNC(_llama_model_params_set_main_gpu), 1);
|
638
|
+
rb_define_method(rb_cLLaMAModelParams, "main_gpu", RUBY_METHOD_FUNC(_llama_model_params_get_main_gpu), 0);
|
639
|
+
rb_define_method(rb_cLLaMAModelParams, "tensor_split", RUBY_METHOD_FUNC(_llama_model_params_get_tensor_split), 0);
|
640
|
+
rb_define_method(rb_cLLaMAModelParams, "vocab_only=", RUBY_METHOD_FUNC(_llama_model_params_set_vocab_only), 1);
|
641
|
+
rb_define_method(rb_cLLaMAModelParams, "vocab_only", RUBY_METHOD_FUNC(_llama_model_params_get_vocab_only), 0);
|
642
|
+
rb_define_method(rb_cLLaMAModelParams, "use_mmap=", RUBY_METHOD_FUNC(_llama_model_params_set_use_mmap), 1);
|
643
|
+
rb_define_method(rb_cLLaMAModelParams, "use_mmap", RUBY_METHOD_FUNC(_llama_model_params_get_use_mmap), 0);
|
644
|
+
rb_define_method(rb_cLLaMAModelParams, "use_mlock=", RUBY_METHOD_FUNC(_llama_model_params_set_use_mlock), 1);
|
645
|
+
rb_define_method(rb_cLLaMAModelParams, "use_mlock", RUBY_METHOD_FUNC(_llama_model_params_get_use_mlock), 0);
|
646
|
+
}
|
647
|
+
|
648
|
+
private:
|
649
|
+
static const rb_data_type_t llama_model_params_type;
|
650
|
+
|
651
|
+
// n_gpu_layers
|
652
|
+
static VALUE _llama_model_params_set_n_gpu_layers(VALUE self, VALUE n_gpu_layers) {
|
653
|
+
LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
|
654
|
+
ptr->params.n_gpu_layers = NUM2INT(n_gpu_layers);
|
655
|
+
return INT2NUM(ptr->params.n_gpu_layers);
|
656
|
+
}
|
657
|
+
|
658
|
+
static VALUE _llama_model_params_get_n_gpu_layers(VALUE self) {
|
659
|
+
LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
|
660
|
+
return INT2NUM(ptr->params.n_gpu_layers);
|
661
|
+
}
|
662
|
+
|
663
|
+
// main_gpu
|
664
|
+
static VALUE _llama_model_params_set_main_gpu(VALUE self, VALUE main_gpu) {
|
665
|
+
LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
|
666
|
+
ptr->params.main_gpu = NUM2INT(main_gpu);
|
667
|
+
return INT2NUM(ptr->params.main_gpu);
|
668
|
+
}
|
669
|
+
|
670
|
+
static VALUE _llama_model_params_get_main_gpu(VALUE self) {
|
671
|
+
LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
|
672
|
+
return INT2NUM(ptr->params.main_gpu);
|
673
|
+
}
|
674
|
+
|
675
|
+
// tensor_split
|
676
|
+
static VALUE _llama_model_params_get_tensor_split(VALUE self) {
|
677
|
+
if (LLAMA_MAX_DEVICES < 1) {
|
678
|
+
return rb_ary_new();
|
679
|
+
}
|
680
|
+
VALUE ret = rb_ary_new2(LLAMA_MAX_DEVICES);
|
681
|
+
LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
|
682
|
+
if (ptr->params.tensor_split == nullptr) {
|
683
|
+
return rb_ary_new();
|
684
|
+
}
|
685
|
+
for (size_t i = 0; i < LLAMA_MAX_DEVICES; i++) {
|
686
|
+
rb_ary_store(ret, i, DBL2NUM(ptr->params.tensor_split[i]));
|
687
|
+
}
|
688
|
+
return ret;
|
689
|
+
}
|
690
|
+
|
691
|
+
// vocab_only
|
692
|
+
static VALUE _llama_model_params_set_vocab_only(VALUE self, VALUE vocab_only) {
|
693
|
+
LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
|
694
|
+
ptr->params.vocab_only = RTEST(vocab_only) ? true : false;
|
695
|
+
return ptr->params.vocab_only ? Qtrue : Qfalse;
|
696
|
+
}
|
697
|
+
|
698
|
+
static VALUE _llama_model_params_get_vocab_only(VALUE self) {
|
699
|
+
LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
|
700
|
+
return ptr->params.vocab_only ? Qtrue : Qfalse;
|
701
|
+
}
|
702
|
+
|
703
|
+
// use_mmap
|
704
|
+
static VALUE _llama_model_params_set_use_mmap(VALUE self, VALUE use_mmap) {
|
705
|
+
LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
|
706
|
+
ptr->params.use_mmap = RTEST(use_mmap) ? true : false;
|
707
|
+
return ptr->params.use_mmap ? Qtrue : Qfalse;
|
708
|
+
}
|
709
|
+
|
710
|
+
static VALUE _llama_model_params_get_use_mmap(VALUE self) {
|
711
|
+
LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
|
712
|
+
return ptr->params.use_mmap ? Qtrue : Qfalse;
|
713
|
+
}
|
714
|
+
|
715
|
+
// use_mlock
|
716
|
+
static VALUE _llama_model_params_set_use_mlock(VALUE self, VALUE use_mlock) {
|
717
|
+
LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
|
718
|
+
ptr->params.use_mlock = RTEST(use_mlock) ? true : false;
|
719
|
+
return ptr->params.use_mlock ? Qtrue : Qfalse;
|
720
|
+
}
|
721
|
+
|
722
|
+
static VALUE _llama_model_params_get_use_mlock(VALUE self) {
|
723
|
+
LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
|
724
|
+
return ptr->params.use_mlock ? Qtrue : Qfalse;
|
725
|
+
}
|
726
|
+
};
|
727
|
+
|
728
|
+
const rb_data_type_t RbLLaMAModelParams::llama_model_params_type = {
|
729
|
+
"RbLLaMAModelParams",
|
730
|
+
{ NULL,
|
731
|
+
RbLLaMAModelParams::llama_model_params_free,
|
732
|
+
RbLLaMAModelParams::llama_model_params_size },
|
733
|
+
NULL,
|
734
|
+
NULL,
|
735
|
+
RUBY_TYPED_FREE_IMMEDIATELY
|
736
|
+
};
|
737
|
+
|
366
738
|
class LLaMAContextParamsWrapper {
|
367
739
|
public:
|
368
740
|
struct llama_context_params params;
|
@@ -399,35 +771,26 @@ public:
|
|
399
771
|
rb_cLLaMAContextParams = rb_define_class_under(outer, "ContextParams", rb_cObject);
|
400
772
|
rb_define_alloc_func(rb_cLLaMAContextParams, llama_context_params_alloc);
|
401
773
|
// rb_define_method(rb_cLLaMAContextParams, "initialize", RUBY_METHOD_FUNC(_llama_context_params_init), 0);
|
774
|
+
rb_define_method(rb_cLLaMAContextParams, "seed=", RUBY_METHOD_FUNC(_llama_context_params_set_seed), 1);
|
775
|
+
rb_define_method(rb_cLLaMAContextParams, "seed", RUBY_METHOD_FUNC(_llama_context_params_get_seed), 0);
|
402
776
|
rb_define_method(rb_cLLaMAContextParams, "n_ctx=", RUBY_METHOD_FUNC(_llama_context_params_set_n_ctx), 1);
|
403
777
|
rb_define_method(rb_cLLaMAContextParams, "n_ctx", RUBY_METHOD_FUNC(_llama_context_params_get_n_ctx), 0);
|
404
778
|
rb_define_method(rb_cLLaMAContextParams, "n_batch=", RUBY_METHOD_FUNC(_llama_context_params_set_n_batch), 1);
|
405
779
|
rb_define_method(rb_cLLaMAContextParams, "n_batch", RUBY_METHOD_FUNC(_llama_context_params_get_n_batch), 0);
|
406
|
-
rb_define_method(rb_cLLaMAContextParams, "
|
407
|
-
rb_define_method(rb_cLLaMAContextParams, "
|
408
|
-
rb_define_method(rb_cLLaMAContextParams, "
|
409
|
-
rb_define_method(rb_cLLaMAContextParams, "
|
410
|
-
rb_define_method(rb_cLLaMAContextParams, "tensor_split", RUBY_METHOD_FUNC(_llama_context_params_get_tensor_split), 0);
|
780
|
+
rb_define_method(rb_cLLaMAContextParams, "n_threads=", RUBY_METHOD_FUNC(_llama_context_params_set_n_threads), 1);
|
781
|
+
rb_define_method(rb_cLLaMAContextParams, "n_threads", RUBY_METHOD_FUNC(_llama_context_params_get_n_threads), 0);
|
782
|
+
rb_define_method(rb_cLLaMAContextParams, "n_threads_batch=", RUBY_METHOD_FUNC(_llama_context_params_set_n_threads_batch), 1);
|
783
|
+
rb_define_method(rb_cLLaMAContextParams, "n_threads_batch", RUBY_METHOD_FUNC(_llama_context_params_get_n_threads_batch), 0);
|
411
784
|
rb_define_method(rb_cLLaMAContextParams, "rope_freq_base=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_freq_base), 1);
|
412
785
|
rb_define_method(rb_cLLaMAContextParams, "rope_freq_base", RUBY_METHOD_FUNC(_llama_context_params_get_rope_freq_base), 0);
|
413
786
|
rb_define_method(rb_cLLaMAContextParams, "rope_freq_scale=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_freq_scale), 1);
|
414
787
|
rb_define_method(rb_cLLaMAContextParams, "rope_freq_scale", RUBY_METHOD_FUNC(_llama_context_params_get_rope_freq_scale), 0);
|
415
|
-
rb_define_method(rb_cLLaMAContextParams, "low_vram=", RUBY_METHOD_FUNC(_llama_context_params_set_low_vram), 1);
|
416
|
-
rb_define_method(rb_cLLaMAContextParams, "low_vram", RUBY_METHOD_FUNC(_llama_context_params_get_low_vram), 0);
|
417
788
|
rb_define_method(rb_cLLaMAContextParams, "mul_mat_q=", RUBY_METHOD_FUNC(_llama_context_params_set_mul_mat_q), 1);
|
418
789
|
rb_define_method(rb_cLLaMAContextParams, "mul_mat_q", RUBY_METHOD_FUNC(_llama_context_params_get_mul_mat_q), 0);
|
419
|
-
rb_define_method(rb_cLLaMAContextParams, "seed=", RUBY_METHOD_FUNC(_llama_context_params_set_seed), 1);
|
420
|
-
rb_define_method(rb_cLLaMAContextParams, "seed", RUBY_METHOD_FUNC(_llama_context_params_get_seed), 0);
|
421
790
|
rb_define_method(rb_cLLaMAContextParams, "f16_kv=", RUBY_METHOD_FUNC(_llama_context_params_set_f16_kv), 1);
|
422
791
|
rb_define_method(rb_cLLaMAContextParams, "f16_kv", RUBY_METHOD_FUNC(_llama_context_params_get_f16_kv), 0);
|
423
792
|
rb_define_method(rb_cLLaMAContextParams, "logits_all=", RUBY_METHOD_FUNC(_llama_context_params_set_logits_all), 1);
|
424
793
|
rb_define_method(rb_cLLaMAContextParams, "logits_all", RUBY_METHOD_FUNC(_llama_context_params_get_logits_all), 0);
|
425
|
-
rb_define_method(rb_cLLaMAContextParams, "vocab_only=", RUBY_METHOD_FUNC(_llama_context_params_set_vocab_only), 1);
|
426
|
-
rb_define_method(rb_cLLaMAContextParams, "vocab_only", RUBY_METHOD_FUNC(_llama_context_params_get_vocab_only), 0);
|
427
|
-
rb_define_method(rb_cLLaMAContextParams, "use_mmap=", RUBY_METHOD_FUNC(_llama_context_params_set_use_mmap), 1);
|
428
|
-
rb_define_method(rb_cLLaMAContextParams, "use_mmap", RUBY_METHOD_FUNC(_llama_context_params_get_use_mmap), 0);
|
429
|
-
rb_define_method(rb_cLLaMAContextParams, "use_mlock=", RUBY_METHOD_FUNC(_llama_context_params_set_use_mlock), 1);
|
430
|
-
rb_define_method(rb_cLLaMAContextParams, "use_mlock", RUBY_METHOD_FUNC(_llama_context_params_get_use_mlock), 0);
|
431
794
|
rb_define_method(rb_cLLaMAContextParams, "embedding=", RUBY_METHOD_FUNC(_llama_context_params_set_embedding), 1);
|
432
795
|
rb_define_method(rb_cLLaMAContextParams, "embedding", RUBY_METHOD_FUNC(_llama_context_params_get_embedding), 0);
|
433
796
|
}
|
@@ -441,6 +804,22 @@ private:
|
|
441
804
|
// return self;
|
442
805
|
// }
|
443
806
|
|
807
|
+
// seed
|
808
|
+
static VALUE _llama_context_params_set_seed(VALUE self, VALUE seed) {
|
809
|
+
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
810
|
+
if (NUM2INT(seed) < 0) {
|
811
|
+
rb_raise(rb_eArgError, "seed must be positive");
|
812
|
+
return Qnil;
|
813
|
+
}
|
814
|
+
ptr->params.seed = NUM2INT(seed);
|
815
|
+
return INT2NUM(ptr->params.seed);
|
816
|
+
}
|
817
|
+
|
818
|
+
static VALUE _llama_context_params_get_seed(VALUE self) {
|
819
|
+
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
820
|
+
return INT2NUM(ptr->params.seed);
|
821
|
+
}
|
822
|
+
|
444
823
|
// n_ctx
|
445
824
|
static VALUE _llama_context_params_set_n_ctx(VALUE self, VALUE n_ctx) {
|
446
825
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
@@ -465,41 +844,28 @@ private:
|
|
465
844
|
return INT2NUM(ptr->params.n_batch);
|
466
845
|
}
|
467
846
|
|
468
|
-
//
|
469
|
-
static VALUE
|
847
|
+
// n_threads
|
848
|
+
static VALUE _llama_context_params_set_n_threads(VALUE self, VALUE n_threads) {
|
470
849
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
471
|
-
ptr->params.
|
472
|
-
return INT2NUM(ptr->params.
|
850
|
+
ptr->params.n_threads = NUM2INT(n_threads);
|
851
|
+
return INT2NUM(ptr->params.n_threads);
|
473
852
|
}
|
474
853
|
|
475
|
-
static VALUE
|
854
|
+
static VALUE _llama_context_params_get_n_threads(VALUE self) {
|
476
855
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
477
|
-
return INT2NUM(ptr->params.
|
856
|
+
return INT2NUM(ptr->params.n_threads);
|
478
857
|
}
|
479
858
|
|
480
|
-
//
|
481
|
-
static VALUE
|
859
|
+
// n_threads_batch
|
860
|
+
static VALUE _llama_context_params_set_n_threads_batch(VALUE self, VALUE n_threads_batch) {
|
482
861
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
483
|
-
ptr->params.
|
484
|
-
return INT2NUM(ptr->params.
|
862
|
+
ptr->params.n_threads_batch = NUM2INT(n_threads_batch);
|
863
|
+
return INT2NUM(ptr->params.n_threads_batch);
|
485
864
|
}
|
486
865
|
|
487
|
-
static VALUE
|
866
|
+
static VALUE _llama_context_params_get_n_threads_batch(VALUE self) {
|
488
867
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
489
|
-
return INT2NUM(ptr->params.
|
490
|
-
}
|
491
|
-
|
492
|
-
// tensor_split
|
493
|
-
static VALUE _llama_context_params_get_tensor_split(VALUE self) {
|
494
|
-
if (LLAMA_MAX_DEVICES < 1) {
|
495
|
-
return rb_ary_new();
|
496
|
-
}
|
497
|
-
VALUE ret = rb_ary_new2(LLAMA_MAX_DEVICES);
|
498
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
499
|
-
for (size_t i = 0; i < LLAMA_MAX_DEVICES; i++) {
|
500
|
-
rb_ary_store(ret, i, DBL2NUM(ptr->params.tensor_split[i]));
|
501
|
-
}
|
502
|
-
return ret;
|
868
|
+
return INT2NUM(ptr->params.n_threads_batch);
|
503
869
|
}
|
504
870
|
|
505
871
|
// rope_freq_base
|
@@ -526,18 +892,6 @@ private:
|
|
526
892
|
return DBL2NUM(ptr->params.rope_freq_scale);
|
527
893
|
}
|
528
894
|
|
529
|
-
// low_vram
|
530
|
-
static VALUE _llama_context_params_set_low_vram(VALUE self, VALUE low_vram) {
|
531
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
532
|
-
ptr->params.low_vram = RTEST(low_vram) ? true : false;
|
533
|
-
return ptr->params.low_vram ? Qtrue : Qfalse;
|
534
|
-
}
|
535
|
-
|
536
|
-
static VALUE _llama_context_params_get_low_vram(VALUE self) {
|
537
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
538
|
-
return ptr->params.low_vram ? Qtrue : Qfalse;
|
539
|
-
}
|
540
|
-
|
541
895
|
// mul_mat_q
|
542
896
|
static VALUE _llama_context_params_set_mul_mat_q(VALUE self, VALUE mul_mat_q) {
|
543
897
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
@@ -550,22 +904,6 @@ private:
|
|
550
904
|
return ptr->params.mul_mat_q ? Qtrue : Qfalse;
|
551
905
|
}
|
552
906
|
|
553
|
-
// seed
|
554
|
-
static VALUE _llama_context_params_set_seed(VALUE self, VALUE seed) {
|
555
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
556
|
-
if (NUM2INT(seed) < 0) {
|
557
|
-
rb_raise(rb_eArgError, "seed must be positive");
|
558
|
-
return Qnil;
|
559
|
-
}
|
560
|
-
ptr->params.seed = NUM2INT(seed);
|
561
|
-
return INT2NUM(ptr->params.seed);
|
562
|
-
}
|
563
|
-
|
564
|
-
static VALUE _llama_context_params_get_seed(VALUE self) {
|
565
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
566
|
-
return INT2NUM(ptr->params.seed);
|
567
|
-
}
|
568
|
-
|
569
907
|
// f16_kv
|
570
908
|
static VALUE _llama_context_params_set_f16_kv(VALUE self, VALUE f16_kv) {
|
571
909
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
@@ -590,42 +928,6 @@ private:
|
|
590
928
|
return ptr->params.logits_all ? Qtrue : Qfalse;
|
591
929
|
}
|
592
930
|
|
593
|
-
// vocab_only
|
594
|
-
static VALUE _llama_context_params_set_vocab_only(VALUE self, VALUE vocab_only) {
|
595
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
596
|
-
ptr->params.vocab_only = RTEST(vocab_only) ? true : false;
|
597
|
-
return ptr->params.vocab_only ? Qtrue : Qfalse;
|
598
|
-
}
|
599
|
-
|
600
|
-
static VALUE _llama_context_params_get_vocab_only(VALUE self) {
|
601
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
602
|
-
return ptr->params.vocab_only ? Qtrue : Qfalse;
|
603
|
-
}
|
604
|
-
|
605
|
-
// use_mmap
|
606
|
-
static VALUE _llama_context_params_set_use_mmap(VALUE self, VALUE use_mmap) {
|
607
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
608
|
-
ptr->params.use_mmap = RTEST(use_mmap) ? true : false;
|
609
|
-
return ptr->params.use_mmap ? Qtrue : Qfalse;
|
610
|
-
}
|
611
|
-
|
612
|
-
static VALUE _llama_context_params_get_use_mmap(VALUE self) {
|
613
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
614
|
-
return ptr->params.use_mmap ? Qtrue : Qfalse;
|
615
|
-
}
|
616
|
-
|
617
|
-
// use_mlock
|
618
|
-
static VALUE _llama_context_params_set_use_mlock(VALUE self, VALUE use_mlock) {
|
619
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
620
|
-
ptr->params.use_mlock = RTEST(use_mlock) ? true : false;
|
621
|
-
return ptr->params.use_mlock ? Qtrue : Qfalse;
|
622
|
-
}
|
623
|
-
|
624
|
-
static VALUE _llama_context_params_get_use_mlock(VALUE self) {
|
625
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
626
|
-
return ptr->params.use_mlock ? Qtrue : Qfalse;
|
627
|
-
}
|
628
|
-
|
629
931
|
// embedding
|
630
932
|
static VALUE _llama_context_params_set_embedding(VALUE self, VALUE embedding) {
|
631
933
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
@@ -823,11 +1125,10 @@ public:
|
|
823
1125
|
rb_define_method(rb_cLLaMAModel, "load", RUBY_METHOD_FUNC(_llama_model_load), -1);
|
824
1126
|
rb_define_method(rb_cLLaMAModel, "apply_lora_from_file", RUBY_METHOD_FUNC(_llama_model_apply_lora_from_file), -1);
|
825
1127
|
rb_define_method(rb_cLLaMAModel, "n_vocab", RUBY_METHOD_FUNC(_llama_model_get_model_n_vocab), 0);
|
826
|
-
rb_define_method(rb_cLLaMAModel, "n_ctx", RUBY_METHOD_FUNC(_llama_model_get_model_n_ctx), 0);
|
827
1128
|
rb_define_method(rb_cLLaMAModel, "n_ctx_train", RUBY_METHOD_FUNC(_llama_model_get_model_n_ctx_train), 0);
|
828
1129
|
rb_define_method(rb_cLLaMAModel, "n_embd", RUBY_METHOD_FUNC(_llama_model_get_model_n_embd), 0);
|
829
|
-
rb_define_method(rb_cLLaMAModel, "token_to_piece", RUBY_METHOD_FUNC(
|
830
|
-
rb_define_method(rb_cLLaMAModel, "tokenize", RUBY_METHOD_FUNC(
|
1130
|
+
rb_define_method(rb_cLLaMAModel, "token_to_piece", RUBY_METHOD_FUNC(_llama_model_token_to_piece), 1);
|
1131
|
+
rb_define_method(rb_cLLaMAModel, "tokenize", RUBY_METHOD_FUNC(_llama_model_tokenize), -1);
|
831
1132
|
rb_define_method(rb_cLLaMAModel, "desc", RUBY_METHOD_FUNC(_llama_model_get_model_desc), 0);
|
832
1133
|
rb_define_method(rb_cLLaMAModel, "size", RUBY_METHOD_FUNC(_llama_model_get_model_size), 0);
|
833
1134
|
rb_define_method(rb_cLLaMAModel, "n_params", RUBY_METHOD_FUNC(_llama_model_get_model_n_params), 0);
|
@@ -841,30 +1142,21 @@ private:
|
|
841
1142
|
ID kw_table[2] = { rb_intern("model_path"), rb_intern("params") };
|
842
1143
|
VALUE kw_values[2] = { Qundef, Qundef };
|
843
1144
|
rb_scan_args(argc, argv, ":", &kw_args);
|
844
|
-
rb_get_kwargs(kw_args, kw_table,
|
845
|
-
|
846
|
-
if (kw_values[0] == Qundef && kw_values[1] == Qundef) {
|
847
|
-
rb_iv_set(self, "@params", Qnil);
|
848
|
-
return Qnil;
|
849
|
-
}
|
1145
|
+
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
850
1146
|
|
851
1147
|
if (!RB_TYPE_P(kw_values[0], T_STRING)) {
|
852
1148
|
rb_raise(rb_eArgError, "model_path must be a string");
|
853
1149
|
return Qnil;
|
854
1150
|
}
|
855
|
-
if (!rb_obj_is_kind_of(kw_values[1],
|
856
|
-
rb_raise(rb_eArgError, "params must be a
|
1151
|
+
if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAModelParams)) {
|
1152
|
+
rb_raise(rb_eArgError, "params must be a ModelParams");
|
857
1153
|
return Qnil;
|
858
1154
|
}
|
859
1155
|
|
860
1156
|
VALUE filename = kw_values[0];
|
861
|
-
|
1157
|
+
LLaMAModelParamsWrapper* prms_ptr = RbLLaMAModelParams::get_llama_model_params(kw_values[1]);
|
862
1158
|
LLaMAModelWrapper* model_ptr = get_llama_model(self);
|
863
1159
|
|
864
|
-
if (prms_ptr->params.seed == LLAMA_DEFAULT_SEED) {
|
865
|
-
prms_ptr->params.seed = time(NULL);
|
866
|
-
}
|
867
|
-
|
868
1160
|
try {
|
869
1161
|
model_ptr->model = llama_load_model_from_file(StringValueCStr(filename), prms_ptr->params);
|
870
1162
|
} catch (const std::runtime_error& e) {
|
@@ -912,8 +1204,8 @@ private:
|
|
912
1204
|
rb_raise(rb_eArgError, "model_path must be a string");
|
913
1205
|
return Qnil;
|
914
1206
|
}
|
915
|
-
if (!rb_obj_is_kind_of(kw_values[1],
|
916
|
-
rb_raise(rb_eArgError, "params must be a
|
1207
|
+
if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAModelParams)) {
|
1208
|
+
rb_raise(rb_eArgError, "params must be a LLaMAModelParams");
|
917
1209
|
return Qnil;
|
918
1210
|
}
|
919
1211
|
|
@@ -924,7 +1216,7 @@ private:
|
|
924
1216
|
}
|
925
1217
|
|
926
1218
|
VALUE filename = kw_values[0];
|
927
|
-
|
1219
|
+
LLaMAModelParamsWrapper* prms_ptr = RbLLaMAModelParams::get_llama_model_params(kw_values[1]);
|
928
1220
|
|
929
1221
|
try {
|
930
1222
|
model_ptr->model = llama_load_model_from_file(StringValueCStr(filename), prms_ptr->params);
|
@@ -946,10 +1238,10 @@ private:
|
|
946
1238
|
|
947
1239
|
static VALUE _llama_model_apply_lora_from_file(int argc, VALUE* argv, VALUE self) {
|
948
1240
|
VALUE kw_args = Qnil;
|
949
|
-
ID kw_table[
|
950
|
-
VALUE kw_values[
|
1241
|
+
ID kw_table[4] = { rb_intern("lora_path"), rb_intern("base_model_path"), rb_intern("n_threads"), rb_intern("scale") };
|
1242
|
+
VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
|
951
1243
|
rb_scan_args(argc, argv, ":", &kw_args);
|
952
|
-
rb_get_kwargs(kw_args, kw_table, 1,
|
1244
|
+
rb_get_kwargs(kw_args, kw_table, 1, 3, kw_values);
|
953
1245
|
|
954
1246
|
if (!RB_TYPE_P(kw_values[0], T_STRING)) {
|
955
1247
|
rb_raise(rb_eArgError, "lora_path must be a string");
|
@@ -963,13 +1255,18 @@ private:
|
|
963
1255
|
rb_raise(rb_eArgError, "n_threads must be an integer");
|
964
1256
|
return Qnil;
|
965
1257
|
}
|
1258
|
+
if (kw_values[3] != Qundef && !RB_FLOAT_TYPE_P(kw_values[3])) {
|
1259
|
+
rb_raise(rb_eArgError, "scale must be a float");
|
1260
|
+
return Qnil;
|
1261
|
+
}
|
966
1262
|
|
967
1263
|
const char* lora_path = StringValueCStr(kw_values[0]);
|
968
1264
|
const char* base_model_path = kw_values[1] == Qundef ? NULL : StringValueCStr(kw_values[1]);
|
969
1265
|
const int n_threads = kw_values[2] == Qundef ? 1 : NUM2INT(kw_values[2]);
|
1266
|
+
const float scale = kw_values[3] == Qundef ? 1.0 : NUM2DBL(kw_values[3]);
|
970
1267
|
|
971
1268
|
LLaMAModelWrapper* ptr = get_llama_model(self);
|
972
|
-
if (llama_model_apply_lora_from_file(ptr->model, lora_path, base_model_path, n_threads) != 0) {
|
1269
|
+
if (llama_model_apply_lora_from_file(ptr->model, lora_path, scale, base_model_path, n_threads) != 0) {
|
973
1270
|
rb_raise(rb_eRuntimeError, "Failed to apply LoRA");
|
974
1271
|
return Qnil;
|
975
1272
|
}
|
@@ -978,25 +1275,20 @@ private:
|
|
978
1275
|
|
979
1276
|
static VALUE _llama_model_get_model_n_vocab(VALUE self) {
|
980
1277
|
LLaMAModelWrapper* ptr = get_llama_model(self);
|
981
|
-
return INT2NUM(
|
982
|
-
}
|
983
|
-
|
984
|
-
static VALUE _llama_model_get_model_n_ctx(VALUE self) {
|
985
|
-
LLaMAModelWrapper* ptr = get_llama_model(self);
|
986
|
-
return INT2NUM(llama_model_n_ctx(ptr->model));
|
1278
|
+
return INT2NUM(llama_n_vocab(ptr->model));
|
987
1279
|
}
|
988
1280
|
|
989
1281
|
static VALUE _llama_model_get_model_n_ctx_train(VALUE self) {
|
990
1282
|
LLaMAModelWrapper* ptr = get_llama_model(self);
|
991
|
-
return INT2NUM(
|
1283
|
+
return INT2NUM(llama_n_ctx_train(ptr->model));
|
992
1284
|
}
|
993
1285
|
|
994
1286
|
static VALUE _llama_model_get_model_n_embd(VALUE self) {
|
995
1287
|
LLaMAModelWrapper* ptr = get_llama_model(self);
|
996
|
-
return INT2NUM(
|
1288
|
+
return INT2NUM(llama_n_embd(ptr->model));
|
997
1289
|
}
|
998
1290
|
|
999
|
-
static VALUE
|
1291
|
+
static VALUE _llama_model_token_to_piece(VALUE self, VALUE token_) {
|
1000
1292
|
if (!RB_INTEGER_TYPE_P(token_)) {
|
1001
1293
|
rb_raise(rb_eArgError, "token must be an integer");
|
1002
1294
|
return Qnil;
|
@@ -1004,10 +1296,10 @@ private:
|
|
1004
1296
|
const llama_token token = NUM2INT(token_);
|
1005
1297
|
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1006
1298
|
std::vector<char> result(8, 0);
|
1007
|
-
const int n_tokens =
|
1299
|
+
const int n_tokens = llama_token_to_piece(ptr->model, token, result.data(), result.size());
|
1008
1300
|
if (n_tokens < 0) {
|
1009
1301
|
result.resize(-n_tokens);
|
1010
|
-
const int check =
|
1302
|
+
const int check = llama_token_to_piece(ptr->model, token, result.data(), result.size());
|
1011
1303
|
if (check != -n_tokens) {
|
1012
1304
|
rb_raise(rb_eRuntimeError, "failed to convert");
|
1013
1305
|
return Qnil;
|
@@ -1019,7 +1311,7 @@ private:
|
|
1019
1311
|
return rb_str_new_cstr(ret.c_str());
|
1020
1312
|
}
|
1021
1313
|
|
1022
|
-
static VALUE
|
1314
|
+
static VALUE _llama_model_tokenize(int argc, VALUE* argv, VALUE self) {
|
1023
1315
|
VALUE kw_args = Qnil;
|
1024
1316
|
ID kw_table[3] = { rb_intern("text"), rb_intern("n_max_tokens"), rb_intern("add_bos") };
|
1025
1317
|
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
@@ -1046,7 +1338,7 @@ private:
|
|
1046
1338
|
|
1047
1339
|
llama_token* tokens = ALLOCA_N(llama_token, n_max_tokens);
|
1048
1340
|
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1049
|
-
const int n_tokens =
|
1341
|
+
const int n_tokens = llama_tokenize(ptr->model, text.c_str(), text.size(), tokens, n_max_tokens, add_bos);
|
1050
1342
|
|
1051
1343
|
if (n_tokens < 0) {
|
1052
1344
|
rb_raise(rb_eRuntimeError, "failed to tokenize. The numebr of tokens (%d) is greater than n_max_tokens.", -n_tokens);
|
@@ -1345,11 +1637,11 @@ public:
|
|
1345
1637
|
static void define_class(VALUE outer) {
|
1346
1638
|
rb_cLLaMAContext = rb_define_class_under(outer, "Context", rb_cObject);
|
1347
1639
|
rb_define_alloc_func(rb_cLLaMAContext, llama_context_alloc);
|
1640
|
+
rb_define_attr(rb_cLLaMAContext, "model", 1, 0);
|
1348
1641
|
rb_define_method(rb_cLLaMAContext, "initialize", RUBY_METHOD_FUNC(_llama_context_initialize), -1);
|
1349
1642
|
rb_define_method(rb_cLLaMAContext, "eval", RUBY_METHOD_FUNC(_llama_context_eval), -1);
|
1350
1643
|
rb_define_method(rb_cLLaMAContext, "eval_embd", RUBY_METHOD_FUNC(_llama_context_eval_embd), -1);
|
1351
|
-
rb_define_method(rb_cLLaMAContext, "
|
1352
|
-
rb_define_method(rb_cLLaMAContext, "tokenize", RUBY_METHOD_FUNC(_llama_context_tokenize), -1);
|
1644
|
+
rb_define_method(rb_cLLaMAContext, "decode", RUBY_METHOD_FUNC(_llama_context_decode), 1);
|
1353
1645
|
rb_define_method(rb_cLLaMAContext, "logits", RUBY_METHOD_FUNC(_llama_context_logits), 0);
|
1354
1646
|
rb_define_method(rb_cLLaMAContext, "embeddings", RUBY_METHOD_FUNC(_llama_context_embeddings), 0);
|
1355
1647
|
rb_define_method(rb_cLLaMAContext, "text", RUBY_METHOD_FUNC(_llama_context_text), 1);
|
@@ -1358,15 +1650,16 @@ public:
|
|
1358
1650
|
rb_define_method(rb_cLLaMAContext, "token_bos", RUBY_METHOD_FUNC(_llama_context_token_bos), 0);
|
1359
1651
|
rb_define_method(rb_cLLaMAContext, "token_eos", RUBY_METHOD_FUNC(_llama_context_token_eos), 0);
|
1360
1652
|
rb_define_method(rb_cLLaMAContext, "token_nl", RUBY_METHOD_FUNC(_llama_context_token_nl), 0);
|
1361
|
-
rb_define_method(rb_cLLaMAContext, "token_to_piece", RUBY_METHOD_FUNC(_llama_context_token_to_piece), 1);
|
1362
|
-
rb_define_method(rb_cLLaMAContext, "n_vocab", RUBY_METHOD_FUNC(_llama_context_n_vocab), 0);
|
1363
1653
|
rb_define_method(rb_cLLaMAContext, "n_ctx", RUBY_METHOD_FUNC(_llama_context_n_ctx), 0);
|
1364
|
-
rb_define_method(rb_cLLaMAContext, "n_ctx_train", RUBY_METHOD_FUNC(_llama_context_n_ctx_train), 0);
|
1365
|
-
rb_define_method(rb_cLLaMAContext, "n_embd", RUBY_METHOD_FUNC(_llama_context_n_embd), 0);
|
1366
1654
|
rb_define_method(rb_cLLaMAContext, "timings", RUBY_METHOD_FUNC(_llama_context_get_timings), 0);
|
1367
1655
|
rb_define_method(rb_cLLaMAContext, "print_timings", RUBY_METHOD_FUNC(_llama_context_print_timings), 0);
|
1368
1656
|
rb_define_method(rb_cLLaMAContext, "reset_timings", RUBY_METHOD_FUNC(_llama_context_reset_timings), 0);
|
1369
1657
|
rb_define_method(rb_cLLaMAContext, "kv_cache_token_count", RUBY_METHOD_FUNC(_llama_context_kv_cache_token_count), 0);
|
1658
|
+
rb_define_method(rb_cLLaMAContext, "kv_cache_tokens_rm", RUBY_METHOD_FUNC(_llama_context_kv_cache_tokens_rm), 2);
|
1659
|
+
rb_define_method(rb_cLLaMAContext, "kv_cache_seq_rm", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_rm), 3);
|
1660
|
+
rb_define_method(rb_cLLaMAContext, "kv_cache_seq_cp", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_cp), 4);
|
1661
|
+
rb_define_method(rb_cLLaMAContext, "kv_cache_seq_keep", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_keep), 1);
|
1662
|
+
rb_define_method(rb_cLLaMAContext, "kv_cache_seq_shift", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_shift), 4);
|
1370
1663
|
rb_define_method(rb_cLLaMAContext, "set_rng_seed", RUBY_METHOD_FUNC(_llama_context_set_rng_seed), 1);
|
1371
1664
|
rb_define_method(rb_cLLaMAContext, "load_session_file", RUBY_METHOD_FUNC(_llama_context_load_session_file), -1);
|
1372
1665
|
rb_define_method(rb_cLLaMAContext, "save_session_file", RUBY_METHOD_FUNC(_llama_context_save_session_file), -1);
|
@@ -1378,6 +1671,7 @@ public:
|
|
1378
1671
|
rb_define_method(rb_cLLaMAContext, "sample_top_p", RUBY_METHOD_FUNC(_llama_context_sample_top_p), -1);
|
1379
1672
|
rb_define_method(rb_cLLaMAContext, "sample_tail_free", RUBY_METHOD_FUNC(_llama_context_sample_tail_free), -1);
|
1380
1673
|
rb_define_method(rb_cLLaMAContext, "sample_typical", RUBY_METHOD_FUNC(_llama_context_sample_typical), -1);
|
1674
|
+
rb_define_method(rb_cLLaMAContext, "sample_temp", RUBY_METHOD_FUNC(_llama_context_sample_temp), -1);
|
1381
1675
|
rb_define_method(rb_cLLaMAContext, "sample_temperature", RUBY_METHOD_FUNC(_llama_context_sample_temperature), -1);
|
1382
1676
|
rb_define_method(rb_cLLaMAContext, "sample_token_mirostat", RUBY_METHOD_FUNC(_llama_context_sample_token_mirostat), -1);
|
1383
1677
|
rb_define_method(rb_cLLaMAContext, "sample_token_mirostat_v2", RUBY_METHOD_FUNC(_llama_context_sample_token_mirostat_v2), -1);
|
@@ -1392,24 +1686,27 @@ private:
|
|
1392
1686
|
|
1393
1687
|
static VALUE _llama_context_initialize(int argc, VALUE* argv, VALUE self) {
|
1394
1688
|
VALUE kw_args = Qnil;
|
1395
|
-
ID kw_table[
|
1396
|
-
VALUE kw_values[
|
1689
|
+
ID kw_table[2] = { rb_intern("model"), rb_intern("params") };
|
1690
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
1397
1691
|
rb_scan_args(argc, argv, ":", &kw_args);
|
1398
|
-
rb_get_kwargs(kw_args, kw_table,
|
1692
|
+
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
1399
1693
|
|
1400
1694
|
VALUE model = kw_values[0];
|
1401
1695
|
if (!rb_obj_is_kind_of(model, rb_cLLaMAModel)) {
|
1402
1696
|
rb_raise(rb_eArgError, "model must be a Model");
|
1403
1697
|
return Qnil;
|
1404
1698
|
}
|
1699
|
+
VALUE params = kw_values[1];
|
1700
|
+
if (!rb_obj_is_kind_of(params, rb_cLLaMAContextParams)) {
|
1701
|
+
rb_raise(rb_eArgError, "params must be a ContextParams");
|
1702
|
+
return Qnil;
|
1703
|
+
}
|
1405
1704
|
|
1406
1705
|
LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
|
1407
1706
|
if (model_ptr->model == NULL) {
|
1408
1707
|
rb_raise(rb_eRuntimeError, "Model is empty");
|
1409
1708
|
return Qnil;
|
1410
1709
|
}
|
1411
|
-
|
1412
|
-
VALUE params = rb_iv_get(model, "@params");
|
1413
1710
|
LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
|
1414
1711
|
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
1415
1712
|
|
@@ -1421,6 +1718,7 @@ private:
|
|
1421
1718
|
}
|
1422
1719
|
|
1423
1720
|
rb_iv_set(self, "@model", model);
|
1721
|
+
rb_iv_set(self, "@params", params);
|
1424
1722
|
rb_iv_set(self, "@has_evaluated", Qfalse);
|
1425
1723
|
|
1426
1724
|
return Qnil;
|
@@ -1428,8 +1726,8 @@ private:
|
|
1428
1726
|
|
1429
1727
|
static VALUE _llama_context_eval(int argc, VALUE* argv, VALUE self) {
|
1430
1728
|
VALUE kw_args = Qnil;
|
1431
|
-
ID kw_table[
|
1432
|
-
VALUE kw_values[
|
1729
|
+
ID kw_table[3] = { rb_intern("tokens"), rb_intern("n_past"), rb_intern("n_tokens") };
|
1730
|
+
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
1433
1731
|
rb_scan_args(argc, argv, ":", &kw_args);
|
1434
1732
|
rb_get_kwargs(kw_args, kw_table, 2, 2, kw_values);
|
1435
1733
|
|
@@ -1445,10 +1743,6 @@ private:
|
|
1445
1743
|
rb_raise(rb_eArgError, "n_tokens must be an integer");
|
1446
1744
|
return Qnil;
|
1447
1745
|
}
|
1448
|
-
if (kw_values[3] != Qundef && !RB_INTEGER_TYPE_P(kw_values[3])) {
|
1449
|
-
rb_raise(rb_eArgError, "n_threads must be an integer");
|
1450
|
-
return Qnil;
|
1451
|
-
}
|
1452
1746
|
|
1453
1747
|
const size_t tokens_len = RARRAY_LEN(kw_values[0]);
|
1454
1748
|
std::vector<llama_token> embd(tokens_len);
|
@@ -1463,14 +1757,13 @@ private:
|
|
1463
1757
|
|
1464
1758
|
const int n_tokens = kw_values[2] == Qundef ? (int)tokens_len : NUM2INT(kw_values[2]);
|
1465
1759
|
const int n_past = NUM2INT(kw_values[1]);
|
1466
|
-
const int n_threads = kw_values[3] == Qundef ? 1 : NUM2INT(kw_values[3]);
|
1467
1760
|
|
1468
1761
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1469
1762
|
if (ptr->ctx == NULL) {
|
1470
1763
|
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1471
1764
|
return Qnil;
|
1472
1765
|
}
|
1473
|
-
if (llama_eval(ptr->ctx, embd.data(), n_tokens, n_past
|
1766
|
+
if (llama_eval(ptr->ctx, embd.data(), n_tokens, n_past) != 0) {
|
1474
1767
|
rb_raise(rb_eRuntimeError, "Failed to evaluate");
|
1475
1768
|
return Qnil;
|
1476
1769
|
}
|
@@ -1483,8 +1776,8 @@ private:
|
|
1483
1776
|
|
1484
1777
|
static VALUE _llama_context_eval_embd(int argc, VALUE* argv, VALUE self) {
|
1485
1778
|
VALUE kw_args = Qnil;
|
1486
|
-
ID kw_table[
|
1487
|
-
VALUE kw_values[
|
1779
|
+
ID kw_table[3] = { rb_intern("embd"), rb_intern("n_past"), rb_intern("n_tokens") };
|
1780
|
+
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
1488
1781
|
rb_scan_args(argc, argv, ":", &kw_args);
|
1489
1782
|
rb_get_kwargs(kw_args, kw_table, 2, 2, kw_values);
|
1490
1783
|
|
@@ -1500,10 +1793,6 @@ private:
|
|
1500
1793
|
rb_raise(rb_eArgError, "n_tokens must be an integer");
|
1501
1794
|
return Qnil;
|
1502
1795
|
}
|
1503
|
-
if (kw_values[3] != Qundef && !RB_INTEGER_TYPE_P(kw_values[3])) {
|
1504
|
-
rb_raise(rb_eArgError, "n_threads must be an integer");
|
1505
|
-
return Qnil;
|
1506
|
-
}
|
1507
1796
|
|
1508
1797
|
const size_t tokens_len = RARRAY_LEN(kw_values[0]);
|
1509
1798
|
std::vector<float> embd(tokens_len);
|
@@ -1518,14 +1807,13 @@ private:
|
|
1518
1807
|
|
1519
1808
|
const int n_tokens = kw_values[2] == Qundef ? (int)tokens_len : NUM2INT(kw_values[2]);
|
1520
1809
|
const int n_past = NUM2INT(kw_values[1]);
|
1521
|
-
const int n_threads = kw_values[3] == Qundef ? 1 : NUM2INT(kw_values[3]);
|
1522
1810
|
|
1523
1811
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1524
1812
|
if (ptr->ctx == NULL) {
|
1525
1813
|
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1526
1814
|
return Qnil;
|
1527
1815
|
}
|
1528
|
-
if (llama_eval_embd(ptr->ctx, embd.data(), n_tokens, n_past
|
1816
|
+
if (llama_eval_embd(ptr->ctx, embd.data(), n_tokens, n_past) != 0) {
|
1529
1817
|
rb_raise(rb_eRuntimeError, "Failed to evaluate");
|
1530
1818
|
return Qnil;
|
1531
1819
|
}
|
@@ -1536,91 +1824,22 @@ private:
|
|
1536
1824
|
return Qnil;
|
1537
1825
|
}
|
1538
1826
|
|
1539
|
-
static VALUE
|
1827
|
+
static VALUE _llama_context_decode(VALUE self, VALUE batch) {
|
1540
1828
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1541
1829
|
if (ptr->ctx == NULL) {
|
1542
1830
|
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1543
1831
|
return Qnil;
|
1544
1832
|
}
|
1545
|
-
if (!
|
1546
|
-
rb_raise(rb_eArgError, "
|
1833
|
+
if (!rb_obj_is_kind_of(batch, rb_cLLaMABatch)) {
|
1834
|
+
rb_raise(rb_eArgError, "batch must be a Batch");
|
1547
1835
|
return Qnil;
|
1548
1836
|
}
|
1549
|
-
|
1550
|
-
if (
|
1551
|
-
|
1552
|
-
}
|
1553
|
-
RB_GC_GUARD(fname_);
|
1554
|
-
return Qtrue;
|
1555
|
-
}
|
1556
|
-
|
1557
|
-
static VALUE _llama_context_tokenize(int argc, VALUE* argv, VALUE self) {
|
1558
|
-
VALUE kw_args = Qnil;
|
1559
|
-
ID kw_table[3] = { rb_intern("text"), rb_intern("n_max_tokens"), rb_intern("add_bos") };
|
1560
|
-
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
1561
|
-
rb_scan_args(argc, argv, ":", &kw_args);
|
1562
|
-
rb_get_kwargs(kw_args, kw_table, 1, 2, kw_values);
|
1563
|
-
|
1564
|
-
if (!RB_TYPE_P(kw_values[0], T_STRING)) {
|
1565
|
-
rb_raise(rb_eArgError, "text must be a String");
|
1837
|
+
LLaMABatchWrapper* batch_ptr = RbLLaMABatch::get_llama_batch(batch);
|
1838
|
+
if (llama_decode(ptr->ctx, batch_ptr->batch) < 0) {
|
1839
|
+
rb_raise(rb_eRuntimeError, "Failed to decode");
|
1566
1840
|
return Qnil;
|
1567
1841
|
}
|
1568
|
-
|
1569
|
-
rb_raise(rb_eArgError, "n_max_tokens must be an integer");
|
1570
|
-
return Qnil;
|
1571
|
-
}
|
1572
|
-
if (kw_values[2] != Qundef && (kw_values[2] != Qtrue && kw_values[2] != Qfalse)) {
|
1573
|
-
rb_raise(rb_eArgError, "add_bos must be a boolean");
|
1574
|
-
return Qnil;
|
1575
|
-
}
|
1576
|
-
|
1577
|
-
VALUE text_ = kw_values[0];
|
1578
|
-
std::string text = StringValueCStr(text_);
|
1579
|
-
const bool add_bos = kw_values[2] == Qtrue ? true : false;
|
1580
|
-
const int n_max_tokens = kw_values[1] != Qundef ? NUM2INT(kw_values[1]) : text.size() + (add_bos ? 1 : 0);
|
1581
|
-
|
1582
|
-
std::vector<llama_token> tokens(n_max_tokens);
|
1583
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1584
|
-
if (ptr->ctx == NULL) {
|
1585
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1586
|
-
return Qnil;
|
1587
|
-
}
|
1588
|
-
const int n = llama_tokenize(ptr->ctx, text.c_str(), text.size(), tokens.data(), n_max_tokens, add_bos);
|
1589
|
-
if (n < 0) {
|
1590
|
-
rb_raise(rb_eRuntimeError, "Failed to tokenize");
|
1591
|
-
return Qnil;
|
1592
|
-
}
|
1593
|
-
|
1594
|
-
VALUE output = rb_ary_new();
|
1595
|
-
for (int i = 0; i < n; i++) {
|
1596
|
-
rb_ary_push(output, INT2NUM(tokens[i]));
|
1597
|
-
}
|
1598
|
-
|
1599
|
-
RB_GC_GUARD(text_);
|
1600
|
-
return output;
|
1601
|
-
}
|
1602
|
-
|
1603
|
-
static VALUE _llama_context_token_to_piece(VALUE self, VALUE token_) {
|
1604
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1605
|
-
if (ptr->ctx == NULL) {
|
1606
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1607
|
-
return Qnil;
|
1608
|
-
}
|
1609
|
-
const llama_token token = NUM2INT(token_);
|
1610
|
-
std::vector<char> result(8, 0);
|
1611
|
-
const int n_tokens = llama_token_to_piece(ptr->ctx, token, result.data(), result.size());
|
1612
|
-
if (n_tokens < 0) {
|
1613
|
-
result.resize(-n_tokens);
|
1614
|
-
const int check = llama_token_to_piece(ptr->ctx, token, result.data(), result.size());
|
1615
|
-
if (check != -n_tokens) {
|
1616
|
-
rb_raise(rb_eRuntimeError, "failed to convert");
|
1617
|
-
return Qnil;
|
1618
|
-
}
|
1619
|
-
} else {
|
1620
|
-
result.resize(n_tokens);
|
1621
|
-
}
|
1622
|
-
std::string ret(result.data(), result.size());
|
1623
|
-
return rb_str_new_cstr(ret.c_str());
|
1842
|
+
return Qnil;
|
1624
1843
|
}
|
1625
1844
|
|
1626
1845
|
static VALUE _llama_context_logits(VALUE self) {
|
@@ -1635,10 +1854,11 @@ private:
|
|
1635
1854
|
}
|
1636
1855
|
|
1637
1856
|
VALUE model = rb_iv_get(self, "@model");
|
1638
|
-
|
1857
|
+
LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
|
1858
|
+
VALUE params = rb_iv_get(self, "@params");
|
1639
1859
|
LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
|
1640
1860
|
const int n_tokens = prms_ptr->params.logits_all ? NUM2INT(rb_iv_get(self, "@n_tokens")) : 1;
|
1641
|
-
const int n_vocab = llama_n_vocab(
|
1861
|
+
const int n_vocab = llama_n_vocab(model_ptr->model);
|
1642
1862
|
const float* logits = llama_get_logits(ptr->ctx);
|
1643
1863
|
VALUE output = rb_ary_new();
|
1644
1864
|
for (int i = 0; i < n_tokens * n_vocab; i++) {
|
@@ -1655,7 +1875,8 @@ private:
|
|
1655
1875
|
return Qnil;
|
1656
1876
|
}
|
1657
1877
|
VALUE model = rb_iv_get(self, "@model");
|
1658
|
-
|
1878
|
+
LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
|
1879
|
+
VALUE params = rb_iv_get(self, "@params");
|
1659
1880
|
LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
|
1660
1881
|
if (!prms_ptr->params.embedding) {
|
1661
1882
|
rb_raise(rb_eRuntimeError, "embedding parameter is false");
|
@@ -1666,7 +1887,7 @@ private:
|
|
1666
1887
|
return Qnil;
|
1667
1888
|
}
|
1668
1889
|
|
1669
|
-
const int n_embd = llama_n_embd(
|
1890
|
+
const int n_embd = llama_n_embd(model_ptr->model);
|
1670
1891
|
const float* embd = llama_get_embeddings(ptr->ctx);
|
1671
1892
|
VALUE output = rb_ary_new();
|
1672
1893
|
for (int i = 0; i < n_embd; i++) {
|
@@ -1736,81 +1957,104 @@ private:
|
|
1736
1957
|
return INT2NUM(llama_token_nl(ptr->ctx));
|
1737
1958
|
}
|
1738
1959
|
|
1739
|
-
static VALUE
|
1960
|
+
static VALUE _llama_context_n_ctx(VALUE self) {
|
1740
1961
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1741
1962
|
if (ptr->ctx == NULL) {
|
1742
1963
|
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1743
1964
|
return Qnil;
|
1744
1965
|
}
|
1745
|
-
return INT2NUM(
|
1966
|
+
return INT2NUM(llama_n_ctx(ptr->ctx));
|
1746
1967
|
}
|
1747
1968
|
|
1748
|
-
static VALUE
|
1969
|
+
static VALUE _llama_context_get_timings(VALUE self) {
|
1749
1970
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1750
1971
|
if (ptr->ctx == NULL) {
|
1751
1972
|
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1752
1973
|
return Qnil;
|
1753
1974
|
}
|
1754
|
-
|
1975
|
+
VALUE tm_obj = rb_funcall(rb_cLLaMATimings, rb_intern("new"), 0);
|
1976
|
+
LLaMATimingsWrapper* tm_ptr = RbLLaMATimings::get_llama_timings(tm_obj);
|
1977
|
+
tm_ptr->timings = llama_get_timings(ptr->ctx);
|
1978
|
+
return tm_obj;
|
1755
1979
|
}
|
1756
1980
|
|
1757
|
-
static VALUE
|
1981
|
+
static VALUE _llama_context_print_timings(VALUE self) {
|
1758
1982
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1759
1983
|
if (ptr->ctx == NULL) {
|
1760
1984
|
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1761
1985
|
return Qnil;
|
1762
1986
|
}
|
1763
|
-
|
1987
|
+
llama_print_timings(ptr->ctx);
|
1988
|
+
return Qnil;
|
1764
1989
|
}
|
1765
1990
|
|
1766
|
-
static VALUE
|
1991
|
+
static VALUE _llama_context_reset_timings(VALUE self) {
|
1767
1992
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1768
1993
|
if (ptr->ctx == NULL) {
|
1769
1994
|
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1770
1995
|
return Qnil;
|
1771
1996
|
}
|
1772
|
-
|
1997
|
+
llama_reset_timings(ptr->ctx);
|
1998
|
+
return Qnil;
|
1773
1999
|
}
|
1774
2000
|
|
1775
|
-
static VALUE
|
2001
|
+
static VALUE _llama_context_kv_cache_token_count(VALUE self) {
|
1776
2002
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1777
2003
|
if (ptr->ctx == NULL) {
|
1778
2004
|
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1779
2005
|
return Qnil;
|
1780
2006
|
}
|
1781
|
-
|
1782
|
-
LLaMATimingsWrapper* tm_ptr = RbLLaMATimings::get_llama_timings(tm_obj);
|
1783
|
-
tm_ptr->timings = llama_get_timings(ptr->ctx);
|
1784
|
-
return tm_obj;
|
2007
|
+
return INT2NUM(llama_get_kv_cache_token_count(ptr->ctx));
|
1785
2008
|
}
|
1786
2009
|
|
1787
|
-
static VALUE
|
2010
|
+
static VALUE _llama_context_kv_cache_tokens_rm(VALUE self, VALUE c0, VALUE c1) {
|
1788
2011
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1789
2012
|
if (ptr->ctx == NULL) {
|
1790
2013
|
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1791
2014
|
return Qnil;
|
1792
2015
|
}
|
1793
|
-
|
2016
|
+
llama_kv_cache_tokens_rm(ptr->ctx, NUM2INT(c0), NUM2INT(c1));
|
1794
2017
|
return Qnil;
|
1795
2018
|
}
|
1796
2019
|
|
1797
|
-
static VALUE
|
2020
|
+
static VALUE _llama_context_kv_cache_seq_rm(VALUE self, VALUE seq_id, VALUE p0, VALUE p1) {
|
1798
2021
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1799
2022
|
if (ptr->ctx == NULL) {
|
1800
2023
|
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1801
2024
|
return Qnil;
|
1802
2025
|
}
|
1803
|
-
|
2026
|
+
llama_kv_cache_seq_rm(ptr->ctx, NUM2INT(seq_id), NUM2INT(p0), NUM2INT(p1));
|
1804
2027
|
return Qnil;
|
1805
2028
|
}
|
1806
2029
|
|
1807
|
-
static VALUE
|
2030
|
+
static VALUE _llama_context_kv_cache_seq_cp(VALUE self, VALUE seq_id_src, VALUE seq_id_dst, VALUE p0, VALUE p1) {
|
1808
2031
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1809
2032
|
if (ptr->ctx == NULL) {
|
1810
|
-
rb_raise(
|
2033
|
+
rb_raise(rb_eArgError, "LLaMA context is not initialized");
|
1811
2034
|
return Qnil;
|
1812
2035
|
}
|
1813
|
-
|
2036
|
+
llama_kv_cache_seq_cp(ptr->ctx, NUM2INT(seq_id_src), NUM2INT(seq_id_dst), NUM2INT(p0), NUM2INT(p1));
|
2037
|
+
return Qnil;
|
2038
|
+
}
|
2039
|
+
|
2040
|
+
static VALUE _llama_context_kv_cache_seq_keep(VALUE self, VALUE seq_id) {
|
2041
|
+
LLaMAContextWrapper* ptr = get_llama_context(self);
|
2042
|
+
if (ptr->ctx == NULL) {
|
2043
|
+
rb_raise(rb_eArgError, "LLaMA context is not initialized");
|
2044
|
+
return Qnil;
|
2045
|
+
}
|
2046
|
+
llama_kv_cache_seq_keep(ptr->ctx, NUM2INT(seq_id));
|
2047
|
+
return Qnil;
|
2048
|
+
}
|
2049
|
+
|
2050
|
+
static VALUE _llama_context_kv_cache_seq_shift(VALUE self, VALUE seq_id, VALUE p0, VALUE p1, VALUE delta) {
|
2051
|
+
LLaMAContextWrapper* ptr = get_llama_context(self);
|
2052
|
+
if (ptr->ctx == NULL) {
|
2053
|
+
rb_raise(rb_eArgError, "LLaMA context is not initialized");
|
2054
|
+
return Qnil;
|
2055
|
+
}
|
2056
|
+
llama_kv_cache_seq_shift(ptr->ctx, NUM2INT(seq_id), NUM2INT(p0), NUM2INT(p1), NUM2INT(delta));
|
2057
|
+
return Qnil;
|
1814
2058
|
}
|
1815
2059
|
|
1816
2060
|
static VALUE _llama_context_set_rng_seed(VALUE self, VALUE seed_) {
|
@@ -1851,7 +2095,7 @@ private:
|
|
1851
2095
|
}
|
1852
2096
|
|
1853
2097
|
VALUE model = rb_iv_get(self, "@model");
|
1854
|
-
VALUE params = rb_iv_get(
|
2098
|
+
VALUE params = rb_iv_get(self, "@params");
|
1855
2099
|
LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
|
1856
2100
|
const int n_ctx = prms_ptr->params.n_ctx;
|
1857
2101
|
|
@@ -2235,6 +2479,40 @@ private:
|
|
2235
2479
|
return Qnil;
|
2236
2480
|
}
|
2237
2481
|
|
2482
|
+
static VALUE _llama_context_sample_temp(int argc, VALUE* argv, VALUE self) {
|
2483
|
+
VALUE kw_args = Qnil;
|
2484
|
+
ID kw_table[1] = { rb_intern("temp") };
|
2485
|
+
VALUE kw_values[1] = { Qundef };
|
2486
|
+
VALUE candidates = Qnil;
|
2487
|
+
rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
|
2488
|
+
rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
|
2489
|
+
|
2490
|
+
if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
|
2491
|
+
rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
|
2492
|
+
return Qnil;
|
2493
|
+
}
|
2494
|
+
if (!RB_FLOAT_TYPE_P(kw_values[0])) {
|
2495
|
+
rb_raise(rb_eArgError, "temp must be a float");
|
2496
|
+
return Qnil;
|
2497
|
+
}
|
2498
|
+
|
2499
|
+
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
2500
|
+
if (ctx_ptr->ctx == NULL) {
|
2501
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
2502
|
+
return Qnil;
|
2503
|
+
}
|
2504
|
+
LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
|
2505
|
+
if (cnd_ptr->array.data == nullptr) {
|
2506
|
+
rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
|
2507
|
+
return Qnil;
|
2508
|
+
}
|
2509
|
+
const float temp = NUM2DBL(kw_values[0]);
|
2510
|
+
|
2511
|
+
llama_sample_temp(ctx_ptr->ctx, &(cnd_ptr->array), temp);
|
2512
|
+
|
2513
|
+
return Qnil;
|
2514
|
+
}
|
2515
|
+
|
2238
2516
|
static VALUE _llama_context_sample_temperature(int argc, VALUE* argv, VALUE self) {
|
2239
2517
|
VALUE kw_args = Qnil;
|
2240
2518
|
ID kw_table[1] = { rb_intern("temperature") };
|
@@ -2560,6 +2838,7 @@ extern "C" void Init_llama_cpp(void) {
|
|
2560
2838
|
RbLLaMATokenData::define_class(rb_mLLaMACpp);
|
2561
2839
|
RbLLaMATokenDataArray::define_class(rb_mLLaMACpp);
|
2562
2840
|
RbLLaMAModel::define_class(rb_mLLaMACpp);
|
2841
|
+
RbLLaMAModelParams::define_class(rb_mLLaMACpp);
|
2563
2842
|
RbLLaMATimings::define_class(rb_mLLaMACpp);
|
2564
2843
|
RbLLaMAContext::define_class(rb_mLLaMACpp);
|
2565
2844
|
RbLLaMAContextParams::define_class(rb_mLLaMACpp);
|
@@ -2578,10 +2857,6 @@ extern "C" void Init_llama_cpp(void) {
|
|
2578
2857
|
|
2579
2858
|
rb_define_const(rb_mLLaMACpp, "LLAMA_MAX_DEVICES", INT2NUM(LLAMA_MAX_DEVICES));
|
2580
2859
|
|
2581
|
-
rb_define_const(rb_mLLaMACpp, "LLAMA_LOG_LEVEL_ERROR", INT2NUM(LLAMA_LOG_LEVEL_ERROR));
|
2582
|
-
rb_define_const(rb_mLLaMACpp, "LLAMA_LOG_LEVEL_WARN", INT2NUM(LLAMA_LOG_LEVEL_WARN));
|
2583
|
-
rb_define_const(rb_mLLaMACpp, "LLAMA_LOG_LEVEL_INFO", INT2NUM(LLAMA_LOG_LEVEL_INFO));
|
2584
|
-
|
2585
2860
|
rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_TYPE_SPM", INT2NUM(LLAMA_VOCAB_TYPE_SPM));
|
2586
2861
|
rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_TYPE_BPE", INT2NUM(LLAMA_VOCAB_TYPE_BPE));
|
2587
2862
|
|