llama_cpp 0.3.2 → 0.3.4

Sign up to get free protection for your applications and to get access to all the features.
@@ -17,9 +17,9 @@ public:
17
17
  data.id = 0;
18
18
  data.logit = 0.0;
19
19
  data.p = 0.0;
20
- };
20
+ }
21
21
 
22
- ~LLaMATokenDataWrapper(){};
22
+ ~LLaMATokenDataWrapper() {}
23
23
  };
24
24
 
25
25
  class RbLLaMATokenData {
@@ -28,22 +28,22 @@ public:
28
28
  LLaMATokenDataWrapper* ptr = (LLaMATokenDataWrapper*)ruby_xmalloc(sizeof(LLaMATokenDataWrapper));
29
29
  new (ptr) LLaMATokenDataWrapper();
30
30
  return TypedData_Wrap_Struct(self, &llama_token_data_type, ptr);
31
- };
31
+ }
32
32
 
33
33
  static void llama_token_data_free(void* ptr) {
34
34
  ((LLaMATokenDataWrapper*)ptr)->~LLaMATokenDataWrapper();
35
35
  ruby_xfree(ptr);
36
- };
36
+ }
37
37
 
38
38
  static size_t llama_token_data_size(const void* ptr) {
39
39
  return sizeof(*((LLaMATokenDataWrapper*)ptr));
40
- };
40
+ }
41
41
 
42
42
  static LLaMATokenDataWrapper* get_llama_token_data(VALUE self) {
43
43
  LLaMATokenDataWrapper* ptr;
44
44
  TypedData_Get_Struct(self, LLaMATokenDataWrapper, &llama_token_data_type, ptr);
45
45
  return ptr;
46
- };
46
+ }
47
47
 
48
48
  static void define_class(VALUE outer) {
49
49
  rb_cLLaMATokenData = rb_define_class_under(outer, "TokenData", rb_cObject);
@@ -95,36 +95,36 @@ private:
95
95
  LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
96
96
  ptr->data.id = NUM2INT(id);
97
97
  return INT2NUM(ptr->data.id);
98
- };
98
+ }
99
99
 
100
100
  static VALUE _llama_token_data_get_id(VALUE self) {
101
101
  LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
102
102
  return INT2NUM(ptr->data.id);
103
- };
103
+ }
104
104
 
105
105
  // logit
106
106
  static VALUE _llama_token_data_set_logit(VALUE self, VALUE logit) {
107
107
  LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
108
108
  ptr->data.logit = NUM2DBL(logit);
109
109
  return DBL2NUM(ptr->data.logit);
110
- };
110
+ }
111
111
 
112
112
  static VALUE _llama_token_data_get_logit(VALUE self) {
113
113
  LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
114
114
  return DBL2NUM(ptr->data.logit);
115
- };
115
+ }
116
116
 
117
117
  // p
118
118
  static VALUE _llama_token_data_set_p(VALUE self, VALUE p) {
119
119
  LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
120
120
  ptr->data.p = NUM2DBL(p);
121
121
  return DBL2NUM(ptr->data.p);
122
- };
122
+ }
123
123
 
124
124
  static VALUE _llama_token_data_get_p(VALUE self) {
125
125
  LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
126
126
  return DBL2NUM(ptr->data.p);
127
- };
127
+ }
128
128
  };
129
129
 
130
130
  const rb_data_type_t RbLLaMATokenData::llama_token_data_type = {
@@ -145,14 +145,14 @@ public:
145
145
  array.data = nullptr;
146
146
  array.size = 0;
147
147
  array.sorted = false;
148
- };
148
+ }
149
149
 
150
150
  ~LLaMATokenDataArrayWrapper() {
151
151
  if (array.data) {
152
152
  ruby_xfree(array.data);
153
153
  array.data = nullptr;
154
154
  }
155
- };
155
+ }
156
156
  };
157
157
 
158
158
  class RbLLaMATokenDataArray {
@@ -161,22 +161,22 @@ public:
161
161
  LLaMATokenDataArrayWrapper* ptr = (LLaMATokenDataArrayWrapper*)ruby_xmalloc(sizeof(LLaMATokenDataArrayWrapper));
162
162
  new (ptr) LLaMATokenDataArrayWrapper();
163
163
  return TypedData_Wrap_Struct(self, &llama_token_data_array_type, ptr);
164
- };
164
+ }
165
165
 
166
166
  static void llama_token_data_array_free(void* ptr) {
167
167
  ((LLaMATokenDataArrayWrapper*)ptr)->~LLaMATokenDataArrayWrapper();
168
168
  ruby_xfree(ptr);
169
- };
169
+ }
170
170
 
171
171
  static size_t llama_token_data_array_size(const void* ptr) {
172
172
  return sizeof(*((LLaMATokenDataArrayWrapper*)ptr));
173
- };
173
+ }
174
174
 
175
175
  static LLaMATokenDataArrayWrapper* get_llama_token_data_array(VALUE self) {
176
176
  LLaMATokenDataArrayWrapper* ptr;
177
177
  TypedData_Get_Struct(self, LLaMATokenDataArrayWrapper, &llama_token_data_array_type, ptr);
178
178
  return ptr;
179
- };
179
+ }
180
180
 
181
181
  static void define_class(VALUE outer) {
182
182
  rb_cLLaMATokenDataArray = rb_define_class_under(outer, "TokenDataArray", rb_cObject);
@@ -184,7 +184,7 @@ public:
184
184
  rb_define_method(rb_cLLaMATokenDataArray, "initialize", RUBY_METHOD_FUNC(_llama_token_data_array_init), -1);
185
185
  rb_define_method(rb_cLLaMATokenDataArray, "size", RUBY_METHOD_FUNC(_llama_token_data_array_get_size), 0);
186
186
  rb_define_method(rb_cLLaMATokenDataArray, "sorted", RUBY_METHOD_FUNC(_llama_token_data_array_get_sorted), 0);
187
- };
187
+ }
188
188
 
189
189
  private:
190
190
  static const rb_data_type_t llama_token_data_array_type;
@@ -233,17 +233,17 @@ private:
233
233
  ptr->array.sorted = kw_values[0] == Qtrue;
234
234
 
235
235
  return self;
236
- };
236
+ }
237
237
 
238
238
  static VALUE _llama_token_data_array_get_size(VALUE self) {
239
239
  LLaMATokenDataArrayWrapper* ptr = get_llama_token_data_array(self);
240
240
  return SIZET2NUM(ptr->array.size);
241
- };
241
+ }
242
242
 
243
243
  static VALUE _llama_token_data_array_get_sorted(VALUE self) {
244
244
  LLaMATokenDataArrayWrapper* ptr = get_llama_token_data_array(self);
245
245
  return ptr->array.sorted ? Qtrue : Qfalse;
246
- };
246
+ }
247
247
  };
