llama_cpp 0.3.2 → 0.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -17,9 +17,9 @@ public:
17
17
  data.id = 0;
18
18
  data.logit = 0.0;
19
19
  data.p = 0.0;
20
- };
20
+ }
21
21
 
22
- ~LLaMATokenDataWrapper(){};
22
+ ~LLaMATokenDataWrapper() {}
23
23
  };
24
24
 
25
25
  class RbLLaMATokenData {
@@ -28,22 +28,22 @@ public:
28
28
  LLaMATokenDataWrapper* ptr = (LLaMATokenDataWrapper*)ruby_xmalloc(sizeof(LLaMATokenDataWrapper));
29
29
  new (ptr) LLaMATokenDataWrapper();
30
30
  return TypedData_Wrap_Struct(self, &llama_token_data_type, ptr);
31
- };
31
+ }
32
32
 
33
33
  static void llama_token_data_free(void* ptr) {
34
34
  ((LLaMATokenDataWrapper*)ptr)->~LLaMATokenDataWrapper();
35
35
  ruby_xfree(ptr);
36
- };
36
+ }
37
37
 
38
38
  static size_t llama_token_data_size(const void* ptr) {
39
39
  return sizeof(*((LLaMATokenDataWrapper*)ptr));
40
- };
40
+ }
41
41
 
42
42
  static LLaMATokenDataWrapper* get_llama_token_data(VALUE self) {
43
43
  LLaMATokenDataWrapper* ptr;
44
44
  TypedData_Get_Struct(self, LLaMATokenDataWrapper, &llama_token_data_type, ptr);
45
45
  return ptr;
46
- };
46
+ }
47
47
 
48
48
  static void define_class(VALUE outer) {
49
49
  rb_cLLaMATokenData = rb_define_class_under(outer, "TokenData", rb_cObject);
@@ -95,36 +95,36 @@ private:
95
95
  LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
96
96
  ptr->data.id = NUM2INT(id);
97
97
  return INT2NUM(ptr->data.id);
98
- };
98
+ }
99
99
 
100
100
  static VALUE _llama_token_data_get_id(VALUE self) {
101
101
  LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
102
102
  return INT2NUM(ptr->data.id);
103
- };
103
+ }
104
104
 
105
105
  // logit
106
106
  static VALUE _llama_token_data_set_logit(VALUE self, VALUE logit) {
107
107
  LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
108
108
  ptr->data.logit = NUM2DBL(logit);
109
109
  return DBL2NUM(ptr->data.logit);
110
- };
110
+ }
111
111
 
112
112
  static VALUE _llama_token_data_get_logit(VALUE self) {
113
113
  LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
114
114
  return DBL2NUM(ptr->data.logit);
115
- };
115
+ }
116
116
 
117
117
  // p
118
118
  static VALUE _llama_token_data_set_p(VALUE self, VALUE p) {
119
119
  LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
120
120
  ptr->data.p = NUM2DBL(p);
121
121
  return DBL2NUM(ptr->data.p);
122
- };
122
+ }
123
123
 
124
124
  static VALUE _llama_token_data_get_p(VALUE self) {
125
125
  LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
126
126
  return DBL2NUM(ptr->data.p);
127
- };
127
+ }
128
128
  };
129
129
 
130
130
  const rb_data_type_t RbLLaMATokenData::llama_token_data_type = {
@@ -145,14 +145,14 @@ public:
145
145
  array.data = nullptr;
146
146
  array.size = 0;
147
147
  array.sorted = false;
148
- };
148
+ }
149
149
 
150
150
  ~LLaMATokenDataArrayWrapper() {
151
151
  if (array.data) {
152
152
  ruby_xfree(array.data);
153
153
  array.data = nullptr;
154
154
  }
155
- };
155
+ }
156
156
  };
157
157
 
158
158
  class RbLLaMATokenDataArray {
@@ -161,22 +161,22 @@ public:
161
161
  LLaMATokenDataArrayWrapper* ptr = (LLaMATokenDataArrayWrapper*)ruby_xmalloc(sizeof(LLaMATokenDataArrayWrapper));
162
162
  new (ptr) LLaMATokenDataArrayWrapper();
163
163
  return TypedData_Wrap_Struct(self, &llama_token_data_array_type, ptr);
164
- };
164
+ }
165
165
 
166
166
  static void llama_token_data_array_free(void* ptr) {
167
167
  ((LLaMATokenDataArrayWrapper*)ptr)->~LLaMATokenDataArrayWrapper();
168
168
  ruby_xfree(ptr);
169
- };
169
+ }
170
170
 
171
171
  static size_t llama_token_data_array_size(const void* ptr) {
172
172
  return sizeof(*((LLaMATokenDataArrayWrapper*)ptr));
173
- };
173
+ }
174
174
 
175
175
  static LLaMATokenDataArrayWrapper* get_llama_token_data_array(VALUE self) {
176
176
  LLaMATokenDataArrayWrapper* ptr;
177
177
  TypedData_Get_Struct(self, LLaMATokenDataArrayWrapper, &llama_token_data_array_type, ptr);
178
178
  return ptr;
179
- };
179
+ }
180
180
 
181
181
  static void define_class(VALUE outer) {
182
182
  rb_cLLaMATokenDataArray = rb_define_class_under(outer, "TokenDataArray", rb_cObject);
@@ -184,7 +184,7 @@ public:
184
184
  rb_define_method(rb_cLLaMATokenDataArray, "initialize", RUBY_METHOD_FUNC(_llama_token_data_array_init), -1);
185
185
  rb_define_method(rb_cLLaMATokenDataArray, "size", RUBY_METHOD_FUNC(_llama_token_data_array_get_size), 0);
186
186
  rb_define_method(rb_cLLaMATokenDataArray, "sorted", RUBY_METHOD_FUNC(_llama_token_data_array_get_sorted), 0);
187
- };
187
+ }
188
188
 
189
189
  private:
190
190
  static const rb_data_type_t llama_token_data_array_type;
@@ -233,17 +233,17 @@ private:
233
233
  ptr->array.sorted = kw_values[0] == Qtrue;
234
234
 
235
235
  return self;
236
- };
236
+ }
237
237
 
238
238
  static VALUE _llama_token_data_array_get_size(VALUE self) {
239
239
  LLaMATokenDataArrayWrapper* ptr = get_llama_token_data_array(self);
240
240
  return SIZET2NUM(ptr->array.size);
241
- };
241
+ }
242
242
 
243
243
  static VALUE _llama_token_data_array_get_sorted(VALUE self) {
244
244
  LLaMATokenDataArrayWrapper* ptr = get_llama_token_data_array(self);
245
245
  return ptr->array.sorted ? Qtrue : Qfalse;
246
- };
246
+ }
247
247
  };
248
248
 
