llama_cpp 0.3.2 → 0.3.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +13 -0
- data/ext/llama_cpp/extconf.rb +9 -0
- data/ext/llama_cpp/llama_cpp.cpp +165 -112
- data/ext/llama_cpp/src/ggml-cuda.cu +217 -76
- data/ext/llama_cpp/src/ggml-metal.h +5 -1
- data/ext/llama_cpp/src/ggml-metal.m +16 -5
- data/ext/llama_cpp/src/ggml-metal.metal +56 -47
- data/ext/llama_cpp/src/ggml-mpi.c +216 -0
- data/ext/llama_cpp/src/ggml-mpi.h +39 -0
- data/ext/llama_cpp/src/ggml.c +1082 -774
- data/ext/llama_cpp/src/ggml.h +64 -18
- data/ext/llama_cpp/src/llama.cpp +179 -51
- data/ext/llama_cpp/src/llama.h +15 -1
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +1 -1
- data/sig/llama_cpp.rbs +3 -1
- metadata +4 -2
data/ext/llama_cpp/llama_cpp.cpp
CHANGED
@@ -17,9 +17,9 @@ public:
|
|
17
17
|
data.id = 0;
|
18
18
|
data.logit = 0.0;
|
19
19
|
data.p = 0.0;
|
20
|
-
}
|
20
|
+
}
|
21
21
|
|
22
|
-
~LLaMATokenDataWrapper(){}
|
22
|
+
~LLaMATokenDataWrapper() {}
|
23
23
|
};
|
24
24
|
|
25
25
|
class RbLLaMATokenData {
|
@@ -28,22 +28,22 @@ public:
|
|
28
28
|
LLaMATokenDataWrapper* ptr = (LLaMATokenDataWrapper*)ruby_xmalloc(sizeof(LLaMATokenDataWrapper));
|
29
29
|
new (ptr) LLaMATokenDataWrapper();
|
30
30
|
return TypedData_Wrap_Struct(self, &llama_token_data_type, ptr);
|
31
|
-
}
|
31
|
+
}
|
32
32
|
|
33
33
|
static void llama_token_data_free(void* ptr) {
|
34
34
|
((LLaMATokenDataWrapper*)ptr)->~LLaMATokenDataWrapper();
|
35
35
|
ruby_xfree(ptr);
|
36
|
-
}
|
36
|
+
}
|
37
37
|
|
38
38
|
static size_t llama_token_data_size(const void* ptr) {
|
39
39
|
return sizeof(*((LLaMATokenDataWrapper*)ptr));
|
40
|
-
}
|
40
|
+
}
|
41
41
|
|
42
42
|
static LLaMATokenDataWrapper* get_llama_token_data(VALUE self) {
|
43
43
|
LLaMATokenDataWrapper* ptr;
|
44
44
|
TypedData_Get_Struct(self, LLaMATokenDataWrapper, &llama_token_data_type, ptr);
|
45
45
|
return ptr;
|
46
|
-
}
|
46
|
+
}
|
47
47
|
|
48
48
|
static void define_class(VALUE outer) {
|
49
49
|
rb_cLLaMATokenData = rb_define_class_under(outer, "TokenData", rb_cObject);
|
@@ -95,36 +95,36 @@ private:
|
|
95
95
|
LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
|
96
96
|
ptr->data.id = NUM2INT(id);
|
97
97
|
return INT2NUM(ptr->data.id);
|
98
|
-
}
|
98
|
+
}
|
99
99
|
|
100
100
|
static VALUE _llama_token_data_get_id(VALUE self) {
|
101
101
|
LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
|
102
102
|
return INT2NUM(ptr->data.id);
|
103
|
-
}
|
103
|
+
}
|
104
104
|
|
105
105
|
// logit
|
106
106
|
static VALUE _llama_token_data_set_logit(VALUE self, VALUE logit) {
|
107
107
|
LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
|
108
108
|
ptr->data.logit = NUM2DBL(logit);
|
109
109
|
return DBL2NUM(ptr->data.logit);
|
110
|
-
}
|
110
|
+
}
|
111
111
|
|
112
112
|
static VALUE _llama_token_data_get_logit(VALUE self) {
|
113
113
|
LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
|
114
114
|
return DBL2NUM(ptr->data.logit);
|
115
|
-
}
|
115
|
+
}
|
116
116
|
|
117
117
|
// p
|
118
118
|
static VALUE _llama_token_data_set_p(VALUE self, VALUE p) {
|
119
119
|
LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
|
120
120
|
ptr->data.p = NUM2DBL(p);
|
121
121
|
return DBL2NUM(ptr->data.p);
|
122
|
-
}
|
122
|
+
}
|
123
123
|
|
124
124
|
static VALUE _llama_token_data_get_p(VALUE self) {
|
125
125
|
LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
|
126
126
|
return DBL2NUM(ptr->data.p);
|
127
|
-
}
|
127
|
+
}
|
128
128
|
};
|
129
129
|
|
130
130
|
const rb_data_type_t RbLLaMATokenData::llama_token_data_type = {
|
@@ -145,14 +145,14 @@ public:
|
|
145
145
|
array.data = nullptr;
|
146
146
|
array.size = 0;
|
147
147
|
array.sorted = false;
|
148
|
-
}
|
148
|
+
}
|
149
149
|
|
150
150
|
~LLaMATokenDataArrayWrapper() {
|
151
151
|
if (array.data) {
|
152
152
|
ruby_xfree(array.data);
|
153
153
|
array.data = nullptr;
|
154
154
|
}
|
155
|
-
}
|
155
|
+
}
|
156
156
|
};
|
157
157
|
|
158
158
|
class RbLLaMATokenDataArray {
|
@@ -161,22 +161,22 @@ public:
|
|
161
161
|
LLaMATokenDataArrayWrapper* ptr = (LLaMATokenDataArrayWrapper*)ruby_xmalloc(sizeof(LLaMATokenDataArrayWrapper));
|
162
162
|
new (ptr) LLaMATokenDataArrayWrapper();
|
163
163
|
return TypedData_Wrap_Struct(self, &llama_token_data_array_type, ptr);
|
164
|
-
}
|
164
|
+
}
|
165
165
|
|
166
166
|
static void llama_token_data_array_free(void* ptr) {
|
167
167
|
((LLaMATokenDataArrayWrapper*)ptr)->~LLaMATokenDataArrayWrapper();
|
168
168
|
ruby_xfree(ptr);
|
169
|
-
}
|
169
|
+
}
|
170
170
|
|
171
171
|
static size_t llama_token_data_array_size(const void* ptr) {
|
172
172
|
return sizeof(*((LLaMATokenDataArrayWrapper*)ptr));
|
173
|
-
}
|
173
|
+
}
|
174
174
|
|
175
175
|
static LLaMATokenDataArrayWrapper* get_llama_token_data_array(VALUE self) {
|
176
176
|
LLaMATokenDataArrayWrapper* ptr;
|
177
177
|
TypedData_Get_Struct(self, LLaMATokenDataArrayWrapper, &llama_token_data_array_type, ptr);
|
178
178
|
return ptr;
|
179
|
-
}
|
179
|
+
}
|
180
180
|
|
181
181
|
static void define_class(VALUE outer) {
|
182
182
|
rb_cLLaMATokenDataArray = rb_define_class_under(outer, "TokenDataArray", rb_cObject);
|
@@ -184,7 +184,7 @@ public:
|
|
184
184
|
rb_define_method(rb_cLLaMATokenDataArray, "initialize", RUBY_METHOD_FUNC(_llama_token_data_array_init), -1);
|
185
185
|
rb_define_method(rb_cLLaMATokenDataArray, "size", RUBY_METHOD_FUNC(_llama_token_data_array_get_size), 0);
|
186
186
|
rb_define_method(rb_cLLaMATokenDataArray, "sorted", RUBY_METHOD_FUNC(_llama_token_data_array_get_sorted), 0);
|
187
|
-
}
|
187
|
+
}
|
188
188
|
|
189
189
|
private:
|
190
190
|
static const rb_data_type_t llama_token_data_array_type;
|
@@ -233,17 +233,17 @@ private:
|
|
233
233
|
ptr->array.sorted = kw_values[0] == Qtrue;
|
234
234
|
|
235
235
|
return self;
|
236
|
-
}
|
236
|
+
}
|
237
237
|
|
238
238
|
static VALUE _llama_token_data_array_get_size(VALUE self) {
|
239
239
|
LLaMATokenDataArrayWrapper* ptr = get_llama_token_data_array(self);
|
240
240
|
return SIZET2NUM(ptr->array.size);
|
241
|
-
}
|
241
|
+
}
|
242
242
|
|
243
243
|
static VALUE _llama_token_data_array_get_sorted(VALUE self) {
|
244
244
|
LLaMATokenDataArrayWrapper* ptr = get_llama_token_data_array(self);
|
245
245
|
return ptr->array.sorted ? Qtrue : Qfalse;
|
246
|
-
}
|
246
|
+
}
|
247
247
|
};
|
248
248
|
|
249
249
|
const rb_data_type_t RbLLaMATokenDataArray::llama_token_data_array_type = {
|
@@ -260,9 +260,9 @@ class LLaMATimingsWrapper {
|
|
260
260
|
public:
|
261
261
|
struct llama_timings timings;
|
262
262
|
|
263
|
-
LLaMATimingsWrapper(){}
|
263
|
+
LLaMATimingsWrapper() {}
|
264
264
|
|
265
|
-
~LLaMATimingsWrapper(){}
|
265
|
+
~LLaMATimingsWrapper() {}
|
266
266
|
};
|
267
267
|
|
268
268
|
class RbLLaMATimings {
|
@@ -365,9 +365,9 @@ class LLaMAContextParamsWrapper {
|
|
365
365
|
public:
|
366
366
|
struct llama_context_params params;
|
367
367
|
|
368
|
-
LLaMAContextParamsWrapper() : params(llama_context_default_params()){}
|
368
|
+
LLaMAContextParamsWrapper() : params(llama_context_default_params()) {}
|
369
369
|
|
370
|
-
~LLaMAContextParamsWrapper(){}
|
370
|
+
~LLaMAContextParamsWrapper() {}
|
371
371
|
};
|
372
372
|
|
373
373
|
class RbLLaMAContextParams {
|
@@ -376,22 +376,22 @@ public:
|
|
376
376
|
LLaMAContextParamsWrapper* ptr = (LLaMAContextParamsWrapper*)ruby_xmalloc(sizeof(LLaMAContextParamsWrapper));
|
377
377
|
new (ptr) LLaMAContextParamsWrapper();
|
378
378
|
return TypedData_Wrap_Struct(self, &llama_context_params_type, ptr);
|
379
|
-
}
|
379
|
+
}
|
380
380
|
|
381
381
|
static void llama_context_params_free(void* ptr) {
|
382
382
|
((LLaMAContextParamsWrapper*)ptr)->~LLaMAContextParamsWrapper();
|
383
383
|
ruby_xfree(ptr);
|
384
|
-
}
|
384
|
+
}
|
385
385
|
|
386
386
|
static size_t llama_context_params_size(const void* ptr) {
|
387
387
|
return sizeof(*((LLaMAContextParamsWrapper*)ptr));
|
388
|
-
}
|
388
|
+
}
|
389
389
|
|
390
390
|
static LLaMAContextParamsWrapper* get_llama_context_params(VALUE self) {
|
391
391
|
LLaMAContextParamsWrapper* ptr;
|
392
392
|
TypedData_Get_Struct(self, LLaMAContextParamsWrapper, &llama_context_params_type, ptr);
|
393
393
|
return ptr;
|
394
|
-
}
|
394
|
+
}
|
395
395
|
|
396
396
|
static void define_class(VALUE outer) {
|
397
397
|
rb_cLLaMAContextParams = rb_define_class_under(outer, "ContextParams", rb_cObject);
|
@@ -422,7 +422,7 @@ public:
|
|
422
422
|
rb_define_method(rb_cLLaMAContextParams, "use_mlock", RUBY_METHOD_FUNC(_llama_context_params_get_use_mlock), 0);
|
423
423
|
rb_define_method(rb_cLLaMAContextParams, "embedding=", RUBY_METHOD_FUNC(_llama_context_params_set_embedding), 1);
|
424
424
|
rb_define_method(rb_cLLaMAContextParams, "embedding", RUBY_METHOD_FUNC(_llama_context_params_get_embedding), 0);
|
425
|
-
}
|
425
|
+
}
|
426
426
|
|
427
427
|
private:
|
428
428
|
static const rb_data_type_t llama_context_params_type;
|
@@ -431,55 +431,55 @@ private:
|
|
431
431
|
// LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
432
432
|
// new (ptr) LLaMAContextParamsWrapper();
|
433
433
|
// return self;
|
434
|
-
// }
|
434
|
+
// }
|
435
435
|
|
436
436
|
// n_ctx
|
437
437
|
static VALUE _llama_context_params_set_n_ctx(VALUE self, VALUE n_ctx) {
|
438
438
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
439
439
|
ptr->params.n_ctx = NUM2INT(n_ctx);
|
440
440
|
return INT2NUM(ptr->params.n_ctx);
|
441
|
-
}
|
441
|
+
}
|
442
442
|
|
443
443
|
static VALUE _llama_context_params_get_n_ctx(VALUE self) {
|
444
444
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
445
445
|
return INT2NUM(ptr->params.n_ctx);
|
446
|
-
}
|
446
|
+
}
|
447
447
|
|
448
448
|
// n_batch
|
449
449
|
static VALUE _llama_context_params_set_n_batch(VALUE self, VALUE n_batch) {
|
450
450
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
451
451
|
ptr->params.n_batch = NUM2INT(n_batch);
|
452
452
|
return INT2NUM(ptr->params.n_batch);
|
453
|
-
}
|
453
|
+
}
|
454
454
|
|
455
455
|
static VALUE _llama_context_params_get_n_batch(VALUE self) {
|
456
456
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
457
457
|
return INT2NUM(ptr->params.n_batch);
|
458
|
-
}
|
458
|
+
}
|
459
459
|
|
460
460
|
// n_gpu_layers
|
461
461
|
static VALUE _llama_context_params_set_n_gpu_layers(VALUE self, VALUE n_gpu_layers) {
|
462
462
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
463
463
|
ptr->params.n_gpu_layers = NUM2INT(n_gpu_layers);
|
464
464
|
return INT2NUM(ptr->params.n_gpu_layers);
|
465
|
-
}
|
465
|
+
}
|
466
466
|
|
467
467
|
static VALUE _llama_context_params_get_n_gpu_layers(VALUE self) {
|
468
468
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
469
469
|
return INT2NUM(ptr->params.n_gpu_layers);
|
470
|
-
}
|
470
|
+
}
|
471
471
|
|
472
472
|
// main_gpu
|
473
473
|
static VALUE _llama_context_params_set_main_gpu(VALUE self, VALUE main_gpu) {
|
474
474
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
475
475
|
ptr->params.main_gpu = NUM2INT(main_gpu);
|
476
476
|
return INT2NUM(ptr->params.main_gpu);
|
477
|
-
}
|
477
|
+
}
|
478
478
|
|
479
479
|
static VALUE _llama_context_params_get_main_gpu(VALUE self) {
|
480
480
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
481
481
|
return INT2NUM(ptr->params.main_gpu);
|
482
|
-
}
|
482
|
+
}
|
483
483
|
|
484
484
|
// tensor_split
|
485
485
|
static VALUE _llama_context_params_get_tensor_split(VALUE self) {
|
@@ -492,19 +492,19 @@ private:
|
|
492
492
|
rb_ary_store(ret, i, DBL2NUM(ptr->params.tensor_split[i]));
|
493
493
|
}
|
494
494
|
return ret;
|
495
|
-
}
|
495
|
+
}
|
496
496
|
|
497
497
|
// low_vram
|
498
498
|
static VALUE _llama_context_params_set_low_vram(VALUE self, VALUE low_vram) {
|
499
499
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
500
500
|
ptr->params.low_vram = low_vram == Qtrue ? true : false;
|
501
501
|
return ptr->params.low_vram ? Qtrue : Qfalse;
|
502
|
-
}
|
502
|
+
}
|
503
503
|
|
504
504
|
static VALUE _llama_context_params_get_low_vram(VALUE self) {
|
505
505
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
506
506
|
return ptr->params.low_vram ? Qtrue : Qfalse;
|
507
|
-
}
|
507
|
+
}
|
508
508
|
|
509
509
|
// seed
|
510
510
|
static VALUE _llama_context_params_set_seed(VALUE self, VALUE seed) {
|
@@ -515,84 +515,84 @@ private:
|
|
515
515
|
}
|
516
516
|
ptr->params.seed = NUM2INT(seed);
|
517
517
|
return INT2NUM(ptr->params.seed);
|
518
|
-
}
|
518
|
+
}
|
519
519
|
|
520
520
|
static VALUE _llama_context_params_get_seed(VALUE self) {
|
521
521
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
522
522
|
return INT2NUM(ptr->params.seed);
|
523
|
-
}
|
523
|
+
}
|
524
524
|
|
525
525
|
// f16_kv
|
526
526
|
static VALUE _llama_context_params_set_f16_kv(VALUE self, VALUE f16_kv) {
|
527
527
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
528
528
|
ptr->params.f16_kv = f16_kv == Qtrue ? true : false;
|
529
529
|
return ptr->params.f16_kv ? Qtrue : Qfalse;
|
530
|
-
}
|
530
|
+
}
|
531
531
|
|
532
532
|
static VALUE _llama_context_params_get_f16_kv(VALUE self) {
|
533
533
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
534
534
|
return ptr->params.f16_kv ? Qtrue : Qfalse;
|
535
|
-
}
|
535
|
+
}
|
536
536
|
|
537
537
|
// logits_all
|
538
538
|
static VALUE _llama_context_params_set_logits_all(VALUE self, VALUE logits_all) {
|
539
539
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
540
540
|
ptr->params.logits_all = logits_all == Qtrue ? true : false;
|
541
541
|
return ptr->params.logits_all ? Qtrue : Qfalse;
|
542
|
-
}
|
542
|
+
}
|
543
543
|
|
544
544
|
static VALUE _llama_context_params_get_logits_all(VALUE self) {
|
545
545
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
546
546
|
return ptr->params.logits_all ? Qtrue : Qfalse;
|
547
|
-
}
|
547
|
+
}
|
548
548
|
|
549
549
|
// vocab_only
|
550
550
|
static VALUE _llama_context_params_set_vocab_only(VALUE self, VALUE vocab_only) {
|
551
551
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
552
552
|
ptr->params.vocab_only = vocab_only == Qtrue ? true : false;
|
553
553
|
return ptr->params.vocab_only ? Qtrue : Qfalse;
|
554
|
-
}
|
554
|
+
}
|
555
555
|
|
556
556
|
static VALUE _llama_context_params_get_vocab_only(VALUE self) {
|
557
557
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
558
558
|
return ptr->params.vocab_only ? Qtrue : Qfalse;
|
559
|
-
}
|
559
|
+
}
|
560
560
|
|
561
561
|
// use_mmap
|
562
562
|
static VALUE _llama_context_params_set_use_mmap(VALUE self, VALUE use_mmap) {
|
563
563
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
564
564
|
ptr->params.use_mmap = use_mmap == Qtrue ? true : false;
|
565
565
|
return ptr->params.use_mmap ? Qtrue : Qfalse;
|
566
|
-
}
|
566
|
+
}
|
567
567
|
|
568
568
|
static VALUE _llama_context_params_get_use_mmap(VALUE self) {
|
569
569
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
570
570
|
return ptr->params.use_mmap ? Qtrue : Qfalse;
|
571
|
-
}
|
571
|
+
}
|
572
572
|
|
573
573
|
// use_mlock
|
574
574
|
static VALUE _llama_context_params_set_use_mlock(VALUE self, VALUE use_mlock) {
|
575
575
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
576
576
|
ptr->params.use_mlock = use_mlock == Qtrue ? true : false;
|
577
577
|
return ptr->params.use_mlock ? Qtrue : Qfalse;
|
578
|
-
}
|
578
|
+
}
|
579
579
|
|
580
580
|
static VALUE _llama_context_params_get_use_mlock(VALUE self) {
|
581
581
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
582
582
|
return ptr->params.use_mlock ? Qtrue : Qfalse;
|
583
|
-
}
|
583
|
+
}
|
584
584
|
|
585
585
|
// embedding
|
586
586
|
static VALUE _llama_context_params_set_embedding(VALUE self, VALUE embedding) {
|
587
587
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
588
588
|
ptr->params.embedding = embedding == Qtrue ? true : false;
|
589
589
|
return ptr->params.embedding ? Qtrue : Qfalse;
|
590
|
-
}
|
590
|
+
}
|
591
591
|
|
592
592
|
static VALUE _llama_context_params_get_embedding(VALUE self) {
|
593
593
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
594
594
|
return ptr->params.embedding ? Qtrue : Qfalse;
|
595
|
-
}
|
595
|
+
}
|
596
596
|
};
|
597
597
|
|
598
598
|
const rb_data_type_t RbLLaMAContextParams::llama_context_params_type = {
|
@@ -609,9 +609,9 @@ class LLaMAModelQuantizeParamsWrapper {
|
|
609
609
|
public:
|
610
610
|
llama_model_quantize_params params;
|
611
611
|
|
612
|
-
LLaMAModelQuantizeParamsWrapper() : params(llama_model_quantize_default_params()){}
|
612
|
+
LLaMAModelQuantizeParamsWrapper() : params(llama_model_quantize_default_params()) {}
|
613
613
|
|
614
|
-
~LLaMAModelQuantizeParamsWrapper(){}
|
614
|
+
~LLaMAModelQuantizeParamsWrapper() {}
|
615
615
|
};
|
616
616
|
|
617
617
|
class RbLLaMAModelQuantizeParams {
|
@@ -620,22 +620,22 @@ public:
|
|
620
620
|
LLaMAModelQuantizeParamsWrapper* ptr = (LLaMAModelQuantizeParamsWrapper*)ruby_xmalloc(sizeof(LLaMAModelQuantizeParamsWrapper));
|
621
621
|
new (ptr) LLaMAModelQuantizeParamsWrapper();
|
622
622
|
return TypedData_Wrap_Struct(self, &llama_model_quantize_params_type, ptr);
|
623
|
-
}
|
623
|
+
}
|
624
624
|
|
625
625
|
static void llama_model_quantize_params_free(void* ptr) {
|
626
626
|
((LLaMAModelQuantizeParamsWrapper*)ptr)->~LLaMAModelQuantizeParamsWrapper();
|
627
627
|
ruby_xfree(ptr);
|
628
|
-
}
|
628
|
+
}
|
629
629
|
|
630
630
|
static size_t llama_model_quantize_params_size(const void* ptr) {
|
631
631
|
return sizeof(*((LLaMAModelQuantizeParamsWrapper*)ptr));
|
632
|
-
}
|
632
|
+
}
|
633
633
|
|
634
634
|
static LLaMAModelQuantizeParamsWrapper* get_llama_model_quantize_params(VALUE self) {
|
635
635
|
LLaMAModelQuantizeParamsWrapper* ptr;
|
636
636
|
TypedData_Get_Struct(self, LLaMAModelQuantizeParamsWrapper, &llama_model_quantize_params_type, ptr);
|
637
637
|
return ptr;
|
638
|
-
}
|
638
|
+
}
|
639
639
|
|
640
640
|
static void define_class(VALUE outer) {
|
641
641
|
rb_cLLaMAModelQuantizeParams = rb_define_class_under(outer, "ModelQuantizeParams", rb_cObject);
|
@@ -648,7 +648,7 @@ public:
|
|
648
648
|
rb_define_method(rb_cLLaMAModelQuantizeParams, "allow_requantize", RUBY_METHOD_FUNC(_llama_model_quantize_params_get_allow_requantize), 0);
|
649
649
|
rb_define_method(rb_cLLaMAModelQuantizeParams, "quantize_output_tensor=", RUBY_METHOD_FUNC(_llama_model_quantize_params_set_quantize_output_tensor), 1);
|
650
650
|
rb_define_method(rb_cLLaMAModelQuantizeParams, "quantize_output_tensor", RUBY_METHOD_FUNC(_llama_model_quantize_params_get_quantize_output_tensor), 0);
|
651
|
-
}
|
651
|
+
}
|
652
652
|
|
653
653
|
private:
|
654
654
|
static const rb_data_type_t llama_model_quantize_params_type;
|
@@ -658,24 +658,24 @@ private:
|
|
658
658
|
LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
|
659
659
|
ptr->params.nthread = NUM2INT(n_thread);
|
660
660
|
return INT2NUM(ptr->params.nthread);
|
661
|
-
}
|
661
|
+
}
|
662
662
|
|
663
663
|
static VALUE _llama_model_quantize_params_get_n_thread(VALUE self) {
|
664
664
|
LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
|
665
665
|
return INT2NUM(ptr->params.nthread);
|
666
|
-
}
|
666
|
+
}
|
667
667
|
|
668
668
|
// ftype
|
669
669
|
static VALUE _llama_model_quantize_params_set_ftype(VALUE self, VALUE ftype) {
|
670
670
|
LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
|
671
671
|
ptr->params.ftype = static_cast<enum llama_ftype>(NUM2INT(ftype));
|
672
672
|
return INT2NUM(ptr->params.ftype);
|
673
|
-
}
|
673
|
+
}
|
674
674
|
|
675
675
|
static VALUE _llama_model_quantize_params_get_ftype(VALUE self) {
|
676
676
|
LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
|
677
677
|
return INT2NUM(ptr->params.ftype);
|
678
|
-
}
|
678
|
+
}
|
679
679
|
|
680
680
|
// allow_requantize
|
681
681
|
static VALUE _llama_model_quantize_params_set_allow_requantize(VALUE self, VALUE allow_requantize) {
|
@@ -686,12 +686,12 @@ private:
|
|
686
686
|
ptr->params.allow_requantize = true;
|
687
687
|
}
|
688
688
|
return ptr->params.allow_requantize ? Qtrue : Qfalse;
|
689
|
-
}
|
689
|
+
}
|
690
690
|
|
691
691
|
static VALUE _llama_model_quantize_params_get_allow_requantize(VALUE self) {
|
692
692
|
LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
|
693
693
|
return ptr->params.allow_requantize ? Qtrue : Qfalse;
|
694
|
-
}
|
694
|
+
}
|
695
695
|
|
696
696
|
// quantize_output_tensor
|
697
697
|
static VALUE _llama_model_quantize_params_set_quantize_output_tensor(VALUE self, VALUE quantize_output_tensor) {
|
@@ -702,12 +702,12 @@ private:
|
|
702
702
|
ptr->params.quantize_output_tensor = true;
|
703
703
|
}
|
704
704
|
return ptr->params.quantize_output_tensor ? Qtrue : Qfalse;
|
705
|
-
}
|
705
|
+
}
|
706
706
|
|
707
707
|
static VALUE _llama_model_quantize_params_get_quantize_output_tensor(VALUE self) {
|
708
708
|
LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
|
709
709
|
return ptr->params.quantize_output_tensor ? Qtrue : Qfalse;
|
710
|
-
}
|
710
|
+
}
|
711
711
|
};
|
712
712
|
|
713
713
|
const rb_data_type_t RbLLaMAModelQuantizeParams::llama_model_quantize_params_type = {
|
@@ -724,13 +724,13 @@ class LLaMAModelWrapper {
|
|
724
724
|
public:
|
725
725
|
struct llama_model* model;
|
726
726
|
|
727
|
-
LLaMAModelWrapper() : model(NULL){}
|
727
|
+
LLaMAModelWrapper() : model(NULL) {}
|
728
728
|
|
729
729
|
~LLaMAModelWrapper() {
|
730
730
|
if (model != NULL) {
|
731
731
|
llama_free_model(model);
|
732
732
|
}
|
733
|
-
}
|
733
|
+
}
|
734
734
|
};
|
735
735
|
|
736
736
|
class RbLLaMAModel {
|
@@ -907,7 +907,7 @@ private:
|
|
907
907
|
return Qnil;
|
908
908
|
}
|
909
909
|
return Qnil;
|
910
|
-
}
|
910
|
+
}
|
911
911
|
};
|
912
912
|
|
913
913
|
const rb_data_type_t RbLLaMAModel::llama_model_type = {
|
@@ -924,13 +924,13 @@ class LLaMAContextWrapper {
|
|
924
924
|
public:
|
925
925
|
struct llama_context* ctx;
|
926
926
|
|
927
|
-
LLaMAContextWrapper() : ctx(NULL){}
|
927
|
+
LLaMAContextWrapper() : ctx(NULL) {}
|
928
928
|
|
929
929
|
~LLaMAContextWrapper() {
|
930
930
|
if (ctx != NULL) {
|
931
931
|
llama_free(ctx);
|
932
932
|
}
|
933
|
-
}
|
933
|
+
}
|
934
934
|
};
|
935
935
|
|
936
936
|
class RbLLaMAContext {
|
@@ -939,22 +939,22 @@ public:
|
|
939
939
|
LLaMAContextWrapper* ptr = (LLaMAContextWrapper*)ruby_xmalloc(sizeof(LLaMAContextWrapper));
|
940
940
|
new (ptr) LLaMAContextWrapper();
|
941
941
|
return TypedData_Wrap_Struct(self, &llama_context_type, ptr);
|
942
|
-
}
|
942
|
+
}
|
943
943
|
|
944
944
|
static void llama_context_free(void* ptr) {
|
945
945
|
((LLaMAContextWrapper*)ptr)->~LLaMAContextWrapper();
|
946
946
|
ruby_xfree(ptr);
|
947
|
-
}
|
947
|
+
}
|
948
948
|
|
949
949
|
static size_t llama_context_size(const void* ptr) {
|
950
950
|
return sizeof(*((LLaMAContextWrapper*)ptr));
|
951
|
-
}
|
951
|
+
}
|
952
952
|
|
953
953
|
static LLaMAContextWrapper* get_llama_context(VALUE self) {
|
954
954
|
LLaMAContextWrapper* ptr;
|
955
955
|
TypedData_Get_Struct(self, LLaMAContextWrapper, &llama_context_type, ptr);
|
956
956
|
return ptr;
|
957
|
-
}
|
957
|
+
}
|
958
958
|
|
959
959
|
static void define_class(VALUE outer) {
|
960
960
|
rb_cLLaMAContext = rb_define_class_under(outer, "Context", rb_cObject);
|
@@ -980,6 +980,7 @@ public:
|
|
980
980
|
rb_define_method(rb_cLLaMAContext, "save_session_file", RUBY_METHOD_FUNC(_llama_context_save_session_file), -1);
|
981
981
|
rb_define_method(rb_cLLaMAContext, "sample_repetition_penalty", RUBY_METHOD_FUNC(_llama_context_sample_repetition_penalty), -1);
|
982
982
|
rb_define_method(rb_cLLaMAContext, "sample_frequency_and_presence_penalties", RUBY_METHOD_FUNC(_llama_context_sample_frequency_and_presence_penalties), -1);
|
983
|
+
rb_define_method(rb_cLLaMAContext, "sample_classifier_free_guidance", RUBY_METHOD_FUNC(_llama_context_sample_classifier_free_guidance), -1);
|
983
984
|
rb_define_method(rb_cLLaMAContext, "sample_softmax", RUBY_METHOD_FUNC(_llama_context_sample_softmax), 1);
|
984
985
|
rb_define_method(rb_cLLaMAContext, "sample_top_k", RUBY_METHOD_FUNC(_llama_context_sample_top_k), -1);
|
985
986
|
rb_define_method(rb_cLLaMAContext, "sample_top_p", RUBY_METHOD_FUNC(_llama_context_sample_top_p), -1);
|
@@ -990,7 +991,7 @@ public:
|
|
990
991
|
rb_define_method(rb_cLLaMAContext, "sample_token_mirostat_v2", RUBY_METHOD_FUNC(_llama_context_sample_token_mirostat_v2), -1);
|
991
992
|
rb_define_method(rb_cLLaMAContext, "sample_token_greedy", RUBY_METHOD_FUNC(_llama_context_sample_token_greedy), 1);
|
992
993
|
rb_define_method(rb_cLLaMAContext, "sample_token", RUBY_METHOD_FUNC(_llama_context_sample_token), 1);
|
993
|
-
}
|
994
|
+
}
|
994
995
|
|
995
996
|
private:
|
996
997
|
static const rb_data_type_t llama_context_type;
|
@@ -1029,7 +1030,7 @@ private:
|
|
1029
1030
|
rb_iv_set(self, "@has_evaluated", Qfalse);
|
1030
1031
|
|
1031
1032
|
return Qnil;
|
1032
|
-
}
|
1033
|
+
}
|
1033
1034
|
|
1034
1035
|
static VALUE _llama_context_eval(int argc, VALUE* argv, VALUE self) {
|
1035
1036
|
VALUE kw_args = Qnil;
|
@@ -1084,7 +1085,7 @@ private:
|
|
1084
1085
|
rb_iv_set(self, "@has_evaluated", Qtrue);
|
1085
1086
|
|
1086
1087
|
return Qnil;
|
1087
|
-
}
|
1088
|
+
}
|
1088
1089
|
|
1089
1090
|
static VALUE _llama_context_eval_embd(int argc, VALUE* argv, VALUE self) {
|
1090
1091
|
VALUE kw_args = Qnil;
|
@@ -1157,7 +1158,7 @@ private:
|
|
1157
1158
|
}
|
1158
1159
|
RB_GC_GUARD(fname_);
|
1159
1160
|
return Qtrue;
|
1160
|
-
}
|
1161
|
+
}
|
1161
1162
|
|
1162
1163
|
static VALUE _llama_context_tokenize(int argc, VALUE* argv, VALUE self) {
|
1163
1164
|
VALUE kw_args = Qnil;
|
@@ -1203,7 +1204,7 @@ private:
|
|
1203
1204
|
|
1204
1205
|
RB_GC_GUARD(text_);
|
1205
1206
|
return output;
|
1206
|
-
}
|
1207
|
+
}
|
1207
1208
|
|
1208
1209
|
static VALUE _llama_context_token_to_str(VALUE self, VALUE token_) {
|
1209
1210
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
@@ -1214,7 +1215,7 @@ private:
|
|
1214
1215
|
const llama_token token = NUM2INT(token_);
|
1215
1216
|
const char* str = llama_token_to_str(ptr->ctx, token);
|
1216
1217
|
return str != nullptr ? rb_utf8_str_new_cstr(str) : rb_utf8_str_new_cstr("");
|
1217
|
-
}
|
1218
|
+
}
|
1218
1219
|
|
1219
1220
|
static VALUE _llama_context_logits(VALUE self) {
|
1220
1221
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
@@ -1239,7 +1240,7 @@ private:
|
|
1239
1240
|
}
|
1240
1241
|
|
1241
1242
|
return output;
|
1242
|
-
}
|
1243
|
+
}
|
1243
1244
|
|
1244
1245
|
static VALUE _llama_context_embeddings(VALUE self) {
|
1245
1246
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
@@ -1267,7 +1268,7 @@ private:
|
|
1267
1268
|
}
|
1268
1269
|
|
1269
1270
|
return output;
|
1270
|
-
}
|
1271
|
+
}
|
1271
1272
|
|
1272
1273
|
static VALUE _llama_context_vocab(int argc, VALUE* argv, VALUE self) {
|
1273
1274
|
VALUE kw_args = Qnil;
|
@@ -1304,7 +1305,7 @@ private:
|
|
1304
1305
|
}
|
1305
1306
|
|
1306
1307
|
return rb_ary_new_from_args(2, ret_strings, ret_scores);
|
1307
|
-
}
|
1308
|
+
}
|
1308
1309
|
|
1309
1310
|
static VALUE _llama_context_n_vocab(VALUE self) {
|
1310
1311
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
@@ -1313,7 +1314,7 @@ private:
|
|
1313
1314
|
return Qnil;
|
1314
1315
|
}
|
1315
1316
|
return INT2NUM(llama_n_vocab(ptr->ctx));
|
1316
|
-
}
|
1317
|
+
}
|
1317
1318
|
|
1318
1319
|
static VALUE _llama_context_n_ctx(VALUE self) {
|
1319
1320
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
@@ -1322,7 +1323,7 @@ private:
|
|
1322
1323
|
return Qnil;
|
1323
1324
|
}
|
1324
1325
|
return INT2NUM(llama_n_ctx(ptr->ctx));
|
1325
|
-
}
|
1326
|
+
}
|
1326
1327
|
|
1327
1328
|
static VALUE _llama_context_n_embd(VALUE self) {
|
1328
1329
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
@@ -1331,7 +1332,7 @@ private:
|
|
1331
1332
|
return Qnil;
|
1332
1333
|
}
|
1333
1334
|
return INT2NUM(llama_n_embd(ptr->ctx));
|
1334
|
-
}
|
1335
|
+
}
|
1335
1336
|
|
1336
1337
|
static VALUE _llama_context_get_timings(VALUE self) {
|
1337
1338
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
@@ -1353,7 +1354,7 @@ private:
|
|
1353
1354
|
}
|
1354
1355
|
llama_print_timings(ptr->ctx);
|
1355
1356
|
return Qnil;
|
1356
|
-
}
|
1357
|
+
}
|
1357
1358
|
|
1358
1359
|
static VALUE _llama_context_reset_timings(VALUE self) {
|
1359
1360
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
@@ -1363,7 +1364,7 @@ private:
|
|
1363
1364
|
}
|
1364
1365
|
llama_reset_timings(ptr->ctx);
|
1365
1366
|
return Qnil;
|
1366
|
-
}
|
1367
|
+
}
|
1367
1368
|
|
1368
1369
|
static VALUE _llama_context_kv_cache_token_count(VALUE self) {
|
1369
1370
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
@@ -1372,7 +1373,7 @@ private:
|
|
1372
1373
|
return Qnil;
|
1373
1374
|
}
|
1374
1375
|
return INT2NUM(llama_get_kv_cache_token_count(ptr->ctx));
|
1375
|
-
}
|
1376
|
+
}
|
1376
1377
|
|
1377
1378
|
static VALUE _llama_context_set_rng_seed(VALUE self, VALUE seed_) {
|
1378
1379
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
@@ -1387,7 +1388,7 @@ private:
|
|
1387
1388
|
const uint32_t seed = NUM2INT(seed_);
|
1388
1389
|
llama_set_rng_seed(ptr->ctx, seed);
|
1389
1390
|
return Qnil;
|
1390
|
-
}
|
1391
|
+
}
|
1391
1392
|
|
1392
1393
|
static VALUE _llama_context_load_session_file(int argc, VALUE* argv, VALUE self) {
|
1393
1394
|
VALUE kw_args = Qnil;
|
@@ -1525,7 +1526,7 @@ private:
|
|
1525
1526
|
llama_sample_repetition_penalty(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size, penalty);
|
1526
1527
|
|
1527
1528
|
return Qnil;
|
1528
|
-
}
|
1529
|
+
}
|
1529
1530
|
|
1530
1531
|
static VALUE _llama_context_sample_frequency_and_presence_penalties(int argc, VALUE* argv, VALUE self) {
|
1531
1532
|
VALUE kw_args = Qnil;
|
@@ -1576,7 +1577,52 @@ private:
|
|
1576
1577
|
llama_sample_frequency_and_presence_penalties(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size, alpha_frequency, alpha_presence);
|
1577
1578
|
|
1578
1579
|
return Qnil;
|
1579
|
-
}
|
1580
|
+
}
|
1581
|
+
|
1582
|
+
static VALUE _llama_context_sample_classifier_free_guidance(int argc, VALUE* argv, VALUE self) {
|
1583
|
+
VALUE kw_args = Qnil;
|
1584
|
+
ID kw_table[3] = { rb_intern("guidance"), rb_intern("scale"), rb_intern("smooth_factor") };
|
1585
|
+
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
1586
|
+
VALUE candidates = Qnil;
|
1587
|
+
rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
|
1588
|
+
rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
|
1589
|
+
|
1590
|
+
if (!rb_obj_is_kind_of(kw_values[0], rb_cLLaMAContext)) {
|
1591
|
+
rb_raise(rb_eArgError, "guidance must be a Context");
|
1592
|
+
return Qnil;
|
1593
|
+
}
|
1594
|
+
if (!RB_FLOAT_TYPE_P(kw_values[1])) {
|
1595
|
+
rb_raise(rb_eArgError, "scale must be a float");
|
1596
|
+
return Qnil;
|
1597
|
+
}
|
1598
|
+
if (!RB_FLOAT_TYPE_P(kw_values[2])) {
|
1599
|
+
rb_raise(rb_eArgError, "smooth_factor must be a float");
|
1600
|
+
return Qnil;
|
1601
|
+
}
|
1602
|
+
|
1603
|
+
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
1604
|
+
if (ctx_ptr->ctx == NULL) {
|
1605
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1606
|
+
return Qnil;
|
1607
|
+
}
|
1608
|
+
LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
|
1609
|
+
if (cnd_ptr->array.data == nullptr) {
|
1610
|
+
rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
|
1611
|
+
return Qnil;
|
1612
|
+
}
|
1613
|
+
|
1614
|
+
LLaMAContextWrapper* guidance_ptr = get_llama_context(kw_values[0]);
|
1615
|
+
if (guidance_ptr->ctx == NULL) {
|
1616
|
+
rb_raise(rb_eRuntimeError, "guidance context is not initialized");
|
1617
|
+
return Qnil;
|
1618
|
+
}
|
1619
|
+
const float scale = NUM2DBL(kw_values[1]);
|
1620
|
+
const float smooth_factor = NUM2DBL(kw_values[2]);
|
1621
|
+
|
1622
|
+
llama_sample_classifier_free_guidance(ctx_ptr->ctx, &(cnd_ptr->array), guidance_ptr->ctx, scale, smooth_factor);
|
1623
|
+
|
1624
|
+
return Qnil;
|
1625
|
+
}
|
1580
1626
|
|
1581
1627
|
static VALUE _llama_context_sample_softmax(VALUE self, VALUE candidates) {
|
1582
1628
|
if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
|
@@ -1598,7 +1644,7 @@ private:
|
|
1598
1644
|
llama_sample_softmax(ctx_ptr->ctx, &(cnd_ptr->array));
|
1599
1645
|
|
1600
1646
|
return Qnil;
|
1601
|
-
}
|
1647
|
+
}
|
1602
1648
|
|
1603
1649
|
static VALUE _llama_context_sample_top_k(int argc, VALUE* argv, VALUE self) {
|
1604
1650
|
VALUE kw_args = Qnil;
|
@@ -1637,7 +1683,7 @@ private:
|
|
1637
1683
|
llama_sample_top_k(ctx_ptr->ctx, &(cnd_ptr->array), k, min_keep);
|
1638
1684
|
|
1639
1685
|
return Qnil;
|
1640
|
-
}
|
1686
|
+
}
|
1641
1687
|
|
1642
1688
|
static VALUE _llama_context_sample_top_p(int argc, VALUE* argv, VALUE self) {
|
1643
1689
|
VALUE kw_args = Qnil;
|
@@ -1676,7 +1722,7 @@ private:
|
|
1676
1722
|
llama_sample_top_p(ctx_ptr->ctx, &(cnd_ptr->array), prob, min_keep);
|
1677
1723
|
|
1678
1724
|
return Qnil;
|
1679
|
-
}
|
1725
|
+
}
|
1680
1726
|
|
1681
1727
|
static VALUE _llama_context_sample_tail_free(int argc, VALUE* argv, VALUE self) {
|
1682
1728
|
VALUE kw_args = Qnil;
|
@@ -1715,7 +1761,7 @@ private:
|
|
1715
1761
|
llama_sample_tail_free(ctx_ptr->ctx, &(cnd_ptr->array), z, min_keep);
|
1716
1762
|
|
1717
1763
|
return Qnil;
|
1718
|
-
}
|
1764
|
+
}
|
1719
1765
|
|
1720
1766
|
static VALUE _llama_context_sample_typical(int argc, VALUE* argv, VALUE self) {
|
1721
1767
|
VALUE kw_args = Qnil;
|
@@ -1754,7 +1800,7 @@ private:
|
|
1754
1800
|
llama_sample_typical(ctx_ptr->ctx, &(cnd_ptr->array), prob, min_keep);
|
1755
1801
|
|
1756
1802
|
return Qnil;
|
1757
|
-
}
|
1803
|
+
}
|
1758
1804
|
|
1759
1805
|
static VALUE _llama_context_sample_temperature(int argc, VALUE* argv, VALUE self) {
|
1760
1806
|
VALUE kw_args = Qnil;
|
@@ -1788,7 +1834,7 @@ private:
|
|
1788
1834
|
llama_sample_temperature(ctx_ptr->ctx, &(cnd_ptr->array), temperature);
|
1789
1835
|
|
1790
1836
|
return Qnil;
|
1791
|
-
}
|
1837
|
+
}
|
1792
1838
|
|
1793
1839
|
static VALUE _llama_context_sample_token_mirostat(int argc, VALUE* argv, VALUE self) {
|
1794
1840
|
VALUE kw_args = Qnil;
|
@@ -1840,7 +1886,7 @@ private:
|
|
1840
1886
|
rb_ary_store(ret, 0, INT2NUM(id));
|
1841
1887
|
rb_ary_store(ret, 1, DBL2NUM(mu));
|
1842
1888
|
return ret;
|
1843
|
-
}
|
1889
|
+
}
|
1844
1890
|
|
1845
1891
|
static VALUE _llama_context_sample_token_mirostat_v2(int argc, VALUE* argv, VALUE self) {
|
1846
1892
|
VALUE kw_args = Qnil;
|
@@ -1887,7 +1933,7 @@ private:
|
|
1887
1933
|
rb_ary_store(ret, 0, INT2NUM(id));
|
1888
1934
|
rb_ary_store(ret, 1, DBL2NUM(mu));
|
1889
1935
|
return ret;
|
1890
|
-
}
|
1936
|
+
}
|
1891
1937
|
|
1892
1938
|
static VALUE _llama_context_sample_token_greedy(VALUE self, VALUE candidates) {
|
1893
1939
|
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
@@ -1906,7 +1952,7 @@ private:
|
|
1906
1952
|
}
|
1907
1953
|
llama_token id = llama_sample_token_greedy(ctx_ptr->ctx, &(cnd_ptr->array));
|
1908
1954
|
return INT2NUM(id);
|
1909
|
-
}
|
1955
|
+
}
|
1910
1956
|
|
1911
1957
|
static VALUE _llama_context_sample_token(VALUE self, VALUE candidates) {
|
1912
1958
|
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
@@ -1925,7 +1971,7 @@ private:
|
|
1925
1971
|
}
|
1926
1972
|
llama_token id = llama_sample_token(ctx_ptr->ctx, &(cnd_ptr->array));
|
1927
1973
|
return INT2NUM(id);
|
1928
|
-
}
|
1974
|
+
}
|
1929
1975
|
};
|
1930
1976
|
|
1931
1977
|
const rb_data_type_t RbLLaMAContext::llama_context_type = {
|
@@ -1940,7 +1986,7 @@ const rb_data_type_t RbLLaMAContext::llama_context_type = {
|
|
1940
1986
|
|
1941
1987
|
// module functions
|
1942
1988
|
|
1943
|
-
static VALUE
|
1989
|
+
static VALUE rb_llama_llama_backend_init(int argc, VALUE* argv, VALUE self) {
|
1944
1990
|
VALUE kw_args = Qnil;
|
1945
1991
|
ID kw_table[1] = { rb_intern("numa") };
|
1946
1992
|
VALUE kw_values[1] = { Qundef };
|
@@ -1948,7 +1994,13 @@ static VALUE rb_llama_llama_init_backend(int argc, VALUE* argv, VALUE self) {
|
|
1948
1994
|
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
1949
1995
|
|
1950
1996
|
const bool numa = kw_values[0] == Qundef ? false : (RTEST ? true : false);
|
1951
|
-
|
1997
|
+
llama_backend_init(numa);
|
1998
|
+
|
1999
|
+
return Qnil;
|
2000
|
+
}
|
2001
|
+
|
2002
|
+
static VALUE rb_llama_llama_backend_free(VALUE self) {
|
2003
|
+
llama_backend_free();
|
1952
2004
|
|
1953
2005
|
return Qnil;
|
1954
2006
|
}
|
@@ -2021,7 +2073,8 @@ extern "C" void Init_llama_cpp(void) {
|
|
2021
2073
|
RbLLaMAContextParams::define_class(rb_mLLaMACpp);
|
2022
2074
|
RbLLaMAModelQuantizeParams::define_class(rb_mLLaMACpp);
|
2023
2075
|
|
2024
|
-
rb_define_module_function(rb_mLLaMACpp, "
|
2076
|
+
rb_define_module_function(rb_mLLaMACpp, "backend_init", rb_llama_llama_backend_init, -1);
|
2077
|
+
rb_define_module_function(rb_mLLaMACpp, "backend_free", rb_llama_llama_backend_free, 0);
|
2025
2078
|
rb_define_module_function(rb_mLLaMACpp, "model_quantize", rb_llama_model_quantize, -1);
|
2026
2079
|
rb_define_module_function(rb_mLLaMACpp, "token_bos", rb_llama_token_bos, 0);
|
2027
2080
|
rb_define_module_function(rb_mLLaMACpp, "token_eos", rb_llama_token_eos, 0);
|