248
248
 
249
249
  const rb_data_type_t RbLLaMATokenDataArray::llama_token_data_array_type = {
@@ -260,9 +260,9 @@ class LLaMATimingsWrapper {
260
260
  public:
261
261
  struct llama_timings timings;
262
262
 
263
- LLaMATimingsWrapper(){};
263
+ LLaMATimingsWrapper() {}
264
264
 
265
- ~LLaMATimingsWrapper(){};
265
+ ~LLaMATimingsWrapper() {}
266
266
  };
267
267
 
268
268
  class RbLLaMATimings {
@@ -365,9 +365,9 @@ class LLaMAContextParamsWrapper {
365
365
  public:
366
366
  struct llama_context_params params;
367
367
 
368
- LLaMAContextParamsWrapper() : params(llama_context_default_params()){};
368
+ LLaMAContextParamsWrapper() : params(llama_context_default_params()) {}
369
369
 
370
- ~LLaMAContextParamsWrapper(){};
370
+ ~LLaMAContextParamsWrapper() {}
371
371
  };
372
372
 
373
373
  class RbLLaMAContextParams {
@@ -376,22 +376,22 @@ public:
376
376
  LLaMAContextParamsWrapper* ptr = (LLaMAContextParamsWrapper*)ruby_xmalloc(sizeof(LLaMAContextParamsWrapper));
377
377
  new (ptr) LLaMAContextParamsWrapper();
378
378
  return TypedData_Wrap_Struct(self, &llama_context_params_type, ptr);
379
- };
379
+ }
380
380
 
381
381
  static void llama_context_params_free(void* ptr) {
382
382
  ((LLaMAContextParamsWrapper*)ptr)->~LLaMAContextParamsWrapper();
383
383
  ruby_xfree(ptr);
384
- };
384
+ }
385
385
 
386
386
  static size_t llama_context_params_size(const void* ptr) {
387
387
  return sizeof(*((LLaMAContextParamsWrapper*)ptr));
388
- };
388
+ }
389
389
 
390
390
  static LLaMAContextParamsWrapper* get_llama_context_params(VALUE self) {
391
391
  LLaMAContextParamsWrapper* ptr;
392
392
  TypedData_Get_Struct(self, LLaMAContextParamsWrapper, &llama_context_params_type, ptr);
393
393
  return ptr;
394
- };
394
+ }
395
395
 
396
396
  static void define_class(VALUE outer) {
397
397
  rb_cLLaMAContextParams = rb_define_class_under(outer, "ContextParams", rb_cObject);
@@ -406,6 +406,10 @@ public:
406
406
  rb_define_method(rb_cLLaMAContextParams, "main_gpu=", RUBY_METHOD_FUNC(_llama_context_params_set_main_gpu), 1);
407
407
  rb_define_method(rb_cLLaMAContextParams, "main_gpu", RUBY_METHOD_FUNC(_llama_context_params_get_main_gpu), 0);
408
408
  rb_define_method(rb_cLLaMAContextParams, "tensor_split", RUBY_METHOD_FUNC(_llama_context_params_get_tensor_split), 0);
409
+ rb_define_method(rb_cLLaMAContextParams, "rope_freq_base=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_freq_base), 1);
410
+ rb_define_method(rb_cLLaMAContextParams, "rope_freq_base", RUBY_METHOD_FUNC(_llama_context_params_get_rope_freq_base), 0);
411
+ rb_define_method(rb_cLLaMAContextParams, "rope_freq_scale=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_freq_scale), 1);
412
+ rb_define_method(rb_cLLaMAContextParams, "rope_freq_scale", RUBY_METHOD_FUNC(_llama_context_params_get_rope_freq_scale), 0);
409
413
  rb_define_method(rb_cLLaMAContextParams, "low_vram=", RUBY_METHOD_FUNC(_llama_context_params_set_low_vram), 1);
410
414
  rb_define_method(rb_cLLaMAContextParams, "low_vram", RUBY_METHOD_FUNC(_llama_context_params_get_low_vram), 0);
411
415
  rb_define_method(rb_cLLaMAContextParams, "seed=", RUBY_METHOD_FUNC(_llama_context_params_set_seed), 1);
@@ -422,7 +426,7 @@ public:
422
426
  rb_define_method(rb_cLLaMAContextParams, "use_mlock", RUBY_METHOD_FUNC(_llama_context_params_get_use_mlock), 0);
423
427
  rb_define_method(rb_cLLaMAContextParams, "embedding=", RUBY_METHOD_FUNC(_llama_context_params_set_embedding), 1);
424
428
  rb_define_method(rb_cLLaMAContextParams, "embedding", RUBY_METHOD_FUNC(_llama_context_params_get_embedding), 0);
425
- };
429
+ }
426
430
 
427
431
  private:
428
432
  static const rb_data_type_t llama_context_params_type;
@@ -431,55 +435,55 @@ private:
431
435
  // LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
432
436
  // new (ptr) LLaMAContextParamsWrapper();
433
437
  // return self;
434
- // };
438
+ // }
435
439
 
436
440
  // n_ctx
437
441
  static VALUE _llama_context_params_set_n_ctx(VALUE self, VALUE n_ctx) {
438
442
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
439
443
  ptr->params.n_ctx = NUM2INT(n_ctx);
440
444
  return INT2NUM(ptr->params.n_ctx);
441
- };
445
+ }
442
446
 
443
447
  static VALUE _llama_context_params_get_n_ctx(VALUE self) {
444
448
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
445
449
  return INT2NUM(ptr->params.n_ctx);
446
- };
450
+ }
447
451
 
448
452
  // n_batch
449
453
  static VALUE _llama_context_params_set_n_batch(VALUE self, VALUE n_batch) {
450
454
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
451
455
  ptr->params.n_batch = NUM2INT(n_batch);
452
456
  return INT2NUM(ptr->params.n_batch);
453
- };
457
+ }
454
458
 
455
459
  static VALUE _llama_context_params_get_n_batch(VALUE self) {
456
460
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
457
461
  return INT2NUM(ptr->params.n_batch);
458
- };
462
+ }
459
463
 
460
464
  // n_gpu_layers
461
465
  static VALUE _llama_context_params_set_n_gpu_layers(VALUE self, VALUE n_gpu_layers) {
462
466
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
463
467
  ptr->params.n_gpu_layers = NUM2INT(n_gpu_layers);
464
468
  return INT2NUM(ptr->params.n_gpu_layers);
465
- };
469
+ }
466
470
 
467
471
  static VALUE _llama_context_params_get_n_gpu_layers(VALUE self) {
468
472
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
469
473
  return INT2NUM(ptr->params.n_gpu_layers);
470
- };
474
+ }
471
475
 
472
476
  // main_gpu