249
249
  const rb_data_type_t RbLLaMATokenDataArray::llama_token_data_array_type = {
@@ -260,9 +260,9 @@ class LLaMATimingsWrapper {
260
260
  public:
261
261
  struct llama_timings timings;
262
262
 
263
- LLaMATimingsWrapper(){};
263
+ LLaMATimingsWrapper() {}
264
264
 
265
- ~LLaMATimingsWrapper(){};
265
+ ~LLaMATimingsWrapper() {}
266
266
  };
267
267
 
268
268
  class RbLLaMATimings {
@@ -365,9 +365,9 @@ class LLaMAContextParamsWrapper {
365
365
  public:
366
366
  struct llama_context_params params;
367
367
 
368
- LLaMAContextParamsWrapper() : params(llama_context_default_params()){};
368
+ LLaMAContextParamsWrapper() : params(llama_context_default_params()) {}
369
369
 
370
- ~LLaMAContextParamsWrapper(){};
370
+ ~LLaMAContextParamsWrapper() {}
371
371
  };
372
372
 
373
373
  class RbLLaMAContextParams {
@@ -376,22 +376,22 @@ public:
376
376
  LLaMAContextParamsWrapper* ptr = (LLaMAContextParamsWrapper*)ruby_xmalloc(sizeof(LLaMAContextParamsWrapper));
377
377
  new (ptr) LLaMAContextParamsWrapper();
378
378
  return TypedData_Wrap_Struct(self, &llama_context_params_type, ptr);
379
- };
379
+ }
380
380
 
381
381
  static void llama_context_params_free(void* ptr) {
382
382
  ((LLaMAContextParamsWrapper*)ptr)->~LLaMAContextParamsWrapper();
383
383
  ruby_xfree(ptr);
384
- };
384
+ }
385
385
 
386
386
  static size_t llama_context_params_size(const void* ptr) {
387
387
  return sizeof(*((LLaMAContextParamsWrapper*)ptr));
388
- };
388
+ }
389
389
 
390
390
  static LLaMAContextParamsWrapper* get_llama_context_params(VALUE self) {
391
391
  LLaMAContextParamsWrapper* ptr;
392
392
  TypedData_Get_Struct(self, LLaMAContextParamsWrapper, &llama_context_params_type, ptr);
393
393
  return ptr;
394
- };
394
+ }
395
395
 
396
396
  static void define_class(VALUE outer) {
397
397
  rb_cLLaMAContextParams = rb_define_class_under(outer, "ContextParams", rb_cObject);
@@ -406,6 +406,10 @@ public:
406
406
  rb_define_method(rb_cLLaMAContextParams, "main_gpu=", RUBY_METHOD_FUNC(_llama_context_params_set_main_gpu), 1);
407
407
  rb_define_method(rb_cLLaMAContextParams, "main_gpu", RUBY_METHOD_FUNC(_llama_context_params_get_main_gpu), 0);
408
408
  rb_define_method(rb_cLLaMAContextParams, "tensor_split", RUBY_METHOD_FUNC(_llama_context_params_get_tensor_split), 0);
409
+ rb_define_method(rb_cLLaMAContextParams, "rope_freq_base=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_freq_base), 1);
410
+ rb_define_method(rb_cLLaMAContextParams, "rope_freq_base", RUBY_METHOD_FUNC(_llama_context_params_get_rope_freq_base), 0);
411
+ rb_define_method(rb_cLLaMAContextParams, "rope_freq_scale=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_freq_scale), 1);
412
+ rb_define_method(rb_cLLaMAContextParams, "rope_freq_scale", RUBY_METHOD_FUNC(_llama_context_params_get_rope_freq_scale), 0);
409
413
  rb_define_method(rb_cLLaMAContextParams, "low_vram=", RUBY_METHOD_FUNC(_llama_context_params_set_low_vram), 1);
410
414
  rb_define_method(rb_cLLaMAContextParams, "low_vram", RUBY_METHOD_FUNC(_llama_context_params_get_low_vram), 0);
411
415
  rb_define_method(rb_cLLaMAContextParams, "seed=", RUBY_METHOD_FUNC(_llama_context_params_set_seed), 1);
@@ -422,7 +426,7 @@ public:
422
426
  rb_define_method(rb_cLLaMAContextParams, "use_mlock", RUBY_METHOD_FUNC(_llama_context_params_get_use_mlock), 0);
423
427
  rb_define_method(rb_cLLaMAContextParams, "embedding=", RUBY_METHOD_FUNC(_llama_context_params_set_embedding), 1);
424
428
  rb_define_method(rb_cLLaMAContextParams, "embedding", RUBY_METHOD_FUNC(_llama_context_params_get_embedding), 0);
425
- };
429
+ }
426
430
 
427
431
  private:
428
432
  static const rb_data_type_t llama_context_params_type;
@@ -431,55 +435,55 @@ private:
431
435
  // LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
432
436
  // new (ptr) LLaMAContextParamsWrapper();
433
437
  // return self;
434
- // };
438
+ // }
435
439
 
436
440
  // n_ctx
437
441
  static VALUE _llama_context_params_set_n_ctx(VALUE self, VALUE n_ctx) {
438
442
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
439
443
  ptr->params.n_ctx = NUM2INT(n_ctx);
440
444
  return INT2NUM(ptr->params.n_ctx);
441
- };
445
+ }
442
446
 
443
447
  static VALUE _llama_context_params_get_n_ctx(VALUE self) {
444
448
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
445
449
  return INT2NUM(ptr->params.n_ctx);
446
- };
450
+ }
447
451
 
448
452
  // n_batch
449
453
  static VALUE _llama_context_params_set_n_batch(VALUE self, VALUE n_batch) {
450
454
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
451
455
  ptr->params.n_batch = NUM2INT(n_batch);
452
456
  return INT2NUM(ptr->params.n_batch);
453
- };
457
+ }
454
458
 
455
459
  static VALUE _llama_context_params_get_n_batch(VALUE self) {
456
460
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
457
461
  return INT2NUM(ptr->params.n_batch);
458
- };
462
+ }
459
463
 
460
464
  // n_gpu_layers
461
465
  static VALUE _llama_context_params_set_n_gpu_layers(VALUE self, VALUE n_gpu_layers) {
462
466
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
463
467
  ptr->params.n_gpu_layers = NUM2INT(n_gpu_layers);
464
468
  return INT2NUM(ptr->params.n_gpu_layers);
465
- };
469
+ }
466
470
 
467
471
  static VALUE _llama_context_params_get_n_gpu_layers(VALUE self) {
468
472
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
469
473
  return INT2NUM(ptr->params.n_gpu_layers);
470
- };
474
+ }
471
475
 
472
476
  // main_gpu
