llama_cpp 0.3.2 → 0.3.3
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +13 -0
- data/ext/llama_cpp/extconf.rb +9 -0
- data/ext/llama_cpp/llama_cpp.cpp +165 -112
- data/ext/llama_cpp/src/ggml-cuda.cu +217 -76
- data/ext/llama_cpp/src/ggml-metal.h +5 -1
- data/ext/llama_cpp/src/ggml-metal.m +16 -5
- data/ext/llama_cpp/src/ggml-metal.metal +56 -47
- data/ext/llama_cpp/src/ggml-mpi.c +216 -0
- data/ext/llama_cpp/src/ggml-mpi.h +39 -0
- data/ext/llama_cpp/src/ggml.c +1082 -774
- data/ext/llama_cpp/src/ggml.h +64 -18
- data/ext/llama_cpp/src/llama.cpp +179 -51
- data/ext/llama_cpp/src/llama.h +15 -1
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +1 -1
- data/sig/llama_cpp.rbs +3 -1
- metadata +4 -2
data/ext/llama_cpp/llama_cpp.cpp
CHANGED
@@ -17,9 +17,9 @@ public:
|
|
17
17
|
data.id = 0;
|
18
18
|
data.logit = 0.0;
|
19
19
|
data.p = 0.0;
|
20
|
-
}
|
20
|
+
}
|
21
21
|
|
22
|
-
~LLaMATokenDataWrapper(){}
|
22
|
+
~LLaMATokenDataWrapper() {}
|
23
23
|
};
|
24
24
|
|
25
25
|
class RbLLaMATokenData {
|
@@ -28,22 +28,22 @@ public:
|
|
28
28
|
LLaMATokenDataWrapper* ptr = (LLaMATokenDataWrapper*)ruby_xmalloc(sizeof(LLaMATokenDataWrapper));
|
29
29
|
new (ptr) LLaMATokenDataWrapper();
|
30
30
|
return TypedData_Wrap_Struct(self, &llama_token_data_type, ptr);
|
31
|
-
}
|
31
|
+
}
|
32
32
|
|
33
33
|
static void llama_token_data_free(void* ptr) {
|
34
34
|
((LLaMATokenDataWrapper*)ptr)->~LLaMATokenDataWrapper();
|
35
35
|
ruby_xfree(ptr);
|
36
|
-
}
|
36
|
+
}
|
37
37
|
|
38
38
|
static size_t llama_token_data_size(const void* ptr) {
|
39
39
|
return sizeof(*((LLaMATokenDataWrapper*)ptr));
|
40
|
-
}
|
40
|
+
}
|
41
41
|
|
42
42
|
static LLaMATokenDataWrapper* get_llama_token_data(VALUE self) {
|
43
43
|
LLaMATokenDataWrapper* ptr;
|
44
44
|
TypedData_Get_Struct(self, LLaMATokenDataWrapper, &llama_token_data_type, ptr);
|
45
45
|
return ptr;
|
46
|
-
}
|
46
|
+
}
|
47
47
|
|
48
48
|
static void define_class(VALUE outer) {
|
49
49
|
rb_cLLaMATokenData = rb_define_class_under(outer, "TokenData", rb_cObject);
|
@@ -95,36 +95,36 @@ private:
|
|
95
95
|
LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
|
96
96
|
ptr->data.id = NUM2INT(id);
|
97
97
|
return INT2NUM(ptr->data.id);
|
98
|
-
}
|
98
|
+
}
|
99
99
|
|
100
100
|
static VALUE _llama_token_data_get_id(VALUE self) {
|
101
101
|
LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
|
102
102
|
return INT2NUM(ptr->data.id);
|
103
|
-
}
|
103
|
+
}
|
104
104
|
|
105
105
|
// logit
|
106
106
|
static VALUE _llama_token_data_set_logit(VALUE self, VALUE logit) {
|
107
107
|
LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
|
108
108
|
ptr->data.logit = NUM2DBL(logit);
|
109
109
|
return DBL2NUM(ptr->data.logit);
|
110
|
-
}
|
110
|
+
}
|
111
111
|
|
112
112
|
static VALUE _llama_token_data_get_logit(VALUE self) {
|
113
113
|
LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
|
114
114
|
return DBL2NUM(ptr->data.logit);
|
115
|
-
}
|
115
|
+
}
|
116
116
|
|
117
117
|
// p
|
118
118
|
static VALUE _llama_token_data_set_p(VALUE self, VALUE p) {
|
119
119
|
LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
|
120
120
|
ptr->data.p = NUM2DBL(p);
|
121
121
|
return DBL2NUM(ptr->data.p);
|
122
|
-
}
|
122
|
+
}
|
123
123
|
|
124
124
|
static VALUE _llama_token_data_get_p(VALUE self) {
|
125
125
|
LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
|
126
126
|
return DBL2NUM(ptr->data.p);
|
127
|
-
}
|
127
|
+
}
|
128
128
|
};
|
129
129
|
|
130
130
|
const rb_data_type_t RbLLaMATokenData::llama_token_data_type = {
|
@@ -145,14 +145,14 @@ public:
|
|
145
145
|
array.data = nullptr;
|
146
146
|
array.size = 0;
|
147
147
|
array.sorted = false;
|
148
|
-
}
|
148
|
+
}
|
149
149
|
|
150
150
|
~LLaMATokenDataArrayWrapper() {
|
151
151
|
if (array.data) {
|
152
152
|
ruby_xfree(array.data);
|
153
153
|
array.data = nullptr;
|
154
154
|
}
|
155
|
-
}
|
155
|
+
}
|
156
156
|
};
|
157
157
|
|
158
158
|
class RbLLaMATokenDataArray {
|
@@ -161,22 +161,22 @@ public:
|
|
161
161
|
LLaMATokenDataArrayWrapper* ptr = (LLaMATokenDataArrayWrapper*)ruby_xmalloc(sizeof(LLaMATokenDataArrayWrapper));
|
162
162
|
new (ptr) LLaMATokenDataArrayWrapper();
|
163
163
|
return TypedData_Wrap_Struct(self, &llama_token_data_array_type, ptr);
|
164
|
-
}
|
164
|
+
}
|
165
165
|
|
166
166
|
static void llama_token_data_array_free(void* ptr) {
|
167
167
|
((LLaMATokenDataArrayWrapper*)ptr)->~LLaMATokenDataArrayWrapper();
|
168
168
|
ruby_xfree(ptr);
|
169
|
-
}
|
169
|
+
}
|
170
170
|
|
171
171
|
static size_t llama_token_data_array_size(const void* ptr) {
|
172
172
|
return sizeof(*((LLaMATokenDataArrayWrapper*)ptr));
|
173
|
-
}
|
173
|
+
}
|
174
174
|
|
175
175
|
static LLaMATokenDataArrayWrapper* get_llama_token_data_array(VALUE self) {
|
176
176
|
LLaMATokenDataArrayWrapper* ptr;
|
177
177
|
TypedData_Get_Struct(self, LLaMATokenDataArrayWrapper, &llama_token_data_array_type, ptr);
|
178
178
|
return ptr;
|
179
|
-
}
|
179
|
+
}
|
180
180
|
|
181
181
|
static void define_class(VALUE outer) {
|
182
182
|
rb_cLLaMATokenDataArray = rb_define_class_under(outer, "TokenDataArray", rb_cObject);
|
@@ -184,7 +184,7 @@ public:
|
|
184
184
|
rb_define_method(rb_cLLaMATokenDataArray, "initialize", RUBY_METHOD_FUNC(_llama_token_data_array_init), -1);
|
185
185
|
rb_define_method(rb_cLLaMATokenDataArray, "size", RUBY_METHOD_FUNC(_llama_token_data_array_get_size), 0);
|
186
186
|
rb_define_method(rb_cLLaMATokenDataArray, "sorted", RUBY_METHOD_FUNC(_llama_token_data_array_get_sorted), 0);
|
187
|
-
}
|
187
|
+
}
|
188
188
|
|
189
189
|
private:
|
190
190
|
static const rb_data_type_t llama_token_data_array_type;
|
@@ -233,17 +233,17 @@ private:
|
|
233
233
|
ptr->array.sorted = kw_values[0] == Qtrue;
|
234
234
|
|
235
235
|
return self;
|
236
|
-
}
|
236
|
+
}
|
237
237
|
|
238
238
|
static VALUE _llama_token_data_array_get_size(VALUE self) {
|
239
239
|
LLaMATokenDataArrayWrapper* ptr = get_llama_token_data_array(self);
|
240
240
|
return SIZET2NUM(ptr->array.size);
|
241
|
-
}
|
241
|
+
}
|
242
242
|
|
243
243
|
static VALUE _llama_token_data_array_get_sorted(VALUE self) {
|
244
244
|
LLaMATokenDataArrayWrapper* ptr = get_llama_token_data_array(self);
|
245
245
|
return ptr->array.sorted ? Qtrue : Qfalse;
|
246
|
-
}
|
246
|
+
}
|
247
247
|
};
|
248
248
|
|
249
249
|
const rb_data_type_t RbLLaMATokenDataArray::llama_token_data_array_type = {
|
@@ -260,9 +260,9 @@ class LLaMATimingsWrapper {
|
|
260
260
|
public:
|
261
261
|
struct llama_timings timings;
|
262
262
|
|
263
|
-
LLaMATimingsWrapper(){}
|
263
|
+
LLaMATimingsWrapper() {}
|
264
264
|
|
265
|
-
~LLaMATimingsWrapper(){}
|
265
|
+
~LLaMATimingsWrapper() {}
|
266
266
|
};
|
267
267
|
|
268
268
|
class RbLLaMATimings {
|
@@ -365,9 +365,9 @@ class LLaMAContextParamsWrapper {
|
|
365
365
|
public:
|
366
366
|
struct llama_context_params params;
|
367
367
|
|
368
|
-
LLaMAContextParamsWrapper() : params(llama_context_default_params()){}
|
368
|
+
LLaMAContextParamsWrapper() : params(llama_context_default_params()) {}
|
369
369
|
|
370
|
-
~LLaMAContextParamsWrapper(){}
|
370
|
+
~LLaMAContextParamsWrapper() {}
|
371
371
|
};
|
372
372
|
|
373
373
|
class RbLLaMAContextParams {
|
@@ -376,22 +376,22 @@ public:
|
|
376
376
|
LLaMAContextParamsWrapper* ptr = (LLaMAContextParamsWrapper*)ruby_xmalloc(sizeof(LLaMAContextParamsWrapper));
|
377
377
|
new (ptr) LLaMAContextParamsWrapper();
|
378
378
|
return TypedData_Wrap_Struct(self, &llama_context_params_type, ptr);
|
379
|
-
}
|
379
|
+
}
|
380
380
|
|
381
381
|
static void llama_context_params_free(void* ptr) {
|
382
382
|
((LLaMAContextParamsWrapper*)ptr)->~LLaMAContextParamsWrapper();
|
383
383
|
ruby_xfree(ptr);
|
384
|
-
}
|
384
|
+
}
|
385
385
|
|
386
386
|
static size_t llama_context_params_size(const void* ptr) {
|
387
387
|
return sizeof(*((LLaMAContextParamsWrapper*)ptr));
|
388
|
-
}
|
388
|
+
}
|
389
389
|
|
390
390
|
static LLaMAContextParamsWrapper* get_llama_context_params(VALUE self) {
|
391
391
|
LLaMAContextParamsWrapper* ptr;
|
392
392
|
TypedData_Get_Struct(self, LLaMAContextParamsWrapper, &llama_context_params_type, ptr);
|
393
393
|
return ptr;
|
394
|
-
}
|
394
|
+
}
|
395
395
|
|
396
396
|
static void define_class(VALUE outer) {
|
397
397
|
rb_cLLaMAContextParams = rb_define_class_under(outer, "ContextParams", rb_cObject);
|
@@ -422,7 +422,7 @@ public:
|
|
422
422
|
rb_define_method(rb_cLLaMAContextParams, "use_mlock", RUBY_METHOD_FUNC(_llama_context_params_get_use_mlock), 0);
|
423
423
|
rb_define_method(rb_cLLaMAContextParams, "embedding=", RUBY_METHOD_FUNC(_llama_context_params_set_embedding), 1);
|
424
424
|
rb_define_method(rb_cLLaMAContextParams, "embedding", RUBY_METHOD_FUNC(_llama_context_params_get_embedding), 0);
|
425
|
-
}
|
425
|
+
}
|
426
426
|
|
427
427
|
private:
|
428
428
|
static const rb_data_type_t llama_context_params_type;
|
@@ -431,55 +431,55 @@ private:
|
|
431
431
|
// LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
432
432
|
// new (ptr) LLaMAContextParamsWrapper();
|
433
433
|
// return self;
|
434
|
-
// }
|
434
|
+
// }
|
435
435
|
|
436
436
|
// n_ctx
|
437
437
|
static VALUE _llama_context_params_set_n_ctx(VALUE self, VALUE n_ctx) {
|
438
438
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
439
439
|
ptr->params.n_ctx = NUM2INT(n_ctx);
|
440
440
|
return INT2NUM(ptr->params.n_ctx);
|
441
|
-
}
|
441
|
+
}
|
442
442
|
|
443
443
|
static VALUE _llama_context_params_get_n_ctx(VALUE self) {
|
444
444
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
445
445
|
return INT2NUM(ptr->params.n_ctx);
|
446
|
-
}
|
446
|
+
}
|
447
447
|
|
448
448
|
// n_batch
|
449
449
|
static VALUE _llama_context_params_set_n_batch(VALUE self, VALUE n_batch) {
|
450
450
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
451
451
|
ptr->params.n_batch = NUM2INT(n_batch);
|
452
452
|
return INT2NUM(ptr->params.n_batch);
|
453
|
-
}
|
453
|
+
}
|
454
454
|
|
455
455
|
static VALUE _llama_context_params_get_n_batch(VALUE self) {
|
456
456
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
457
457
|
return INT2NUM(ptr->params.n_batch);
|
458
|
-
}
|
458
|
+
}
|
459
459
|
|
460
460
|
// n_gpu_layers
|
461
461
|
static VALUE _llama_context_params_set_n_gpu_layers(VALUE self, VALUE n_gpu_layers) {
|
462
462
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
463
463
|
ptr->params.n_gpu_layers = NUM2INT(n_gpu_layers);
|
464
464
|
return INT2NUM(ptr->params.n_gpu_layers);
|
465
|
-
}
|
465
|
+
}
|
466
466
|
|
467
467
|
static VALUE _llama_context_params_get_n_gpu_layers(VALUE self) {
|
468
468
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
469
469
|
return INT2NUM(ptr->params.n_gpu_layers);
|
470
|
-
}
|
470
|
+
}
|
471
471
|
|
472
472
|
// main_gpu
|
473
473
|
static VALUE _llama_context_params_set_main_gpu(VALUE self, VALUE main_gpu) {
|
474
474
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
475
475
|
ptr->params.main_gpu = NUM2INT(main_gpu);
|
476
476
|
return INT2NUM(ptr->params.main_gpu);
|
477
|
-
}
|
477
|
+
}
|
478
478
|
|
479
479
|
static VALUE _llama_context_params_get_main_gpu(VALUE self) {
|
480
480
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
481
481
|
return INT2NUM(ptr->params.main_gpu);
|
482
|
-
}
|
482
|
+
}
|
483
483
|
|
484
484
|
// tensor_split
|
485
485
|
static VALUE _llama_context_params_get_tensor_split(VALUE self) {
|
@@ -492,19 +492,19 @@ private:
|
|
492
492
|
rb_ary_store(ret, i, DBL2NUM(ptr->params.tensor_split[i]));
|
493
493
|
}
|
494
494
|
return ret;
|
495
|
-
}
|
495
|
+
}
|
496
496
|
|
497
497
|
// low_vram
|
498
498
|
static VALUE _llama_context_params_set_low_vram(VALUE self, VALUE low_vram) {
|
499
499
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
500
500
|
ptr->params.low_vram = low_vram == Qtrue ? true : false;
|
501
501
|
return ptr->params.low_vram ? Qtrue : Qfalse;
|
502
|
-
}
|
502
|
+
}
|
503
503
|
|
504
504
|
static VALUE _llama_context_params_get_low_vram(VALUE self) {
|
505
505
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
506
506
|
return ptr->params.low_vram ? Qtrue : Qfalse;
|
507
|
-
}
|
507
|
+
}
|
508
508
|
|
509
509
|
// seed
|
510
510
|
static VALUE _llama_context_params_set_seed(VALUE self, VALUE seed) {
|
@@ -515,84 +515,84 @@ private:
|
|
515
515
|
}
|
516
516
|
ptr->params.seed = NUM2INT(seed);
|
517
517
|
return INT2NUM(ptr->params.seed);
|
518
|
-
}
|
518
|
+
}
|
519
519
|
|
520
520
|
static VALUE _llama_context_params_get_seed(VALUE self) {
|
521
521
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
522
522
|
return INT2NUM(ptr->params.seed);
|
523
|
-
}
|
523
|
+
}
|
524
524
|
|
525
525
|
// f16_kv
|
526
526
|
static VALUE _llama_context_params_set_f16_kv(VALUE self, VALUE f16_kv) {
|
527
527
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
528
528
|
ptr->params.f16_kv = f16_kv == Qtrue ? true : false;
|
529
529
|
return ptr->params.f16_kv ? Qtrue : Qfalse;
|
530
|
-
}
|
530
|
+
}
|
531
531
|
|
532
532
|
static VALUE _llama_context_params_get_f16_kv(VALUE self) {
|
533
533
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
534
534
|
return ptr->params.f16_kv ? Qtrue : Qfalse;
|
535
|
-
}
|
535
|
+
}
|
536
536
|
|
537
537
|
// logits_all
|
538
538
|
static VALUE _llama_context_params_set_logits_all(VALUE self, VALUE logits_all) {
|
539
539
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
540
540
|
ptr->params.logits_all = logits_all == Qtrue ? true : false;
|
541
541
|
return ptr->params.logits_all ? Qtrue : Qfalse;
|
542
|
-
}
|
542
|
+
}
|
543
543
|
|
544
544
|
static VALUE _llama_context_params_get_logits_all(VALUE self) {
|
545
545
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
546
546
|
return ptr->params.logits_all ? Qtrue : Qfalse;
|
547
|
-
}
|
547
|
+
}
|
548
548
|
|
549
549
|
// vocab_only
|
550
550
|
static VALUE _llama_context_params_set_vocab_only(VALUE self, VALUE vocab_only) {
|
551
551
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
552
552
|
ptr->params.vocab_only = vocab_only == Qtrue ? true : false;
|
553
553
|
return ptr->params.vocab_only ? Qtrue : Qfalse;
|
554
|
-
}
|
554
|
+
}
|
555
555
|
|
556
556
|
static VALUE _llama_context_params_get_vocab_only(VALUE self) {
|
557
557
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
558
558
|
return ptr->params.vocab_only ? Qtrue : Qfalse;
|
559
|
-
}
|
559
|
+
}
|
560
560
|
|
561
561
|
// use_mmap
|
562
562
|
static VALUE _llama_context_params_set_use_mmap(VALUE self, VALUE use_mmap) {
|
563
563
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
564
564
|
ptr->params.use_mmap = use_mmap == Qtrue ? true : false;
|
565
565
|
return ptr->params.use_mmap ? Qtrue : Qfalse;
|
566
|
-
}
|
566
|
+
}
|
567
567
|
|
568
568
|
static VALUE _llama_context_params_get_use_mmap(VALUE self) {
|
569
569
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
570
570
|
return ptr->params.use_mmap ? Qtrue : Qfalse;
|
571
|
-
}
|
571
|
+
}
|
572
572
|
|
573
573
|
// use_mlock
|
574
574
|
static VALUE _llama_context_params_set_use_mlock(VALUE self, VALUE use_mlock) {
|
575
575
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
576
576
|
ptr->params.use_mlock = use_mlock == Qtrue ? true : false;
|
577
577
|
return ptr->params.use_mlock ? Qtrue : Qfalse;
|
578
|
-
}
|
578
|
+
}
|
579
579
|
|
580
580
|
static VALUE _llama_context_params_get_use_mlock(VALUE self) {
|
581
581
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
582
582
|
return ptr->params.use_mlock ? Qtrue : Qfalse;
|
583
|
-
}
|
583
|
+
}
|
584
584
|
|
585
585
|
// embedding
|
586
586
|
static VALUE _llama_context_params_set_embedding(VALUE self, VALUE embedding) {
|
587
587
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
588
588
|
ptr->params.embedding = embedding == Qtrue ? true : false;
|
589
589
|
return ptr->params.embedding ? Qtrue : Qfalse;
|
590
|
-
}
|
590
|
+
}
|
591
591
|
|
592
592
|
static VALUE _llama_context_params_get_embedding(VALUE self) {
|
593
593
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
594
594
|
return ptr->params.embedding ? Qtrue : Qfalse;
|
595
|
-
}
|
595
|
+
}
|
596
596
|
};
|
597
597
|
|
598
598
|
const rb_data_type_t RbLLaMAContextParams::llama_context_params_type = {
|
@@ -609,9 +609,9 @@ class LLaMAModelQuantizeParamsWrapper {
|
|
609
609
|
public:
|
610
610
|
llama_model_quantize_params params;
|
611
611
|
|
612
|
-
LLaMAModelQuantizeParamsWrapper() : params(llama_model_quantize_default_params()){}
|
612
|
+
LLaMAModelQuantizeParamsWrapper() : params(llama_model_quantize_default_params()) {}
|
613
613
|
|
614
|
-
~LLaMAModelQuantizeParamsWrapper(){}
|
614
|
+
~LLaMAModelQuantizeParamsWrapper() {}
|
615
615
|
};
|
616
616
|
|
617
617
|
class RbLLaMAModelQuantizeParams {
|
@@ -620,22 +620,22 @@ public:
|
|
620
620
|
LLaMAModelQuantizeParamsWrapper* ptr = (LLaMAModelQuantizeParamsWrapper*)ruby_xmalloc(sizeof(LLaMAModelQuantizeParamsWrapper));
|
621
621
|
new (ptr) LLaMAModelQuantizeParamsWrapper();
|
622
622
|
return TypedData_Wrap_Struct(self, &llama_model_quantize_params_type, ptr);
|
623
|
-
}
|
623
|
+
}
|
624
624
|
|
625
625
|
static void llama_model_quantize_params_free(void* ptr) {
|
626
626
|
((LLaMAModelQuantizeParamsWrapper*)ptr)->~LLaMAModelQuantizeParamsWrapper();
|
627
627
|
ruby_xfree(ptr);
|
628
|
-
}
|
628
|
+
}
|
629
629
|
|
630
630
|
static size_t llama_model_quantize_params_size(const void* ptr) {
|
631
631
|
return sizeof(*((LLaMAModelQuantizeParamsWrapper*)ptr));
|
632
|
-
}
|
632
|
+
}
|
633
633
|
|
634
634
|
static LLaMAModelQuantizeParamsWrapper* get_llama_model_quantize_params(VALUE self) {
|
635
635
|
LLaMAModelQuantizeParamsWrapper* ptr;
|
636
636
|
TypedData_Get_Struct(self, LLaMAModelQuantizeParamsWrapper, &llama_model_quantize_params_type, ptr);
|
637
637
|
return ptr;
|
638
|
-
}
|
638
|
+
}
|
639
639
|
|
640
640
|
static void define_class(VALUE outer) {
|
641
641
|
rb_cLLaMAModelQuantizeParams = rb_define_class_under(outer, "ModelQuantizeParams", rb_cObject);
|
@@ -648,7 +648,7 @@ public:
|
|
648
648
|
rb_define_method(rb_cLLaMAModelQuantizeParams, "allow_requantize", RUBY_METHOD_FUNC(_llama_model_quantize_params_get_allow_requantize), 0);
|
649
649
|
rb_define_method(rb_cLLaMAModelQuantizeParams, "quantize_output_tensor=", RUBY_METHOD_FUNC(_llama_model_quantize_params_set_quantize_output_tensor), 1);
|
650
650
|
rb_define_method(rb_cLLaMAModelQuantizeParams, "quantize_output_tensor", RUBY_METHOD_FUNC(_llama_model_quantize_params_get_quantize_output_tensor), 0);
|
651
|
-
}
|
651
|
+
}
|
652
652
|
|
653
653
|
private:
|
654
654
|
static const rb_data_type_t llama_model_quantize_params_type;
|
@@ -658,24 +658,24 @@ private:
|
|
658
658
|
LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
|
659
659
|
ptr->params.nthread = NUM2INT(n_thread);
|
660
660
|
return INT2NUM(ptr->params.nthread);
|
661
|
-
}
|
661
|
+
}
|
662
662
|
|
663
663
|
static VALUE _llama_model_quantize_params_get_n_thread(VALUE self) {
|
664
664
|
LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
|
665
665
|
return INT2NUM(ptr->params.nthread);
|
666
|
-
}
|
666
|
+
}
|
667
667
|
|
668
668
|
// ftype
|
669
669
|
static VALUE _llama_model_quantize_params_set_ftype(VALUE self, VALUE ftype) {
|
670
670
|
LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
|
671
671
|
ptr->params.ftype = static_cast<enum llama_ftype>(NUM2INT(ftype));
|
672
672
|
return INT2NUM(ptr->params.ftype);
|
673
|
-
}
|
673
|
+
}
|
674
674
|
|
675
675
|
static VALUE _llama_model_quantize_params_get_ftype(VALUE self) {
|
676
676
|
LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
|
677
677
|
return INT2NUM(ptr->params.ftype);
|
678
|
-
}
|
678
|
+
}
|
679
679
|
|
680
680
|
// allow_requantize
|
681
681
|
static VALUE _llama_model_quantize_params_set_allow_requantize(VALUE self, VALUE allow_requantize) {
|
@@ -686,12 +686,12 @@ private:
|
|
686
686
|
ptr->params.allow_requantize = true;
|
687
687
|
}
|
688
688
|
return ptr->params.allow_requantize ? Qtrue : Qfalse;
|
689
|
-
}
|
689
|
+
}
|
690
690
|
|
691
691
|
static VALUE _llama_model_quantize_params_get_allow_requantize(VALUE self) {
|
692
692
|
LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
|
693
693
|
return ptr->params.allow_requantize ? Qtrue : Qfalse;
|
694
|
-
}
|
694
|
+
}
|
695
695
|
|
696
696
|
// quantize_output_tensor
|
697
697
|
static VALUE _llama_model_quantize_params_set_quantize_output_tensor(VALUE self, VALUE quantize_output_tensor) {
|
@@ -702,12 +702,12 @@ private:
|
|
702
702
|
ptr->params.quantize_output_tensor = true;
|
703
703
|
}
|
704
704
|
return ptr->params.quantize_output_tensor ? Qtrue : Qfalse;
|
705
|
-
}
|
705
|
+
}
|
706
706
|
|
707
707
|
static VALUE _llama_model_quantize_params_get_quantize_output_tensor(VALUE self) {
|
708
708
|
LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
|
709
709
|
return ptr->params.quantize_output_tensor ? Qtrue : Qfalse;
|
710
|
-
}
|
710
|
+
}
|
711
711
|
};
|
712
712
|
|
713
713
|
const rb_data_type_t RbLLaMAModelQuantizeParams::llama_model_quantize_params_type = {
|
@@ -724,13 +724,13 @@ class LLaMAModelWrapper {
|
|
724
724
|
public:
|
725
725
|
struct llama_model* model;
|
726
726
|
|
727
|
-
LLaMAModelWrapper() : model(NULL){}
|
727
|
+
LLaMAModelWrapper() : model(NULL) {}
|
728
728
|
|
729
729
|
~LLaMAModelWrapper() {
|
730
730
|
if (model != NULL) {
|
731
731
|
llama_free_model(model);
|
732
732
|
}
|
733
|
-
}
|
733
|
+
}
|
734
734
|
};
|
735
735
|
|
736
736
|
class RbLLaMAModel {
|
@@ -907,7 +907,7 @@ private:
|
|
907
907
|
return Qnil;
|
908
908
|
}
|
909
909
|
return Qnil;
|
910
|
-
}
|
910
|
+
}
|
911
911
|
};
|
912
912
|
|
913
913
|
const rb_data_type_t RbLLaMAModel::llama_model_type = {
|
@@ -924,13 +924,13 @@ class LLaMAContextWrapper {
|
|
924
924
|
public:
|
925
925
|
struct llama_context* ctx;
|
926
926
|
|
927
|
-
LLaMAContextWrapper() : ctx(NULL){}
|
927
|
+
LLaMAContextWrapper() : ctx(NULL) {}
|
928
928
|
|
929
929
|
~LLaMAContextWrapper() {
|
930
930
|
if (ctx != NULL) {
|
931
931
|
llama_free(ctx);
|
932
932
|
}
|
933
|
-
}
|
933
|
+
}
|
934
934
|
};
|
935
935
|
|
936
936
|
class RbLLaMAContext {
|
@@ -939,22 +939,22 @@ public:
|
|
939
939
|
LLaMAContextWrapper* ptr = (LLaMAContextWrapper*)ruby_xmalloc(sizeof(LLaMAContextWrapper));
|
940
940
|
new (ptr) LLaMAContextWrapper();
|
941
941
|
return TypedData_Wrap_Struct(self, &llama_context_type, ptr);
|
942
|
-
}
|
942
|
+
}
|
943
943
|
|
944
944
|
static void llama_context_free(void* ptr) {
|
945
945
|
((LLaMAContextWrapper*)ptr)->~LLaMAContextWrapper();
|
946
946
|
ruby_xfree(ptr);
|
947
|
-
}
|
947
|
+
}
|
948
948
|
|
949
949
|
static size_t llama_context_size(const void* ptr) {
|
950
950
|
return sizeof(*((LLaMAContextWrapper*)ptr));
|
951
|
-
}
|
951
|
+
}
|
952
952
|
|
953
953
|
static LLaMAContextWrapper* get_llama_context(VALUE self) {
|
954
954
|
LLaMAContextWrapper* ptr;
|
955
955
|
TypedData_Get_Struct(self, LLaMAContextWrapper, &llama_context_type, ptr);
|
956
956
|
return ptr;
|
957
|
-
}
|
957
|
+
}
|
958
958
|
|
959
959
|
static void define_class(VALUE outer) {
|
960
960
|
rb_cLLaMAContext = rb_define_class_under(outer, "Context", rb_cObject);
|
@@ -980,6 +980,7 @@ public:
|
|
980
980
|
rb_define_method(rb_cLLaMAContext, "save_session_file", RUBY_METHOD_FUNC(_llama_context_save_session_file), -1);
|
981
981
|
rb_define_method(rb_cLLaMAContext, "sample_repetition_penalty", RUBY_METHOD_FUNC(_llama_context_sample_repetition_penalty), -1);
|
982
982
|
rb_define_method(rb_cLLaMAContext, "sample_frequency_and_presence_penalties", RUBY_METHOD_FUNC(_llama_context_sample_frequency_and_presence_penalties), -1);
|
983
|
+
rb_define_method(rb_cLLaMAContext, "sample_classifier_free_guidance", RUBY_METHOD_FUNC(_llama_context_sample_classifier_free_guidance), -1);
|
983
984
|
rb_define_method(rb_cLLaMAContext, "sample_softmax", RUBY_METHOD_FUNC(_llama_context_sample_softmax), 1);
|
984
985
|
rb_define_method(rb_cLLaMAContext, "sample_top_k", RUBY_METHOD_FUNC(_llama_context_sample_top_k), -1);
|
985
986
|
rb_define_method(rb_cLLaMAContext, "sample_top_p", RUBY_METHOD_FUNC(_llama_context_sample_top_p), -1);
|
@@ -990,7 +991,7 @@ public:
|
|
990
991
|
rb_define_method(rb_cLLaMAContext, "sample_token_mirostat_v2", RUBY_METHOD_FUNC(_llama_context_sample_token_mirostat_v2), -1);
|
991
992
|
rb_define_method(rb_cLLaMAContext, "sample_token_greedy", RUBY_METHOD_FUNC(_llama_context_sample_token_greedy), 1);
|
992
993
|
rb_define_method(rb_cLLaMAContext, "sample_token", RUBY_METHOD_FUNC(_llama_context_sample_token), 1);
|
993
|
-
}
|
994
|
+
}
|
994
995
|
|
995
996
|
private:
|
996
997
|
static const rb_data_type_t llama_context_type;
|
@@ -1029,7 +1030,7 @@ private:
|
|
1029
1030
|
rb_iv_set(self, "@has_evaluated", Qfalse);
|
1030
1031
|
|
1031
1032
|
return Qnil;
|
1032
|
-
}
|
1033
|
+
}
|
1033
1034
|
|
1034
1035
|
static VALUE _llama_context_eval(int argc, VALUE* argv, VALUE self) {
|
1035
1036
|
VALUE kw_args = Qnil;
|
@@ -1084,7 +1085,7 @@ private:
|
|
1084
1085
|
rb_iv_set(self, "@has_evaluated", Qtrue);
|
1085
1086
|
|
1086
1087
|
return Qnil;
|
1087
|
-
}
|
1088
|
+
}
|
1088
1089
|
|
1089
1090
|
static VALUE _llama_context_eval_embd(int argc, VALUE* argv, VALUE self) {
|
1090
1091
|
VALUE kw_args = Qnil;
|
@@ -1157,7 +1158,7 @@ private:
|
|
1157
1158
|
}
|
1158
1159
|
RB_GC_GUARD(fname_);
|
1159
1160
|
return Qtrue;
|
1160
|
-
}
|
1161
|
+
}
|
1161
1162
|
|
1162
1163
|
static VALUE _llama_context_tokenize(int argc, VALUE* argv, VALUE self) {
|
1163
1164
|
VALUE kw_args = Qnil;
|
@@ -1203,7 +1204,7 @@ private:
|
|
1203
1204
|
|
1204
1205
|
RB_GC_GUARD(text_);
|
1205
1206
|
return output;
|
1206
|
-
}
|
1207
|
+
}
|
1207
1208
|
|
1208
1209
|
static VALUE _llama_context_token_to_str(VALUE self, VALUE token_) {
|
1209
1210
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
@@ -1214,7 +1215,7 @@ private:
|
|
1214
1215
|
const llama_token token = NUM2INT(token_);
|
1215
1216
|
const char* str = llama_token_to_str(ptr->ctx, token);
|
1216
1217
|
return str != nullptr ? rb_utf8_str_new_cstr(str) : rb_utf8_str_new_cstr("");
|
1217
|
-
}
|
1218
|
+
}
|
1218
1219
|
|
1219
1220
|
static VALUE _llama_context_logits(VALUE self) {
|
1220
1221
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
@@ -1239,7 +1240,7 @@ private:
|
|
1239
1240
|
}
|
1240
1241
|
|
1241
1242
|
return output;
|
1242
|
-
}
|
1243
|
+
}
|
1243
1244
|
|
1244
1245
|
static VALUE _llama_context_embeddings(VALUE self) {
|
1245
1246
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
@@ -1267,7 +1268,7 @@ private:
|
|
1267
1268
|
}
|
1268
1269
|
|
1269
1270
|
return output;
|
1270
|
-
}
|
1271
|
+
}
|
1271
1272
|
|
1272
1273
|
static VALUE _llama_context_vocab(int argc, VALUE* argv, VALUE self) {
|
1273
1274
|
VALUE kw_args = Qnil;
|
@@ -1304,7 +1305,7 @@ private:
|
|
1304
1305
|
}
|
1305
1306
|
|
1306
1307
|
return rb_ary_new_from_args(2, ret_strings, ret_scores);
|
1307
|
-
}
|
1308
|
+
}
|
1308
1309
|
|
1309
1310
|
static VALUE _llama_context_n_vocab(VALUE self) {
|
1310
1311
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
@@ -1313,7 +1314,7 @@ private:
|
|
1313
1314
|
return Qnil;
|
1314
1315
|
}
|
1315
1316
|
return INT2NUM(llama_n_vocab(ptr->ctx));
|
1316
|
-
}
|
1317
|
+
}
|
1317
1318
|
|
1318
1319
|
static VALUE _llama_context_n_ctx(VALUE self) {
|
1319
1320
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
@@ -1322,7 +1323,7 @@ private:
|
|
1322
1323
|
return Qnil;
|
1323
1324
|
}
|
1324
1325
|
return INT2NUM(llama_n_ctx(ptr->ctx));
|
1325
|
-
}
|
1326
|
+
}
|
1326
1327
|
|
1327
1328
|
static VALUE _llama_context_n_embd(VALUE self) {
|
1328
1329
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
@@ -1331,7 +1332,7 @@ private:
|
|
1331
1332
|
return Qnil;
|
1332
1333
|
}
|
1333
1334
|
return INT2NUM(llama_n_embd(ptr->ctx));
|
1334
|
-
}
|
1335
|
+
}
|
1335
1336
|
|
1336
1337
|
static VALUE _llama_context_get_timings(VALUE self) {
|
1337
1338
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
@@ -1353,7 +1354,7 @@ private:
|
|
1353
1354
|
}
|
1354
1355
|
llama_print_timings(ptr->ctx);
|
1355
1356
|
return Qnil;
|
1356
|
-
}
|
1357
|
+
}
|
1357
1358
|
|
1358
1359
|
static VALUE _llama_context_reset_timings(VALUE self) {
|
1359
1360
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
@@ -1363,7 +1364,7 @@ private:
|
|
1363
1364
|
}
|
1364
1365
|
llama_reset_timings(ptr->ctx);
|
1365
1366
|
return Qnil;
|
1366
|
-
}
|
1367
|
+
}
|
1367
1368
|
|
1368
1369
|
static VALUE _llama_context_kv_cache_token_count(VALUE self) {
|
1369
1370
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
@@ -1372,7 +1373,7 @@ private:
|
|
1372
1373
|
return Qnil;
|
1373
1374
|
}
|
1374
1375
|
return INT2NUM(llama_get_kv_cache_token_count(ptr->ctx));
|
1375
|
-
}
|
1376
|
+
}
|
1376
1377
|
|
1377
1378
|
static VALUE _llama_context_set_rng_seed(VALUE self, VALUE seed_) {
|
1378
1379
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
@@ -1387,7 +1388,7 @@ private:
|
|
1387
1388
|
const uint32_t seed = NUM2INT(seed_);
|
1388
1389
|
llama_set_rng_seed(ptr->ctx, seed);
|
1389
1390
|
return Qnil;
|
1390
|
-
}
|
1391
|
+
}
|
1391
1392
|
|
1392
1393
|
static VALUE _llama_context_load_session_file(int argc, VALUE* argv, VALUE self) {
|
1393
1394
|
VALUE kw_args = Qnil;
|
@@ -1525,7 +1526,7 @@ private:
|
|
1525
1526
|
llama_sample_repetition_penalty(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size, penalty);
|
1526
1527
|
|
1527
1528
|
return Qnil;
|
1528
|
-
}
|
1529
|
+
}
|
1529
1530
|
|
1530
1531
|
static VALUE _llama_context_sample_frequency_and_presence_penalties(int argc, VALUE* argv, VALUE self) {
|
1531
1532
|
VALUE kw_args = Qnil;
|
@@ -1576,7 +1577,52 @@ private:
|
|
1576
1577
|
llama_sample_frequency_and_presence_penalties(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size, alpha_frequency, alpha_presence);
|
1577
1578
|
|
1578
1579
|
return Qnil;
|
1579
|
-
}
|
1580
|
+
}
|
1581
|
+
|
1582
|
+
static VALUE _llama_context_sample_classifier_free_guidance(int argc, VALUE* argv, VALUE self) {
|
1583
|
+
VALUE kw_args = Qnil;
|
1584
|
+
ID kw_table[3] = { rb_intern("guidance"), rb_intern("scale"), rb_intern("smooth_factor") };
|
1585
|
+
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
1586
|
+
VALUE candidates = Qnil;
|
1587
|
+
rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
|
1588
|
+
rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
|
1589
|
+
|
1590
|
+
if (!rb_obj_is_kind_of(kw_values[0], rb_cLLaMAContext)) {
|
1591
|
+
rb_raise(rb_eArgError, "guidance must be a Context");
|
1592
|
+
return Qnil;
|
1593
|
+
}
|
1594
|
+
if (!RB_FLOAT_TYPE_P(kw_values[1])) {
|
1595
|
+
rb_raise(rb_eArgError, "scale must be a float");
|
1596
|
+
return Qnil;
|
1597
|
+
}
|
1598
|
+
if (!RB_FLOAT_TYPE_P(kw_values[2])) {
|
1599
|
+
rb_raise(rb_eArgError, "smooth_factor must be a float");
|
1600
|
+
return Qnil;
|
1601
|
+
}
|
1602
|
+
|
1603
|
+
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
1604
|
+
if (ctx_ptr->ctx == NULL) {
|
1605
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1606
|
+
return Qnil;
|
1607
|
+
}
|
1608
|
+
LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
|
1609
|
+
if (cnd_ptr->array.data == nullptr) {
|
1610
|
+
rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
|
1611
|
+
return Qnil;
|
1612
|
+
}
|
1613
|
+
|
1614
|
+
LLaMAContextWrapper* guidance_ptr = get_llama_context(kw_values[0]);
|
1615
|
+
if (guidance_ptr->ctx == NULL) {
|
1616
|
+
rb_raise(rb_eRuntimeError, "guidance context is not initialized");
|
1617
|
+
return Qnil;
|
1618
|
+
}
|
1619
|
+
const float scale = NUM2DBL(kw_values[1]);
|
1620
|
+
const float smooth_factor = NUM2DBL(kw_values[2]);
|
1621
|
+
|
1622
|
+
llama_sample_classifier_free_guidance(ctx_ptr->ctx, &(cnd_ptr->array), guidance_ptr->ctx, scale, smooth_factor);
|
1623
|
+
|
1624
|
+
return Qnil;
|
1625
|
+
}
|
1580
1626
|
|
1581
1627
|
static VALUE _llama_context_sample_softmax(VALUE self, VALUE candidates) {
|
1582
1628
|
if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
|
@@ -1598,7 +1644,7 @@ private:
|
|
1598
1644
|
llama_sample_softmax(ctx_ptr->ctx, &(cnd_ptr->array));
|
1599
1645
|
|
1600
1646
|
return Qnil;
|
1601
|
-
}
|
1647
|
+
}
|
1602
1648
|
|
1603
1649
|
static VALUE _llama_context_sample_top_k(int argc, VALUE* argv, VALUE self) {
|
1604
1650
|
VALUE kw_args = Qnil;
|
@@ -1637,7 +1683,7 @@ private:
|
|
1637
1683
|
llama_sample_top_k(ctx_ptr->ctx, &(cnd_ptr->array), k, min_keep);
|
1638
1684
|
|
1639
1685
|
return Qnil;
|
1640
|
-
}
|
1686
|
+
}
|
1641
1687
|
|
1642
1688
|
static VALUE _llama_context_sample_top_p(int argc, VALUE* argv, VALUE self) {
|
1643
1689
|
VALUE kw_args = Qnil;
|
@@ -1676,7 +1722,7 @@ private:
|
|
1676
1722
|
llama_sample_top_p(ctx_ptr->ctx, &(cnd_ptr->array), prob, min_keep);
|
1677
1723
|
|
1678
1724
|
return Qnil;
|
1679
|
-
}
|
1725
|
+
}
|
1680
1726
|
|
1681
1727
|
static VALUE _llama_context_sample_tail_free(int argc, VALUE* argv, VALUE self) {
|
1682
1728
|
VALUE kw_args = Qnil;
|
@@ -1715,7 +1761,7 @@ private:
|
|
1715
1761
|
llama_sample_tail_free(ctx_ptr->ctx, &(cnd_ptr->array), z, min_keep);
|
1716
1762
|
|
1717
1763
|
return Qnil;
|
1718
|
-
}
|
1764
|
+
}
|
1719
1765
|
|
1720
1766
|
static VALUE _llama_context_sample_typical(int argc, VALUE* argv, VALUE self) {
|
1721
1767
|
VALUE kw_args = Qnil;
|
@@ -1754,7 +1800,7 @@ private:
|
|
1754
1800
|
llama_sample_typical(ctx_ptr->ctx, &(cnd_ptr->array), prob, min_keep);
|
1755
1801
|
|
1756
1802
|
return Qnil;
|
1757
|
-
}
|
1803
|
+
}
|
1758
1804
|
|
1759
1805
|
static VALUE _llama_context_sample_temperature(int argc, VALUE* argv, VALUE self) {
|
1760
1806
|
VALUE kw_args = Qnil;
|
@@ -1788,7 +1834,7 @@ private:
|
|
1788
1834
|
llama_sample_temperature(ctx_ptr->ctx, &(cnd_ptr->array), temperature);
|
1789
1835
|
|
1790
1836
|
return Qnil;
|
1791
|
-
}
|
1837
|
+
}
|
1792
1838
|
|
1793
1839
|
static VALUE _llama_context_sample_token_mirostat(int argc, VALUE* argv, VALUE self) {
|
1794
1840
|
VALUE kw_args = Qnil;
|
@@ -1840,7 +1886,7 @@ private:
|
|
1840
1886
|
rb_ary_store(ret, 0, INT2NUM(id));
|
1841
1887
|
rb_ary_store(ret, 1, DBL2NUM(mu));
|
1842
1888
|
return ret;
|
1843
|
-
}
|
1889
|
+
}
|
1844
1890
|
|
1845
1891
|
static VALUE _llama_context_sample_token_mirostat_v2(int argc, VALUE* argv, VALUE self) {
|
1846
1892
|
VALUE kw_args = Qnil;
|
@@ -1887,7 +1933,7 @@ private:
|
|
1887
1933
|
rb_ary_store(ret, 0, INT2NUM(id));
|
1888
1934
|
rb_ary_store(ret, 1, DBL2NUM(mu));
|
1889
1935
|
return ret;
|
1890
|
-
}
|
1936
|
+
}
|
1891
1937
|
|
1892
1938
|
static VALUE _llama_context_sample_token_greedy(VALUE self, VALUE candidates) {
|
1893
1939
|
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
@@ -1906,7 +1952,7 @@ private:
|
|
1906
1952
|
}
|
1907
1953
|
llama_token id = llama_sample_token_greedy(ctx_ptr->ctx, &(cnd_ptr->array));
|
1908
1954
|
return INT2NUM(id);
|
1909
|
-
}
|
1955
|
+
}
|
1910
1956
|
|
1911
1957
|
static VALUE _llama_context_sample_token(VALUE self, VALUE candidates) {
|
1912
1958
|
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
@@ -1925,7 +1971,7 @@ private:
|
|
1925
1971
|
}
|
1926
1972
|
llama_token id = llama_sample_token(ctx_ptr->ctx, &(cnd_ptr->array));
|
1927
1973
|
return INT2NUM(id);
|
1928
|
-
}
|
1974
|
+
}
|
1929
1975
|
};
|
1930
1976
|
|
1931
1977
|
const rb_data_type_t RbLLaMAContext::llama_context_type = {
|
@@ -1940,7 +1986,7 @@ const rb_data_type_t RbLLaMAContext::llama_context_type = {
|
|
1940
1986
|
|
1941
1987
|
// module functions
|
1942
1988
|
|
1943
|
-
static VALUE
|
1989
|
+
static VALUE rb_llama_llama_backend_init(int argc, VALUE* argv, VALUE self) {
|
1944
1990
|
VALUE kw_args = Qnil;
|
1945
1991
|
ID kw_table[1] = { rb_intern("numa") };
|
1946
1992
|
VALUE kw_values[1] = { Qundef };
|
@@ -1948,7 +1994,13 @@ static VALUE rb_llama_llama_init_backend(int argc, VALUE* argv, VALUE self) {
|
|
1948
1994
|
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
1949
1995
|
|
1950
1996
|
const bool numa = kw_values[0] == Qundef ? false : (RTEST ? true : false);
|
1951
|
-
|
1997
|
+
llama_backend_init(numa);
|
1998
|
+
|
1999
|
+
return Qnil;
|
2000
|
+
}
|
2001
|
+
|
2002
|
+
static VALUE rb_llama_llama_backend_free(VALUE self) {
|
2003
|
+
llama_backend_free();
|
1952
2004
|
|
1953
2005
|
return Qnil;
|
1954
2006
|
}
|
@@ -2021,7 +2073,8 @@ extern "C" void Init_llama_cpp(void) {
|
|
2021
2073
|
RbLLaMAContextParams::define_class(rb_mLLaMACpp);
|
2022
2074
|
RbLLaMAModelQuantizeParams::define_class(rb_mLLaMACpp);
|
2023
2075
|
|
2024
|
-
rb_define_module_function(rb_mLLaMACpp, "
|
2076
|
+
rb_define_module_function(rb_mLLaMACpp, "backend_init", rb_llama_llama_backend_init, -1);
|
2077
|
+
rb_define_module_function(rb_mLLaMACpp, "backend_free", rb_llama_llama_backend_free, 0);
|
2025
2078
|
rb_define_module_function(rb_mLLaMACpp, "model_quantize", rb_llama_model_quantize, -1);
|
2026
2079
|
rb_define_module_function(rb_mLLaMACpp, "token_bos", rb_llama_token_bos, 0);
|
2027
2080
|
rb_define_module_function(rb_mLLaMACpp, "token_eos", rb_llama_token_eos, 0);
|