473
477
  static VALUE _llama_context_params_set_main_gpu(VALUE self, VALUE main_gpu) {
474
478
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
475
479
  ptr->params.main_gpu = NUM2INT(main_gpu);
476
480
  return INT2NUM(ptr->params.main_gpu);
477
- };
481
+ }
478
482
 
479
483
  static VALUE _llama_context_params_get_main_gpu(VALUE self) {
480
484
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
481
485
  return INT2NUM(ptr->params.main_gpu);
482
- };
486
+ }
483
487
 
484
488
  // tensor_split
485
489
  static VALUE _llama_context_params_get_tensor_split(VALUE self) {
@@ -492,19 +496,43 @@ private:
492
496
  rb_ary_store(ret, i, DBL2NUM(ptr->params.tensor_split[i]));
493
497
  }
494
498
  return ret;
495
- };
499
+ }
500
+
501
+ // rope_freq_base
502
+ static VALUE _llama_context_params_set_rope_freq_base(VALUE self, VALUE rope_freq_base) {
503
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
504
+ ptr->params.rope_freq_base = NUM2DBL(rope_freq_base);
505
+ return DBL2NUM(ptr->params.rope_freq_base);
506
+ }
507
+
508
+ static VALUE _llama_context_params_get_rope_freq_base(VALUE self) {
509
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
510
+ return DBL2NUM(ptr->params.rope_freq_base);
511
+ }
512
+
513
+ // rope_freq_scale
514
+ static VALUE _llama_context_params_set_rope_freq_scale(VALUE self, VALUE rope_freq_scale) {
515
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
516
+ ptr->params.rope_freq_scale = NUM2DBL(rope_freq_scale);
517
+ return DBL2NUM(ptr->params.rope_freq_scale);
518
+ }
519
+
520
+ static VALUE _llama_context_params_get_rope_freq_scale(VALUE self) {
521
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
522
+ return DBL2NUM(ptr->params.rope_freq_scale);
523
+ }
496
524
 
497
525
  // low_vram
498
526
  static VALUE _llama_context_params_set_low_vram(VALUE self, VALUE low_vram) {
499
527
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
500
528
  ptr->params.low_vram = low_vram == Qtrue ? true : false;
501
529
  return ptr->params.low_vram ? Qtrue : Qfalse;
502
- };
530
+ }
503
531
 
504
532
  static VALUE _llama_context_params_get_low_vram(VALUE self) {
505
533
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
506
534
  return ptr->params.low_vram ? Qtrue : Qfalse;
507
- };
535
+ }
508
536
 
509
537
  // seed
510
538
  static VALUE _llama_context_params_set_seed(VALUE self, VALUE seed) {
@@ -515,84 +543,84 @@ private:
515
543
  }
516
544
  ptr->params.seed = NUM2INT(seed);
517
545
  return INT2NUM(ptr->params.seed);
518
- };
546
+ }
519
547
 
520
548
  static VALUE _llama_context_params_get_seed(VALUE self) {
521
549
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
522
550
  return INT2NUM(ptr->params.seed);
523
- };
551
+ }
524
552
 
525
553
  // f16_kv
526
554
  static VALUE _llama_context_params_set_f16_kv(VALUE self, VALUE f16_kv) {
527
555
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
528
556
  ptr->params.f16_kv = f16_kv == Qtrue ? true : false;
529
557
  return ptr->params.f16_kv ? Qtrue : Qfalse;
530
- };
558
+ }
531
559
 
532
560
  static VALUE _llama_context_params_get_f16_kv(VALUE self) {
533
561
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
534
562
  return ptr->params.f16_kv ? Qtrue : Qfalse;
535
- };
563
+ }
536
564
 
537
565
  // logits_all
538
566
  static VALUE _llama_context_params_set_logits_all(VALUE self, VALUE logits_all) {
539
567
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
540
568
  ptr->params.logits_all = logits_all == Qtrue ? true : false;
541
569
  return ptr->params.logits_all ? Qtrue : Qfalse;
542
- };
570
+ }
543
571
 
544
572
  static VALUE _llama_context_params_get_logits_all(VALUE self) {
545
573
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
546
574
  return ptr->params.logits_all ? Qtrue : Qfalse;
547
- };
575
+ }
548
576
 
549
577
  // vocab_only
550
578
  static VALUE _llama_context_params_set_vocab_only(VALUE self, VALUE vocab_only) {
551
579
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
552
580
  ptr->params.vocab_only = vocab_only == Qtrue ? true : false;
553
581
  return ptr->params.vocab_only ? Qtrue : Qfalse;
554
- };
582
+ }
555
583
 
556
584
  static VALUE _llama_context_params_get_vocab_only(VALUE self) {
557
585
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
558
586
  return ptr->params.vocab_only ? Qtrue : Qfalse;
559
- };
587
+ }
560
588
 
561
589
  // use_mmap
562
590
  static VALUE _llama_context_params_set_use_mmap(VALUE self, VALUE use_mmap) {
563
591
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
564
592
  ptr->params.use_mmap = use_mmap == Qtrue ? true : false;
565
593
  return ptr->params.use_mmap ? Qtrue : Qfalse;
566
- };
594
+ }
567
595
 
568
596
  static VALUE _llama_context_params_get_use_mmap(VALUE self) {
569
597
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
570
598
  return ptr->params.use_mmap ? Qtrue : Qfalse;
571
- };
599
+ }
572
600
 
573
601
  // use_mlock
574
602
  static VALUE _llama_context_params_set_use_mlock(VALUE self, VALUE use_mlock) {
575
603
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
576
604
  ptr->params.use_mlock = use_mlock == Qtrue ? true : false;
577
605
  return ptr->params.use_mlock ? Qtrue : Qfalse;
578
- };
606
+ }
579
607
 
580
608
  static VALUE _llama_context_params_get_use_mlock(VALUE self) {
581
609
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
582
610
  return ptr->params.use_mlock ? Qtrue : Qfalse;
583
- };
611
+ }
584
612
 
585
613
  // embedding
586
614
  static VALUE _llama_context_params_set_embedding(VALUE self, VALUE embedding) {
587
615
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
588
616
  ptr->params.embedding = embedding == Qtrue ? true : false;
589
617
  return ptr->params.embedding ? Qtrue : Qfalse;
590
- };
618
+ }
591
619
 
592
620
  static VALUE _llama_context_params_get_embedding(VALUE self) {
593
621
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
594
622
  return ptr->params.embedding ? Qtrue : Qfalse;
595
- };
623
+ }
596
624
  };
597
625
 