473
477
  static VALUE _llama_context_params_set_main_gpu(VALUE self, VALUE main_gpu) {
474
478
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
475
479
  ptr->params.main_gpu = NUM2INT(main_gpu);
476
480
  return INT2NUM(ptr->params.main_gpu);
477
- };
481
+ }
478
482
 
479
483
  static VALUE _llama_context_params_get_main_gpu(VALUE self) {
480
484
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
481
485
  return INT2NUM(ptr->params.main_gpu);
482
- };
486
+ }
483
487
 
484
488
  // tensor_split
485
489
  static VALUE _llama_context_params_get_tensor_split(VALUE self) {
@@ -492,19 +496,43 @@ private:
492
496
  rb_ary_store(ret, i, DBL2NUM(ptr->params.tensor_split[i]));
493
497
  }
494
498
  return ret;
495
- };
499
+ }
500
+
501
+ // rope_freq_base
502
+ static VALUE _llama_context_params_set_rope_freq_base(VALUE self, VALUE rope_freq_base) {
503
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
504
+ ptr->params.rope_freq_base = NUM2DBL(rope_freq_base);
505
+ return DBL2NUM(ptr->params.rope_freq_base);
506
+ }
507
+
508
+ static VALUE _llama_context_params_get_rope_freq_base(VALUE self) {
509
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
510
+ return DBL2NUM(ptr->params.rope_freq_base);
511
+ }
512
+
513
+ // rope_freq_scale
514
+ static VALUE _llama_context_params_set_rope_freq_scale(VALUE self, VALUE rope_freq_scale) {
515
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
516
+ ptr->params.rope_freq_scale = NUM2DBL(rope_freq_scale);
517
+ return DBL2NUM(ptr->params.rope_freq_scale);
518
+ }
519
+
520
+ static VALUE _llama_context_params_get_rope_freq_scale(VALUE self) {
521
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
522
+ return DBL2NUM(ptr->params.rope_freq_scale);
523
+ }
496
524
 
497
525
  // low_vram
498
526
  static VALUE _llama_context_params_set_low_vram(VALUE self, VALUE low_vram) {
499
527
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
500
528
  ptr->params.low_vram = low_vram == Qtrue ? true : false;
501
529
  return ptr->params.low_vram ? Qtrue : Qfalse;
502
- };
530
+ }
503
531
 
504
532
  static VALUE _llama_context_params_get_low_vram(VALUE self) {
505
533
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
506
534
  return ptr->params.low_vram ? Qtrue : Qfalse;
507
- };
535
+ }
508
536
 
509
537
  // seed
510
538
  static VALUE _llama_context_params_set_seed(VALUE self, VALUE seed) {
@@ -515,84 +543,84 @@ private:
515
543
  }
516
544
  ptr->params.seed = NUM2INT(seed);
517
545
  return INT2NUM(ptr->params.seed);
518
- };
546
+ }
519
547
 
520
548
  static VALUE _llama_context_params_get_seed(VALUE self) {
521
549
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
522
550
  return INT2NUM(ptr->params.seed);
523
- };
551
+ }
524
552
 
525
553
  // f16_kv
526
554
  static VALUE _llama_context_params_set_f16_kv(VALUE self, VALUE f16_kv) {
527
555
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
528
556
  ptr->params.f16_kv = f16_kv == Qtrue ? true : false;
529
557
  return ptr->params.f16_kv ? Qtrue : Qfalse;
530
- };
558
+ }
531
559
 
532
560
  static VALUE _llama_context_params_get_f16_kv(VALUE self) {
533
561
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
534
562
  return ptr->params.f16_kv ? Qtrue : Qfalse;
535
- };
563
+ }
536
564
 
537
565
  // logits_all
538
566
  static VALUE _llama_context_params_set_logits_all(VALUE self, VALUE logits_all) {
539
567
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
540
568
  ptr->params.logits_all = logits_all == Qtrue ? true : false;
541
569
  return ptr->params.logits_all ? Qtrue : Qfalse;
542
- };
570
+ }
543
571
 
544
572
  static VALUE _llama_context_params_get_logits_all(VALUE self) {
545
573
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
546
574
  return ptr->params.logits_all ? Qtrue : Qfalse;
547
- };
575
+ }
548
576
 
549
577
  // vocab_only
550
578
  static VALUE _llama_context_params_set_vocab_only(VALUE self, VALUE vocab_only) {
551
579
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
552
580
  ptr->params.vocab_only = vocab_only == Qtrue ? true : false;
553
581
  return ptr->params.vocab_only ? Qtrue : Qfalse;
554
- };
582
+ }
555
583
 
556
584
  static VALUE _llama_context_params_get_vocab_only(VALUE self) {
557
585
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
558
586
  return ptr->params.vocab_only ? Qtrue : Qfalse;
559
- };
587
+ }
560
588
 
561
589
  // use_mmap
562
590
  static VALUE _llama_context_params_set_use_mmap(VALUE self, VALUE use_mmap) {
563
591
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
564
592
  ptr->params.use_mmap = use_mmap == Qtrue ? true : false;
565
593
  return ptr->params.use_mmap ? Qtrue : Qfalse;
566
- };
594
+ }
567
595
 
568
596
  static VALUE _llama_context_params_get_use_mmap(VALUE self) {
569
597
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
570
598
  return ptr->params.use_mmap ? Qtrue : Qfalse;
571
- };
599
+ }
572
600
 
573
601
  // use_mlock
574
602
  static VALUE _llama_context_params_set_use_mlock(VALUE self, VALUE use_mlock) {
575
603
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
576
604
  ptr->params.use_mlock = use_mlock == Qtrue ? true : false;
577
605
  return ptr->params.use_mlock ? Qtrue : Qfalse;
578
- };
606
+ }
579
607
 
580
608
  static VALUE _llama_context_params_get_use_mlock(VALUE self) {
581
609
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
582
610
  return ptr->params.use_mlock ? Qtrue : Qfalse;
583
- };
611
+ }
584
612
 
585
613
  // embedding
586
614
  static VALUE _llama_context_params_set_embedding(VALUE self, VALUE embedding) {
587
615
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
588
616
  ptr->params.embedding = embedding == Qtrue ? true : false;
589
617
  return ptr->params.embedding ? Qtrue : Qfalse;
590
- };
618
+ }
591
619
 
592
620
  static VALUE _llama_context_params_get_embedding(VALUE self) {
593
621
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
594
622
  return ptr->params.embedding ? Qtrue : Qfalse;
595
- };
623
+ }
596
624
  };
597
625
 
