llama_cpp 0.5.3 → 0.7.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,7 +1,9 @@
1
1
  #include "llama_cpp.h"
2
2
 
3
3
  VALUE rb_mLLaMACpp;
4
+ VALUE rb_cLLaMABatch;
4
5
  VALUE rb_cLLaMAModel;
6
+ VALUE rb_cLLaMAModelParams;
5
7
  VALUE rb_cLLaMATimings;
6
8
  VALUE rb_cLLaMAContext;
7
9
  VALUE rb_cLLaMAContextParams;
@@ -11,6 +13,238 @@ VALUE rb_cLLaMATokenDataArray;
11
13
  VALUE rb_cLLaMAGrammarElement;
12
14
  VALUE rb_cLLaMAGrammar;
13
15
 
16
+ class LLaMABatchWrapper {
17
+ public:
18
+ llama_batch batch;
19
+
20
+ LLaMABatchWrapper() {}
21
+
22
+ ~LLaMABatchWrapper() {
23
+ llama_batch_free(batch);
24
+ }
25
+ };
26
+
27
+ class RbLLaMABatch {
28
+ public:
29
+ static VALUE llama_batch_alloc(VALUE self) {
30
+ LLaMABatchWrapper* ptr = (LLaMABatchWrapper*)ruby_xmalloc(sizeof(LLaMABatchWrapper));
31
+ new (ptr) LLaMABatchWrapper();
32
+ return TypedData_Wrap_Struct(self, &llama_batch_type, ptr);
33
+ }
34
+
35
+ static void llama_batch_free(void* ptr) {
36
+ ((LLaMABatchWrapper*)ptr)->~LLaMABatchWrapper();
37
+ ruby_xfree(ptr);
38
+ }
39
+
40
+ static size_t llama_batch_size(const void* ptr) {
41
+ return sizeof(*((LLaMABatchWrapper*)ptr));
42
+ }
43
+
44
+ static LLaMABatchWrapper* get_llama_batch(VALUE self) {
45
+ LLaMABatchWrapper* ptr;
46
+ TypedData_Get_Struct(self, LLaMABatchWrapper, &llama_batch_type, ptr);
47
+ return ptr;
48
+ }
49
+
50
+ static void define_class(VALUE outer) {
51
+ rb_cLLaMABatch = rb_define_class_under(outer, "Batch", rb_cObject);
52
+ rb_define_alloc_func(rb_cLLaMABatch, llama_batch_alloc);
53
+ rb_define_method(rb_cLLaMABatch, "initialize", RUBY_METHOD_FUNC(_llama_batch_initialize), -1);
54
+ rb_define_method(rb_cLLaMABatch, "n_tokens=", RUBY_METHOD_FUNC(_llama_batch_set_n_tokens), 1);
55
+ rb_define_method(rb_cLLaMABatch, "n_tokens", RUBY_METHOD_FUNC(_llama_batch_get_n_tokens), 0);
56
+ rb_define_method(rb_cLLaMABatch, "all_pos_zero=", RUBY_METHOD_FUNC(_llama_batch_set_all_pos_zero), 1);
57
+ rb_define_method(rb_cLLaMABatch, "all_pos_zero", RUBY_METHOD_FUNC(_llama_batch_get_all_pos_zero), 0);
58
+ rb_define_method(rb_cLLaMABatch, "all_pos_one=", RUBY_METHOD_FUNC(_llama_batch_set_all_pos_one), 1);
59
+ rb_define_method(rb_cLLaMABatch, "all_pos_one", RUBY_METHOD_FUNC(_llama_batch_get_all_pos_one), 0);
60
+ rb_define_method(rb_cLLaMABatch, "all_seq_id=", RUBY_METHOD_FUNC(_llama_batch_set_all_seq_id), 1);
61
+ rb_define_method(rb_cLLaMABatch, "all_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_all_seq_id), 0);
62
+ rb_define_method(rb_cLLaMABatch, "set_token", RUBY_METHOD_FUNC(_llama_batch_set_token), 2);
63
+ rb_define_method(rb_cLLaMABatch, "get_token", RUBY_METHOD_FUNC(_llama_batch_get_token), 1);
64
+ rb_define_method(rb_cLLaMABatch, "set_pos", RUBY_METHOD_FUNC(_llama_batch_set_pos), 2);
65
+ rb_define_method(rb_cLLaMABatch, "get_pos", RUBY_METHOD_FUNC(_llama_batch_get_pos), 1);
66
+ rb_define_method(rb_cLLaMABatch, "set_seq_id", RUBY_METHOD_FUNC(_llama_batch_set_seq_id), 2);
67
+ rb_define_method(rb_cLLaMABatch, "get_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_seq_id), 1);
68
+ rb_define_method(rb_cLLaMABatch, "set_logits", RUBY_METHOD_FUNC(_llama_batch_set_logits), 2);
69
+ rb_define_method(rb_cLLaMABatch, "get_logits", RUBY_METHOD_FUNC(_llama_batch_get_logits), 1);
70
+ }
71
+
72
+ private:
73
+ static const rb_data_type_t llama_batch_type;
74
+
75
+ static VALUE _llama_batch_initialize(int argc, VALUE* argv, VALUE self) {
76
+ VALUE kw_args = Qnil;
77
+ ID kw_table[2] = { rb_intern("n_tokens"), rb_intern("embd") };
78
+ VALUE kw_values[2] = { Qundef, Qundef };
79
+ rb_scan_args(argc, argv, ":", &kw_args);
80
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
81
+
82
+ if (!RB_INTEGER_TYPE_P(kw_values[0])) {
83
+ rb_raise(rb_eArgError, "n_tokens must be an integer");
84
+ return Qnil;
85
+ }
86
+ if (!RB_INTEGER_TYPE_P(kw_values[1])) {
87
+ rb_raise(rb_eArgError, "embd must be an integer");
88
+ return Qnil;
89
+ }
90
+
91
+ const int32_t n_tokens = NUM2INT(kw_values[0]);
92
+ const int32_t embd = NUM2INT(kw_values[1]);
93
+
94
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
95
+ ptr->batch = llama_batch_init(n_tokens, embd);
96
+
97
+ return Qnil;
98
+ }
99
+
100
+ // n_tokens
101
+ static VALUE _llama_batch_set_n_tokens(VALUE self, VALUE n_tokens) {
102
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
103
+ ptr->batch.n_tokens = NUM2INT(n_tokens);
104
+ return INT2NUM(ptr->batch.n_tokens);
105
+ }
106
+
107
+ static VALUE _llama_batch_get_n_tokens(VALUE self) {
108
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
109
+ return INT2NUM(ptr->batch.n_tokens);
110
+ }
111
+
112
+ // all_pos_0
113
+ static VALUE _llama_batch_set_all_pos_zero(VALUE self, VALUE all_pos_0) {
114
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
115
+ ptr->batch.all_pos_0 = NUM2INT(all_pos_0);
116
+ return INT2NUM(ptr->batch.all_pos_0);
117
+ }
118
+
119
+ static VALUE _llama_batch_get_all_pos_zero(VALUE self) {
120
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
121
+ return INT2NUM(ptr->batch.all_pos_0);
122
+ }
123
+
124
+ // all_pos_1
125
+ static VALUE _llama_batch_set_all_pos_one(VALUE self, VALUE all_pos_1) {
126
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
127
+ ptr->batch.all_pos_1 = NUM2INT(all_pos_1);
128
+ return INT2NUM(ptr->batch.all_pos_1);
129
+ }
130
+
131
+ static VALUE _llama_batch_get_all_pos_one(VALUE self) {
132
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
133
+ return INT2NUM(ptr->batch.all_pos_1);
134
+ }
135
+
136
+ // all_seq_id
137
+ static VALUE _llama_batch_set_all_seq_id(VALUE self, VALUE all_seq_id) {
138
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
139
+ ptr->batch.all_seq_id = NUM2INT(all_seq_id);
140
+ return INT2NUM(ptr->batch.all_seq_id);
141
+ }
142
+
143
+ static VALUE _llama_batch_get_all_seq_id(VALUE self) {
144
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
145
+ return INT2NUM(ptr->batch.all_seq_id);
146
+ }
147
+
148
+ // token
149
+ static VALUE _llama_batch_set_token(VALUE self, VALUE idx, VALUE value) {
150
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
151
+ const int32_t id = NUM2INT(idx);
152
+ if (id < 0 || id >= ptr->batch.n_tokens) {
153
+ rb_raise(rb_eArgError, "idx must be in [0, n_tokens)");
154
+ return Qnil;
155
+ }
156
+ ptr->batch.token[id] = NUM2INT(value);
157
+ return INT2NUM(ptr->batch.token[id]);
158
+ }
159
+
160
+ static VALUE _llama_batch_get_token(VALUE self, VALUE idx) {
161
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
162
+ const int32_t id = NUM2INT(idx);
163
+ if (id < 0 || id >= ptr->batch.n_tokens) {
164
+ rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
165
+ return Qnil;
166
+ }
167
+ return INT2NUM(ptr->batch.token[id]);
168
+ }
169
+
170
+ // pos
171
+ static VALUE _llama_batch_set_pos(VALUE self, VALUE idx, VALUE value) {
172
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
173
+ const int32_t id = NUM2INT(idx);
174
+ if (id < 0 || id >= ptr->batch.n_tokens) {
175
+ rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
176
+ return Qnil;
177
+ }
178
+ ptr->batch.pos[id] = NUM2INT(value);
179
+ return INT2NUM(ptr->batch.pos[id]);
180
+ }
181
+
182
+ static VALUE _llama_batch_get_pos(VALUE self, VALUE idx) {
183
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
184
+ const int32_t id = NUM2INT(idx);
185
+ if (id < 0 || id >= ptr->batch.n_tokens) {
186
+ rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
187
+ return Qnil;
188
+ }
189
+ return INT2NUM(ptr->batch.pos[id]);
190
+ }
191
+
192
+ // seq_id
193
+ static VALUE _llama_batch_set_seq_id(VALUE self, VALUE idx, VALUE value) {
194
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
195
+ const int32_t id = NUM2INT(idx);
196
+ if (id < 0 || id >= ptr->batch.n_tokens) {
197
+ rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
198
+ return Qnil;
199
+ }
200
+ ptr->batch.seq_id[id] = NUM2INT(value);
201
+ return INT2NUM(ptr->batch.seq_id[id]);
202
+ }
203
+
204
+ static VALUE _llama_batch_get_seq_id(VALUE self, VALUE idx) {
205
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
206
+ const int32_t id = NUM2INT(idx);
207
+ if (id < 0 || id >= ptr->batch.n_tokens) {
208
+ rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
209
+ return Qnil;
210
+ }
211
+ return INT2NUM(ptr->batch.seq_id[id]);
212
+ }
213
+
214
+ // logits
215
+ static VALUE _llama_batch_set_logits(VALUE self, VALUE idx, VALUE value) {
216
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
217
+ const int32_t id = NUM2INT(idx);
218
+ if (id < 0 || id >= ptr->batch.n_tokens) {
219
+ rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
220
+ return Qnil;
221
+ }
222
+ ptr->batch.logits[id] = RTEST(value) ? true : false;
223
+ return ptr->batch.logits[id] ? Qtrue : Qfalse;
224
+ }
225
+
226
+ static VALUE _llama_batch_get_logits(VALUE self, VALUE idx) {
227
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
228
+ const int32_t id = NUM2INT(idx);
229
+ if (id < 0 || id >= ptr->batch.n_tokens) {
230
+ rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
231
+ return Qnil;
232
+ }
233
+ return ptr->batch.logits[id] ? Qtrue : Qfalse;
234
+ }
235
+ };
236
+
237
+ const rb_data_type_t RbLLaMABatch::llama_batch_type = {
238
+ "RbLLaMABatch",
239
+ { NULL,
240
+ RbLLaMABatch::llama_batch_free,
241
+ RbLLaMABatch::llama_batch_size },
242
+ NULL,
243
+ NULL,
244
+ RUBY_TYPED_FREE_IMMEDIATELY
245
+
246
+ };
247
+
14
248
  class LLaMATokenDataWrapper {
15
249
  public:
16
250
  llama_token_data data;
@@ -363,6 +597,144 @@ const rb_data_type_t RbLLaMATimings::llama_timings_type = {
363
597
  RUBY_TYPED_FREE_IMMEDIATELY
364
598
  };
365
599
 
600
+ class LLaMAModelParamsWrapper {
601
+ public:
602
+ struct llama_model_params params;
603
+
604
+ LLaMAModelParamsWrapper() : params(llama_model_default_params()) {}
605
+
606
+ ~LLaMAModelParamsWrapper() {}
607
+ };
608
+
609
+ class RbLLaMAModelParams {
610
+ public:
611
+ static VALUE llama_model_params_alloc(VALUE self) {
612
+ LLaMAModelParamsWrapper* ptr = (LLaMAModelParamsWrapper*)ruby_xmalloc(sizeof(LLaMAModelParamsWrapper));
613
+ new (ptr) LLaMAModelParamsWrapper();
614
+ return TypedData_Wrap_Struct(self, &llama_model_params_type, ptr);
615
+ }
616
+
617
+ static void llama_model_params_free(void* ptr) {
618
+ ((LLaMAModelParamsWrapper*)ptr)->~LLaMAModelParamsWrapper();
619
+ ruby_xfree(ptr);
620
+ }
621
+
622
+ static size_t llama_model_params_size(const void* ptr) {
623
+ return sizeof(*((LLaMAModelParamsWrapper*)ptr));
624
+ }
625
+
626
+ static LLaMAModelParamsWrapper* get_llama_model_params(VALUE self) {
627
+ LLaMAModelParamsWrapper* ptr;
628
+ TypedData_Get_Struct(self, LLaMAModelParamsWrapper, &llama_model_params_type, ptr);
629
+ return ptr;
630
+ }
631
+
632
+ static void define_class(VALUE outer) {
633
+ rb_cLLaMAModelParams = rb_define_class_under(outer, "ModelParams", rb_cObject);
634
+ rb_define_alloc_func(rb_cLLaMAModelParams, llama_model_params_alloc);
635
+ rb_define_method(rb_cLLaMAModelParams, "n_gpu_layers=", RUBY_METHOD_FUNC(_llama_model_params_set_n_gpu_layers), 1);
636
+ rb_define_method(rb_cLLaMAModelParams, "n_gpu_layers", RUBY_METHOD_FUNC(_llama_model_params_get_n_gpu_layers), 0);
637
+ rb_define_method(rb_cLLaMAModelParams, "main_gpu=", RUBY_METHOD_FUNC(_llama_model_params_set_main_gpu), 1);
638
+ rb_define_method(rb_cLLaMAModelParams, "main_gpu", RUBY_METHOD_FUNC(_llama_model_params_get_main_gpu), 0);
639
+ rb_define_method(rb_cLLaMAModelParams, "tensor_split", RUBY_METHOD_FUNC(_llama_model_params_get_tensor_split), 0);
640
+ rb_define_method(rb_cLLaMAModelParams, "vocab_only=", RUBY_METHOD_FUNC(_llama_model_params_set_vocab_only), 1);
641
+ rb_define_method(rb_cLLaMAModelParams, "vocab_only", RUBY_METHOD_FUNC(_llama_model_params_get_vocab_only), 0);
642
+ rb_define_method(rb_cLLaMAModelParams, "use_mmap=", RUBY_METHOD_FUNC(_llama_model_params_set_use_mmap), 1);
643
+ rb_define_method(rb_cLLaMAModelParams, "use_mmap", RUBY_METHOD_FUNC(_llama_model_params_get_use_mmap), 0);
644
+ rb_define_method(rb_cLLaMAModelParams, "use_mlock=", RUBY_METHOD_FUNC(_llama_model_params_set_use_mlock), 1);
645
+ rb_define_method(rb_cLLaMAModelParams, "use_mlock", RUBY_METHOD_FUNC(_llama_model_params_get_use_mlock), 0);
646
+ }
647
+
648
+ private:
649
+ static const rb_data_type_t llama_model_params_type;
650
+
651
+ // n_gpu_layers
652
+ static VALUE _llama_model_params_set_n_gpu_layers(VALUE self, VALUE n_gpu_layers) {
653
+ LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
654
+ ptr->params.n_gpu_layers = NUM2INT(n_gpu_layers);
655
+ return INT2NUM(ptr->params.n_gpu_layers);
656
+ }
657
+
658
+ static VALUE _llama_model_params_get_n_gpu_layers(VALUE self) {
659
+ LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
660
+ return INT2NUM(ptr->params.n_gpu_layers);
661
+ }
662
+
663
+ // main_gpu
664
+ static VALUE _llama_model_params_set_main_gpu(VALUE self, VALUE main_gpu) {
665
+ LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
666
+ ptr->params.main_gpu = NUM2INT(main_gpu);
667
+ return INT2NUM(ptr->params.main_gpu);
668
+ }
669
+
670
+ static VALUE _llama_model_params_get_main_gpu(VALUE self) {
671
+ LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
672
+ return INT2NUM(ptr->params.main_gpu);
673
+ }
674
+
675
+ // tensor_split
676
+ static VALUE _llama_model_params_get_tensor_split(VALUE self) {
677
+ if (LLAMA_MAX_DEVICES < 1) {
678
+ return rb_ary_new();
679
+ }
680
+ VALUE ret = rb_ary_new2(LLAMA_MAX_DEVICES);
681
+ LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
682
+ if (ptr->params.tensor_split == nullptr) {
683
+ return rb_ary_new();
684
+ }
685
+ for (size_t i = 0; i < LLAMA_MAX_DEVICES; i++) {
686
+ rb_ary_store(ret, i, DBL2NUM(ptr->params.tensor_split[i]));
687
+ }
688
+ return ret;
689
+ }
690
+
691
+ // vocab_only
692
+ static VALUE _llama_model_params_set_vocab_only(VALUE self, VALUE vocab_only) {
693
+ LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
694
+ ptr->params.vocab_only = RTEST(vocab_only) ? true : false;
695
+ return ptr->params.vocab_only ? Qtrue : Qfalse;
696
+ }
697
+
698
+ static VALUE _llama_model_params_get_vocab_only(VALUE self) {
699
+ LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
700
+ return ptr->params.vocab_only ? Qtrue : Qfalse;
701
+ }
702
+
703
+ // use_mmap
704
+ static VALUE _llama_model_params_set_use_mmap(VALUE self, VALUE use_mmap) {
705
+ LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
706
+ ptr->params.use_mmap = RTEST(use_mmap) ? true : false;
707
+ return ptr->params.use_mmap ? Qtrue : Qfalse;
708
+ }
709
+
710
+ static VALUE _llama_model_params_get_use_mmap(VALUE self) {
711
+ LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
712
+ return ptr->params.use_mmap ? Qtrue : Qfalse;
713
+ }
714
+
715
+ // use_mlock
716
+ static VALUE _llama_model_params_set_use_mlock(VALUE self, VALUE use_mlock) {
717
+ LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
718
+ ptr->params.use_mlock = RTEST(use_mlock) ? true : false;
719
+ return ptr->params.use_mlock ? Qtrue : Qfalse;
720
+ }
721
+
722
+ static VALUE _llama_model_params_get_use_mlock(VALUE self) {
723
+ LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
724
+ return ptr->params.use_mlock ? Qtrue : Qfalse;
725
+ }
726
+ };
727
+
728
+ const rb_data_type_t RbLLaMAModelParams::llama_model_params_type = {
729
+ "RbLLaMAModelParams",
730
+ { NULL,
731
+ RbLLaMAModelParams::llama_model_params_free,
732
+ RbLLaMAModelParams::llama_model_params_size },
733
+ NULL,
734
+ NULL,
735
+ RUBY_TYPED_FREE_IMMEDIATELY
736
+ };
737
+
366
738
  class LLaMAContextParamsWrapper {
367
739
  public:
368
740
  struct llama_context_params params;
@@ -399,35 +771,26 @@ public:
399
771
  rb_cLLaMAContextParams = rb_define_class_under(outer, "ContextParams", rb_cObject);
400
772
  rb_define_alloc_func(rb_cLLaMAContextParams, llama_context_params_alloc);
401
773
  // rb_define_method(rb_cLLaMAContextParams, "initialize", RUBY_METHOD_FUNC(_llama_context_params_init), 0);
774
+ rb_define_method(rb_cLLaMAContextParams, "seed=", RUBY_METHOD_FUNC(_llama_context_params_set_seed), 1);
775
+ rb_define_method(rb_cLLaMAContextParams, "seed", RUBY_METHOD_FUNC(_llama_context_params_get_seed), 0);
402
776
  rb_define_method(rb_cLLaMAContextParams, "n_ctx=", RUBY_METHOD_FUNC(_llama_context_params_set_n_ctx), 1);
403
777
  rb_define_method(rb_cLLaMAContextParams, "n_ctx", RUBY_METHOD_FUNC(_llama_context_params_get_n_ctx), 0);
404
778
  rb_define_method(rb_cLLaMAContextParams, "n_batch=", RUBY_METHOD_FUNC(_llama_context_params_set_n_batch), 1);
405
779
  rb_define_method(rb_cLLaMAContextParams, "n_batch", RUBY_METHOD_FUNC(_llama_context_params_get_n_batch), 0);
406
- rb_define_method(rb_cLLaMAContextParams, "n_gpu_layers=", RUBY_METHOD_FUNC(_llama_context_params_set_n_gpu_layers), 1);
407
- rb_define_method(rb_cLLaMAContextParams, "n_gpu_layers", RUBY_METHOD_FUNC(_llama_context_params_get_n_gpu_layers), 0);
408
- rb_define_method(rb_cLLaMAContextParams, "main_gpu=", RUBY_METHOD_FUNC(_llama_context_params_set_main_gpu), 1);
409
- rb_define_method(rb_cLLaMAContextParams, "main_gpu", RUBY_METHOD_FUNC(_llama_context_params_get_main_gpu), 0);
410
- rb_define_method(rb_cLLaMAContextParams, "tensor_split", RUBY_METHOD_FUNC(_llama_context_params_get_tensor_split), 0);
780
+ rb_define_method(rb_cLLaMAContextParams, "n_threads=", RUBY_METHOD_FUNC(_llama_context_params_set_n_threads), 1);
781
+ rb_define_method(rb_cLLaMAContextParams, "n_threads", RUBY_METHOD_FUNC(_llama_context_params_get_n_threads), 0);
782
+ rb_define_method(rb_cLLaMAContextParams, "n_threads_batch=", RUBY_METHOD_FUNC(_llama_context_params_set_n_threads_batch), 1);
783
+ rb_define_method(rb_cLLaMAContextParams, "n_threads_batch", RUBY_METHOD_FUNC(_llama_context_params_get_n_threads_batch), 0);
411
784
  rb_define_method(rb_cLLaMAContextParams, "rope_freq_base=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_freq_base), 1);
412
785
  rb_define_method(rb_cLLaMAContextParams, "rope_freq_base", RUBY_METHOD_FUNC(_llama_context_params_get_rope_freq_base), 0);
413
786
  rb_define_method(rb_cLLaMAContextParams, "rope_freq_scale=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_freq_scale), 1);
414
787
  rb_define_method(rb_cLLaMAContextParams, "rope_freq_scale", RUBY_METHOD_FUNC(_llama_context_params_get_rope_freq_scale), 0);
415
- rb_define_method(rb_cLLaMAContextParams, "low_vram=", RUBY_METHOD_FUNC(_llama_context_params_set_low_vram), 1);
416
- rb_define_method(rb_cLLaMAContextParams, "low_vram", RUBY_METHOD_FUNC(_llama_context_params_get_low_vram), 0);
417
788
  rb_define_method(rb_cLLaMAContextParams, "mul_mat_q=", RUBY_METHOD_FUNC(_llama_context_params_set_mul_mat_q), 1);
418
789
  rb_define_method(rb_cLLaMAContextParams, "mul_mat_q", RUBY_METHOD_FUNC(_llama_context_params_get_mul_mat_q), 0);
419
- rb_define_method(rb_cLLaMAContextParams, "seed=", RUBY_METHOD_FUNC(_llama_context_params_set_seed), 1);
420
- rb_define_method(rb_cLLaMAContextParams, "seed", RUBY_METHOD_FUNC(_llama_context_params_get_seed), 0);
421
790
  rb_define_method(rb_cLLaMAContextParams, "f16_kv=", RUBY_METHOD_FUNC(_llama_context_params_set_f16_kv), 1);
422
791
  rb_define_method(rb_cLLaMAContextParams, "f16_kv", RUBY_METHOD_FUNC(_llama_context_params_get_f16_kv), 0);
423
792
  rb_define_method(rb_cLLaMAContextParams, "logits_all=", RUBY_METHOD_FUNC(_llama_context_params_set_logits_all), 1);
424
793
  rb_define_method(rb_cLLaMAContextParams, "logits_all", RUBY_METHOD_FUNC(_llama_context_params_get_logits_all), 0);
425
- rb_define_method(rb_cLLaMAContextParams, "vocab_only=", RUBY_METHOD_FUNC(_llama_context_params_set_vocab_only), 1);
426
- rb_define_method(rb_cLLaMAContextParams, "vocab_only", RUBY_METHOD_FUNC(_llama_context_params_get_vocab_only), 0);
427
- rb_define_method(rb_cLLaMAContextParams, "use_mmap=", RUBY_METHOD_FUNC(_llama_context_params_set_use_mmap), 1);
428
- rb_define_method(rb_cLLaMAContextParams, "use_mmap", RUBY_METHOD_FUNC(_llama_context_params_get_use_mmap), 0);
429
- rb_define_method(rb_cLLaMAContextParams, "use_mlock=", RUBY_METHOD_FUNC(_llama_context_params_set_use_mlock), 1);
430
- rb_define_method(rb_cLLaMAContextParams, "use_mlock", RUBY_METHOD_FUNC(_llama_context_params_get_use_mlock), 0);
431
794
  rb_define_method(rb_cLLaMAContextParams, "embedding=", RUBY_METHOD_FUNC(_llama_context_params_set_embedding), 1);
432
795
  rb_define_method(rb_cLLaMAContextParams, "embedding", RUBY_METHOD_FUNC(_llama_context_params_get_embedding), 0);
433
796
  }
@@ -441,6 +804,22 @@ private:
441
804
  // return self;
442
805
  // }
443
806
 
807
+ // seed
808
+ static VALUE _llama_context_params_set_seed(VALUE self, VALUE seed) {
809
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
810
+ if (NUM2INT(seed) < 0) {
811
+ rb_raise(rb_eArgError, "seed must be positive");
812
+ return Qnil;
813
+ }
814
+ ptr->params.seed = NUM2INT(seed);
815
+ return INT2NUM(ptr->params.seed);
816
+ }
817
+
818
+ static VALUE _llama_context_params_get_seed(VALUE self) {
819
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
820
+ return INT2NUM(ptr->params.seed);
821
+ }
822
+
444
823
  // n_ctx
445
824
  static VALUE _llama_context_params_set_n_ctx(VALUE self, VALUE n_ctx) {
446
825
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
@@ -465,41 +844,28 @@ private:
465
844
  return INT2NUM(ptr->params.n_batch);
466
845
  }
467
846
 
468
- // n_gpu_layers
469
- static VALUE _llama_context_params_set_n_gpu_layers(VALUE self, VALUE n_gpu_layers) {
847
+ // n_threads
848
+ static VALUE _llama_context_params_set_n_threads(VALUE self, VALUE n_threads) {
470
849
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
471
- ptr->params.n_gpu_layers = NUM2INT(n_gpu_layers);
472
- return INT2NUM(ptr->params.n_gpu_layers);
850
+ ptr->params.n_threads = NUM2INT(n_threads);
851
+ return INT2NUM(ptr->params.n_threads);
473
852
  }
474
853
 
475
- static VALUE _llama_context_params_get_n_gpu_layers(VALUE self) {
854
+ static VALUE _llama_context_params_get_n_threads(VALUE self) {
476
855
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
477
- return INT2NUM(ptr->params.n_gpu_layers);
856
+ return INT2NUM(ptr->params.n_threads);
478
857
  }
479
858
 
480
- // main_gpu
481
- static VALUE _llama_context_params_set_main_gpu(VALUE self, VALUE main_gpu) {
859
+ // n_threads_batch
860
+ static VALUE _llama_context_params_set_n_threads_batch(VALUE self, VALUE n_threads_batch) {
482
861
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
483
- ptr->params.main_gpu = NUM2INT(main_gpu);
484
- return INT2NUM(ptr->params.main_gpu);
862
+ ptr->params.n_threads_batch = NUM2INT(n_threads_batch);
863
+ return INT2NUM(ptr->params.n_threads_batch);
485
864
  }
486
865
 
487
- static VALUE _llama_context_params_get_main_gpu(VALUE self) {
866
+ static VALUE _llama_context_params_get_n_threads_batch(VALUE self) {
488
867
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
489
- return INT2NUM(ptr->params.main_gpu);
490
- }
491
-
492
- // tensor_split
493
- static VALUE _llama_context_params_get_tensor_split(VALUE self) {
494
- if (LLAMA_MAX_DEVICES < 1) {
495
- return rb_ary_new();
496
- }
497
- VALUE ret = rb_ary_new2(LLAMA_MAX_DEVICES);
498
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
499
- for (size_t i = 0; i < LLAMA_MAX_DEVICES; i++) {
500
- rb_ary_store(ret, i, DBL2NUM(ptr->params.tensor_split[i]));
501
- }
502
- return ret;
868
+ return INT2NUM(ptr->params.n_threads_batch);
503
869
  }
504
870
 
505
871
  // rope_freq_base
@@ -526,18 +892,6 @@ private:
526
892
  return DBL2NUM(ptr->params.rope_freq_scale);
527
893
  }
528
894
 
529
- // low_vram
530
- static VALUE _llama_context_params_set_low_vram(VALUE self, VALUE low_vram) {
531
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
532
- ptr->params.low_vram = RTEST(low_vram) ? true : false;
533
- return ptr->params.low_vram ? Qtrue : Qfalse;
534
- }
535
-
536
- static VALUE _llama_context_params_get_low_vram(VALUE self) {
537
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
538
- return ptr->params.low_vram ? Qtrue : Qfalse;
539
- }
540
-
541
895
  // mul_mat_q
542
896
  static VALUE _llama_context_params_set_mul_mat_q(VALUE self, VALUE mul_mat_q) {
543
897
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
@@ -550,22 +904,6 @@ private:
550
904
  return ptr->params.mul_mat_q ? Qtrue : Qfalse;
551
905
  }
552
906
 
553
- // seed
554
- static VALUE _llama_context_params_set_seed(VALUE self, VALUE seed) {
555
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
556
- if (NUM2INT(seed) < 0) {
557
- rb_raise(rb_eArgError, "seed must be positive");
558
- return Qnil;
559
- }
560
- ptr->params.seed = NUM2INT(seed);
561
- return INT2NUM(ptr->params.seed);
562
- }
563
-
564
- static VALUE _llama_context_params_get_seed(VALUE self) {
565
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
566
- return INT2NUM(ptr->params.seed);
567
- }
568
-
569
907
  // f16_kv
570
908
  static VALUE _llama_context_params_set_f16_kv(VALUE self, VALUE f16_kv) {
571
909
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
@@ -590,42 +928,6 @@ private:
590
928
  return ptr->params.logits_all ? Qtrue : Qfalse;
591
929
  }
592
930
 
593
- // vocab_only
594
- static VALUE _llama_context_params_set_vocab_only(VALUE self, VALUE vocab_only) {
595
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
596
- ptr->params.vocab_only = RTEST(vocab_only) ? true : false;
597
- return ptr->params.vocab_only ? Qtrue : Qfalse;
598
- }
599
-
600
- static VALUE _llama_context_params_get_vocab_only(VALUE self) {
601
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
602
- return ptr->params.vocab_only ? Qtrue : Qfalse;
603
- }
604
-
605
- // use_mmap
606
- static VALUE _llama_context_params_set_use_mmap(VALUE self, VALUE use_mmap) {
607
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
608
- ptr->params.use_mmap = RTEST(use_mmap) ? true : false;
609
- return ptr->params.use_mmap ? Qtrue : Qfalse;
610
- }
611
-
612
- static VALUE _llama_context_params_get_use_mmap(VALUE self) {
613
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
614
- return ptr->params.use_mmap ? Qtrue : Qfalse;
615
- }
616
-
617
- // use_mlock
618
- static VALUE _llama_context_params_set_use_mlock(VALUE self, VALUE use_mlock) {
619
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
620
- ptr->params.use_mlock = RTEST(use_mlock) ? true : false;
621
- return ptr->params.use_mlock ? Qtrue : Qfalse;
622
- }
623
-
624
- static VALUE _llama_context_params_get_use_mlock(VALUE self) {
625
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
626
- return ptr->params.use_mlock ? Qtrue : Qfalse;
627
- }
628
-
629
931
  // embedding
630
932
  static VALUE _llama_context_params_set_embedding(VALUE self, VALUE embedding) {
631
933
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
@@ -823,11 +1125,11 @@ public:
823
1125
  rb_define_method(rb_cLLaMAModel, "load", RUBY_METHOD_FUNC(_llama_model_load), -1);
824
1126
  rb_define_method(rb_cLLaMAModel, "apply_lora_from_file", RUBY_METHOD_FUNC(_llama_model_apply_lora_from_file), -1);
825
1127
  rb_define_method(rb_cLLaMAModel, "n_vocab", RUBY_METHOD_FUNC(_llama_model_get_model_n_vocab), 0);
826
- rb_define_method(rb_cLLaMAModel, "n_ctx", RUBY_METHOD_FUNC(_llama_model_get_model_n_ctx), 0);
827
1128
  rb_define_method(rb_cLLaMAModel, "n_ctx_train", RUBY_METHOD_FUNC(_llama_model_get_model_n_ctx_train), 0);
828
1129
  rb_define_method(rb_cLLaMAModel, "n_embd", RUBY_METHOD_FUNC(_llama_model_get_model_n_embd), 0);
829
- rb_define_method(rb_cLLaMAModel, "token_to_piece", RUBY_METHOD_FUNC(_llama_model_token_to_piece_with_model), 1);
830
- rb_define_method(rb_cLLaMAModel, "tokenize", RUBY_METHOD_FUNC(_llama_model_tokenize_with_model), -1);
1130
+ rb_define_method(rb_cLLaMAModel, "rope_freq_scale_train", RUBY_METHOD_FUNC(_llama_model_rope_freq_scale_train), 0);
1131
+ rb_define_method(rb_cLLaMAModel, "token_to_piece", RUBY_METHOD_FUNC(_llama_model_token_to_piece), 1);
1132
+ rb_define_method(rb_cLLaMAModel, "tokenize", RUBY_METHOD_FUNC(_llama_model_tokenize), -1);
831
1133
  rb_define_method(rb_cLLaMAModel, "desc", RUBY_METHOD_FUNC(_llama_model_get_model_desc), 0);
832
1134
  rb_define_method(rb_cLLaMAModel, "size", RUBY_METHOD_FUNC(_llama_model_get_model_size), 0);
833
1135
  rb_define_method(rb_cLLaMAModel, "n_params", RUBY_METHOD_FUNC(_llama_model_get_model_n_params), 0);
@@ -841,30 +1143,21 @@ private:
841
1143
  ID kw_table[2] = { rb_intern("model_path"), rb_intern("params") };
842
1144
  VALUE kw_values[2] = { Qundef, Qundef };
843
1145
  rb_scan_args(argc, argv, ":", &kw_args);
844
- rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
845
-
846
- if (kw_values[0] == Qundef && kw_values[1] == Qundef) {
847
- rb_iv_set(self, "@params", Qnil);
848
- return Qnil;
849
- }
1146
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
850
1147
 
851
1148
  if (!RB_TYPE_P(kw_values[0], T_STRING)) {
852
1149
  rb_raise(rb_eArgError, "model_path must be a string");
853
1150
  return Qnil;
854
1151
  }
855
- if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAContextParams)) {
856
- rb_raise(rb_eArgError, "params must be a ContextParams");
1152
+ if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAModelParams)) {
1153
+ rb_raise(rb_eArgError, "params must be a ModelParams");
857
1154
  return Qnil;
858
1155
  }
859
1156
 
860
1157
  VALUE filename = kw_values[0];
861
- LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(kw_values[1]);
1158
+ LLaMAModelParamsWrapper* prms_ptr = RbLLaMAModelParams::get_llama_model_params(kw_values[1]);
862
1159
  LLaMAModelWrapper* model_ptr = get_llama_model(self);
863
1160
 
864
- if (prms_ptr->params.seed == LLAMA_DEFAULT_SEED) {
865
- prms_ptr->params.seed = time(NULL);
866
- }
867
-
868
1161
  try {
869
1162
  model_ptr->model = llama_load_model_from_file(StringValueCStr(filename), prms_ptr->params);
870
1163
  } catch (const std::runtime_error& e) {
@@ -912,8 +1205,8 @@ private:
912
1205
  rb_raise(rb_eArgError, "model_path must be a string");
913
1206
  return Qnil;
914
1207
  }
915
- if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAContextParams)) {
916
- rb_raise(rb_eArgError, "params must be a LLaMAContextParams");
1208
+ if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAModelParams)) {
1209
+ rb_raise(rb_eArgError, "params must be a LLaMAModelParams");
917
1210
  return Qnil;
918
1211
  }
919
1212
 
@@ -924,7 +1217,7 @@ private:
924
1217
  }
925
1218
 
926
1219
  VALUE filename = kw_values[0];
927
- LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(kw_values[1]);
1220
+ LLaMAModelParamsWrapper* prms_ptr = RbLLaMAModelParams::get_llama_model_params(kw_values[1]);
928
1221
 
929
1222
  try {
930
1223
  model_ptr->model = llama_load_model_from_file(StringValueCStr(filename), prms_ptr->params);
@@ -946,10 +1239,10 @@ private:
946
1239
 
947
1240
  static VALUE _llama_model_apply_lora_from_file(int argc, VALUE* argv, VALUE self) {
948
1241
  VALUE kw_args = Qnil;
949
- ID kw_table[3] = { rb_intern("lora_path"), rb_intern("base_model_path"), rb_intern("n_threads") };
950
- VALUE kw_values[3] = { Qundef, Qundef, Qundef };
1242
+ ID kw_table[4] = { rb_intern("lora_path"), rb_intern("base_model_path"), rb_intern("n_threads"), rb_intern("scale") };
1243
+ VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
951
1244
  rb_scan_args(argc, argv, ":", &kw_args);
952
- rb_get_kwargs(kw_args, kw_table, 1, 2, kw_values);
1245
+ rb_get_kwargs(kw_args, kw_table, 1, 3, kw_values);
953
1246
 
954
1247
  if (!RB_TYPE_P(kw_values[0], T_STRING)) {
955
1248
  rb_raise(rb_eArgError, "lora_path must be a string");
@@ -963,13 +1256,18 @@ private:
963
1256
  rb_raise(rb_eArgError, "n_threads must be an integer");
964
1257
  return Qnil;
965
1258
  }
1259
+ if (kw_values[3] != Qundef && !RB_FLOAT_TYPE_P(kw_values[3])) {
1260
+ rb_raise(rb_eArgError, "scale must be a float");
1261
+ return Qnil;
1262
+ }
966
1263
 
967
1264
  const char* lora_path = StringValueCStr(kw_values[0]);
968
1265
  const char* base_model_path = kw_values[1] == Qundef ? NULL : StringValueCStr(kw_values[1]);
969
1266
  const int n_threads = kw_values[2] == Qundef ? 1 : NUM2INT(kw_values[2]);
1267
+ const float scale = kw_values[3] == Qundef ? 1.0 : NUM2DBL(kw_values[3]);
970
1268
 
971
1269
  LLaMAModelWrapper* ptr = get_llama_model(self);
972
- if (llama_model_apply_lora_from_file(ptr->model, lora_path, base_model_path, n_threads) != 0) {
1270
+ if (llama_model_apply_lora_from_file(ptr->model, lora_path, scale, base_model_path, n_threads) != 0) {
973
1271
  rb_raise(rb_eRuntimeError, "Failed to apply LoRA");
974
1272
  return Qnil;
975
1273
  }
@@ -978,25 +1276,25 @@ private:
978
1276
 
979
1277
  static VALUE _llama_model_get_model_n_vocab(VALUE self) {
980
1278
  LLaMAModelWrapper* ptr = get_llama_model(self);
981
- return INT2NUM(llama_model_n_vocab(ptr->model));
1279
+ return INT2NUM(llama_n_vocab(ptr->model));
982
1280
  }
983
1281
 
984
- static VALUE _llama_model_get_model_n_ctx(VALUE self) {
1282
+ static VALUE _llama_model_get_model_n_ctx_train(VALUE self) {
985
1283
  LLaMAModelWrapper* ptr = get_llama_model(self);
986
- return INT2NUM(llama_model_n_ctx(ptr->model));
1284
+ return INT2NUM(llama_n_ctx_train(ptr->model));
987
1285
  }
988
1286
 
989
- static VALUE _llama_model_get_model_n_ctx_train(VALUE self) {
1287
+ static VALUE _llama_model_get_model_n_embd(VALUE self) {
990
1288
  LLaMAModelWrapper* ptr = get_llama_model(self);
991
- return INT2NUM(llama_model_n_ctx_train(ptr->model));
1289
+ return INT2NUM(llama_n_embd(ptr->model));
992
1290
  }
993
1291
 
994
- static VALUE _llama_model_get_model_n_embd(VALUE self) {
1292
+ static VALUE _llama_model_rope_freq_scale_train(VALUE self) {
995
1293
  LLaMAModelWrapper* ptr = get_llama_model(self);
996
- return INT2NUM(llama_model_n_embd(ptr->model));
1294
+ return DBL2NUM(llama_rope_freq_scale_train(ptr->model));
997
1295
  }
998
1296
 
999
- static VALUE _llama_model_token_to_piece_with_model(VALUE self, VALUE token_) {
1297
+ static VALUE _llama_model_token_to_piece(VALUE self, VALUE token_) {
1000
1298
  if (!RB_INTEGER_TYPE_P(token_)) {
1001
1299
  rb_raise(rb_eArgError, "token must be an integer");
1002
1300
  return Qnil;
@@ -1004,10 +1302,10 @@ private:
1004
1302
  const llama_token token = NUM2INT(token_);
1005
1303
  LLaMAModelWrapper* ptr = get_llama_model(self);
1006
1304
  std::vector<char> result(8, 0);
1007
- const int n_tokens = llama_token_to_piece_with_model(ptr->model, token, result.data(), result.size());
1305
+ const int n_tokens = llama_token_to_piece(ptr->model, token, result.data(), result.size());
1008
1306
  if (n_tokens < 0) {
1009
1307
  result.resize(-n_tokens);
1010
- const int check = llama_token_to_piece_with_model(ptr->model, token, result.data(), result.size());
1308
+ const int check = llama_token_to_piece(ptr->model, token, result.data(), result.size());
1011
1309
  if (check != -n_tokens) {
1012
1310
  rb_raise(rb_eRuntimeError, "failed to convert");
1013
1311
  return Qnil;
@@ -1016,10 +1314,10 @@ private:
1016
1314
  result.resize(n_tokens);
1017
1315
  }
1018
1316
  std::string ret(result.data(), result.size());
1019
- return rb_str_new_cstr(ret.c_str());
1317
+ return rb_utf8_str_new_cstr(ret.c_str());
1020
1318
  }
1021
1319
 
1022
- static VALUE _llama_model_tokenize_with_model(int argc, VALUE* argv, VALUE self) {
1320
+ static VALUE _llama_model_tokenize(int argc, VALUE* argv, VALUE self) {
1023
1321
  VALUE kw_args = Qnil;
1024
1322
  ID kw_table[3] = { rb_intern("text"), rb_intern("n_max_tokens"), rb_intern("add_bos") };
1025
1323
  VALUE kw_values[3] = { Qundef, Qundef, Qundef };
@@ -1046,7 +1344,7 @@ private:
1046
1344
 
1047
1345
  llama_token* tokens = ALLOCA_N(llama_token, n_max_tokens);
1048
1346
  LLaMAModelWrapper* ptr = get_llama_model(self);
1049
- const int n_tokens = llama_tokenize_with_model(ptr->model, text.c_str(), text.size(), tokens, n_max_tokens, add_bos);
1347
+ const int n_tokens = llama_tokenize(ptr->model, text.c_str(), text.size(), tokens, n_max_tokens, add_bos);
1050
1348
 
1051
1349
  if (n_tokens < 0) {
1052
1350
  rb_raise(rb_eRuntimeError, "failed to tokenize. The numebr of tokens (%d) is greater than n_max_tokens.", -n_tokens);
@@ -1066,7 +1364,7 @@ private:
1066
1364
  LLaMAModelWrapper* ptr = get_llama_model(self);
1067
1365
  char buf[128];
1068
1366
  llama_model_desc(ptr->model, buf, sizeof(buf));
1069
- return rb_str_new_cstr(buf);
1367
+ return rb_utf8_str_new_cstr(buf);
1070
1368
  }
1071
1369
 
1072
1370
  static VALUE _llama_model_get_model_size(VALUE self) {
@@ -1345,11 +1643,11 @@ public:
1345
1643
  static void define_class(VALUE outer) {
1346
1644
  rb_cLLaMAContext = rb_define_class_under(outer, "Context", rb_cObject);
1347
1645
  rb_define_alloc_func(rb_cLLaMAContext, llama_context_alloc);
1646
+ rb_define_attr(rb_cLLaMAContext, "model", 1, 0);
1348
1647
  rb_define_method(rb_cLLaMAContext, "initialize", RUBY_METHOD_FUNC(_llama_context_initialize), -1);
1349
1648
  rb_define_method(rb_cLLaMAContext, "eval", RUBY_METHOD_FUNC(_llama_context_eval), -1);
1350
1649
  rb_define_method(rb_cLLaMAContext, "eval_embd", RUBY_METHOD_FUNC(_llama_context_eval_embd), -1);
1351
- rb_define_method(rb_cLLaMAContext, "eval_export", RUBY_METHOD_FUNC(_llama_context_eval_export), 1);
1352
- rb_define_method(rb_cLLaMAContext, "tokenize", RUBY_METHOD_FUNC(_llama_context_tokenize), -1);
1650
+ rb_define_method(rb_cLLaMAContext, "decode", RUBY_METHOD_FUNC(_llama_context_decode), 1);
1353
1651
  rb_define_method(rb_cLLaMAContext, "logits", RUBY_METHOD_FUNC(_llama_context_logits), 0);
1354
1652
  rb_define_method(rb_cLLaMAContext, "embeddings", RUBY_METHOD_FUNC(_llama_context_embeddings), 0);
1355
1653
  rb_define_method(rb_cLLaMAContext, "text", RUBY_METHOD_FUNC(_llama_context_text), 1);
@@ -1358,15 +1656,20 @@ public:
1358
1656
  rb_define_method(rb_cLLaMAContext, "token_bos", RUBY_METHOD_FUNC(_llama_context_token_bos), 0);
1359
1657
  rb_define_method(rb_cLLaMAContext, "token_eos", RUBY_METHOD_FUNC(_llama_context_token_eos), 0);
1360
1658
  rb_define_method(rb_cLLaMAContext, "token_nl", RUBY_METHOD_FUNC(_llama_context_token_nl), 0);
1361
- rb_define_method(rb_cLLaMAContext, "token_to_piece", RUBY_METHOD_FUNC(_llama_context_token_to_piece), 1);
1362
- rb_define_method(rb_cLLaMAContext, "n_vocab", RUBY_METHOD_FUNC(_llama_context_n_vocab), 0);
1659
+ rb_define_method(rb_cLLaMAContext, "token_prefix", RUBY_METHOD_FUNC(_llama_context_token_prefix), 0);
1660
+ rb_define_method(rb_cLLaMAContext, "token_middle", RUBY_METHOD_FUNC(_llama_context_token_middle), 0);
1661
+ rb_define_method(rb_cLLaMAContext, "token_suffix", RUBY_METHOD_FUNC(_llama_context_token_suffix), 0);
1662
+ rb_define_method(rb_cLLaMAContext, "token_eot", RUBY_METHOD_FUNC(_llama_context_token_eot), 0);
1363
1663
  rb_define_method(rb_cLLaMAContext, "n_ctx", RUBY_METHOD_FUNC(_llama_context_n_ctx), 0);
1364
- rb_define_method(rb_cLLaMAContext, "n_ctx_train", RUBY_METHOD_FUNC(_llama_context_n_ctx_train), 0);
1365
- rb_define_method(rb_cLLaMAContext, "n_embd", RUBY_METHOD_FUNC(_llama_context_n_embd), 0);
1366
1664
  rb_define_method(rb_cLLaMAContext, "timings", RUBY_METHOD_FUNC(_llama_context_get_timings), 0);
1367
1665
  rb_define_method(rb_cLLaMAContext, "print_timings", RUBY_METHOD_FUNC(_llama_context_print_timings), 0);
1368
1666
  rb_define_method(rb_cLLaMAContext, "reset_timings", RUBY_METHOD_FUNC(_llama_context_reset_timings), 0);
1369
1667
  rb_define_method(rb_cLLaMAContext, "kv_cache_token_count", RUBY_METHOD_FUNC(_llama_context_kv_cache_token_count), 0);
1668
+ rb_define_method(rb_cLLaMAContext, "kv_cache_tokens_rm", RUBY_METHOD_FUNC(_llama_context_kv_cache_tokens_rm), 2);
1669
+ rb_define_method(rb_cLLaMAContext, "kv_cache_seq_rm", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_rm), 3);
1670
+ rb_define_method(rb_cLLaMAContext, "kv_cache_seq_cp", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_cp), 4);
1671
+ rb_define_method(rb_cLLaMAContext, "kv_cache_seq_keep", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_keep), 1);
1672
+ rb_define_method(rb_cLLaMAContext, "kv_cache_seq_shift", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_shift), 4);
1370
1673
  rb_define_method(rb_cLLaMAContext, "set_rng_seed", RUBY_METHOD_FUNC(_llama_context_set_rng_seed), 1);
1371
1674
  rb_define_method(rb_cLLaMAContext, "load_session_file", RUBY_METHOD_FUNC(_llama_context_load_session_file), -1);
1372
1675
  rb_define_method(rb_cLLaMAContext, "save_session_file", RUBY_METHOD_FUNC(_llama_context_save_session_file), -1);
@@ -1378,6 +1681,7 @@ public:
1378
1681
  rb_define_method(rb_cLLaMAContext, "sample_top_p", RUBY_METHOD_FUNC(_llama_context_sample_top_p), -1);
1379
1682
  rb_define_method(rb_cLLaMAContext, "sample_tail_free", RUBY_METHOD_FUNC(_llama_context_sample_tail_free), -1);
1380
1683
  rb_define_method(rb_cLLaMAContext, "sample_typical", RUBY_METHOD_FUNC(_llama_context_sample_typical), -1);
1684
+ rb_define_method(rb_cLLaMAContext, "sample_temp", RUBY_METHOD_FUNC(_llama_context_sample_temp), -1);
1381
1685
  rb_define_method(rb_cLLaMAContext, "sample_temperature", RUBY_METHOD_FUNC(_llama_context_sample_temperature), -1);
1382
1686
  rb_define_method(rb_cLLaMAContext, "sample_token_mirostat", RUBY_METHOD_FUNC(_llama_context_sample_token_mirostat), -1);
1383
1687
  rb_define_method(rb_cLLaMAContext, "sample_token_mirostat_v2", RUBY_METHOD_FUNC(_llama_context_sample_token_mirostat_v2), -1);
@@ -1392,24 +1696,27 @@ private:
1392
1696
 
1393
1697
  static VALUE _llama_context_initialize(int argc, VALUE* argv, VALUE self) {
1394
1698
  VALUE kw_args = Qnil;
1395
- ID kw_table[1] = { rb_intern("model") };
1396
- VALUE kw_values[1] = { Qundef };
1699
+ ID kw_table[2] = { rb_intern("model"), rb_intern("params") };
1700
+ VALUE kw_values[2] = { Qundef, Qundef };
1397
1701
  rb_scan_args(argc, argv, ":", &kw_args);
1398
- rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
1702
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
1399
1703
 
1400
1704
  VALUE model = kw_values[0];
1401
1705
  if (!rb_obj_is_kind_of(model, rb_cLLaMAModel)) {
1402
1706
  rb_raise(rb_eArgError, "model must be a Model");
1403
1707
  return Qnil;
1404
1708
  }
1709
+ VALUE params = kw_values[1];
1710
+ if (!rb_obj_is_kind_of(params, rb_cLLaMAContextParams)) {
1711
+ rb_raise(rb_eArgError, "params must be a ContextParams");
1712
+ return Qnil;
1713
+ }
1405
1714
 
1406
1715
  LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
1407
1716
  if (model_ptr->model == NULL) {
1408
1717
  rb_raise(rb_eRuntimeError, "Model is empty");
1409
1718
  return Qnil;
1410
1719
  }
1411
-
1412
- VALUE params = rb_iv_get(model, "@params");
1413
1720
  LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
1414
1721
  LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
1415
1722
 
@@ -1421,6 +1728,7 @@ private:
1421
1728
  }
1422
1729
 
1423
1730
  rb_iv_set(self, "@model", model);
1731
+ rb_iv_set(self, "@params", params);
1424
1732
  rb_iv_set(self, "@has_evaluated", Qfalse);
1425
1733
 
1426
1734
  return Qnil;
@@ -1428,8 +1736,8 @@ private:
1428
1736
 
1429
1737
  static VALUE _llama_context_eval(int argc, VALUE* argv, VALUE self) {
1430
1738
  VALUE kw_args = Qnil;
1431
- ID kw_table[4] = { rb_intern("tokens"), rb_intern("n_past"), rb_intern("n_tokens"), rb_intern("n_threads") };
1432
- VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
1739
+ ID kw_table[3] = { rb_intern("tokens"), rb_intern("n_past"), rb_intern("n_tokens") };
1740
+ VALUE kw_values[3] = { Qundef, Qundef, Qundef };
1433
1741
  rb_scan_args(argc, argv, ":", &kw_args);
1434
1742
  rb_get_kwargs(kw_args, kw_table, 2, 2, kw_values);
1435
1743
 
@@ -1445,10 +1753,6 @@ private:
1445
1753
  rb_raise(rb_eArgError, "n_tokens must be an integer");
1446
1754
  return Qnil;
1447
1755
  }
1448
- if (kw_values[3] != Qundef && !RB_INTEGER_TYPE_P(kw_values[3])) {
1449
- rb_raise(rb_eArgError, "n_threads must be an integer");
1450
- return Qnil;
1451
- }
1452
1756
 
1453
1757
  const size_t tokens_len = RARRAY_LEN(kw_values[0]);
1454
1758
  std::vector<llama_token> embd(tokens_len);
@@ -1463,14 +1767,13 @@ private:
1463
1767
 
1464
1768
  const int n_tokens = kw_values[2] == Qundef ? (int)tokens_len : NUM2INT(kw_values[2]);
1465
1769
  const int n_past = NUM2INT(kw_values[1]);
1466
- const int n_threads = kw_values[3] == Qundef ? 1 : NUM2INT(kw_values[3]);
1467
1770
 
1468
1771
  LLaMAContextWrapper* ptr = get_llama_context(self);
1469
1772
  if (ptr->ctx == NULL) {
1470
1773
  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1471
1774
  return Qnil;
1472
1775
  }
1473
- if (llama_eval(ptr->ctx, embd.data(), n_tokens, n_past, n_threads) != 0) {
1776
+ if (llama_eval(ptr->ctx, embd.data(), n_tokens, n_past) != 0) {
1474
1777
  rb_raise(rb_eRuntimeError, "Failed to evaluate");
1475
1778
  return Qnil;
1476
1779
  }
@@ -1483,8 +1786,8 @@ private:
1483
1786
 
1484
1787
  static VALUE _llama_context_eval_embd(int argc, VALUE* argv, VALUE self) {
1485
1788
  VALUE kw_args = Qnil;
1486
- ID kw_table[4] = { rb_intern("embd"), rb_intern("n_past"), rb_intern("n_tokens"), rb_intern("n_threads") };
1487
- VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
1789
+ ID kw_table[3] = { rb_intern("embd"), rb_intern("n_past"), rb_intern("n_tokens") };
1790
+ VALUE kw_values[3] = { Qundef, Qundef, Qundef };
1488
1791
  rb_scan_args(argc, argv, ":", &kw_args);
1489
1792
  rb_get_kwargs(kw_args, kw_table, 2, 2, kw_values);
1490
1793
 
@@ -1500,10 +1803,6 @@ private:
1500
1803
  rb_raise(rb_eArgError, "n_tokens must be an integer");
1501
1804
  return Qnil;
1502
1805
  }
1503
- if (kw_values[3] != Qundef && !RB_INTEGER_TYPE_P(kw_values[3])) {
1504
- rb_raise(rb_eArgError, "n_threads must be an integer");
1505
- return Qnil;
1506
- }
1507
1806
 
1508
1807
  const size_t tokens_len = RARRAY_LEN(kw_values[0]);
1509
1808
  std::vector<float> embd(tokens_len);
@@ -1518,14 +1817,13 @@ private:
1518
1817
 
1519
1818
  const int n_tokens = kw_values[2] == Qundef ? (int)tokens_len : NUM2INT(kw_values[2]);
1520
1819
  const int n_past = NUM2INT(kw_values[1]);
1521
- const int n_threads = kw_values[3] == Qundef ? 1 : NUM2INT(kw_values[3]);
1522
1820
 
1523
1821
  LLaMAContextWrapper* ptr = get_llama_context(self);
1524
1822
  if (ptr->ctx == NULL) {
1525
1823
  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1526
1824
  return Qnil;
1527
1825
  }
1528
- if (llama_eval_embd(ptr->ctx, embd.data(), n_tokens, n_past, n_threads) != 0) {
1826
+ if (llama_eval_embd(ptr->ctx, embd.data(), n_tokens, n_past) != 0) {
1529
1827
  rb_raise(rb_eRuntimeError, "Failed to evaluate");
1530
1828
  return Qnil;
1531
1829
  }
@@ -1536,91 +1834,22 @@ private:
1536
1834
  return Qnil;
1537
1835
  }
1538
1836
 
1539
- static VALUE _llama_context_eval_export(VALUE self, VALUE fname_) {
1837
+ static VALUE _llama_context_decode(VALUE self, VALUE batch) {
1540
1838
  LLaMAContextWrapper* ptr = get_llama_context(self);
1541
1839
  if (ptr->ctx == NULL) {
1542
1840
  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1543
1841
  return Qnil;
1544
1842
  }
1545
- if (!RB_TYPE_P(fname_, T_STRING)) {
1546
- rb_raise(rb_eArgError, "fname must be a string");
1547
- return Qnil;
1548
- }
1549
- const char* fname = StringValueCStr(fname_);
1550
- if (llama_eval_export(ptr->ctx, fname) != 0) {
1551
- return Qfalse;
1552
- }
1553
- RB_GC_GUARD(fname_);
1554
- return Qtrue;
1555
- }
1556
-
1557
- static VALUE _llama_context_tokenize(int argc, VALUE* argv, VALUE self) {
1558
- VALUE kw_args = Qnil;
1559
- ID kw_table[3] = { rb_intern("text"), rb_intern("n_max_tokens"), rb_intern("add_bos") };
1560
- VALUE kw_values[3] = { Qundef, Qundef, Qundef };
1561
- rb_scan_args(argc, argv, ":", &kw_args);
1562
- rb_get_kwargs(kw_args, kw_table, 1, 2, kw_values);
1563
-
1564
- if (!RB_TYPE_P(kw_values[0], T_STRING)) {
1565
- rb_raise(rb_eArgError, "text must be a String");
1566
- return Qnil;
1567
- }
1568
- if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
1569
- rb_raise(rb_eArgError, "n_max_tokens must be an integer");
1843
+ if (!rb_obj_is_kind_of(batch, rb_cLLaMABatch)) {
1844
+ rb_raise(rb_eArgError, "batch must be a Batch");
1570
1845
  return Qnil;
1571
1846
  }
1572
- if (kw_values[2] != Qundef && (kw_values[2] != Qtrue && kw_values[2] != Qfalse)) {
1573
- rb_raise(rb_eArgError, "add_bos must be a boolean");
1847
+ LLaMABatchWrapper* batch_ptr = RbLLaMABatch::get_llama_batch(batch);
1848
+ if (llama_decode(ptr->ctx, batch_ptr->batch) < 0) {
1849
+ rb_raise(rb_eRuntimeError, "Failed to decode");
1574
1850
  return Qnil;
1575
1851
  }
1576
-
1577
- VALUE text_ = kw_values[0];
1578
- std::string text = StringValueCStr(text_);
1579
- const bool add_bos = kw_values[2] == Qtrue ? true : false;
1580
- const int n_max_tokens = kw_values[1] != Qundef ? NUM2INT(kw_values[1]) : text.size() + (add_bos ? 1 : 0);
1581
-
1582
- std::vector<llama_token> tokens(n_max_tokens);
1583
- LLaMAContextWrapper* ptr = get_llama_context(self);
1584
- if (ptr->ctx == NULL) {
1585
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1586
- return Qnil;
1587
- }
1588
- const int n = llama_tokenize(ptr->ctx, text.c_str(), text.size(), tokens.data(), n_max_tokens, add_bos);
1589
- if (n < 0) {
1590
- rb_raise(rb_eRuntimeError, "Failed to tokenize");
1591
- return Qnil;
1592
- }
1593
-
1594
- VALUE output = rb_ary_new();
1595
- for (int i = 0; i < n; i++) {
1596
- rb_ary_push(output, INT2NUM(tokens[i]));
1597
- }
1598
-
1599
- RB_GC_GUARD(text_);
1600
- return output;
1601
- }
1602
-
1603
- static VALUE _llama_context_token_to_piece(VALUE self, VALUE token_) {
1604
- LLaMAContextWrapper* ptr = get_llama_context(self);
1605
- if (ptr->ctx == NULL) {
1606
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1607
- return Qnil;
1608
- }
1609
- const llama_token token = NUM2INT(token_);
1610
- std::vector<char> result(8, 0);
1611
- const int n_tokens = llama_token_to_piece(ptr->ctx, token, result.data(), result.size());
1612
- if (n_tokens < 0) {
1613
- result.resize(-n_tokens);
1614
- const int check = llama_token_to_piece(ptr->ctx, token, result.data(), result.size());
1615
- if (check != -n_tokens) {
1616
- rb_raise(rb_eRuntimeError, "failed to convert");
1617
- return Qnil;
1618
- }
1619
- } else {
1620
- result.resize(n_tokens);
1621
- }
1622
- std::string ret(result.data(), result.size());
1623
- return rb_str_new_cstr(ret.c_str());
1852
+ return Qnil;
1624
1853
  }
1625
1854
 
1626
1855
  static VALUE _llama_context_logits(VALUE self) {
@@ -1635,10 +1864,11 @@ private:
1635
1864
  }
1636
1865
 
1637
1866
  VALUE model = rb_iv_get(self, "@model");
1638
- VALUE params = rb_iv_get(model, "@params");
1867
+ LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
1868
+ VALUE params = rb_iv_get(self, "@params");
1639
1869
  LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
1640
1870
  const int n_tokens = prms_ptr->params.logits_all ? NUM2INT(rb_iv_get(self, "@n_tokens")) : 1;
1641
- const int n_vocab = llama_n_vocab(ptr->ctx);
1871
+ const int n_vocab = llama_n_vocab(model_ptr->model);
1642
1872
  const float* logits = llama_get_logits(ptr->ctx);
1643
1873
  VALUE output = rb_ary_new();
1644
1874
  for (int i = 0; i < n_tokens * n_vocab; i++) {
@@ -1655,7 +1885,8 @@ private:
1655
1885
  return Qnil;
1656
1886
  }
1657
1887
  VALUE model = rb_iv_get(self, "@model");
1658
- VALUE params = rb_iv_get(model, "@params");
1888
+ LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
1889
+ VALUE params = rb_iv_get(self, "@params");
1659
1890
  LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
1660
1891
  if (!prms_ptr->params.embedding) {
1661
1892
  rb_raise(rb_eRuntimeError, "embedding parameter is false");
@@ -1666,7 +1897,7 @@ private:
1666
1897
  return Qnil;
1667
1898
  }
1668
1899
 
1669
- const int n_embd = llama_n_embd(ptr->ctx);
1900
+ const int n_embd = llama_n_embd(model_ptr->model);
1670
1901
  const float* embd = llama_get_embeddings(ptr->ctx);
1671
1902
  VALUE output = rb_ary_new();
1672
1903
  for (int i = 0; i < n_embd; i++) {
@@ -1684,7 +1915,7 @@ private:
1684
1915
  }
1685
1916
  const llama_token token = NUM2INT(token_);
1686
1917
  const char* text = llama_token_get_text(ptr->ctx, token);
1687
- return rb_str_new_cstr(text);
1918
+ return rb_utf8_str_new_cstr(text);
1688
1919
  }
1689
1920
 
1690
1921
  static VALUE _llama_context_score(VALUE self, VALUE token_) {
@@ -1736,40 +1967,49 @@ private:
1736
1967
  return INT2NUM(llama_token_nl(ptr->ctx));
1737
1968
  }
1738
1969
 
1739
- static VALUE _llama_context_n_vocab(VALUE self) {
1970
+ static VALUE _llama_context_token_prefix(VALUE self) {
1740
1971
  LLaMAContextWrapper* ptr = get_llama_context(self);
1741
1972
  if (ptr->ctx == NULL) {
1742
1973
  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1743
1974
  return Qnil;
1744
1975
  }
1745
- return INT2NUM(llama_n_vocab(ptr->ctx));
1976
+ return INT2NUM(llama_token_prefix(ptr->ctx));
1746
1977
  }
1747
1978
 
1748
- static VALUE _llama_context_n_ctx(VALUE self) {
1979
+ static VALUE _llama_context_token_middle(VALUE self) {
1749
1980
  LLaMAContextWrapper* ptr = get_llama_context(self);
1750
1981
  if (ptr->ctx == NULL) {
1751
1982
  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1752
1983
  return Qnil;
1753
1984
  }
1754
- return INT2NUM(llama_n_ctx(ptr->ctx));
1985
+ return INT2NUM(llama_token_middle(ptr->ctx));
1986
+ }
1987
+
1988
+ static VALUE _llama_context_token_suffix(VALUE self) {
1989
+ LLaMAContextWrapper* ptr = get_llama_context(self);
1990
+ if (ptr->ctx == NULL) {
1991
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1992
+ return Qnil;
1993
+ }
1994
+ return INT2NUM(llama_token_suffix(ptr->ctx));
1755
1995
  }
1756
1996
 
1757
- static VALUE _llama_context_n_ctx_train(VALUE self) {
1997
+ static VALUE _llama_context_token_eot(VALUE self) {
1758
1998
  LLaMAContextWrapper* ptr = get_llama_context(self);
1759
1999
  if (ptr->ctx == NULL) {
1760
2000
  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1761
2001
  return Qnil;
1762
2002
  }
1763
- return INT2NUM(llama_n_ctx_train(ptr->ctx));
2003
+ return INT2NUM(llama_token_eot(ptr->ctx));
1764
2004
  }
1765
2005
 
1766
- static VALUE _llama_context_n_embd(VALUE self) {
2006
+ static VALUE _llama_context_n_ctx(VALUE self) {
1767
2007
  LLaMAContextWrapper* ptr = get_llama_context(self);
1768
2008
  if (ptr->ctx == NULL) {
1769
2009
  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1770
2010
  return Qnil;
1771
2011
  }
1772
- return INT2NUM(llama_n_embd(ptr->ctx));
2012
+ return INT2NUM(llama_n_ctx(ptr->ctx));
1773
2013
  }
1774
2014
 
1775
2015
  static VALUE _llama_context_get_timings(VALUE self) {
@@ -1813,6 +2053,56 @@ private:
1813
2053
  return INT2NUM(llama_get_kv_cache_token_count(ptr->ctx));
1814
2054
  }
1815
2055
 
2056
+ static VALUE _llama_context_kv_cache_tokens_rm(VALUE self, VALUE c0, VALUE c1) {
2057
+ LLaMAContextWrapper* ptr = get_llama_context(self);
2058
+ if (ptr->ctx == NULL) {
2059
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2060
+ return Qnil;
2061
+ }
2062
+ llama_kv_cache_tokens_rm(ptr->ctx, NUM2INT(c0), NUM2INT(c1));
2063
+ return Qnil;
2064
+ }
2065
+
2066
+ static VALUE _llama_context_kv_cache_seq_rm(VALUE self, VALUE seq_id, VALUE p0, VALUE p1) {
2067
+ LLaMAContextWrapper* ptr = get_llama_context(self);
2068
+ if (ptr->ctx == NULL) {
2069
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2070
+ return Qnil;
2071
+ }
2072
+ llama_kv_cache_seq_rm(ptr->ctx, NUM2INT(seq_id), NUM2INT(p0), NUM2INT(p1));
2073
+ return Qnil;
2074
+ }
2075
+
2076
+ static VALUE _llama_context_kv_cache_seq_cp(VALUE self, VALUE seq_id_src, VALUE seq_id_dst, VALUE p0, VALUE p1) {
2077
+ LLaMAContextWrapper* ptr = get_llama_context(self);
2078
+ if (ptr->ctx == NULL) {
2079
+ rb_raise(rb_eArgError, "LLaMA context is not initialized");
2080
+ return Qnil;
2081
+ }
2082
+ llama_kv_cache_seq_cp(ptr->ctx, NUM2INT(seq_id_src), NUM2INT(seq_id_dst), NUM2INT(p0), NUM2INT(p1));
2083
+ return Qnil;
2084
+ }
2085
+
2086
+ static VALUE _llama_context_kv_cache_seq_keep(VALUE self, VALUE seq_id) {
2087
+ LLaMAContextWrapper* ptr = get_llama_context(self);
2088
+ if (ptr->ctx == NULL) {
2089
+ rb_raise(rb_eArgError, "LLaMA context is not initialized");
2090
+ return Qnil;
2091
+ }
2092
+ llama_kv_cache_seq_keep(ptr->ctx, NUM2INT(seq_id));
2093
+ return Qnil;
2094
+ }
2095
+
2096
+ static VALUE _llama_context_kv_cache_seq_shift(VALUE self, VALUE seq_id, VALUE p0, VALUE p1, VALUE delta) {
2097
+ LLaMAContextWrapper* ptr = get_llama_context(self);
2098
+ if (ptr->ctx == NULL) {
2099
+ rb_raise(rb_eArgError, "LLaMA context is not initialized");
2100
+ return Qnil;
2101
+ }
2102
+ llama_kv_cache_seq_shift(ptr->ctx, NUM2INT(seq_id), NUM2INT(p0), NUM2INT(p1), NUM2INT(delta));
2103
+ return Qnil;
2104
+ }
2105
+
1816
2106
  static VALUE _llama_context_set_rng_seed(VALUE self, VALUE seed_) {
1817
2107
  LLaMAContextWrapper* ptr = get_llama_context(self);
1818
2108
  if (ptr->ctx == NULL) {
@@ -1851,7 +2141,7 @@ private:
1851
2141
  }
1852
2142
 
1853
2143
  VALUE model = rb_iv_get(self, "@model");
1854
- VALUE params = rb_iv_get(model, "@params");
2144
+ VALUE params = rb_iv_get(self, "@params");
1855
2145
  LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
1856
2146
  const int n_ctx = prms_ptr->params.n_ctx;
1857
2147
 
@@ -2235,6 +2525,40 @@ private:
2235
2525
  return Qnil;
2236
2526
  }
2237
2527
 
2528
+ static VALUE _llama_context_sample_temp(int argc, VALUE* argv, VALUE self) {
2529
+ VALUE kw_args = Qnil;
2530
+ ID kw_table[1] = { rb_intern("temp") };
2531
+ VALUE kw_values[1] = { Qundef };
2532
+ VALUE candidates = Qnil;
2533
+ rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
2534
+ rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
2535
+
2536
+ if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
2537
+ rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
2538
+ return Qnil;
2539
+ }
2540
+ if (!RB_FLOAT_TYPE_P(kw_values[0])) {
2541
+ rb_raise(rb_eArgError, "temp must be a float");
2542
+ return Qnil;
2543
+ }
2544
+
2545
+ LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
2546
+ if (ctx_ptr->ctx == NULL) {
2547
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2548
+ return Qnil;
2549
+ }
2550
+ LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
2551
+ if (cnd_ptr->array.data == nullptr) {
2552
+ rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
2553
+ return Qnil;
2554
+ }
2555
+ const float temp = NUM2DBL(kw_values[0]);
2556
+
2557
+ llama_sample_temp(ctx_ptr->ctx, &(cnd_ptr->array), temp);
2558
+
2559
+ return Qnil;
2560
+ }
2561
+
2238
2562
  static VALUE _llama_context_sample_temperature(int argc, VALUE* argv, VALUE self) {
2239
2563
  VALUE kw_args = Qnil;
2240
2564
  ID kw_table[1] = { rb_intern("temperature") };
@@ -2560,6 +2884,7 @@ extern "C" void Init_llama_cpp(void) {
2560
2884
  RbLLaMATokenData::define_class(rb_mLLaMACpp);
2561
2885
  RbLLaMATokenDataArray::define_class(rb_mLLaMACpp);
2562
2886
  RbLLaMAModel::define_class(rb_mLLaMACpp);
2887
+ RbLLaMAModelParams::define_class(rb_mLLaMACpp);
2563
2888
  RbLLaMATimings::define_class(rb_mLLaMACpp);
2564
2889
  RbLLaMAContext::define_class(rb_mLLaMACpp);
2565
2890
  RbLLaMAContextParams::define_class(rb_mLLaMACpp);
@@ -2578,10 +2903,6 @@ extern "C" void Init_llama_cpp(void) {
2578
2903
 
2579
2904
  rb_define_const(rb_mLLaMACpp, "LLAMA_MAX_DEVICES", INT2NUM(LLAMA_MAX_DEVICES));
2580
2905
 
2581
- rb_define_const(rb_mLLaMACpp, "LLAMA_LOG_LEVEL_ERROR", INT2NUM(LLAMA_LOG_LEVEL_ERROR));
2582
- rb_define_const(rb_mLLaMACpp, "LLAMA_LOG_LEVEL_WARN", INT2NUM(LLAMA_LOG_LEVEL_WARN));
2583
- rb_define_const(rb_mLLaMACpp, "LLAMA_LOG_LEVEL_INFO", INT2NUM(LLAMA_LOG_LEVEL_INFO));
2584
-
2585
2906
  rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_TYPE_SPM", INT2NUM(LLAMA_VOCAB_TYPE_SPM));
2586
2907
  rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_TYPE_BPE", INT2NUM(LLAMA_VOCAB_TYPE_BPE));
2587
2908