598
626
  const rb_data_type_t RbLLaMAContextParams::llama_context_params_type = {
@@ -609,9 +637,9 @@ class LLaMAModelQuantizeParamsWrapper {
609
637
  public:
610
638
  llama_model_quantize_params params;
611
639
 
612
- LLaMAModelQuantizeParamsWrapper() : params(llama_model_quantize_default_params()){};
640
+ LLaMAModelQuantizeParamsWrapper() : params(llama_model_quantize_default_params()) {}
613
641
 
614
- ~LLaMAModelQuantizeParamsWrapper(){};
642
+ ~LLaMAModelQuantizeParamsWrapper() {}
615
643
  };
616
644
 
617
645
  class RbLLaMAModelQuantizeParams {
@@ -620,22 +648,22 @@ public:
620
648
  LLaMAModelQuantizeParamsWrapper* ptr = (LLaMAModelQuantizeParamsWrapper*)ruby_xmalloc(sizeof(LLaMAModelQuantizeParamsWrapper));
621
649
  new (ptr) LLaMAModelQuantizeParamsWrapper();
622
650
  return TypedData_Wrap_Struct(self, &llama_model_quantize_params_type, ptr);
623
- };
651
+ }
624
652
 
625
653
  static void llama_model_quantize_params_free(void* ptr) {
626
654
  ((LLaMAModelQuantizeParamsWrapper*)ptr)->~LLaMAModelQuantizeParamsWrapper();
627
655
  ruby_xfree(ptr);
628
- };
656
+ }
629
657
 
630
658
  static size_t llama_model_quantize_params_size(const void* ptr) {
631
659
  return sizeof(*((LLaMAModelQuantizeParamsWrapper*)ptr));
632
- };
660
+ }
633
661
 
634
662
  static LLaMAModelQuantizeParamsWrapper* get_llama_model_quantize_params(VALUE self) {
635
663
  LLaMAModelQuantizeParamsWrapper* ptr;
636
664
  TypedData_Get_Struct(self, LLaMAModelQuantizeParamsWrapper, &llama_model_quantize_params_type, ptr);
637
665
  return ptr;
638
- };
666
+ }
639
667
 
640
668
  static void define_class(VALUE outer) {
641
669
  rb_cLLaMAModelQuantizeParams = rb_define_class_under(outer, "ModelQuantizeParams", rb_cObject);
@@ -648,7 +676,7 @@ public:
648
676
  rb_define_method(rb_cLLaMAModelQuantizeParams, "allow_requantize", RUBY_METHOD_FUNC(_llama_model_quantize_params_get_allow_requantize), 0);
649
677
  rb_define_method(rb_cLLaMAModelQuantizeParams, "quantize_output_tensor=", RUBY_METHOD_FUNC(_llama_model_quantize_params_set_quantize_output_tensor), 1);
650
678
  rb_define_method(rb_cLLaMAModelQuantizeParams, "quantize_output_tensor", RUBY_METHOD_FUNC(_llama_model_quantize_params_get_quantize_output_tensor), 0);
651
- };
679
+ }
652
680
 
653
681
  private:
654
682
  static const rb_data_type_t llama_model_quantize_params_type;
@@ -658,24 +686,24 @@ private:
658
686
  LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
659
687
  ptr->params.nthread = NUM2INT(n_thread);
660
688
  return INT2NUM(ptr->params.nthread);
661
- };
689
+ }
662
690
 
663
691
  static VALUE _llama_model_quantize_params_get_n_thread(VALUE self) {
664
692
  LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
665
693
  return INT2NUM(ptr->params.nthread);
666
- };
694
+ }
667
695
 
668
696
  // ftype
669
697
  static VALUE _llama_model_quantize_params_set_ftype(VALUE self, VALUE ftype) {
670
698
  LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
671
699
  ptr->params.ftype = static_cast<enum llama_ftype>(NUM2INT(ftype));
672
700
  return INT2NUM(ptr->params.ftype);
673
- };
701
+ }
674
702
 
675
703
  static VALUE _llama_model_quantize_params_get_ftype(VALUE self) {
676
704
  LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
677
705
  return INT2NUM(ptr->params.ftype);
678
- };
706
+ }
679
707
 
680
708
  // allow_requantize
681
709
  static VALUE _llama_model_quantize_params_set_allow_requantize(VALUE self, VALUE allow_requantize) {
@@ -686,12 +714,12 @@ private:
686
714
  ptr->params.allow_requantize = true;
687
715
  }
688
716
  return ptr->params.allow_requantize ? Qtrue : Qfalse;
689
- };
717
+ }
690
718
 
691
719
  static VALUE _llama_model_quantize_params_get_allow_requantize(VALUE self) {
692
720
  LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
693
721
  return ptr->params.allow_requantize ? Qtrue : Qfalse;
694
- };
722
+ }
695
723
 
696
724
  // quantize_output_tensor
697
725
  static VALUE _llama_model_quantize_params_set_quantize_output_tensor(VALUE self, VALUE quantize_output_tensor) {
@@ -702,12 +730,12 @@ private:
702
730
  ptr->params.quantize_output_tensor = true;
703
731
  }
704
732
  return ptr->params.quantize_output_tensor ? Qtrue : Qfalse;
705
- };
733
+ }
706
734
 
707
735
  static VALUE _llama_model_quantize_params_get_quantize_output_tensor(VALUE self) {
708
736
  LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
709
737
  return ptr->params.quantize_output_tensor ? Qtrue : Qfalse;
710
- };
738
+ }
711
739
  };
712
740
 
713
741
  const rb_data_type_t RbLLaMAModelQuantizeParams::llama_model_quantize_params_type = {
@@ -724,13 +752,13 @@ class LLaMAModelWrapper {
724
752
  public:
725
753
  struct llama_model* model;
726
754
 
727
- LLaMAModelWrapper() : model(NULL){};
755
+ LLaMAModelWrapper() : model(NULL) {}
728
756
 
729
757
  ~LLaMAModelWrapper() {
730
758
  if (model != NULL) {
731
759
  llama_free_model(model);
732
760
  }
733
- };
761
+ }
734
762
  };
735
763
 
736
764
  class RbLLaMAModel {
@@ -764,6 +792,12 @@ public:
764
792
  rb_define_method(rb_cLLaMAModel, "free", RUBY_METHOD_FUNC(_llama_model_free), 0);
765
793
  rb_define_method(rb_cLLaMAModel, "load", RUBY_METHOD_FUNC(_llama_model_load), -1);
766
794
  rb_define_method(rb_cLLaMAModel, "apply_lora_from_file", RUBY_METHOD_FUNC(_llama_model_apply_lora_from_file), -1);
795
+ rb_define_method(rb_cLLaMAModel, "n_vocab", RUBY_METHOD_FUNC(_llama_model_get_n_vocab_from_model), 0);
796
+ rb_define_method(rb_cLLaMAModel, "n_ctx", RUBY_METHOD_FUNC(_llama_model_get_n_ctx_from_model), 0);
797
+ rb_define_method(rb_cLLaMAModel, "n_embd", RUBY_METHOD_FUNC(_llama_model_get_n_embd_from_model), 0);
798
+ rb_define_method(rb_cLLaMAModel, "vocab", RUBY_METHOD_FUNC(_llama_model_get_vocab_from_model), -1);
799
+ rb_define_method(rb_cLLaMAModel, "token_to_str", RUBY_METHOD_FUNC(_llama_model_token_to_str_with_model), 1);
800
+ rb_define_method(rb_cLLaMAModel, "tokenize", RUBY_METHOD_FUNC(_llama_model_tokenize_with_model), -1);
767
801
  }