598
626
  const rb_data_type_t RbLLaMAContextParams::llama_context_params_type = {
@@ -609,9 +637,9 @@ class LLaMAModelQuantizeParamsWrapper {
609
637
  public:
610
638
  llama_model_quantize_params params;
611
639
 
612
- LLaMAModelQuantizeParamsWrapper() : params(llama_model_quantize_default_params()){};
640
+ LLaMAModelQuantizeParamsWrapper() : params(llama_model_quantize_default_params()) {}
613
641
 
614
- ~LLaMAModelQuantizeParamsWrapper(){};
642
+ ~LLaMAModelQuantizeParamsWrapper() {}
615
643
  };
616
644
 
617
645
  class RbLLaMAModelQuantizeParams {
@@ -620,22 +648,22 @@ public:
620
648
  LLaMAModelQuantizeParamsWrapper* ptr = (LLaMAModelQuantizeParamsWrapper*)ruby_xmalloc(sizeof(LLaMAModelQuantizeParamsWrapper));
621
649
  new (ptr) LLaMAModelQuantizeParamsWrapper();
622
650
  return TypedData_Wrap_Struct(self, &llama_model_quantize_params_type, ptr);
623
- };
651
+ }
624
652
 
625
653
  static void llama_model_quantize_params_free(void* ptr) {
626
654
  ((LLaMAModelQuantizeParamsWrapper*)ptr)->~LLaMAModelQuantizeParamsWrapper();
627
655
  ruby_xfree(ptr);
628
- };
656
+ }
629
657
 
630
658
  static size_t llama_model_quantize_params_size(const void* ptr) {
631
659
  return sizeof(*((LLaMAModelQuantizeParamsWrapper*)ptr));
632
- };
660
+ }
633
661
 
634
662
  static LLaMAModelQuantizeParamsWrapper* get_llama_model_quantize_params(VALUE self) {
635
663
  LLaMAModelQuantizeParamsWrapper* ptr;
636
664
  TypedData_Get_Struct(self, LLaMAModelQuantizeParamsWrapper, &llama_model_quantize_params_type, ptr);
637
665
  return ptr;
638
- };
666
+ }
639
667
 
640
668
  static void define_class(VALUE outer) {
641
669
  rb_cLLaMAModelQuantizeParams = rb_define_class_under(outer, "ModelQuantizeParams", rb_cObject);
@@ -648,7 +676,7 @@ public:
648
676
  rb_define_method(rb_cLLaMAModelQuantizeParams, "allow_requantize", RUBY_METHOD_FUNC(_llama_model_quantize_params_get_allow_requantize), 0);
649
677
  rb_define_method(rb_cLLaMAModelQuantizeParams, "quantize_output_tensor=", RUBY_METHOD_FUNC(_llama_model_quantize_params_set_quantize_output_tensor), 1);
650
678
  rb_define_method(rb_cLLaMAModelQuantizeParams, "quantize_output_tensor", RUBY_METHOD_FUNC(_llama_model_quantize_params_get_quantize_output_tensor), 0);
651
- };
679
+ }
652
680
 
653
681
  private:
654
682
  static const rb_data_type_t llama_model_quantize_params_type;
@@ -658,24 +686,24 @@ private:
658
686
  LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
659
687
  ptr->params.nthread = NUM2INT(n_thread);
660
688
  return INT2NUM(ptr->params.nthread);
661
- };
689
+ }
662
690
 
663
691
  static VALUE _llama_model_quantize_params_get_n_thread(VALUE self) {
664
692
  LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
665
693
  return INT2NUM(ptr->params.nthread);
666
- };
694
+ }
667
695
 
668
696
  // ftype
669
697
  static VALUE _llama_model_quantize_params_set_ftype(VALUE self, VALUE ftype) {
670
698
  LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
671
699
  ptr->params.ftype = static_cast<enum llama_ftype>(NUM2INT(ftype));
672
700
  return INT2NUM(ptr->params.ftype);
673
- };
701
+ }
674
702
 
675
703
  static VALUE _llama_model_quantize_params_get_ftype(VALUE self) {
676
704
  LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
677
705
  return INT2NUM(ptr->params.ftype);
678
- };
706
+ }
679
707
 
680
708
  // allow_requantize
681
709
  static VALUE _llama_model_quantize_params_set_allow_requantize(VALUE self, VALUE allow_requantize) {
@@ -686,12 +714,12 @@ private:
686
714
  ptr->params.allow_requantize = true;
687
715
  }
688
716
  return ptr->params.allow_requantize ? Qtrue : Qfalse;
689
- };
717
+ }
690
718
 
691
719
  static VALUE _llama_model_quantize_params_get_allow_requantize(VALUE self) {
692
720
  LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
693
721
  return ptr->params.allow_requantize ? Qtrue : Qfalse;
694
- };
722
+ }
695
723
 
696
724
  // quantize_output_tensor
697
725
  static VALUE _llama_model_quantize_params_set_quantize_output_tensor(VALUE self, VALUE quantize_output_tensor) {
@@ -702,12 +730,12 @@ private:
702
730
  ptr->params.quantize_output_tensor = true;
703
731
  }
704
732
  return ptr->params.quantize_output_tensor ? Qtrue : Qfalse;
705
- };
733
+ }
706
734
 
707
735
  static VALUE _llama_model_quantize_params_get_quantize_output_tensor(VALUE self) {
708
736
  LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
709
737
  return ptr->params.quantize_output_tensor ? Qtrue : Qfalse;
710
- };
738
+ }
711
739
  };
712
740
 
713
741
  const rb_data_type_t RbLLaMAModelQuantizeParams::llama_model_quantize_params_type = {
@@ -724,13 +752,13 @@ class LLaMAModelWrapper {
724
752
  public:
725
753
  struct llama_model* model;
726
754
 
727
- LLaMAModelWrapper() : model(NULL){};
755
+ LLaMAModelWrapper() : model(NULL) {}
728
756
 
729
757
  ~LLaMAModelWrapper() {
730
758
  if (model != NULL) {
731
759
  llama_free_model(model);
732
760
  }
733
- };
761
+ }
734
762
  };
735
763
 
736
764
  class RbLLaMAModel {
@@ -764,6 +792,12 @@ public:
764
792
  rb_define_method(rb_cLLaMAModel, "free", RUBY_METHOD_FUNC(_llama_model_free), 0);
765
793
  rb_define_method(rb_cLLaMAModel, "load", RUBY_METHOD_FUNC(_llama_model_load), -1);
766
794
  rb_define_method(rb_cLLaMAModel, "apply_lora_from_file", RUBY_METHOD_FUNC(_llama_model_apply_lora_from_file), -1);
795
+ rb_define_method(rb_cLLaMAModel, "n_vocab", RUBY_METHOD_FUNC(_llama_model_get_n_vocab_from_model), 0);
796
+ rb_define_method(rb_cLLaMAModel, "n_ctx", RUBY_METHOD_FUNC(_llama_model_get_n_ctx_from_model), 0);
797
+ rb_define_method(rb_cLLaMAModel, "n_embd", RUBY_METHOD_FUNC(_llama_model_get_n_embd_from_model), 0);
798
+ rb_define_method(rb_cLLaMAModel, "vocab", RUBY_METHOD_FUNC(_llama_model_get_vocab_from_model), -1);
799
+ rb_define_method(rb_cLLaMAModel, "token_to_str", RUBY_METHOD_FUNC(_llama_model_token_to_str_with_model), 1);
800
+ rb_define_method(rb_cLLaMAModel, "tokenize", RUBY_METHOD_FUNC(_llama_model_tokenize_with_model), -1);
767
801
  }
