llama_cpp 0.3.2 → 0.3.4
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +37 -0
- data/ext/llama_cpp/extconf.rb +9 -0
- data/ext/llama_cpp/llama_cpp.cpp +302 -112
- data/ext/llama_cpp/src/ggml-cuda.cu +677 -118
- data/ext/llama_cpp/src/ggml-metal.h +5 -1
- data/ext/llama_cpp/src/ggml-metal.m +65 -45
- data/ext/llama_cpp/src/ggml-metal.metal +610 -484
- data/ext/llama_cpp/src/ggml-mpi.c +216 -0
- data/ext/llama_cpp/src/ggml-mpi.h +39 -0
- data/ext/llama_cpp/src/ggml.c +1146 -812
- data/ext/llama_cpp/src/ggml.h +77 -19
- data/ext/llama_cpp/src/k_quants.h +8 -0
- data/ext/llama_cpp/src/llama.cpp +289 -104
- data/ext/llama_cpp/src/llama.h +46 -3
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +2 -1
- data/sig/llama_cpp.rbs +14 -1
- metadata +4 -2
data/ext/llama_cpp/llama_cpp.cpp
CHANGED
@@ -17,9 +17,9 @@ public:
|
|
17
17
|
data.id = 0;
|
18
18
|
data.logit = 0.0;
|
19
19
|
data.p = 0.0;
|
20
|
-
}
|
20
|
+
}
|
21
21
|
|
22
|
-
~LLaMATokenDataWrapper(){}
|
22
|
+
~LLaMATokenDataWrapper() {}
|
23
23
|
};
|
24
24
|
|
25
25
|
class RbLLaMATokenData {
|
@@ -28,22 +28,22 @@ public:
|
|
28
28
|
LLaMATokenDataWrapper* ptr = (LLaMATokenDataWrapper*)ruby_xmalloc(sizeof(LLaMATokenDataWrapper));
|
29
29
|
new (ptr) LLaMATokenDataWrapper();
|
30
30
|
return TypedData_Wrap_Struct(self, &llama_token_data_type, ptr);
|
31
|
-
}
|
31
|
+
}
|
32
32
|
|
33
33
|
static void llama_token_data_free(void* ptr) {
|
34
34
|
((LLaMATokenDataWrapper*)ptr)->~LLaMATokenDataWrapper();
|
35
35
|
ruby_xfree(ptr);
|
36
|
-
}
|
36
|
+
}
|
37
37
|
|
38
38
|
static size_t llama_token_data_size(const void* ptr) {
|
39
39
|
return sizeof(*((LLaMATokenDataWrapper*)ptr));
|
40
|
-
}
|
40
|
+
}
|
41
41
|
|
42
42
|
static LLaMATokenDataWrapper* get_llama_token_data(VALUE self) {
|
43
43
|
LLaMATokenDataWrapper* ptr;
|
44
44
|
TypedData_Get_Struct(self, LLaMATokenDataWrapper, &llama_token_data_type, ptr);
|
45
45
|
return ptr;
|
46
|
-
}
|
46
|
+
}
|
47
47
|
|
48
48
|
static void define_class(VALUE outer) {
|
49
49
|
rb_cLLaMATokenData = rb_define_class_under(outer, "TokenData", rb_cObject);
|
@@ -95,36 +95,36 @@ private:
|
|
95
95
|
LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
|
96
96
|
ptr->data.id = NUM2INT(id);
|
97
97
|
return INT2NUM(ptr->data.id);
|
98
|
-
}
|
98
|
+
}
|
99
99
|
|
100
100
|
static VALUE _llama_token_data_get_id(VALUE self) {
|
101
101
|
LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
|
102
102
|
return INT2NUM(ptr->data.id);
|
103
|
-
}
|
103
|
+
}
|
104
104
|
|
105
105
|
// logit
|
106
106
|
static VALUE _llama_token_data_set_logit(VALUE self, VALUE logit) {
|
107
107
|
LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
|
108
108
|
ptr->data.logit = NUM2DBL(logit);
|
109
109
|
return DBL2NUM(ptr->data.logit);
|
110
|
-
}
|
110
|
+
}
|
111
111
|
|
112
112
|
static VALUE _llama_token_data_get_logit(VALUE self) {
|
113
113
|
LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
|
114
114
|
return DBL2NUM(ptr->data.logit);
|
115
|
-
}
|
115
|
+
}
|
116
116
|
|
117
117
|
// p
|
118
118
|
static VALUE _llama_token_data_set_p(VALUE self, VALUE p) {
|
119
119
|
LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
|
120
120
|
ptr->data.p = NUM2DBL(p);
|
121
121
|
return DBL2NUM(ptr->data.p);
|
122
|
-
}
|
122
|
+
}
|
123
123
|
|
124
124
|
static VALUE _llama_token_data_get_p(VALUE self) {
|
125
125
|
LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
|
126
126
|
return DBL2NUM(ptr->data.p);
|
127
|
-
}
|
127
|
+
}
|
128
128
|
};
|
129
129
|
|
130
130
|
const rb_data_type_t RbLLaMATokenData::llama_token_data_type = {
|
@@ -145,14 +145,14 @@ public:
|
|
145
145
|
array.data = nullptr;
|
146
146
|
array.size = 0;
|
147
147
|
array.sorted = false;
|
148
|
-
}
|
148
|
+
}
|
149
149
|
|
150
150
|
~LLaMATokenDataArrayWrapper() {
|
151
151
|
if (array.data) {
|
152
152
|
ruby_xfree(array.data);
|
153
153
|
array.data = nullptr;
|
154
154
|
}
|
155
|
-
}
|
155
|
+
}
|
156
156
|
};
|
157
157
|
|
158
158
|
class RbLLaMATokenDataArray {
|
@@ -161,22 +161,22 @@ public:
|
|
161
161
|
LLaMATokenDataArrayWrapper* ptr = (LLaMATokenDataArrayWrapper*)ruby_xmalloc(sizeof(LLaMATokenDataArrayWrapper));
|
162
162
|
new (ptr) LLaMATokenDataArrayWrapper();
|
163
163
|
return TypedData_Wrap_Struct(self, &llama_token_data_array_type, ptr);
|
164
|
-
}
|
164
|
+
}
|
165
165
|
|
166
166
|
static void llama_token_data_array_free(void* ptr) {
|
167
167
|
((LLaMATokenDataArrayWrapper*)ptr)->~LLaMATokenDataArrayWrapper();
|
168
168
|
ruby_xfree(ptr);
|
169
|
-
}
|
169
|
+
}
|
170
170
|
|
171
171
|
static size_t llama_token_data_array_size(const void* ptr) {
|
172
172
|
return sizeof(*((LLaMATokenDataArrayWrapper*)ptr));
|
173
|
-
}
|
173
|
+
}
|
174
174
|
|
175
175
|
static LLaMATokenDataArrayWrapper* get_llama_token_data_array(VALUE self) {
|
176
176
|
LLaMATokenDataArrayWrapper* ptr;
|
177
177
|
TypedData_Get_Struct(self, LLaMATokenDataArrayWrapper, &llama_token_data_array_type, ptr);
|
178
178
|
return ptr;
|
179
|
-
}
|
179
|
+
}
|
180
180
|
|
181
181
|
static void define_class(VALUE outer) {
|
182
182
|
rb_cLLaMATokenDataArray = rb_define_class_under(outer, "TokenDataArray", rb_cObject);
|
@@ -184,7 +184,7 @@ public:
|
|
184
184
|
rb_define_method(rb_cLLaMATokenDataArray, "initialize", RUBY_METHOD_FUNC(_llama_token_data_array_init), -1);
|
185
185
|
rb_define_method(rb_cLLaMATokenDataArray, "size", RUBY_METHOD_FUNC(_llama_token_data_array_get_size), 0);
|
186
186
|
rb_define_method(rb_cLLaMATokenDataArray, "sorted", RUBY_METHOD_FUNC(_llama_token_data_array_get_sorted), 0);
|
187
|
-
}
|
187
|
+
}
|
188
188
|
|
189
189
|
private:
|
190
190
|
static const rb_data_type_t llama_token_data_array_type;
|
@@ -233,17 +233,17 @@ private:
|
|
233
233
|
ptr->array.sorted = kw_values[0] == Qtrue;
|
234
234
|
|
235
235
|
return self;
|
236
|
-
}
|
236
|
+
}
|
237
237
|
|
238
238
|
static VALUE _llama_token_data_array_get_size(VALUE self) {
|
239
239
|
LLaMATokenDataArrayWrapper* ptr = get_llama_token_data_array(self);
|
240
240
|
return SIZET2NUM(ptr->array.size);
|
241
|
-
}
|
241
|
+
}
|
242
242
|
|
243
243
|
static VALUE _llama_token_data_array_get_sorted(VALUE self) {
|
244
244
|
LLaMATokenDataArrayWrapper* ptr = get_llama_token_data_array(self);
|
245
245
|
return ptr->array.sorted ? Qtrue : Qfalse;
|
246
|
-
}
|
246
|
+
}
|
247
247
|
};
|
248
248
|
|
249
249
|
const rb_data_type_t RbLLaMATokenDataArray::llama_token_data_array_type = {
|
@@ -260,9 +260,9 @@ class LLaMATimingsWrapper {
|
|
260
260
|
public:
|
261
261
|
struct llama_timings timings;
|
262
262
|
|
263
|
-
LLaMATimingsWrapper(){}
|
263
|
+
LLaMATimingsWrapper() {}
|
264
264
|
|
265
|
-
~LLaMATimingsWrapper(){}
|
265
|
+
~LLaMATimingsWrapper() {}
|
266
266
|
};
|
267
267
|
|
268
268
|
class RbLLaMATimings {
|
@@ -365,9 +365,9 @@ class LLaMAContextParamsWrapper {
|
|
365
365
|
public:
|
366
366
|
struct llama_context_params params;
|
367
367
|
|
368
|
-
LLaMAContextParamsWrapper() : params(llama_context_default_params()){}
|
368
|
+
LLaMAContextParamsWrapper() : params(llama_context_default_params()) {}
|
369
369
|
|
370
|
-
~LLaMAContextParamsWrapper(){}
|
370
|
+
~LLaMAContextParamsWrapper() {}
|
371
371
|
};
|
372
372
|
|
373
373
|
class RbLLaMAContextParams {
|
@@ -376,22 +376,22 @@ public:
|
|
376
376
|
LLaMAContextParamsWrapper* ptr = (LLaMAContextParamsWrapper*)ruby_xmalloc(sizeof(LLaMAContextParamsWrapper));
|
377
377
|
new (ptr) LLaMAContextParamsWrapper();
|
378
378
|
return TypedData_Wrap_Struct(self, &llama_context_params_type, ptr);
|
379
|
-
}
|
379
|
+
}
|
380
380
|
|
381
381
|
static void llama_context_params_free(void* ptr) {
|
382
382
|
((LLaMAContextParamsWrapper*)ptr)->~LLaMAContextParamsWrapper();
|
383
383
|
ruby_xfree(ptr);
|
384
|
-
}
|
384
|
+
}
|
385
385
|
|
386
386
|
static size_t llama_context_params_size(const void* ptr) {
|
387
387
|
return sizeof(*((LLaMAContextParamsWrapper*)ptr));
|
388
|
-
}
|
388
|
+
}
|
389
389
|
|
390
390
|
static LLaMAContextParamsWrapper* get_llama_context_params(VALUE self) {
|
391
391
|
LLaMAContextParamsWrapper* ptr;
|
392
392
|
TypedData_Get_Struct(self, LLaMAContextParamsWrapper, &llama_context_params_type, ptr);
|
393
393
|
return ptr;
|
394
|
-
}
|
394
|
+
}
|
395
395
|
|
396
396
|
static void define_class(VALUE outer) {
|
397
397
|
rb_cLLaMAContextParams = rb_define_class_under(outer, "ContextParams", rb_cObject);
|
@@ -406,6 +406,10 @@ public:
|
|
406
406
|
rb_define_method(rb_cLLaMAContextParams, "main_gpu=", RUBY_METHOD_FUNC(_llama_context_params_set_main_gpu), 1);
|
407
407
|
rb_define_method(rb_cLLaMAContextParams, "main_gpu", RUBY_METHOD_FUNC(_llama_context_params_get_main_gpu), 0);
|
408
408
|
rb_define_method(rb_cLLaMAContextParams, "tensor_split", RUBY_METHOD_FUNC(_llama_context_params_get_tensor_split), 0);
|
409
|
+
rb_define_method(rb_cLLaMAContextParams, "rope_freq_base=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_freq_base), 1);
|
410
|
+
rb_define_method(rb_cLLaMAContextParams, "rope_freq_base", RUBY_METHOD_FUNC(_llama_context_params_get_rope_freq_base), 0);
|
411
|
+
rb_define_method(rb_cLLaMAContextParams, "rope_freq_scale=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_freq_scale), 1);
|
412
|
+
rb_define_method(rb_cLLaMAContextParams, "rope_freq_scale", RUBY_METHOD_FUNC(_llama_context_params_get_rope_freq_scale), 0);
|
409
413
|
rb_define_method(rb_cLLaMAContextParams, "low_vram=", RUBY_METHOD_FUNC(_llama_context_params_set_low_vram), 1);
|
410
414
|
rb_define_method(rb_cLLaMAContextParams, "low_vram", RUBY_METHOD_FUNC(_llama_context_params_get_low_vram), 0);
|
411
415
|
rb_define_method(rb_cLLaMAContextParams, "seed=", RUBY_METHOD_FUNC(_llama_context_params_set_seed), 1);
|
@@ -422,7 +426,7 @@ public:
|
|
422
426
|
rb_define_method(rb_cLLaMAContextParams, "use_mlock", RUBY_METHOD_FUNC(_llama_context_params_get_use_mlock), 0);
|
423
427
|
rb_define_method(rb_cLLaMAContextParams, "embedding=", RUBY_METHOD_FUNC(_llama_context_params_set_embedding), 1);
|
424
428
|
rb_define_method(rb_cLLaMAContextParams, "embedding", RUBY_METHOD_FUNC(_llama_context_params_get_embedding), 0);
|
425
|
-
}
|
429
|
+
}
|
426
430
|
|
427
431
|
private:
|
428
432
|
static const rb_data_type_t llama_context_params_type;
|
@@ -431,55 +435,55 @@ private:
|
|
431
435
|
// LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
432
436
|
// new (ptr) LLaMAContextParamsWrapper();
|
433
437
|
// return self;
|
434
|
-
// }
|
438
|
+
// }
|
435
439
|
|
436
440
|
// n_ctx
|
437
441
|
static VALUE _llama_context_params_set_n_ctx(VALUE self, VALUE n_ctx) {
|
438
442
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
439
443
|
ptr->params.n_ctx = NUM2INT(n_ctx);
|
440
444
|
return INT2NUM(ptr->params.n_ctx);
|
441
|
-
}
|
445
|
+
}
|
442
446
|
|
443
447
|
static VALUE _llama_context_params_get_n_ctx(VALUE self) {
|
444
448
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
445
449
|
return INT2NUM(ptr->params.n_ctx);
|
446
|
-
}
|
450
|
+
}
|
447
451
|
|
448
452
|
// n_batch
|
449
453
|
static VALUE _llama_context_params_set_n_batch(VALUE self, VALUE n_batch) {
|
450
454
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
451
455
|
ptr->params.n_batch = NUM2INT(n_batch);
|
452
456
|
return INT2NUM(ptr->params.n_batch);
|
453
|
-
}
|
457
|
+
}
|
454
458
|
|
455
459
|
static VALUE _llama_context_params_get_n_batch(VALUE self) {
|
456
460
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
457
461
|
return INT2NUM(ptr->params.n_batch);
|
458
|
-
}
|
462
|
+
}
|
459
463
|
|
460
464
|
// n_gpu_layers
|
461
465
|
static VALUE _llama_context_params_set_n_gpu_layers(VALUE self, VALUE n_gpu_layers) {
|
462
466
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
463
467
|
ptr->params.n_gpu_layers = NUM2INT(n_gpu_layers);
|
464
468
|
return INT2NUM(ptr->params.n_gpu_layers);
|
465
|
-
}
|
469
|
+
}
|
466
470
|
|
467
471
|
static VALUE _llama_context_params_get_n_gpu_layers(VALUE self) {
|
468
472
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
469
473
|
return INT2NUM(ptr->params.n_gpu_layers);
|
470
|
-
}
|
474
|
+
}
|
471
475
|
|
472
476
|
// main_gpu
|
473
477
|
static VALUE _llama_context_params_set_main_gpu(VALUE self, VALUE main_gpu) {
|
474
478
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
475
479
|
ptr->params.main_gpu = NUM2INT(main_gpu);
|
476
480
|
return INT2NUM(ptr->params.main_gpu);
|
477
|
-
}
|
481
|
+
}
|
478
482
|
|
479
483
|
static VALUE _llama_context_params_get_main_gpu(VALUE self) {
|
480
484
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
481
485
|
return INT2NUM(ptr->params.main_gpu);
|
482
|
-
}
|
486
|
+
}
|
483
487
|
|
484
488
|
// tensor_split
|
485
489
|
static VALUE _llama_context_params_get_tensor_split(VALUE self) {
|
@@ -492,19 +496,43 @@ private:
|
|
492
496
|
rb_ary_store(ret, i, DBL2NUM(ptr->params.tensor_split[i]));
|
493
497
|
}
|
494
498
|
return ret;
|
495
|
-
}
|
499
|
+
}
|
500
|
+
|
501
|
+
// rope_freq_base
|
502
|
+
static VALUE _llama_context_params_set_rope_freq_base(VALUE self, VALUE rope_freq_base) {
|
503
|
+
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
504
|
+
ptr->params.rope_freq_base = NUM2DBL(rope_freq_base);
|
505
|
+
return DBL2NUM(ptr->params.rope_freq_base);
|
506
|
+
}
|
507
|
+
|
508
|
+
static VALUE _llama_context_params_get_rope_freq_base(VALUE self) {
|
509
|
+
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
510
|
+
return DBL2NUM(ptr->params.rope_freq_base);
|
511
|
+
}
|
512
|
+
|
513
|
+
// rope_freq_scale
|
514
|
+
static VALUE _llama_context_params_set_rope_freq_scale(VALUE self, VALUE rope_freq_scale) {
|
515
|
+
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
516
|
+
ptr->params.rope_freq_scale = NUM2DBL(rope_freq_scale);
|
517
|
+
return DBL2NUM(ptr->params.rope_freq_scale);
|
518
|
+
}
|
519
|
+
|
520
|
+
static VALUE _llama_context_params_get_rope_freq_scale(VALUE self) {
|
521
|
+
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
522
|
+
return DBL2NUM(ptr->params.rope_freq_scale);
|
523
|
+
}
|
496
524
|
|
497
525
|
// low_vram
|
498
526
|
static VALUE _llama_context_params_set_low_vram(VALUE self, VALUE low_vram) {
|
499
527
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
500
528
|
ptr->params.low_vram = low_vram == Qtrue ? true : false;
|
501
529
|
return ptr->params.low_vram ? Qtrue : Qfalse;
|
502
|
-
}
|
530
|
+
}
|
503
531
|
|
504
532
|
static VALUE _llama_context_params_get_low_vram(VALUE self) {
|
505
533
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
506
534
|
return ptr->params.low_vram ? Qtrue : Qfalse;
|
507
|
-
}
|
535
|
+
}
|
508
536
|
|
509
537
|
// seed
|
510
538
|
static VALUE _llama_context_params_set_seed(VALUE self, VALUE seed) {
|
@@ -515,84 +543,84 @@ private:
|
|
515
543
|
}
|
516
544
|
ptr->params.seed = NUM2INT(seed);
|
517
545
|
return INT2NUM(ptr->params.seed);
|
518
|
-
}
|
546
|
+
}
|
519
547
|
|
520
548
|
static VALUE _llama_context_params_get_seed(VALUE self) {
|
521
549
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
522
550
|
return INT2NUM(ptr->params.seed);
|
523
|
-
}
|
551
|
+
}
|
524
552
|
|
525
553
|
// f16_kv
|
526
554
|
static VALUE _llama_context_params_set_f16_kv(VALUE self, VALUE f16_kv) {
|
527
555
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
528
556
|
ptr->params.f16_kv = f16_kv == Qtrue ? true : false;
|
529
557
|
return ptr->params.f16_kv ? Qtrue : Qfalse;
|
530
|
-
}
|
558
|
+
}
|
531
559
|
|
532
560
|
static VALUE _llama_context_params_get_f16_kv(VALUE self) {
|
533
561
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
534
562
|
return ptr->params.f16_kv ? Qtrue : Qfalse;
|
535
|
-
}
|
563
|
+
}
|
536
564
|
|
537
565
|
// logits_all
|
538
566
|
static VALUE _llama_context_params_set_logits_all(VALUE self, VALUE logits_all) {
|
539
567
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
540
568
|
ptr->params.logits_all = logits_all == Qtrue ? true : false;
|
541
569
|
return ptr->params.logits_all ? Qtrue : Qfalse;
|
542
|
-
}
|
570
|
+
}
|
543
571
|
|
544
572
|
static VALUE _llama_context_params_get_logits_all(VALUE self) {
|
545
573
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
546
574
|
return ptr->params.logits_all ? Qtrue : Qfalse;
|
547
|
-
}
|
575
|
+
}
|
548
576
|
|
549
577
|
// vocab_only
|
550
578
|
static VALUE _llama_context_params_set_vocab_only(VALUE self, VALUE vocab_only) {
|
551
579
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
552
580
|
ptr->params.vocab_only = vocab_only == Qtrue ? true : false;
|
553
581
|
return ptr->params.vocab_only ? Qtrue : Qfalse;
|
554
|
-
}
|
582
|
+
}
|
555
583
|
|
556
584
|
static VALUE _llama_context_params_get_vocab_only(VALUE self) {
|
557
585
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
558
586
|
return ptr->params.vocab_only ? Qtrue : Qfalse;
|
559
|
-
}
|
587
|
+
}
|
560
588
|
|
561
589
|
// use_mmap
|
562
590
|
static VALUE _llama_context_params_set_use_mmap(VALUE self, VALUE use_mmap) {
|
563
591
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
564
592
|
ptr->params.use_mmap = use_mmap == Qtrue ? true : false;
|
565
593
|
return ptr->params.use_mmap ? Qtrue : Qfalse;
|
566
|
-
}
|
594
|
+
}
|
567
595
|
|
568
596
|
static VALUE _llama_context_params_get_use_mmap(VALUE self) {
|
569
597
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
570
598
|
return ptr->params.use_mmap ? Qtrue : Qfalse;
|
571
|
-
}
|
599
|
+
}
|
572
600
|
|
573
601
|
// use_mlock
|
574
602
|
static VALUE _llama_context_params_set_use_mlock(VALUE self, VALUE use_mlock) {
|
575
603
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
576
604
|
ptr->params.use_mlock = use_mlock == Qtrue ? true : false;
|
577
605
|
return ptr->params.use_mlock ? Qtrue : Qfalse;
|
578
|
-
}
|
606
|
+
}
|
579
607
|
|
580
608
|
static VALUE _llama_context_params_get_use_mlock(VALUE self) {
|
581
609
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
582
610
|
return ptr->params.use_mlock ? Qtrue : Qfalse;
|
583
|
-
}
|
611
|
+
}
|
584
612
|
|
585
613
|
// embedding
|
586
614
|
static VALUE _llama_context_params_set_embedding(VALUE self, VALUE embedding) {
|
587
615
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
588
616
|
ptr->params.embedding = embedding == Qtrue ? true : false;
|
589
617
|
return ptr->params.embedding ? Qtrue : Qfalse;
|
590
|
-
}
|
618
|
+
}
|
591
619
|
|
592
620
|
static VALUE _llama_context_params_get_embedding(VALUE self) {
|
593
621
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
594
622
|
return ptr->params.embedding ? Qtrue : Qfalse;
|
595
|
-
}
|
623
|
+
}
|
596
624
|
};
|
597
625
|
|
598
626
|
const rb_data_type_t RbLLaMAContextParams::llama_context_params_type = {
|
@@ -609,9 +637,9 @@ class LLaMAModelQuantizeParamsWrapper {
|
|
609
637
|
public:
|
610
638
|
llama_model_quantize_params params;
|
611
639
|
|
612
|
-
LLaMAModelQuantizeParamsWrapper() : params(llama_model_quantize_default_params()){}
|
640
|
+
LLaMAModelQuantizeParamsWrapper() : params(llama_model_quantize_default_params()) {}
|
613
641
|
|
614
|
-
~LLaMAModelQuantizeParamsWrapper(){}
|
642
|
+
~LLaMAModelQuantizeParamsWrapper() {}
|
615
643
|
};
|
616
644
|
|
617
645
|
class RbLLaMAModelQuantizeParams {
|
@@ -620,22 +648,22 @@ public:
|
|
620
648
|
LLaMAModelQuantizeParamsWrapper* ptr = (LLaMAModelQuantizeParamsWrapper*)ruby_xmalloc(sizeof(LLaMAModelQuantizeParamsWrapper));
|
621
649
|
new (ptr) LLaMAModelQuantizeParamsWrapper();
|
622
650
|
return TypedData_Wrap_Struct(self, &llama_model_quantize_params_type, ptr);
|
623
|
-
}
|
651
|
+
}
|
624
652
|
|
625
653
|
static void llama_model_quantize_params_free(void* ptr) {
|
626
654
|
((LLaMAModelQuantizeParamsWrapper*)ptr)->~LLaMAModelQuantizeParamsWrapper();
|
627
655
|
ruby_xfree(ptr);
|
628
|
-
}
|
656
|
+
}
|
629
657
|
|
630
658
|
static size_t llama_model_quantize_params_size(const void* ptr) {
|
631
659
|
return sizeof(*((LLaMAModelQuantizeParamsWrapper*)ptr));
|
632
|
-
}
|
660
|
+
}
|
633
661
|
|
634
662
|
static LLaMAModelQuantizeParamsWrapper* get_llama_model_quantize_params(VALUE self) {
|
635
663
|
LLaMAModelQuantizeParamsWrapper* ptr;
|
636
664
|
TypedData_Get_Struct(self, LLaMAModelQuantizeParamsWrapper, &llama_model_quantize_params_type, ptr);
|
637
665
|
return ptr;
|
638
|
-
}
|
666
|
+
}
|
639
667
|
|
640
668
|
static void define_class(VALUE outer) {
|
641
669
|
rb_cLLaMAModelQuantizeParams = rb_define_class_under(outer, "ModelQuantizeParams", rb_cObject);
|
@@ -648,7 +676,7 @@ public:
|
|
648
676
|
rb_define_method(rb_cLLaMAModelQuantizeParams, "allow_requantize", RUBY_METHOD_FUNC(_llama_model_quantize_params_get_allow_requantize), 0);
|
649
677
|
rb_define_method(rb_cLLaMAModelQuantizeParams, "quantize_output_tensor=", RUBY_METHOD_FUNC(_llama_model_quantize_params_set_quantize_output_tensor), 1);
|
650
678
|
rb_define_method(rb_cLLaMAModelQuantizeParams, "quantize_output_tensor", RUBY_METHOD_FUNC(_llama_model_quantize_params_get_quantize_output_tensor), 0);
|
651
|
-
}
|
679
|
+
}
|
652
680
|
|
653
681
|
private:
|
654
682
|
static const rb_data_type_t llama_model_quantize_params_type;
|
@@ -658,24 +686,24 @@ private:
|
|
658
686
|
LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
|
659
687
|
ptr->params.nthread = NUM2INT(n_thread);
|
660
688
|
return INT2NUM(ptr->params.nthread);
|
661
|
-
}
|
689
|
+
}
|
662
690
|
|
663
691
|
static VALUE _llama_model_quantize_params_get_n_thread(VALUE self) {
|
664
692
|
LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
|
665
693
|
return INT2NUM(ptr->params.nthread);
|
666
|
-
}
|
694
|
+
}
|
667
695
|
|
668
696
|
// ftype
|
669
697
|
static VALUE _llama_model_quantize_params_set_ftype(VALUE self, VALUE ftype) {
|
670
698
|
LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
|
671
699
|
ptr->params.ftype = static_cast<enum llama_ftype>(NUM2INT(ftype));
|
672
700
|
return INT2NUM(ptr->params.ftype);
|
673
|
-
}
|
701
|
+
}
|
674
702
|
|
675
703
|
static VALUE _llama_model_quantize_params_get_ftype(VALUE self) {
|
676
704
|
LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
|
677
705
|
return INT2NUM(ptr->params.ftype);
|
678
|
-
}
|
706
|
+
}
|
679
707
|
|
680
708
|
// allow_requantize
|
681
709
|
static VALUE _llama_model_quantize_params_set_allow_requantize(VALUE self, VALUE allow_requantize) {
|
@@ -686,12 +714,12 @@ private:
|
|
686
714
|
ptr->params.allow_requantize = true;
|
687
715
|
}
|
688
716
|
return ptr->params.allow_requantize ? Qtrue : Qfalse;
|
689
|
-
}
|
717
|
+
}
|
690
718
|
|
691
719
|
static VALUE _llama_model_quantize_params_get_allow_requantize(VALUE self) {
|
692
720
|
LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
|
693
721
|
return ptr->params.allow_requantize ? Qtrue : Qfalse;
|
694
|
-
}
|
722
|
+
}
|
695
723
|
|
696
724
|
// quantize_output_tensor
|
697
725
|
static VALUE _llama_model_quantize_params_set_quantize_output_tensor(VALUE self, VALUE quantize_output_tensor) {
|
@@ -702,12 +730,12 @@ private:
|
|
702
730
|
ptr->params.quantize_output_tensor = true;
|
703
731
|
}
|
704
732
|
return ptr->params.quantize_output_tensor ? Qtrue : Qfalse;
|
705
|
-
}
|
733
|
+
}
|
706
734
|
|
707
735
|
static VALUE _llama_model_quantize_params_get_quantize_output_tensor(VALUE self) {
|
708
736
|
LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
|
709
737
|
return ptr->params.quantize_output_tensor ? Qtrue : Qfalse;
|
710
|
-
}
|
738
|
+
}
|
711
739
|
};
|
712
740
|
|
713
741
|
const rb_data_type_t RbLLaMAModelQuantizeParams::llama_model_quantize_params_type = {
|
@@ -724,13 +752,13 @@ class LLaMAModelWrapper {
|
|
724
752
|
public:
|
725
753
|
struct llama_model* model;
|
726
754
|
|
727
|
-
LLaMAModelWrapper() : model(NULL){}
|
755
|
+
LLaMAModelWrapper() : model(NULL) {}
|
728
756
|
|
729
757
|
~LLaMAModelWrapper() {
|
730
758
|
if (model != NULL) {
|
731
759
|
llama_free_model(model);
|
732
760
|
}
|
733
|
-
}
|
761
|
+
}
|
734
762
|
};
|
735
763
|
|
736
764
|
class RbLLaMAModel {
|
@@ -764,6 +792,12 @@ public:
|
|
764
792
|
rb_define_method(rb_cLLaMAModel, "free", RUBY_METHOD_FUNC(_llama_model_free), 0);
|
765
793
|
rb_define_method(rb_cLLaMAModel, "load", RUBY_METHOD_FUNC(_llama_model_load), -1);
|
766
794
|
rb_define_method(rb_cLLaMAModel, "apply_lora_from_file", RUBY_METHOD_FUNC(_llama_model_apply_lora_from_file), -1);
|
795
|
+
rb_define_method(rb_cLLaMAModel, "n_vocab", RUBY_METHOD_FUNC(_llama_model_get_n_vocab_from_model), 0);
|
796
|
+
rb_define_method(rb_cLLaMAModel, "n_ctx", RUBY_METHOD_FUNC(_llama_model_get_n_ctx_from_model), 0);
|
797
|
+
rb_define_method(rb_cLLaMAModel, "n_embd", RUBY_METHOD_FUNC(_llama_model_get_n_embd_from_model), 0);
|
798
|
+
rb_define_method(rb_cLLaMAModel, "vocab", RUBY_METHOD_FUNC(_llama_model_get_vocab_from_model), -1);
|
799
|
+
rb_define_method(rb_cLLaMAModel, "token_to_str", RUBY_METHOD_FUNC(_llama_model_token_to_str_with_model), 1);
|
800
|
+
rb_define_method(rb_cLLaMAModel, "tokenize", RUBY_METHOD_FUNC(_llama_model_tokenize_with_model), -1);
|
767
801
|
}
|
768
802
|
|
769
803
|
private:
|
@@ -907,7 +941,110 @@ private:
|
|
907
941
|
return Qnil;
|
908
942
|
}
|
909
943
|
return Qnil;
|
910
|
-
}
|
944
|
+
}
|
945
|
+
|
946
|
+
static VALUE _llama_model_get_n_vocab_from_model(VALUE self) {
|
947
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
948
|
+
return INT2NUM(llama_n_vocab_from_model(ptr->model));
|
949
|
+
}
|
950
|
+
|
951
|
+
static VALUE _llama_model_get_n_ctx_from_model(VALUE self) {
|
952
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
953
|
+
return INT2NUM(llama_n_ctx_from_model(ptr->model));
|
954
|
+
}
|
955
|
+
|
956
|
+
static VALUE _llama_model_get_n_embd_from_model(VALUE self) {
|
957
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
958
|
+
return INT2NUM(llama_n_embd_from_model(ptr->model));
|
959
|
+
}
|
960
|
+
|
961
|
+
static VALUE _llama_model_get_vocab_from_model(int argc, VALUE* argv, VALUE self) {
|
962
|
+
VALUE kw_args = Qnil;
|
963
|
+
ID kw_table[1] = { rb_intern("capacity") };
|
964
|
+
VALUE kw_values[1] = { Qundef };
|
965
|
+
rb_scan_args(argc, argv, ":", &kw_args);
|
966
|
+
rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
|
967
|
+
|
968
|
+
if (!RB_INTEGER_TYPE_P(kw_values[0])) {
|
969
|
+
rb_raise(rb_eArgError, "capacity must be an integer");
|
970
|
+
return Qnil;
|
971
|
+
}
|
972
|
+
|
973
|
+
const int capacity = NUM2INT(kw_values[0]);
|
974
|
+
|
975
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
976
|
+
const int n = std::min(capacity, llama_n_vocab_from_model(ptr->model));
|
977
|
+
const char** vocabs = ALLOCA_N(const char*, n);
|
978
|
+
float* scores = ALLOCA_N(float, n);
|
979
|
+
|
980
|
+
llama_get_vocab_from_model(ptr->model, vocabs, scores, capacity);
|
981
|
+
|
982
|
+
VALUE vocabs_ary = rb_ary_new();
|
983
|
+
VALUE scores_ary = rb_ary_new();
|
984
|
+
|
985
|
+
for (int i = 0; i < n; i++) {
|
986
|
+
rb_ary_push(vocabs_ary, rb_str_new_cstr(vocabs[i]));
|
987
|
+
rb_ary_push(scores_ary, DBL2NUM(scores[i]));
|
988
|
+
}
|
989
|
+
|
990
|
+
VALUE ret = rb_ary_new3(2, vocabs_ary, scores_ary);
|
991
|
+
|
992
|
+
return ret;
|
993
|
+
}
|
994
|
+
|
995
|
+
static VALUE _llama_model_token_to_str_with_model(VALUE self, VALUE token_) {
|
996
|
+
if (!RB_INTEGER_TYPE_P(token_)) {
|
997
|
+
rb_raise(rb_eArgError, "token must be an integer");
|
998
|
+
return Qnil;
|
999
|
+
}
|
1000
|
+
const llama_token token = NUM2INT(token_);
|
1001
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1002
|
+
const char* str = llama_token_to_str_with_model(ptr->model, token);
|
1003
|
+
return rb_str_new_cstr(str);
|
1004
|
+
}
|
1005
|
+
|
1006
|
+
static VALUE _llama_model_tokenize_with_model(int argc, VALUE* argv, VALUE self) {
|
1007
|
+
VALUE kw_args = Qnil;
|
1008
|
+
ID kw_table[3] = { rb_intern("text"), rb_intern("n_max_tokens"), rb_intern("add_bos") };
|
1009
|
+
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
1010
|
+
rb_scan_args(argc, argv, ":", &kw_args);
|
1011
|
+
rb_get_kwargs(kw_args, kw_table, 1, 2, kw_values);
|
1012
|
+
|
1013
|
+
if (!RB_TYPE_P(kw_values[0], T_STRING)) {
|
1014
|
+
rb_raise(rb_eArgError, "text must be a String");
|
1015
|
+
return Qnil;
|
1016
|
+
}
|
1017
|
+
if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
|
1018
|
+
rb_raise(rb_eArgError, "n_max_tokens must be an integer");
|
1019
|
+
return Qnil;
|
1020
|
+
}
|
1021
|
+
if (kw_values[2] != Qundef && (kw_values[2] != Qtrue && kw_values[2] != Qfalse)) {
|
1022
|
+
rb_raise(rb_eArgError, "add_bos must be a boolean");
|
1023
|
+
return Qnil;
|
1024
|
+
}
|
1025
|
+
|
1026
|
+
VALUE text_ = kw_values[0];
|
1027
|
+
std::string text = StringValueCStr(text_);
|
1028
|
+
const bool add_bos = kw_values[2] == Qtrue ? true : false;
|
1029
|
+
const int n_max_tokens = kw_values[1] != Qundef ? NUM2INT(kw_values[1]) : text.size() + (add_bos ? 1 : 0);
|
1030
|
+
|
1031
|
+
llama_token* tokens = ALLOCA_N(llama_token, n_max_tokens);
|
1032
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1033
|
+
const int n_tokens = llama_tokenize_with_model(ptr->model, text.c_str(), tokens, n_max_tokens, add_bos);
|
1034
|
+
|
1035
|
+
if (n_tokens < 0) {
|
1036
|
+
rb_raise(rb_eRuntimeError, "failed to tokenize. The numebr of tokens (%d) is greater than n_max_tokens.", -n_tokens);
|
1037
|
+
return Qnil;
|
1038
|
+
}
|
1039
|
+
|
1040
|
+
VALUE ret = rb_ary_new2(n_tokens);
|
1041
|
+
for (int i = 0; i < n_tokens; i++) {
|
1042
|
+
rb_ary_store(ret, i, INT2NUM(tokens[i]));
|
1043
|
+
}
|
1044
|
+
|
1045
|
+
RB_GC_GUARD(text_);
|
1046
|
+
return ret;
|
1047
|
+
}
|
911
1048
|
};
|
912
1049
|
|
913
1050
|
const rb_data_type_t RbLLaMAModel::llama_model_type = {
|
@@ -924,13 +1061,13 @@ class LLaMAContextWrapper {
|
|
924
1061
|
public:
|
925
1062
|
struct llama_context* ctx;
|
926
1063
|
|
927
|
-
LLaMAContextWrapper() : ctx(NULL){}
|
1064
|
+
LLaMAContextWrapper() : ctx(NULL) {}
|
928
1065
|
|
929
1066
|
~LLaMAContextWrapper() {
|
930
1067
|
if (ctx != NULL) {
|
931
1068
|
llama_free(ctx);
|
932
1069
|
}
|
933
|
-
}
|
1070
|
+
}
|
934
1071
|
};
|
935
1072
|
|
936
1073
|
class RbLLaMAContext {
|
@@ -939,22 +1076,22 @@ public:
|
|
939
1076
|
LLaMAContextWrapper* ptr = (LLaMAContextWrapper*)ruby_xmalloc(sizeof(LLaMAContextWrapper));
|
940
1077
|
new (ptr) LLaMAContextWrapper();
|
941
1078
|
return TypedData_Wrap_Struct(self, &llama_context_type, ptr);
|
942
|
-
}
|
1079
|
+
}
|
943
1080
|
|
944
1081
|
static void llama_context_free(void* ptr) {
|
945
1082
|
((LLaMAContextWrapper*)ptr)->~LLaMAContextWrapper();
|
946
1083
|
ruby_xfree(ptr);
|
947
|
-
}
|
1084
|
+
}
|
948
1085
|
|
949
1086
|
static size_t llama_context_size(const void* ptr) {
|
950
1087
|
return sizeof(*((LLaMAContextWrapper*)ptr));
|
951
|
-
}
|
1088
|
+
}
|
952
1089
|
|
953
1090
|
static LLaMAContextWrapper* get_llama_context(VALUE self) {
|
954
1091
|
LLaMAContextWrapper* ptr;
|
955
1092
|
TypedData_Get_Struct(self, LLaMAContextWrapper, &llama_context_type, ptr);
|
956
1093
|
return ptr;
|
957
|
-
}
|
1094
|
+
}
|
958
1095
|
|
959
1096
|
static void define_class(VALUE outer) {
|
960
1097
|
rb_cLLaMAContext = rb_define_class_under(outer, "Context", rb_cObject);
|
@@ -980,6 +1117,7 @@ public:
|
|
980
1117
|
rb_define_method(rb_cLLaMAContext, "save_session_file", RUBY_METHOD_FUNC(_llama_context_save_session_file), -1);
|
981
1118
|
rb_define_method(rb_cLLaMAContext, "sample_repetition_penalty", RUBY_METHOD_FUNC(_llama_context_sample_repetition_penalty), -1);
|
982
1119
|
rb_define_method(rb_cLLaMAContext, "sample_frequency_and_presence_penalties", RUBY_METHOD_FUNC(_llama_context_sample_frequency_and_presence_penalties), -1);
|
1120
|
+
rb_define_method(rb_cLLaMAContext, "sample_classifier_free_guidance", RUBY_METHOD_FUNC(_llama_context_sample_classifier_free_guidance), -1);
|
983
1121
|
rb_define_method(rb_cLLaMAContext, "sample_softmax", RUBY_METHOD_FUNC(_llama_context_sample_softmax), 1);
|
984
1122
|
rb_define_method(rb_cLLaMAContext, "sample_top_k", RUBY_METHOD_FUNC(_llama_context_sample_top_k), -1);
|
985
1123
|
rb_define_method(rb_cLLaMAContext, "sample_top_p", RUBY_METHOD_FUNC(_llama_context_sample_top_p), -1);
|
@@ -990,7 +1128,7 @@ public:
|
|
990
1128
|
rb_define_method(rb_cLLaMAContext, "sample_token_mirostat_v2", RUBY_METHOD_FUNC(_llama_context_sample_token_mirostat_v2), -1);
|
991
1129
|
rb_define_method(rb_cLLaMAContext, "sample_token_greedy", RUBY_METHOD_FUNC(_llama_context_sample_token_greedy), 1);
|
992
1130
|
rb_define_method(rb_cLLaMAContext, "sample_token", RUBY_METHOD_FUNC(_llama_context_sample_token), 1);
|
993
|
-
}
|
1131
|
+
}
|
994
1132
|
|
995
1133
|
private:
|
996
1134
|
static const rb_data_type_t llama_context_type;
|
@@ -1029,7 +1167,7 @@ private:
|
|
1029
1167
|
rb_iv_set(self, "@has_evaluated", Qfalse);
|
1030
1168
|
|
1031
1169
|
return Qnil;
|
1032
|
-
}
|
1170
|
+
}
|
1033
1171
|
|
1034
1172
|
static VALUE _llama_context_eval(int argc, VALUE* argv, VALUE self) {
|
1035
1173
|
VALUE kw_args = Qnil;
|
@@ -1084,7 +1222,7 @@ private:
|
|
1084
1222
|
rb_iv_set(self, "@has_evaluated", Qtrue);
|
1085
1223
|
|
1086
1224
|
return Qnil;
|
1087
|
-
}
|
1225
|
+
}
|
1088
1226
|
|
1089
1227
|
static VALUE _llama_context_eval_embd(int argc, VALUE* argv, VALUE self) {
|
1090
1228
|
VALUE kw_args = Qnil;
|
@@ -1157,7 +1295,7 @@ private:
|
|
1157
1295
|
}
|
1158
1296
|
RB_GC_GUARD(fname_);
|
1159
1297
|
return Qtrue;
|
1160
|
-
}
|
1298
|
+
}
|
1161
1299
|
|
1162
1300
|
static VALUE _llama_context_tokenize(int argc, VALUE* argv, VALUE self) {
|
1163
1301
|
VALUE kw_args = Qnil;
|
@@ -1203,7 +1341,7 @@ private:
|
|
1203
1341
|
|
1204
1342
|
RB_GC_GUARD(text_);
|
1205
1343
|
return output;
|
1206
|
-
}
|
1344
|
+
}
|
1207
1345
|
|
1208
1346
|
static VALUE _llama_context_token_to_str(VALUE self, VALUE token_) {
|
1209
1347
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
@@ -1214,7 +1352,7 @@ private:
|
|
1214
1352
|
const llama_token token = NUM2INT(token_);
|
1215
1353
|
const char* str = llama_token_to_str(ptr->ctx, token);
|
1216
1354
|
return str != nullptr ? rb_utf8_str_new_cstr(str) : rb_utf8_str_new_cstr("");
|
1217
|
-
}
|
1355
|
+
}
|
1218
1356
|
|
1219
1357
|
static VALUE _llama_context_logits(VALUE self) {
|
1220
1358
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
@@ -1239,7 +1377,7 @@ private:
|
|
1239
1377
|
}
|
1240
1378
|
|
1241
1379
|
return output;
|
1242
|
-
}
|
1380
|
+
}
|
1243
1381
|
|
1244
1382
|
static VALUE _llama_context_embeddings(VALUE self) {
|
1245
1383
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
@@ -1267,7 +1405,7 @@ private:
|
|
1267
1405
|
}
|
1268
1406
|
|
1269
1407
|
return output;
|
1270
|
-
}
|
1408
|
+
}
|
1271
1409
|
|
1272
1410
|
static VALUE _llama_context_vocab(int argc, VALUE* argv, VALUE self) {
|
1273
1411
|
VALUE kw_args = Qnil;
|
@@ -1304,7 +1442,7 @@ private:
|
|
1304
1442
|
}
|
1305
1443
|
|
1306
1444
|
return rb_ary_new_from_args(2, ret_strings, ret_scores);
|
1307
|
-
}
|
1445
|
+
}
|
1308
1446
|
|
1309
1447
|
static VALUE _llama_context_n_vocab(VALUE self) {
|
1310
1448
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
@@ -1313,7 +1451,7 @@ private:
|
|
1313
1451
|
return Qnil;
|
1314
1452
|
}
|
1315
1453
|
return INT2NUM(llama_n_vocab(ptr->ctx));
|
1316
|
-
}
|
1454
|
+
}
|
1317
1455
|
|
1318
1456
|
static VALUE _llama_context_n_ctx(VALUE self) {
|
1319
1457
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
@@ -1322,7 +1460,7 @@ private:
|
|
1322
1460
|
return Qnil;
|
1323
1461
|
}
|
1324
1462
|
return INT2NUM(llama_n_ctx(ptr->ctx));
|
1325
|
-
}
|
1463
|
+
}
|
1326
1464
|
|
1327
1465
|
static VALUE _llama_context_n_embd(VALUE self) {
|
1328
1466
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
@@ -1331,7 +1469,7 @@ private:
|
|
1331
1469
|
return Qnil;
|
1332
1470
|
}
|
1333
1471
|
return INT2NUM(llama_n_embd(ptr->ctx));
|
1334
|
-
}
|
1472
|
+
}
|
1335
1473
|
|
1336
1474
|
static VALUE _llama_context_get_timings(VALUE self) {
|
1337
1475
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
@@ -1353,7 +1491,7 @@ private:
|
|
1353
1491
|
}
|
1354
1492
|
llama_print_timings(ptr->ctx);
|
1355
1493
|
return Qnil;
|
1356
|
-
}
|
1494
|
+
}
|
1357
1495
|
|
1358
1496
|
static VALUE _llama_context_reset_timings(VALUE self) {
|
1359
1497
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
@@ -1363,7 +1501,7 @@ private:
|
|
1363
1501
|
}
|
1364
1502
|
llama_reset_timings(ptr->ctx);
|
1365
1503
|
return Qnil;
|
1366
|
-
}
|
1504
|
+
}
|
1367
1505
|
|
1368
1506
|
static VALUE _llama_context_kv_cache_token_count(VALUE self) {
|
1369
1507
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
@@ -1372,7 +1510,7 @@ private:
|
|
1372
1510
|
return Qnil;
|
1373
1511
|
}
|
1374
1512
|
return INT2NUM(llama_get_kv_cache_token_count(ptr->ctx));
|
1375
|
-
}
|
1513
|
+
}
|
1376
1514
|
|
1377
1515
|
static VALUE _llama_context_set_rng_seed(VALUE self, VALUE seed_) {
|
1378
1516
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
@@ -1387,7 +1525,7 @@ private:
|
|
1387
1525
|
const uint32_t seed = NUM2INT(seed_);
|
1388
1526
|
llama_set_rng_seed(ptr->ctx, seed);
|
1389
1527
|
return Qnil;
|
1390
|
-
}
|
1528
|
+
}
|
1391
1529
|
|
1392
1530
|
static VALUE _llama_context_load_session_file(int argc, VALUE* argv, VALUE self) {
|
1393
1531
|
VALUE kw_args = Qnil;
|
@@ -1525,7 +1663,7 @@ private:
|
|
1525
1663
|
llama_sample_repetition_penalty(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size, penalty);
|
1526
1664
|
|
1527
1665
|
return Qnil;
|
1528
|
-
}
|
1666
|
+
}
|
1529
1667
|
|
1530
1668
|
static VALUE _llama_context_sample_frequency_and_presence_penalties(int argc, VALUE* argv, VALUE self) {
|
1531
1669
|
VALUE kw_args = Qnil;
|
@@ -1576,7 +1714,47 @@ private:
|
|
1576
1714
|
llama_sample_frequency_and_presence_penalties(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size, alpha_frequency, alpha_presence);
|
1577
1715
|
|
1578
1716
|
return Qnil;
|
1579
|
-
}
|
1717
|
+
}
|
1718
|
+
|
1719
|
+
static VALUE _llama_context_sample_classifier_free_guidance(int argc, VALUE* argv, VALUE self) {
|
1720
|
+
VALUE kw_args = Qnil;
|
1721
|
+
ID kw_table[2] = { rb_intern("guidance"), rb_intern("scale") };
|
1722
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
1723
|
+
VALUE candidates = Qnil;
|
1724
|
+
rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
|
1725
|
+
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
1726
|
+
|
1727
|
+
if (!rb_obj_is_kind_of(kw_values[0], rb_cLLaMAContext)) {
|
1728
|
+
rb_raise(rb_eArgError, "guidance must be a Context");
|
1729
|
+
return Qnil;
|
1730
|
+
}
|
1731
|
+
if (!RB_FLOAT_TYPE_P(kw_values[1])) {
|
1732
|
+
rb_raise(rb_eArgError, "scale must be a float");
|
1733
|
+
return Qnil;
|
1734
|
+
}
|
1735
|
+
|
1736
|
+
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
1737
|
+
if (ctx_ptr->ctx == NULL) {
|
1738
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1739
|
+
return Qnil;
|
1740
|
+
}
|
1741
|
+
LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
|
1742
|
+
if (cnd_ptr->array.data == nullptr) {
|
1743
|
+
rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
|
1744
|
+
return Qnil;
|
1745
|
+
}
|
1746
|
+
|
1747
|
+
LLaMAContextWrapper* guidance_ptr = get_llama_context(kw_values[0]);
|
1748
|
+
if (guidance_ptr->ctx == NULL) {
|
1749
|
+
rb_raise(rb_eRuntimeError, "guidance context is not initialized");
|
1750
|
+
return Qnil;
|
1751
|
+
}
|
1752
|
+
const float scale = NUM2DBL(kw_values[1]);
|
1753
|
+
|
1754
|
+
llama_sample_classifier_free_guidance(ctx_ptr->ctx, &(cnd_ptr->array), guidance_ptr->ctx, scale);
|
1755
|
+
|
1756
|
+
return Qnil;
|
1757
|
+
}
|
1580
1758
|
|
1581
1759
|
static VALUE _llama_context_sample_softmax(VALUE self, VALUE candidates) {
|
1582
1760
|
if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
|
@@ -1598,7 +1776,7 @@ private:
|
|
1598
1776
|
llama_sample_softmax(ctx_ptr->ctx, &(cnd_ptr->array));
|
1599
1777
|
|
1600
1778
|
return Qnil;
|
1601
|
-
}
|
1779
|
+
}
|
1602
1780
|
|
1603
1781
|
static VALUE _llama_context_sample_top_k(int argc, VALUE* argv, VALUE self) {
|
1604
1782
|
VALUE kw_args = Qnil;
|
@@ -1637,7 +1815,7 @@ private:
|
|
1637
1815
|
llama_sample_top_k(ctx_ptr->ctx, &(cnd_ptr->array), k, min_keep);
|
1638
1816
|
|
1639
1817
|
return Qnil;
|
1640
|
-
}
|
1818
|
+
}
|
1641
1819
|
|
1642
1820
|
static VALUE _llama_context_sample_top_p(int argc, VALUE* argv, VALUE self) {
|
1643
1821
|
VALUE kw_args = Qnil;
|
@@ -1676,7 +1854,7 @@ private:
|
|
1676
1854
|
llama_sample_top_p(ctx_ptr->ctx, &(cnd_ptr->array), prob, min_keep);
|
1677
1855
|
|
1678
1856
|
return Qnil;
|
1679
|
-
}
|
1857
|
+
}
|
1680
1858
|
|
1681
1859
|
static VALUE _llama_context_sample_tail_free(int argc, VALUE* argv, VALUE self) {
|
1682
1860
|
VALUE kw_args = Qnil;
|
@@ -1715,7 +1893,7 @@ private:
|
|
1715
1893
|
llama_sample_tail_free(ctx_ptr->ctx, &(cnd_ptr->array), z, min_keep);
|
1716
1894
|
|
1717
1895
|
return Qnil;
|
1718
|
-
}
|
1896
|
+
}
|
1719
1897
|
|
1720
1898
|
static VALUE _llama_context_sample_typical(int argc, VALUE* argv, VALUE self) {
|
1721
1899
|
VALUE kw_args = Qnil;
|
@@ -1754,7 +1932,7 @@ private:
|
|
1754
1932
|
llama_sample_typical(ctx_ptr->ctx, &(cnd_ptr->array), prob, min_keep);
|
1755
1933
|
|
1756
1934
|
return Qnil;
|
1757
|
-
}
|
1935
|
+
}
|
1758
1936
|
|
1759
1937
|
static VALUE _llama_context_sample_temperature(int argc, VALUE* argv, VALUE self) {
|
1760
1938
|
VALUE kw_args = Qnil;
|
@@ -1788,7 +1966,7 @@ private:
|
|
1788
1966
|
llama_sample_temperature(ctx_ptr->ctx, &(cnd_ptr->array), temperature);
|
1789
1967
|
|
1790
1968
|
return Qnil;
|
1791
|
-
}
|
1969
|
+
}
|
1792
1970
|
|
1793
1971
|
static VALUE _llama_context_sample_token_mirostat(int argc, VALUE* argv, VALUE self) {
|
1794
1972
|
VALUE kw_args = Qnil;
|
@@ -1840,7 +2018,7 @@ private:
|
|
1840
2018
|
rb_ary_store(ret, 0, INT2NUM(id));
|
1841
2019
|
rb_ary_store(ret, 1, DBL2NUM(mu));
|
1842
2020
|
return ret;
|
1843
|
-
}
|
2021
|
+
}
|
1844
2022
|
|
1845
2023
|
static VALUE _llama_context_sample_token_mirostat_v2(int argc, VALUE* argv, VALUE self) {
|
1846
2024
|
VALUE kw_args = Qnil;
|
@@ -1887,7 +2065,7 @@ private:
|
|
1887
2065
|
rb_ary_store(ret, 0, INT2NUM(id));
|
1888
2066
|
rb_ary_store(ret, 1, DBL2NUM(mu));
|
1889
2067
|
return ret;
|
1890
|
-
}
|
2068
|
+
}
|
1891
2069
|
|
1892
2070
|
static VALUE _llama_context_sample_token_greedy(VALUE self, VALUE candidates) {
|
1893
2071
|
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
@@ -1906,7 +2084,7 @@ private:
|
|
1906
2084
|
}
|
1907
2085
|
llama_token id = llama_sample_token_greedy(ctx_ptr->ctx, &(cnd_ptr->array));
|
1908
2086
|
return INT2NUM(id);
|
1909
|
-
}
|
2087
|
+
}
|
1910
2088
|
|
1911
2089
|
static VALUE _llama_context_sample_token(VALUE self, VALUE candidates) {
|
1912
2090
|
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
@@ -1925,7 +2103,7 @@ private:
|
|
1925
2103
|
}
|
1926
2104
|
llama_token id = llama_sample_token(ctx_ptr->ctx, &(cnd_ptr->array));
|
1927
2105
|
return INT2NUM(id);
|
1928
|
-
}
|
2106
|
+
}
|
1929
2107
|
};
|
1930
2108
|
|
1931
2109
|
const rb_data_type_t RbLLaMAContext::llama_context_type = {
|
@@ -1940,7 +2118,7 @@ const rb_data_type_t RbLLaMAContext::llama_context_type = {
|
|
1940
2118
|
|
1941
2119
|
// module functions
|
1942
2120
|
|
1943
|
-
static VALUE
|
2121
|
+
static VALUE rb_llama_llama_backend_init(int argc, VALUE* argv, VALUE self) {
|
1944
2122
|
VALUE kw_args = Qnil;
|
1945
2123
|
ID kw_table[1] = { rb_intern("numa") };
|
1946
2124
|
VALUE kw_values[1] = { Qundef };
|
@@ -1948,7 +2126,13 @@ static VALUE rb_llama_llama_init_backend(int argc, VALUE* argv, VALUE self) {
|
|
1948
2126
|
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
1949
2127
|
|
1950
2128
|
const bool numa = kw_values[0] == Qundef ? false : (RTEST ? true : false);
|
1951
|
-
|
2129
|
+
llama_backend_init(numa);
|
2130
|
+
|
2131
|
+
return Qnil;
|
2132
|
+
}
|
2133
|
+
|
2134
|
+
static VALUE rb_llama_llama_backend_free(VALUE self) {
|
2135
|
+
llama_backend_free();
|
1952
2136
|
|
1953
2137
|
return Qnil;
|
1954
2138
|
}
|
@@ -2010,6 +2194,10 @@ static VALUE rb_llama_mlock_supported(VALUE self) {
|
|
2010
2194
|
return llama_mlock_supported() ? Qtrue : Qfalse;
|
2011
2195
|
}
|
2012
2196
|
|
2197
|
+
static VALUE rb_llama_max_devices(VALUE self) {
|
2198
|
+
return INT2NUM(llama_max_devices());
|
2199
|
+
}
|
2200
|
+
|
2013
2201
|
extern "C" void Init_llama_cpp(void) {
|
2014
2202
|
rb_mLLaMACpp = rb_define_module("LLaMACpp");
|
2015
2203
|
|
@@ -2021,7 +2209,8 @@ extern "C" void Init_llama_cpp(void) {
|
|
2021
2209
|
RbLLaMAContextParams::define_class(rb_mLLaMACpp);
|
2022
2210
|
RbLLaMAModelQuantizeParams::define_class(rb_mLLaMACpp);
|
2023
2211
|
|
2024
|
-
rb_define_module_function(rb_mLLaMACpp, "
|
2212
|
+
rb_define_module_function(rb_mLLaMACpp, "backend_init", rb_llama_llama_backend_init, -1);
|
2213
|
+
rb_define_module_function(rb_mLLaMACpp, "backend_free", rb_llama_llama_backend_free, 0);
|
2025
2214
|
rb_define_module_function(rb_mLLaMACpp, "model_quantize", rb_llama_model_quantize, -1);
|
2026
2215
|
rb_define_module_function(rb_mLLaMACpp, "token_bos", rb_llama_token_bos, 0);
|
2027
2216
|
rb_define_module_function(rb_mLLaMACpp, "token_eos", rb_llama_token_eos, 0);
|
@@ -2029,6 +2218,7 @@ extern "C" void Init_llama_cpp(void) {
|
|
2029
2218
|
rb_define_module_function(rb_mLLaMACpp, "print_system_info", rb_llama_print_system_info, 0);
|
2030
2219
|
rb_define_module_function(rb_mLLaMACpp, "mmap_supported?", rb_llama_mmap_supported, 0);
|
2031
2220
|
rb_define_module_function(rb_mLLaMACpp, "mlock_supported?", rb_llama_mlock_supported, 0);
|
2221
|
+
rb_define_module_function(rb_mLLaMACpp, "max_devices", rb_llama_max_devices, 0);
|
2032
2222
|
|
2033
2223
|
rb_define_const(rb_mLLaMACpp, "LLAMA_MAX_DEVICES", INT2NUM(LLAMA_MAX_DEVICES));
|
2034
2224
|
|