768
802
 
769
803
  private:
@@ -907,7 +941,110 @@ private:
907
941
  return Qnil;
908
942
  }
909
943
  return Qnil;
910
- };
944
+ }
945
+
946
+ static VALUE _llama_model_get_n_vocab_from_model(VALUE self) {
947
+ LLaMAModelWrapper* ptr = get_llama_model(self);
948
+ return INT2NUM(llama_n_vocab_from_model(ptr->model));
949
+ }
950
+
951
+ static VALUE _llama_model_get_n_ctx_from_model(VALUE self) {
952
+ LLaMAModelWrapper* ptr = get_llama_model(self);
953
+ return INT2NUM(llama_n_ctx_from_model(ptr->model));
954
+ }
955
+
956
+ static VALUE _llama_model_get_n_embd_from_model(VALUE self) {
957
+ LLaMAModelWrapper* ptr = get_llama_model(self);
958
+ return INT2NUM(llama_n_embd_from_model(ptr->model));
959
+ }
960
+
961
+ static VALUE _llama_model_get_vocab_from_model(int argc, VALUE* argv, VALUE self) {
962
+ VALUE kw_args = Qnil;
963
+ ID kw_table[1] = { rb_intern("capacity") };
964
+ VALUE kw_values[1] = { Qundef };
965
+ rb_scan_args(argc, argv, ":", &kw_args);
966
+ rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
967
+
968
+ if (!RB_INTEGER_TYPE_P(kw_values[0])) {
969
+ rb_raise(rb_eArgError, "capacity must be an integer");
970
+ return Qnil;
971
+ }
972
+
973
+ const int capacity = NUM2INT(kw_values[0]);
974
+
975
+ LLaMAModelWrapper* ptr = get_llama_model(self);
976
+ const int n = std::min(capacity, llama_n_vocab_from_model(ptr->model));
977
+ const char** vocabs = ALLOCA_N(const char*, n);
978
+ float* scores = ALLOCA_N(float, n);
979
+
980
+ llama_get_vocab_from_model(ptr->model, vocabs, scores, capacity);
981
+
982
+ VALUE vocabs_ary = rb_ary_new();
983
+ VALUE scores_ary = rb_ary_new();
984
+
985
+ for (int i = 0; i < n; i++) {
986
+ rb_ary_push(vocabs_ary, rb_str_new_cstr(vocabs[i]));
987
+ rb_ary_push(scores_ary, DBL2NUM(scores[i]));
988
+ }
989
+
990
+ VALUE ret = rb_ary_new3(2, vocabs_ary, scores_ary);
991
+
992
+ return ret;
993
+ }
994
+
995
+ static VALUE _llama_model_token_to_str_with_model(VALUE self, VALUE token_) {
996
+ if (!RB_INTEGER_TYPE_P(token_)) {
997
+ rb_raise(rb_eArgError, "token must be an integer");
998
+ return Qnil;
999
+ }
1000
+ const llama_token token = NUM2INT(token_);
1001
+ LLaMAModelWrapper* ptr = get_llama_model(self);
1002
+ const char* str = llama_token_to_str_with_model(ptr->model, token);
1003
+ return rb_str_new_cstr(str);
1004
+ }
1005
+
1006
+ static VALUE _llama_model_tokenize_with_model(int argc, VALUE* argv, VALUE self) {
1007
+ VALUE kw_args = Qnil;
1008
+ ID kw_table[3] = { rb_intern("text"), rb_intern("n_max_tokens"), rb_intern("add_bos") };
1009
+ VALUE kw_values[3] = { Qundef, Qundef, Qundef };
1010
+ rb_scan_args(argc, argv, ":", &kw_args);
1011
+ rb_get_kwargs(kw_args, kw_table, 1, 2, kw_values);
1012
+
1013
+ if (!RB_TYPE_P(kw_values[0], T_STRING)) {
1014
+ rb_raise(rb_eArgError, "text must be a String");
1015
+ return Qnil;
1016
+ }
1017
+ if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
1018
+ rb_raise(rb_eArgError, "n_max_tokens must be an integer");
1019
+ return Qnil;
1020
+ }
1021
+ if (kw_values[2] != Qundef && (kw_values[2] != Qtrue && kw_values[2] != Qfalse)) {
1022
+ rb_raise(rb_eArgError, "add_bos must be a boolean");
1023
+ return Qnil;
1024
+ }
1025
+
1026
+ VALUE text_ = kw_values[0];
1027
+ std::string text = StringValueCStr(text_);
1028
+ const bool add_bos = kw_values[2] == Qtrue ? true : false;
1029
+ const int n_max_tokens = kw_values[1] != Qundef ? NUM2INT(kw_values[1]) : text.size() + (add_bos ? 1 : 0);
1030
+
1031
+ llama_token* tokens = ALLOCA_N(llama_token, n_max_tokens);
1032
+ LLaMAModelWrapper* ptr = get_llama_model(self);
1033
+ const int n_tokens = llama_tokenize_with_model(ptr->model, text.c_str(), tokens, n_max_tokens, add_bos);
1034
+
1035
+ if (n_tokens < 0) {
1036
+ rb_raise(rb_eRuntimeError, "failed to tokenize. The numebr of tokens (%d) is greater than n_max_tokens.", -n_tokens);
1037
+ return Qnil;
1038
+ }
1039
+
1040
+ VALUE ret = rb_ary_new2(n_tokens);
1041
+ for (int i = 0; i < n_tokens; i++) {
1042
+ rb_ary_store(ret, i, INT2NUM(tokens[i]));
1043
+ }
1044
+
1045
+ RB_GC_GUARD(text_);
1046
+ return ret;
1047
+ }
911
1048
  };
912
1049
 
913
1050
  const rb_data_type_t RbLLaMAModel::llama_model_type = {
@@ -924,13 +1061,13 @@ class LLaMAContextWrapper {
924
1061
  public:
925
1062
  struct llama_context* ctx;
926
1063
 
927
- LLaMAContextWrapper() : ctx(NULL){};
1064
+ LLaMAContextWrapper() : ctx(NULL) {}
928
1065
 
929
1066
  ~LLaMAContextWrapper() {
930
1067
  if (ctx != NULL) {
931
1068
  llama_free(ctx);
932
1069
  }
933
- };
1070
+ }
934
1071
  };