768
802
 
769
803
  private:
@@ -907,7 +941,110 @@ private:
907
941
  return Qnil;
908
942
  }
909
943
  return Qnil;
910
- };
944
+ }
945
+
946
+ static VALUE _llama_model_get_n_vocab_from_model(VALUE self) {
947
+ LLaMAModelWrapper* ptr = get_llama_model(self);
948
+ return INT2NUM(llama_n_vocab_from_model(ptr->model));
949
+ }
950
+
951
+ static VALUE _llama_model_get_n_ctx_from_model(VALUE self) {
952
+ LLaMAModelWrapper* ptr = get_llama_model(self);
953
+ return INT2NUM(llama_n_ctx_from_model(ptr->model));
954
+ }
955
+
956
+ static VALUE _llama_model_get_n_embd_from_model(VALUE self) {
957
+ LLaMAModelWrapper* ptr = get_llama_model(self);
958
+ return INT2NUM(llama_n_embd_from_model(ptr->model));
959
+ }
960
+
961
+ static VALUE _llama_model_get_vocab_from_model(int argc, VALUE* argv, VALUE self) {
962
+ VALUE kw_args = Qnil;
963
+ ID kw_table[1] = { rb_intern("capacity") };
964
+ VALUE kw_values[1] = { Qundef };
965
+ rb_scan_args(argc, argv, ":", &kw_args);
966
+ rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
967
+
968
+ if (!RB_INTEGER_TYPE_P(kw_values[0])) {
969
+ rb_raise(rb_eArgError, "capacity must be an integer");
970
+ return Qnil;
971
+ }
972
+
973
+ const int capacity = NUM2INT(kw_values[0]);
974
+
975
+ LLaMAModelWrapper* ptr = get_llama_model(self);
976
+ const int n = std::min(capacity, llama_n_vocab_from_model(ptr->model));
977
+ const char** vocabs = ALLOCA_N(const char*, n);
978
+ float* scores = ALLOCA_N(float, n);
979
+
980
+ llama_get_vocab_from_model(ptr->model, vocabs, scores, capacity);
981
+
982
+ VALUE vocabs_ary = rb_ary_new();
983
+ VALUE scores_ary = rb_ary_new();
984
+
985
+ for (int i = 0; i < n; i++) {
986
+ rb_ary_push(vocabs_ary, rb_str_new_cstr(vocabs[i]));
987
+ rb_ary_push(scores_ary, DBL2NUM(scores[i]));
988
+ }
989
+
990
+ VALUE ret = rb_ary_new3(2, vocabs_ary, scores_ary);
991
+
992
+ return ret;
993
+ }
994
+
995
+ static VALUE _llama_model_token_to_str_with_model(VALUE self, VALUE token_) {
996
+ if (!RB_INTEGER_TYPE_P(token_)) {
997
+ rb_raise(rb_eArgError, "token must be an integer");
998
+ return Qnil;
999
+ }
1000
+ const llama_token token = NUM2INT(token_);
1001
+ LLaMAModelWrapper* ptr = get_llama_model(self);
1002
+ const char* str = llama_token_to_str_with_model(ptr->model, token);
1003
+ return rb_str_new_cstr(str);
1004
+ }
1005
+
1006
+ static VALUE _llama_model_tokenize_with_model(int argc, VALUE* argv, VALUE self) {
1007
+ VALUE kw_args = Qnil;
1008
+ ID kw_table[3] = { rb_intern("text"), rb_intern("n_max_tokens"), rb_intern("add_bos") };
1009
+ VALUE kw_values[3] = { Qundef, Qundef, Qundef };
1010
+ rb_scan_args(argc, argv, ":", &kw_args);
1011
+ rb_get_kwargs(kw_args, kw_table, 1, 2, kw_values);
1012
+
1013
+ if (!RB_TYPE_P(kw_values[0], T_STRING)) {
1014
+ rb_raise(rb_eArgError, "text must be a String");
1015
+ return Qnil;
1016
+ }
1017
+ if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
1018
+ rb_raise(rb_eArgError, "n_max_tokens must be an integer");
1019
+ return Qnil;
1020
+ }
1021
+ if (kw_values[2] != Qundef && (kw_values[2] != Qtrue && kw_values[2] != Qfalse)) {
1022
+ rb_raise(rb_eArgError, "add_bos must be a boolean");
1023
+ return Qnil;
1024
+ }
1025
+
1026
+ VALUE text_ = kw_values[0];
1027
+ std::string text = StringValueCStr(text_);
1028
+ const bool add_bos = kw_values[2] == Qtrue ? true : false;
1029
+ const int n_max_tokens = kw_values[1] != Qundef ? NUM2INT(kw_values[1]) : text.size() + (add_bos ? 1 : 0);
1030
+
1031
+ llama_token* tokens = ALLOCA_N(llama_token, n_max_tokens);
1032
+ LLaMAModelWrapper* ptr = get_llama_model(self);
1033
+ const int n_tokens = llama_tokenize_with_model(ptr->model, text.c_str(), tokens, n_max_tokens, add_bos);
1034
+
1035
+ if (n_tokens < 0) {
1036
+ rb_raise(rb_eRuntimeError, "failed to tokenize. The numebr of tokens (%d) is greater than n_max_tokens.", -n_tokens);
1037
+ return Qnil;
1038
+ }
1039
+
1040
+ VALUE ret = rb_ary_new2(n_tokens);
1041
+ for (int i = 0; i < n_tokens; i++) {
1042
+ rb_ary_store(ret, i, INT2NUM(tokens[i]));
1043
+ }
1044
+
1045
+ RB_GC_GUARD(text_);
1046
+ return ret;
1047
+ }
911
1048
  };
912
1049
 
913
1050
  const rb_data_type_t RbLLaMAModel::llama_model_type = {
@@ -924,13 +1061,13 @@ class LLaMAContextWrapper {
924
1061
  public:
925
1062
  struct llama_context* ctx;
926
1063
 
927
- LLaMAContextWrapper() : ctx(NULL){};
1064
+ LLaMAContextWrapper() : ctx(NULL) {}
928
1065
 
929
1066
  ~LLaMAContextWrapper() {
930
1067
  if (ctx != NULL) {
931
1068
  llama_free(ctx);
932
1069
  }
933
- };
1070
+ }
934
1071
  };
