llama_cpp 0.5.3 → 0.6.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/README.md +6 -5
- data/examples/chat.rb +13 -13
- data/examples/embedding.rb +9 -9
- data/ext/llama_cpp/llama_cpp.cpp +547 -272
- data/ext/llama_cpp/src/ggml-alloc.c +8 -2
- data/ext/llama_cpp/src/ggml-alloc.h +1 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +209 -82
- data/ext/llama_cpp/src/ggml-cuda.h +1 -0
- data/ext/llama_cpp/src/ggml-metal.h +4 -0
- data/ext/llama_cpp/src/ggml-metal.m +163 -84
- data/ext/llama_cpp/src/ggml-metal.metal +121 -38
- data/ext/llama_cpp/src/ggml.c +1596 -842
- data/ext/llama_cpp/src/ggml.h +116 -35
- data/ext/llama_cpp/src/llama.cpp +1015 -586
- data/ext/llama_cpp/src/llama.h +304 -119
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +5 -9
- data/sig/llama_cpp.rbs +65 -34
- metadata +3 -3
data/ext/llama_cpp/llama_cpp.cpp
CHANGED
@@ -1,7 +1,9 @@
|
|
1
1
|
#include "llama_cpp.h"
|
2
2
|
|
3
3
|
VALUE rb_mLLaMACpp;
|
4
|
+
VALUE rb_cLLaMABatch;
|
4
5
|
VALUE rb_cLLaMAModel;
|
6
|
+
VALUE rb_cLLaMAModelParams;
|
5
7
|
VALUE rb_cLLaMATimings;
|
6
8
|
VALUE rb_cLLaMAContext;
|
7
9
|
VALUE rb_cLLaMAContextParams;
|
@@ -11,6 +13,238 @@ VALUE rb_cLLaMATokenDataArray;
|
|
11
13
|
VALUE rb_cLLaMAGrammarElement;
|
12
14
|
VALUE rb_cLLaMAGrammar;
|
13
15
|
|
16
|
+
class LLaMABatchWrapper {
|
17
|
+
public:
|
18
|
+
llama_batch batch;
|
19
|
+
|
20
|
+
LLaMABatchWrapper() {}
|
21
|
+
|
22
|
+
~LLaMABatchWrapper() {
|
23
|
+
llama_batch_free(batch);
|
24
|
+
}
|
25
|
+
};
|
26
|
+
|
27
|
+
class RbLLaMABatch {
|
28
|
+
public:
|
29
|
+
static VALUE llama_batch_alloc(VALUE self) {
|
30
|
+
LLaMABatchWrapper* ptr = (LLaMABatchWrapper*)ruby_xmalloc(sizeof(LLaMABatchWrapper));
|
31
|
+
new (ptr) LLaMABatchWrapper();
|
32
|
+
return TypedData_Wrap_Struct(self, &llama_batch_type, ptr);
|
33
|
+
}
|
34
|
+
|
35
|
+
static void llama_batch_free(void* ptr) {
|
36
|
+
((LLaMABatchWrapper*)ptr)->~LLaMABatchWrapper();
|
37
|
+
ruby_xfree(ptr);
|
38
|
+
}
|
39
|
+
|
40
|
+
static size_t llama_batch_size(const void* ptr) {
|
41
|
+
return sizeof(*((LLaMABatchWrapper*)ptr));
|
42
|
+
}
|
43
|
+
|
44
|
+
static LLaMABatchWrapper* get_llama_batch(VALUE self) {
|
45
|
+
LLaMABatchWrapper* ptr;
|
46
|
+
TypedData_Get_Struct(self, LLaMABatchWrapper, &llama_batch_type, ptr);
|
47
|
+
return ptr;
|
48
|
+
}
|
49
|
+
|
50
|
+
static void define_class(VALUE outer) {
|
51
|
+
rb_cLLaMABatch = rb_define_class_under(outer, "Batch", rb_cObject);
|
52
|
+
rb_define_alloc_func(rb_cLLaMABatch, llama_batch_alloc);
|
53
|
+
rb_define_method(rb_cLLaMABatch, "initialize", RUBY_METHOD_FUNC(_llama_batch_initialize), -1);
|
54
|
+
rb_define_method(rb_cLLaMABatch, "n_tokens=", RUBY_METHOD_FUNC(_llama_batch_set_n_tokens), 1);
|
55
|
+
rb_define_method(rb_cLLaMABatch, "n_tokens", RUBY_METHOD_FUNC(_llama_batch_get_n_tokens), 0);
|
56
|
+
rb_define_method(rb_cLLaMABatch, "all_pos_zero=", RUBY_METHOD_FUNC(_llama_batch_set_all_pos_zero), 1);
|
57
|
+
rb_define_method(rb_cLLaMABatch, "all_pos_zero", RUBY_METHOD_FUNC(_llama_batch_get_all_pos_zero), 0);
|
58
|
+
rb_define_method(rb_cLLaMABatch, "all_pos_one=", RUBY_METHOD_FUNC(_llama_batch_set_all_pos_one), 1);
|
59
|
+
rb_define_method(rb_cLLaMABatch, "all_pos_one", RUBY_METHOD_FUNC(_llama_batch_get_all_pos_one), 0);
|
60
|
+
rb_define_method(rb_cLLaMABatch, "all_seq_id=", RUBY_METHOD_FUNC(_llama_batch_set_all_seq_id), 1);
|
61
|
+
rb_define_method(rb_cLLaMABatch, "all_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_all_seq_id), 0);
|
62
|
+
rb_define_method(rb_cLLaMABatch, "set_token", RUBY_METHOD_FUNC(_llama_batch_set_token), 2);
|
63
|
+
rb_define_method(rb_cLLaMABatch, "get_token", RUBY_METHOD_FUNC(_llama_batch_get_token), 1);
|
64
|
+
rb_define_method(rb_cLLaMABatch, "set_pos", RUBY_METHOD_FUNC(_llama_batch_set_pos), 2);
|
65
|
+
rb_define_method(rb_cLLaMABatch, "get_pos", RUBY_METHOD_FUNC(_llama_batch_get_pos), 1);
|
66
|
+
rb_define_method(rb_cLLaMABatch, "set_seq_id", RUBY_METHOD_FUNC(_llama_batch_set_seq_id), 2);
|
67
|
+
rb_define_method(rb_cLLaMABatch, "get_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_seq_id), 1);
|
68
|
+
rb_define_method(rb_cLLaMABatch, "set_logits", RUBY_METHOD_FUNC(_llama_batch_set_logits), 2);
|
69
|
+
rb_define_method(rb_cLLaMABatch, "get_logits", RUBY_METHOD_FUNC(_llama_batch_get_logits), 1);
|
70
|
+
}
|
71
|
+
|
72
|
+
private:
|
73
|
+
static const rb_data_type_t llama_batch_type;
|
74
|
+
|
75
|
+
static VALUE _llama_batch_initialize(int argc, VALUE* argv, VALUE self) {
|
76
|
+
VALUE kw_args = Qnil;
|
77
|
+
ID kw_table[2] = { rb_intern("n_tokens"), rb_intern("embd") };
|
78
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
79
|
+
rb_scan_args(argc, argv, ":", &kw_args);
|
80
|
+
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
81
|
+
|
82
|
+
if (!RB_INTEGER_TYPE_P(kw_values[0])) {
|
83
|
+
rb_raise(rb_eArgError, "n_tokens must be an integer");
|
84
|
+
return Qnil;
|
85
|
+
}
|
86
|
+
if (!RB_INTEGER_TYPE_P(kw_values[1])) {
|
87
|
+
rb_raise(rb_eArgError, "embd must be an integer");
|
88
|
+
return Qnil;
|
89
|
+
}
|
90
|
+
|
91
|
+
const int32_t n_tokens = NUM2INT(kw_values[0]);
|
92
|
+
const int32_t embd = NUM2INT(kw_values[1]);
|
93
|
+
|
94
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
95
|
+
ptr->batch = llama_batch_init(n_tokens, embd);
|
96
|
+
|
97
|
+
return Qnil;
|
98
|
+
}
|
99
|
+
|
100
|
+
// n_tokens
|
101
|
+
static VALUE _llama_batch_set_n_tokens(VALUE self, VALUE n_tokens) {
|
102
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
103
|
+
ptr->batch.n_tokens = NUM2INT(n_tokens);
|
104
|
+
return INT2NUM(ptr->batch.n_tokens);
|
105
|
+
}
|
106
|
+
|
107
|
+
static VALUE _llama_batch_get_n_tokens(VALUE self) {
|
108
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
109
|
+
return INT2NUM(ptr->batch.n_tokens);
|
110
|
+
}
|
111
|
+
|
112
|
+
// all_pos_0
|
113
|
+
static VALUE _llama_batch_set_all_pos_zero(VALUE self, VALUE all_pos_0) {
|
114
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
115
|
+
ptr->batch.all_pos_0 = NUM2INT(all_pos_0);
|
116
|
+
return INT2NUM(ptr->batch.all_pos_0);
|
117
|
+
}
|
118
|
+
|
119
|
+
static VALUE _llama_batch_get_all_pos_zero(VALUE self) {
|
120
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
121
|
+
return INT2NUM(ptr->batch.all_pos_0);
|
122
|
+
}
|
123
|
+
|
124
|
+
// all_pos_1
|
125
|
+
static VALUE _llama_batch_set_all_pos_one(VALUE self, VALUE all_pos_1) {
|
126
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
127
|
+
ptr->batch.all_pos_1 = NUM2INT(all_pos_1);
|
128
|
+
return INT2NUM(ptr->batch.all_pos_1);
|
129
|
+
}
|
130
|
+
|
131
|
+
static VALUE _llama_batch_get_all_pos_one(VALUE self) {
|
132
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
133
|
+
return INT2NUM(ptr->batch.all_pos_1);
|
134
|
+
}
|
135
|
+
|
136
|
+
// all_seq_id
|
137
|
+
static VALUE _llama_batch_set_all_seq_id(VALUE self, VALUE all_seq_id) {
|
138
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
139
|
+
ptr->batch.all_seq_id = NUM2INT(all_seq_id);
|
140
|
+
return INT2NUM(ptr->batch.all_seq_id);
|
141
|
+
}
|
142
|
+
|
143
|
+
static VALUE _llama_batch_get_all_seq_id(VALUE self) {
|
144
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
145
|
+
return INT2NUM(ptr->batch.all_seq_id);
|
146
|
+
}
|
147
|
+
|
148
|
+
// token
|
149
|
+
static VALUE _llama_batch_set_token(VALUE self, VALUE idx, VALUE value) {
|
150
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
151
|
+
const int32_t id = NUM2INT(idx);
|
152
|
+
if (id < 0 || id >= ptr->batch.n_tokens) {
|
153
|
+
rb_raise(rb_eArgError, "idx must be in [0, n_tokens)");
|
154
|
+
return Qnil;
|
155
|
+
}
|
156
|
+
ptr->batch.token[id] = NUM2INT(value);
|
157
|
+
return INT2NUM(ptr->batch.token[id]);
|
158
|
+
}
|
159
|
+
|
160
|
+
static VALUE _llama_batch_get_token(VALUE self, VALUE idx) {
|
161
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
162
|
+
const int32_t id = NUM2INT(idx);
|
163
|
+
if (id < 0 || id >= ptr->batch.n_tokens) {
|
164
|
+
rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
|
165
|
+
return Qnil;
|
166
|
+
}
|
167
|
+
return INT2NUM(ptr->batch.token[id]);
|
168
|
+
}
|
169
|
+
|
170
|
+
// pos
|
171
|
+
static VALUE _llama_batch_set_pos(VALUE self, VALUE idx, VALUE value) {
|
172
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
173
|
+
const int32_t id = NUM2INT(idx);
|
174
|
+
if (id < 0 || id >= ptr->batch.n_tokens) {
|
175
|
+
rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
|
176
|
+
return Qnil;
|
177
|
+
}
|
178
|
+
ptr->batch.pos[id] = NUM2INT(value);
|
179
|
+
return INT2NUM(ptr->batch.pos[id]);
|
180
|
+
}
|
181
|
+
|
182
|
+
static VALUE _llama_batch_get_pos(VALUE self, VALUE idx) {
|
183
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
184
|
+
const int32_t id = NUM2INT(idx);
|
185
|
+
if (id < 0 || id >= ptr->batch.n_tokens) {
|
186
|
+
rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
|
187
|
+
return Qnil;
|
188
|
+
}
|
189
|
+
return INT2NUM(ptr->batch.pos[id]);
|
190
|
+
}
|
191
|
+
|
192
|
+
// seq_id
|
193
|
+
static VALUE _llama_batch_set_seq_id(VALUE self, VALUE idx, VALUE value) {
|
194
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
195
|
+
const int32_t id = NUM2INT(idx);
|
196
|
+
if (id < 0 || id >= ptr->batch.n_tokens) {
|
197
|
+
rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
|
198
|
+
return Qnil;
|
199
|
+
}
|
200
|
+
ptr->batch.seq_id[id] = NUM2INT(value);
|
201
|
+
return INT2NUM(ptr->batch.seq_id[id]);
|
202
|
+
}
|
203
|
+
|
204
|
+
static VALUE _llama_batch_get_seq_id(VALUE self, VALUE idx) {
|
205
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
206
|
+
const int32_t id = NUM2INT(idx);
|
207
|
+
if (id < 0 || id >= ptr->batch.n_tokens) {
|
208
|
+
rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
|
209
|
+
return Qnil;
|
210
|
+
}
|
211
|
+
return INT2NUM(ptr->batch.seq_id[id]);
|
212
|
+
}
|
213
|
+
|
214
|
+
// logits
|
215
|
+
static VALUE _llama_batch_set_logits(VALUE self, VALUE idx, VALUE value) {
|
216
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
217
|
+
const int32_t id = NUM2INT(idx);
|
218
|
+
if (id < 0 || id >= ptr->batch.n_tokens) {
|
219
|
+
rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
|
220
|
+
return Qnil;
|
221
|
+
}
|
222
|
+
ptr->batch.logits[id] = RTEST(value) ? true : false;
|
223
|
+
return ptr->batch.logits[id] ? Qtrue : Qfalse;
|
224
|
+
}
|
225
|
+
|
226
|
+
static VALUE _llama_batch_get_logits(VALUE self, VALUE idx) {
|
227
|
+
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
228
|
+
const int32_t id = NUM2INT(idx);
|
229
|
+
if (id < 0 || id >= ptr->batch.n_tokens) {
|
230
|
+
rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
|
231
|
+
return Qnil;
|
232
|
+
}
|
233
|
+
return ptr->batch.logits[id] ? Qtrue : Qfalse;
|
234
|
+
}
|
235
|
+
};
|
236
|
+
|
237
|
+
const rb_data_type_t RbLLaMABatch::llama_batch_type = {
|
238
|
+
"RbLLaMABatch",
|
239
|
+
{ NULL,
|
240
|
+
RbLLaMABatch::llama_batch_free,
|
241
|
+
RbLLaMABatch::llama_batch_size },
|
242
|
+
NULL,
|
243
|
+
NULL,
|
244
|
+
RUBY_TYPED_FREE_IMMEDIATELY
|
245
|
+
|
246
|
+
};
|
247
|
+
|
14
248
|
class LLaMATokenDataWrapper {
|
15
249
|
public:
|
16
250
|
llama_token_data data;
|
@@ -363,6 +597,144 @@ const rb_data_type_t RbLLaMATimings::llama_timings_type = {
|
|
363
597
|
RUBY_TYPED_FREE_IMMEDIATELY
|
364
598
|
};
|
365
599
|
|
600
|
+
class LLaMAModelParamsWrapper {
|
601
|
+
public:
|
602
|
+
struct llama_model_params params;
|
603
|
+
|
604
|
+
LLaMAModelParamsWrapper() : params(llama_model_default_params()) {}
|
605
|
+
|
606
|
+
~LLaMAModelParamsWrapper() {}
|
607
|
+
};
|
608
|
+
|
609
|
+
class RbLLaMAModelParams {
|
610
|
+
public:
|
611
|
+
static VALUE llama_model_params_alloc(VALUE self) {
|
612
|
+
LLaMAModelParamsWrapper* ptr = (LLaMAModelParamsWrapper*)ruby_xmalloc(sizeof(LLaMAModelParamsWrapper));
|
613
|
+
new (ptr) LLaMAModelParamsWrapper();
|
614
|
+
return TypedData_Wrap_Struct(self, &llama_model_params_type, ptr);
|
615
|
+
}
|
616
|
+
|
617
|
+
static void llama_model_params_free(void* ptr) {
|
618
|
+
((LLaMAModelParamsWrapper*)ptr)->~LLaMAModelParamsWrapper();
|
619
|
+
ruby_xfree(ptr);
|
620
|
+
}
|
621
|
+
|
622
|
+
static size_t llama_model_params_size(const void* ptr) {
|
623
|
+
return sizeof(*((LLaMAModelParamsWrapper*)ptr));
|
624
|
+
}
|
625
|
+
|
626
|
+
static LLaMAModelParamsWrapper* get_llama_model_params(VALUE self) {
|
627
|
+
LLaMAModelParamsWrapper* ptr;
|
628
|
+
TypedData_Get_Struct(self, LLaMAModelParamsWrapper, &llama_model_params_type, ptr);
|
629
|
+
return ptr;
|
630
|
+
}
|
631
|
+
|
632
|
+
static void define_class(VALUE outer) {
|
633
|
+
rb_cLLaMAModelParams = rb_define_class_under(outer, "ModelParams", rb_cObject);
|
634
|
+
rb_define_alloc_func(rb_cLLaMAModelParams, llama_model_params_alloc);
|
635
|
+
rb_define_method(rb_cLLaMAModelParams, "n_gpu_layers=", RUBY_METHOD_FUNC(_llama_model_params_set_n_gpu_layers), 1);
|
636
|
+
rb_define_method(rb_cLLaMAModelParams, "n_gpu_layers", RUBY_METHOD_FUNC(_llama_model_params_get_n_gpu_layers), 0);
|
637
|
+
rb_define_method(rb_cLLaMAModelParams, "main_gpu=", RUBY_METHOD_FUNC(_llama_model_params_set_main_gpu), 1);
|
638
|
+
rb_define_method(rb_cLLaMAModelParams, "main_gpu", RUBY_METHOD_FUNC(_llama_model_params_get_main_gpu), 0);
|
639
|
+
rb_define_method(rb_cLLaMAModelParams, "tensor_split", RUBY_METHOD_FUNC(_llama_model_params_get_tensor_split), 0);
|
640
|
+
rb_define_method(rb_cLLaMAModelParams, "vocab_only=", RUBY_METHOD_FUNC(_llama_model_params_set_vocab_only), 1);
|
641
|
+
rb_define_method(rb_cLLaMAModelParams, "vocab_only", RUBY_METHOD_FUNC(_llama_model_params_get_vocab_only), 0);
|
642
|
+
rb_define_method(rb_cLLaMAModelParams, "use_mmap=", RUBY_METHOD_FUNC(_llama_model_params_set_use_mmap), 1);
|
643
|
+
rb_define_method(rb_cLLaMAModelParams, "use_mmap", RUBY_METHOD_FUNC(_llama_model_params_get_use_mmap), 0);
|
644
|
+
rb_define_method(rb_cLLaMAModelParams, "use_mlock=", RUBY_METHOD_FUNC(_llama_model_params_set_use_mlock), 1);
|
645
|
+
rb_define_method(rb_cLLaMAModelParams, "use_mlock", RUBY_METHOD_FUNC(_llama_model_params_get_use_mlock), 0);
|
646
|
+
}
|
647
|
+
|
648
|
+
private:
|
649
|
+
static const rb_data_type_t llama_model_params_type;
|
650
|
+
|
651
|
+
// n_gpu_layers
|
652
|
+
static VALUE _llama_model_params_set_n_gpu_layers(VALUE self, VALUE n_gpu_layers) {
|
653
|
+
LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
|
654
|
+
ptr->params.n_gpu_layers = NUM2INT(n_gpu_layers);
|
655
|
+
return INT2NUM(ptr->params.n_gpu_layers);
|
656
|
+
}
|
657
|
+
|
658
|
+
static VALUE _llama_model_params_get_n_gpu_layers(VALUE self) {
|
659
|
+
LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
|
660
|
+
return INT2NUM(ptr->params.n_gpu_layers);
|
661
|
+
}
|
662
|
+
|
663
|
+
// main_gpu
|
664
|
+
static VALUE _llama_model_params_set_main_gpu(VALUE self, VALUE main_gpu) {
|
665
|
+
LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
|
666
|
+
ptr->params.main_gpu = NUM2INT(main_gpu);
|
667
|
+
return INT2NUM(ptr->params.main_gpu);
|
668
|
+
}
|
669
|
+
|
670
|
+
static VALUE _llama_model_params_get_main_gpu(VALUE self) {
|
671
|
+
LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
|
672
|
+
return INT2NUM(ptr->params.main_gpu);
|
673
|
+
}
|
674
|
+
|
675
|
+
// tensor_split
|
676
|
+
static VALUE _llama_model_params_get_tensor_split(VALUE self) {
|
677
|
+
if (LLAMA_MAX_DEVICES < 1) {
|
678
|
+
return rb_ary_new();
|
679
|
+
}
|
680
|
+
VALUE ret = rb_ary_new2(LLAMA_MAX_DEVICES);
|
681
|
+
LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
|
682
|
+
if (ptr->params.tensor_split == nullptr) {
|
683
|
+
return rb_ary_new();
|
684
|
+
}
|
685
|
+
for (size_t i = 0; i < LLAMA_MAX_DEVICES; i++) {
|
686
|
+
rb_ary_store(ret, i, DBL2NUM(ptr->params.tensor_split[i]));
|
687
|
+
}
|
688
|
+
return ret;
|
689
|
+
}
|
690
|
+
|
691
|
+
// vocab_only
|
692
|
+
static VALUE _llama_model_params_set_vocab_only(VALUE self, VALUE vocab_only) {
|
693
|
+
LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
|
694
|
+
ptr->params.vocab_only = RTEST(vocab_only) ? true : false;
|
695
|
+
return ptr->params.vocab_only ? Qtrue : Qfalse;
|
696
|
+
}
|
697
|
+
|
698
|
+
static VALUE _llama_model_params_get_vocab_only(VALUE self) {
|
699
|
+
LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
|
700
|
+
return ptr->params.vocab_only ? Qtrue : Qfalse;
|
701
|
+
}
|
702
|
+
|
703
|
+
// use_mmap
|
704
|
+
static VALUE _llama_model_params_set_use_mmap(VALUE self, VALUE use_mmap) {
|
705
|
+
LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
|
706
|
+
ptr->params.use_mmap = RTEST(use_mmap) ? true : false;
|
707
|
+
return ptr->params.use_mmap ? Qtrue : Qfalse;
|
708
|
+
}
|
709
|
+
|
710
|
+
static VALUE _llama_model_params_get_use_mmap(VALUE self) {
|
711
|
+
LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
|
712
|
+
return ptr->params.use_mmap ? Qtrue : Qfalse;
|
713
|
+
}
|
714
|
+
|
715
|
+
// use_mlock
|
716
|
+
static VALUE _llama_model_params_set_use_mlock(VALUE self, VALUE use_mlock) {
|
717
|
+
LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
|
718
|
+
ptr->params.use_mlock = RTEST(use_mlock) ? true : false;
|
719
|
+
return ptr->params.use_mlock ? Qtrue : Qfalse;
|
720
|
+
}
|
721
|
+
|
722
|
+
static VALUE _llama_model_params_get_use_mlock(VALUE self) {
|
723
|
+
LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
|
724
|
+
return ptr->params.use_mlock ? Qtrue : Qfalse;
|
725
|
+
}
|
726
|
+
};
|
727
|
+
|
728
|
+
const rb_data_type_t RbLLaMAModelParams::llama_model_params_type = {
|
729
|
+
"RbLLaMAModelParams",
|
730
|
+
{ NULL,
|
731
|
+
RbLLaMAModelParams::llama_model_params_free,
|
732
|
+
RbLLaMAModelParams::llama_model_params_size },
|
733
|
+
NULL,
|
734
|
+
NULL,
|
735
|
+
RUBY_TYPED_FREE_IMMEDIATELY
|
736
|
+
};
|
737
|
+
|
366
738
|
class LLaMAContextParamsWrapper {
|
367
739
|
public:
|
368
740
|
struct llama_context_params params;
|
@@ -399,35 +771,26 @@ public:
|
|
399
771
|
rb_cLLaMAContextParams = rb_define_class_under(outer, "ContextParams", rb_cObject);
|
400
772
|
rb_define_alloc_func(rb_cLLaMAContextParams, llama_context_params_alloc);
|
401
773
|
// rb_define_method(rb_cLLaMAContextParams, "initialize", RUBY_METHOD_FUNC(_llama_context_params_init), 0);
|
774
|
+
rb_define_method(rb_cLLaMAContextParams, "seed=", RUBY_METHOD_FUNC(_llama_context_params_set_seed), 1);
|
775
|
+
rb_define_method(rb_cLLaMAContextParams, "seed", RUBY_METHOD_FUNC(_llama_context_params_get_seed), 0);
|
402
776
|
rb_define_method(rb_cLLaMAContextParams, "n_ctx=", RUBY_METHOD_FUNC(_llama_context_params_set_n_ctx), 1);
|
403
777
|
rb_define_method(rb_cLLaMAContextParams, "n_ctx", RUBY_METHOD_FUNC(_llama_context_params_get_n_ctx), 0);
|
404
778
|
rb_define_method(rb_cLLaMAContextParams, "n_batch=", RUBY_METHOD_FUNC(_llama_context_params_set_n_batch), 1);
|
405
779
|
rb_define_method(rb_cLLaMAContextParams, "n_batch", RUBY_METHOD_FUNC(_llama_context_params_get_n_batch), 0);
|
406
|
-
rb_define_method(rb_cLLaMAContextParams, "
|
407
|
-
rb_define_method(rb_cLLaMAContextParams, "
|
408
|
-
rb_define_method(rb_cLLaMAContextParams, "
|
409
|
-
rb_define_method(rb_cLLaMAContextParams, "
|
410
|
-
rb_define_method(rb_cLLaMAContextParams, "tensor_split", RUBY_METHOD_FUNC(_llama_context_params_get_tensor_split), 0);
|
780
|
+
rb_define_method(rb_cLLaMAContextParams, "n_threads=", RUBY_METHOD_FUNC(_llama_context_params_set_n_threads), 1);
|
781
|
+
rb_define_method(rb_cLLaMAContextParams, "n_threads", RUBY_METHOD_FUNC(_llama_context_params_get_n_threads), 0);
|
782
|
+
rb_define_method(rb_cLLaMAContextParams, "n_threads_batch=", RUBY_METHOD_FUNC(_llama_context_params_set_n_threads_batch), 1);
|
783
|
+
rb_define_method(rb_cLLaMAContextParams, "n_threads_batch", RUBY_METHOD_FUNC(_llama_context_params_get_n_threads_batch), 0);
|
411
784
|
rb_define_method(rb_cLLaMAContextParams, "rope_freq_base=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_freq_base), 1);
|
412
785
|
rb_define_method(rb_cLLaMAContextParams, "rope_freq_base", RUBY_METHOD_FUNC(_llama_context_params_get_rope_freq_base), 0);
|
413
786
|
rb_define_method(rb_cLLaMAContextParams, "rope_freq_scale=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_freq_scale), 1);
|
414
787
|
rb_define_method(rb_cLLaMAContextParams, "rope_freq_scale", RUBY_METHOD_FUNC(_llama_context_params_get_rope_freq_scale), 0);
|
415
|
-
rb_define_method(rb_cLLaMAContextParams, "low_vram=", RUBY_METHOD_FUNC(_llama_context_params_set_low_vram), 1);
|
416
|
-
rb_define_method(rb_cLLaMAContextParams, "low_vram", RUBY_METHOD_FUNC(_llama_context_params_get_low_vram), 0);
|
417
788
|
rb_define_method(rb_cLLaMAContextParams, "mul_mat_q=", RUBY_METHOD_FUNC(_llama_context_params_set_mul_mat_q), 1);
|
418
789
|
rb_define_method(rb_cLLaMAContextParams, "mul_mat_q", RUBY_METHOD_FUNC(_llama_context_params_get_mul_mat_q), 0);
|
419
|
-
rb_define_method(rb_cLLaMAContextParams, "seed=", RUBY_METHOD_FUNC(_llama_context_params_set_seed), 1);
|
420
|
-
rb_define_method(rb_cLLaMAContextParams, "seed", RUBY_METHOD_FUNC(_llama_context_params_get_seed), 0);
|
421
790
|
rb_define_method(rb_cLLaMAContextParams, "f16_kv=", RUBY_METHOD_FUNC(_llama_context_params_set_f16_kv), 1);
|
422
791
|
rb_define_method(rb_cLLaMAContextParams, "f16_kv", RUBY_METHOD_FUNC(_llama_context_params_get_f16_kv), 0);
|
423
792
|
rb_define_method(rb_cLLaMAContextParams, "logits_all=", RUBY_METHOD_FUNC(_llama_context_params_set_logits_all), 1);
|
424
793
|
rb_define_method(rb_cLLaMAContextParams, "logits_all", RUBY_METHOD_FUNC(_llama_context_params_get_logits_all), 0);
|
425
|
-
rb_define_method(rb_cLLaMAContextParams, "vocab_only=", RUBY_METHOD_FUNC(_llama_context_params_set_vocab_only), 1);
|
426
|
-
rb_define_method(rb_cLLaMAContextParams, "vocab_only", RUBY_METHOD_FUNC(_llama_context_params_get_vocab_only), 0);
|
427
|
-
rb_define_method(rb_cLLaMAContextParams, "use_mmap=", RUBY_METHOD_FUNC(_llama_context_params_set_use_mmap), 1);
|
428
|
-
rb_define_method(rb_cLLaMAContextParams, "use_mmap", RUBY_METHOD_FUNC(_llama_context_params_get_use_mmap), 0);
|
429
|
-
rb_define_method(rb_cLLaMAContextParams, "use_mlock=", RUBY_METHOD_FUNC(_llama_context_params_set_use_mlock), 1);
|
430
|
-
rb_define_method(rb_cLLaMAContextParams, "use_mlock", RUBY_METHOD_FUNC(_llama_context_params_get_use_mlock), 0);
|
431
794
|
rb_define_method(rb_cLLaMAContextParams, "embedding=", RUBY_METHOD_FUNC(_llama_context_params_set_embedding), 1);
|
432
795
|
rb_define_method(rb_cLLaMAContextParams, "embedding", RUBY_METHOD_FUNC(_llama_context_params_get_embedding), 0);
|
433
796
|
}
|
@@ -441,6 +804,22 @@ private:
|
|
441
804
|
// return self;
|
442
805
|
// }
|
443
806
|
|
807
|
+
// seed
|
808
|
+
static VALUE _llama_context_params_set_seed(VALUE self, VALUE seed) {
|
809
|
+
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
810
|
+
if (NUM2INT(seed) < 0) {
|
811
|
+
rb_raise(rb_eArgError, "seed must be positive");
|
812
|
+
return Qnil;
|
813
|
+
}
|
814
|
+
ptr->params.seed = NUM2INT(seed);
|
815
|
+
return INT2NUM(ptr->params.seed);
|
816
|
+
}
|
817
|
+
|
818
|
+
static VALUE _llama_context_params_get_seed(VALUE self) {
|
819
|
+
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
820
|
+
return INT2NUM(ptr->params.seed);
|
821
|
+
}
|
822
|
+
|
444
823
|
// n_ctx
|
445
824
|
static VALUE _llama_context_params_set_n_ctx(VALUE self, VALUE n_ctx) {
|
446
825
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
@@ -465,41 +844,28 @@ private:
|
|
465
844
|
return INT2NUM(ptr->params.n_batch);
|
466
845
|
}
|
467
846
|
|
468
|
-
//
|
469
|
-
static VALUE
|
847
|
+
// n_threads
|
848
|
+
static VALUE _llama_context_params_set_n_threads(VALUE self, VALUE n_threads) {
|
470
849
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
471
|
-
ptr->params.
|
472
|
-
return INT2NUM(ptr->params.
|
850
|
+
ptr->params.n_threads = NUM2INT(n_threads);
|
851
|
+
return INT2NUM(ptr->params.n_threads);
|
473
852
|
}
|
474
853
|
|
475
|
-
static VALUE
|
854
|
+
static VALUE _llama_context_params_get_n_threads(VALUE self) {
|
476
855
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
477
|
-
return INT2NUM(ptr->params.
|
856
|
+
return INT2NUM(ptr->params.n_threads);
|
478
857
|
}
|
479
858
|
|
480
|
-
//
|
481
|
-
static VALUE
|
859
|
+
// n_threads_batch
|
860
|
+
static VALUE _llama_context_params_set_n_threads_batch(VALUE self, VALUE n_threads_batch) {
|
482
861
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
483
|
-
ptr->params.
|
484
|
-
return INT2NUM(ptr->params.
|
862
|
+
ptr->params.n_threads_batch = NUM2INT(n_threads_batch);
|
863
|
+
return INT2NUM(ptr->params.n_threads_batch);
|
485
864
|
}
|
486
865
|
|
487
|
-
static VALUE
|
866
|
+
static VALUE _llama_context_params_get_n_threads_batch(VALUE self) {
|
488
867
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
489
|
-
return INT2NUM(ptr->params.
|
490
|
-
}
|
491
|
-
|
492
|
-
// tensor_split
|
493
|
-
static VALUE _llama_context_params_get_tensor_split(VALUE self) {
|
494
|
-
if (LLAMA_MAX_DEVICES < 1) {
|
495
|
-
return rb_ary_new();
|
496
|
-
}
|
497
|
-
VALUE ret = rb_ary_new2(LLAMA_MAX_DEVICES);
|
498
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
499
|
-
for (size_t i = 0; i < LLAMA_MAX_DEVICES; i++) {
|
500
|
-
rb_ary_store(ret, i, DBL2NUM(ptr->params.tensor_split[i]));
|
501
|
-
}
|
502
|
-
return ret;
|
868
|
+
return INT2NUM(ptr->params.n_threads_batch);
|
503
869
|
}
|
504
870
|
|
505
871
|
// rope_freq_base
|
@@ -526,18 +892,6 @@ private:
|
|
526
892
|
return DBL2NUM(ptr->params.rope_freq_scale);
|
527
893
|
}
|
528
894
|
|
529
|
-
// low_vram
|
530
|
-
static VALUE _llama_context_params_set_low_vram(VALUE self, VALUE low_vram) {
|
531
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
532
|
-
ptr->params.low_vram = RTEST(low_vram) ? true : false;
|
533
|
-
return ptr->params.low_vram ? Qtrue : Qfalse;
|
534
|
-
}
|
535
|
-
|
536
|
-
static VALUE _llama_context_params_get_low_vram(VALUE self) {
|
537
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
538
|
-
return ptr->params.low_vram ? Qtrue : Qfalse;
|
539
|
-
}
|
540
|
-
|
541
895
|
// mul_mat_q
|
542
896
|
static VALUE _llama_context_params_set_mul_mat_q(VALUE self, VALUE mul_mat_q) {
|
543
897
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
@@ -550,22 +904,6 @@ private:
|
|
550
904
|
return ptr->params.mul_mat_q ? Qtrue : Qfalse;
|
551
905
|
}
|
552
906
|
|
553
|
-
// seed
|
554
|
-
static VALUE _llama_context_params_set_seed(VALUE self, VALUE seed) {
|
555
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
556
|
-
if (NUM2INT(seed) < 0) {
|
557
|
-
rb_raise(rb_eArgError, "seed must be positive");
|
558
|
-
return Qnil;
|
559
|
-
}
|
560
|
-
ptr->params.seed = NUM2INT(seed);
|
561
|
-
return INT2NUM(ptr->params.seed);
|
562
|
-
}
|
563
|
-
|
564
|
-
static VALUE _llama_context_params_get_seed(VALUE self) {
|
565
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
566
|
-
return INT2NUM(ptr->params.seed);
|
567
|
-
}
|
568
|
-
|
569
907
|
// f16_kv
|
570
908
|
static VALUE _llama_context_params_set_f16_kv(VALUE self, VALUE f16_kv) {
|
571
909
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
@@ -590,42 +928,6 @@ private:
|
|
590
928
|
return ptr->params.logits_all ? Qtrue : Qfalse;
|
591
929
|
}
|
592
930
|
|
593
|
-
// vocab_only
|
594
|
-
static VALUE _llama_context_params_set_vocab_only(VALUE self, VALUE vocab_only) {
|
595
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
596
|
-
ptr->params.vocab_only = RTEST(vocab_only) ? true : false;
|
597
|
-
return ptr->params.vocab_only ? Qtrue : Qfalse;
|
598
|
-
}
|
599
|
-
|
600
|
-
static VALUE _llama_context_params_get_vocab_only(VALUE self) {
|
601
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
602
|
-
return ptr->params.vocab_only ? Qtrue : Qfalse;
|
603
|
-
}
|
604
|
-
|
605
|
-
// use_mmap
|
606
|
-
static VALUE _llama_context_params_set_use_mmap(VALUE self, VALUE use_mmap) {
|
607
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
608
|
-
ptr->params.use_mmap = RTEST(use_mmap) ? true : false;
|
609
|
-
return ptr->params.use_mmap ? Qtrue : Qfalse;
|
610
|
-
}
|
611
|
-
|
612
|
-
static VALUE _llama_context_params_get_use_mmap(VALUE self) {
|
613
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
614
|
-
return ptr->params.use_mmap ? Qtrue : Qfalse;
|
615
|
-
}
|
616
|
-
|
617
|
-
// use_mlock
|
618
|
-
static VALUE _llama_context_params_set_use_mlock(VALUE self, VALUE use_mlock) {
|
619
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
620
|
-
ptr->params.use_mlock = RTEST(use_mlock) ? true : false;
|
621
|
-
return ptr->params.use_mlock ? Qtrue : Qfalse;
|
622
|
-
}
|
623
|
-
|
624
|
-
static VALUE _llama_context_params_get_use_mlock(VALUE self) {
|
625
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
626
|
-
return ptr->params.use_mlock ? Qtrue : Qfalse;
|
627
|
-
}
|
628
|
-
|
629
931
|
// embedding
|
630
932
|
static VALUE _llama_context_params_set_embedding(VALUE self, VALUE embedding) {
|
631
933
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
@@ -823,11 +1125,10 @@ public:
|
|
823
1125
|
rb_define_method(rb_cLLaMAModel, "load", RUBY_METHOD_FUNC(_llama_model_load), -1);
|
824
1126
|
rb_define_method(rb_cLLaMAModel, "apply_lora_from_file", RUBY_METHOD_FUNC(_llama_model_apply_lora_from_file), -1);
|
825
1127
|
rb_define_method(rb_cLLaMAModel, "n_vocab", RUBY_METHOD_FUNC(_llama_model_get_model_n_vocab), 0);
|
826
|
-
rb_define_method(rb_cLLaMAModel, "n_ctx", RUBY_METHOD_FUNC(_llama_model_get_model_n_ctx), 0);
|
827
1128
|
rb_define_method(rb_cLLaMAModel, "n_ctx_train", RUBY_METHOD_FUNC(_llama_model_get_model_n_ctx_train), 0);
|
828
1129
|
rb_define_method(rb_cLLaMAModel, "n_embd", RUBY_METHOD_FUNC(_llama_model_get_model_n_embd), 0);
|
829
|
-
rb_define_method(rb_cLLaMAModel, "token_to_piece", RUBY_METHOD_FUNC(
|
830
|
-
rb_define_method(rb_cLLaMAModel, "tokenize", RUBY_METHOD_FUNC(
|
1130
|
+
rb_define_method(rb_cLLaMAModel, "token_to_piece", RUBY_METHOD_FUNC(_llama_model_token_to_piece), 1);
|
1131
|
+
rb_define_method(rb_cLLaMAModel, "tokenize", RUBY_METHOD_FUNC(_llama_model_tokenize), -1);
|
831
1132
|
rb_define_method(rb_cLLaMAModel, "desc", RUBY_METHOD_FUNC(_llama_model_get_model_desc), 0);
|
832
1133
|
rb_define_method(rb_cLLaMAModel, "size", RUBY_METHOD_FUNC(_llama_model_get_model_size), 0);
|
833
1134
|
rb_define_method(rb_cLLaMAModel, "n_params", RUBY_METHOD_FUNC(_llama_model_get_model_n_params), 0);
|
@@ -841,30 +1142,21 @@ private:
|
|
841
1142
|
ID kw_table[2] = { rb_intern("model_path"), rb_intern("params") };
|
842
1143
|
VALUE kw_values[2] = { Qundef, Qundef };
|
843
1144
|
rb_scan_args(argc, argv, ":", &kw_args);
|
844
|
-
rb_get_kwargs(kw_args, kw_table,
|
845
|
-
|
846
|
-
if (kw_values[0] == Qundef && kw_values[1] == Qundef) {
|
847
|
-
rb_iv_set(self, "@params", Qnil);
|
848
|
-
return Qnil;
|
849
|
-
}
|
1145
|
+
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
850
1146
|
|
851
1147
|
if (!RB_TYPE_P(kw_values[0], T_STRING)) {
|
852
1148
|
rb_raise(rb_eArgError, "model_path must be a string");
|
853
1149
|
return Qnil;
|
854
1150
|
}
|
855
|
-
if (!rb_obj_is_kind_of(kw_values[1],
|
856
|
-
rb_raise(rb_eArgError, "params must be a
|
1151
|
+
if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAModelParams)) {
|
1152
|
+
rb_raise(rb_eArgError, "params must be a ModelParams");
|
857
1153
|
return Qnil;
|
858
1154
|
}
|
859
1155
|
|
860
1156
|
VALUE filename = kw_values[0];
|
861
|
-
|
1157
|
+
LLaMAModelParamsWrapper* prms_ptr = RbLLaMAModelParams::get_llama_model_params(kw_values[1]);
|
862
1158
|
LLaMAModelWrapper* model_ptr = get_llama_model(self);
|
863
1159
|
|
864
|
-
if (prms_ptr->params.seed == LLAMA_DEFAULT_SEED) {
|
865
|
-
prms_ptr->params.seed = time(NULL);
|
866
|
-
}
|
867
|
-
|
868
1160
|
try {
|
869
1161
|
model_ptr->model = llama_load_model_from_file(StringValueCStr(filename), prms_ptr->params);
|
870
1162
|
} catch (const std::runtime_error& e) {
|
@@ -912,8 +1204,8 @@ private:
|
|
912
1204
|
rb_raise(rb_eArgError, "model_path must be a string");
|
913
1205
|
return Qnil;
|
914
1206
|
}
|
915
|
-
if (!rb_obj_is_kind_of(kw_values[1],
|
916
|
-
rb_raise(rb_eArgError, "params must be a
|
1207
|
+
if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAModelParams)) {
|
1208
|
+
rb_raise(rb_eArgError, "params must be a LLaMAModelParams");
|
917
1209
|
return Qnil;
|
918
1210
|
}
|
919
1211
|
|
@@ -924,7 +1216,7 @@ private:
|
|
924
1216
|
}
|
925
1217
|
|
926
1218
|
VALUE filename = kw_values[0];
|
927
|
-
|
1219
|
+
LLaMAModelParamsWrapper* prms_ptr = RbLLaMAModelParams::get_llama_model_params(kw_values[1]);
|
928
1220
|
|
929
1221
|
try {
|
930
1222
|
model_ptr->model = llama_load_model_from_file(StringValueCStr(filename), prms_ptr->params);
|
@@ -946,10 +1238,10 @@ private:
|
|
946
1238
|
|
947
1239
|
static VALUE _llama_model_apply_lora_from_file(int argc, VALUE* argv, VALUE self) {
|
948
1240
|
VALUE kw_args = Qnil;
|
949
|
-
ID kw_table[
|
950
|
-
VALUE kw_values[
|
1241
|
+
ID kw_table[4] = { rb_intern("lora_path"), rb_intern("base_model_path"), rb_intern("n_threads"), rb_intern("scale") };
|
1242
|
+
VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
|
951
1243
|
rb_scan_args(argc, argv, ":", &kw_args);
|
952
|
-
rb_get_kwargs(kw_args, kw_table, 1,
|
1244
|
+
rb_get_kwargs(kw_args, kw_table, 1, 3, kw_values);
|
953
1245
|
|
954
1246
|
if (!RB_TYPE_P(kw_values[0], T_STRING)) {
|
955
1247
|
rb_raise(rb_eArgError, "lora_path must be a string");
|
@@ -963,13 +1255,18 @@ private:
|
|
963
1255
|
rb_raise(rb_eArgError, "n_threads must be an integer");
|
964
1256
|
return Qnil;
|
965
1257
|
}
|
1258
|
+
if (kw_values[3] != Qundef && !RB_FLOAT_TYPE_P(kw_values[3])) {
|
1259
|
+
rb_raise(rb_eArgError, "scale must be a float");
|
1260
|
+
return Qnil;
|
1261
|
+
}
|
966
1262
|
|
967
1263
|
const char* lora_path = StringValueCStr(kw_values[0]);
|
968
1264
|
const char* base_model_path = kw_values[1] == Qundef ? NULL : StringValueCStr(kw_values[1]);
|
969
1265
|
const int n_threads = kw_values[2] == Qundef ? 1 : NUM2INT(kw_values[2]);
|
1266
|
+
const float scale = kw_values[3] == Qundef ? 1.0 : NUM2DBL(kw_values[3]);
|
970
1267
|
|
971
1268
|
LLaMAModelWrapper* ptr = get_llama_model(self);
|
972
|
-
if (llama_model_apply_lora_from_file(ptr->model, lora_path, base_model_path, n_threads) != 0) {
|
1269
|
+
if (llama_model_apply_lora_from_file(ptr->model, lora_path, scale, base_model_path, n_threads) != 0) {
|
973
1270
|
rb_raise(rb_eRuntimeError, "Failed to apply LoRA");
|
974
1271
|
return Qnil;
|
975
1272
|
}
|
@@ -978,25 +1275,20 @@ private:
|
|
978
1275
|
|
979
1276
|
static VALUE _llama_model_get_model_n_vocab(VALUE self) {
|
980
1277
|
LLaMAModelWrapper* ptr = get_llama_model(self);
|
981
|
-
return INT2NUM(
|
982
|
-
}
|
983
|
-
|
984
|
-
static VALUE _llama_model_get_model_n_ctx(VALUE self) {
|
985
|
-
LLaMAModelWrapper* ptr = get_llama_model(self);
|
986
|
-
return INT2NUM(llama_model_n_ctx(ptr->model));
|
1278
|
+
return INT2NUM(llama_n_vocab(ptr->model));
|
987
1279
|
}
|
988
1280
|
|
989
1281
|
static VALUE _llama_model_get_model_n_ctx_train(VALUE self) {
|
990
1282
|
LLaMAModelWrapper* ptr = get_llama_model(self);
|
991
|
-
return INT2NUM(
|
1283
|
+
return INT2NUM(llama_n_ctx_train(ptr->model));
|
992
1284
|
}
|
993
1285
|
|
994
1286
|
static VALUE _llama_model_get_model_n_embd(VALUE self) {
|
995
1287
|
LLaMAModelWrapper* ptr = get_llama_model(self);
|
996
|
-
return INT2NUM(
|
1288
|
+
return INT2NUM(llama_n_embd(ptr->model));
|
997
1289
|
}
|
998
1290
|
|
999
|
-
static VALUE
|
1291
|
+
static VALUE _llama_model_token_to_piece(VALUE self, VALUE token_) {
|
1000
1292
|
if (!RB_INTEGER_TYPE_P(token_)) {
|
1001
1293
|
rb_raise(rb_eArgError, "token must be an integer");
|
1002
1294
|
return Qnil;
|
@@ -1004,10 +1296,10 @@ private:
|
|
1004
1296
|
const llama_token token = NUM2INT(token_);
|
1005
1297
|
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1006
1298
|
std::vector<char> result(8, 0);
|
1007
|
-
const int n_tokens =
|
1299
|
+
const int n_tokens = llama_token_to_piece(ptr->model, token, result.data(), result.size());
|
1008
1300
|
if (n_tokens < 0) {
|
1009
1301
|
result.resize(-n_tokens);
|
1010
|
-
const int check =
|
1302
|
+
const int check = llama_token_to_piece(ptr->model, token, result.data(), result.size());
|
1011
1303
|
if (check != -n_tokens) {
|
1012
1304
|
rb_raise(rb_eRuntimeError, "failed to convert");
|
1013
1305
|
return Qnil;
|
@@ -1019,7 +1311,7 @@ private:
|
|
1019
1311
|
return rb_str_new_cstr(ret.c_str());
|
1020
1312
|
}
|
1021
1313
|
|
1022
|
-
static VALUE
|
1314
|
+
static VALUE _llama_model_tokenize(int argc, VALUE* argv, VALUE self) {
|
1023
1315
|
VALUE kw_args = Qnil;
|
1024
1316
|
ID kw_table[3] = { rb_intern("text"), rb_intern("n_max_tokens"), rb_intern("add_bos") };
|
1025
1317
|
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
@@ -1046,7 +1338,7 @@ private:
|
|
1046
1338
|
|
1047
1339
|
llama_token* tokens = ALLOCA_N(llama_token, n_max_tokens);
|
1048
1340
|
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1049
|
-
const int n_tokens =
|
1341
|
+
const int n_tokens = llama_tokenize(ptr->model, text.c_str(), text.size(), tokens, n_max_tokens, add_bos);
|
1050
1342
|
|
1051
1343
|
if (n_tokens < 0) {
|
1052
1344
|
rb_raise(rb_eRuntimeError, "failed to tokenize. The numebr of tokens (%d) is greater than n_max_tokens.", -n_tokens);
|
@@ -1345,11 +1637,11 @@ public:
|
|
1345
1637
|
static void define_class(VALUE outer) {
|
1346
1638
|
rb_cLLaMAContext = rb_define_class_under(outer, "Context", rb_cObject);
|
1347
1639
|
rb_define_alloc_func(rb_cLLaMAContext, llama_context_alloc);
|
1640
|
+
rb_define_attr(rb_cLLaMAContext, "model", 1, 0);
|
1348
1641
|
rb_define_method(rb_cLLaMAContext, "initialize", RUBY_METHOD_FUNC(_llama_context_initialize), -1);
|
1349
1642
|
rb_define_method(rb_cLLaMAContext, "eval", RUBY_METHOD_FUNC(_llama_context_eval), -1);
|
1350
1643
|
rb_define_method(rb_cLLaMAContext, "eval_embd", RUBY_METHOD_FUNC(_llama_context_eval_embd), -1);
|
1351
|
-
rb_define_method(rb_cLLaMAContext, "
|
1352
|
-
rb_define_method(rb_cLLaMAContext, "tokenize", RUBY_METHOD_FUNC(_llama_context_tokenize), -1);
|
1644
|
+
rb_define_method(rb_cLLaMAContext, "decode", RUBY_METHOD_FUNC(_llama_context_decode), 1);
|
1353
1645
|
rb_define_method(rb_cLLaMAContext, "logits", RUBY_METHOD_FUNC(_llama_context_logits), 0);
|
1354
1646
|
rb_define_method(rb_cLLaMAContext, "embeddings", RUBY_METHOD_FUNC(_llama_context_embeddings), 0);
|
1355
1647
|
rb_define_method(rb_cLLaMAContext, "text", RUBY_METHOD_FUNC(_llama_context_text), 1);
|
@@ -1358,15 +1650,16 @@ public:
|
|
1358
1650
|
rb_define_method(rb_cLLaMAContext, "token_bos", RUBY_METHOD_FUNC(_llama_context_token_bos), 0);
|
1359
1651
|
rb_define_method(rb_cLLaMAContext, "token_eos", RUBY_METHOD_FUNC(_llama_context_token_eos), 0);
|
1360
1652
|
rb_define_method(rb_cLLaMAContext, "token_nl", RUBY_METHOD_FUNC(_llama_context_token_nl), 0);
|
1361
|
-
rb_define_method(rb_cLLaMAContext, "token_to_piece", RUBY_METHOD_FUNC(_llama_context_token_to_piece), 1);
|
1362
|
-
rb_define_method(rb_cLLaMAContext, "n_vocab", RUBY_METHOD_FUNC(_llama_context_n_vocab), 0);
|
1363
1653
|
rb_define_method(rb_cLLaMAContext, "n_ctx", RUBY_METHOD_FUNC(_llama_context_n_ctx), 0);
|
1364
|
-
rb_define_method(rb_cLLaMAContext, "n_ctx_train", RUBY_METHOD_FUNC(_llama_context_n_ctx_train), 0);
|
1365
|
-
rb_define_method(rb_cLLaMAContext, "n_embd", RUBY_METHOD_FUNC(_llama_context_n_embd), 0);
|
1366
1654
|
rb_define_method(rb_cLLaMAContext, "timings", RUBY_METHOD_FUNC(_llama_context_get_timings), 0);
|
1367
1655
|
rb_define_method(rb_cLLaMAContext, "print_timings", RUBY_METHOD_FUNC(_llama_context_print_timings), 0);
|
1368
1656
|
rb_define_method(rb_cLLaMAContext, "reset_timings", RUBY_METHOD_FUNC(_llama_context_reset_timings), 0);
|
1369
1657
|
rb_define_method(rb_cLLaMAContext, "kv_cache_token_count", RUBY_METHOD_FUNC(_llama_context_kv_cache_token_count), 0);
|
1658
|
+
rb_define_method(rb_cLLaMAContext, "kv_cache_tokens_rm", RUBY_METHOD_FUNC(_llama_context_kv_cache_tokens_rm), 2);
|
1659
|
+
rb_define_method(rb_cLLaMAContext, "kv_cache_seq_rm", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_rm), 3);
|
1660
|
+
rb_define_method(rb_cLLaMAContext, "kv_cache_seq_cp", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_cp), 4);
|
1661
|
+
rb_define_method(rb_cLLaMAContext, "kv_cache_seq_keep", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_keep), 1);
|
1662
|
+
rb_define_method(rb_cLLaMAContext, "kv_cache_seq_shift", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_shift), 4);
|
1370
1663
|
rb_define_method(rb_cLLaMAContext, "set_rng_seed", RUBY_METHOD_FUNC(_llama_context_set_rng_seed), 1);
|
1371
1664
|
rb_define_method(rb_cLLaMAContext, "load_session_file", RUBY_METHOD_FUNC(_llama_context_load_session_file), -1);
|
1372
1665
|
rb_define_method(rb_cLLaMAContext, "save_session_file", RUBY_METHOD_FUNC(_llama_context_save_session_file), -1);
|
@@ -1378,6 +1671,7 @@ public:
|
|
1378
1671
|
rb_define_method(rb_cLLaMAContext, "sample_top_p", RUBY_METHOD_FUNC(_llama_context_sample_top_p), -1);
|
1379
1672
|
rb_define_method(rb_cLLaMAContext, "sample_tail_free", RUBY_METHOD_FUNC(_llama_context_sample_tail_free), -1);
|
1380
1673
|
rb_define_method(rb_cLLaMAContext, "sample_typical", RUBY_METHOD_FUNC(_llama_context_sample_typical), -1);
|
1674
|
+
rb_define_method(rb_cLLaMAContext, "sample_temp", RUBY_METHOD_FUNC(_llama_context_sample_temp), -1);
|
1381
1675
|
rb_define_method(rb_cLLaMAContext, "sample_temperature", RUBY_METHOD_FUNC(_llama_context_sample_temperature), -1);
|
1382
1676
|
rb_define_method(rb_cLLaMAContext, "sample_token_mirostat", RUBY_METHOD_FUNC(_llama_context_sample_token_mirostat), -1);
|
1383
1677
|
rb_define_method(rb_cLLaMAContext, "sample_token_mirostat_v2", RUBY_METHOD_FUNC(_llama_context_sample_token_mirostat_v2), -1);
|
@@ -1392,24 +1686,27 @@ private:
|
|
1392
1686
|
|
1393
1687
|
static VALUE _llama_context_initialize(int argc, VALUE* argv, VALUE self) {
|
1394
1688
|
VALUE kw_args = Qnil;
|
1395
|
-
ID kw_table[
|
1396
|
-
VALUE kw_values[
|
1689
|
+
ID kw_table[2] = { rb_intern("model"), rb_intern("params") };
|
1690
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
1397
1691
|
rb_scan_args(argc, argv, ":", &kw_args);
|
1398
|
-
rb_get_kwargs(kw_args, kw_table,
|
1692
|
+
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
1399
1693
|
|
1400
1694
|
VALUE model = kw_values[0];
|
1401
1695
|
if (!rb_obj_is_kind_of(model, rb_cLLaMAModel)) {
|
1402
1696
|
rb_raise(rb_eArgError, "model must be a Model");
|
1403
1697
|
return Qnil;
|
1404
1698
|
}
|
1699
|
+
VALUE params = kw_values[1];
|
1700
|
+
if (!rb_obj_is_kind_of(params, rb_cLLaMAContextParams)) {
|
1701
|
+
rb_raise(rb_eArgError, "params must be a ContextParams");
|
1702
|
+
return Qnil;
|
1703
|
+
}
|
1405
1704
|
|
1406
1705
|
LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
|
1407
1706
|
if (model_ptr->model == NULL) {
|
1408
1707
|
rb_raise(rb_eRuntimeError, "Model is empty");
|
1409
1708
|
return Qnil;
|
1410
1709
|
}
|
1411
|
-
|
1412
|
-
VALUE params = rb_iv_get(model, "@params");
|
1413
1710
|
LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
|
1414
1711
|
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
1415
1712
|
|
@@ -1421,6 +1718,7 @@ private:
|
|
1421
1718
|
}
|
1422
1719
|
|
1423
1720
|
rb_iv_set(self, "@model", model);
|
1721
|
+
rb_iv_set(self, "@params", params);
|
1424
1722
|
rb_iv_set(self, "@has_evaluated", Qfalse);
|
1425
1723
|
|
1426
1724
|
return Qnil;
|
@@ -1428,8 +1726,8 @@ private:
|
|
1428
1726
|
|
1429
1727
|
static VALUE _llama_context_eval(int argc, VALUE* argv, VALUE self) {
|
1430
1728
|
VALUE kw_args = Qnil;
|
1431
|
-
ID kw_table[
|
1432
|
-
VALUE kw_values[
|
1729
|
+
ID kw_table[3] = { rb_intern("tokens"), rb_intern("n_past"), rb_intern("n_tokens") };
|
1730
|
+
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
1433
1731
|
rb_scan_args(argc, argv, ":", &kw_args);
|
1434
1732
|
rb_get_kwargs(kw_args, kw_table, 2, 2, kw_values);
|
1435
1733
|
|
@@ -1445,10 +1743,6 @@ private:
|
|
1445
1743
|
rb_raise(rb_eArgError, "n_tokens must be an integer");
|
1446
1744
|
return Qnil;
|
1447
1745
|
}
|
1448
|
-
if (kw_values[3] != Qundef && !RB_INTEGER_TYPE_P(kw_values[3])) {
|
1449
|
-
rb_raise(rb_eArgError, "n_threads must be an integer");
|
1450
|
-
return Qnil;
|
1451
|
-
}
|
1452
1746
|
|
1453
1747
|
const size_t tokens_len = RARRAY_LEN(kw_values[0]);
|
1454
1748
|
std::vector<llama_token> embd(tokens_len);
|
@@ -1463,14 +1757,13 @@ private:
|
|
1463
1757
|
|
1464
1758
|
const int n_tokens = kw_values[2] == Qundef ? (int)tokens_len : NUM2INT(kw_values[2]);
|
1465
1759
|
const int n_past = NUM2INT(kw_values[1]);
|
1466
|
-
const int n_threads = kw_values[3] == Qundef ? 1 : NUM2INT(kw_values[3]);
|
1467
1760
|
|
1468
1761
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1469
1762
|
if (ptr->ctx == NULL) {
|
1470
1763
|
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1471
1764
|
return Qnil;
|
1472
1765
|
}
|
1473
|
-
if (llama_eval(ptr->ctx, embd.data(), n_tokens, n_past
|
1766
|
+
if (llama_eval(ptr->ctx, embd.data(), n_tokens, n_past) != 0) {
|
1474
1767
|
rb_raise(rb_eRuntimeError, "Failed to evaluate");
|
1475
1768
|
return Qnil;
|
1476
1769
|
}
|
@@ -1483,8 +1776,8 @@ private:
|
|
1483
1776
|
|
1484
1777
|
static VALUE _llama_context_eval_embd(int argc, VALUE* argv, VALUE self) {
|
1485
1778
|
VALUE kw_args = Qnil;
|
1486
|
-
ID kw_table[
|
1487
|
-
VALUE kw_values[
|
1779
|
+
ID kw_table[3] = { rb_intern("embd"), rb_intern("n_past"), rb_intern("n_tokens") };
|
1780
|
+
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
1488
1781
|
rb_scan_args(argc, argv, ":", &kw_args);
|
1489
1782
|
rb_get_kwargs(kw_args, kw_table, 2, 2, kw_values);
|
1490
1783
|
|
@@ -1500,10 +1793,6 @@ private:
|
|
1500
1793
|
rb_raise(rb_eArgError, "n_tokens must be an integer");
|
1501
1794
|
return Qnil;
|
1502
1795
|
}
|
1503
|
-
if (kw_values[3] != Qundef && !RB_INTEGER_TYPE_P(kw_values[3])) {
|
1504
|
-
rb_raise(rb_eArgError, "n_threads must be an integer");
|
1505
|
-
return Qnil;
|
1506
|
-
}
|
1507
1796
|
|
1508
1797
|
const size_t tokens_len = RARRAY_LEN(kw_values[0]);
|
1509
1798
|
std::vector<float> embd(tokens_len);
|
@@ -1518,14 +1807,13 @@ private:
|
|
1518
1807
|
|
1519
1808
|
const int n_tokens = kw_values[2] == Qundef ? (int)tokens_len : NUM2INT(kw_values[2]);
|
1520
1809
|
const int n_past = NUM2INT(kw_values[1]);
|
1521
|
-
const int n_threads = kw_values[3] == Qundef ? 1 : NUM2INT(kw_values[3]);
|
1522
1810
|
|
1523
1811
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1524
1812
|
if (ptr->ctx == NULL) {
|
1525
1813
|
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1526
1814
|
return Qnil;
|
1527
1815
|
}
|
1528
|
-
if (llama_eval_embd(ptr->ctx, embd.data(), n_tokens, n_past
|
1816
|
+
if (llama_eval_embd(ptr->ctx, embd.data(), n_tokens, n_past) != 0) {
|
1529
1817
|
rb_raise(rb_eRuntimeError, "Failed to evaluate");
|
1530
1818
|
return Qnil;
|
1531
1819
|
}
|
@@ -1536,91 +1824,22 @@ private:
|
|
1536
1824
|
return Qnil;
|
1537
1825
|
}
|
1538
1826
|
|
1539
|
-
static VALUE
|
1827
|
+
static VALUE _llama_context_decode(VALUE self, VALUE batch) {
|
1540
1828
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1541
1829
|
if (ptr->ctx == NULL) {
|
1542
1830
|
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1543
1831
|
return Qnil;
|
1544
1832
|
}
|
1545
|
-
if (!
|
1546
|
-
rb_raise(rb_eArgError, "
|
1833
|
+
if (!rb_obj_is_kind_of(batch, rb_cLLaMABatch)) {
|
1834
|
+
rb_raise(rb_eArgError, "batch must be a Batch");
|
1547
1835
|
return Qnil;
|
1548
1836
|
}
|
1549
|
-
|
1550
|
-
if (
|
1551
|
-
|
1552
|
-
}
|
1553
|
-
RB_GC_GUARD(fname_);
|
1554
|
-
return Qtrue;
|
1555
|
-
}
|
1556
|
-
|
1557
|
-
static VALUE _llama_context_tokenize(int argc, VALUE* argv, VALUE self) {
|
1558
|
-
VALUE kw_args = Qnil;
|
1559
|
-
ID kw_table[3] = { rb_intern("text"), rb_intern("n_max_tokens"), rb_intern("add_bos") };
|
1560
|
-
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
1561
|
-
rb_scan_args(argc, argv, ":", &kw_args);
|
1562
|
-
rb_get_kwargs(kw_args, kw_table, 1, 2, kw_values);
|
1563
|
-
|
1564
|
-
if (!RB_TYPE_P(kw_values[0], T_STRING)) {
|
1565
|
-
rb_raise(rb_eArgError, "text must be a String");
|
1837
|
+
LLaMABatchWrapper* batch_ptr = RbLLaMABatch::get_llama_batch(batch);
|
1838
|
+
if (llama_decode(ptr->ctx, batch_ptr->batch) < 0) {
|
1839
|
+
rb_raise(rb_eRuntimeError, "Failed to decode");
|
1566
1840
|
return Qnil;
|
1567
1841
|
}
|
1568
|
-
|
1569
|
-
rb_raise(rb_eArgError, "n_max_tokens must be an integer");
|
1570
|
-
return Qnil;
|
1571
|
-
}
|
1572
|
-
if (kw_values[2] != Qundef && (kw_values[2] != Qtrue && kw_values[2] != Qfalse)) {
|
1573
|
-
rb_raise(rb_eArgError, "add_bos must be a boolean");
|
1574
|
-
return Qnil;
|
1575
|
-
}
|
1576
|
-
|
1577
|
-
VALUE text_ = kw_values[0];
|
1578
|
-
std::string text = StringValueCStr(text_);
|
1579
|
-
const bool add_bos = kw_values[2] == Qtrue ? true : false;
|
1580
|
-
const int n_max_tokens = kw_values[1] != Qundef ? NUM2INT(kw_values[1]) : text.size() + (add_bos ? 1 : 0);
|
1581
|
-
|
1582
|
-
std::vector<llama_token> tokens(n_max_tokens);
|
1583
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1584
|
-
if (ptr->ctx == NULL) {
|
1585
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1586
|
-
return Qnil;
|
1587
|
-
}
|
1588
|
-
const int n = llama_tokenize(ptr->ctx, text.c_str(), text.size(), tokens.data(), n_max_tokens, add_bos);
|
1589
|
-
if (n < 0) {
|
1590
|
-
rb_raise(rb_eRuntimeError, "Failed to tokenize");
|
1591
|
-
return Qnil;
|
1592
|
-
}
|
1593
|
-
|
1594
|
-
VALUE output = rb_ary_new();
|
1595
|
-
for (int i = 0; i < n; i++) {
|
1596
|
-
rb_ary_push(output, INT2NUM(tokens[i]));
|
1597
|
-
}
|
1598
|
-
|
1599
|
-
RB_GC_GUARD(text_);
|
1600
|
-
return output;
|
1601
|
-
}
|
1602
|
-
|
1603
|
-
static VALUE _llama_context_token_to_piece(VALUE self, VALUE token_) {
|
1604
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1605
|
-
if (ptr->ctx == NULL) {
|
1606
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1607
|
-
return Qnil;
|
1608
|
-
}
|
1609
|
-
const llama_token token = NUM2INT(token_);
|
1610
|
-
std::vector<char> result(8, 0);
|
1611
|
-
const int n_tokens = llama_token_to_piece(ptr->ctx, token, result.data(), result.size());
|
1612
|
-
if (n_tokens < 0) {
|
1613
|
-
result.resize(-n_tokens);
|
1614
|
-
const int check = llama_token_to_piece(ptr->ctx, token, result.data(), result.size());
|
1615
|
-
if (check != -n_tokens) {
|
1616
|
-
rb_raise(rb_eRuntimeError, "failed to convert");
|
1617
|
-
return Qnil;
|
1618
|
-
}
|
1619
|
-
} else {
|
1620
|
-
result.resize(n_tokens);
|
1621
|
-
}
|
1622
|
-
std::string ret(result.data(), result.size());
|
1623
|
-
return rb_str_new_cstr(ret.c_str());
|
1842
|
+
return Qnil;
|
1624
1843
|
}
|
1625
1844
|
|
1626
1845
|
static VALUE _llama_context_logits(VALUE self) {
|
@@ -1635,10 +1854,11 @@ private:
|
|
1635
1854
|
}
|
1636
1855
|
|
1637
1856
|
VALUE model = rb_iv_get(self, "@model");
|
1638
|
-
|
1857
|
+
LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
|
1858
|
+
VALUE params = rb_iv_get(self, "@params");
|
1639
1859
|
LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
|
1640
1860
|
const int n_tokens = prms_ptr->params.logits_all ? NUM2INT(rb_iv_get(self, "@n_tokens")) : 1;
|
1641
|
-
const int n_vocab = llama_n_vocab(
|
1861
|
+
const int n_vocab = llama_n_vocab(model_ptr->model);
|
1642
1862
|
const float* logits = llama_get_logits(ptr->ctx);
|
1643
1863
|
VALUE output = rb_ary_new();
|
1644
1864
|
for (int i = 0; i < n_tokens * n_vocab; i++) {
|
@@ -1655,7 +1875,8 @@ private:
|
|
1655
1875
|
return Qnil;
|
1656
1876
|
}
|
1657
1877
|
VALUE model = rb_iv_get(self, "@model");
|
1658
|
-
|
1878
|
+
LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
|
1879
|
+
VALUE params = rb_iv_get(self, "@params");
|
1659
1880
|
LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
|
1660
1881
|
if (!prms_ptr->params.embedding) {
|
1661
1882
|
rb_raise(rb_eRuntimeError, "embedding parameter is false");
|
@@ -1666,7 +1887,7 @@ private:
|
|
1666
1887
|
return Qnil;
|
1667
1888
|
}
|
1668
1889
|
|
1669
|
-
const int n_embd = llama_n_embd(
|
1890
|
+
const int n_embd = llama_n_embd(model_ptr->model);
|
1670
1891
|
const float* embd = llama_get_embeddings(ptr->ctx);
|
1671
1892
|
VALUE output = rb_ary_new();
|
1672
1893
|
for (int i = 0; i < n_embd; i++) {
|
@@ -1736,81 +1957,104 @@ private:
|
|
1736
1957
|
return INT2NUM(llama_token_nl(ptr->ctx));
|
1737
1958
|
}
|
1738
1959
|
|
1739
|
-
static VALUE
|
1960
|
+
static VALUE _llama_context_n_ctx(VALUE self) {
|
1740
1961
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1741
1962
|
if (ptr->ctx == NULL) {
|
1742
1963
|
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1743
1964
|
return Qnil;
|
1744
1965
|
}
|
1745
|
-
return INT2NUM(
|
1966
|
+
return INT2NUM(llama_n_ctx(ptr->ctx));
|
1746
1967
|
}
|
1747
1968
|
|
1748
|
-
static VALUE
|
1969
|
+
static VALUE _llama_context_get_timings(VALUE self) {
|
1749
1970
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1750
1971
|
if (ptr->ctx == NULL) {
|
1751
1972
|
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1752
1973
|
return Qnil;
|
1753
1974
|
}
|
1754
|
-
|
1975
|
+
VALUE tm_obj = rb_funcall(rb_cLLaMATimings, rb_intern("new"), 0);
|
1976
|
+
LLaMATimingsWrapper* tm_ptr = RbLLaMATimings::get_llama_timings(tm_obj);
|
1977
|
+
tm_ptr->timings = llama_get_timings(ptr->ctx);
|
1978
|
+
return tm_obj;
|
1755
1979
|
}
|
1756
1980
|
|
1757
|
-
static VALUE
|
1981
|
+
static VALUE _llama_context_print_timings(VALUE self) {
|
1758
1982
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1759
1983
|
if (ptr->ctx == NULL) {
|
1760
1984
|
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1761
1985
|
return Qnil;
|
1762
1986
|
}
|
1763
|
-
|
1987
|
+
llama_print_timings(ptr->ctx);
|
1988
|
+
return Qnil;
|
1764
1989
|
}
|
1765
1990
|
|
1766
|
-
static VALUE
|
1991
|
+
static VALUE _llama_context_reset_timings(VALUE self) {
|
1767
1992
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1768
1993
|
if (ptr->ctx == NULL) {
|
1769
1994
|
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1770
1995
|
return Qnil;
|
1771
1996
|
}
|
1772
|
-
|
1997
|
+
llama_reset_timings(ptr->ctx);
|
1998
|
+
return Qnil;
|
1773
1999
|
}
|
1774
2000
|
|
1775
|
-
static VALUE
|
2001
|
+
static VALUE _llama_context_kv_cache_token_count(VALUE self) {
|
1776
2002
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1777
2003
|
if (ptr->ctx == NULL) {
|
1778
2004
|
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1779
2005
|
return Qnil;
|
1780
2006
|
}
|
1781
|
-
|
1782
|
-
LLaMATimingsWrapper* tm_ptr = RbLLaMATimings::get_llama_timings(tm_obj);
|
1783
|
-
tm_ptr->timings = llama_get_timings(ptr->ctx);
|
1784
|
-
return tm_obj;
|
2007
|
+
return INT2NUM(llama_get_kv_cache_token_count(ptr->ctx));
|
1785
2008
|
}
|
1786
2009
|
|
1787
|
-
static VALUE
|
2010
|
+
static VALUE _llama_context_kv_cache_tokens_rm(VALUE self, VALUE c0, VALUE c1) {
|
1788
2011
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1789
2012
|
if (ptr->ctx == NULL) {
|
1790
2013
|
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1791
2014
|
return Qnil;
|
1792
2015
|
}
|
1793
|
-
|
2016
|
+
llama_kv_cache_tokens_rm(ptr->ctx, NUM2INT(c0), NUM2INT(c1));
|
1794
2017
|
return Qnil;
|
1795
2018
|
}
|
1796
2019
|
|
1797
|
-
static VALUE
|
2020
|
+
static VALUE _llama_context_kv_cache_seq_rm(VALUE self, VALUE seq_id, VALUE p0, VALUE p1) {
|
1798
2021
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1799
2022
|
if (ptr->ctx == NULL) {
|
1800
2023
|
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1801
2024
|
return Qnil;
|
1802
2025
|
}
|
1803
|
-
|
2026
|
+
llama_kv_cache_seq_rm(ptr->ctx, NUM2INT(seq_id), NUM2INT(p0), NUM2INT(p1));
|
1804
2027
|
return Qnil;
|
1805
2028
|
}
|
1806
2029
|
|
1807
|
-
static VALUE
|
2030
|
+
static VALUE _llama_context_kv_cache_seq_cp(VALUE self, VALUE seq_id_src, VALUE seq_id_dst, VALUE p0, VALUE p1) {
|
1808
2031
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1809
2032
|
if (ptr->ctx == NULL) {
|
1810
|
-
rb_raise(
|
2033
|
+
rb_raise(rb_eArgError, "LLaMA context is not initialized");
|
1811
2034
|
return Qnil;
|
1812
2035
|
}
|
1813
|
-
|
2036
|
+
llama_kv_cache_seq_cp(ptr->ctx, NUM2INT(seq_id_src), NUM2INT(seq_id_dst), NUM2INT(p0), NUM2INT(p1));
|
2037
|
+
return Qnil;
|
2038
|
+
}
|
2039
|
+
|
2040
|
+
static VALUE _llama_context_kv_cache_seq_keep(VALUE self, VALUE seq_id) {
|
2041
|
+
LLaMAContextWrapper* ptr = get_llama_context(self);
|
2042
|
+
if (ptr->ctx == NULL) {
|
2043
|
+
rb_raise(rb_eArgError, "LLaMA context is not initialized");
|
2044
|
+
return Qnil;
|
2045
|
+
}
|
2046
|
+
llama_kv_cache_seq_keep(ptr->ctx, NUM2INT(seq_id));
|
2047
|
+
return Qnil;
|
2048
|
+
}
|
2049
|
+
|
2050
|
+
static VALUE _llama_context_kv_cache_seq_shift(VALUE self, VALUE seq_id, VALUE p0, VALUE p1, VALUE delta) {
|
2051
|
+
LLaMAContextWrapper* ptr = get_llama_context(self);
|
2052
|
+
if (ptr->ctx == NULL) {
|
2053
|
+
rb_raise(rb_eArgError, "LLaMA context is not initialized");
|
2054
|
+
return Qnil;
|
2055
|
+
}
|
2056
|
+
llama_kv_cache_seq_shift(ptr->ctx, NUM2INT(seq_id), NUM2INT(p0), NUM2INT(p1), NUM2INT(delta));
|
2057
|
+
return Qnil;
|
1814
2058
|
}
|
1815
2059
|
|
1816
2060
|
static VALUE _llama_context_set_rng_seed(VALUE self, VALUE seed_) {
|
@@ -1851,7 +2095,7 @@ private:
|
|
1851
2095
|
}
|
1852
2096
|
|
1853
2097
|
VALUE model = rb_iv_get(self, "@model");
|
1854
|
-
VALUE params = rb_iv_get(
|
2098
|
+
VALUE params = rb_iv_get(self, "@params");
|
1855
2099
|
LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
|
1856
2100
|
const int n_ctx = prms_ptr->params.n_ctx;
|
1857
2101
|
|
@@ -2235,6 +2479,40 @@ private:
|
|
2235
2479
|
return Qnil;
|
2236
2480
|
}
|
2237
2481
|
|
2482
|
+
static VALUE _llama_context_sample_temp(int argc, VALUE* argv, VALUE self) {
|
2483
|
+
VALUE kw_args = Qnil;
|
2484
|
+
ID kw_table[1] = { rb_intern("temp") };
|
2485
|
+
VALUE kw_values[1] = { Qundef };
|
2486
|
+
VALUE candidates = Qnil;
|
2487
|
+
rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
|
2488
|
+
rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
|
2489
|
+
|
2490
|
+
if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
|
2491
|
+
rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
|
2492
|
+
return Qnil;
|
2493
|
+
}
|
2494
|
+
if (!RB_FLOAT_TYPE_P(kw_values[0])) {
|
2495
|
+
rb_raise(rb_eArgError, "temp must be a float");
|
2496
|
+
return Qnil;
|
2497
|
+
}
|
2498
|
+
|
2499
|
+
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
2500
|
+
if (ctx_ptr->ctx == NULL) {
|
2501
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
2502
|
+
return Qnil;
|
2503
|
+
}
|
2504
|
+
LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
|
2505
|
+
if (cnd_ptr->array.data == nullptr) {
|
2506
|
+
rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
|
2507
|
+
return Qnil;
|
2508
|
+
}
|
2509
|
+
const float temp = NUM2DBL(kw_values[0]);
|
2510
|
+
|
2511
|
+
llama_sample_temp(ctx_ptr->ctx, &(cnd_ptr->array), temp);
|
2512
|
+
|
2513
|
+
return Qnil;
|
2514
|
+
}
|
2515
|
+
|
2238
2516
|
static VALUE _llama_context_sample_temperature(int argc, VALUE* argv, VALUE self) {
|
2239
2517
|
VALUE kw_args = Qnil;
|
2240
2518
|
ID kw_table[1] = { rb_intern("temperature") };
|
@@ -2560,6 +2838,7 @@ extern "C" void Init_llama_cpp(void) {
|
|
2560
2838
|
RbLLaMATokenData::define_class(rb_mLLaMACpp);
|
2561
2839
|
RbLLaMATokenDataArray::define_class(rb_mLLaMACpp);
|
2562
2840
|
RbLLaMAModel::define_class(rb_mLLaMACpp);
|
2841
|
+
RbLLaMAModelParams::define_class(rb_mLLaMACpp);
|
2563
2842
|
RbLLaMATimings::define_class(rb_mLLaMACpp);
|
2564
2843
|
RbLLaMAContext::define_class(rb_mLLaMACpp);
|
2565
2844
|
RbLLaMAContextParams::define_class(rb_mLLaMACpp);
|
@@ -2578,10 +2857,6 @@ extern "C" void Init_llama_cpp(void) {
|
|
2578
2857
|
|
2579
2858
|
rb_define_const(rb_mLLaMACpp, "LLAMA_MAX_DEVICES", INT2NUM(LLAMA_MAX_DEVICES));
|
2580
2859
|
|
2581
|
-
rb_define_const(rb_mLLaMACpp, "LLAMA_LOG_LEVEL_ERROR", INT2NUM(LLAMA_LOG_LEVEL_ERROR));
|
2582
|
-
rb_define_const(rb_mLLaMACpp, "LLAMA_LOG_LEVEL_WARN", INT2NUM(LLAMA_LOG_LEVEL_WARN));
|
2583
|
-
rb_define_const(rb_mLLaMACpp, "LLAMA_LOG_LEVEL_INFO", INT2NUM(LLAMA_LOG_LEVEL_INFO));
|
2584
|
-
|
2585
2860
|
rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_TYPE_SPM", INT2NUM(LLAMA_VOCAB_TYPE_SPM));
|
2586
2861
|
rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_TYPE_BPE", INT2NUM(LLAMA_VOCAB_TYPE_BPE));
|
2587
2862
|
|