935
1072
 
936
1073
  class RbLLaMAContext {
@@ -939,22 +1076,22 @@ public:
939
1076
  LLaMAContextWrapper* ptr = (LLaMAContextWrapper*)ruby_xmalloc(sizeof(LLaMAContextWrapper));
940
1077
  new (ptr) LLaMAContextWrapper();
941
1078
  return TypedData_Wrap_Struct(self, &llama_context_type, ptr);
942
- };
1079
+ }
943
1080
 
944
1081
  static void llama_context_free(void* ptr) {
945
1082
  ((LLaMAContextWrapper*)ptr)->~LLaMAContextWrapper();
946
1083
  ruby_xfree(ptr);
947
- };
1084
+ }
948
1085
 
949
1086
  static size_t llama_context_size(const void* ptr) {
950
1087
  return sizeof(*((LLaMAContextWrapper*)ptr));
951
- };
1088
+ }
952
1089
 
953
1090
  static LLaMAContextWrapper* get_llama_context(VALUE self) {
954
1091
  LLaMAContextWrapper* ptr;
955
1092
  TypedData_Get_Struct(self, LLaMAContextWrapper, &llama_context_type, ptr);
956
1093
  return ptr;
957
- };
1094
+ }
958
1095
 
959
1096
  static void define_class(VALUE outer) {
960
1097
  rb_cLLaMAContext = rb_define_class_under(outer, "Context", rb_cObject);
@@ -980,6 +1117,7 @@ public:
980
1117
  rb_define_method(rb_cLLaMAContext, "save_session_file", RUBY_METHOD_FUNC(_llama_context_save_session_file), -1);
981
1118
  rb_define_method(rb_cLLaMAContext, "sample_repetition_penalty", RUBY_METHOD_FUNC(_llama_context_sample_repetition_penalty), -1);
982
1119
  rb_define_method(rb_cLLaMAContext, "sample_frequency_and_presence_penalties", RUBY_METHOD_FUNC(_llama_context_sample_frequency_and_presence_penalties), -1);
1120
+ rb_define_method(rb_cLLaMAContext, "sample_classifier_free_guidance", RUBY_METHOD_FUNC(_llama_context_sample_classifier_free_guidance), -1);
983
1121
  rb_define_method(rb_cLLaMAContext, "sample_softmax", RUBY_METHOD_FUNC(_llama_context_sample_softmax), 1);
984
1122
  rb_define_method(rb_cLLaMAContext, "sample_top_k", RUBY_METHOD_FUNC(_llama_context_sample_top_k), -1);
985
1123
  rb_define_method(rb_cLLaMAContext, "sample_top_p", RUBY_METHOD_FUNC(_llama_context_sample_top_p), -1);
@@ -990,7 +1128,7 @@ public:
990
1128
  rb_define_method(rb_cLLaMAContext, "sample_token_mirostat_v2", RUBY_METHOD_FUNC(_llama_context_sample_token_mirostat_v2), -1);
991
1129
  rb_define_method(rb_cLLaMAContext, "sample_token_greedy", RUBY_METHOD_FUNC(_llama_context_sample_token_greedy), 1);
992
1130
  rb_define_method(rb_cLLaMAContext, "sample_token", RUBY_METHOD_FUNC(_llama_context_sample_token), 1);
993
- };
1131
+ }
994
1132
 
995
1133
  private:
996
1134
  static const rb_data_type_t llama_context_type;
@@ -1029,7 +1167,7 @@ private:
1029
1167
  rb_iv_set(self, "@has_evaluated", Qfalse);
1030
1168
 
1031
1169
  return Qnil;
1032
- };
1170
+ }
1033
1171
 
1034
1172
  static VALUE _llama_context_eval(int argc, VALUE* argv, VALUE self) {
1035
1173
  VALUE kw_args = Qnil;
@@ -1084,7 +1222,7 @@ private:
1084
1222
  rb_iv_set(self, "@has_evaluated", Qtrue);
1085
1223
 
1086
1224
  return Qnil;
1087
- };
1225
+ }
1088
1226
 
1089
1227
  static VALUE _llama_context_eval_embd(int argc, VALUE* argv, VALUE self) {
1090
1228
  VALUE kw_args = Qnil;
@@ -1157,7 +1295,7 @@ private:
1157
1295
  }
1158
1296
  RB_GC_GUARD(fname_);
1159
1297
  return Qtrue;
1160
- };
1298
+ }
1161
1299
 
1162
1300
  static VALUE _llama_context_tokenize(int argc, VALUE* argv, VALUE self) {
1163
1301
  VALUE kw_args = Qnil;
@@ -1203,7 +1341,7 @@ private:
1203
1341
 
1204
1342
  RB_GC_GUARD(text_);
1205
1343
  return output;
1206
- };
1344
+ }
1207
1345
 
1208
1346
  static VALUE _llama_context_token_to_str(VALUE self, VALUE token_) {
1209
1347
  LLaMAContextWrapper* ptr = get_llama_context(self);
@@ -1214,7 +1352,7 @@ private:
1214
1352
  const llama_token token = NUM2INT(token_);
1215
1353
  const char* str = llama_token_to_str(ptr->ctx, token);
1216
1354
  return str != nullptr ? rb_utf8_str_new_cstr(str) : rb_utf8_str_new_cstr("");
1217
- };
1355
+ }
1218
1356
 
1219
1357
  static VALUE _llama_context_logits(VALUE self) {
1220
1358
  LLaMAContextWrapper* ptr = get_llama_context(self);
@@ -1239,7 +1377,7 @@ private:
1239
1377
  }
1240
1378
 
1241
1379
  return output;
1242
- };
1380
+ }
1243
1381
 
1244
1382
  static VALUE _llama_context_embeddings(VALUE self) {
1245
1383
  LLaMAContextWrapper* ptr = get_llama_context(self);
@@ -1267,7 +1405,7 @@ private:
1267
1405
  }
1268
1406
 
1269
1407
  return output;
1270
- };
1408
+ }
1271
1409
 
1272
1410
  static VALUE _llama_context_vocab(int argc, VALUE* argv, VALUE self) {
1273
1411
  VALUE kw_args = Qnil;
@@ -1304,7 +1442,7 @@ private:
1304
1442
  }
1305
1443
 
1306
1444
  return rb_ary_new_from_args(2, ret_strings, ret_scores);
1307
- };
1445
+ }
1308
1446
 
1309
1447
  static VALUE _llama_context_n_vocab(VALUE self) {
1310
1448
  LLaMAContextWrapper* ptr = get_llama_context(self);
@@ -1313,7 +1451,7 @@ private:
1313
1451
  return Qnil;
1314
1452
  }