935
1072
 
936
1073
  class RbLLaMAContext {
@@ -939,22 +1076,22 @@ public:
939
1076
  LLaMAContextWrapper* ptr = (LLaMAContextWrapper*)ruby_xmalloc(sizeof(LLaMAContextWrapper));
940
1077
  new (ptr) LLaMAContextWrapper();
941
1078
  return TypedData_Wrap_Struct(self, &llama_context_type, ptr);
942
- };
1079
+ }
943
1080
 
944
1081
  static void llama_context_free(void* ptr) {
945
1082
  ((LLaMAContextWrapper*)ptr)->~LLaMAContextWrapper();
946
1083
  ruby_xfree(ptr);
947
- };
1084
+ }
948
1085
 
949
1086
  static size_t llama_context_size(const void* ptr) {
950
1087
  return sizeof(*((LLaMAContextWrapper*)ptr));
951
- };
1088
+ }
952
1089
 
953
1090
  static LLaMAContextWrapper* get_llama_context(VALUE self) {
954
1091
  LLaMAContextWrapper* ptr;
955
1092
  TypedData_Get_Struct(self, LLaMAContextWrapper, &llama_context_type, ptr);
956
1093
  return ptr;
957
- };
1094
+ }
958
1095
 
959
1096
  static void define_class(VALUE outer) {
960
1097
  rb_cLLaMAContext = rb_define_class_under(outer, "Context", rb_cObject);
@@ -980,6 +1117,7 @@ public:
980
1117
  rb_define_method(rb_cLLaMAContext, "save_session_file", RUBY_METHOD_FUNC(_llama_context_save_session_file), -1);
981
1118
  rb_define_method(rb_cLLaMAContext, "sample_repetition_penalty", RUBY_METHOD_FUNC(_llama_context_sample_repetition_penalty), -1);
982
1119
  rb_define_method(rb_cLLaMAContext, "sample_frequency_and_presence_penalties", RUBY_METHOD_FUNC(_llama_context_sample_frequency_and_presence_penalties), -1);
1120
+ rb_define_method(rb_cLLaMAContext, "sample_classifier_free_guidance", RUBY_METHOD_FUNC(_llama_context_sample_classifier_free_guidance), -1);
983
1121
  rb_define_method(rb_cLLaMAContext, "sample_softmax", RUBY_METHOD_FUNC(_llama_context_sample_softmax), 1);
984
1122
  rb_define_method(rb_cLLaMAContext, "sample_top_k", RUBY_METHOD_FUNC(_llama_context_sample_top_k), -1);
985
1123
  rb_define_method(rb_cLLaMAContext, "sample_top_p", RUBY_METHOD_FUNC(_llama_context_sample_top_p), -1);
@@ -990,7 +1128,7 @@ public:
990
1128
  rb_define_method(rb_cLLaMAContext, "sample_token_mirostat_v2", RUBY_METHOD_FUNC(_llama_context_sample_token_mirostat_v2), -1);
991
1129
  rb_define_method(rb_cLLaMAContext, "sample_token_greedy", RUBY_METHOD_FUNC(_llama_context_sample_token_greedy), 1);
992
1130
  rb_define_method(rb_cLLaMAContext, "sample_token", RUBY_METHOD_FUNC(_llama_context_sample_token), 1);
993
- };
1131
+ }
994
1132
 
995
1133
  private:
996
1134
  static const rb_data_type_t llama_context_type;
@@ -1029,7 +1167,7 @@ private:
1029
1167
  rb_iv_set(self, "@has_evaluated", Qfalse);
1030
1168
 
1031
1169
  return Qnil;
1032
- };
1170
+ }
1033
1171
 
1034
1172
  static VALUE _llama_context_eval(int argc, VALUE* argv, VALUE self) {
1035
1173
  VALUE kw_args = Qnil;
@@ -1084,7 +1222,7 @@ private:
1084
1222
  rb_iv_set(self, "@has_evaluated", Qtrue);
1085
1223
 
1086
1224
  return Qnil;
1087
- };
1225
+ }
1088
1226
 
1089
1227
  static VALUE _llama_context_eval_embd(int argc, VALUE* argv, VALUE self) {
1090
1228
  VALUE kw_args = Qnil;
@@ -1157,7 +1295,7 @@ private:
1157
1295
  }
1158
1296
  RB_GC_GUARD(fname_);
1159
1297
  return Qtrue;
1160
- };
1298
+ }
1161
1299
 
1162
1300
  static VALUE _llama_context_tokenize(int argc, VALUE* argv, VALUE self) {
1163
1301
  VALUE kw_args = Qnil;
@@ -1203,7 +1341,7 @@ private:
1203
1341
 
1204
1342
  RB_GC_GUARD(text_);
1205
1343
  return output;
1206
- };
1344
+ }
1207
1345
 
1208
1346
  static VALUE _llama_context_token_to_str(VALUE self, VALUE token_) {
1209
1347
  LLaMAContextWrapper* ptr = get_llama_context(self);
@@ -1214,7 +1352,7 @@ private:
1214
1352
  const llama_token token = NUM2INT(token_);
1215
1353
  const char* str = llama_token_to_str(ptr->ctx, token);
1216
1354
  return str != nullptr ? rb_utf8_str_new_cstr(str) : rb_utf8_str_new_cstr("");
1217
- };
1355
+ }
1218
1356
 
1219
1357
  static VALUE _llama_context_logits(VALUE self) {
1220
1358
  LLaMAContextWrapper* ptr = get_llama_context(self);
@@ -1239,7 +1377,7 @@ private:
1239
1377
  }
1240
1378
 
1241
1379
  return output;
1242
- };
1380
+ }
1243
1381
 
1244
1382
  static VALUE _llama_context_embeddings(VALUE self) {
1245
1383
  LLaMAContextWrapper* ptr = get_llama_context(self);
@@ -1267,7 +1405,7 @@ private:
1267
1405
  }
1268
1406
 
1269
1407
  return output;
1270
- };
1408
+ }
1271
1409
 
1272
1410
  static VALUE _llama_context_vocab(int argc, VALUE* argv, VALUE self) {
1273
1411
  VALUE kw_args = Qnil;
@@ -1304,7 +1442,7 @@ private:
1304
1442
  }
1305
1443
 
1306
1444
  return rb_ary_new_from_args(2, ret_strings, ret_scores);
1307
- };
1445
+ }
1308
1446
 
1309
1447
  static VALUE _llama_context_n_vocab(VALUE self) {
1310
1448
  LLaMAContextWrapper* ptr = get_llama_context(self);
@@ -1313,7 +1451,7 @@ private:
1313
1451
  return Qnil;
1314
1452
  }
