llama_cpp 0.5.3 → 0.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,9 @@
1
1
  #include "llama_cpp.h"
2
2
 
3
3
  VALUE rb_mLLaMACpp;
4
+ VALUE rb_cLLaMABatch;
4
5
  VALUE rb_cLLaMAModel;
6
+ VALUE rb_cLLaMAModelParams;
5
7
  VALUE rb_cLLaMATimings;
6
8
  VALUE rb_cLLaMAContext;
7
9
  VALUE rb_cLLaMAContextParams;
@@ -11,6 +13,238 @@ VALUE rb_cLLaMATokenDataArray;
11
13
  VALUE rb_cLLaMAGrammarElement;
12
14
  VALUE rb_cLLaMAGrammar;
13
15
 
16
+ class LLaMABatchWrapper {
17
+ public:
18
+ llama_batch batch;
19
+
20
+ LLaMABatchWrapper() {}
21
+
22
+ ~LLaMABatchWrapper() {
23
+ llama_batch_free(batch);
24
+ }
25
+ };
26
+
27
+ class RbLLaMABatch {
28
+ public:
29
+ static VALUE llama_batch_alloc(VALUE self) {
30
+ LLaMABatchWrapper* ptr = (LLaMABatchWrapper*)ruby_xmalloc(sizeof(LLaMABatchWrapper));
31
+ new (ptr) LLaMABatchWrapper();
32
+ return TypedData_Wrap_Struct(self, &llama_batch_type, ptr);
33
+ }
34
+
35
+ static void llama_batch_free(void* ptr) {
36
+ ((LLaMABatchWrapper*)ptr)->~LLaMABatchWrapper();
37
+ ruby_xfree(ptr);
38
+ }
39
+
40
+ static size_t llama_batch_size(const void* ptr) {
41
+ return sizeof(*((LLaMABatchWrapper*)ptr));
42
+ }
43
+
44
+ static LLaMABatchWrapper* get_llama_batch(VALUE self) {
45
+ LLaMABatchWrapper* ptr;
46
+ TypedData_Get_Struct(self, LLaMABatchWrapper, &llama_batch_type, ptr);
47
+ return ptr;
48
+ }
49
+
50
+ static void define_class(VALUE outer) {
51
+ rb_cLLaMABatch = rb_define_class_under(outer, "Batch", rb_cObject);
52
+ rb_define_alloc_func(rb_cLLaMABatch, llama_batch_alloc);
53
+ rb_define_method(rb_cLLaMABatch, "initialize", RUBY_METHOD_FUNC(_llama_batch_initialize), -1);
54
+ rb_define_method(rb_cLLaMABatch, "n_tokens=", RUBY_METHOD_FUNC(_llama_batch_set_n_tokens), 1);
55
+ rb_define_method(rb_cLLaMABatch, "n_tokens", RUBY_METHOD_FUNC(_llama_batch_get_n_tokens), 0);
56
+ rb_define_method(rb_cLLaMABatch, "all_pos_zero=", RUBY_METHOD_FUNC(_llama_batch_set_all_pos_zero), 1);
57
+ rb_define_method(rb_cLLaMABatch, "all_pos_zero", RUBY_METHOD_FUNC(_llama_batch_get_all_pos_zero), 0);
58
+ rb_define_method(rb_cLLaMABatch, "all_pos_one=", RUBY_METHOD_FUNC(_llama_batch_set_all_pos_one), 1);
59
+ rb_define_method(rb_cLLaMABatch, "all_pos_one", RUBY_METHOD_FUNC(_llama_batch_get_all_pos_one), 0);
60
+ rb_define_method(rb_cLLaMABatch, "all_seq_id=", RUBY_METHOD_FUNC(_llama_batch_set_all_seq_id), 1);
61
+ rb_define_method(rb_cLLaMABatch, "all_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_all_seq_id), 0);
62
+ rb_define_method(rb_cLLaMABatch, "set_token", RUBY_METHOD_FUNC(_llama_batch_set_token), 2);
63
+ rb_define_method(rb_cLLaMABatch, "get_token", RUBY_METHOD_FUNC(_llama_batch_get_token), 1);
64
+ rb_define_method(rb_cLLaMABatch, "set_pos", RUBY_METHOD_FUNC(_llama_batch_set_pos), 2);
65
+ rb_define_method(rb_cLLaMABatch, "get_pos", RUBY_METHOD_FUNC(_llama_batch_get_pos), 1);
66
+ rb_define_method(rb_cLLaMABatch, "set_seq_id", RUBY_METHOD_FUNC(_llama_batch_set_seq_id), 2);
67
+ rb_define_method(rb_cLLaMABatch, "get_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_seq_id), 1);
68
+ rb_define_method(rb_cLLaMABatch, "set_logits", RUBY_METHOD_FUNC(_llama_batch_set_logits), 2);
69
+ rb_define_method(rb_cLLaMABatch, "get_logits", RUBY_METHOD_FUNC(_llama_batch_get_logits), 1);
70
+ }
71
+
72
+ private:
73
+ static const rb_data_type_t llama_batch_type;
74
+
75
+ static VALUE _llama_batch_initialize(int argc, VALUE* argv, VALUE self) {
76
+ VALUE kw_args = Qnil;
77
+ ID kw_table[2] = { rb_intern("n_tokens"), rb_intern("embd") };
78
+ VALUE kw_values[2] = { Qundef, Qundef };
79
+ rb_scan_args(argc, argv, ":", &kw_args);
80
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
81
+
82
+ if (!RB_INTEGER_TYPE_P(kw_values[0])) {
83
+ rb_raise(rb_eArgError, "n_tokens must be an integer");
84
+ return Qnil;
85
+ }
86
+ if (!RB_INTEGER_TYPE_P(kw_values[1])) {
87
+ rb_raise(rb_eArgError, "embd must be an integer");
88
+ return Qnil;
89
+ }
90
+
91
+ const int32_t n_tokens = NUM2INT(kw_values[0]);
92
+ const int32_t embd = NUM2INT(kw_values[1]);
93
+
94
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
95
+ ptr->batch = llama_batch_init(n_tokens, embd);
96
+
97
+ return Qnil;
98
+ }
99
+
100
+ // n_tokens
101
+ static VALUE _llama_batch_set_n_tokens(VALUE self, VALUE n_tokens) {
102
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
103
+ ptr->batch.n_tokens = NUM2INT(n_tokens);
104
+ return INT2NUM(ptr->batch.n_tokens);
105
+ }
106
+
107
+ static VALUE _llama_batch_get_n_tokens(VALUE self) {
108
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
109
+ return INT2NUM(ptr->batch.n_tokens);
110
+ }
111
+
112
+ // all_pos_0
113
+ static VALUE _llama_batch_set_all_pos_zero(VALUE self, VALUE all_pos_0) {
114
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
115
+ ptr->batch.all_pos_0 = NUM2INT(all_pos_0);
116
+ return INT2NUM(ptr->batch.all_pos_0);
117
+ }
118
+
119
+ static VALUE _llama_batch_get_all_pos_zero(VALUE self) {
120
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
121
+ return INT2NUM(ptr->batch.all_pos_0);
122
+ }
123
+
124
+ // all_pos_1
125
+ static VALUE _llama_batch_set_all_pos_one(VALUE self, VALUE all_pos_1) {
126
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
127
+ ptr->batch.all_pos_1 = NUM2INT(all_pos_1);
128
+ return INT2NUM(ptr->batch.all_pos_1);
129
+ }
130
+
131
+ static VALUE _llama_batch_get_all_pos_one(VALUE self) {
132
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
133
+ return INT2NUM(ptr->batch.all_pos_1);
134
+ }
135
+
136
+ // all_seq_id
137
+ static VALUE _llama_batch_set_all_seq_id(VALUE self, VALUE all_seq_id) {
138
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
139
+ ptr->batch.all_seq_id = NUM2INT(all_seq_id);
140
+ return INT2NUM(ptr->batch.all_seq_id);
141
+ }
142
+
143
+ static VALUE _llama_batch_get_all_seq_id(VALUE self) {
144
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
145
+ return INT2NUM(ptr->batch.all_seq_id);
146
+ }
147
+
148
+ // token
149
+ static VALUE _llama_batch_set_token(VALUE self, VALUE idx, VALUE value) {
150
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
151
+ const int32_t id = NUM2INT(idx);
152
+ if (id < 0 || id >= ptr->batch.n_tokens) {
153
+ rb_raise(rb_eArgError, "idx must be in [0, n_tokens)");
154
+ return Qnil;
155
+ }
156
+ ptr->batch.token[id] = NUM2INT(value);
157
+ return INT2NUM(ptr->batch.token[id]);
158
+ }
159
+
160
+ static VALUE _llama_batch_get_token(VALUE self, VALUE idx) {
161
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
162
+ const int32_t id = NUM2INT(idx);
163
+ if (id < 0 || id >= ptr->batch.n_tokens) {
164
+ rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
165
+ return Qnil;
166
+ }
167
+ return INT2NUM(ptr->batch.token[id]);
168
+ }
169
+
170
+ // pos
171
+ static VALUE _llama_batch_set_pos(VALUE self, VALUE idx, VALUE value) {
172
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
173
+ const int32_t id = NUM2INT(idx);
174
+ if (id < 0 || id >= ptr->batch.n_tokens) {
175
+ rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
176
+ return Qnil;
177
+ }
178
+ ptr->batch.pos[id] = NUM2INT(value);
179
+ return INT2NUM(ptr->batch.pos[id]);
180
+ }
181
+
182
+ static VALUE _llama_batch_get_pos(VALUE self, VALUE idx) {
183
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
184
+ const int32_t id = NUM2INT(idx);
185
+ if (id < 0 || id >= ptr->batch.n_tokens) {
186
+ rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
187
+ return Qnil;
188
+ }
189
+ return INT2NUM(ptr->batch.pos[id]);
190
+ }
191
+
192
+ // seq_id
193
+ static VALUE _llama_batch_set_seq_id(VALUE self, VALUE idx, VALUE value) {
194
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
195
+ const int32_t id = NUM2INT(idx);
196
+ if (id < 0 || id >= ptr->batch.n_tokens) {
197
+ rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
198
+ return Qnil;
199
+ }
200
+ ptr->batch.seq_id[id] = NUM2INT(value);
201
+ return INT2NUM(ptr->batch.seq_id[id]);
202
+ }
203
+
204
+ static VALUE _llama_batch_get_seq_id(VALUE self, VALUE idx) {
205
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
206
+ const int32_t id = NUM2INT(idx);
207
+ if (id < 0 || id >= ptr->batch.n_tokens) {
208
+ rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
209
+ return Qnil;
210
+ }
211
+ return INT2NUM(ptr->batch.seq_id[id]);
212
+ }
213
+
214
+ // logits
215
+ static VALUE _llama_batch_set_logits(VALUE self, VALUE idx, VALUE value) {
216
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
217
+ const int32_t id = NUM2INT(idx);
218
+ if (id < 0 || id >= ptr->batch.n_tokens) {
219
+ rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
220
+ return Qnil;
221
+ }
222
+ ptr->batch.logits[id] = RTEST(value) ? true : false;
223
+ return ptr->batch.logits[id] ? Qtrue : Qfalse;
224
+ }
225
+
226
+ static VALUE _llama_batch_get_logits(VALUE self, VALUE idx) {
227
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
228
+ const int32_t id = NUM2INT(idx);
229
+ if (id < 0 || id >= ptr->batch.n_tokens) {
230
+ rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
231
+ return Qnil;
232
+ }
233
+ return ptr->batch.logits[id] ? Qtrue : Qfalse;
234
+ }
235
+ };
236
+
237
+ const rb_data_type_t RbLLaMABatch::llama_batch_type = {
238
+ "RbLLaMABatch",
239
+ { NULL,
240
+ RbLLaMABatch::llama_batch_free,
241
+ RbLLaMABatch::llama_batch_size },
242
+ NULL,
243
+ NULL,
244
+ RUBY_TYPED_FREE_IMMEDIATELY
245
+
246
+ };
247
+
14
248
  class LLaMATokenDataWrapper {
15
249
  public:
16
250
  llama_token_data data;
@@ -363,6 +597,144 @@ const rb_data_type_t RbLLaMATimings::llama_timings_type = {
363
597
  RUBY_TYPED_FREE_IMMEDIATELY
364
598
  };
365
599
 
600
+ class LLaMAModelParamsWrapper {
601
+ public:
602
+ struct llama_model_params params;
603
+
604
+ LLaMAModelParamsWrapper() : params(llama_model_default_params()) {}
605
+
606
+ ~LLaMAModelParamsWrapper() {}
607
+ };
608
+
609
+ class RbLLaMAModelParams {
610
+ public:
611
+ static VALUE llama_model_params_alloc(VALUE self) {
612
+ LLaMAModelParamsWrapper* ptr = (LLaMAModelParamsWrapper*)ruby_xmalloc(sizeof(LLaMAModelParamsWrapper));
613
+ new (ptr) LLaMAModelParamsWrapper();
614
+ return TypedData_Wrap_Struct(self, &llama_model_params_type, ptr);
615
+ }
616
+
617
+ static void llama_model_params_free(void* ptr) {
618
+ ((LLaMAModelParamsWrapper*)ptr)->~LLaMAModelParamsWrapper();
619
+ ruby_xfree(ptr);
620
+ }
621
+
622
+ static size_t llama_model_params_size(const void* ptr) {
623
+ return sizeof(*((LLaMAModelParamsWrapper*)ptr));
624
+ }
625
+
626
+ static LLaMAModelParamsWrapper* get_llama_model_params(VALUE self) {
627
+ LLaMAModelParamsWrapper* ptr;
628
+ TypedData_Get_Struct(self, LLaMAModelParamsWrapper, &llama_model_params_type, ptr);
629
+ return ptr;
630
+ }
631
+
632
+ static void define_class(VALUE outer) {
633
+ rb_cLLaMAModelParams = rb_define_class_under(outer, "ModelParams", rb_cObject);
634
+ rb_define_alloc_func(rb_cLLaMAModelParams, llama_model_params_alloc);
635
+ rb_define_method(rb_cLLaMAModelParams, "n_gpu_layers=", RUBY_METHOD_FUNC(_llama_model_params_set_n_gpu_layers), 1);
636
+ rb_define_method(rb_cLLaMAModelParams, "n_gpu_layers", RUBY_METHOD_FUNC(_llama_model_params_get_n_gpu_layers), 0);
637
+ rb_define_method(rb_cLLaMAModelParams, "main_gpu=", RUBY_METHOD_FUNC(_llama_model_params_set_main_gpu), 1);
638
+ rb_define_method(rb_cLLaMAModelParams, "main_gpu", RUBY_METHOD_FUNC(_llama_model_params_get_main_gpu), 0);
639
+ rb_define_method(rb_cLLaMAModelParams, "tensor_split", RUBY_METHOD_FUNC(_llama_model_params_get_tensor_split), 0);
640
+ rb_define_method(rb_cLLaMAModelParams, "vocab_only=", RUBY_METHOD_FUNC(_llama_model_params_set_vocab_only), 1);
641
+ rb_define_method(rb_cLLaMAModelParams, "vocab_only", RUBY_METHOD_FUNC(_llama_model_params_get_vocab_only), 0);
642
+ rb_define_method(rb_cLLaMAModelParams, "use_mmap=", RUBY_METHOD_FUNC(_llama_model_params_set_use_mmap), 1);
643
+ rb_define_method(rb_cLLaMAModelParams, "use_mmap", RUBY_METHOD_FUNC(_llama_model_params_get_use_mmap), 0);
644
+ rb_define_method(rb_cLLaMAModelParams, "use_mlock=", RUBY_METHOD_FUNC(_llama_model_params_set_use_mlock), 1);
645
+ rb_define_method(rb_cLLaMAModelParams, "use_mlock", RUBY_METHOD_FUNC(_llama_model_params_get_use_mlock), 0);
646
+ }
647
+
648
+ private:
649
+ static const rb_data_type_t llama_model_params_type;
650
+
651
+ // n_gpu_layers
652
+ static VALUE _llama_model_params_set_n_gpu_layers(VALUE self, VALUE n_gpu_layers) {
653
+ LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
654
+ ptr->params.n_gpu_layers = NUM2INT(n_gpu_layers);
655
+ return INT2NUM(ptr->params.n_gpu_layers);
656
+ }
657
+
658
+ static VALUE _llama_model_params_get_n_gpu_layers(VALUE self) {
659
+ LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
660
+ return INT2NUM(ptr->params.n_gpu_layers);
661
+ }
662
+
663
+ // main_gpu
664
+ static VALUE _llama_model_params_set_main_gpu(VALUE self, VALUE main_gpu) {
665
+ LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
666
+ ptr->params.main_gpu = NUM2INT(main_gpu);
667
+ return INT2NUM(ptr->params.main_gpu);
668
+ }
669
+
670
+ static VALUE _llama_model_params_get_main_gpu(VALUE self) {
671
+ LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
672
+ return INT2NUM(ptr->params.main_gpu);
673
+ }
674
+
675
+ // tensor_split
676
+ static VALUE _llama_model_params_get_tensor_split(VALUE self) {
677
+ if (LLAMA_MAX_DEVICES < 1) {
678
+ return rb_ary_new();
679
+ }
680
+ VALUE ret = rb_ary_new2(LLAMA_MAX_DEVICES);
681
+ LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
682
+ if (ptr->params.tensor_split == nullptr) {
683
+ return rb_ary_new();
684
+ }
685
+ for (size_t i = 0; i < LLAMA_MAX_DEVICES; i++) {
686
+ rb_ary_store(ret, i, DBL2NUM(ptr->params.tensor_split[i]));
687
+ }
688
+ return ret;
689
+ }
690
+
691
+ // vocab_only
692
+ static VALUE _llama_model_params_set_vocab_only(VALUE self, VALUE vocab_only) {
693
+ LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
694
+ ptr->params.vocab_only = RTEST(vocab_only) ? true : false;
695
+ return ptr->params.vocab_only ? Qtrue : Qfalse;
696
+ }
697
+
698
+ static VALUE _llama_model_params_get_vocab_only(VALUE self) {
699
+ LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
700
+ return ptr->params.vocab_only ? Qtrue : Qfalse;
701
+ }
702
+
703
+ // use_mmap
704
+ static VALUE _llama_model_params_set_use_mmap(VALUE self, VALUE use_mmap) {
705
+ LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
706
+ ptr->params.use_mmap = RTEST(use_mmap) ? true : false;
707
+ return ptr->params.use_mmap ? Qtrue : Qfalse;
708
+ }
709
+
710
+ static VALUE _llama_model_params_get_use_mmap(VALUE self) {
711
+ LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
712
+ return ptr->params.use_mmap ? Qtrue : Qfalse;
713
+ }
714
+
715
+ // use_mlock
716
+ static VALUE _llama_model_params_set_use_mlock(VALUE self, VALUE use_mlock) {
717
+ LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
718
+ ptr->params.use_mlock = RTEST(use_mlock) ? true : false;
719
+ return ptr->params.use_mlock ? Qtrue : Qfalse;
720
+ }
721
+
722
+ static VALUE _llama_model_params_get_use_mlock(VALUE self) {
723
+ LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
724
+ return ptr->params.use_mlock ? Qtrue : Qfalse;
725
+ }
726
+ };
727
+
728
+ const rb_data_type_t RbLLaMAModelParams::llama_model_params_type = {
729
+ "RbLLaMAModelParams",
730
+ { NULL,
731
+ RbLLaMAModelParams::llama_model_params_free,
732
+ RbLLaMAModelParams::llama_model_params_size },
733
+ NULL,
734
+ NULL,
735
+ RUBY_TYPED_FREE_IMMEDIATELY
736
+ };
737
+
366
738
  class LLaMAContextParamsWrapper {
367
739
  public:
368
740
  struct llama_context_params params;
@@ -399,35 +771,26 @@ public:
399
771
  rb_cLLaMAContextParams = rb_define_class_under(outer, "ContextParams", rb_cObject);
400
772
  rb_define_alloc_func(rb_cLLaMAContextParams, llama_context_params_alloc);
401
773
  // rb_define_method(rb_cLLaMAContextParams, "initialize", RUBY_METHOD_FUNC(_llama_context_params_init), 0);
774
+ rb_define_method(rb_cLLaMAContextParams, "seed=", RUBY_METHOD_FUNC(_llama_context_params_set_seed), 1);
775
+ rb_define_method(rb_cLLaMAContextParams, "seed", RUBY_METHOD_FUNC(_llama_context_params_get_seed), 0);
402
776
  rb_define_method(rb_cLLaMAContextParams, "n_ctx=", RUBY_METHOD_FUNC(_llama_context_params_set_n_ctx), 1);
403
777
  rb_define_method(rb_cLLaMAContextParams, "n_ctx", RUBY_METHOD_FUNC(_llama_context_params_get_n_ctx), 0);
404
778
  rb_define_method(rb_cLLaMAContextParams, "n_batch=", RUBY_METHOD_FUNC(_llama_context_params_set_n_batch), 1);
405
779
  rb_define_method(rb_cLLaMAContextParams, "n_batch", RUBY_METHOD_FUNC(_llama_context_params_get_n_batch), 0);
406
- rb_define_method(rb_cLLaMAContextParams, "n_gpu_layers=", RUBY_METHOD_FUNC(_llama_context_params_set_n_gpu_layers), 1);
407
- rb_define_method(rb_cLLaMAContextParams, "n_gpu_layers", RUBY_METHOD_FUNC(_llama_context_params_get_n_gpu_layers), 0);
408
- rb_define_method(rb_cLLaMAContextParams, "main_gpu=", RUBY_METHOD_FUNC(_llama_context_params_set_main_gpu), 1);
409
- rb_define_method(rb_cLLaMAContextParams, "main_gpu", RUBY_METHOD_FUNC(_llama_context_params_get_main_gpu), 0);
410
- rb_define_method(rb_cLLaMAContextParams, "tensor_split", RUBY_METHOD_FUNC(_llama_context_params_get_tensor_split), 0);
780
+ rb_define_method(rb_cLLaMAContextParams, "n_threads=", RUBY_METHOD_FUNC(_llama_context_params_set_n_threads), 1);
781
+ rb_define_method(rb_cLLaMAContextParams, "n_threads", RUBY_METHOD_FUNC(_llama_context_params_get_n_threads), 0);
782
+ rb_define_method(rb_cLLaMAContextParams, "n_threads_batch=", RUBY_METHOD_FUNC(_llama_context_params_set_n_threads_batch), 1);
783
+ rb_define_method(rb_cLLaMAContextParams, "n_threads_batch", RUBY_METHOD_FUNC(_llama_context_params_get_n_threads_batch), 0);
411
784
  rb_define_method(rb_cLLaMAContextParams, "rope_freq_base=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_freq_base), 1);
412
785
  rb_define_method(rb_cLLaMAContextParams, "rope_freq_base", RUBY_METHOD_FUNC(_llama_context_params_get_rope_freq_base), 0);
413
786
  rb_define_method(rb_cLLaMAContextParams, "rope_freq_scale=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_freq_scale), 1);
414
787
  rb_define_method(rb_cLLaMAContextParams, "rope_freq_scale", RUBY_METHOD_FUNC(_llama_context_params_get_rope_freq_scale), 0);
415
- rb_define_method(rb_cLLaMAContextParams, "low_vram=", RUBY_METHOD_FUNC(_llama_context_params_set_low_vram), 1);
416
- rb_define_method(rb_cLLaMAContextParams, "low_vram", RUBY_METHOD_FUNC(_llama_context_params_get_low_vram), 0);
417
788
  rb_define_method(rb_cLLaMAContextParams, "mul_mat_q=", RUBY_METHOD_FUNC(_llama_context_params_set_mul_mat_q), 1);
418
789
  rb_define_method(rb_cLLaMAContextParams, "mul_mat_q", RUBY_METHOD_FUNC(_llama_context_params_get_mul_mat_q), 0);
419
- rb_define_method(rb_cLLaMAContextParams, "seed=", RUBY_METHOD_FUNC(_llama_context_params_set_seed), 1);
420
- rb_define_method(rb_cLLaMAContextParams, "seed", RUBY_METHOD_FUNC(_llama_context_params_get_seed), 0);
421
790
  rb_define_method(rb_cLLaMAContextParams, "f16_kv=", RUBY_METHOD_FUNC(_llama_context_params_set_f16_kv), 1);
422
791
  rb_define_method(rb_cLLaMAContextParams, "f16_kv", RUBY_METHOD_FUNC(_llama_context_params_get_f16_kv), 0);
423
792
  rb_define_method(rb_cLLaMAContextParams, "logits_all=", RUBY_METHOD_FUNC(_llama_context_params_set_logits_all), 1);
424
793
  rb_define_method(rb_cLLaMAContextParams, "logits_all", RUBY_METHOD_FUNC(_llama_context_params_get_logits_all), 0);
425
- rb_define_method(rb_cLLaMAContextParams, "vocab_only=", RUBY_METHOD_FUNC(_llama_context_params_set_vocab_only), 1);
426
- rb_define_method(rb_cLLaMAContextParams, "vocab_only", RUBY_METHOD_FUNC(_llama_context_params_get_vocab_only), 0);
427
- rb_define_method(rb_cLLaMAContextParams, "use_mmap=", RUBY_METHOD_FUNC(_llama_context_params_set_use_mmap), 1);
428
- rb_define_method(rb_cLLaMAContextParams, "use_mmap", RUBY_METHOD_FUNC(_llama_context_params_get_use_mmap), 0);
429
- rb_define_method(rb_cLLaMAContextParams, "use_mlock=", RUBY_METHOD_FUNC(_llama_context_params_set_use_mlock), 1);
430
- rb_define_method(rb_cLLaMAContextParams, "use_mlock", RUBY_METHOD_FUNC(_llama_context_params_get_use_mlock), 0);
431
794
  rb_define_method(rb_cLLaMAContextParams, "embedding=", RUBY_METHOD_FUNC(_llama_context_params_set_embedding), 1);
432
795
  rb_define_method(rb_cLLaMAContextParams, "embedding", RUBY_METHOD_FUNC(_llama_context_params_get_embedding), 0);
433
796
  }
@@ -441,6 +804,22 @@ private:
441
804
  // return self;
442
805
  // }
443
806
 
807
+ // seed
808
+ static VALUE _llama_context_params_set_seed(VALUE self, VALUE seed) {
809
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
810
+ if (NUM2INT(seed) < 0) {
811
+ rb_raise(rb_eArgError, "seed must be positive");
812
+ return Qnil;
813
+ }
814
+ ptr->params.seed = NUM2INT(seed);
815
+ return INT2NUM(ptr->params.seed);
816
+ }
817
+
818
+ static VALUE _llama_context_params_get_seed(VALUE self) {
819
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
820
+ return INT2NUM(ptr->params.seed);
821
+ }
822
+
444
823
  // n_ctx
445
824
  static VALUE _llama_context_params_set_n_ctx(VALUE self, VALUE n_ctx) {
446
825
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
@@ -465,41 +844,28 @@ private:
465
844
  return INT2NUM(ptr->params.n_batch);
466
845
  }
467
846
 
468
- // n_gpu_layers
469
- static VALUE _llama_context_params_set_n_gpu_layers(VALUE self, VALUE n_gpu_layers) {
847
+ // n_threads
848
+ static VALUE _llama_context_params_set_n_threads(VALUE self, VALUE n_threads) {
470
849
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
471
- ptr->params.n_gpu_layers = NUM2INT(n_gpu_layers);
472
- return INT2NUM(ptr->params.n_gpu_layers);
850
+ ptr->params.n_threads = NUM2INT(n_threads);
851
+ return INT2NUM(ptr->params.n_threads);
473
852
  }
474
853
 
475
- static VALUE _llama_context_params_get_n_gpu_layers(VALUE self) {
854
+ static VALUE _llama_context_params_get_n_threads(VALUE self) {
476
855
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
477
- return INT2NUM(ptr->params.n_gpu_layers);
856
+ return INT2NUM(ptr->params.n_threads);
478
857
  }
479
858
 
480
- // main_gpu
481
- static VALUE _llama_context_params_set_main_gpu(VALUE self, VALUE main_gpu) {
859
+ // n_threads_batch
860
+ static VALUE _llama_context_params_set_n_threads_batch(VALUE self, VALUE n_threads_batch) {
482
861
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
483
- ptr->params.main_gpu = NUM2INT(main_gpu);
484
- return INT2NUM(ptr->params.main_gpu);
862
+ ptr->params.n_threads_batch = NUM2INT(n_threads_batch);
863
+ return INT2NUM(ptr->params.n_threads_batch);
485
864
  }
486
865
 
487
- static VALUE _llama_context_params_get_main_gpu(VALUE self) {
866
+ static VALUE _llama_context_params_get_n_threads_batch(VALUE self) {
488
867
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
489
- return INT2NUM(ptr->params.main_gpu);
490
- }
491
-
492
- // tensor_split
493
- static VALUE _llama_context_params_get_tensor_split(VALUE self) {
494
- if (LLAMA_MAX_DEVICES < 1) {
495
- return rb_ary_new();
496
- }
497
- VALUE ret = rb_ary_new2(LLAMA_MAX_DEVICES);
498
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
499
- for (size_t i = 0; i < LLAMA_MAX_DEVICES; i++) {
500
- rb_ary_store(ret, i, DBL2NUM(ptr->params.tensor_split[i]));
501
- }
502
- return ret;
868
+ return INT2NUM(ptr->params.n_threads_batch);
503
869
  }
504
870
 
505
871
  // rope_freq_base
@@ -526,18 +892,6 @@ private:
526
892
  return DBL2NUM(ptr->params.rope_freq_scale);
527
893
  }
528
894
 
529
- // low_vram
530
- static VALUE _llama_context_params_set_low_vram(VALUE self, VALUE low_vram) {
531
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
532
- ptr->params.low_vram = RTEST(low_vram) ? true : false;
533
- return ptr->params.low_vram ? Qtrue : Qfalse;
534
- }
535
-
536
- static VALUE _llama_context_params_get_low_vram(VALUE self) {
537
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
538
- return ptr->params.low_vram ? Qtrue : Qfalse;
539
- }
540
-
541
895
  // mul_mat_q
542
896
  static VALUE _llama_context_params_set_mul_mat_q(VALUE self, VALUE mul_mat_q) {
543
897
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
@@ -550,22 +904,6 @@ private:
550
904
  return ptr->params.mul_mat_q ? Qtrue : Qfalse;
551
905
  }
552
906
 
553
- // seed
554
- static VALUE _llama_context_params_set_seed(VALUE self, VALUE seed) {
555
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
556
- if (NUM2INT(seed) < 0) {
557
- rb_raise(rb_eArgError, "seed must be positive");
558
- return Qnil;
559
- }
560
- ptr->params.seed = NUM2INT(seed);
561
- return INT2NUM(ptr->params.seed);
562
- }
563
-
564
- static VALUE _llama_context_params_get_seed(VALUE self) {
565
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
566
- return INT2NUM(ptr->params.seed);
567
- }
568
-
569
907
  // f16_kv
570
908
  static VALUE _llama_context_params_set_f16_kv(VALUE self, VALUE f16_kv) {
571
909
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
@@ -590,42 +928,6 @@ private:
590
928
  return ptr->params.logits_all ? Qtrue : Qfalse;
591
929
  }
592
930
 
593
- // vocab_only
594
- static VALUE _llama_context_params_set_vocab_only(VALUE self, VALUE vocab_only) {
595
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
596
- ptr->params.vocab_only = RTEST(vocab_only) ? true : false;
597
- return ptr->params.vocab_only ? Qtrue : Qfalse;
598
- }
599
-
600
- static VALUE _llama_context_params_get_vocab_only(VALUE self) {
601
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
602
- return ptr->params.vocab_only ? Qtrue : Qfalse;
603
- }
604
-
605
- // use_mmap
606
- static VALUE _llama_context_params_set_use_mmap(VALUE self, VALUE use_mmap) {
607
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
608
- ptr->params.use_mmap = RTEST(use_mmap) ? true : false;
609
- return ptr->params.use_mmap ? Qtrue : Qfalse;
610
- }
611
-
612
- static VALUE _llama_context_params_get_use_mmap(VALUE self) {
613
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
614
- return ptr->params.use_mmap ? Qtrue : Qfalse;
615
- }
616
-
617
- // use_mlock
618
- static VALUE _llama_context_params_set_use_mlock(VALUE self, VALUE use_mlock) {
619
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
620
- ptr->params.use_mlock = RTEST(use_mlock) ? true : false;
621
- return ptr->params.use_mlock ? Qtrue : Qfalse;
622
- }
623
-
624
- static VALUE _llama_context_params_get_use_mlock(VALUE self) {
625
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
626
- return ptr->params.use_mlock ? Qtrue : Qfalse;
627
- }
628
-
629
931
  // embedding
630
932
  static VALUE _llama_context_params_set_embedding(VALUE self, VALUE embedding) {
631
933
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
@@ -823,11 +1125,10 @@ public:
823
1125
  rb_define_method(rb_cLLaMAModel, "load", RUBY_METHOD_FUNC(_llama_model_load), -1);
824
1126
  rb_define_method(rb_cLLaMAModel, "apply_lora_from_file", RUBY_METHOD_FUNC(_llama_model_apply_lora_from_file), -1);
825
1127
  rb_define_method(rb_cLLaMAModel, "n_vocab", RUBY_METHOD_FUNC(_llama_model_get_model_n_vocab), 0);
826
- rb_define_method(rb_cLLaMAModel, "n_ctx", RUBY_METHOD_FUNC(_llama_model_get_model_n_ctx), 0);
827
1128
  rb_define_method(rb_cLLaMAModel, "n_ctx_train", RUBY_METHOD_FUNC(_llama_model_get_model_n_ctx_train), 0);
828
1129
  rb_define_method(rb_cLLaMAModel, "n_embd", RUBY_METHOD_FUNC(_llama_model_get_model_n_embd), 0);
829
- rb_define_method(rb_cLLaMAModel, "token_to_piece", RUBY_METHOD_FUNC(_llama_model_token_to_piece_with_model), 1);
830
- rb_define_method(rb_cLLaMAModel, "tokenize", RUBY_METHOD_FUNC(_llama_model_tokenize_with_model), -1);
1130
+ rb_define_method(rb_cLLaMAModel, "token_to_piece", RUBY_METHOD_FUNC(_llama_model_token_to_piece), 1);
1131
+ rb_define_method(rb_cLLaMAModel, "tokenize", RUBY_METHOD_FUNC(_llama_model_tokenize), -1);
831
1132
  rb_define_method(rb_cLLaMAModel, "desc", RUBY_METHOD_FUNC(_llama_model_get_model_desc), 0);
832
1133
  rb_define_method(rb_cLLaMAModel, "size", RUBY_METHOD_FUNC(_llama_model_get_model_size), 0);
833
1134
  rb_define_method(rb_cLLaMAModel, "n_params", RUBY_METHOD_FUNC(_llama_model_get_model_n_params), 0);
@@ -841,30 +1142,21 @@ private:
841
1142
  ID kw_table[2] = { rb_intern("model_path"), rb_intern("params") };
842
1143
  VALUE kw_values[2] = { Qundef, Qundef };
843
1144
  rb_scan_args(argc, argv, ":", &kw_args);
844
- rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
845
-
846
- if (kw_values[0] == Qundef && kw_values[1] == Qundef) {
847
- rb_iv_set(self, "@params", Qnil);
848
- return Qnil;
849
- }
1145
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
850
1146
 
851
1147
  if (!RB_TYPE_P(kw_values[0], T_STRING)) {
852
1148
  rb_raise(rb_eArgError, "model_path must be a string");
853
1149
  return Qnil;
854
1150
  }
855
- if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAContextParams)) {
856
- rb_raise(rb_eArgError, "params must be a ContextParams");
1151
+ if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAModelParams)) {
1152
+ rb_raise(rb_eArgError, "params must be a ModelParams");
857
1153
  return Qnil;
858
1154
  }
859
1155
 
860
1156
  VALUE filename = kw_values[0];
861
- LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(kw_values[1]);
1157
+ LLaMAModelParamsWrapper* prms_ptr = RbLLaMAModelParams::get_llama_model_params(kw_values[1]);
862
1158
  LLaMAModelWrapper* model_ptr = get_llama_model(self);
863
1159
 
864
- if (prms_ptr->params.seed == LLAMA_DEFAULT_SEED) {
865
- prms_ptr->params.seed = time(NULL);
866
- }
867
-
868
1160
  try {
869
1161
  model_ptr->model = llama_load_model_from_file(StringValueCStr(filename), prms_ptr->params);
870
1162
  } catch (const std::runtime_error& e) {
@@ -912,8 +1204,8 @@ private:
912
1204
  rb_raise(rb_eArgError, "model_path must be a string");
913
1205
  return Qnil;
914
1206
  }
915
- if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAContextParams)) {
916
- rb_raise(rb_eArgError, "params must be a LLaMAContextParams");
1207
+ if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAModelParams)) {
1208
+ rb_raise(rb_eArgError, "params must be a LLaMAModelParams");
917
1209
  return Qnil;
918
1210
  }
919
1211
 
@@ -924,7 +1216,7 @@ private:
924
1216
  }
925
1217
 
926
1218
  VALUE filename = kw_values[0];
927
- LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(kw_values[1]);
1219
+ LLaMAModelParamsWrapper* prms_ptr = RbLLaMAModelParams::get_llama_model_params(kw_values[1]);
928
1220
 
929
1221
  try {
930
1222
  model_ptr->model = llama_load_model_from_file(StringValueCStr(filename), prms_ptr->params);
@@ -946,10 +1238,10 @@ private:
946
1238
 
947
1239
  static VALUE _llama_model_apply_lora_from_file(int argc, VALUE* argv, VALUE self) {
948
1240
  VALUE kw_args = Qnil;
949
- ID kw_table[3] = { rb_intern("lora_path"), rb_intern("base_model_path"), rb_intern("n_threads") };
950
- VALUE kw_values[3] = { Qundef, Qundef, Qundef };
1241
+ ID kw_table[4] = { rb_intern("lora_path"), rb_intern("base_model_path"), rb_intern("n_threads"), rb_intern("scale") };
1242
+ VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
951
1243
  rb_scan_args(argc, argv, ":", &kw_args);
952
- rb_get_kwargs(kw_args, kw_table, 1, 2, kw_values);
1244
+ rb_get_kwargs(kw_args, kw_table, 1, 3, kw_values);
953
1245
 
954
1246
  if (!RB_TYPE_P(kw_values[0], T_STRING)) {
955
1247
  rb_raise(rb_eArgError, "lora_path must be a string");
@@ -963,13 +1255,18 @@ private:
963
1255
  rb_raise(rb_eArgError, "n_threads must be an integer");
964
1256
  return Qnil;
965
1257
  }
1258
+ if (kw_values[3] != Qundef && !RB_FLOAT_TYPE_P(kw_values[3])) {
1259
+ rb_raise(rb_eArgError, "scale must be a float");
1260
+ return Qnil;
1261
+ }
966
1262
 
967
1263
  const char* lora_path = StringValueCStr(kw_values[0]);
968
1264
  const char* base_model_path = kw_values[1] == Qundef ? NULL : StringValueCStr(kw_values[1]);
969
1265
  const int n_threads = kw_values[2] == Qundef ? 1 : NUM2INT(kw_values[2]);
1266
+ const float scale = kw_values[3] == Qundef ? 1.0 : NUM2DBL(kw_values[3]);
970
1267
 
971
1268
  LLaMAModelWrapper* ptr = get_llama_model(self);
972
- if (llama_model_apply_lora_from_file(ptr->model, lora_path, base_model_path, n_threads) != 0) {
1269
+ if (llama_model_apply_lora_from_file(ptr->model, lora_path, scale, base_model_path, n_threads) != 0) {
973
1270
  rb_raise(rb_eRuntimeError, "Failed to apply LoRA");
974
1271
  return Qnil;
975
1272
  }
@@ -978,25 +1275,20 @@ private:
978
1275
 
979
1276
  static VALUE _llama_model_get_model_n_vocab(VALUE self) {
980
1277
  LLaMAModelWrapper* ptr = get_llama_model(self);
981
- return INT2NUM(llama_model_n_vocab(ptr->model));
982
- }
983
-
984
- static VALUE _llama_model_get_model_n_ctx(VALUE self) {
985
- LLaMAModelWrapper* ptr = get_llama_model(self);
986
- return INT2NUM(llama_model_n_ctx(ptr->model));
1278
+ return INT2NUM(llama_n_vocab(ptr->model));
987
1279
  }
988
1280
 
989
1281
  static VALUE _llama_model_get_model_n_ctx_train(VALUE self) {
990
1282
  LLaMAModelWrapper* ptr = get_llama_model(self);
991
- return INT2NUM(llama_model_n_ctx_train(ptr->model));
1283
+ return INT2NUM(llama_n_ctx_train(ptr->model));
992
1284
  }
993
1285
 
994
1286
  static VALUE _llama_model_get_model_n_embd(VALUE self) {
995
1287
  LLaMAModelWrapper* ptr = get_llama_model(self);
996
- return INT2NUM(llama_model_n_embd(ptr->model));
1288
+ return INT2NUM(llama_n_embd(ptr->model));
997
1289
  }
998
1290
 
999
- static VALUE _llama_model_token_to_piece_with_model(VALUE self, VALUE token_) {
1291
+ static VALUE _llama_model_token_to_piece(VALUE self, VALUE token_) {
1000
1292
  if (!RB_INTEGER_TYPE_P(token_)) {
1001
1293
  rb_raise(rb_eArgError, "token must be an integer");
1002
1294
  return Qnil;
@@ -1004,10 +1296,10 @@ private:
1004
1296
  const llama_token token = NUM2INT(token_);
1005
1297
  LLaMAModelWrapper* ptr = get_llama_model(self);
1006
1298
  std::vector<char> result(8, 0);
1007
- const int n_tokens = llama_token_to_piece_with_model(ptr->model, token, result.data(), result.size());
1299
+ const int n_tokens = llama_token_to_piece(ptr->model, token, result.data(), result.size());
1008
1300
  if (n_tokens < 0) {
1009
1301
  result.resize(-n_tokens);
1010
- const int check = llama_token_to_piece_with_model(ptr->model, token, result.data(), result.size());
1302
+ const int check = llama_token_to_piece(ptr->model, token, result.data(), result.size());
1011
1303
  if (check != -n_tokens) {
1012
1304
  rb_raise(rb_eRuntimeError, "failed to convert");
1013
1305
  return Qnil;
@@ -1019,7 +1311,7 @@ private:
1019
1311
  return rb_str_new_cstr(ret.c_str());
1020
1312
  }
1021
1313
 
1022
- static VALUE _llama_model_tokenize_with_model(int argc, VALUE* argv, VALUE self) {
1314
+ static VALUE _llama_model_tokenize(int argc, VALUE* argv, VALUE self) {
1023
1315
  VALUE kw_args = Qnil;
1024
1316
  ID kw_table[3] = { rb_intern("text"), rb_intern("n_max_tokens"), rb_intern("add_bos") };
1025
1317
  VALUE kw_values[3] = { Qundef, Qundef, Qundef };
@@ -1046,7 +1338,7 @@ private:
1046
1338
 
1047
1339
  llama_token* tokens = ALLOCA_N(llama_token, n_max_tokens);
1048
1340
  LLaMAModelWrapper* ptr = get_llama_model(self);
1049
- const int n_tokens = llama_tokenize_with_model(ptr->model, text.c_str(), text.size(), tokens, n_max_tokens, add_bos);
1341
+ const int n_tokens = llama_tokenize(ptr->model, text.c_str(), text.size(), tokens, n_max_tokens, add_bos);
1050
1342
 
1051
1343
  if (n_tokens < 0) {
1052
1344
  rb_raise(rb_eRuntimeError, "failed to tokenize. The numebr of tokens (%d) is greater than n_max_tokens.", -n_tokens);
@@ -1345,11 +1637,11 @@ public:
1345
1637
  static void define_class(VALUE outer) {
1346
1638
  rb_cLLaMAContext = rb_define_class_under(outer, "Context", rb_cObject);
1347
1639
  rb_define_alloc_func(rb_cLLaMAContext, llama_context_alloc);
1640
+ rb_define_attr(rb_cLLaMAContext, "model", 1, 0);
1348
1641
  rb_define_method(rb_cLLaMAContext, "initialize", RUBY_METHOD_FUNC(_llama_context_initialize), -1);
1349
1642
  rb_define_method(rb_cLLaMAContext, "eval", RUBY_METHOD_FUNC(_llama_context_eval), -1);
1350
1643
  rb_define_method(rb_cLLaMAContext, "eval_embd", RUBY_METHOD_FUNC(_llama_context_eval_embd), -1);
1351
- rb_define_method(rb_cLLaMAContext, "eval_export", RUBY_METHOD_FUNC(_llama_context_eval_export), 1);
1352
- rb_define_method(rb_cLLaMAContext, "tokenize", RUBY_METHOD_FUNC(_llama_context_tokenize), -1);
1644
+ rb_define_method(rb_cLLaMAContext, "decode", RUBY_METHOD_FUNC(_llama_context_decode), 1);
1353
1645
  rb_define_method(rb_cLLaMAContext, "logits", RUBY_METHOD_FUNC(_llama_context_logits), 0);
1354
1646
  rb_define_method(rb_cLLaMAContext, "embeddings", RUBY_METHOD_FUNC(_llama_context_embeddings), 0);
1355
1647
  rb_define_method(rb_cLLaMAContext, "text", RUBY_METHOD_FUNC(_llama_context_text), 1);
@@ -1358,15 +1650,16 @@ public:
1358
1650
  rb_define_method(rb_cLLaMAContext, "token_bos", RUBY_METHOD_FUNC(_llama_context_token_bos), 0);
1359
1651
  rb_define_method(rb_cLLaMAContext, "token_eos", RUBY_METHOD_FUNC(_llama_context_token_eos), 0);
1360
1652
  rb_define_method(rb_cLLaMAContext, "token_nl", RUBY_METHOD_FUNC(_llama_context_token_nl), 0);
1361
- rb_define_method(rb_cLLaMAContext, "token_to_piece", RUBY_METHOD_FUNC(_llama_context_token_to_piece), 1);
1362
- rb_define_method(rb_cLLaMAContext, "n_vocab", RUBY_METHOD_FUNC(_llama_context_n_vocab), 0);
1363
1653
  rb_define_method(rb_cLLaMAContext, "n_ctx", RUBY_METHOD_FUNC(_llama_context_n_ctx), 0);
1364
- rb_define_method(rb_cLLaMAContext, "n_ctx_train", RUBY_METHOD_FUNC(_llama_context_n_ctx_train), 0);
1365
- rb_define_method(rb_cLLaMAContext, "n_embd", RUBY_METHOD_FUNC(_llama_context_n_embd), 0);
1366
1654
  rb_define_method(rb_cLLaMAContext, "timings", RUBY_METHOD_FUNC(_llama_context_get_timings), 0);
1367
1655
  rb_define_method(rb_cLLaMAContext, "print_timings", RUBY_METHOD_FUNC(_llama_context_print_timings), 0);
1368
1656
  rb_define_method(rb_cLLaMAContext, "reset_timings", RUBY_METHOD_FUNC(_llama_context_reset_timings), 0);
1369
1657
  rb_define_method(rb_cLLaMAContext, "kv_cache_token_count", RUBY_METHOD_FUNC(_llama_context_kv_cache_token_count), 0);
1658
+ rb_define_method(rb_cLLaMAContext, "kv_cache_tokens_rm", RUBY_METHOD_FUNC(_llama_context_kv_cache_tokens_rm), 2);
1659
+ rb_define_method(rb_cLLaMAContext, "kv_cache_seq_rm", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_rm), 3);
1660
+ rb_define_method(rb_cLLaMAContext, "kv_cache_seq_cp", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_cp), 4);
1661
+ rb_define_method(rb_cLLaMAContext, "kv_cache_seq_keep", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_keep), 1);
1662
+ rb_define_method(rb_cLLaMAContext, "kv_cache_seq_shift", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_shift), 4);
1370
1663
  rb_define_method(rb_cLLaMAContext, "set_rng_seed", RUBY_METHOD_FUNC(_llama_context_set_rng_seed), 1);
1371
1664
  rb_define_method(rb_cLLaMAContext, "load_session_file", RUBY_METHOD_FUNC(_llama_context_load_session_file), -1);
1372
1665
  rb_define_method(rb_cLLaMAContext, "save_session_file", RUBY_METHOD_FUNC(_llama_context_save_session_file), -1);
@@ -1378,6 +1671,7 @@ public:
1378
1671
  rb_define_method(rb_cLLaMAContext, "sample_top_p", RUBY_METHOD_FUNC(_llama_context_sample_top_p), -1);
1379
1672
  rb_define_method(rb_cLLaMAContext, "sample_tail_free", RUBY_METHOD_FUNC(_llama_context_sample_tail_free), -1);
1380
1673
  rb_define_method(rb_cLLaMAContext, "sample_typical", RUBY_METHOD_FUNC(_llama_context_sample_typical), -1);
1674
+ rb_define_method(rb_cLLaMAContext, "sample_temp", RUBY_METHOD_FUNC(_llama_context_sample_temp), -1);
1381
1675
  rb_define_method(rb_cLLaMAContext, "sample_temperature", RUBY_METHOD_FUNC(_llama_context_sample_temperature), -1);
1382
1676
  rb_define_method(rb_cLLaMAContext, "sample_token_mirostat", RUBY_METHOD_FUNC(_llama_context_sample_token_mirostat), -1);
1383
1677
  rb_define_method(rb_cLLaMAContext, "sample_token_mirostat_v2", RUBY_METHOD_FUNC(_llama_context_sample_token_mirostat_v2), -1);
@@ -1392,24 +1686,27 @@ private:
1392
1686
 
1393
1687
  static VALUE _llama_context_initialize(int argc, VALUE* argv, VALUE self) {
1394
1688
  VALUE kw_args = Qnil;
1395
- ID kw_table[1] = { rb_intern("model") };
1396
- VALUE kw_values[1] = { Qundef };
1689
+ ID kw_table[2] = { rb_intern("model"), rb_intern("params") };
1690
+ VALUE kw_values[2] = { Qundef, Qundef };
1397
1691
  rb_scan_args(argc, argv, ":", &kw_args);
1398
- rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
1692
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
1399
1693
 
1400
1694
  VALUE model = kw_values[0];
1401
1695
  if (!rb_obj_is_kind_of(model, rb_cLLaMAModel)) {
1402
1696
  rb_raise(rb_eArgError, "model must be a Model");
1403
1697
  return Qnil;
1404
1698
  }
1699
+ VALUE params = kw_values[1];
1700
+ if (!rb_obj_is_kind_of(params, rb_cLLaMAContextParams)) {
1701
+ rb_raise(rb_eArgError, "params must be a ContextParams");
1702
+ return Qnil;
1703
+ }
1405
1704
 
1406
1705
  LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
1407
1706
  if (model_ptr->model == NULL) {
1408
1707
  rb_raise(rb_eRuntimeError, "Model is empty");
1409
1708
  return Qnil;
1410
1709
  }
1411
-
1412
- VALUE params = rb_iv_get(model, "@params");
1413
1710
  LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
1414
1711
  LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
1415
1712
 
@@ -1421,6 +1718,7 @@ private:
1421
1718
  }
1422
1719
 
1423
1720
  rb_iv_set(self, "@model", model);
1721
+ rb_iv_set(self, "@params", params);
1424
1722
  rb_iv_set(self, "@has_evaluated", Qfalse);
1425
1723
 
1426
1724
  return Qnil;
@@ -1428,8 +1726,8 @@ private:
1428
1726
 
1429
1727
  static VALUE _llama_context_eval(int argc, VALUE* argv, VALUE self) {
1430
1728
  VALUE kw_args = Qnil;
1431
- ID kw_table[4] = { rb_intern("tokens"), rb_intern("n_past"), rb_intern("n_tokens"), rb_intern("n_threads") };
1432
- VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
1729
+ ID kw_table[3] = { rb_intern("tokens"), rb_intern("n_past"), rb_intern("n_tokens") };
1730
+ VALUE kw_values[3] = { Qundef, Qundef, Qundef };
1433
1731
  rb_scan_args(argc, argv, ":", &kw_args);
1434
1732
  rb_get_kwargs(kw_args, kw_table, 2, 2, kw_values);
1435
1733
 
@@ -1445,10 +1743,6 @@ private:
1445
1743
  rb_raise(rb_eArgError, "n_tokens must be an integer");
1446
1744
  return Qnil;
1447
1745
  }
1448
- if (kw_values[3] != Qundef && !RB_INTEGER_TYPE_P(kw_values[3])) {
1449
- rb_raise(rb_eArgError, "n_threads must be an integer");
1450
- return Qnil;
1451
- }
1452
1746
 
1453
1747
  const size_t tokens_len = RARRAY_LEN(kw_values[0]);
1454
1748
  std::vector<llama_token> embd(tokens_len);
@@ -1463,14 +1757,13 @@ private:
1463
1757
 
1464
1758
  const int n_tokens = kw_values[2] == Qundef ? (int)tokens_len : NUM2INT(kw_values[2]);
1465
1759
  const int n_past = NUM2INT(kw_values[1]);
1466
- const int n_threads = kw_values[3] == Qundef ? 1 : NUM2INT(kw_values[3]);
1467
1760
 
1468
1761
  LLaMAContextWrapper* ptr = get_llama_context(self);
1469
1762
  if (ptr->ctx == NULL) {
1470
1763
  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1471
1764
  return Qnil;
1472
1765
  }
1473
- if (llama_eval(ptr->ctx, embd.data(), n_tokens, n_past, n_threads) != 0) {
1766
+ if (llama_eval(ptr->ctx, embd.data(), n_tokens, n_past) != 0) {
1474
1767
  rb_raise(rb_eRuntimeError, "Failed to evaluate");
1475
1768
  return Qnil;
1476
1769
  }
@@ -1483,8 +1776,8 @@ private:
1483
1776
 
1484
1777
  static VALUE _llama_context_eval_embd(int argc, VALUE* argv, VALUE self) {
1485
1778
  VALUE kw_args = Qnil;
1486
- ID kw_table[4] = { rb_intern("embd"), rb_intern("n_past"), rb_intern("n_tokens"), rb_intern("n_threads") };
1487
- VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
1779
+ ID kw_table[3] = { rb_intern("embd"), rb_intern("n_past"), rb_intern("n_tokens") };
1780
+ VALUE kw_values[3] = { Qundef, Qundef, Qundef };
1488
1781
  rb_scan_args(argc, argv, ":", &kw_args);
1489
1782
  rb_get_kwargs(kw_args, kw_table, 2, 2, kw_values);
1490
1783
 
@@ -1500,10 +1793,6 @@ private:
1500
1793
  rb_raise(rb_eArgError, "n_tokens must be an integer");
1501
1794
  return Qnil;
1502
1795
  }
1503
- if (kw_values[3] != Qundef && !RB_INTEGER_TYPE_P(kw_values[3])) {
1504
- rb_raise(rb_eArgError, "n_threads must be an integer");
1505
- return Qnil;
1506
- }
1507
1796
 
1508
1797
  const size_t tokens_len = RARRAY_LEN(kw_values[0]);
1509
1798
  std::vector<float> embd(tokens_len);
@@ -1518,14 +1807,13 @@ private:
1518
1807
 
1519
1808
  const int n_tokens = kw_values[2] == Qundef ? (int)tokens_len : NUM2INT(kw_values[2]);
1520
1809
  const int n_past = NUM2INT(kw_values[1]);
1521
- const int n_threads = kw_values[3] == Qundef ? 1 : NUM2INT(kw_values[3]);
1522
1810
 
1523
1811
  LLaMAContextWrapper* ptr = get_llama_context(self);
1524
1812
  if (ptr->ctx == NULL) {
1525
1813
  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1526
1814
  return Qnil;
1527
1815
  }
1528
- if (llama_eval_embd(ptr->ctx, embd.data(), n_tokens, n_past, n_threads) != 0) {
1816
+ if (llama_eval_embd(ptr->ctx, embd.data(), n_tokens, n_past) != 0) {
1529
1817
  rb_raise(rb_eRuntimeError, "Failed to evaluate");
1530
1818
  return Qnil;
1531
1819
  }
@@ -1536,91 +1824,22 @@ private:
1536
1824
  return Qnil;
1537
1825
  }
1538
1826
 
1539
- static VALUE _llama_context_eval_export(VALUE self, VALUE fname_) {
1827
+ static VALUE _llama_context_decode(VALUE self, VALUE batch) {
1540
1828
  LLaMAContextWrapper* ptr = get_llama_context(self);
1541
1829
  if (ptr->ctx == NULL) {
1542
1830
  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1543
1831
  return Qnil;
1544
1832
  }
1545
- if (!RB_TYPE_P(fname_, T_STRING)) {
1546
- rb_raise(rb_eArgError, "fname must be a string");
1833
+ if (!rb_obj_is_kind_of(batch, rb_cLLaMABatch)) {
1834
+ rb_raise(rb_eArgError, "batch must be a Batch");
1547
1835
  return Qnil;
1548
1836
  }
1549
- const char* fname = StringValueCStr(fname_);
1550
- if (llama_eval_export(ptr->ctx, fname) != 0) {
1551
- return Qfalse;
1552
- }
1553
- RB_GC_GUARD(fname_);
1554
- return Qtrue;
1555
- }
1556
-
1557
- static VALUE _llama_context_tokenize(int argc, VALUE* argv, VALUE self) {
1558
- VALUE kw_args = Qnil;
1559
- ID kw_table[3] = { rb_intern("text"), rb_intern("n_max_tokens"), rb_intern("add_bos") };
1560
- VALUE kw_values[3] = { Qundef, Qundef, Qundef };
1561
- rb_scan_args(argc, argv, ":", &kw_args);
1562
- rb_get_kwargs(kw_args, kw_table, 1, 2, kw_values);
1563
-
1564
- if (!RB_TYPE_P(kw_values[0], T_STRING)) {
1565
- rb_raise(rb_eArgError, "text must be a String");
1837
+ LLaMABatchWrapper* batch_ptr = RbLLaMABatch::get_llama_batch(batch);
1838
+ if (llama_decode(ptr->ctx, batch_ptr->batch) < 0) {
1839
+ rb_raise(rb_eRuntimeError, "Failed to decode");
1566
1840
  return Qnil;
1567
1841
  }
1568
- if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
1569
- rb_raise(rb_eArgError, "n_max_tokens must be an integer");
1570
- return Qnil;
1571
- }
1572
- if (kw_values[2] != Qundef && (kw_values[2] != Qtrue && kw_values[2] != Qfalse)) {
1573
- rb_raise(rb_eArgError, "add_bos must be a boolean");
1574
- return Qnil;
1575
- }
1576
-
1577
- VALUE text_ = kw_values[0];
1578
- std::string text = StringValueCStr(text_);
1579
- const bool add_bos = kw_values[2] == Qtrue ? true : false;
1580
- const int n_max_tokens = kw_values[1] != Qundef ? NUM2INT(kw_values[1]) : text.size() + (add_bos ? 1 : 0);
1581
-
1582
- std::vector<llama_token> tokens(n_max_tokens);
1583
- LLaMAContextWrapper* ptr = get_llama_context(self);
1584
- if (ptr->ctx == NULL) {
1585
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1586
- return Qnil;
1587
- }
1588
- const int n = llama_tokenize(ptr->ctx, text.c_str(), text.size(), tokens.data(), n_max_tokens, add_bos);
1589
- if (n < 0) {
1590
- rb_raise(rb_eRuntimeError, "Failed to tokenize");
1591
- return Qnil;
1592
- }
1593
-
1594
- VALUE output = rb_ary_new();
1595
- for (int i = 0; i < n; i++) {
1596
- rb_ary_push(output, INT2NUM(tokens[i]));
1597
- }
1598
-
1599
- RB_GC_GUARD(text_);
1600
- return output;
1601
- }
1602
-
1603
- static VALUE _llama_context_token_to_piece(VALUE self, VALUE token_) {
1604
- LLaMAContextWrapper* ptr = get_llama_context(self);
1605
- if (ptr->ctx == NULL) {
1606
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1607
- return Qnil;
1608
- }
1609
- const llama_token token = NUM2INT(token_);
1610
- std::vector<char> result(8, 0);
1611
- const int n_tokens = llama_token_to_piece(ptr->ctx, token, result.data(), result.size());
1612
- if (n_tokens < 0) {
1613
- result.resize(-n_tokens);
1614
- const int check = llama_token_to_piece(ptr->ctx, token, result.data(), result.size());
1615
- if (check != -n_tokens) {
1616
- rb_raise(rb_eRuntimeError, "failed to convert");
1617
- return Qnil;
1618
- }
1619
- } else {
1620
- result.resize(n_tokens);
1621
- }
1622
- std::string ret(result.data(), result.size());
1623
- return rb_str_new_cstr(ret.c_str());
1842
+ return Qnil;
1624
1843
  }
1625
1844
 
1626
1845
  static VALUE _llama_context_logits(VALUE self) {
@@ -1635,10 +1854,11 @@ private:
1635
1854
  }
1636
1855
 
1637
1856
  VALUE model = rb_iv_get(self, "@model");
1638
- VALUE params = rb_iv_get(model, "@params");
1857
+ LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
1858
+ VALUE params = rb_iv_get(self, "@params");
1639
1859
  LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
1640
1860
  const int n_tokens = prms_ptr->params.logits_all ? NUM2INT(rb_iv_get(self, "@n_tokens")) : 1;
1641
- const int n_vocab = llama_n_vocab(ptr->ctx);
1861
+ const int n_vocab = llama_n_vocab(model_ptr->model);
1642
1862
  const float* logits = llama_get_logits(ptr->ctx);
1643
1863
  VALUE output = rb_ary_new();
1644
1864
  for (int i = 0; i < n_tokens * n_vocab; i++) {
@@ -1655,7 +1875,8 @@ private:
1655
1875
  return Qnil;
1656
1876
  }
1657
1877
  VALUE model = rb_iv_get(self, "@model");
1658
- VALUE params = rb_iv_get(model, "@params");
1878
+ LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
1879
+ VALUE params = rb_iv_get(self, "@params");
1659
1880
  LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
1660
1881
  if (!prms_ptr->params.embedding) {
1661
1882
  rb_raise(rb_eRuntimeError, "embedding parameter is false");
@@ -1666,7 +1887,7 @@ private:
1666
1887
  return Qnil;
1667
1888
  }
1668
1889
 
1669
- const int n_embd = llama_n_embd(ptr->ctx);
1890
+ const int n_embd = llama_n_embd(model_ptr->model);
1670
1891
  const float* embd = llama_get_embeddings(ptr->ctx);
1671
1892
  VALUE output = rb_ary_new();
1672
1893
  for (int i = 0; i < n_embd; i++) {
@@ -1736,81 +1957,104 @@ private:
1736
1957
  return INT2NUM(llama_token_nl(ptr->ctx));
1737
1958
  }
1738
1959
 
1739
- static VALUE _llama_context_n_vocab(VALUE self) {
1960
+ static VALUE _llama_context_n_ctx(VALUE self) {
1740
1961
  LLaMAContextWrapper* ptr = get_llama_context(self);
1741
1962
  if (ptr->ctx == NULL) {
1742
1963
  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1743
1964
  return Qnil;
1744
1965
  }
1745
- return INT2NUM(llama_n_vocab(ptr->ctx));
1966
+ return INT2NUM(llama_n_ctx(ptr->ctx));
1746
1967
  }
1747
1968
 
1748
- static VALUE _llama_context_n_ctx(VALUE self) {
1969
+ static VALUE _llama_context_get_timings(VALUE self) {
1749
1970
  LLaMAContextWrapper* ptr = get_llama_context(self);
1750
1971
  if (ptr->ctx == NULL) {
1751
1972
  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1752
1973
  return Qnil;
1753
1974
  }
1754
- return INT2NUM(llama_n_ctx(ptr->ctx));
1975
+ VALUE tm_obj = rb_funcall(rb_cLLaMATimings, rb_intern("new"), 0);
1976
+ LLaMATimingsWrapper* tm_ptr = RbLLaMATimings::get_llama_timings(tm_obj);
1977
+ tm_ptr->timings = llama_get_timings(ptr->ctx);
1978
+ return tm_obj;
1755
1979
  }
1756
1980
 
1757
- static VALUE _llama_context_n_ctx_train(VALUE self) {
1981
+ static VALUE _llama_context_print_timings(VALUE self) {
1758
1982
  LLaMAContextWrapper* ptr = get_llama_context(self);
1759
1983
  if (ptr->ctx == NULL) {
1760
1984
  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1761
1985
  return Qnil;
1762
1986
  }
1763
- return INT2NUM(llama_n_ctx_train(ptr->ctx));
1987
+ llama_print_timings(ptr->ctx);
1988
+ return Qnil;
1764
1989
  }
1765
1990
 
1766
- static VALUE _llama_context_n_embd(VALUE self) {
1991
+ static VALUE _llama_context_reset_timings(VALUE self) {
1767
1992
  LLaMAContextWrapper* ptr = get_llama_context(self);
1768
1993
  if (ptr->ctx == NULL) {
1769
1994
  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1770
1995
  return Qnil;
1771
1996
  }
1772
- return INT2NUM(llama_n_embd(ptr->ctx));
1997
+ llama_reset_timings(ptr->ctx);
1998
+ return Qnil;
1773
1999
  }
1774
2000
 
1775
- static VALUE _llama_context_get_timings(VALUE self) {
2001
+ static VALUE _llama_context_kv_cache_token_count(VALUE self) {
1776
2002
  LLaMAContextWrapper* ptr = get_llama_context(self);
1777
2003
  if (ptr->ctx == NULL) {
1778
2004
  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1779
2005
  return Qnil;
1780
2006
  }
1781
- VALUE tm_obj = rb_funcall(rb_cLLaMATimings, rb_intern("new"), 0);
1782
- LLaMATimingsWrapper* tm_ptr = RbLLaMATimings::get_llama_timings(tm_obj);
1783
- tm_ptr->timings = llama_get_timings(ptr->ctx);
1784
- return tm_obj;
2007
+ return INT2NUM(llama_get_kv_cache_token_count(ptr->ctx));
1785
2008
  }
1786
2009
 
1787
- static VALUE _llama_context_print_timings(VALUE self) {
2010
+ static VALUE _llama_context_kv_cache_tokens_rm(VALUE self, VALUE c0, VALUE c1) {
1788
2011
  LLaMAContextWrapper* ptr = get_llama_context(self);
1789
2012
  if (ptr->ctx == NULL) {
1790
2013
  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1791
2014
  return Qnil;
1792
2015
  }
1793
- llama_print_timings(ptr->ctx);
2016
+ llama_kv_cache_tokens_rm(ptr->ctx, NUM2INT(c0), NUM2INT(c1));
1794
2017
  return Qnil;
1795
2018
  }
1796
2019
 
1797
- static VALUE _llama_context_reset_timings(VALUE self) {
2020
+ static VALUE _llama_context_kv_cache_seq_rm(VALUE self, VALUE seq_id, VALUE p0, VALUE p1) {
1798
2021
  LLaMAContextWrapper* ptr = get_llama_context(self);
1799
2022
  if (ptr->ctx == NULL) {
1800
2023
  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1801
2024
  return Qnil;
1802
2025
  }
1803
- llama_reset_timings(ptr->ctx);
2026
+ llama_kv_cache_seq_rm(ptr->ctx, NUM2INT(seq_id), NUM2INT(p0), NUM2INT(p1));
1804
2027
  return Qnil;
1805
2028
  }
1806
2029
 
1807
- static VALUE _llama_context_kv_cache_token_count(VALUE self) {
2030
+ static VALUE _llama_context_kv_cache_seq_cp(VALUE self, VALUE seq_id_src, VALUE seq_id_dst, VALUE p0, VALUE p1) {
1808
2031
  LLaMAContextWrapper* ptr = get_llama_context(self);
1809
2032
  if (ptr->ctx == NULL) {
1810
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2033
+ rb_raise(rb_eArgError, "LLaMA context is not initialized");
1811
2034
  return Qnil;
1812
2035
  }
1813
- return INT2NUM(llama_get_kv_cache_token_count(ptr->ctx));
2036
+ llama_kv_cache_seq_cp(ptr->ctx, NUM2INT(seq_id_src), NUM2INT(seq_id_dst), NUM2INT(p0), NUM2INT(p1));
2037
+ return Qnil;
2038
+ }
2039
+
2040
+ static VALUE _llama_context_kv_cache_seq_keep(VALUE self, VALUE seq_id) {
2041
+ LLaMAContextWrapper* ptr = get_llama_context(self);
2042
+ if (ptr->ctx == NULL) {
2043
+ rb_raise(rb_eArgError, "LLaMA context is not initialized");
2044
+ return Qnil;
2045
+ }
2046
+ llama_kv_cache_seq_keep(ptr->ctx, NUM2INT(seq_id));
2047
+ return Qnil;
2048
+ }
2049
+
2050
+ static VALUE _llama_context_kv_cache_seq_shift(VALUE self, VALUE seq_id, VALUE p0, VALUE p1, VALUE delta) {
2051
+ LLaMAContextWrapper* ptr = get_llama_context(self);
2052
+ if (ptr->ctx == NULL) {
2053
+ rb_raise(rb_eArgError, "LLaMA context is not initialized");
2054
+ return Qnil;
2055
+ }
2056
+ llama_kv_cache_seq_shift(ptr->ctx, NUM2INT(seq_id), NUM2INT(p0), NUM2INT(p1), NUM2INT(delta));
2057
+ return Qnil;
1814
2058
  }
1815
2059
 
1816
2060
  static VALUE _llama_context_set_rng_seed(VALUE self, VALUE seed_) {
@@ -1851,7 +2095,7 @@ private:
1851
2095
  }
1852
2096
 
1853
2097
  VALUE model = rb_iv_get(self, "@model");
1854
- VALUE params = rb_iv_get(model, "@params");
2098
+ VALUE params = rb_iv_get(self, "@params");
1855
2099
  LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
1856
2100
  const int n_ctx = prms_ptr->params.n_ctx;
1857
2101
 
@@ -2235,6 +2479,40 @@ private:
2235
2479
  return Qnil;
2236
2480
  }
2237
2481
 
2482
+ static VALUE _llama_context_sample_temp(int argc, VALUE* argv, VALUE self) {
2483
+ VALUE kw_args = Qnil;
2484
+ ID kw_table[1] = { rb_intern("temp") };
2485
+ VALUE kw_values[1] = { Qundef };
2486
+ VALUE candidates = Qnil;
2487
+ rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
2488
+ rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
2489
+
2490
+ if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
2491
+ rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
2492
+ return Qnil;
2493
+ }
2494
+ if (!RB_FLOAT_TYPE_P(kw_values[0])) {
2495
+ rb_raise(rb_eArgError, "temp must be a float");
2496
+ return Qnil;
2497
+ }
2498
+
2499
+ LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
2500
+ if (ctx_ptr->ctx == NULL) {
2501
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2502
+ return Qnil;
2503
+ }
2504
+ LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
2505
+ if (cnd_ptr->array.data == nullptr) {
2506
+ rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
2507
+ return Qnil;
2508
+ }
2509
+ const float temp = NUM2DBL(kw_values[0]);
2510
+
2511
+ llama_sample_temp(ctx_ptr->ctx, &(cnd_ptr->array), temp);
2512
+
2513
+ return Qnil;
2514
+ }
2515
+
2238
2516
  static VALUE _llama_context_sample_temperature(int argc, VALUE* argv, VALUE self) {
2239
2517
  VALUE kw_args = Qnil;
2240
2518
  ID kw_table[1] = { rb_intern("temperature") };
@@ -2560,6 +2838,7 @@ extern "C" void Init_llama_cpp(void) {
2560
2838
  RbLLaMATokenData::define_class(rb_mLLaMACpp);
2561
2839
  RbLLaMATokenDataArray::define_class(rb_mLLaMACpp);
2562
2840
  RbLLaMAModel::define_class(rb_mLLaMACpp);
2841
+ RbLLaMAModelParams::define_class(rb_mLLaMACpp);
2563
2842
  RbLLaMATimings::define_class(rb_mLLaMACpp);
2564
2843
  RbLLaMAContext::define_class(rb_mLLaMACpp);
2565
2844
  RbLLaMAContextParams::define_class(rb_mLLaMACpp);
@@ -2578,10 +2857,6 @@ extern "C" void Init_llama_cpp(void) {
2578
2857
 
2579
2858
  rb_define_const(rb_mLLaMACpp, "LLAMA_MAX_DEVICES", INT2NUM(LLAMA_MAX_DEVICES));
2580
2859
 
2581
- rb_define_const(rb_mLLaMACpp, "LLAMA_LOG_LEVEL_ERROR", INT2NUM(LLAMA_LOG_LEVEL_ERROR));
2582
- rb_define_const(rb_mLLaMACpp, "LLAMA_LOG_LEVEL_WARN", INT2NUM(LLAMA_LOG_LEVEL_WARN));
2583
- rb_define_const(rb_mLLaMACpp, "LLAMA_LOG_LEVEL_INFO", INT2NUM(LLAMA_LOG_LEVEL_INFO));
2584
-
2585
2860
  rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_TYPE_SPM", INT2NUM(LLAMA_VOCAB_TYPE_SPM));
2586
2861
  rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_TYPE_BPE", INT2NUM(LLAMA_VOCAB_TYPE_BPE));
2587
2862