1315
1453
  return INT2NUM(llama_n_vocab(ptr->ctx));
1316
- };
1454
+ }
1317
1455
 
1318
1456
  static VALUE _llama_context_n_ctx(VALUE self) {
1319
1457
  LLaMAContextWrapper* ptr = get_llama_context(self);
@@ -1322,7 +1460,7 @@ private:
1322
1460
  return Qnil;
1323
1461
  }
1324
1462
  return INT2NUM(llama_n_ctx(ptr->ctx));
1325
- };
1463
+ }
1326
1464
 
1327
1465
  static VALUE _llama_context_n_embd(VALUE self) {
1328
1466
  LLaMAContextWrapper* ptr = get_llama_context(self);
@@ -1331,7 +1469,7 @@ private:
1331
1469
  return Qnil;
1332
1470
  }
1333
1471
  return INT2NUM(llama_n_embd(ptr->ctx));
1334
- };
1472
+ }
1335
1473
 
1336
1474
  static VALUE _llama_context_get_timings(VALUE self) {
1337
1475
  LLaMAContextWrapper* ptr = get_llama_context(self);
@@ -1353,7 +1491,7 @@ private:
1353
1491
  }
1354
1492
  llama_print_timings(ptr->ctx);
1355
1493
  return Qnil;
1356
- };
1494
+ }
1357
1495
 
1358
1496
  static VALUE _llama_context_reset_timings(VALUE self) {
1359
1497
  LLaMAContextWrapper* ptr = get_llama_context(self);
@@ -1363,7 +1501,7 @@ private:
1363
1501
  }
1364
1502
  llama_reset_timings(ptr->ctx);
1365
1503
  return Qnil;
1366
- };
1504
+ }
1367
1505
 
1368
1506
  static VALUE _llama_context_kv_cache_token_count(VALUE self) {
1369
1507
  LLaMAContextWrapper* ptr = get_llama_context(self);
@@ -1372,7 +1510,7 @@ private:
1372
1510
  return Qnil;
1373
1511
  }
1374
1512
  return INT2NUM(llama_get_kv_cache_token_count(ptr->ctx));
1375
- };
1513
+ }
1376
1514
 
1377
1515
  static VALUE _llama_context_set_rng_seed(VALUE self, VALUE seed_) {
1378
1516
  LLaMAContextWrapper* ptr = get_llama_context(self);
@@ -1387,7 +1525,7 @@ private:
1387
1525
  const uint32_t seed = NUM2INT(seed_);
1388
1526
  llama_set_rng_seed(ptr->ctx, seed);
1389
1527
  return Qnil;
1390
- };
1528
+ }
1391
1529
 
1392
1530
  static VALUE _llama_context_load_session_file(int argc, VALUE* argv, VALUE self) {
1393
1531
  VALUE kw_args = Qnil;
@@ -1525,7 +1663,7 @@ private:
1525
1663
  llama_sample_repetition_penalty(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size, penalty);
1526
1664
 
1527
1665
  return Qnil;
1528
- };
1666
+ }
1529
1667
 
1530
1668
  static VALUE _llama_context_sample_frequency_and_presence_penalties(int argc, VALUE* argv, VALUE self) {
1531
1669
  VALUE kw_args = Qnil;
@@ -1576,7 +1714,47 @@ private:
1576
1714
  llama_sample_frequency_and_presence_penalties(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size, alpha_frequency, alpha_presence);
1577
1715
 
1578
1716
  return Qnil;
1579
- };
1717
+ }
1718
+
1719
+ static VALUE _llama_context_sample_classifier_free_guidance(int argc, VALUE* argv, VALUE self) {
1720
+ VALUE kw_args = Qnil;
1721
+ ID kw_table[2] = { rb_intern("guidance"), rb_intern("scale") };
1722
+ VALUE kw_values[2] = { Qundef, Qundef };
1723
+ VALUE candidates = Qnil;
1724
+ rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
1725
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
1726
+
1727
+ if (!rb_obj_is_kind_of(kw_values[0], rb_cLLaMAContext)) {
1728
+ rb_raise(rb_eArgError, "guidance must be a Context");
1729
+ return Qnil;
1730
+ }
1731
+ if (!RB_FLOAT_TYPE_P(kw_values[1])) {
1732
+ rb_raise(rb_eArgError, "scale must be a float");
1733
+ return Qnil;
1734
+ }
1735
+
1736
+ LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
1737
+ if (ctx_ptr->ctx == NULL) {
1738
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1739
+ return Qnil;
1740
+ }
1741
+ LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
1742
+ if (cnd_ptr->array.data == nullptr) {
1743
+ rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
1744
+ return Qnil;
1745
+ }
1746
+
1747
+ LLaMAContextWrapper* guidance_ptr = get_llama_context(kw_values[0]);
1748
+ if (guidance_ptr->ctx == NULL) {
1749
+ rb_raise(rb_eRuntimeError, "guidance context is not initialized");
1750
+ return Qnil;
1751
+ }
1752
+ const float scale = NUM2DBL(kw_values[1]);
1753
+
1754
+ llama_sample_classifier_free_guidance(ctx_ptr->ctx, &(cnd_ptr->array), guidance_ptr->ctx, scale);
1755
+
1756
+ return Qnil;
1757
+ }
1580
1758
 
1581
1759
  static VALUE _llama_context_sample_softmax(VALUE self, VALUE candidates) {
1582
1760
  if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
@@ -1598,7 +1776,7 @@ private:
1598
1776
  llama_sample_softmax(ctx_ptr->ctx, &(cnd_ptr->array));
1599
1777
 
1600
1778
  return Qnil;
1601
- };
1779
+ }
1602
1780
 
1603
1781
  static VALUE _llama_context_sample_top_k(int argc, VALUE* argv, VALUE self) {
1604
1782
  VALUE kw_args = Qnil;
@@ -1637,7 +1815,7 @@ private:
1637
1815
  llama_sample_top_k(ctx_ptr->ctx, &(cnd_ptr->array), k, min_keep);
1638
1816
 
1639
1817
  return Qnil;
1640
- };
1818
+ }
1641
1819
 
1642
1820
  static VALUE _llama_context_sample_top_p(int argc, VALUE* argv, VALUE self) {
1643
1821
  VALUE kw_args = Qnil;
@@ -1676,7 +1854,7 @@ private:
1676
1854
  llama_sample_top_p(ctx_ptr->ctx, &(cnd_ptr->array), prob, min_keep);
1677
1855
 
1678
1856
  return Qnil;
1679
- };
1857
+ }
1680
1858
 