1315
1453
  return INT2NUM(llama_n_vocab(ptr->ctx));
1316
- };
1454
+ }
1317
1455
 
1318
1456
  static VALUE _llama_context_n_ctx(VALUE self) {
1319
1457
  LLaMAContextWrapper* ptr = get_llama_context(self);
@@ -1322,7 +1460,7 @@ private:
1322
1460
  return Qnil;
1323
1461
  }
1324
1462
  return INT2NUM(llama_n_ctx(ptr->ctx));
1325
- };
1463
+ }
1326
1464
 
1327
1465
  static VALUE _llama_context_n_embd(VALUE self) {
1328
1466
  LLaMAContextWrapper* ptr = get_llama_context(self);
@@ -1331,7 +1469,7 @@ private:
1331
1469
  return Qnil;
1332
1470
  }
1333
1471
  return INT2NUM(llama_n_embd(ptr->ctx));
1334
- };
1472
+ }
1335
1473
 
1336
1474
  static VALUE _llama_context_get_timings(VALUE self) {
1337
1475
  LLaMAContextWrapper* ptr = get_llama_context(self);
@@ -1353,7 +1491,7 @@ private:
1353
1491
  }
1354
1492
  llama_print_timings(ptr->ctx);
1355
1493
  return Qnil;
1356
- };
1494
+ }
1357
1495
 
1358
1496
  static VALUE _llama_context_reset_timings(VALUE self) {
1359
1497
  LLaMAContextWrapper* ptr = get_llama_context(self);
@@ -1363,7 +1501,7 @@ private:
1363
1501
  }
1364
1502
  llama_reset_timings(ptr->ctx);
1365
1503
  return Qnil;
1366
- };
1504
+ }
1367
1505
 
1368
1506
  static VALUE _llama_context_kv_cache_token_count(VALUE self) {
1369
1507
  LLaMAContextWrapper* ptr = get_llama_context(self);
@@ -1372,7 +1510,7 @@ private:
1372
1510
  return Qnil;
1373
1511
  }
1374
1512
  return INT2NUM(llama_get_kv_cache_token_count(ptr->ctx));
1375
- };
1513
+ }
1376
1514
 
1377
1515
  static VALUE _llama_context_set_rng_seed(VALUE self, VALUE seed_) {
1378
1516
  LLaMAContextWrapper* ptr = get_llama_context(self);
@@ -1387,7 +1525,7 @@ private:
1387
1525
  const uint32_t seed = NUM2INT(seed_);
1388
1526
  llama_set_rng_seed(ptr->ctx, seed);
1389
1527
  return Qnil;
1390
- };
1528
+ }
1391
1529
 
1392
1530
  static VALUE _llama_context_load_session_file(int argc, VALUE* argv, VALUE self) {
1393
1531
  VALUE kw_args = Qnil;
@@ -1525,7 +1663,7 @@ private:
1525
1663
  llama_sample_repetition_penalty(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size, penalty);
1526
1664
 
1527
1665
  return Qnil;
1528
- };
1666
+ }
1529
1667
 
1530
1668
  static VALUE _llama_context_sample_frequency_and_presence_penalties(int argc, VALUE* argv, VALUE self) {
1531
1669
  VALUE kw_args = Qnil;
@@ -1576,7 +1714,47 @@ private:
1576
1714
  llama_sample_frequency_and_presence_penalties(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size, alpha_frequency, alpha_presence);
1577
1715
 
1578
1716
  return Qnil;
1579
- };
1717
+ }
1718
+
1719
+ static VALUE _llama_context_sample_classifier_free_guidance(int argc, VALUE* argv, VALUE self) {
1720
+ VALUE kw_args = Qnil;
1721
+ ID kw_table[2] = { rb_intern("guidance"), rb_intern("scale") };
1722
+ VALUE kw_values[2] = { Qundef, Qundef };
1723
+ VALUE candidates = Qnil;
1724
+ rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
1725
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
1726
+
1727
+ if (!rb_obj_is_kind_of(kw_values[0], rb_cLLaMAContext)) {
1728
+ rb_raise(rb_eArgError, "guidance must be a Context");
1729
+ return Qnil;
1730
+ }
1731
+ if (!RB_FLOAT_TYPE_P(kw_values[1])) {
1732
+ rb_raise(rb_eArgError, "scale must be a float");
1733
+ return Qnil;
1734
+ }
1735
+
1736
+ LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
1737
+ if (ctx_ptr->ctx == NULL) {
1738
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1739
+ return Qnil;
1740
+ }
1741
+ LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
1742
+ if (cnd_ptr->array.data == nullptr) {
1743
+ rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
1744
+ return Qnil;
1745
+ }
1746
+
1747
+ LLaMAContextWrapper* guidance_ptr = get_llama_context(kw_values[0]);
1748
+ if (guidance_ptr->ctx == NULL) {
1749
+ rb_raise(rb_eRuntimeError, "guidance context is not initialized");
1750
+ return Qnil;
1751
+ }
1752
+ const float scale = NUM2DBL(kw_values[1]);
1753
+
1754
+ llama_sample_classifier_free_guidance(ctx_ptr->ctx, &(cnd_ptr->array), guidance_ptr->ctx, scale);
1755
+
1756
+ return Qnil;
1757
+ }
1580
1758
 
1581
1759
  static VALUE _llama_context_sample_softmax(VALUE self, VALUE candidates) {
1582
1760
  if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
@@ -1598,7 +1776,7 @@ private:
1598
1776
  llama_sample_softmax(ctx_ptr->ctx, &(cnd_ptr->array));
1599
1777
 
1600
1778
  return Qnil;
1601
- };
1779
+ }
1602
1780
 
1603
1781
  static VALUE _llama_context_sample_top_k(int argc, VALUE* argv, VALUE self) {
1604
1782
  VALUE kw_args = Qnil;
@@ -1637,7 +1815,7 @@ private:
1637
1815
  llama_sample_top_k(ctx_ptr->ctx, &(cnd_ptr->array), k, min_keep);
1638
1816
 
1639
1817
  return Qnil;
1640
- };
1818
+ }
1641
1819
 
1642
1820
  static VALUE _llama_context_sample_top_p(int argc, VALUE* argv, VALUE self) {
1643
1821
  VALUE kw_args = Qnil;
@@ -1676,7 +1854,7 @@ private:
1676
1854
  llama_sample_top_p(ctx_ptr->ctx, &(cnd_ptr->array), prob, min_keep);
1677
1855
 
1678
1856
  return Qnil;
1679
- };
1857
+ }
1680
1858
 
