llama_cpp 0.3.1 → 0.3.3
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +41 -0
- data/README.md +9 -0
- data/examples/chat.rb +1 -1
- data/examples/embedding.rb +1 -1
- data/examples/prompt_jp.txt +8 -0
- data/ext/llama_cpp/extconf.rb +11 -2
- data/ext/llama_cpp/llama_cpp.cpp +284 -111
- data/ext/llama_cpp/src/ggml-cuda.cu +639 -148
- data/ext/llama_cpp/src/ggml-cuda.h +0 -4
- data/ext/llama_cpp/src/ggml-metal.h +5 -1
- data/ext/llama_cpp/src/ggml-metal.m +19 -6
- data/ext/llama_cpp/src/ggml-metal.metal +56 -47
- data/ext/llama_cpp/src/ggml-mpi.c +216 -0
- data/ext/llama_cpp/src/ggml-mpi.h +39 -0
- data/ext/llama_cpp/src/ggml-opencl.cpp +11 -7
- data/ext/llama_cpp/src/ggml.c +1734 -2248
- data/ext/llama_cpp/src/ggml.h +152 -80
- data/ext/llama_cpp/src/llama.cpp +282 -90
- data/ext/llama_cpp/src/llama.h +30 -1
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +16 -13
- data/sig/llama_cpp.rbs +22 -2
- metadata +5 -2
data/ext/llama_cpp/llama_cpp.cpp
CHANGED
@@ -1,8 +1,8 @@
|
|
1
|
-
|
2
1
|
#include "llama_cpp.h"
|
3
2
|
|
4
3
|
VALUE rb_mLLaMACpp;
|
5
4
|
VALUE rb_cLLaMAModel;
|
5
|
+
VALUE rb_cLLaMATimings;
|
6
6
|
VALUE rb_cLLaMAContext;
|
7
7
|
VALUE rb_cLLaMAContextParams;
|
8
8
|
VALUE rb_cLLaMAModelQuantizeParams;
|
@@ -17,9 +17,9 @@ public:
|
|
17
17
|
data.id = 0;
|
18
18
|
data.logit = 0.0;
|
19
19
|
data.p = 0.0;
|
20
|
-
}
|
20
|
+
}
|
21
21
|
|
22
|
-
~LLaMATokenDataWrapper(){}
|
22
|
+
~LLaMATokenDataWrapper() {}
|
23
23
|
};
|
24
24
|
|
25
25
|
class RbLLaMATokenData {
|
@@ -28,22 +28,22 @@ public:
|
|
28
28
|
LLaMATokenDataWrapper* ptr = (LLaMATokenDataWrapper*)ruby_xmalloc(sizeof(LLaMATokenDataWrapper));
|
29
29
|
new (ptr) LLaMATokenDataWrapper();
|
30
30
|
return TypedData_Wrap_Struct(self, &llama_token_data_type, ptr);
|
31
|
-
}
|
31
|
+
}
|
32
32
|
|
33
33
|
static void llama_token_data_free(void* ptr) {
|
34
34
|
((LLaMATokenDataWrapper*)ptr)->~LLaMATokenDataWrapper();
|
35
35
|
ruby_xfree(ptr);
|
36
|
-
}
|
36
|
+
}
|
37
37
|
|
38
38
|
static size_t llama_token_data_size(const void* ptr) {
|
39
39
|
return sizeof(*((LLaMATokenDataWrapper*)ptr));
|
40
|
-
}
|
40
|
+
}
|
41
41
|
|
42
42
|
static LLaMATokenDataWrapper* get_llama_token_data(VALUE self) {
|
43
43
|
LLaMATokenDataWrapper* ptr;
|
44
44
|
TypedData_Get_Struct(self, LLaMATokenDataWrapper, &llama_token_data_type, ptr);
|
45
45
|
return ptr;
|
46
|
-
}
|
46
|
+
}
|
47
47
|
|
48
48
|
static void define_class(VALUE outer) {
|
49
49
|
rb_cLLaMATokenData = rb_define_class_under(outer, "TokenData", rb_cObject);
|
@@ -95,36 +95,36 @@ private:
|
|
95
95
|
LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
|
96
96
|
ptr->data.id = NUM2INT(id);
|
97
97
|
return INT2NUM(ptr->data.id);
|
98
|
-
}
|
98
|
+
}
|
99
99
|
|
100
100
|
static VALUE _llama_token_data_get_id(VALUE self) {
|
101
101
|
LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
|
102
102
|
return INT2NUM(ptr->data.id);
|
103
|
-
}
|
103
|
+
}
|
104
104
|
|
105
105
|
// logit
|
106
106
|
static VALUE _llama_token_data_set_logit(VALUE self, VALUE logit) {
|
107
107
|
LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
|
108
108
|
ptr->data.logit = NUM2DBL(logit);
|
109
109
|
return DBL2NUM(ptr->data.logit);
|
110
|
-
}
|
110
|
+
}
|
111
111
|
|
112
112
|
static VALUE _llama_token_data_get_logit(VALUE self) {
|
113
113
|
LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
|
114
114
|
return DBL2NUM(ptr->data.logit);
|
115
|
-
}
|
115
|
+
}
|
116
116
|
|
117
117
|
// p
|
118
118
|
static VALUE _llama_token_data_set_p(VALUE self, VALUE p) {
|
119
119
|
LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
|
120
120
|
ptr->data.p = NUM2DBL(p);
|
121
121
|
return DBL2NUM(ptr->data.p);
|
122
|
-
}
|
122
|
+
}
|
123
123
|
|
124
124
|
static VALUE _llama_token_data_get_p(VALUE self) {
|
125
125
|
LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
|
126
126
|
return DBL2NUM(ptr->data.p);
|
127
|
-
}
|
127
|
+
}
|
128
128
|
};
|
129
129
|
|
130
130
|
const rb_data_type_t RbLLaMATokenData::llama_token_data_type = {
|
@@ -145,14 +145,14 @@ public:
|
|
145
145
|
array.data = nullptr;
|
146
146
|
array.size = 0;
|
147
147
|
array.sorted = false;
|
148
|
-
}
|
148
|
+
}
|
149
149
|
|
150
150
|
~LLaMATokenDataArrayWrapper() {
|
151
151
|
if (array.data) {
|
152
152
|
ruby_xfree(array.data);
|
153
153
|
array.data = nullptr;
|
154
154
|
}
|
155
|
-
}
|
155
|
+
}
|
156
156
|
};
|
157
157
|
|
158
158
|
class RbLLaMATokenDataArray {
|
@@ -161,22 +161,22 @@ public:
|
|
161
161
|
LLaMATokenDataArrayWrapper* ptr = (LLaMATokenDataArrayWrapper*)ruby_xmalloc(sizeof(LLaMATokenDataArrayWrapper));
|
162
162
|
new (ptr) LLaMATokenDataArrayWrapper();
|
163
163
|
return TypedData_Wrap_Struct(self, &llama_token_data_array_type, ptr);
|
164
|
-
}
|
164
|
+
}
|
165
165
|
|
166
166
|
static void llama_token_data_array_free(void* ptr) {
|
167
167
|
((LLaMATokenDataArrayWrapper*)ptr)->~LLaMATokenDataArrayWrapper();
|
168
168
|
ruby_xfree(ptr);
|
169
|
-
}
|
169
|
+
}
|
170
170
|
|
171
171
|
static size_t llama_token_data_array_size(const void* ptr) {
|
172
172
|
return sizeof(*((LLaMATokenDataArrayWrapper*)ptr));
|
173
|
-
}
|
173
|
+
}
|
174
174
|
|
175
175
|
static LLaMATokenDataArrayWrapper* get_llama_token_data_array(VALUE self) {
|
176
176
|
LLaMATokenDataArrayWrapper* ptr;
|
177
177
|
TypedData_Get_Struct(self, LLaMATokenDataArrayWrapper, &llama_token_data_array_type, ptr);
|
178
178
|
return ptr;
|
179
|
-
}
|
179
|
+
}
|
180
180
|
|
181
181
|
static void define_class(VALUE outer) {
|
182
182
|
rb_cLLaMATokenDataArray = rb_define_class_under(outer, "TokenDataArray", rb_cObject);
|
@@ -184,7 +184,7 @@ public:
|
|
184
184
|
rb_define_method(rb_cLLaMATokenDataArray, "initialize", RUBY_METHOD_FUNC(_llama_token_data_array_init), -1);
|
185
185
|
rb_define_method(rb_cLLaMATokenDataArray, "size", RUBY_METHOD_FUNC(_llama_token_data_array_get_size), 0);
|
186
186
|
rb_define_method(rb_cLLaMATokenDataArray, "sorted", RUBY_METHOD_FUNC(_llama_token_data_array_get_sorted), 0);
|
187
|
-
}
|
187
|
+
}
|
188
188
|
|
189
189
|
private:
|
190
190
|
static const rb_data_type_t llama_token_data_array_type;
|
@@ -233,17 +233,17 @@ private:
|
|
233
233
|
ptr->array.sorted = kw_values[0] == Qtrue;
|
234
234
|
|
235
235
|
return self;
|
236
|
-
}
|
236
|
+
}
|
237
237
|
|
238
238
|
static VALUE _llama_token_data_array_get_size(VALUE self) {
|
239
239
|
LLaMATokenDataArrayWrapper* ptr = get_llama_token_data_array(self);
|
240
240
|
return SIZET2NUM(ptr->array.size);
|
241
|
-
}
|
241
|
+
}
|
242
242
|
|
243
243
|
static VALUE _llama_token_data_array_get_sorted(VALUE self) {
|
244
244
|
LLaMATokenDataArrayWrapper* ptr = get_llama_token_data_array(self);
|
245
245
|
return ptr->array.sorted ? Qtrue : Qfalse;
|
246
|
-
}
|
246
|
+
}
|
247
247
|
};
|
248
248
|
|
249
249
|
const rb_data_type_t RbLLaMATokenDataArray::llama_token_data_array_type = {
|
@@ -256,13 +256,118 @@ const rb_data_type_t RbLLaMATokenDataArray::llama_token_data_array_type = {
|
|
256
256
|
RUBY_TYPED_FREE_IMMEDIATELY
|
257
257
|
};
|
258
258
|
|
259
|
+
class LLaMATimingsWrapper {
|
260
|
+
public:
|
261
|
+
struct llama_timings timings;
|
262
|
+
|
263
|
+
LLaMATimingsWrapper() {}
|
264
|
+
|
265
|
+
~LLaMATimingsWrapper() {}
|
266
|
+
};
|
267
|
+
|
268
|
+
class RbLLaMATimings {
|
269
|
+
public:
|
270
|
+
static VALUE llama_timings_alloc(VALUE self) {
|
271
|
+
LLaMATimingsWrapper* ptr = (LLaMATimingsWrapper*)ruby_xmalloc(sizeof(LLaMATimingsWrapper));
|
272
|
+
new (ptr) LLaMATimingsWrapper();
|
273
|
+
return TypedData_Wrap_Struct(self, &llama_timings_type, ptr);
|
274
|
+
}
|
275
|
+
|
276
|
+
static void llama_timings_free(void* ptr) {
|
277
|
+
((LLaMATimingsWrapper*)ptr)->~LLaMATimingsWrapper();
|
278
|
+
ruby_xfree(ptr);
|
279
|
+
}
|
280
|
+
|
281
|
+
static size_t llama_timings_size(const void* ptr) {
|
282
|
+
return sizeof(*((LLaMATimingsWrapper*)ptr));
|
283
|
+
}
|
284
|
+
|
285
|
+
static LLaMATimingsWrapper* get_llama_timings(VALUE self) {
|
286
|
+
LLaMATimingsWrapper* ptr;
|
287
|
+
TypedData_Get_Struct(self, LLaMATimingsWrapper, &llama_timings_type, ptr);
|
288
|
+
return ptr;
|
289
|
+
}
|
290
|
+
|
291
|
+
static void define_class(VALUE outer) {
|
292
|
+
rb_cLLaMATimings = rb_define_class_under(outer, "Timings", rb_cObject);
|
293
|
+
rb_define_alloc_func(rb_cLLaMATimings, llama_timings_alloc);
|
294
|
+
rb_define_method(rb_cLLaMATimings, "t_start_ms", RUBY_METHOD_FUNC(_llama_timings_get_t_start_ms), 0);
|
295
|
+
rb_define_method(rb_cLLaMATimings, "t_end_ms", RUBY_METHOD_FUNC(_llama_timings_get_t_end_ms), 0);
|
296
|
+
rb_define_method(rb_cLLaMATimings, "t_load_ms", RUBY_METHOD_FUNC(_llama_timings_get_t_load_ms), 0);
|
297
|
+
rb_define_method(rb_cLLaMATimings, "t_sample_ms", RUBY_METHOD_FUNC(_llama_timings_get_t_sample_ms), 0);
|
298
|
+
rb_define_method(rb_cLLaMATimings, "t_p_eval_ms", RUBY_METHOD_FUNC(_llama_timings_get_t_p_eval_ms), 0);
|
299
|
+
rb_define_method(rb_cLLaMATimings, "t_eval_ms", RUBY_METHOD_FUNC(_llama_timings_get_t_eval_ms), 0);
|
300
|
+
rb_define_method(rb_cLLaMATimings, "n_sample", RUBY_METHOD_FUNC(_llama_timings_get_n_sample), 0);
|
301
|
+
rb_define_method(rb_cLLaMATimings, "n_p_eval", RUBY_METHOD_FUNC(_llama_timings_get_n_p_eval), 0);
|
302
|
+
rb_define_method(rb_cLLaMATimings, "n_eval", RUBY_METHOD_FUNC(_llama_timings_get_n_eval), 0);
|
303
|
+
}
|
304
|
+
|
305
|
+
private:
|
306
|
+
static const rb_data_type_t llama_timings_type;
|
307
|
+
|
308
|
+
static VALUE _llama_timings_get_t_start_ms(VALUE self) {
|
309
|
+
LLaMATimingsWrapper* ptr = get_llama_timings(self);
|
310
|
+
return DBL2NUM(ptr->timings.t_start_ms);
|
311
|
+
}
|
312
|
+
|
313
|
+
static VALUE _llama_timings_get_t_end_ms(VALUE self) {
|
314
|
+
LLaMATimingsWrapper* ptr = get_llama_timings(self);
|
315
|
+
return DBL2NUM(ptr->timings.t_end_ms);
|
316
|
+
}
|
317
|
+
|
318
|
+
static VALUE _llama_timings_get_t_load_ms(VALUE self) {
|
319
|
+
LLaMATimingsWrapper* ptr = get_llama_timings(self);
|
320
|
+
return DBL2NUM(ptr->timings.t_load_ms);
|
321
|
+
}
|
322
|
+
|
323
|
+
static VALUE _llama_timings_get_t_sample_ms(VALUE self) {
|
324
|
+
LLaMATimingsWrapper* ptr = get_llama_timings(self);
|
325
|
+
return DBL2NUM(ptr->timings.t_sample_ms);
|
326
|
+
}
|
327
|
+
|
328
|
+
static VALUE _llama_timings_get_t_p_eval_ms(VALUE self) {
|
329
|
+
LLaMATimingsWrapper* ptr = get_llama_timings(self);
|
330
|
+
return DBL2NUM(ptr->timings.t_p_eval_ms);
|
331
|
+
}
|
332
|
+
|
333
|
+
static VALUE _llama_timings_get_t_eval_ms(VALUE self) {
|
334
|
+
LLaMATimingsWrapper* ptr = get_llama_timings(self);
|
335
|
+
return DBL2NUM(ptr->timings.t_eval_ms);
|
336
|
+
}
|
337
|
+
|
338
|
+
static VALUE _llama_timings_get_n_sample(VALUE self) {
|
339
|
+
LLaMATimingsWrapper* ptr = get_llama_timings(self);
|
340
|
+
return INT2NUM(ptr->timings.n_sample);
|
341
|
+
}
|
342
|
+
|
343
|
+
static VALUE _llama_timings_get_n_p_eval(VALUE self) {
|
344
|
+
LLaMATimingsWrapper* ptr = get_llama_timings(self);
|
345
|
+
return INT2NUM(ptr->timings.n_p_eval);
|
346
|
+
}
|
347
|
+
|
348
|
+
static VALUE _llama_timings_get_n_eval(VALUE self) {
|
349
|
+
LLaMATimingsWrapper* ptr = get_llama_timings(self);
|
350
|
+
return INT2NUM(ptr->timings.n_eval);
|
351
|
+
}
|
352
|
+
};
|
353
|
+
|
354
|
+
const rb_data_type_t RbLLaMATimings::llama_timings_type = {
|
355
|
+
"RbLLaMATimings",
|
356
|
+
{ NULL,
|
357
|
+
RbLLaMATimings::llama_timings_free,
|
358
|
+
RbLLaMATimings::llama_timings_size },
|
359
|
+
NULL,
|
360
|
+
NULL,
|
361
|
+
RUBY_TYPED_FREE_IMMEDIATELY
|
362
|
+
};
|
363
|
+
|
259
364
|
class LLaMAContextParamsWrapper {
|
260
365
|
public:
|
261
366
|
struct llama_context_params params;
|
262
367
|
|
263
|
-
LLaMAContextParamsWrapper() : params(llama_context_default_params()){}
|
368
|
+
LLaMAContextParamsWrapper() : params(llama_context_default_params()) {}
|
264
369
|
|
265
|
-
~LLaMAContextParamsWrapper(){}
|
370
|
+
~LLaMAContextParamsWrapper() {}
|
266
371
|
};
|
267
372
|
|
268
373
|
class RbLLaMAContextParams {
|
@@ -271,22 +376,22 @@ public:
|
|
271
376
|
LLaMAContextParamsWrapper* ptr = (LLaMAContextParamsWrapper*)ruby_xmalloc(sizeof(LLaMAContextParamsWrapper));
|
272
377
|
new (ptr) LLaMAContextParamsWrapper();
|
273
378
|
return TypedData_Wrap_Struct(self, &llama_context_params_type, ptr);
|
274
|
-
}
|
379
|
+
}
|
275
380
|
|
276
381
|
static void llama_context_params_free(void* ptr) {
|
277
382
|
((LLaMAContextParamsWrapper*)ptr)->~LLaMAContextParamsWrapper();
|
278
383
|
ruby_xfree(ptr);
|
279
|
-
}
|
384
|
+
}
|
280
385
|
|
281
386
|
static size_t llama_context_params_size(const void* ptr) {
|
282
387
|
return sizeof(*((LLaMAContextParamsWrapper*)ptr));
|
283
|
-
}
|
388
|
+
}
|
284
389
|
|
285
390
|
static LLaMAContextParamsWrapper* get_llama_context_params(VALUE self) {
|
286
391
|
LLaMAContextParamsWrapper* ptr;
|
287
392
|
TypedData_Get_Struct(self, LLaMAContextParamsWrapper, &llama_context_params_type, ptr);
|
288
393
|
return ptr;
|
289
|
-
}
|
394
|
+
}
|
290
395
|
|
291
396
|
static void define_class(VALUE outer) {
|
292
397
|
rb_cLLaMAContextParams = rb_define_class_under(outer, "ContextParams", rb_cObject);
|
@@ -317,7 +422,7 @@ public:
|
|
317
422
|
rb_define_method(rb_cLLaMAContextParams, "use_mlock", RUBY_METHOD_FUNC(_llama_context_params_get_use_mlock), 0);
|
318
423
|
rb_define_method(rb_cLLaMAContextParams, "embedding=", RUBY_METHOD_FUNC(_llama_context_params_set_embedding), 1);
|
319
424
|
rb_define_method(rb_cLLaMAContextParams, "embedding", RUBY_METHOD_FUNC(_llama_context_params_get_embedding), 0);
|
320
|
-
}
|
425
|
+
}
|
321
426
|
|
322
427
|
private:
|
323
428
|
static const rb_data_type_t llama_context_params_type;
|
@@ -326,55 +431,55 @@ private:
|
|
326
431
|
// LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
327
432
|
// new (ptr) LLaMAContextParamsWrapper();
|
328
433
|
// return self;
|
329
|
-
// }
|
434
|
+
// }
|
330
435
|
|
331
436
|
// n_ctx
|
332
437
|
static VALUE _llama_context_params_set_n_ctx(VALUE self, VALUE n_ctx) {
|
333
438
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
334
439
|
ptr->params.n_ctx = NUM2INT(n_ctx);
|
335
440
|
return INT2NUM(ptr->params.n_ctx);
|
336
|
-
}
|
441
|
+
}
|
337
442
|
|
338
443
|
static VALUE _llama_context_params_get_n_ctx(VALUE self) {
|
339
444
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
340
445
|
return INT2NUM(ptr->params.n_ctx);
|
341
|
-
}
|
446
|
+
}
|
342
447
|
|
343
448
|
// n_batch
|
344
449
|
static VALUE _llama_context_params_set_n_batch(VALUE self, VALUE n_batch) {
|
345
450
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
346
451
|
ptr->params.n_batch = NUM2INT(n_batch);
|
347
452
|
return INT2NUM(ptr->params.n_batch);
|
348
|
-
}
|
453
|
+
}
|
349
454
|
|
350
455
|
static VALUE _llama_context_params_get_n_batch(VALUE self) {
|
351
456
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
352
457
|
return INT2NUM(ptr->params.n_batch);
|
353
|
-
}
|
458
|
+
}
|
354
459
|
|
355
460
|
// n_gpu_layers
|
356
461
|
static VALUE _llama_context_params_set_n_gpu_layers(VALUE self, VALUE n_gpu_layers) {
|
357
462
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
358
463
|
ptr->params.n_gpu_layers = NUM2INT(n_gpu_layers);
|
359
464
|
return INT2NUM(ptr->params.n_gpu_layers);
|
360
|
-
}
|
465
|
+
}
|
361
466
|
|
362
467
|
static VALUE _llama_context_params_get_n_gpu_layers(VALUE self) {
|
363
468
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
364
469
|
return INT2NUM(ptr->params.n_gpu_layers);
|
365
|
-
}
|
470
|
+
}
|
366
471
|
|
367
472
|
// main_gpu
|
368
473
|
static VALUE _llama_context_params_set_main_gpu(VALUE self, VALUE main_gpu) {
|
369
474
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
370
475
|
ptr->params.main_gpu = NUM2INT(main_gpu);
|
371
476
|
return INT2NUM(ptr->params.main_gpu);
|
372
|
-
}
|
477
|
+
}
|
373
478
|
|
374
479
|
static VALUE _llama_context_params_get_main_gpu(VALUE self) {
|
375
480
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
376
481
|
return INT2NUM(ptr->params.main_gpu);
|
377
|
-
}
|
482
|
+
}
|
378
483
|
|
379
484
|
// tensor_split
|
380
485
|
static VALUE _llama_context_params_get_tensor_split(VALUE self) {
|
@@ -387,19 +492,19 @@ private:
|
|
387
492
|
rb_ary_store(ret, i, DBL2NUM(ptr->params.tensor_split[i]));
|
388
493
|
}
|
389
494
|
return ret;
|
390
|
-
}
|
495
|
+
}
|
391
496
|
|
392
497
|
// low_vram
|
393
498
|
static VALUE _llama_context_params_set_low_vram(VALUE self, VALUE low_vram) {
|
394
499
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
395
500
|
ptr->params.low_vram = low_vram == Qtrue ? true : false;
|
396
501
|
return ptr->params.low_vram ? Qtrue : Qfalse;
|
397
|
-
}
|
502
|
+
}
|
398
503
|
|
399
504
|
static VALUE _llama_context_params_get_low_vram(VALUE self) {
|
400
505
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
401
506
|
return ptr->params.low_vram ? Qtrue : Qfalse;
|
402
|
-
}
|
507
|
+
}
|
403
508
|
|
404
509
|
// seed
|
405
510
|
static VALUE _llama_context_params_set_seed(VALUE self, VALUE seed) {
|
@@ -410,84 +515,84 @@ private:
|
|
410
515
|
}
|
411
516
|
ptr->params.seed = NUM2INT(seed);
|
412
517
|
return INT2NUM(ptr->params.seed);
|
413
|
-
}
|
518
|
+
}
|
414
519
|
|
415
520
|
static VALUE _llama_context_params_get_seed(VALUE self) {
|
416
521
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
417
522
|
return INT2NUM(ptr->params.seed);
|
418
|
-
}
|
523
|
+
}
|
419
524
|
|
420
525
|
// f16_kv
|
421
526
|
static VALUE _llama_context_params_set_f16_kv(VALUE self, VALUE f16_kv) {
|
422
527
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
423
528
|
ptr->params.f16_kv = f16_kv == Qtrue ? true : false;
|
424
529
|
return ptr->params.f16_kv ? Qtrue : Qfalse;
|
425
|
-
}
|
530
|
+
}
|
426
531
|
|
427
532
|
static VALUE _llama_context_params_get_f16_kv(VALUE self) {
|
428
533
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
429
534
|
return ptr->params.f16_kv ? Qtrue : Qfalse;
|
430
|
-
}
|
535
|
+
}
|
431
536
|
|
432
537
|
// logits_all
|
433
538
|
static VALUE _llama_context_params_set_logits_all(VALUE self, VALUE logits_all) {
|
434
539
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
435
540
|
ptr->params.logits_all = logits_all == Qtrue ? true : false;
|
436
541
|
return ptr->params.logits_all ? Qtrue : Qfalse;
|
437
|
-
}
|
542
|
+
}
|
438
543
|
|
439
544
|
static VALUE _llama_context_params_get_logits_all(VALUE self) {
|
440
545
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
441
546
|
return ptr->params.logits_all ? Qtrue : Qfalse;
|
442
|
-
}
|
547
|
+
}
|
443
548
|
|
444
549
|
// vocab_only
|
445
550
|
static VALUE _llama_context_params_set_vocab_only(VALUE self, VALUE vocab_only) {
|
446
551
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
447
552
|
ptr->params.vocab_only = vocab_only == Qtrue ? true : false;
|
448
553
|
return ptr->params.vocab_only ? Qtrue : Qfalse;
|
449
|
-
}
|
554
|
+
}
|
450
555
|
|
451
556
|
static VALUE _llama_context_params_get_vocab_only(VALUE self) {
|
452
557
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
453
558
|
return ptr->params.vocab_only ? Qtrue : Qfalse;
|
454
|
-
}
|
559
|
+
}
|
455
560
|
|
456
561
|
// use_mmap
|
457
562
|
static VALUE _llama_context_params_set_use_mmap(VALUE self, VALUE use_mmap) {
|
458
563
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
459
564
|
ptr->params.use_mmap = use_mmap == Qtrue ? true : false;
|
460
565
|
return ptr->params.use_mmap ? Qtrue : Qfalse;
|
461
|
-
}
|
566
|
+
}
|
462
567
|
|
463
568
|
static VALUE _llama_context_params_get_use_mmap(VALUE self) {
|
464
569
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
465
570
|
return ptr->params.use_mmap ? Qtrue : Qfalse;
|
466
|
-
}
|
571
|
+
}
|
467
572
|
|
468
573
|
// use_mlock
|
469
574
|
static VALUE _llama_context_params_set_use_mlock(VALUE self, VALUE use_mlock) {
|
470
575
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
471
576
|
ptr->params.use_mlock = use_mlock == Qtrue ? true : false;
|
472
577
|
return ptr->params.use_mlock ? Qtrue : Qfalse;
|
473
|
-
}
|
578
|
+
}
|
474
579
|
|
475
580
|
static VALUE _llama_context_params_get_use_mlock(VALUE self) {
|
476
581
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
477
582
|
return ptr->params.use_mlock ? Qtrue : Qfalse;
|
478
|
-
}
|
583
|
+
}
|
479
584
|
|
480
585
|
// embedding
|
481
586
|
static VALUE _llama_context_params_set_embedding(VALUE self, VALUE embedding) {
|
482
587
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
483
588
|
ptr->params.embedding = embedding == Qtrue ? true : false;
|
484
589
|
return ptr->params.embedding ? Qtrue : Qfalse;
|
485
|
-
}
|
590
|
+
}
|
486
591
|
|
487
592
|
static VALUE _llama_context_params_get_embedding(VALUE self) {
|
488
593
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
489
594
|
return ptr->params.embedding ? Qtrue : Qfalse;
|
490
|
-
}
|
595
|
+
}
|
491
596
|
};
|
492
597
|
|
493
598
|
const rb_data_type_t RbLLaMAContextParams::llama_context_params_type = {
|
@@ -504,9 +609,9 @@ class LLaMAModelQuantizeParamsWrapper {
|
|
504
609
|
public:
|
505
610
|
llama_model_quantize_params params;
|
506
611
|
|
507
|
-
LLaMAModelQuantizeParamsWrapper() : params(llama_model_quantize_default_params()){}
|
612
|
+
LLaMAModelQuantizeParamsWrapper() : params(llama_model_quantize_default_params()) {}
|
508
613
|
|
509
|
-
~LLaMAModelQuantizeParamsWrapper(){}
|
614
|
+
~LLaMAModelQuantizeParamsWrapper() {}
|
510
615
|
};
|
511
616
|
|
512
617
|
class RbLLaMAModelQuantizeParams {
|
@@ -515,22 +620,22 @@ public:
|
|
515
620
|
LLaMAModelQuantizeParamsWrapper* ptr = (LLaMAModelQuantizeParamsWrapper*)ruby_xmalloc(sizeof(LLaMAModelQuantizeParamsWrapper));
|
516
621
|
new (ptr) LLaMAModelQuantizeParamsWrapper();
|
517
622
|
return TypedData_Wrap_Struct(self, &llama_model_quantize_params_type, ptr);
|
518
|
-
}
|
623
|
+
}
|
519
624
|
|
520
625
|
static void llama_model_quantize_params_free(void* ptr) {
|
521
626
|
((LLaMAModelQuantizeParamsWrapper*)ptr)->~LLaMAModelQuantizeParamsWrapper();
|
522
627
|
ruby_xfree(ptr);
|
523
|
-
}
|
628
|
+
}
|
524
629
|
|
525
630
|
static size_t llama_model_quantize_params_size(const void* ptr) {
|
526
631
|
return sizeof(*((LLaMAModelQuantizeParamsWrapper*)ptr));
|
527
|
-
}
|
632
|
+
}
|
528
633
|
|
529
634
|
static LLaMAModelQuantizeParamsWrapper* get_llama_model_quantize_params(VALUE self) {
|
530
635
|
LLaMAModelQuantizeParamsWrapper* ptr;
|
531
636
|
TypedData_Get_Struct(self, LLaMAModelQuantizeParamsWrapper, &llama_model_quantize_params_type, ptr);
|
532
637
|
return ptr;
|
533
|
-
}
|
638
|
+
}
|
534
639
|
|
535
640
|
static void define_class(VALUE outer) {
|
536
641
|
rb_cLLaMAModelQuantizeParams = rb_define_class_under(outer, "ModelQuantizeParams", rb_cObject);
|
@@ -543,7 +648,7 @@ public:
|
|
543
648
|
rb_define_method(rb_cLLaMAModelQuantizeParams, "allow_requantize", RUBY_METHOD_FUNC(_llama_model_quantize_params_get_allow_requantize), 0);
|
544
649
|
rb_define_method(rb_cLLaMAModelQuantizeParams, "quantize_output_tensor=", RUBY_METHOD_FUNC(_llama_model_quantize_params_set_quantize_output_tensor), 1);
|
545
650
|
rb_define_method(rb_cLLaMAModelQuantizeParams, "quantize_output_tensor", RUBY_METHOD_FUNC(_llama_model_quantize_params_get_quantize_output_tensor), 0);
|
546
|
-
}
|
651
|
+
}
|
547
652
|
|
548
653
|
private:
|
549
654
|
static const rb_data_type_t llama_model_quantize_params_type;
|
@@ -553,24 +658,24 @@ private:
|
|
553
658
|
LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
|
554
659
|
ptr->params.nthread = NUM2INT(n_thread);
|
555
660
|
return INT2NUM(ptr->params.nthread);
|
556
|
-
}
|
661
|
+
}
|
557
662
|
|
558
663
|
static VALUE _llama_model_quantize_params_get_n_thread(VALUE self) {
|
559
664
|
LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
|
560
665
|
return INT2NUM(ptr->params.nthread);
|
561
|
-
}
|
666
|
+
}
|
562
667
|
|
563
668
|
// ftype
|
564
669
|
static VALUE _llama_model_quantize_params_set_ftype(VALUE self, VALUE ftype) {
|
565
670
|
LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
|
566
671
|
ptr->params.ftype = static_cast<enum llama_ftype>(NUM2INT(ftype));
|
567
672
|
return INT2NUM(ptr->params.ftype);
|
568
|
-
}
|
673
|
+
}
|
569
674
|
|
570
675
|
static VALUE _llama_model_quantize_params_get_ftype(VALUE self) {
|
571
676
|
LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
|
572
677
|
return INT2NUM(ptr->params.ftype);
|
573
|
-
}
|
678
|
+
}
|
574
679
|
|
575
680
|
// allow_requantize
|
576
681
|
static VALUE _llama_model_quantize_params_set_allow_requantize(VALUE self, VALUE allow_requantize) {
|
@@ -581,12 +686,12 @@ private:
|
|
581
686
|
ptr->params.allow_requantize = true;
|
582
687
|
}
|
583
688
|
return ptr->params.allow_requantize ? Qtrue : Qfalse;
|
584
|
-
}
|
689
|
+
}
|
585
690
|
|
586
691
|
static VALUE _llama_model_quantize_params_get_allow_requantize(VALUE self) {
|
587
692
|
LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
|
588
693
|
return ptr->params.allow_requantize ? Qtrue : Qfalse;
|
589
|
-
}
|
694
|
+
}
|
590
695
|
|
591
696
|
// quantize_output_tensor
|
592
697
|
static VALUE _llama_model_quantize_params_set_quantize_output_tensor(VALUE self, VALUE quantize_output_tensor) {
|
@@ -597,12 +702,12 @@ private:
|
|
597
702
|
ptr->params.quantize_output_tensor = true;
|
598
703
|
}
|
599
704
|
return ptr->params.quantize_output_tensor ? Qtrue : Qfalse;
|
600
|
-
}
|
705
|
+
}
|
601
706
|
|
602
707
|
static VALUE _llama_model_quantize_params_get_quantize_output_tensor(VALUE self) {
|
603
708
|
LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
|
604
709
|
return ptr->params.quantize_output_tensor ? Qtrue : Qfalse;
|
605
|
-
}
|
710
|
+
}
|
606
711
|
};
|
607
712
|
|
608
713
|
const rb_data_type_t RbLLaMAModelQuantizeParams::llama_model_quantize_params_type = {
|
@@ -619,13 +724,13 @@ class LLaMAModelWrapper {
|
|
619
724
|
public:
|
620
725
|
struct llama_model* model;
|
621
726
|
|
622
|
-
LLaMAModelWrapper() : model(NULL){}
|
727
|
+
LLaMAModelWrapper() : model(NULL) {}
|
623
728
|
|
624
729
|
~LLaMAModelWrapper() {
|
625
730
|
if (model != NULL) {
|
626
731
|
llama_free_model(model);
|
627
732
|
}
|
628
|
-
}
|
733
|
+
}
|
629
734
|
};
|
630
735
|
|
631
736
|
class RbLLaMAModel {
|
@@ -802,7 +907,7 @@ private:
|
|
802
907
|
return Qnil;
|
803
908
|
}
|
804
909
|
return Qnil;
|
805
|
-
}
|
910
|
+
}
|
806
911
|
};
|
807
912
|
|
808
913
|
const rb_data_type_t RbLLaMAModel::llama_model_type = {
|
@@ -819,13 +924,13 @@ class LLaMAContextWrapper {
|
|
819
924
|
public:
|
820
925
|
struct llama_context* ctx;
|
821
926
|
|
822
|
-
LLaMAContextWrapper() : ctx(NULL){}
|
927
|
+
LLaMAContextWrapper() : ctx(NULL) {}
|
823
928
|
|
824
929
|
~LLaMAContextWrapper() {
|
825
930
|
if (ctx != NULL) {
|
826
931
|
llama_free(ctx);
|
827
932
|
}
|
828
|
-
}
|
933
|
+
}
|
829
934
|
};
|
830
935
|
|
831
936
|
class RbLLaMAContext {
|
@@ -834,22 +939,22 @@ public:
|
|
834
939
|
LLaMAContextWrapper* ptr = (LLaMAContextWrapper*)ruby_xmalloc(sizeof(LLaMAContextWrapper));
|
835
940
|
new (ptr) LLaMAContextWrapper();
|
836
941
|
return TypedData_Wrap_Struct(self, &llama_context_type, ptr);
|
837
|
-
}
|
942
|
+
}
|
838
943
|
|
839
944
|
static void llama_context_free(void* ptr) {
|
840
945
|
((LLaMAContextWrapper*)ptr)->~LLaMAContextWrapper();
|
841
946
|
ruby_xfree(ptr);
|
842
|
-
}
|
947
|
+
}
|
843
948
|
|
844
949
|
static size_t llama_context_size(const void* ptr) {
|
845
950
|
return sizeof(*((LLaMAContextWrapper*)ptr));
|
846
|
-
}
|
951
|
+
}
|
847
952
|
|
848
953
|
static LLaMAContextWrapper* get_llama_context(VALUE self) {
|
849
954
|
LLaMAContextWrapper* ptr;
|
850
955
|
TypedData_Get_Struct(self, LLaMAContextWrapper, &llama_context_type, ptr);
|
851
956
|
return ptr;
|
852
|
-
}
|
957
|
+
}
|
853
958
|
|
854
959
|
static void define_class(VALUE outer) {
|
855
960
|
rb_cLLaMAContext = rb_define_class_under(outer, "Context", rb_cObject);
|
@@ -866,6 +971,7 @@ public:
|
|
866
971
|
rb_define_method(rb_cLLaMAContext, "n_vocab", RUBY_METHOD_FUNC(_llama_context_n_vocab), 0);
|
867
972
|
rb_define_method(rb_cLLaMAContext, "n_ctx", RUBY_METHOD_FUNC(_llama_context_n_ctx), 0);
|
868
973
|
rb_define_method(rb_cLLaMAContext, "n_embd", RUBY_METHOD_FUNC(_llama_context_n_embd), 0);
|
974
|
+
rb_define_method(rb_cLLaMAContext, "timings", RUBY_METHOD_FUNC(_llama_context_get_timings), 0);
|
869
975
|
rb_define_method(rb_cLLaMAContext, "print_timings", RUBY_METHOD_FUNC(_llama_context_print_timings), 0);
|
870
976
|
rb_define_method(rb_cLLaMAContext, "reset_timings", RUBY_METHOD_FUNC(_llama_context_reset_timings), 0);
|
871
977
|
rb_define_method(rb_cLLaMAContext, "kv_cache_token_count", RUBY_METHOD_FUNC(_llama_context_kv_cache_token_count), 0);
|
@@ -874,6 +980,7 @@ public:
|
|
874
980
|
rb_define_method(rb_cLLaMAContext, "save_session_file", RUBY_METHOD_FUNC(_llama_context_save_session_file), -1);
|
875
981
|
rb_define_method(rb_cLLaMAContext, "sample_repetition_penalty", RUBY_METHOD_FUNC(_llama_context_sample_repetition_penalty), -1);
|
876
982
|
rb_define_method(rb_cLLaMAContext, "sample_frequency_and_presence_penalties", RUBY_METHOD_FUNC(_llama_context_sample_frequency_and_presence_penalties), -1);
|
983
|
+
rb_define_method(rb_cLLaMAContext, "sample_classifier_free_guidance", RUBY_METHOD_FUNC(_llama_context_sample_classifier_free_guidance), -1);
|
877
984
|
rb_define_method(rb_cLLaMAContext, "sample_softmax", RUBY_METHOD_FUNC(_llama_context_sample_softmax), 1);
|
878
985
|
rb_define_method(rb_cLLaMAContext, "sample_top_k", RUBY_METHOD_FUNC(_llama_context_sample_top_k), -1);
|
879
986
|
rb_define_method(rb_cLLaMAContext, "sample_top_p", RUBY_METHOD_FUNC(_llama_context_sample_top_p), -1);
|
@@ -884,7 +991,7 @@ public:
|
|
884
991
|
rb_define_method(rb_cLLaMAContext, "sample_token_mirostat_v2", RUBY_METHOD_FUNC(_llama_context_sample_token_mirostat_v2), -1);
|
885
992
|
rb_define_method(rb_cLLaMAContext, "sample_token_greedy", RUBY_METHOD_FUNC(_llama_context_sample_token_greedy), 1);
|
886
993
|
rb_define_method(rb_cLLaMAContext, "sample_token", RUBY_METHOD_FUNC(_llama_context_sample_token), 1);
|
887
|
-
}
|
994
|
+
}
|
888
995
|
|
889
996
|
private:
|
890
997
|
static const rb_data_type_t llama_context_type;
|
@@ -923,7 +1030,7 @@ private:
|
|
923
1030
|
rb_iv_set(self, "@has_evaluated", Qfalse);
|
924
1031
|
|
925
1032
|
return Qnil;
|
926
|
-
}
|
1033
|
+
}
|
927
1034
|
|
928
1035
|
static VALUE _llama_context_eval(int argc, VALUE* argv, VALUE self) {
|
929
1036
|
VALUE kw_args = Qnil;
|
@@ -978,7 +1085,7 @@ private:
|
|
978
1085
|
rb_iv_set(self, "@has_evaluated", Qtrue);
|
979
1086
|
|
980
1087
|
return Qnil;
|
981
|
-
}
|
1088
|
+
}
|
982
1089
|
|
983
1090
|
static VALUE _llama_context_eval_embd(int argc, VALUE* argv, VALUE self) {
|
984
1091
|
VALUE kw_args = Qnil;
|
@@ -1051,7 +1158,7 @@ private:
|
|
1051
1158
|
}
|
1052
1159
|
RB_GC_GUARD(fname_);
|
1053
1160
|
return Qtrue;
|
1054
|
-
}
|
1161
|
+
}
|
1055
1162
|
|
1056
1163
|
static VALUE _llama_context_tokenize(int argc, VALUE* argv, VALUE self) {
|
1057
1164
|
VALUE kw_args = Qnil;
|
@@ -1097,7 +1204,7 @@ private:
|
|
1097
1204
|
|
1098
1205
|
RB_GC_GUARD(text_);
|
1099
1206
|
return output;
|
1100
|
-
}
|
1207
|
+
}
|
1101
1208
|
|
1102
1209
|
static VALUE _llama_context_token_to_str(VALUE self, VALUE token_) {
|
1103
1210
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
@@ -1108,7 +1215,7 @@ private:
|
|
1108
1215
|
const llama_token token = NUM2INT(token_);
|
1109
1216
|
const char* str = llama_token_to_str(ptr->ctx, token);
|
1110
1217
|
return str != nullptr ? rb_utf8_str_new_cstr(str) : rb_utf8_str_new_cstr("");
|
1111
|
-
}
|
1218
|
+
}
|
1112
1219
|
|
1113
1220
|
static VALUE _llama_context_logits(VALUE self) {
|
1114
1221
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
@@ -1133,7 +1240,7 @@ private:
|
|
1133
1240
|
}
|
1134
1241
|
|
1135
1242
|
return output;
|
1136
|
-
}
|
1243
|
+
}
|
1137
1244
|
|
1138
1245
|
static VALUE _llama_context_embeddings(VALUE self) {
|
1139
1246
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
@@ -1161,7 +1268,7 @@ private:
|
|
1161
1268
|
}
|
1162
1269
|
|
1163
1270
|
return output;
|
1164
|
-
}
|
1271
|
+
}
|
1165
1272
|
|
1166
1273
|
static VALUE _llama_context_vocab(int argc, VALUE* argv, VALUE self) {
|
1167
1274
|
VALUE kw_args = Qnil;
|
@@ -1198,7 +1305,7 @@ private:
|
|
1198
1305
|
}
|
1199
1306
|
|
1200
1307
|
return rb_ary_new_from_args(2, ret_strings, ret_scores);
|
1201
|
-
}
|
1308
|
+
}
|
1202
1309
|
|
1203
1310
|
static VALUE _llama_context_n_vocab(VALUE self) {
|
1204
1311
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
@@ -1207,7 +1314,7 @@ private:
|
|
1207
1314
|
return Qnil;
|
1208
1315
|
}
|
1209
1316
|
return INT2NUM(llama_n_vocab(ptr->ctx));
|
1210
|
-
}
|
1317
|
+
}
|
1211
1318
|
|
1212
1319
|
static VALUE _llama_context_n_ctx(VALUE self) {
|
1213
1320
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
@@ -1216,7 +1323,7 @@ private:
|
|
1216
1323
|
return Qnil;
|
1217
1324
|
}
|
1218
1325
|
return INT2NUM(llama_n_ctx(ptr->ctx));
|
1219
|
-
}
|
1326
|
+
}
|
1220
1327
|
|
1221
1328
|
static VALUE _llama_context_n_embd(VALUE self) {
|
1222
1329
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
@@ -1225,7 +1332,19 @@ private:
|
|
1225
1332
|
return Qnil;
|
1226
1333
|
}
|
1227
1334
|
return INT2NUM(llama_n_embd(ptr->ctx));
|
1228
|
-
}
|
1335
|
+
}
|
1336
|
+
|
1337
|
+
static VALUE _llama_context_get_timings(VALUE self) {
|
1338
|
+
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1339
|
+
if (ptr->ctx == NULL) {
|
1340
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1341
|
+
return Qnil;
|
1342
|
+
}
|
1343
|
+
VALUE tm_obj = rb_funcall(rb_cLLaMATimings, rb_intern("new"), 0);
|
1344
|
+
LLaMATimingsWrapper* tm_ptr = RbLLaMATimings::get_llama_timings(tm_obj);
|
1345
|
+
tm_ptr->timings = llama_get_timings(ptr->ctx);
|
1346
|
+
return tm_obj;
|
1347
|
+
}
|
1229
1348
|
|
1230
1349
|
static VALUE _llama_context_print_timings(VALUE self) {
|
1231
1350
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
@@ -1235,7 +1354,7 @@ private:
|
|
1235
1354
|
}
|
1236
1355
|
llama_print_timings(ptr->ctx);
|
1237
1356
|
return Qnil;
|
1238
|
-
}
|
1357
|
+
}
|
1239
1358
|
|
1240
1359
|
static VALUE _llama_context_reset_timings(VALUE self) {
|
1241
1360
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
@@ -1245,7 +1364,7 @@ private:
|
|
1245
1364
|
}
|
1246
1365
|
llama_reset_timings(ptr->ctx);
|
1247
1366
|
return Qnil;
|
1248
|
-
}
|
1367
|
+
}
|
1249
1368
|
|
1250
1369
|
static VALUE _llama_context_kv_cache_token_count(VALUE self) {
|
1251
1370
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
@@ -1254,7 +1373,7 @@ private:
|
|
1254
1373
|
return Qnil;
|
1255
1374
|
}
|
1256
1375
|
return INT2NUM(llama_get_kv_cache_token_count(ptr->ctx));
|
1257
|
-
}
|
1376
|
+
}
|
1258
1377
|
|
1259
1378
|
static VALUE _llama_context_set_rng_seed(VALUE self, VALUE seed_) {
|
1260
1379
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
@@ -1269,7 +1388,7 @@ private:
|
|
1269
1388
|
const uint32_t seed = NUM2INT(seed_);
|
1270
1389
|
llama_set_rng_seed(ptr->ctx, seed);
|
1271
1390
|
return Qnil;
|
1272
|
-
}
|
1391
|
+
}
|
1273
1392
|
|
1274
1393
|
static VALUE _llama_context_load_session_file(int argc, VALUE* argv, VALUE self) {
|
1275
1394
|
VALUE kw_args = Qnil;
|
@@ -1407,7 +1526,7 @@ private:
|
|
1407
1526
|
llama_sample_repetition_penalty(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size, penalty);
|
1408
1527
|
|
1409
1528
|
return Qnil;
|
1410
|
-
}
|
1529
|
+
}
|
1411
1530
|
|
1412
1531
|
static VALUE _llama_context_sample_frequency_and_presence_penalties(int argc, VALUE* argv, VALUE self) {
|
1413
1532
|
VALUE kw_args = Qnil;
|
@@ -1458,7 +1577,52 @@ private:
|
|
1458
1577
|
llama_sample_frequency_and_presence_penalties(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size, alpha_frequency, alpha_presence);
|
1459
1578
|
|
1460
1579
|
return Qnil;
|
1461
|
-
}
|
1580
|
+
}
|
1581
|
+
|
1582
|
+
static VALUE _llama_context_sample_classifier_free_guidance(int argc, VALUE* argv, VALUE self) {
|
1583
|
+
VALUE kw_args = Qnil;
|
1584
|
+
ID kw_table[3] = { rb_intern("guidance"), rb_intern("scale"), rb_intern("smooth_factor") };
|
1585
|
+
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
1586
|
+
VALUE candidates = Qnil;
|
1587
|
+
rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
|
1588
|
+
rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
|
1589
|
+
|
1590
|
+
if (!rb_obj_is_kind_of(kw_values[0], rb_cLLaMAContext)) {
|
1591
|
+
rb_raise(rb_eArgError, "guidance must be a Context");
|
1592
|
+
return Qnil;
|
1593
|
+
}
|
1594
|
+
if (!RB_FLOAT_TYPE_P(kw_values[1])) {
|
1595
|
+
rb_raise(rb_eArgError, "scale must be a float");
|
1596
|
+
return Qnil;
|
1597
|
+
}
|
1598
|
+
if (!RB_FLOAT_TYPE_P(kw_values[2])) {
|
1599
|
+
rb_raise(rb_eArgError, "smooth_factor must be a float");
|
1600
|
+
return Qnil;
|
1601
|
+
}
|
1602
|
+
|
1603
|
+
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
1604
|
+
if (ctx_ptr->ctx == NULL) {
|
1605
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1606
|
+
return Qnil;
|
1607
|
+
}
|
1608
|
+
LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
|
1609
|
+
if (cnd_ptr->array.data == nullptr) {
|
1610
|
+
rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
|
1611
|
+
return Qnil;
|
1612
|
+
}
|
1613
|
+
|
1614
|
+
LLaMAContextWrapper* guidance_ptr = get_llama_context(kw_values[0]);
|
1615
|
+
if (guidance_ptr->ctx == NULL) {
|
1616
|
+
rb_raise(rb_eRuntimeError, "guidance context is not initialized");
|
1617
|
+
return Qnil;
|
1618
|
+
}
|
1619
|
+
const float scale = NUM2DBL(kw_values[1]);
|
1620
|
+
const float smooth_factor = NUM2DBL(kw_values[2]);
|
1621
|
+
|
1622
|
+
llama_sample_classifier_free_guidance(ctx_ptr->ctx, &(cnd_ptr->array), guidance_ptr->ctx, scale, smooth_factor);
|
1623
|
+
|
1624
|
+
return Qnil;
|
1625
|
+
}
|
1462
1626
|
|
1463
1627
|
static VALUE _llama_context_sample_softmax(VALUE self, VALUE candidates) {
|
1464
1628
|
if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
|
@@ -1480,7 +1644,7 @@ private:
|
|
1480
1644
|
llama_sample_softmax(ctx_ptr->ctx, &(cnd_ptr->array));
|
1481
1645
|
|
1482
1646
|
return Qnil;
|
1483
|
-
}
|
1647
|
+
}
|
1484
1648
|
|
1485
1649
|
static VALUE _llama_context_sample_top_k(int argc, VALUE* argv, VALUE self) {
|
1486
1650
|
VALUE kw_args = Qnil;
|
@@ -1519,7 +1683,7 @@ private:
|
|
1519
1683
|
llama_sample_top_k(ctx_ptr->ctx, &(cnd_ptr->array), k, min_keep);
|
1520
1684
|
|
1521
1685
|
return Qnil;
|
1522
|
-
}
|
1686
|
+
}
|
1523
1687
|
|
1524
1688
|
static VALUE _llama_context_sample_top_p(int argc, VALUE* argv, VALUE self) {
|
1525
1689
|
VALUE kw_args = Qnil;
|
@@ -1558,7 +1722,7 @@ private:
|
|
1558
1722
|
llama_sample_top_p(ctx_ptr->ctx, &(cnd_ptr->array), prob, min_keep);
|
1559
1723
|
|
1560
1724
|
return Qnil;
|
1561
|
-
}
|
1725
|
+
}
|
1562
1726
|
|
1563
1727
|
static VALUE _llama_context_sample_tail_free(int argc, VALUE* argv, VALUE self) {
|
1564
1728
|
VALUE kw_args = Qnil;
|
@@ -1597,7 +1761,7 @@ private:
|
|
1597
1761
|
llama_sample_tail_free(ctx_ptr->ctx, &(cnd_ptr->array), z, min_keep);
|
1598
1762
|
|
1599
1763
|
return Qnil;
|
1600
|
-
}
|
1764
|
+
}
|
1601
1765
|
|
1602
1766
|
static VALUE _llama_context_sample_typical(int argc, VALUE* argv, VALUE self) {
|
1603
1767
|
VALUE kw_args = Qnil;
|
@@ -1636,7 +1800,7 @@ private:
|
|
1636
1800
|
llama_sample_typical(ctx_ptr->ctx, &(cnd_ptr->array), prob, min_keep);
|
1637
1801
|
|
1638
1802
|
return Qnil;
|
1639
|
-
}
|
1803
|
+
}
|
1640
1804
|
|
1641
1805
|
static VALUE _llama_context_sample_temperature(int argc, VALUE* argv, VALUE self) {
|
1642
1806
|
VALUE kw_args = Qnil;
|
@@ -1670,7 +1834,7 @@ private:
|
|
1670
1834
|
llama_sample_temperature(ctx_ptr->ctx, &(cnd_ptr->array), temperature);
|
1671
1835
|
|
1672
1836
|
return Qnil;
|
1673
|
-
}
|
1837
|
+
}
|
1674
1838
|
|
1675
1839
|
static VALUE _llama_context_sample_token_mirostat(int argc, VALUE* argv, VALUE self) {
|
1676
1840
|
VALUE kw_args = Qnil;
|
@@ -1722,7 +1886,7 @@ private:
|
|
1722
1886
|
rb_ary_store(ret, 0, INT2NUM(id));
|
1723
1887
|
rb_ary_store(ret, 1, DBL2NUM(mu));
|
1724
1888
|
return ret;
|
1725
|
-
}
|
1889
|
+
}
|
1726
1890
|
|
1727
1891
|
static VALUE _llama_context_sample_token_mirostat_v2(int argc, VALUE* argv, VALUE self) {
|
1728
1892
|
VALUE kw_args = Qnil;
|
@@ -1769,7 +1933,7 @@ private:
|
|
1769
1933
|
rb_ary_store(ret, 0, INT2NUM(id));
|
1770
1934
|
rb_ary_store(ret, 1, DBL2NUM(mu));
|
1771
1935
|
return ret;
|
1772
|
-
}
|
1936
|
+
}
|
1773
1937
|
|
1774
1938
|
static VALUE _llama_context_sample_token_greedy(VALUE self, VALUE candidates) {
|
1775
1939
|
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
@@ -1788,7 +1952,7 @@ private:
|
|
1788
1952
|
}
|
1789
1953
|
llama_token id = llama_sample_token_greedy(ctx_ptr->ctx, &(cnd_ptr->array));
|
1790
1954
|
return INT2NUM(id);
|
1791
|
-
}
|
1955
|
+
}
|
1792
1956
|
|
1793
1957
|
static VALUE _llama_context_sample_token(VALUE self, VALUE candidates) {
|
1794
1958
|
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
@@ -1807,7 +1971,7 @@ private:
|
|
1807
1971
|
}
|
1808
1972
|
llama_token id = llama_sample_token(ctx_ptr->ctx, &(cnd_ptr->array));
|
1809
1973
|
return INT2NUM(id);
|
1810
|
-
}
|
1974
|
+
}
|
1811
1975
|
};
|
1812
1976
|
|
1813
1977
|
const rb_data_type_t RbLLaMAContext::llama_context_type = {
|
@@ -1822,7 +1986,7 @@ const rb_data_type_t RbLLaMAContext::llama_context_type = {
|
|
1822
1986
|
|
1823
1987
|
// module functions
|
1824
1988
|
|
1825
|
-
static VALUE
|
1989
|
+
static VALUE rb_llama_llama_backend_init(int argc, VALUE* argv, VALUE self) {
|
1826
1990
|
VALUE kw_args = Qnil;
|
1827
1991
|
ID kw_table[1] = { rb_intern("numa") };
|
1828
1992
|
VALUE kw_values[1] = { Qundef };
|
@@ -1830,7 +1994,13 @@ static VALUE rb_llama_llama_init_backend(int argc, VALUE* argv, VALUE self) {
|
|
1830
1994
|
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
1831
1995
|
|
1832
1996
|
const bool numa = kw_values[0] == Qundef ? false : (RTEST ? true : false);
|
1833
|
-
|
1997
|
+
llama_backend_init(numa);
|
1998
|
+
|
1999
|
+
return Qnil;
|
2000
|
+
}
|
2001
|
+
|
2002
|
+
static VALUE rb_llama_llama_backend_free(VALUE self) {
|
2003
|
+
llama_backend_free();
|
1834
2004
|
|
1835
2005
|
return Qnil;
|
1836
2006
|
}
|
@@ -1898,10 +2068,13 @@ extern "C" void Init_llama_cpp(void) {
|
|
1898
2068
|
RbLLaMATokenData::define_class(rb_mLLaMACpp);
|
1899
2069
|
RbLLaMATokenDataArray::define_class(rb_mLLaMACpp);
|
1900
2070
|
RbLLaMAModel::define_class(rb_mLLaMACpp);
|
2071
|
+
RbLLaMATimings::define_class(rb_mLLaMACpp);
|
1901
2072
|
RbLLaMAContext::define_class(rb_mLLaMACpp);
|
1902
2073
|
RbLLaMAContextParams::define_class(rb_mLLaMACpp);
|
2074
|
+
RbLLaMAModelQuantizeParams::define_class(rb_mLLaMACpp);
|
1903
2075
|
|
1904
|
-
rb_define_module_function(rb_mLLaMACpp, "
|
2076
|
+
rb_define_module_function(rb_mLLaMACpp, "backend_init", rb_llama_llama_backend_init, -1);
|
2077
|
+
rb_define_module_function(rb_mLLaMACpp, "backend_free", rb_llama_llama_backend_free, 0);
|
1905
2078
|
rb_define_module_function(rb_mLLaMACpp, "model_quantize", rb_llama_model_quantize, -1);
|
1906
2079
|
rb_define_module_function(rb_mLLaMACpp, "token_bos", rb_llama_token_bos, 0);
|
1907
2080
|
rb_define_module_function(rb_mLLaMACpp, "token_eos", rb_llama_token_eos, 0);
|