1681
1859
  static VALUE _llama_context_sample_tail_free(int argc, VALUE* argv, VALUE self) {
1682
1860
  VALUE kw_args = Qnil;
@@ -1715,7 +1893,7 @@ private:
1715
1893
  llama_sample_tail_free(ctx_ptr->ctx, &(cnd_ptr->array), z, min_keep);
1716
1894
 
1717
1895
  return Qnil;
1718
- };
1896
+ }
1719
1897
 
1720
1898
  static VALUE _llama_context_sample_typical(int argc, VALUE* argv, VALUE self) {
1721
1899
  VALUE kw_args = Qnil;
@@ -1754,7 +1932,7 @@ private:
1754
1932
  llama_sample_typical(ctx_ptr->ctx, &(cnd_ptr->array), prob, min_keep);
1755
1933
 
1756
1934
  return Qnil;
1757
- };
1935
+ }
1758
1936
 
1759
1937
  static VALUE _llama_context_sample_temperature(int argc, VALUE* argv, VALUE self) {
1760
1938
  VALUE kw_args = Qnil;
@@ -1788,7 +1966,7 @@ private:
1788
1966
  llama_sample_temperature(ctx_ptr->ctx, &(cnd_ptr->array), temperature);
1789
1967
 
1790
1968
  return Qnil;
1791
- };
1969
+ }
1792
1970
 
1793
1971
  static VALUE _llama_context_sample_token_mirostat(int argc, VALUE* argv, VALUE self) {
1794
1972
  VALUE kw_args = Qnil;
@@ -1840,7 +2018,7 @@ private:
1840
2018
  rb_ary_store(ret, 0, INT2NUM(id));
1841
2019
  rb_ary_store(ret, 1, DBL2NUM(mu));
1842
2020
  return ret;
1843
- };
2021
+ }
1844
2022
 
1845
2023
  static VALUE _llama_context_sample_token_mirostat_v2(int argc, VALUE* argv, VALUE self) {
1846
2024
  VALUE kw_args = Qnil;
@@ -1887,7 +2065,7 @@ private:
1887
2065
  rb_ary_store(ret, 0, INT2NUM(id));
1888
2066
  rb_ary_store(ret, 1, DBL2NUM(mu));
1889
2067
  return ret;
1890
- };
2068
+ }
1891
2069
 
1892
2070
  static VALUE _llama_context_sample_token_greedy(VALUE self, VALUE candidates) {
1893
2071
  LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
@@ -1906,7 +2084,7 @@ private:
1906
2084
  }
1907
2085
  llama_token id = llama_sample_token_greedy(ctx_ptr->ctx, &(cnd_ptr->array));
1908
2086
  return INT2NUM(id);
1909
- };
2087
+ }
1910
2088
 
1911
2089
  static VALUE _llama_context_sample_token(VALUE self, VALUE candidates) {
1912
2090
  LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
@@ -1925,7 +2103,7 @@ private:
1925
2103
  }
1926
2104
  llama_token id = llama_sample_token(ctx_ptr->ctx, &(cnd_ptr->array));
1927
2105
  return INT2NUM(id);
1928
- };
2106
+ }
1929
2107
  };
1930
2108
 
1931
2109
  const rb_data_type_t RbLLaMAContext::llama_context_type = {
@@ -1940,7 +2118,7 @@ const rb_data_type_t RbLLaMAContext::llama_context_type = {
1940
2118
 
1941
2119
  // module functions
1942
2120
 
1943
- static VALUE rb_llama_llama_init_backend(int argc, VALUE* argv, VALUE self) {
2121
+ static VALUE rb_llama_llama_backend_init(int argc, VALUE* argv, VALUE self) {
1944
2122
  VALUE kw_args = Qnil;
1945
2123
  ID kw_table[1] = { rb_intern("numa") };
1946
2124
  VALUE kw_values[1] = { Qundef };
@@ -1948,7 +2126,13 @@ static VALUE rb_llama_llama_init_backend(int argc, VALUE* argv, VALUE self) {
1948
2126
  rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
1949
2127
 
1950
2128
  const bool numa = kw_values[0] == Qundef ? false : (RTEST ? true : false);
1951
- llama_init_backend(numa);
2129
+ llama_backend_init(numa);
2130
+
2131
+ return Qnil;
2132
+ }
2133
+
2134
+ static VALUE rb_llama_llama_backend_free(VALUE self) {
2135
+ llama_backend_free();
1952
2136
 
1953
2137
  return Qnil;
1954
2138
  }
@@ -2010,6 +2194,10 @@ static VALUE rb_llama_mlock_supported(VALUE self) {
2010
2194
  return llama_mlock_supported() ? Qtrue : Qfalse;
2011
2195
  }
2012
2196
 
2197
+ static VALUE rb_llama_max_devices(VALUE self) {
2198
+ return INT2NUM(llama_max_devices());
2199
+ }
2200
+
2013
2201
  extern "C" void Init_llama_cpp(void) {
2014
2202
  rb_mLLaMACpp = rb_define_module("LLaMACpp");
2015
2203
 
@@ -2021,7 +2209,8 @@ extern "C" void Init_llama_cpp(void) {
2021
2209
  RbLLaMAContextParams::define_class(rb_mLLaMACpp);
2022
2210
  RbLLaMAModelQuantizeParams::define_class(rb_mLLaMACpp);
2023
2211
 
2024
- rb_define_module_function(rb_mLLaMACpp, "init_backend", rb_llama_llama_init_backend, -1);
2212
+ rb_define_module_function(rb_mLLaMACpp, "backend_init", rb_llama_llama_backend_init, -1);
2213
+ rb_define_module_function(rb_mLLaMACpp, "backend_free", rb_llama_llama_backend_free, 0);
2025
2214
  rb_define_module_function(rb_mLLaMACpp, "model_quantize", rb_llama_model_quantize, -1);
2026
2215
  rb_define_module_function(rb_mLLaMACpp, "token_bos", rb_llama_token_bos, 0);
2027
2216
  rb_define_module_function(rb_mLLaMACpp, "token_eos", rb_llama_token_eos, 0);
@@ -2029,6 +2218,7 @@ extern "C" void Init_llama_cpp(void) {
2029
2218
  rb_define_module_function(rb_mLLaMACpp, "print_system_info", rb_llama_print_system_info, 0);
2030
2219
  rb_define_module_function(rb_mLLaMACpp, "mmap_supported?", rb_llama_mmap_supported, 0);
2031
2220
  rb_define_module_function(rb_mLLaMACpp, "mlock_supported?", rb_llama_mlock_supported, 0);
2221
+ rb_define_module_function(rb_mLLaMACpp, "max_devices", rb_llama_max_devices, 0);
2032
2222
 
2033
2223
  rb_define_const(rb_mLLaMACpp, "LLAMA_MAX_DEVICES", INT2NUM(LLAMA_MAX_DEVICES));
2034
2224