1681
1859
  static VALUE _llama_context_sample_tail_free(int argc, VALUE* argv, VALUE self) {
1682
1860
  VALUE kw_args = Qnil;
@@ -1715,7 +1893,7 @@ private:
1715
1893
  llama_sample_tail_free(ctx_ptr->ctx, &(cnd_ptr->array), z, min_keep);
1716
1894
 
1717
1895
  return Qnil;
1718
- };
1896
+ }
1719
1897
 
1720
1898
  static VALUE _llama_context_sample_typical(int argc, VALUE* argv, VALUE self) {
1721
1899
  VALUE kw_args = Qnil;
@@ -1754,7 +1932,7 @@ private:
1754
1932
  llama_sample_typical(ctx_ptr->ctx, &(cnd_ptr->array), prob, min_keep);
1755
1933
 
1756
1934
  return Qnil;
1757
- };
1935
+ }
1758
1936
 
1759
1937
  static VALUE _llama_context_sample_temperature(int argc, VALUE* argv, VALUE self) {
1760
1938
  VALUE kw_args = Qnil;
@@ -1788,7 +1966,7 @@ private:
1788
1966
  llama_sample_temperature(ctx_ptr->ctx, &(cnd_ptr->array), temperature);
1789
1967
 
1790
1968
  return Qnil;
1791
- };
1969
+ }
1792
1970
 
1793
1971
  static VALUE _llama_context_sample_token_mirostat(int argc, VALUE* argv, VALUE self) {
1794
1972
  VALUE kw_args = Qnil;
@@ -1840,7 +2018,7 @@ private:
1840
2018
  rb_ary_store(ret, 0, INT2NUM(id));
1841
2019
  rb_ary_store(ret, 1, DBL2NUM(mu));
1842
2020
  return ret;
1843
- };
2021
+ }
1844
2022
 
1845
2023
  static VALUE _llama_context_sample_token_mirostat_v2(int argc, VALUE* argv, VALUE self) {
1846
2024
  VALUE kw_args = Qnil;
@@ -1887,7 +2065,7 @@ private:
1887
2065
  rb_ary_store(ret, 0, INT2NUM(id));
1888
2066
  rb_ary_store(ret, 1, DBL2NUM(mu));
1889
2067
  return ret;
1890
- };
2068
+ }
1891
2069
 
1892
2070
  static VALUE _llama_context_sample_token_greedy(VALUE self, VALUE candidates) {
1893
2071
  LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
@@ -1906,7 +2084,7 @@ private:
1906
2084
  }
1907
2085
  llama_token id = llama_sample_token_greedy(ctx_ptr->ctx, &(cnd_ptr->array));
1908
2086
  return INT2NUM(id);
1909
- };
2087
+ }
1910
2088
 
1911
2089
  static VALUE _llama_context_sample_token(VALUE self, VALUE candidates) {
1912
2090
  LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
@@ -1925,7 +2103,7 @@ private:
1925
2103
  }
1926
2104
  llama_token id = llama_sample_token(ctx_ptr->ctx, &(cnd_ptr->array));
1927
2105
  return INT2NUM(id);
1928
- };
2106
+ }
1929
2107
  };
1930
2108
 
1931
2109
  const rb_data_type_t RbLLaMAContext::llama_context_type = {
@@ -1940,7 +2118,7 @@ const rb_data_type_t RbLLaMAContext::llama_context_type = {
1940
2118
 
1941
2119
  // module functions
1942
2120
 
1943
- static VALUE rb_llama_llama_init_backend(int argc, VALUE* argv, VALUE self) {
2121
+ static VALUE rb_llama_llama_backend_init(int argc, VALUE* argv, VALUE self) {
1944
2122
  VALUE kw_args = Qnil;
1945
2123
  ID kw_table[1] = { rb_intern("numa") };
1946
2124
  VALUE kw_values[1] = { Qundef };
@@ -1948,7 +2126,13 @@ static VALUE rb_llama_llama_init_backend(int argc, VALUE* argv, VALUE self) {
1948
2126
  rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
1949
2127
 
1950
2128
  const bool numa = kw_values[0] == Qundef ? false : (RTEST ? true : false);
1951
- llama_init_backend(numa);
2129
+ llama_backend_init(numa);
2130
+
2131
+ return Qnil;
2132
+ }
2133
+
2134
+ static VALUE rb_llama_llama_backend_free(VALUE self) {
2135
+ llama_backend_free();
1952
2136
 
1953
2137
  return Qnil;
1954
2138
  }
@@ -2010,6 +2194,10 @@ static VALUE rb_llama_mlock_supported(VALUE self) {
2010
2194
  return llama_mlock_supported() ? Qtrue : Qfalse;
2011
2195
  }
2012
2196
 
2197
+ static VALUE rb_llama_max_devices(VALUE self) {
2198
+ return INT2NUM(llama_max_devices());
2199
+ }
2200
+
2013
2201
  extern "C" void Init_llama_cpp(void) {
2014
2202
  rb_mLLaMACpp = rb_define_module("LLaMACpp");
2015
2203
 
@@ -2021,7 +2209,8 @@ extern "C" void Init_llama_cpp(void) {
2021
2209
  RbLLaMAContextParams::define_class(rb_mLLaMACpp);
2022
2210
  RbLLaMAModelQuantizeParams::define_class(rb_mLLaMACpp);
2023
2211
 
2024
- rb_define_module_function(rb_mLLaMACpp, "init_backend", rb_llama_llama_init_backend, -1);
2212
+ rb_define_module_function(rb_mLLaMACpp, "backend_init", rb_llama_llama_backend_init, -1);
2213
+ rb_define_module_function(rb_mLLaMACpp, "backend_free", rb_llama_llama_backend_free, 0);
2025
2214
  rb_define_module_function(rb_mLLaMACpp, "model_quantize", rb_llama_model_quantize, -1);
2026
2215
  rb_define_module_function(rb_mLLaMACpp, "token_bos", rb_llama_token_bos, 0);
2027
2216
  rb_define_module_function(rb_mLLaMACpp, "token_eos", rb_llama_token_eos, 0);
@@ -2029,6 +2218,7 @@ extern "C" void Init_llama_cpp(void) {
2029
2218
  rb_define_module_function(rb_mLLaMACpp, "print_system_info", rb_llama_print_system_info, 0);
2030
2219
  rb_define_module_function(rb_mLLaMACpp, "mmap_supported?", rb_llama_mmap_supported, 0);
2031
2220
  rb_define_module_function(rb_mLLaMACpp, "mlock_supported?", rb_llama_mlock_supported, 0);
2221
+ rb_define_module_function(rb_mLLaMACpp, "max_devices", rb_llama_max_devices, 0);
2032
2222
 
2033
2223
  rb_define_const(rb_mLLaMACpp, "LLAMA_MAX_DEVICES", INT2NUM(LLAMA_MAX_DEVICES));
2034
2224