llama_cpp 0.5.3 → 0.7.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,9 @@
1
1
  #include "llama_cpp.h"
2
2
 
3
3
  VALUE rb_mLLaMACpp;
4
+ VALUE rb_cLLaMABatch;
4
5
  VALUE rb_cLLaMAModel;
6
+ VALUE rb_cLLaMAModelParams;
5
7
  VALUE rb_cLLaMATimings;
6
8
  VALUE rb_cLLaMAContext;
7
9
  VALUE rb_cLLaMAContextParams;
@@ -11,6 +13,238 @@ VALUE rb_cLLaMATokenDataArray;
11
13
  VALUE rb_cLLaMAGrammarElement;
12
14
  VALUE rb_cLLaMAGrammar;
13
15
 
16
+ class LLaMABatchWrapper {
17
+ public:
18
+ llama_batch batch;
19
+
20
+ LLaMABatchWrapper() {}
21
+
22
+ ~LLaMABatchWrapper() {
23
+ llama_batch_free(batch);
24
+ }
25
+ };
26
+
27
+ class RbLLaMABatch {
28
+ public:
29
+ static VALUE llama_batch_alloc(VALUE self) {
30
+ LLaMABatchWrapper* ptr = (LLaMABatchWrapper*)ruby_xmalloc(sizeof(LLaMABatchWrapper));
31
+ new (ptr) LLaMABatchWrapper();
32
+ return TypedData_Wrap_Struct(self, &llama_batch_type, ptr);
33
+ }
34
+
35
+ static void llama_batch_free(void* ptr) {
36
+ ((LLaMABatchWrapper*)ptr)->~LLaMABatchWrapper();
37
+ ruby_xfree(ptr);
38
+ }
39
+
40
+ static size_t llama_batch_size(const void* ptr) {
41
+ return sizeof(*((LLaMABatchWrapper*)ptr));
42
+ }
43
+
44
+ static LLaMABatchWrapper* get_llama_batch(VALUE self) {
45
+ LLaMABatchWrapper* ptr;
46
+ TypedData_Get_Struct(self, LLaMABatchWrapper, &llama_batch_type, ptr);
47
+ return ptr;
48
+ }
49
+
50
+ static void define_class(VALUE outer) {
51
+ rb_cLLaMABatch = rb_define_class_under(outer, "Batch", rb_cObject);
52
+ rb_define_alloc_func(rb_cLLaMABatch, llama_batch_alloc);
53
+ rb_define_method(rb_cLLaMABatch, "initialize", RUBY_METHOD_FUNC(_llama_batch_initialize), -1);
54
+ rb_define_method(rb_cLLaMABatch, "n_tokens=", RUBY_METHOD_FUNC(_llama_batch_set_n_tokens), 1);
55
+ rb_define_method(rb_cLLaMABatch, "n_tokens", RUBY_METHOD_FUNC(_llama_batch_get_n_tokens), 0);
56
+ rb_define_method(rb_cLLaMABatch, "all_pos_zero=", RUBY_METHOD_FUNC(_llama_batch_set_all_pos_zero), 1);
57
+ rb_define_method(rb_cLLaMABatch, "all_pos_zero", RUBY_METHOD_FUNC(_llama_batch_get_all_pos_zero), 0);
58
+ rb_define_method(rb_cLLaMABatch, "all_pos_one=", RUBY_METHOD_FUNC(_llama_batch_set_all_pos_one), 1);
59
+ rb_define_method(rb_cLLaMABatch, "all_pos_one", RUBY_METHOD_FUNC(_llama_batch_get_all_pos_one), 0);
60
+ rb_define_method(rb_cLLaMABatch, "all_seq_id=", RUBY_METHOD_FUNC(_llama_batch_set_all_seq_id), 1);
61
+ rb_define_method(rb_cLLaMABatch, "all_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_all_seq_id), 0);
62
+ rb_define_method(rb_cLLaMABatch, "set_token", RUBY_METHOD_FUNC(_llama_batch_set_token), 2);
63
+ rb_define_method(rb_cLLaMABatch, "get_token", RUBY_METHOD_FUNC(_llama_batch_get_token), 1);
64
+ rb_define_method(rb_cLLaMABatch, "set_pos", RUBY_METHOD_FUNC(_llama_batch_set_pos), 2);
65
+ rb_define_method(rb_cLLaMABatch, "get_pos", RUBY_METHOD_FUNC(_llama_batch_get_pos), 1);
66
+ rb_define_method(rb_cLLaMABatch, "set_seq_id", RUBY_METHOD_FUNC(_llama_batch_set_seq_id), 2);
67
+ rb_define_method(rb_cLLaMABatch, "get_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_seq_id), 1);
68
+ rb_define_method(rb_cLLaMABatch, "set_logits", RUBY_METHOD_FUNC(_llama_batch_set_logits), 2);
69
+ rb_define_method(rb_cLLaMABatch, "get_logits", RUBY_METHOD_FUNC(_llama_batch_get_logits), 1);
70
+ }
71
+
72
+ private:
73
+ static const rb_data_type_t llama_batch_type;
74
+
75
+ static VALUE _llama_batch_initialize(int argc, VALUE* argv, VALUE self) {
76
+ VALUE kw_args = Qnil;
77
+ ID kw_table[2] = { rb_intern("n_tokens"), rb_intern("embd") };
78
+ VALUE kw_values[2] = { Qundef, Qundef };
79
+ rb_scan_args(argc, argv, ":", &kw_args);
80
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
81
+
82
+ if (!RB_INTEGER_TYPE_P(kw_values[0])) {
83
+ rb_raise(rb_eArgError, "n_tokens must be an integer");
84
+ return Qnil;
85
+ }
86
+ if (!RB_INTEGER_TYPE_P(kw_values[1])) {
87
+ rb_raise(rb_eArgError, "embd must be an integer");
88
+ return Qnil;
89
+ }
90
+
91
+ const int32_t n_tokens = NUM2INT(kw_values[0]);
92
+ const int32_t embd = NUM2INT(kw_values[1]);
93
+
94
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
95
+ ptr->batch = llama_batch_init(n_tokens, embd);
96
+
97
+ return Qnil;
98
+ }
99
+
100
+ // n_tokens
101
+ static VALUE _llama_batch_set_n_tokens(VALUE self, VALUE n_tokens) {
102
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
103
+ ptr->batch.n_tokens = NUM2INT(n_tokens);
104
+ return INT2NUM(ptr->batch.n_tokens);
105
+ }
106
+
107
+ static VALUE _llama_batch_get_n_tokens(VALUE self) {
108
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
109
+ return INT2NUM(ptr->batch.n_tokens);
110
+ }
111
+
112
+ // all_pos_0
113
+ static VALUE _llama_batch_set_all_pos_zero(VALUE self, VALUE all_pos_0) {
114
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
115
+ ptr->batch.all_pos_0 = NUM2INT(all_pos_0);
116
+ return INT2NUM(ptr->batch.all_pos_0);
117
+ }
118
+
119
+ static VALUE _llama_batch_get_all_pos_zero(VALUE self) {
120
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
121
+ return INT2NUM(ptr->batch.all_pos_0);
122
+ }
123
+
124
+ // all_pos_1
125
+ static VALUE _llama_batch_set_all_pos_one(VALUE self, VALUE all_pos_1) {
126
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
127
+ ptr->batch.all_pos_1 = NUM2INT(all_pos_1);
128
+ return INT2NUM(ptr->batch.all_pos_1);
129
+ }
130
+
131
+ static VALUE _llama_batch_get_all_pos_one(VALUE self) {
132
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
133
+ return INT2NUM(ptr->batch.all_pos_1);
134
+ }
135
+
136
+ // all_seq_id
137
+ static VALUE _llama_batch_set_all_seq_id(VALUE self, VALUE all_seq_id) {
138
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
139
+ ptr->batch.all_seq_id = NUM2INT(all_seq_id);
140
+ return INT2NUM(ptr->batch.all_seq_id);
141
+ }
142
+
143
+ static VALUE _llama_batch_get_all_seq_id(VALUE self) {
144
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
145
+ return INT2NUM(ptr->batch.all_seq_id);
146
+ }
147
+
148
+ // token
149
+ static VALUE _llama_batch_set_token(VALUE self, VALUE idx, VALUE value) {
150
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
151
+ const int32_t id = NUM2INT(idx);
152
+ if (id < 0 || id >= ptr->batch.n_tokens) {
153
+ rb_raise(rb_eArgError, "idx must be in [0, n_tokens)");
154
+ return Qnil;
155
+ }
156
+ ptr->batch.token[id] = NUM2INT(value);
157
+ return INT2NUM(ptr->batch.token[id]);
158
+ }
159
+
160
+ static VALUE _llama_batch_get_token(VALUE self, VALUE idx) {
161
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
162
+ const int32_t id = NUM2INT(idx);
163
+ if (id < 0 || id >= ptr->batch.n_tokens) {
164
+ rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
165
+ return Qnil;
166
+ }
167
+ return INT2NUM(ptr->batch.token[id]);
168
+ }
169
+
170
+ // pos
171
+ static VALUE _llama_batch_set_pos(VALUE self, VALUE idx, VALUE value) {
172
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
173
+ const int32_t id = NUM2INT(idx);
174
+ if (id < 0 || id >= ptr->batch.n_tokens) {
175
+ rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
176
+ return Qnil;
177
+ }
178
+ ptr->batch.pos[id] = NUM2INT(value);
179
+ return INT2NUM(ptr->batch.pos[id]);
180
+ }
181
+
182
+ static VALUE _llama_batch_get_pos(VALUE self, VALUE idx) {
183
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
184
+ const int32_t id = NUM2INT(idx);
185
+ if (id < 0 || id >= ptr->batch.n_tokens) {
186
+ rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
187
+ return Qnil;
188
+ }
189
+ return INT2NUM(ptr->batch.pos[id]);
190
+ }
191
+
192
+ // seq_id
193
+ static VALUE _llama_batch_set_seq_id(VALUE self, VALUE idx, VALUE value) {
194
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
195
+ const int32_t id = NUM2INT(idx);
196
+ if (id < 0 || id >= ptr->batch.n_tokens) {
197
+ rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
198
+ return Qnil;
199
+ }
200
+ ptr->batch.seq_id[id] = NUM2INT(value);
201
+ return INT2NUM(ptr->batch.seq_id[id]);
202
+ }
203
+
204
+ static VALUE _llama_batch_get_seq_id(VALUE self, VALUE idx) {
205
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
206
+ const int32_t id = NUM2INT(idx);
207
+ if (id < 0 || id >= ptr->batch.n_tokens) {
208
+ rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
209
+ return Qnil;
210
+ }
211
+ return INT2NUM(ptr->batch.seq_id[id]);
212
+ }
213
+
214
+ // logits
215
+ static VALUE _llama_batch_set_logits(VALUE self, VALUE idx, VALUE value) {
216
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
217
+ const int32_t id = NUM2INT(idx);
218
+ if (id < 0 || id >= ptr->batch.n_tokens) {
219
+ rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
220
+ return Qnil;
221
+ }
222
+ ptr->batch.logits[id] = RTEST(value) ? true : false;
223
+ return ptr->batch.logits[id] ? Qtrue : Qfalse;
224
+ }
225
+
226
+ static VALUE _llama_batch_get_logits(VALUE self, VALUE idx) {
227
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
228
+ const int32_t id = NUM2INT(idx);
229
+ if (id < 0 || id >= ptr->batch.n_tokens) {
230
+ rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
231
+ return Qnil;
232
+ }
233
+ return ptr->batch.logits[id] ? Qtrue : Qfalse;
234
+ }
235
+ };
236
+
237
+ const rb_data_type_t RbLLaMABatch::llama_batch_type = {
238
+ "RbLLaMABatch",
239
+ { NULL,
240
+ RbLLaMABatch::llama_batch_free,
241
+ RbLLaMABatch::llama_batch_size },
242
+ NULL,
243
+ NULL,
244
+ RUBY_TYPED_FREE_IMMEDIATELY
245
+
246
+ };
247
+
14
248
  class LLaMATokenDataWrapper {
15
249
  public:
16
250
  llama_token_data data;
@@ -363,6 +597,144 @@ const rb_data_type_t RbLLaMATimings::llama_timings_type = {
363
597
  RUBY_TYPED_FREE_IMMEDIATELY
364
598
  };
365
599
 
600
+ class LLaMAModelParamsWrapper {
601
+ public:
602
+ struct llama_model_params params;
603
+
604
+ LLaMAModelParamsWrapper() : params(llama_model_default_params()) {}
605
+
606
+ ~LLaMAModelParamsWrapper() {}
607
+ };
608
+
609
+ class RbLLaMAModelParams {
610
+ public:
611
+ static VALUE llama_model_params_alloc(VALUE self) {
612
+ LLaMAModelParamsWrapper* ptr = (LLaMAModelParamsWrapper*)ruby_xmalloc(sizeof(LLaMAModelParamsWrapper));
613
+ new (ptr) LLaMAModelParamsWrapper();
614
+ return TypedData_Wrap_Struct(self, &llama_model_params_type, ptr);
615
+ }
616
+
617
+ static void llama_model_params_free(void* ptr) {
618
+ ((LLaMAModelParamsWrapper*)ptr)->~LLaMAModelParamsWrapper();
619
+ ruby_xfree(ptr);
620
+ }
621
+
622
+ static size_t llama_model_params_size(const void* ptr) {
623
+ return sizeof(*((LLaMAModelParamsWrapper*)ptr));
624
+ }
625
+
626
+ static LLaMAModelParamsWrapper* get_llama_model_params(VALUE self) {
627
+ LLaMAModelParamsWrapper* ptr;
628
+ TypedData_Get_Struct(self, LLaMAModelParamsWrapper, &llama_model_params_type, ptr);
629
+ return ptr;
630
+ }
631
+
632
+ static void define_class(VALUE outer) {
633
+ rb_cLLaMAModelParams = rb_define_class_under(outer, "ModelParams", rb_cObject);
634
+ rb_define_alloc_func(rb_cLLaMAModelParams, llama_model_params_alloc);
635
+ rb_define_method(rb_cLLaMAModelParams, "n_gpu_layers=", RUBY_METHOD_FUNC(_llama_model_params_set_n_gpu_layers), 1);
636
+ rb_define_method(rb_cLLaMAModelParams, "n_gpu_layers", RUBY_METHOD_FUNC(_llama_model_params_get_n_gpu_layers), 0);
637
+ rb_define_method(rb_cLLaMAModelParams, "main_gpu=", RUBY_METHOD_FUNC(_llama_model_params_set_main_gpu), 1);
638
+ rb_define_method(rb_cLLaMAModelParams, "main_gpu", RUBY_METHOD_FUNC(_llama_model_params_get_main_gpu), 0);
639
+ rb_define_method(rb_cLLaMAModelParams, "tensor_split", RUBY_METHOD_FUNC(_llama_model_params_get_tensor_split), 0);
640
+ rb_define_method(rb_cLLaMAModelParams, "vocab_only=", RUBY_METHOD_FUNC(_llama_model_params_set_vocab_only), 1);
641
+ rb_define_method(rb_cLLaMAModelParams, "vocab_only", RUBY_METHOD_FUNC(_llama_model_params_get_vocab_only), 0);
642
+ rb_define_method(rb_cLLaMAModelParams, "use_mmap=", RUBY_METHOD_FUNC(_llama_model_params_set_use_mmap), 1);
643
+ rb_define_method(rb_cLLaMAModelParams, "use_mmap", RUBY_METHOD_FUNC(_llama_model_params_get_use_mmap), 0);
644
+ rb_define_method(rb_cLLaMAModelParams, "use_mlock=", RUBY_METHOD_FUNC(_llama_model_params_set_use_mlock), 1);
645
+ rb_define_method(rb_cLLaMAModelParams, "use_mlock", RUBY_METHOD_FUNC(_llama_model_params_get_use_mlock), 0);
646
+ }
647
+
648
+ private:
649
+ static const rb_data_type_t llama_model_params_type;
650
+
651
+ // n_gpu_layers
652
+ static VALUE _llama_model_params_set_n_gpu_layers(VALUE self, VALUE n_gpu_layers) {
653
+ LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
654
+ ptr->params.n_gpu_layers = NUM2INT(n_gpu_layers);
655
+ return INT2NUM(ptr->params.n_gpu_layers);
656
+ }
657
+
658
+ static VALUE _llama_model_params_get_n_gpu_layers(VALUE self) {
659
+ LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
660
+ return INT2NUM(ptr->params.n_gpu_layers);
661
+ }
662
+
663
+ // main_gpu
664
+ static VALUE _llama_model_params_set_main_gpu(VALUE self, VALUE main_gpu) {
665
+ LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
666
+ ptr->params.main_gpu = NUM2INT(main_gpu);
667
+ return INT2NUM(ptr->params.main_gpu);
668
+ }
669
+
670
+ static VALUE _llama_model_params_get_main_gpu(VALUE self) {
671
+ LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
672
+ return INT2NUM(ptr->params.main_gpu);
673
+ }
674
+
675
+ // tensor_split
676
+ static VALUE _llama_model_params_get_tensor_split(VALUE self) {
677
+ if (LLAMA_MAX_DEVICES < 1) {
678
+ return rb_ary_new();
679
+ }
680
+ VALUE ret = rb_ary_new2(LLAMA_MAX_DEVICES);
681
+ LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
682
+ if (ptr->params.tensor_split == nullptr) {
683
+ return rb_ary_new();
684
+ }
685
+ for (size_t i = 0; i < LLAMA_MAX_DEVICES; i++) {
686
+ rb_ary_store(ret, i, DBL2NUM(ptr->params.tensor_split[i]));
687
+ }
688
+ return ret;
689
+ }
690
+
691
+ // vocab_only
692
+ static VALUE _llama_model_params_set_vocab_only(VALUE self, VALUE vocab_only) {
693
+ LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
694
+ ptr->params.vocab_only = RTEST(vocab_only) ? true : false;
695
+ return ptr->params.vocab_only ? Qtrue : Qfalse;
696
+ }
697
+
698
+ static VALUE _llama_model_params_get_vocab_only(VALUE self) {
699
+ LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
700
+ return ptr->params.vocab_only ? Qtrue : Qfalse;
701
+ }
702
+
703
+ // use_mmap
704
+ static VALUE _llama_model_params_set_use_mmap(VALUE self, VALUE use_mmap) {
705
+ LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
706
+ ptr->params.use_mmap = RTEST(use_mmap) ? true : false;
707
+ return ptr->params.use_mmap ? Qtrue : Qfalse;
708
+ }
709
+
710
+ static VALUE _llama_model_params_get_use_mmap(VALUE self) {
711
+ LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
712
+ return ptr->params.use_mmap ? Qtrue : Qfalse;
713
+ }
714
+
715
+ // use_mlock
716
+ static VALUE _llama_model_params_set_use_mlock(VALUE self, VALUE use_mlock) {
717
+ LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
718
+ ptr->params.use_mlock = RTEST(use_mlock) ? true : false;
719
+ return ptr->params.use_mlock ? Qtrue : Qfalse;
720
+ }
721
+
722
+ static VALUE _llama_model_params_get_use_mlock(VALUE self) {
723
+ LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
724
+ return ptr->params.use_mlock ? Qtrue : Qfalse;
725
+ }
726
+ };
727
+
728
+ const rb_data_type_t RbLLaMAModelParams::llama_model_params_type = {
729
+ "RbLLaMAModelParams",
730
+ { NULL,
731
+ RbLLaMAModelParams::llama_model_params_free,
732
+ RbLLaMAModelParams::llama_model_params_size },
733
+ NULL,
734
+ NULL,
735
+ RUBY_TYPED_FREE_IMMEDIATELY
736
+ };
737
+
366
738
  class LLaMAContextParamsWrapper {
367
739
  public:
368
740
  struct llama_context_params params;
@@ -399,35 +771,26 @@ public:
399
771
  rb_cLLaMAContextParams = rb_define_class_under(outer, "ContextParams", rb_cObject);
400
772
  rb_define_alloc_func(rb_cLLaMAContextParams, llama_context_params_alloc);
401
773
  // rb_define_method(rb_cLLaMAContextParams, "initialize", RUBY_METHOD_FUNC(_llama_context_params_init), 0);
774
+ rb_define_method(rb_cLLaMAContextParams, "seed=", RUBY_METHOD_FUNC(_llama_context_params_set_seed), 1);
775
+ rb_define_method(rb_cLLaMAContextParams, "seed", RUBY_METHOD_FUNC(_llama_context_params_get_seed), 0);
402
776
  rb_define_method(rb_cLLaMAContextParams, "n_ctx=", RUBY_METHOD_FUNC(_llama_context_params_set_n_ctx), 1);
403
777
  rb_define_method(rb_cLLaMAContextParams, "n_ctx", RUBY_METHOD_FUNC(_llama_context_params_get_n_ctx), 0);
404
778
  rb_define_method(rb_cLLaMAContextParams, "n_batch=", RUBY_METHOD_FUNC(_llama_context_params_set_n_batch), 1);
405
779
  rb_define_method(rb_cLLaMAContextParams, "n_batch", RUBY_METHOD_FUNC(_llama_context_params_get_n_batch), 0);
406
- rb_define_method(rb_cLLaMAContextParams, "n_gpu_layers=", RUBY_METHOD_FUNC(_llama_context_params_set_n_gpu_layers), 1);
407
- rb_define_method(rb_cLLaMAContextParams, "n_gpu_layers", RUBY_METHOD_FUNC(_llama_context_params_get_n_gpu_layers), 0);
408
- rb_define_method(rb_cLLaMAContextParams, "main_gpu=", RUBY_METHOD_FUNC(_llama_context_params_set_main_gpu), 1);
409
- rb_define_method(rb_cLLaMAContextParams, "main_gpu", RUBY_METHOD_FUNC(_llama_context_params_get_main_gpu), 0);
410
- rb_define_method(rb_cLLaMAContextParams, "tensor_split", RUBY_METHOD_FUNC(_llama_context_params_get_tensor_split), 0);
780
+ rb_define_method(rb_cLLaMAContextParams, "n_threads=", RUBY_METHOD_FUNC(_llama_context_params_set_n_threads), 1);
781
+ rb_define_method(rb_cLLaMAContextParams, "n_threads", RUBY_METHOD_FUNC(_llama_context_params_get_n_threads), 0);
782
+ rb_define_method(rb_cLLaMAContextParams, "n_threads_batch=", RUBY_METHOD_FUNC(_llama_context_params_set_n_threads_batch), 1);
783
+ rb_define_method(rb_cLLaMAContextParams, "n_threads_batch", RUBY_METHOD_FUNC(_llama_context_params_get_n_threads_batch), 0);
411
784
  rb_define_method(rb_cLLaMAContextParams, "rope_freq_base=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_freq_base), 1);
412
785
  rb_define_method(rb_cLLaMAContextParams, "rope_freq_base", RUBY_METHOD_FUNC(_llama_context_params_get_rope_freq_base), 0);
413
786
  rb_define_method(rb_cLLaMAContextParams, "rope_freq_scale=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_freq_scale), 1);
414
787
  rb_define_method(rb_cLLaMAContextParams, "rope_freq_scale", RUBY_METHOD_FUNC(_llama_context_params_get_rope_freq_scale), 0);
415
- rb_define_method(rb_cLLaMAContextParams, "low_vram=", RUBY_METHOD_FUNC(_llama_context_params_set_low_vram), 1);
416
- rb_define_method(rb_cLLaMAContextParams, "low_vram", RUBY_METHOD_FUNC(_llama_context_params_get_low_vram), 0);
417
788
  rb_define_method(rb_cLLaMAContextParams, "mul_mat_q=", RUBY_METHOD_FUNC(_llama_context_params_set_mul_mat_q), 1);
418
789
  rb_define_method(rb_cLLaMAContextParams, "mul_mat_q", RUBY_METHOD_FUNC(_llama_context_params_get_mul_mat_q), 0);
419
- rb_define_method(rb_cLLaMAContextParams, "seed=", RUBY_METHOD_FUNC(_llama_context_params_set_seed), 1);
420
- rb_define_method(rb_cLLaMAContextParams, "seed", RUBY_METHOD_FUNC(_llama_context_params_get_seed), 0);
421
790
  rb_define_method(rb_cLLaMAContextParams, "f16_kv=", RUBY_METHOD_FUNC(_llama_context_params_set_f16_kv), 1);
422
791
  rb_define_method(rb_cLLaMAContextParams, "f16_kv", RUBY_METHOD_FUNC(_llama_context_params_get_f16_kv), 0);
423
792
  rb_define_method(rb_cLLaMAContextParams, "logits_all=", RUBY_METHOD_FUNC(_llama_context_params_set_logits_all), 1);
424
793
  rb_define_method(rb_cLLaMAContextParams, "logits_all", RUBY_METHOD_FUNC(_llama_context_params_get_logits_all), 0);
425
- rb_define_method(rb_cLLaMAContextParams, "vocab_only=", RUBY_METHOD_FUNC(_llama_context_params_set_vocab_only), 1);
426
- rb_define_method(rb_cLLaMAContextParams, "vocab_only", RUBY_METHOD_FUNC(_llama_context_params_get_vocab_only), 0);
427
- rb_define_method(rb_cLLaMAContextParams, "use_mmap=", RUBY_METHOD_FUNC(_llama_context_params_set_use_mmap), 1);
428
- rb_define_method(rb_cLLaMAContextParams, "use_mmap", RUBY_METHOD_FUNC(_llama_context_params_get_use_mmap), 0);
429
- rb_define_method(rb_cLLaMAContextParams, "use_mlock=", RUBY_METHOD_FUNC(_llama_context_params_set_use_mlock), 1);
430
- rb_define_method(rb_cLLaMAContextParams, "use_mlock", RUBY_METHOD_FUNC(_llama_context_params_get_use_mlock), 0);
431
794
  rb_define_method(rb_cLLaMAContextParams, "embedding=", RUBY_METHOD_FUNC(_llama_context_params_set_embedding), 1);
432
795
  rb_define_method(rb_cLLaMAContextParams, "embedding", RUBY_METHOD_FUNC(_llama_context_params_get_embedding), 0);
433
796
  }
@@ -441,6 +804,22 @@ private:
441
804
  // return self;
442
805
  // }
443
806
 
807
+ // seed
808
+ static VALUE _llama_context_params_set_seed(VALUE self, VALUE seed) {
809
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
810
+ if (NUM2INT(seed) < 0) {
811
+ rb_raise(rb_eArgError, "seed must be positive");
812
+ return Qnil;
813
+ }
814
+ ptr->params.seed = NUM2INT(seed);
815
+ return INT2NUM(ptr->params.seed);
816
+ }
817
+
818
+ static VALUE _llama_context_params_get_seed(VALUE self) {
819
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
820
+ return INT2NUM(ptr->params.seed);
821
+ }
822
+
444
823
  // n_ctx
445
824
  static VALUE _llama_context_params_set_n_ctx(VALUE self, VALUE n_ctx) {
446
825
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
@@ -465,41 +844,28 @@ private:
465
844
  return INT2NUM(ptr->params.n_batch);
466
845
  }
467
846
 
468
- // n_gpu_layers
469
- static VALUE _llama_context_params_set_n_gpu_layers(VALUE self, VALUE n_gpu_layers) {
847
+ // n_threads
848
+ static VALUE _llama_context_params_set_n_threads(VALUE self, VALUE n_threads) {
470
849
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
471
- ptr->params.n_gpu_layers = NUM2INT(n_gpu_layers);
472
- return INT2NUM(ptr->params.n_gpu_layers);
850
+ ptr->params.n_threads = NUM2INT(n_threads);
851
+ return INT2NUM(ptr->params.n_threads);
473
852
  }
474
853
 
475
- static VALUE _llama_context_params_get_n_gpu_layers(VALUE self) {
854
+ static VALUE _llama_context_params_get_n_threads(VALUE self) {
476
855
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
477
- return INT2NUM(ptr->params.n_gpu_layers);
856
+ return INT2NUM(ptr->params.n_threads);
478
857
  }
479
858
 
480
- // main_gpu
481
- static VALUE _llama_context_params_set_main_gpu(VALUE self, VALUE main_gpu) {
859
+ // n_threads_batch
860
+ static VALUE _llama_context_params_set_n_threads_batch(VALUE self, VALUE n_threads_batch) {
482
861
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
483
- ptr->params.main_gpu = NUM2INT(main_gpu);
484
- return INT2NUM(ptr->params.main_gpu);
862
+ ptr->params.n_threads_batch = NUM2INT(n_threads_batch);
863
+ return INT2NUM(ptr->params.n_threads_batch);
485
864
  }
486
865
 
487
- static VALUE _llama_context_params_get_main_gpu(VALUE self) {
866
+ static VALUE _llama_context_params_get_n_threads_batch(VALUE self) {
488
867
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
489
- return INT2NUM(ptr->params.main_gpu);
490
- }
491
-
492
- // tensor_split
493
- static VALUE _llama_context_params_get_tensor_split(VALUE self) {
494
- if (LLAMA_MAX_DEVICES < 1) {
495
- return rb_ary_new();
496
- }
497
- VALUE ret = rb_ary_new2(LLAMA_MAX_DEVICES);
498
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
499
- for (size_t i = 0; i < LLAMA_MAX_DEVICES; i++) {
500
- rb_ary_store(ret, i, DBL2NUM(ptr->params.tensor_split[i]));
501
- }
502
- return ret;
868
+ return INT2NUM(ptr->params.n_threads_batch);
503
869
  }
504
870
 
505
871
  // rope_freq_base
@@ -526,18 +892,6 @@ private:
526
892
  return DBL2NUM(ptr->params.rope_freq_scale);
527
893
  }
528
894
 
529
- // low_vram
530
- static VALUE _llama_context_params_set_low_vram(VALUE self, VALUE low_vram) {
531
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
532
- ptr->params.low_vram = RTEST(low_vram) ? true : false;
533
- return ptr->params.low_vram ? Qtrue : Qfalse;
534
- }
535
-
536
- static VALUE _llama_context_params_get_low_vram(VALUE self) {
537
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
538
- return ptr->params.low_vram ? Qtrue : Qfalse;
539
- }
540
-
541
895
  // mul_mat_q
542
896
  static VALUE _llama_context_params_set_mul_mat_q(VALUE self, VALUE mul_mat_q) {
543
897
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
@@ -550,22 +904,6 @@ private:
550
904
  return ptr->params.mul_mat_q ? Qtrue : Qfalse;
551
905
  }
552
906
 
553
- // seed
554
- static VALUE _llama_context_params_set_seed(VALUE self, VALUE seed) {
555
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
556
- if (NUM2INT(seed) < 0) {
557
- rb_raise(rb_eArgError, "seed must be positive");
558
- return Qnil;
559
- }
560
- ptr->params.seed = NUM2INT(seed);
561
- return INT2NUM(ptr->params.seed);
562
- }
563
-
564
- static VALUE _llama_context_params_get_seed(VALUE self) {
565
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
566
- return INT2NUM(ptr->params.seed);
567
- }
568
-
569
907
  // f16_kv
570
908
  static VALUE _llama_context_params_set_f16_kv(VALUE self, VALUE f16_kv) {
571
909
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
@@ -590,42 +928,6 @@ private:
590
928
  return ptr->params.logits_all ? Qtrue : Qfalse;
591
929
  }
592
930
 
593
- // vocab_only
594
- static VALUE _llama_context_params_set_vocab_only(VALUE self, VALUE vocab_only) {
595
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
596
- ptr->params.vocab_only = RTEST(vocab_only) ? true : false;
597
- return ptr->params.vocab_only ? Qtrue : Qfalse;
598
- }
599
-
600
- static VALUE _llama_context_params_get_vocab_only(VALUE self) {
601
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
602
- return ptr->params.vocab_only ? Qtrue : Qfalse;
603
- }
604
-
605
- // use_mmap
606
- static VALUE _llama_context_params_set_use_mmap(VALUE self, VALUE use_mmap) {
607
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
608
- ptr->params.use_mmap = RTEST(use_mmap) ? true : false;
609
- return ptr->params.use_mmap ? Qtrue : Qfalse;
610
- }
611
-
612
- static VALUE _llama_context_params_get_use_mmap(VALUE self) {
613
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
614
- return ptr->params.use_mmap ? Qtrue : Qfalse;
615
- }
616
-
617
- // use_mlock
618
- static VALUE _llama_context_params_set_use_mlock(VALUE self, VALUE use_mlock) {
619
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
620
- ptr->params.use_mlock = RTEST(use_mlock) ? true : false;
621
- return ptr->params.use_mlock ? Qtrue : Qfalse;
622
- }
623
-
624
- static VALUE _llama_context_params_get_use_mlock(VALUE self) {
625
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
626
- return ptr->params.use_mlock ? Qtrue : Qfalse;
627
- }
628
-
629
931
  // embedding
630
932
  static VALUE _llama_context_params_set_embedding(VALUE self, VALUE embedding) {
631
933
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
@@ -823,11 +1125,11 @@ public:
823
1125
  rb_define_method(rb_cLLaMAModel, "load", RUBY_METHOD_FUNC(_llama_model_load), -1);
824
1126
  rb_define_method(rb_cLLaMAModel, "apply_lora_from_file", RUBY_METHOD_FUNC(_llama_model_apply_lora_from_file), -1);
825
1127
  rb_define_method(rb_cLLaMAModel, "n_vocab", RUBY_METHOD_FUNC(_llama_model_get_model_n_vocab), 0);
826
- rb_define_method(rb_cLLaMAModel, "n_ctx", RUBY_METHOD_FUNC(_llama_model_get_model_n_ctx), 0);
827
1128
  rb_define_method(rb_cLLaMAModel, "n_ctx_train", RUBY_METHOD_FUNC(_llama_model_get_model_n_ctx_train), 0);
828
1129
  rb_define_method(rb_cLLaMAModel, "n_embd", RUBY_METHOD_FUNC(_llama_model_get_model_n_embd), 0);
829
- rb_define_method(rb_cLLaMAModel, "token_to_piece", RUBY_METHOD_FUNC(_llama_model_token_to_piece_with_model), 1);
830
- rb_define_method(rb_cLLaMAModel, "tokenize", RUBY_METHOD_FUNC(_llama_model_tokenize_with_model), -1);
1130
+ rb_define_method(rb_cLLaMAModel, "rope_freq_scale_train", RUBY_METHOD_FUNC(_llama_model_rope_freq_scale_train), 0);
1131
+ rb_define_method(rb_cLLaMAModel, "token_to_piece", RUBY_METHOD_FUNC(_llama_model_token_to_piece), 1);
1132
+ rb_define_method(rb_cLLaMAModel, "tokenize", RUBY_METHOD_FUNC(_llama_model_tokenize), -1);
831
1133
  rb_define_method(rb_cLLaMAModel, "desc", RUBY_METHOD_FUNC(_llama_model_get_model_desc), 0);
832
1134
  rb_define_method(rb_cLLaMAModel, "size", RUBY_METHOD_FUNC(_llama_model_get_model_size), 0);
833
1135
  rb_define_method(rb_cLLaMAModel, "n_params", RUBY_METHOD_FUNC(_llama_model_get_model_n_params), 0);
@@ -841,30 +1143,21 @@ private:
841
1143
  ID kw_table[2] = { rb_intern("model_path"), rb_intern("params") };
842
1144
  VALUE kw_values[2] = { Qundef, Qundef };
843
1145
  rb_scan_args(argc, argv, ":", &kw_args);
844
- rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
845
-
846
- if (kw_values[0] == Qundef && kw_values[1] == Qundef) {
847
- rb_iv_set(self, "@params", Qnil);
848
- return Qnil;
849
- }
1146
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
850
1147
 
851
1148
  if (!RB_TYPE_P(kw_values[0], T_STRING)) {
852
1149
  rb_raise(rb_eArgError, "model_path must be a string");
853
1150
  return Qnil;
854
1151
  }
855
- if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAContextParams)) {
856
- rb_raise(rb_eArgError, "params must be a ContextParams");
1152
+ if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAModelParams)) {
1153
+ rb_raise(rb_eArgError, "params must be a ModelParams");
857
1154
  return Qnil;
858
1155
  }
859
1156
 
860
1157
  VALUE filename = kw_values[0];
861
- LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(kw_values[1]);
1158
+ LLaMAModelParamsWrapper* prms_ptr = RbLLaMAModelParams::get_llama_model_params(kw_values[1]);
862
1159
  LLaMAModelWrapper* model_ptr = get_llama_model(self);
863
1160
 
864
- if (prms_ptr->params.seed == LLAMA_DEFAULT_SEED) {
865
- prms_ptr->params.seed = time(NULL);
866
- }
867
-
868
1161
  try {
869
1162
  model_ptr->model = llama_load_model_from_file(StringValueCStr(filename), prms_ptr->params);
870
1163
  } catch (const std::runtime_error& e) {
@@ -912,8 +1205,8 @@ private:
912
1205
  rb_raise(rb_eArgError, "model_path must be a string");
913
1206
  return Qnil;
914
1207
  }
915
- if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAContextParams)) {
916
- rb_raise(rb_eArgError, "params must be a LLaMAContextParams");
1208
+ if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAModelParams)) {
1209
+ rb_raise(rb_eArgError, "params must be a LLaMAModelParams");
917
1210
  return Qnil;
918
1211
  }
919
1212
 
@@ -924,7 +1217,7 @@ private:
924
1217
  }
925
1218
 
926
1219
  VALUE filename = kw_values[0];
927
- LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(kw_values[1]);
1220
+ LLaMAModelParamsWrapper* prms_ptr = RbLLaMAModelParams::get_llama_model_params(kw_values[1]);
928
1221
 
929
1222
  try {
930
1223
  model_ptr->model = llama_load_model_from_file(StringValueCStr(filename), prms_ptr->params);
@@ -946,10 +1239,10 @@ private:
946
1239
 
947
1240
  static VALUE _llama_model_apply_lora_from_file(int argc, VALUE* argv, VALUE self) {
948
1241
  VALUE kw_args = Qnil;
949
- ID kw_table[3] = { rb_intern("lora_path"), rb_intern("base_model_path"), rb_intern("n_threads") };
950
- VALUE kw_values[3] = { Qundef, Qundef, Qundef };
1242
+ ID kw_table[4] = { rb_intern("lora_path"), rb_intern("base_model_path"), rb_intern("n_threads"), rb_intern("scale") };
1243
+ VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
951
1244
  rb_scan_args(argc, argv, ":", &kw_args);
952
- rb_get_kwargs(kw_args, kw_table, 1, 2, kw_values);
1245
+ rb_get_kwargs(kw_args, kw_table, 1, 3, kw_values);
953
1246
 
954
1247
  if (!RB_TYPE_P(kw_values[0], T_STRING)) {
955
1248
  rb_raise(rb_eArgError, "lora_path must be a string");
@@ -963,13 +1256,18 @@ private:
963
1256
  rb_raise(rb_eArgError, "n_threads must be an integer");
964
1257
  return Qnil;
965
1258
  }
1259
+ if (kw_values[3] != Qundef && !RB_FLOAT_TYPE_P(kw_values[3])) {
1260
+ rb_raise(rb_eArgError, "scale must be a float");
1261
+ return Qnil;
1262
+ }
966
1263
 
967
1264
  const char* lora_path = StringValueCStr(kw_values[0]);
968
1265
  const char* base_model_path = kw_values[1] == Qundef ? NULL : StringValueCStr(kw_values[1]);
969
1266
  const int n_threads = kw_values[2] == Qundef ? 1 : NUM2INT(kw_values[2]);
1267
+ const float scale = kw_values[3] == Qundef ? 1.0 : NUM2DBL(kw_values[3]);
970
1268
 
971
1269
  LLaMAModelWrapper* ptr = get_llama_model(self);
972
- if (llama_model_apply_lora_from_file(ptr->model, lora_path, base_model_path, n_threads) != 0) {
1270
+ if (llama_model_apply_lora_from_file(ptr->model, lora_path, scale, base_model_path, n_threads) != 0) {
973
1271
  rb_raise(rb_eRuntimeError, "Failed to apply LoRA");
974
1272
  return Qnil;
975
1273
  }
@@ -978,25 +1276,25 @@ private:
978
1276
 
979
1277
  static VALUE _llama_model_get_model_n_vocab(VALUE self) {
980
1278
  LLaMAModelWrapper* ptr = get_llama_model(self);
981
- return INT2NUM(llama_model_n_vocab(ptr->model));
1279
+ return INT2NUM(llama_n_vocab(ptr->model));
982
1280
  }
983
1281
 
984
- static VALUE _llama_model_get_model_n_ctx(VALUE self) {
1282
+ static VALUE _llama_model_get_model_n_ctx_train(VALUE self) {
985
1283
  LLaMAModelWrapper* ptr = get_llama_model(self);
986
- return INT2NUM(llama_model_n_ctx(ptr->model));
1284
+ return INT2NUM(llama_n_ctx_train(ptr->model));
987
1285
  }
988
1286
 
989
- static VALUE _llama_model_get_model_n_ctx_train(VALUE self) {
1287
+ static VALUE _llama_model_get_model_n_embd(VALUE self) {
990
1288
  LLaMAModelWrapper* ptr = get_llama_model(self);
991
- return INT2NUM(llama_model_n_ctx_train(ptr->model));
1289
+ return INT2NUM(llama_n_embd(ptr->model));
992
1290
  }
993
1291
 
994
- static VALUE _llama_model_get_model_n_embd(VALUE self) {
1292
+ static VALUE _llama_model_rope_freq_scale_train(VALUE self) {
995
1293
  LLaMAModelWrapper* ptr = get_llama_model(self);
996
- return INT2NUM(llama_model_n_embd(ptr->model));
1294
+ return DBL2NUM(llama_rope_freq_scale_train(ptr->model));
997
1295
  }
998
1296
 
999
- static VALUE _llama_model_token_to_piece_with_model(VALUE self, VALUE token_) {
1297
+ static VALUE _llama_model_token_to_piece(VALUE self, VALUE token_) {
1000
1298
  if (!RB_INTEGER_TYPE_P(token_)) {
1001
1299
  rb_raise(rb_eArgError, "token must be an integer");
1002
1300
  return Qnil;
@@ -1004,10 +1302,10 @@ private:
1004
1302
  const llama_token token = NUM2INT(token_);
1005
1303
  LLaMAModelWrapper* ptr = get_llama_model(self);
1006
1304
  std::vector<char> result(8, 0);
1007
- const int n_tokens = llama_token_to_piece_with_model(ptr->model, token, result.data(), result.size());
1305
+ const int n_tokens = llama_token_to_piece(ptr->model, token, result.data(), result.size());
1008
1306
  if (n_tokens < 0) {
1009
1307
  result.resize(-n_tokens);
1010
- const int check = llama_token_to_piece_with_model(ptr->model, token, result.data(), result.size());
1308
+ const int check = llama_token_to_piece(ptr->model, token, result.data(), result.size());
1011
1309
  if (check != -n_tokens) {
1012
1310
  rb_raise(rb_eRuntimeError, "failed to convert");
1013
1311
  return Qnil;
@@ -1016,10 +1314,10 @@ private:
1016
1314
  result.resize(n_tokens);
1017
1315
  }
1018
1316
  std::string ret(result.data(), result.size());
1019
- return rb_str_new_cstr(ret.c_str());
1317
+ return rb_utf8_str_new_cstr(ret.c_str());
1020
1318
  }
1021
1319
 
1022
- static VALUE _llama_model_tokenize_with_model(int argc, VALUE* argv, VALUE self) {
1320
+ static VALUE _llama_model_tokenize(int argc, VALUE* argv, VALUE self) {
1023
1321
  VALUE kw_args = Qnil;
1024
1322
  ID kw_table[3] = { rb_intern("text"), rb_intern("n_max_tokens"), rb_intern("add_bos") };
1025
1323
  VALUE kw_values[3] = { Qundef, Qundef, Qundef };
@@ -1046,7 +1344,7 @@ private:
1046
1344
 
1047
1345
  llama_token* tokens = ALLOCA_N(llama_token, n_max_tokens);
1048
1346
  LLaMAModelWrapper* ptr = get_llama_model(self);
1049
- const int n_tokens = llama_tokenize_with_model(ptr->model, text.c_str(), text.size(), tokens, n_max_tokens, add_bos);
1347
+ const int n_tokens = llama_tokenize(ptr->model, text.c_str(), text.size(), tokens, n_max_tokens, add_bos);
1050
1348
 
1051
1349
  if (n_tokens < 0) {
1052
1350
  rb_raise(rb_eRuntimeError, "failed to tokenize. The numebr of tokens (%d) is greater than n_max_tokens.", -n_tokens);
@@ -1066,7 +1364,7 @@ private:
1066
1364
  LLaMAModelWrapper* ptr = get_llama_model(self);
1067
1365
  char buf[128];
1068
1366
  llama_model_desc(ptr->model, buf, sizeof(buf));
1069
- return rb_str_new_cstr(buf);
1367
+ return rb_utf8_str_new_cstr(buf);
1070
1368
  }
1071
1369
 
1072
1370
  static VALUE _llama_model_get_model_size(VALUE self) {
@@ -1345,11 +1643,11 @@ public:
1345
1643
  static void define_class(VALUE outer) {
1346
1644
  rb_cLLaMAContext = rb_define_class_under(outer, "Context", rb_cObject);
1347
1645
  rb_define_alloc_func(rb_cLLaMAContext, llama_context_alloc);
1646
+ rb_define_attr(rb_cLLaMAContext, "model", 1, 0);
1348
1647
  rb_define_method(rb_cLLaMAContext, "initialize", RUBY_METHOD_FUNC(_llama_context_initialize), -1);
1349
1648
  rb_define_method(rb_cLLaMAContext, "eval", RUBY_METHOD_FUNC(_llama_context_eval), -1);
1350
1649
  rb_define_method(rb_cLLaMAContext, "eval_embd", RUBY_METHOD_FUNC(_llama_context_eval_embd), -1);
1351
- rb_define_method(rb_cLLaMAContext, "eval_export", RUBY_METHOD_FUNC(_llama_context_eval_export), 1);
1352
- rb_define_method(rb_cLLaMAContext, "tokenize", RUBY_METHOD_FUNC(_llama_context_tokenize), -1);
1650
+ rb_define_method(rb_cLLaMAContext, "decode", RUBY_METHOD_FUNC(_llama_context_decode), 1);
1353
1651
  rb_define_method(rb_cLLaMAContext, "logits", RUBY_METHOD_FUNC(_llama_context_logits), 0);
1354
1652
  rb_define_method(rb_cLLaMAContext, "embeddings", RUBY_METHOD_FUNC(_llama_context_embeddings), 0);
1355
1653
  rb_define_method(rb_cLLaMAContext, "text", RUBY_METHOD_FUNC(_llama_context_text), 1);
@@ -1358,15 +1656,20 @@ public:
1358
1656
  rb_define_method(rb_cLLaMAContext, "token_bos", RUBY_METHOD_FUNC(_llama_context_token_bos), 0);
1359
1657
  rb_define_method(rb_cLLaMAContext, "token_eos", RUBY_METHOD_FUNC(_llama_context_token_eos), 0);
1360
1658
  rb_define_method(rb_cLLaMAContext, "token_nl", RUBY_METHOD_FUNC(_llama_context_token_nl), 0);
1361
- rb_define_method(rb_cLLaMAContext, "token_to_piece", RUBY_METHOD_FUNC(_llama_context_token_to_piece), 1);
1362
- rb_define_method(rb_cLLaMAContext, "n_vocab", RUBY_METHOD_FUNC(_llama_context_n_vocab), 0);
1659
+ rb_define_method(rb_cLLaMAContext, "token_prefix", RUBY_METHOD_FUNC(_llama_context_token_prefix), 0);
1660
+ rb_define_method(rb_cLLaMAContext, "token_middle", RUBY_METHOD_FUNC(_llama_context_token_middle), 0);
1661
+ rb_define_method(rb_cLLaMAContext, "token_suffix", RUBY_METHOD_FUNC(_llama_context_token_suffix), 0);
1662
+ rb_define_method(rb_cLLaMAContext, "token_eot", RUBY_METHOD_FUNC(_llama_context_token_eot), 0);
1363
1663
  rb_define_method(rb_cLLaMAContext, "n_ctx", RUBY_METHOD_FUNC(_llama_context_n_ctx), 0);
1364
- rb_define_method(rb_cLLaMAContext, "n_ctx_train", RUBY_METHOD_FUNC(_llama_context_n_ctx_train), 0);
1365
- rb_define_method(rb_cLLaMAContext, "n_embd", RUBY_METHOD_FUNC(_llama_context_n_embd), 0);
1366
1664
  rb_define_method(rb_cLLaMAContext, "timings", RUBY_METHOD_FUNC(_llama_context_get_timings), 0);
1367
1665
  rb_define_method(rb_cLLaMAContext, "print_timings", RUBY_METHOD_FUNC(_llama_context_print_timings), 0);
1368
1666
  rb_define_method(rb_cLLaMAContext, "reset_timings", RUBY_METHOD_FUNC(_llama_context_reset_timings), 0);
1369
1667
  rb_define_method(rb_cLLaMAContext, "kv_cache_token_count", RUBY_METHOD_FUNC(_llama_context_kv_cache_token_count), 0);
1668
+ rb_define_method(rb_cLLaMAContext, "kv_cache_tokens_rm", RUBY_METHOD_FUNC(_llama_context_kv_cache_tokens_rm), 2);
1669
+ rb_define_method(rb_cLLaMAContext, "kv_cache_seq_rm", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_rm), 3);
1670
+ rb_define_method(rb_cLLaMAContext, "kv_cache_seq_cp", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_cp), 4);
1671
+ rb_define_method(rb_cLLaMAContext, "kv_cache_seq_keep", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_keep), 1);
1672
+ rb_define_method(rb_cLLaMAContext, "kv_cache_seq_shift", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_shift), 4);
1370
1673
  rb_define_method(rb_cLLaMAContext, "set_rng_seed", RUBY_METHOD_FUNC(_llama_context_set_rng_seed), 1);
1371
1674
  rb_define_method(rb_cLLaMAContext, "load_session_file", RUBY_METHOD_FUNC(_llama_context_load_session_file), -1);
1372
1675
  rb_define_method(rb_cLLaMAContext, "save_session_file", RUBY_METHOD_FUNC(_llama_context_save_session_file), -1);
@@ -1378,6 +1681,7 @@ public:
1378
1681
  rb_define_method(rb_cLLaMAContext, "sample_top_p", RUBY_METHOD_FUNC(_llama_context_sample_top_p), -1);
1379
1682
  rb_define_method(rb_cLLaMAContext, "sample_tail_free", RUBY_METHOD_FUNC(_llama_context_sample_tail_free), -1);
1380
1683
  rb_define_method(rb_cLLaMAContext, "sample_typical", RUBY_METHOD_FUNC(_llama_context_sample_typical), -1);
1684
+ rb_define_method(rb_cLLaMAContext, "sample_temp", RUBY_METHOD_FUNC(_llama_context_sample_temp), -1);
1381
1685
  rb_define_method(rb_cLLaMAContext, "sample_temperature", RUBY_METHOD_FUNC(_llama_context_sample_temperature), -1);
1382
1686
  rb_define_method(rb_cLLaMAContext, "sample_token_mirostat", RUBY_METHOD_FUNC(_llama_context_sample_token_mirostat), -1);
1383
1687
  rb_define_method(rb_cLLaMAContext, "sample_token_mirostat_v2", RUBY_METHOD_FUNC(_llama_context_sample_token_mirostat_v2), -1);
@@ -1392,24 +1696,27 @@ private:
1392
1696
 
1393
1697
  static VALUE _llama_context_initialize(int argc, VALUE* argv, VALUE self) {
1394
1698
  VALUE kw_args = Qnil;
1395
- ID kw_table[1] = { rb_intern("model") };
1396
- VALUE kw_values[1] = { Qundef };
1699
+ ID kw_table[2] = { rb_intern("model"), rb_intern("params") };
1700
+ VALUE kw_values[2] = { Qundef, Qundef };
1397
1701
  rb_scan_args(argc, argv, ":", &kw_args);
1398
- rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
1702
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
1399
1703
 
1400
1704
  VALUE model = kw_values[0];
1401
1705
  if (!rb_obj_is_kind_of(model, rb_cLLaMAModel)) {
1402
1706
  rb_raise(rb_eArgError, "model must be a Model");
1403
1707
  return Qnil;
1404
1708
  }
1709
+ VALUE params = kw_values[1];
1710
+ if (!rb_obj_is_kind_of(params, rb_cLLaMAContextParams)) {
1711
+ rb_raise(rb_eArgError, "params must be a ContextParams");
1712
+ return Qnil;
1713
+ }
1405
1714
 
1406
1715
  LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
1407
1716
  if (model_ptr->model == NULL) {
1408
1717
  rb_raise(rb_eRuntimeError, "Model is empty");
1409
1718
  return Qnil;
1410
1719
  }
1411
-
1412
- VALUE params = rb_iv_get(model, "@params");
1413
1720
  LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
1414
1721
  LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
1415
1722
 
@@ -1421,6 +1728,7 @@ private:
1421
1728
  }
1422
1729
 
1423
1730
  rb_iv_set(self, "@model", model);
1731
+ rb_iv_set(self, "@params", params);
1424
1732
  rb_iv_set(self, "@has_evaluated", Qfalse);
1425
1733
 
1426
1734
  return Qnil;
@@ -1428,8 +1736,8 @@ private:
1428
1736
 
1429
1737
  static VALUE _llama_context_eval(int argc, VALUE* argv, VALUE self) {
1430
1738
  VALUE kw_args = Qnil;
1431
- ID kw_table[4] = { rb_intern("tokens"), rb_intern("n_past"), rb_intern("n_tokens"), rb_intern("n_threads") };
1432
- VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
1739
+ ID kw_table[3] = { rb_intern("tokens"), rb_intern("n_past"), rb_intern("n_tokens") };
1740
+ VALUE kw_values[3] = { Qundef, Qundef, Qundef };
1433
1741
  rb_scan_args(argc, argv, ":", &kw_args);
1434
1742
  rb_get_kwargs(kw_args, kw_table, 2, 2, kw_values);
1435
1743
 
@@ -1445,10 +1753,6 @@ private:
1445
1753
  rb_raise(rb_eArgError, "n_tokens must be an integer");
1446
1754
  return Qnil;
1447
1755
  }
1448
- if (kw_values[3] != Qundef && !RB_INTEGER_TYPE_P(kw_values[3])) {
1449
- rb_raise(rb_eArgError, "n_threads must be an integer");
1450
- return Qnil;
1451
- }
1452
1756
 
1453
1757
  const size_t tokens_len = RARRAY_LEN(kw_values[0]);
1454
1758
  std::vector<llama_token> embd(tokens_len);
@@ -1463,14 +1767,13 @@ private:
1463
1767
 
1464
1768
  const int n_tokens = kw_values[2] == Qundef ? (int)tokens_len : NUM2INT(kw_values[2]);
1465
1769
  const int n_past = NUM2INT(kw_values[1]);
1466
- const int n_threads = kw_values[3] == Qundef ? 1 : NUM2INT(kw_values[3]);
1467
1770
 
1468
1771
  LLaMAContextWrapper* ptr = get_llama_context(self);
1469
1772
  if (ptr->ctx == NULL) {
1470
1773
  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1471
1774
  return Qnil;
1472
1775
  }
1473
- if (llama_eval(ptr->ctx, embd.data(), n_tokens, n_past, n_threads) != 0) {
1776
+ if (llama_eval(ptr->ctx, embd.data(), n_tokens, n_past) != 0) {
1474
1777
  rb_raise(rb_eRuntimeError, "Failed to evaluate");
1475
1778
  return Qnil;
1476
1779
  }
@@ -1483,8 +1786,8 @@ private:
1483
1786
 
1484
1787
  static VALUE _llama_context_eval_embd(int argc, VALUE* argv, VALUE self) {
1485
1788
  VALUE kw_args = Qnil;
1486
- ID kw_table[4] = { rb_intern("embd"), rb_intern("n_past"), rb_intern("n_tokens"), rb_intern("n_threads") };
1487
- VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
1789
+ ID kw_table[3] = { rb_intern("embd"), rb_intern("n_past"), rb_intern("n_tokens") };
1790
+ VALUE kw_values[3] = { Qundef, Qundef, Qundef };
1488
1791
  rb_scan_args(argc, argv, ":", &kw_args);
1489
1792
  rb_get_kwargs(kw_args, kw_table, 2, 2, kw_values);
1490
1793
 
@@ -1500,10 +1803,6 @@ private:
1500
1803
  rb_raise(rb_eArgError, "n_tokens must be an integer");
1501
1804
  return Qnil;
1502
1805
  }
1503
- if (kw_values[3] != Qundef && !RB_INTEGER_TYPE_P(kw_values[3])) {
1504
- rb_raise(rb_eArgError, "n_threads must be an integer");
1505
- return Qnil;
1506
- }
1507
1806
 
1508
1807
  const size_t tokens_len = RARRAY_LEN(kw_values[0]);
1509
1808
  std::vector<float> embd(tokens_len);
@@ -1518,14 +1817,13 @@ private:
1518
1817
 
1519
1818
  const int n_tokens = kw_values[2] == Qundef ? (int)tokens_len : NUM2INT(kw_values[2]);
1520
1819
  const int n_past = NUM2INT(kw_values[1]);
1521
- const int n_threads = kw_values[3] == Qundef ? 1 : NUM2INT(kw_values[3]);
1522
1820
 
1523
1821
  LLaMAContextWrapper* ptr = get_llama_context(self);
1524
1822
  if (ptr->ctx == NULL) {
1525
1823
  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1526
1824
  return Qnil;
1527
1825
  }
1528
- if (llama_eval_embd(ptr->ctx, embd.data(), n_tokens, n_past, n_threads) != 0) {
1826
+ if (llama_eval_embd(ptr->ctx, embd.data(), n_tokens, n_past) != 0) {
1529
1827
  rb_raise(rb_eRuntimeError, "Failed to evaluate");
1530
1828
  return Qnil;
1531
1829
  }
@@ -1536,91 +1834,22 @@ private:
1536
1834
  return Qnil;
1537
1835
  }
1538
1836
 
1539
- static VALUE _llama_context_eval_export(VALUE self, VALUE fname_) {
1837
+ static VALUE _llama_context_decode(VALUE self, VALUE batch) {
1540
1838
  LLaMAContextWrapper* ptr = get_llama_context(self);
1541
1839
  if (ptr->ctx == NULL) {
1542
1840
  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1543
1841
  return Qnil;
1544
1842
  }
1545
- if (!RB_TYPE_P(fname_, T_STRING)) {
1546
- rb_raise(rb_eArgError, "fname must be a string");
1547
- return Qnil;
1548
- }
1549
- const char* fname = StringValueCStr(fname_);
1550
- if (llama_eval_export(ptr->ctx, fname) != 0) {
1551
- return Qfalse;
1552
- }
1553
- RB_GC_GUARD(fname_);
1554
- return Qtrue;
1555
- }
1556
-
1557
- static VALUE _llama_context_tokenize(int argc, VALUE* argv, VALUE self) {
1558
- VALUE kw_args = Qnil;
1559
- ID kw_table[3] = { rb_intern("text"), rb_intern("n_max_tokens"), rb_intern("add_bos") };
1560
- VALUE kw_values[3] = { Qundef, Qundef, Qundef };
1561
- rb_scan_args(argc, argv, ":", &kw_args);
1562
- rb_get_kwargs(kw_args, kw_table, 1, 2, kw_values);
1563
-
1564
- if (!RB_TYPE_P(kw_values[0], T_STRING)) {
1565
- rb_raise(rb_eArgError, "text must be a String");
1566
- return Qnil;
1567
- }
1568
- if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
1569
- rb_raise(rb_eArgError, "n_max_tokens must be an integer");
1843
+ if (!rb_obj_is_kind_of(batch, rb_cLLaMABatch)) {
1844
+ rb_raise(rb_eArgError, "batch must be a Batch");
1570
1845
  return Qnil;
1571
1846
  }
1572
- if (kw_values[2] != Qundef && (kw_values[2] != Qtrue && kw_values[2] != Qfalse)) {
1573
- rb_raise(rb_eArgError, "add_bos must be a boolean");
1847
+ LLaMABatchWrapper* batch_ptr = RbLLaMABatch::get_llama_batch(batch);
1848
+ if (llama_decode(ptr->ctx, batch_ptr->batch) < 0) {
1849
+ rb_raise(rb_eRuntimeError, "Failed to decode");
1574
1850
  return Qnil;
1575
1851
  }
1576
-
1577
- VALUE text_ = kw_values[0];
1578
- std::string text = StringValueCStr(text_);
1579
- const bool add_bos = kw_values[2] == Qtrue ? true : false;
1580
- const int n_max_tokens = kw_values[1] != Qundef ? NUM2INT(kw_values[1]) : text.size() + (add_bos ? 1 : 0);
1581
-
1582
- std::vector<llama_token> tokens(n_max_tokens);
1583
- LLaMAContextWrapper* ptr = get_llama_context(self);
1584
- if (ptr->ctx == NULL) {
1585
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1586
- return Qnil;
1587
- }
1588
- const int n = llama_tokenize(ptr->ctx, text.c_str(), text.size(), tokens.data(), n_max_tokens, add_bos);
1589
- if (n < 0) {
1590
- rb_raise(rb_eRuntimeError, "Failed to tokenize");
1591
- return Qnil;
1592
- }
1593
-
1594
- VALUE output = rb_ary_new();
1595
- for (int i = 0; i < n; i++) {
1596
- rb_ary_push(output, INT2NUM(tokens[i]));
1597
- }
1598
-
1599
- RB_GC_GUARD(text_);
1600
- return output;
1601
- }
1602
-
1603
- static VALUE _llama_context_token_to_piece(VALUE self, VALUE token_) {
1604
- LLaMAContextWrapper* ptr = get_llama_context(self);
1605
- if (ptr->ctx == NULL) {
1606
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1607
- return Qnil;
1608
- }
1609
- const llama_token token = NUM2INT(token_);
1610
- std::vector<char> result(8, 0);
1611
- const int n_tokens = llama_token_to_piece(ptr->ctx, token, result.data(), result.size());
1612
- if (n_tokens < 0) {
1613
- result.resize(-n_tokens);
1614
- const int check = llama_token_to_piece(ptr->ctx, token, result.data(), result.size());
1615
- if (check != -n_tokens) {
1616
- rb_raise(rb_eRuntimeError, "failed to convert");
1617
- return Qnil;
1618
- }
1619
- } else {
1620
- result.resize(n_tokens);
1621
- }
1622
- std::string ret(result.data(), result.size());
1623
- return rb_str_new_cstr(ret.c_str());
1852
+ return Qnil;
1624
1853
  }
1625
1854
 
1626
1855
  static VALUE _llama_context_logits(VALUE self) {
@@ -1635,10 +1864,11 @@ private:
1635
1864
  }
1636
1865
 
1637
1866
  VALUE model = rb_iv_get(self, "@model");
1638
- VALUE params = rb_iv_get(model, "@params");
1867
+ LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
1868
+ VALUE params = rb_iv_get(self, "@params");
1639
1869
  LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
1640
1870
  const int n_tokens = prms_ptr->params.logits_all ? NUM2INT(rb_iv_get(self, "@n_tokens")) : 1;
1641
- const int n_vocab = llama_n_vocab(ptr->ctx);
1871
+ const int n_vocab = llama_n_vocab(model_ptr->model);
1642
1872
  const float* logits = llama_get_logits(ptr->ctx);
1643
1873
  VALUE output = rb_ary_new();
1644
1874
  for (int i = 0; i < n_tokens * n_vocab; i++) {
@@ -1655,7 +1885,8 @@ private:
1655
1885
  return Qnil;
1656
1886
  }
1657
1887
  VALUE model = rb_iv_get(self, "@model");
1658
- VALUE params = rb_iv_get(model, "@params");
1888
+ LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
1889
+ VALUE params = rb_iv_get(self, "@params");
1659
1890
  LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
1660
1891
  if (!prms_ptr->params.embedding) {
1661
1892
  rb_raise(rb_eRuntimeError, "embedding parameter is false");
@@ -1666,7 +1897,7 @@ private:
1666
1897
  return Qnil;
1667
1898
  }
1668
1899
 
1669
- const int n_embd = llama_n_embd(ptr->ctx);
1900
+ const int n_embd = llama_n_embd(model_ptr->model);
1670
1901
  const float* embd = llama_get_embeddings(ptr->ctx);
1671
1902
  VALUE output = rb_ary_new();
1672
1903
  for (int i = 0; i < n_embd; i++) {
@@ -1684,7 +1915,7 @@ private:
1684
1915
  }
1685
1916
  const llama_token token = NUM2INT(token_);
1686
1917
  const char* text = llama_token_get_text(ptr->ctx, token);
1687
- return rb_str_new_cstr(text);
1918
+ return rb_utf8_str_new_cstr(text);
1688
1919
  }
1689
1920
 
1690
1921
  static VALUE _llama_context_score(VALUE self, VALUE token_) {
@@ -1736,40 +1967,49 @@ private:
1736
1967
  return INT2NUM(llama_token_nl(ptr->ctx));
1737
1968
  }
1738
1969
 
1739
- static VALUE _llama_context_n_vocab(VALUE self) {
1970
+ static VALUE _llama_context_token_prefix(VALUE self) {
1740
1971
  LLaMAContextWrapper* ptr = get_llama_context(self);
1741
1972
  if (ptr->ctx == NULL) {
1742
1973
  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1743
1974
  return Qnil;
1744
1975
  }
1745
- return INT2NUM(llama_n_vocab(ptr->ctx));
1976
+ return INT2NUM(llama_token_prefix(ptr->ctx));
1746
1977
  }
1747
1978
 
1748
- static VALUE _llama_context_n_ctx(VALUE self) {
1979
+ static VALUE _llama_context_token_middle(VALUE self) {
1749
1980
  LLaMAContextWrapper* ptr = get_llama_context(self);
1750
1981
  if (ptr->ctx == NULL) {
1751
1982
  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1752
1983
  return Qnil;
1753
1984
  }
1754
- return INT2NUM(llama_n_ctx(ptr->ctx));
1985
+ return INT2NUM(llama_token_middle(ptr->ctx));
1986
+ }
1987
+
1988
+ static VALUE _llama_context_token_suffix(VALUE self) {
1989
+ LLaMAContextWrapper* ptr = get_llama_context(self);
1990
+ if (ptr->ctx == NULL) {
1991
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1992
+ return Qnil;
1993
+ }
1994
+ return INT2NUM(llama_token_suffix(ptr->ctx));
1755
1995
  }
1756
1996
 
1757
- static VALUE _llama_context_n_ctx_train(VALUE self) {
1997
+ static VALUE _llama_context_token_eot(VALUE self) {
1758
1998
  LLaMAContextWrapper* ptr = get_llama_context(self);
1759
1999
  if (ptr->ctx == NULL) {
1760
2000
  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1761
2001
  return Qnil;
1762
2002
  }
1763
- return INT2NUM(llama_n_ctx_train(ptr->ctx));
2003
+ return INT2NUM(llama_token_eot(ptr->ctx));
1764
2004
  }
1765
2005
 
1766
- static VALUE _llama_context_n_embd(VALUE self) {
2006
+ static VALUE _llama_context_n_ctx(VALUE self) {
1767
2007
  LLaMAContextWrapper* ptr = get_llama_context(self);
1768
2008
  if (ptr->ctx == NULL) {
1769
2009
  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1770
2010
  return Qnil;
1771
2011
  }
1772
- return INT2NUM(llama_n_embd(ptr->ctx));
2012
+ return INT2NUM(llama_n_ctx(ptr->ctx));
1773
2013
  }
1774
2014
 
1775
2015
  static VALUE _llama_context_get_timings(VALUE self) {
@@ -1813,6 +2053,56 @@ private:
1813
2053
  return INT2NUM(llama_get_kv_cache_token_count(ptr->ctx));
1814
2054
  }
1815
2055
 
2056
+ static VALUE _llama_context_kv_cache_tokens_rm(VALUE self, VALUE c0, VALUE c1) {
2057
+ LLaMAContextWrapper* ptr = get_llama_context(self);
2058
+ if (ptr->ctx == NULL) {
2059
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2060
+ return Qnil;
2061
+ }
2062
+ llama_kv_cache_tokens_rm(ptr->ctx, NUM2INT(c0), NUM2INT(c1));
2063
+ return Qnil;
2064
+ }
2065
+
2066
+ static VALUE _llama_context_kv_cache_seq_rm(VALUE self, VALUE seq_id, VALUE p0, VALUE p1) {
2067
+ LLaMAContextWrapper* ptr = get_llama_context(self);
2068
+ if (ptr->ctx == NULL) {
2069
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2070
+ return Qnil;
2071
+ }
2072
+ llama_kv_cache_seq_rm(ptr->ctx, NUM2INT(seq_id), NUM2INT(p0), NUM2INT(p1));
2073
+ return Qnil;
2074
+ }
2075
+
2076
+ static VALUE _llama_context_kv_cache_seq_cp(VALUE self, VALUE seq_id_src, VALUE seq_id_dst, VALUE p0, VALUE p1) {
2077
+ LLaMAContextWrapper* ptr = get_llama_context(self);
2078
+ if (ptr->ctx == NULL) {
2079
+ rb_raise(rb_eArgError, "LLaMA context is not initialized");
2080
+ return Qnil;
2081
+ }
2082
+ llama_kv_cache_seq_cp(ptr->ctx, NUM2INT(seq_id_src), NUM2INT(seq_id_dst), NUM2INT(p0), NUM2INT(p1));
2083
+ return Qnil;
2084
+ }
2085
+
2086
+ static VALUE _llama_context_kv_cache_seq_keep(VALUE self, VALUE seq_id) {
2087
+ LLaMAContextWrapper* ptr = get_llama_context(self);
2088
+ if (ptr->ctx == NULL) {
2089
+ rb_raise(rb_eArgError, "LLaMA context is not initialized");
2090
+ return Qnil;
2091
+ }
2092
+ llama_kv_cache_seq_keep(ptr->ctx, NUM2INT(seq_id));
2093
+ return Qnil;
2094
+ }
2095
+
2096
+ static VALUE _llama_context_kv_cache_seq_shift(VALUE self, VALUE seq_id, VALUE p0, VALUE p1, VALUE delta) {
2097
+ LLaMAContextWrapper* ptr = get_llama_context(self);
2098
+ if (ptr->ctx == NULL) {
2099
+ rb_raise(rb_eArgError, "LLaMA context is not initialized");
2100
+ return Qnil;
2101
+ }
2102
+ llama_kv_cache_seq_shift(ptr->ctx, NUM2INT(seq_id), NUM2INT(p0), NUM2INT(p1), NUM2INT(delta));
2103
+ return Qnil;
2104
+ }
2105
+
1816
2106
  static VALUE _llama_context_set_rng_seed(VALUE self, VALUE seed_) {
1817
2107
  LLaMAContextWrapper* ptr = get_llama_context(self);
1818
2108
  if (ptr->ctx == NULL) {
@@ -1851,7 +2141,7 @@ private:
1851
2141
  }
1852
2142
 
1853
2143
  VALUE model = rb_iv_get(self, "@model");
1854
- VALUE params = rb_iv_get(model, "@params");
2144
+ VALUE params = rb_iv_get(self, "@params");
1855
2145
  LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
1856
2146
  const int n_ctx = prms_ptr->params.n_ctx;
1857
2147
 
@@ -2235,6 +2525,40 @@ private:
2235
2525
  return Qnil;
2236
2526
  }
2237
2527
 
2528
+ static VALUE _llama_context_sample_temp(int argc, VALUE* argv, VALUE self) {
2529
+ VALUE kw_args = Qnil;
2530
+ ID kw_table[1] = { rb_intern("temp") };
2531
+ VALUE kw_values[1] = { Qundef };
2532
+ VALUE candidates = Qnil;
2533
+ rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
2534
+ rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
2535
+
2536
+ if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
2537
+ rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
2538
+ return Qnil;
2539
+ }
2540
+ if (!RB_FLOAT_TYPE_P(kw_values[0])) {
2541
+ rb_raise(rb_eArgError, "temp must be a float");
2542
+ return Qnil;
2543
+ }
2544
+
2545
+ LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
2546
+ if (ctx_ptr->ctx == NULL) {
2547
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2548
+ return Qnil;
2549
+ }
2550
+ LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
2551
+ if (cnd_ptr->array.data == nullptr) {
2552
+ rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
2553
+ return Qnil;
2554
+ }
2555
+ const float temp = NUM2DBL(kw_values[0]);
2556
+
2557
+ llama_sample_temp(ctx_ptr->ctx, &(cnd_ptr->array), temp);
2558
+
2559
+ return Qnil;
2560
+ }
2561
+
2238
2562
  static VALUE _llama_context_sample_temperature(int argc, VALUE* argv, VALUE self) {
2239
2563
  VALUE kw_args = Qnil;
2240
2564
  ID kw_table[1] = { rb_intern("temperature") };
@@ -2560,6 +2884,7 @@ extern "C" void Init_llama_cpp(void) {
2560
2884
  RbLLaMATokenData::define_class(rb_mLLaMACpp);
2561
2885
  RbLLaMATokenDataArray::define_class(rb_mLLaMACpp);
2562
2886
  RbLLaMAModel::define_class(rb_mLLaMACpp);
2887
+ RbLLaMAModelParams::define_class(rb_mLLaMACpp);
2563
2888
  RbLLaMATimings::define_class(rb_mLLaMACpp);
2564
2889
  RbLLaMAContext::define_class(rb_mLLaMACpp);
2565
2890
  RbLLaMAContextParams::define_class(rb_mLLaMACpp);
@@ -2578,10 +2903,6 @@ extern "C" void Init_llama_cpp(void) {
2578
2903
 
2579
2904
  rb_define_const(rb_mLLaMACpp, "LLAMA_MAX_DEVICES", INT2NUM(LLAMA_MAX_DEVICES));
2580
2905
 
2581
- rb_define_const(rb_mLLaMACpp, "LLAMA_LOG_LEVEL_ERROR", INT2NUM(LLAMA_LOG_LEVEL_ERROR));
2582
- rb_define_const(rb_mLLaMACpp, "LLAMA_LOG_LEVEL_WARN", INT2NUM(LLAMA_LOG_LEVEL_WARN));
2583
- rb_define_const(rb_mLLaMACpp, "LLAMA_LOG_LEVEL_INFO", INT2NUM(LLAMA_LOG_LEVEL_INFO));
2584
-
2585
2906
  rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_TYPE_SPM", INT2NUM(LLAMA_VOCAB_TYPE_SPM));
2586
2907
  rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_TYPE_BPE", INT2NUM(LLAMA_VOCAB_TYPE_BPE));
2587
2908