llama_cpp 0.5.3 → 0.6.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,7 +1,9 @@
1
1
  #include "llama_cpp.h"
2
2
 
3
3
  VALUE rb_mLLaMACpp;
4
+ VALUE rb_cLLaMABatch;
4
5
  VALUE rb_cLLaMAModel;
6
+ VALUE rb_cLLaMAModelParams;
5
7
  VALUE rb_cLLaMATimings;
6
8
  VALUE rb_cLLaMAContext;
7
9
  VALUE rb_cLLaMAContextParams;
@@ -11,6 +13,238 @@ VALUE rb_cLLaMATokenDataArray;
11
13
  VALUE rb_cLLaMAGrammarElement;
12
14
  VALUE rb_cLLaMAGrammar;
13
15
 
16
+ class LLaMABatchWrapper {
17
+ public:
18
+ llama_batch batch;
19
+
20
+ LLaMABatchWrapper() {}
21
+
22
+ ~LLaMABatchWrapper() {
23
+ llama_batch_free(batch);
24
+ }
25
+ };
26
+
27
+ class RbLLaMABatch {
28
+ public:
29
+ static VALUE llama_batch_alloc(VALUE self) {
30
+ LLaMABatchWrapper* ptr = (LLaMABatchWrapper*)ruby_xmalloc(sizeof(LLaMABatchWrapper));
31
+ new (ptr) LLaMABatchWrapper();
32
+ return TypedData_Wrap_Struct(self, &llama_batch_type, ptr);
33
+ }
34
+
35
+ static void llama_batch_free(void* ptr) {
36
+ ((LLaMABatchWrapper*)ptr)->~LLaMABatchWrapper();
37
+ ruby_xfree(ptr);
38
+ }
39
+
40
+ static size_t llama_batch_size(const void* ptr) {
41
+ return sizeof(*((LLaMABatchWrapper*)ptr));
42
+ }
43
+
44
+ static LLaMABatchWrapper* get_llama_batch(VALUE self) {
45
+ LLaMABatchWrapper* ptr;
46
+ TypedData_Get_Struct(self, LLaMABatchWrapper, &llama_batch_type, ptr);
47
+ return ptr;
48
+ }
49
+
50
+ static void define_class(VALUE outer) {
51
+ rb_cLLaMABatch = rb_define_class_under(outer, "Batch", rb_cObject);
52
+ rb_define_alloc_func(rb_cLLaMABatch, llama_batch_alloc);
53
+ rb_define_method(rb_cLLaMABatch, "initialize", RUBY_METHOD_FUNC(_llama_batch_initialize), -1);
54
+ rb_define_method(rb_cLLaMABatch, "n_tokens=", RUBY_METHOD_FUNC(_llama_batch_set_n_tokens), 1);
55
+ rb_define_method(rb_cLLaMABatch, "n_tokens", RUBY_METHOD_FUNC(_llama_batch_get_n_tokens), 0);
56
+ rb_define_method(rb_cLLaMABatch, "all_pos_zero=", RUBY_METHOD_FUNC(_llama_batch_set_all_pos_zero), 1);
57
+ rb_define_method(rb_cLLaMABatch, "all_pos_zero", RUBY_METHOD_FUNC(_llama_batch_get_all_pos_zero), 0);
58
+ rb_define_method(rb_cLLaMABatch, "all_pos_one=", RUBY_METHOD_FUNC(_llama_batch_set_all_pos_one), 1);
59
+ rb_define_method(rb_cLLaMABatch, "all_pos_one", RUBY_METHOD_FUNC(_llama_batch_get_all_pos_one), 0);
60
+ rb_define_method(rb_cLLaMABatch, "all_seq_id=", RUBY_METHOD_FUNC(_llama_batch_set_all_seq_id), 1);
61
+ rb_define_method(rb_cLLaMABatch, "all_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_all_seq_id), 0);
62
+ rb_define_method(rb_cLLaMABatch, "set_token", RUBY_METHOD_FUNC(_llama_batch_set_token), 2);
63
+ rb_define_method(rb_cLLaMABatch, "get_token", RUBY_METHOD_FUNC(_llama_batch_get_token), 1);
64
+ rb_define_method(rb_cLLaMABatch, "set_pos", RUBY_METHOD_FUNC(_llama_batch_set_pos), 2);
65
+ rb_define_method(rb_cLLaMABatch, "get_pos", RUBY_METHOD_FUNC(_llama_batch_get_pos), 1);
66
+ rb_define_method(rb_cLLaMABatch, "set_seq_id", RUBY_METHOD_FUNC(_llama_batch_set_seq_id), 2);
67
+ rb_define_method(rb_cLLaMABatch, "get_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_seq_id), 1);
68
+ rb_define_method(rb_cLLaMABatch, "set_logits", RUBY_METHOD_FUNC(_llama_batch_set_logits), 2);
69
+ rb_define_method(rb_cLLaMABatch, "get_logits", RUBY_METHOD_FUNC(_llama_batch_get_logits), 1);
70
+ }
71
+
72
+ private:
73
+ static const rb_data_type_t llama_batch_type;
74
+
75
+ static VALUE _llama_batch_initialize(int argc, VALUE* argv, VALUE self) {
76
+ VALUE kw_args = Qnil;
77
+ ID kw_table[2] = { rb_intern("n_tokens"), rb_intern("embd") };
78
+ VALUE kw_values[2] = { Qundef, Qundef };
79
+ rb_scan_args(argc, argv, ":", &kw_args);
80
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
81
+
82
+ if (!RB_INTEGER_TYPE_P(kw_values[0])) {
83
+ rb_raise(rb_eArgError, "n_tokens must be an integer");
84
+ return Qnil;
85
+ }
86
+ if (!RB_INTEGER_TYPE_P(kw_values[1])) {
87
+ rb_raise(rb_eArgError, "embd must be an integer");
88
+ return Qnil;
89
+ }
90
+
91
+ const int32_t n_tokens = NUM2INT(kw_values[0]);
92
+ const int32_t embd = NUM2INT(kw_values[1]);
93
+
94
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
95
+ ptr->batch = llama_batch_init(n_tokens, embd);
96
+
97
+ return Qnil;
98
+ }
99
+
100
+ // n_tokens
101
+ static VALUE _llama_batch_set_n_tokens(VALUE self, VALUE n_tokens) {
102
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
103
+ ptr->batch.n_tokens = NUM2INT(n_tokens);
104
+ return INT2NUM(ptr->batch.n_tokens);
105
+ }
106
+
107
+ static VALUE _llama_batch_get_n_tokens(VALUE self) {
108
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
109
+ return INT2NUM(ptr->batch.n_tokens);
110
+ }
111
+
112
+ // all_pos_0
113
+ static VALUE _llama_batch_set_all_pos_zero(VALUE self, VALUE all_pos_0) {
114
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
115
+ ptr->batch.all_pos_0 = NUM2INT(all_pos_0);
116
+ return INT2NUM(ptr->batch.all_pos_0);
117
+ }
118
+
119
+ static VALUE _llama_batch_get_all_pos_zero(VALUE self) {
120
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
121
+ return INT2NUM(ptr->batch.all_pos_0);
122
+ }
123
+
124
+ // all_pos_1
125
+ static VALUE _llama_batch_set_all_pos_one(VALUE self, VALUE all_pos_1) {
126
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
127
+ ptr->batch.all_pos_1 = NUM2INT(all_pos_1);
128
+ return INT2NUM(ptr->batch.all_pos_1);
129
+ }
130
+
131
+ static VALUE _llama_batch_get_all_pos_one(VALUE self) {
132
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
133
+ return INT2NUM(ptr->batch.all_pos_1);
134
+ }
135
+
136
+ // all_seq_id
137
+ static VALUE _llama_batch_set_all_seq_id(VALUE self, VALUE all_seq_id) {
138
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
139
+ ptr->batch.all_seq_id = NUM2INT(all_seq_id);
140
+ return INT2NUM(ptr->batch.all_seq_id);
141
+ }
142
+
143
+ static VALUE _llama_batch_get_all_seq_id(VALUE self) {
144
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
145
+ return INT2NUM(ptr->batch.all_seq_id);
146
+ }
147
+
148
+ // token
149
+ static VALUE _llama_batch_set_token(VALUE self, VALUE idx, VALUE value) {
150
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
151
+ const int32_t id = NUM2INT(idx);
152
+ if (id < 0 || id >= ptr->batch.n_tokens) {
153
+ rb_raise(rb_eArgError, "idx must be in [0, n_tokens)");
154
+ return Qnil;
155
+ }
156
+ ptr->batch.token[id] = NUM2INT(value);
157
+ return INT2NUM(ptr->batch.token[id]);
158
+ }
159
+
160
+ static VALUE _llama_batch_get_token(VALUE self, VALUE idx) {
161
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
162
+ const int32_t id = NUM2INT(idx);
163
+ if (id < 0 || id >= ptr->batch.n_tokens) {
164
+ rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
165
+ return Qnil;
166
+ }
167
+ return INT2NUM(ptr->batch.token[id]);
168
+ }
169
+
170
+ // pos
171
+ static VALUE _llama_batch_set_pos(VALUE self, VALUE idx, VALUE value) {
172
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
173
+ const int32_t id = NUM2INT(idx);
174
+ if (id < 0 || id >= ptr->batch.n_tokens) {
175
+ rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
176
+ return Qnil;
177
+ }
178
+ ptr->batch.pos[id] = NUM2INT(value);
179
+ return INT2NUM(ptr->batch.pos[id]);
180
+ }
181
+
182
+ static VALUE _llama_batch_get_pos(VALUE self, VALUE idx) {
183
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
184
+ const int32_t id = NUM2INT(idx);
185
+ if (id < 0 || id >= ptr->batch.n_tokens) {
186
+ rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
187
+ return Qnil;
188
+ }
189
+ return INT2NUM(ptr->batch.pos[id]);
190
+ }
191
+
192
+ // seq_id
193
+ static VALUE _llama_batch_set_seq_id(VALUE self, VALUE idx, VALUE value) {
194
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
195
+ const int32_t id = NUM2INT(idx);
196
+ if (id < 0 || id >= ptr->batch.n_tokens) {
197
+ rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
198
+ return Qnil;
199
+ }
200
+ ptr->batch.seq_id[id] = NUM2INT(value);
201
+ return INT2NUM(ptr->batch.seq_id[id]);
202
+ }
203
+
204
+ static VALUE _llama_batch_get_seq_id(VALUE self, VALUE idx) {
205
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
206
+ const int32_t id = NUM2INT(idx);
207
+ if (id < 0 || id >= ptr->batch.n_tokens) {
208
+ rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
209
+ return Qnil;
210
+ }
211
+ return INT2NUM(ptr->batch.seq_id[id]);
212
+ }
213
+
214
+ // logits
215
+ static VALUE _llama_batch_set_logits(VALUE self, VALUE idx, VALUE value) {
216
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
217
+ const int32_t id = NUM2INT(idx);
218
+ if (id < 0 || id >= ptr->batch.n_tokens) {
219
+ rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
220
+ return Qnil;
221
+ }
222
+ ptr->batch.logits[id] = RTEST(value) ? true : false;
223
+ return ptr->batch.logits[id] ? Qtrue : Qfalse;
224
+ }
225
+
226
+ static VALUE _llama_batch_get_logits(VALUE self, VALUE idx) {
227
+ LLaMABatchWrapper* ptr = get_llama_batch(self);
228
+ const int32_t id = NUM2INT(idx);
229
+ if (id < 0 || id >= ptr->batch.n_tokens) {
230
+ rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
231
+ return Qnil;
232
+ }
233
+ return ptr->batch.logits[id] ? Qtrue : Qfalse;
234
+ }
235
+ };
236
+
237
+ const rb_data_type_t RbLLaMABatch::llama_batch_type = {
238
+ "RbLLaMABatch",
239
+ { NULL,
240
+ RbLLaMABatch::llama_batch_free,
241
+ RbLLaMABatch::llama_batch_size },
242
+ NULL,
243
+ NULL,
244
+ RUBY_TYPED_FREE_IMMEDIATELY
245
+
246
+ };
247
+
14
248
  class LLaMATokenDataWrapper {
15
249
  public:
16
250
  llama_token_data data;
@@ -363,6 +597,144 @@ const rb_data_type_t RbLLaMATimings::llama_timings_type = {
363
597
  RUBY_TYPED_FREE_IMMEDIATELY
364
598
  };
365
599
 
600
+ class LLaMAModelParamsWrapper {
601
+ public:
602
+ struct llama_model_params params;
603
+
604
+ LLaMAModelParamsWrapper() : params(llama_model_default_params()) {}
605
+
606
+ ~LLaMAModelParamsWrapper() {}
607
+ };
608
+
609
+ class RbLLaMAModelParams {
610
+ public:
611
+ static VALUE llama_model_params_alloc(VALUE self) {
612
+ LLaMAModelParamsWrapper* ptr = (LLaMAModelParamsWrapper*)ruby_xmalloc(sizeof(LLaMAModelParamsWrapper));
613
+ new (ptr) LLaMAModelParamsWrapper();
614
+ return TypedData_Wrap_Struct(self, &llama_model_params_type, ptr);
615
+ }
616
+
617
+ static void llama_model_params_free(void* ptr) {
618
+ ((LLaMAModelParamsWrapper*)ptr)->~LLaMAModelParamsWrapper();
619
+ ruby_xfree(ptr);
620
+ }
621
+
622
+ static size_t llama_model_params_size(const void* ptr) {
623
+ return sizeof(*((LLaMAModelParamsWrapper*)ptr));
624
+ }
625
+
626
+ static LLaMAModelParamsWrapper* get_llama_model_params(VALUE self) {
627
+ LLaMAModelParamsWrapper* ptr;
628
+ TypedData_Get_Struct(self, LLaMAModelParamsWrapper, &llama_model_params_type, ptr);
629
+ return ptr;
630
+ }
631
+
632
+ static void define_class(VALUE outer) {
633
+ rb_cLLaMAModelParams = rb_define_class_under(outer, "ModelParams", rb_cObject);
634
+ rb_define_alloc_func(rb_cLLaMAModelParams, llama_model_params_alloc);
635
+ rb_define_method(rb_cLLaMAModelParams, "n_gpu_layers=", RUBY_METHOD_FUNC(_llama_model_params_set_n_gpu_layers), 1);
636
+ rb_define_method(rb_cLLaMAModelParams, "n_gpu_layers", RUBY_METHOD_FUNC(_llama_model_params_get_n_gpu_layers), 0);
637
+ rb_define_method(rb_cLLaMAModelParams, "main_gpu=", RUBY_METHOD_FUNC(_llama_model_params_set_main_gpu), 1);
638
+ rb_define_method(rb_cLLaMAModelParams, "main_gpu", RUBY_METHOD_FUNC(_llama_model_params_get_main_gpu), 0);
639
+ rb_define_method(rb_cLLaMAModelParams, "tensor_split", RUBY_METHOD_FUNC(_llama_model_params_get_tensor_split), 0);
640
+ rb_define_method(rb_cLLaMAModelParams, "vocab_only=", RUBY_METHOD_FUNC(_llama_model_params_set_vocab_only), 1);
641
+ rb_define_method(rb_cLLaMAModelParams, "vocab_only", RUBY_METHOD_FUNC(_llama_model_params_get_vocab_only), 0);
642
+ rb_define_method(rb_cLLaMAModelParams, "use_mmap=", RUBY_METHOD_FUNC(_llama_model_params_set_use_mmap), 1);
643
+ rb_define_method(rb_cLLaMAModelParams, "use_mmap", RUBY_METHOD_FUNC(_llama_model_params_get_use_mmap), 0);
644
+ rb_define_method(rb_cLLaMAModelParams, "use_mlock=", RUBY_METHOD_FUNC(_llama_model_params_set_use_mlock), 1);
645
+ rb_define_method(rb_cLLaMAModelParams, "use_mlock", RUBY_METHOD_FUNC(_llama_model_params_get_use_mlock), 0);
646
+ }
647
+
648
+ private:
649
+ static const rb_data_type_t llama_model_params_type;
650
+
651
+ // n_gpu_layers
652
+ static VALUE _llama_model_params_set_n_gpu_layers(VALUE self, VALUE n_gpu_layers) {
653
+ LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
654
+ ptr->params.n_gpu_layers = NUM2INT(n_gpu_layers);
655
+ return INT2NUM(ptr->params.n_gpu_layers);
656
+ }
657
+
658
+ static VALUE _llama_model_params_get_n_gpu_layers(VALUE self) {
659
+ LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
660
+ return INT2NUM(ptr->params.n_gpu_layers);
661
+ }
662
+
663
+ // main_gpu
664
+ static VALUE _llama_model_params_set_main_gpu(VALUE self, VALUE main_gpu) {
665
+ LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
666
+ ptr->params.main_gpu = NUM2INT(main_gpu);
667
+ return INT2NUM(ptr->params.main_gpu);
668
+ }
669
+
670
+ static VALUE _llama_model_params_get_main_gpu(VALUE self) {
671
+ LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
672
+ return INT2NUM(ptr->params.main_gpu);
673
+ }
674
+
675
+ // tensor_split
676
+ static VALUE _llama_model_params_get_tensor_split(VALUE self) {
677
+ if (LLAMA_MAX_DEVICES < 1) {
678
+ return rb_ary_new();
679
+ }
680
+ VALUE ret = rb_ary_new2(LLAMA_MAX_DEVICES);
681
+ LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
682
+ if (ptr->params.tensor_split == nullptr) {
683
+ return rb_ary_new();
684
+ }
685
+ for (size_t i = 0; i < LLAMA_MAX_DEVICES; i++) {
686
+ rb_ary_store(ret, i, DBL2NUM(ptr->params.tensor_split[i]));
687
+ }
688
+ return ret;
689
+ }
690
+
691
+ // vocab_only
692
+ static VALUE _llama_model_params_set_vocab_only(VALUE self, VALUE vocab_only) {
693
+ LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
694
+ ptr->params.vocab_only = RTEST(vocab_only) ? true : false;
695
+ return ptr->params.vocab_only ? Qtrue : Qfalse;
696
+ }
697
+
698
+ static VALUE _llama_model_params_get_vocab_only(VALUE self) {
699
+ LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
700
+ return ptr->params.vocab_only ? Qtrue : Qfalse;
701
+ }
702
+
703
+ // use_mmap
704
+ static VALUE _llama_model_params_set_use_mmap(VALUE self, VALUE use_mmap) {
705
+ LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
706
+ ptr->params.use_mmap = RTEST(use_mmap) ? true : false;
707
+ return ptr->params.use_mmap ? Qtrue : Qfalse;
708
+ }
709
+
710
+ static VALUE _llama_model_params_get_use_mmap(VALUE self) {
711
+ LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
712
+ return ptr->params.use_mmap ? Qtrue : Qfalse;
713
+ }
714
+
715
+ // use_mlock
716
+ static VALUE _llama_model_params_set_use_mlock(VALUE self, VALUE use_mlock) {
717
+ LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
718
+ ptr->params.use_mlock = RTEST(use_mlock) ? true : false;
719
+ return ptr->params.use_mlock ? Qtrue : Qfalse;
720
+ }
721
+
722
+ static VALUE _llama_model_params_get_use_mlock(VALUE self) {
723
+ LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
724
+ return ptr->params.use_mlock ? Qtrue : Qfalse;
725
+ }
726
+ };
727
+
728
+ const rb_data_type_t RbLLaMAModelParams::llama_model_params_type = {
729
+ "RbLLaMAModelParams",
730
+ { NULL,
731
+ RbLLaMAModelParams::llama_model_params_free,
732
+ RbLLaMAModelParams::llama_model_params_size },
733
+ NULL,
734
+ NULL,
735
+ RUBY_TYPED_FREE_IMMEDIATELY
736
+ };
737
+
366
738
  class LLaMAContextParamsWrapper {
367
739
  public:
368
740
  struct llama_context_params params;
@@ -399,35 +771,26 @@ public:
399
771
  rb_cLLaMAContextParams = rb_define_class_under(outer, "ContextParams", rb_cObject);
400
772
  rb_define_alloc_func(rb_cLLaMAContextParams, llama_context_params_alloc);
401
773
  // rb_define_method(rb_cLLaMAContextParams, "initialize", RUBY_METHOD_FUNC(_llama_context_params_init), 0);
774
+ rb_define_method(rb_cLLaMAContextParams, "seed=", RUBY_METHOD_FUNC(_llama_context_params_set_seed), 1);
775
+ rb_define_method(rb_cLLaMAContextParams, "seed", RUBY_METHOD_FUNC(_llama_context_params_get_seed), 0);
402
776
  rb_define_method(rb_cLLaMAContextParams, "n_ctx=", RUBY_METHOD_FUNC(_llama_context_params_set_n_ctx), 1);
403
777
  rb_define_method(rb_cLLaMAContextParams, "n_ctx", RUBY_METHOD_FUNC(_llama_context_params_get_n_ctx), 0);
404
778
  rb_define_method(rb_cLLaMAContextParams, "n_batch=", RUBY_METHOD_FUNC(_llama_context_params_set_n_batch), 1);
405
779
  rb_define_method(rb_cLLaMAContextParams, "n_batch", RUBY_METHOD_FUNC(_llama_context_params_get_n_batch), 0);
406
- rb_define_method(rb_cLLaMAContextParams, "n_gpu_layers=", RUBY_METHOD_FUNC(_llama_context_params_set_n_gpu_layers), 1);
407
- rb_define_method(rb_cLLaMAContextParams, "n_gpu_layers", RUBY_METHOD_FUNC(_llama_context_params_get_n_gpu_layers), 0);
408
- rb_define_method(rb_cLLaMAContextParams, "main_gpu=", RUBY_METHOD_FUNC(_llama_context_params_set_main_gpu), 1);
409
- rb_define_method(rb_cLLaMAContextParams, "main_gpu", RUBY_METHOD_FUNC(_llama_context_params_get_main_gpu), 0);
410
- rb_define_method(rb_cLLaMAContextParams, "tensor_split", RUBY_METHOD_FUNC(_llama_context_params_get_tensor_split), 0);
780
+ rb_define_method(rb_cLLaMAContextParams, "n_threads=", RUBY_METHOD_FUNC(_llama_context_params_set_n_threads), 1);
781
+ rb_define_method(rb_cLLaMAContextParams, "n_threads", RUBY_METHOD_FUNC(_llama_context_params_get_n_threads), 0);
782
+ rb_define_method(rb_cLLaMAContextParams, "n_threads_batch=", RUBY_METHOD_FUNC(_llama_context_params_set_n_threads_batch), 1);
783
+ rb_define_method(rb_cLLaMAContextParams, "n_threads_batch", RUBY_METHOD_FUNC(_llama_context_params_get_n_threads_batch), 0);
411
784
  rb_define_method(rb_cLLaMAContextParams, "rope_freq_base=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_freq_base), 1);
412
785
  rb_define_method(rb_cLLaMAContextParams, "rope_freq_base", RUBY_METHOD_FUNC(_llama_context_params_get_rope_freq_base), 0);
413
786
  rb_define_method(rb_cLLaMAContextParams, "rope_freq_scale=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_freq_scale), 1);
414
787
  rb_define_method(rb_cLLaMAContextParams, "rope_freq_scale", RUBY_METHOD_FUNC(_llama_context_params_get_rope_freq_scale), 0);
415
- rb_define_method(rb_cLLaMAContextParams, "low_vram=", RUBY_METHOD_FUNC(_llama_context_params_set_low_vram), 1);
416
- rb_define_method(rb_cLLaMAContextParams, "low_vram", RUBY_METHOD_FUNC(_llama_context_params_get_low_vram), 0);
417
788
  rb_define_method(rb_cLLaMAContextParams, "mul_mat_q=", RUBY_METHOD_FUNC(_llama_context_params_set_mul_mat_q), 1);
418
789
  rb_define_method(rb_cLLaMAContextParams, "mul_mat_q", RUBY_METHOD_FUNC(_llama_context_params_get_mul_mat_q), 0);
419
- rb_define_method(rb_cLLaMAContextParams, "seed=", RUBY_METHOD_FUNC(_llama_context_params_set_seed), 1);
420
- rb_define_method(rb_cLLaMAContextParams, "seed", RUBY_METHOD_FUNC(_llama_context_params_get_seed), 0);
421
790
  rb_define_method(rb_cLLaMAContextParams, "f16_kv=", RUBY_METHOD_FUNC(_llama_context_params_set_f16_kv), 1);
422
791
  rb_define_method(rb_cLLaMAContextParams, "f16_kv", RUBY_METHOD_FUNC(_llama_context_params_get_f16_kv), 0);
423
792
  rb_define_method(rb_cLLaMAContextParams, "logits_all=", RUBY_METHOD_FUNC(_llama_context_params_set_logits_all), 1);
424
793
  rb_define_method(rb_cLLaMAContextParams, "logits_all", RUBY_METHOD_FUNC(_llama_context_params_get_logits_all), 0);
425
- rb_define_method(rb_cLLaMAContextParams, "vocab_only=", RUBY_METHOD_FUNC(_llama_context_params_set_vocab_only), 1);
426
- rb_define_method(rb_cLLaMAContextParams, "vocab_only", RUBY_METHOD_FUNC(_llama_context_params_get_vocab_only), 0);
427
- rb_define_method(rb_cLLaMAContextParams, "use_mmap=", RUBY_METHOD_FUNC(_llama_context_params_set_use_mmap), 1);
428
- rb_define_method(rb_cLLaMAContextParams, "use_mmap", RUBY_METHOD_FUNC(_llama_context_params_get_use_mmap), 0);
429
- rb_define_method(rb_cLLaMAContextParams, "use_mlock=", RUBY_METHOD_FUNC(_llama_context_params_set_use_mlock), 1);
430
- rb_define_method(rb_cLLaMAContextParams, "use_mlock", RUBY_METHOD_FUNC(_llama_context_params_get_use_mlock), 0);
431
794
  rb_define_method(rb_cLLaMAContextParams, "embedding=", RUBY_METHOD_FUNC(_llama_context_params_set_embedding), 1);
432
795
  rb_define_method(rb_cLLaMAContextParams, "embedding", RUBY_METHOD_FUNC(_llama_context_params_get_embedding), 0);
433
796
  }
@@ -441,6 +804,22 @@ private:
441
804
  // return self;
442
805
  // }
443
806
 
807
+ // seed
808
+ static VALUE _llama_context_params_set_seed(VALUE self, VALUE seed) {
809
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
810
+ if (NUM2INT(seed) < 0) {
811
+ rb_raise(rb_eArgError, "seed must be positive");
812
+ return Qnil;
813
+ }
814
+ ptr->params.seed = NUM2INT(seed);
815
+ return INT2NUM(ptr->params.seed);
816
+ }
817
+
818
+ static VALUE _llama_context_params_get_seed(VALUE self) {
819
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
820
+ return INT2NUM(ptr->params.seed);
821
+ }
822
+
444
823
  // n_ctx
445
824
  static VALUE _llama_context_params_set_n_ctx(VALUE self, VALUE n_ctx) {
446
825
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
@@ -465,41 +844,28 @@ private:
465
844
  return INT2NUM(ptr->params.n_batch);
466
845
  }
467
846
 
468
- // n_gpu_layers
469
- static VALUE _llama_context_params_set_n_gpu_layers(VALUE self, VALUE n_gpu_layers) {
847
+ // n_threads
848
+ static VALUE _llama_context_params_set_n_threads(VALUE self, VALUE n_threads) {
470
849
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
471
- ptr->params.n_gpu_layers = NUM2INT(n_gpu_layers);
472
- return INT2NUM(ptr->params.n_gpu_layers);
850
+ ptr->params.n_threads = NUM2INT(n_threads);
851
+ return INT2NUM(ptr->params.n_threads);
473
852
  }
474
853
 
475
- static VALUE _llama_context_params_get_n_gpu_layers(VALUE self) {
854
+ static VALUE _llama_context_params_get_n_threads(VALUE self) {
476
855
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
477
- return INT2NUM(ptr->params.n_gpu_layers);
856
+ return INT2NUM(ptr->params.n_threads);
478
857
  }
479
858
 
480
- // main_gpu
481
- static VALUE _llama_context_params_set_main_gpu(VALUE self, VALUE main_gpu) {
859
+ // n_threads_batch
860
+ static VALUE _llama_context_params_set_n_threads_batch(VALUE self, VALUE n_threads_batch) {
482
861
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
483
- ptr->params.main_gpu = NUM2INT(main_gpu);
484
- return INT2NUM(ptr->params.main_gpu);
862
+ ptr->params.n_threads_batch = NUM2INT(n_threads_batch);
863
+ return INT2NUM(ptr->params.n_threads_batch);
485
864
  }
486
865
 
487
- static VALUE _llama_context_params_get_main_gpu(VALUE self) {
866
+ static VALUE _llama_context_params_get_n_threads_batch(VALUE self) {
488
867
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
489
- return INT2NUM(ptr->params.main_gpu);
490
- }
491
-
492
- // tensor_split
493
- static VALUE _llama_context_params_get_tensor_split(VALUE self) {
494
- if (LLAMA_MAX_DEVICES < 1) {
495
- return rb_ary_new();
496
- }
497
- VALUE ret = rb_ary_new2(LLAMA_MAX_DEVICES);
498
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
499
- for (size_t i = 0; i < LLAMA_MAX_DEVICES; i++) {
500
- rb_ary_store(ret, i, DBL2NUM(ptr->params.tensor_split[i]));
501
- }
502
- return ret;
868
+ return INT2NUM(ptr->params.n_threads_batch);
503
869
  }
504
870
 
505
871
  // rope_freq_base
@@ -526,18 +892,6 @@ private:
526
892
  return DBL2NUM(ptr->params.rope_freq_scale);
527
893
  }
528
894
 
529
- // low_vram
530
- static VALUE _llama_context_params_set_low_vram(VALUE self, VALUE low_vram) {
531
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
532
- ptr->params.low_vram = RTEST(low_vram) ? true : false;
533
- return ptr->params.low_vram ? Qtrue : Qfalse;
534
- }
535
-
536
- static VALUE _llama_context_params_get_low_vram(VALUE self) {
537
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
538
- return ptr->params.low_vram ? Qtrue : Qfalse;
539
- }
540
-
541
895
  // mul_mat_q
542
896
  static VALUE _llama_context_params_set_mul_mat_q(VALUE self, VALUE mul_mat_q) {
543
897
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
@@ -550,22 +904,6 @@ private:
550
904
  return ptr->params.mul_mat_q ? Qtrue : Qfalse;
551
905
  }
552
906
 
553
- // seed
554
- static VALUE _llama_context_params_set_seed(VALUE self, VALUE seed) {
555
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
556
- if (NUM2INT(seed) < 0) {
557
- rb_raise(rb_eArgError, "seed must be positive");
558
- return Qnil;
559
- }
560
- ptr->params.seed = NUM2INT(seed);
561
- return INT2NUM(ptr->params.seed);
562
- }
563
-
564
- static VALUE _llama_context_params_get_seed(VALUE self) {
565
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
566
- return INT2NUM(ptr->params.seed);
567
- }
568
-
569
907
  // f16_kv
570
908
  static VALUE _llama_context_params_set_f16_kv(VALUE self, VALUE f16_kv) {
571
909
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
@@ -590,42 +928,6 @@ private:
590
928
  return ptr->params.logits_all ? Qtrue : Qfalse;
591
929
  }
592
930
 
593
- // vocab_only
594
- static VALUE _llama_context_params_set_vocab_only(VALUE self, VALUE vocab_only) {
595
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
596
- ptr->params.vocab_only = RTEST(vocab_only) ? true : false;
597
- return ptr->params.vocab_only ? Qtrue : Qfalse;
598
- }
599
-
600
- static VALUE _llama_context_params_get_vocab_only(VALUE self) {
601
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
602
- return ptr->params.vocab_only ? Qtrue : Qfalse;
603
- }
604
-
605
- // use_mmap
606
- static VALUE _llama_context_params_set_use_mmap(VALUE self, VALUE use_mmap) {
607
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
608
- ptr->params.use_mmap = RTEST(use_mmap) ? true : false;
609
- return ptr->params.use_mmap ? Qtrue : Qfalse;
610
- }
611
-
612
- static VALUE _llama_context_params_get_use_mmap(VALUE self) {
613
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
614
- return ptr->params.use_mmap ? Qtrue : Qfalse;
615
- }
616
-
617
- // use_mlock
618
- static VALUE _llama_context_params_set_use_mlock(VALUE self, VALUE use_mlock) {
619
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
620
- ptr->params.use_mlock = RTEST(use_mlock) ? true : false;
621
- return ptr->params.use_mlock ? Qtrue : Qfalse;
622
- }
623
-
624
- static VALUE _llama_context_params_get_use_mlock(VALUE self) {
625
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
626
- return ptr->params.use_mlock ? Qtrue : Qfalse;
627
- }
628
-
629
931
  // embedding
630
932
  static VALUE _llama_context_params_set_embedding(VALUE self, VALUE embedding) {
631
933
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
@@ -823,11 +1125,10 @@ public:
823
1125
  rb_define_method(rb_cLLaMAModel, "load", RUBY_METHOD_FUNC(_llama_model_load), -1);
824
1126
  rb_define_method(rb_cLLaMAModel, "apply_lora_from_file", RUBY_METHOD_FUNC(_llama_model_apply_lora_from_file), -1);
825
1127
  rb_define_method(rb_cLLaMAModel, "n_vocab", RUBY_METHOD_FUNC(_llama_model_get_model_n_vocab), 0);
826
- rb_define_method(rb_cLLaMAModel, "n_ctx", RUBY_METHOD_FUNC(_llama_model_get_model_n_ctx), 0);
827
1128
  rb_define_method(rb_cLLaMAModel, "n_ctx_train", RUBY_METHOD_FUNC(_llama_model_get_model_n_ctx_train), 0);
828
1129
  rb_define_method(rb_cLLaMAModel, "n_embd", RUBY_METHOD_FUNC(_llama_model_get_model_n_embd), 0);
829
- rb_define_method(rb_cLLaMAModel, "token_to_piece", RUBY_METHOD_FUNC(_llama_model_token_to_piece_with_model), 1);
830
- rb_define_method(rb_cLLaMAModel, "tokenize", RUBY_METHOD_FUNC(_llama_model_tokenize_with_model), -1);
1130
+ rb_define_method(rb_cLLaMAModel, "token_to_piece", RUBY_METHOD_FUNC(_llama_model_token_to_piece), 1);
1131
+ rb_define_method(rb_cLLaMAModel, "tokenize", RUBY_METHOD_FUNC(_llama_model_tokenize), -1);
831
1132
  rb_define_method(rb_cLLaMAModel, "desc", RUBY_METHOD_FUNC(_llama_model_get_model_desc), 0);
832
1133
  rb_define_method(rb_cLLaMAModel, "size", RUBY_METHOD_FUNC(_llama_model_get_model_size), 0);
833
1134
  rb_define_method(rb_cLLaMAModel, "n_params", RUBY_METHOD_FUNC(_llama_model_get_model_n_params), 0);
@@ -841,30 +1142,21 @@ private:
841
1142
  ID kw_table[2] = { rb_intern("model_path"), rb_intern("params") };
842
1143
  VALUE kw_values[2] = { Qundef, Qundef };
843
1144
  rb_scan_args(argc, argv, ":", &kw_args);
844
- rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
845
-
846
- if (kw_values[0] == Qundef && kw_values[1] == Qundef) {
847
- rb_iv_set(self, "@params", Qnil);
848
- return Qnil;
849
- }
1145
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
850
1146
 
851
1147
  if (!RB_TYPE_P(kw_values[0], T_STRING)) {
852
1148
  rb_raise(rb_eArgError, "model_path must be a string");
853
1149
  return Qnil;
854
1150
  }
855
- if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAContextParams)) {
856
- rb_raise(rb_eArgError, "params must be a ContextParams");
1151
+ if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAModelParams)) {
1152
+ rb_raise(rb_eArgError, "params must be a ModelParams");
857
1153
  return Qnil;
858
1154
  }
859
1155
 
860
1156
  VALUE filename = kw_values[0];
861
- LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(kw_values[1]);
1157
+ LLaMAModelParamsWrapper* prms_ptr = RbLLaMAModelParams::get_llama_model_params(kw_values[1]);
862
1158
  LLaMAModelWrapper* model_ptr = get_llama_model(self);
863
1159
 
864
- if (prms_ptr->params.seed == LLAMA_DEFAULT_SEED) {
865
- prms_ptr->params.seed = time(NULL);
866
- }
867
-
868
1160
  try {
869
1161
  model_ptr->model = llama_load_model_from_file(StringValueCStr(filename), prms_ptr->params);
870
1162
  } catch (const std::runtime_error& e) {
@@ -912,8 +1204,8 @@ private:
912
1204
  rb_raise(rb_eArgError, "model_path must be a string");
913
1205
  return Qnil;
914
1206
  }
915
- if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAContextParams)) {
916
- rb_raise(rb_eArgError, "params must be a LLaMAContextParams");
1207
+ if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAModelParams)) {
1208
+ rb_raise(rb_eArgError, "params must be a LLaMAModelParams");
917
1209
  return Qnil;
918
1210
  }
919
1211
 
@@ -924,7 +1216,7 @@ private:
924
1216
  }
925
1217
 
926
1218
  VALUE filename = kw_values[0];
927
- LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(kw_values[1]);
1219
+ LLaMAModelParamsWrapper* prms_ptr = RbLLaMAModelParams::get_llama_model_params(kw_values[1]);
928
1220
 
929
1221
  try {
930
1222
  model_ptr->model = llama_load_model_from_file(StringValueCStr(filename), prms_ptr->params);
@@ -946,10 +1238,10 @@ private:
946
1238
 
947
1239
  static VALUE _llama_model_apply_lora_from_file(int argc, VALUE* argv, VALUE self) {
948
1240
  VALUE kw_args = Qnil;
949
- ID kw_table[3] = { rb_intern("lora_path"), rb_intern("base_model_path"), rb_intern("n_threads") };
950
- VALUE kw_values[3] = { Qundef, Qundef, Qundef };
1241
+ ID kw_table[4] = { rb_intern("lora_path"), rb_intern("base_model_path"), rb_intern("n_threads"), rb_intern("scale") };
1242
+ VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
951
1243
  rb_scan_args(argc, argv, ":", &kw_args);
952
- rb_get_kwargs(kw_args, kw_table, 1, 2, kw_values);
1244
+ rb_get_kwargs(kw_args, kw_table, 1, 3, kw_values);
953
1245
 
954
1246
  if (!RB_TYPE_P(kw_values[0], T_STRING)) {
955
1247
  rb_raise(rb_eArgError, "lora_path must be a string");
@@ -963,13 +1255,18 @@ private:
963
1255
  rb_raise(rb_eArgError, "n_threads must be an integer");
964
1256
  return Qnil;
965
1257
  }
1258
+ if (kw_values[3] != Qundef && !RB_FLOAT_TYPE_P(kw_values[3])) {
1259
+ rb_raise(rb_eArgError, "scale must be a float");
1260
+ return Qnil;
1261
+ }
966
1262
 
967
1263
  const char* lora_path = StringValueCStr(kw_values[0]);
968
1264
  const char* base_model_path = kw_values[1] == Qundef ? NULL : StringValueCStr(kw_values[1]);
969
1265
  const int n_threads = kw_values[2] == Qundef ? 1 : NUM2INT(kw_values[2]);
1266
+ const float scale = kw_values[3] == Qundef ? 1.0 : NUM2DBL(kw_values[3]);
970
1267
 
971
1268
  LLaMAModelWrapper* ptr = get_llama_model(self);
972
- if (llama_model_apply_lora_from_file(ptr->model, lora_path, base_model_path, n_threads) != 0) {
1269
+ if (llama_model_apply_lora_from_file(ptr->model, lora_path, scale, base_model_path, n_threads) != 0) {
973
1270
  rb_raise(rb_eRuntimeError, "Failed to apply LoRA");
974
1271
  return Qnil;
975
1272
  }
@@ -978,25 +1275,20 @@ private:
978
1275
 
979
1276
  static VALUE _llama_model_get_model_n_vocab(VALUE self) {
980
1277
  LLaMAModelWrapper* ptr = get_llama_model(self);
981
- return INT2NUM(llama_model_n_vocab(ptr->model));
982
- }
983
-
984
- static VALUE _llama_model_get_model_n_ctx(VALUE self) {
985
- LLaMAModelWrapper* ptr = get_llama_model(self);
986
- return INT2NUM(llama_model_n_ctx(ptr->model));
1278
+ return INT2NUM(llama_n_vocab(ptr->model));
987
1279
  }
988
1280
 
989
1281
  static VALUE _llama_model_get_model_n_ctx_train(VALUE self) {
990
1282
  LLaMAModelWrapper* ptr = get_llama_model(self);
991
- return INT2NUM(llama_model_n_ctx_train(ptr->model));
1283
+ return INT2NUM(llama_n_ctx_train(ptr->model));
992
1284
  }
993
1285
 
994
1286
  static VALUE _llama_model_get_model_n_embd(VALUE self) {
995
1287
  LLaMAModelWrapper* ptr = get_llama_model(self);
996
- return INT2NUM(llama_model_n_embd(ptr->model));
1288
+ return INT2NUM(llama_n_embd(ptr->model));
997
1289
  }
998
1290
 
999
- static VALUE _llama_model_token_to_piece_with_model(VALUE self, VALUE token_) {
1291
+ static VALUE _llama_model_token_to_piece(VALUE self, VALUE token_) {
1000
1292
  if (!RB_INTEGER_TYPE_P(token_)) {
1001
1293
  rb_raise(rb_eArgError, "token must be an integer");
1002
1294
  return Qnil;
@@ -1004,10 +1296,10 @@ private:
1004
1296
  const llama_token token = NUM2INT(token_);
1005
1297
  LLaMAModelWrapper* ptr = get_llama_model(self);
1006
1298
  std::vector<char> result(8, 0);
1007
- const int n_tokens = llama_token_to_piece_with_model(ptr->model, token, result.data(), result.size());
1299
+ const int n_tokens = llama_token_to_piece(ptr->model, token, result.data(), result.size());
1008
1300
  if (n_tokens < 0) {
1009
1301
  result.resize(-n_tokens);
1010
- const int check = llama_token_to_piece_with_model(ptr->model, token, result.data(), result.size());
1302
+ const int check = llama_token_to_piece(ptr->model, token, result.data(), result.size());
1011
1303
  if (check != -n_tokens) {
1012
1304
  rb_raise(rb_eRuntimeError, "failed to convert");
1013
1305
  return Qnil;
@@ -1019,7 +1311,7 @@ private:
1019
1311
  return rb_str_new_cstr(ret.c_str());
1020
1312
  }
1021
1313
 
1022
- static VALUE _llama_model_tokenize_with_model(int argc, VALUE* argv, VALUE self) {
1314
+ static VALUE _llama_model_tokenize(int argc, VALUE* argv, VALUE self) {
1023
1315
  VALUE kw_args = Qnil;
1024
1316
  ID kw_table[3] = { rb_intern("text"), rb_intern("n_max_tokens"), rb_intern("add_bos") };
1025
1317
  VALUE kw_values[3] = { Qundef, Qundef, Qundef };
@@ -1046,7 +1338,7 @@ private:
1046
1338
 
1047
1339
  llama_token* tokens = ALLOCA_N(llama_token, n_max_tokens);
1048
1340
  LLaMAModelWrapper* ptr = get_llama_model(self);
1049
- const int n_tokens = llama_tokenize_with_model(ptr->model, text.c_str(), text.size(), tokens, n_max_tokens, add_bos);
1341
+ const int n_tokens = llama_tokenize(ptr->model, text.c_str(), text.size(), tokens, n_max_tokens, add_bos);
1050
1342
 
1051
1343
  if (n_tokens < 0) {
1052
1344
  rb_raise(rb_eRuntimeError, "failed to tokenize. The numebr of tokens (%d) is greater than n_max_tokens.", -n_tokens);
@@ -1345,11 +1637,11 @@ public:
1345
1637
  static void define_class(VALUE outer) {
1346
1638
  rb_cLLaMAContext = rb_define_class_under(outer, "Context", rb_cObject);
1347
1639
  rb_define_alloc_func(rb_cLLaMAContext, llama_context_alloc);
1640
+ rb_define_attr(rb_cLLaMAContext, "model", 1, 0);
1348
1641
  rb_define_method(rb_cLLaMAContext, "initialize", RUBY_METHOD_FUNC(_llama_context_initialize), -1);
1349
1642
  rb_define_method(rb_cLLaMAContext, "eval", RUBY_METHOD_FUNC(_llama_context_eval), -1);
1350
1643
  rb_define_method(rb_cLLaMAContext, "eval_embd", RUBY_METHOD_FUNC(_llama_context_eval_embd), -1);
1351
- rb_define_method(rb_cLLaMAContext, "eval_export", RUBY_METHOD_FUNC(_llama_context_eval_export), 1);
1352
- rb_define_method(rb_cLLaMAContext, "tokenize", RUBY_METHOD_FUNC(_llama_context_tokenize), -1);
1644
+ rb_define_method(rb_cLLaMAContext, "decode", RUBY_METHOD_FUNC(_llama_context_decode), 1);
1353
1645
  rb_define_method(rb_cLLaMAContext, "logits", RUBY_METHOD_FUNC(_llama_context_logits), 0);
1354
1646
  rb_define_method(rb_cLLaMAContext, "embeddings", RUBY_METHOD_FUNC(_llama_context_embeddings), 0);
1355
1647
  rb_define_method(rb_cLLaMAContext, "text", RUBY_METHOD_FUNC(_llama_context_text), 1);
@@ -1358,15 +1650,16 @@ public:
1358
1650
  rb_define_method(rb_cLLaMAContext, "token_bos", RUBY_METHOD_FUNC(_llama_context_token_bos), 0);
1359
1651
  rb_define_method(rb_cLLaMAContext, "token_eos", RUBY_METHOD_FUNC(_llama_context_token_eos), 0);
1360
1652
  rb_define_method(rb_cLLaMAContext, "token_nl", RUBY_METHOD_FUNC(_llama_context_token_nl), 0);
1361
- rb_define_method(rb_cLLaMAContext, "token_to_piece", RUBY_METHOD_FUNC(_llama_context_token_to_piece), 1);
1362
- rb_define_method(rb_cLLaMAContext, "n_vocab", RUBY_METHOD_FUNC(_llama_context_n_vocab), 0);
1363
1653
  rb_define_method(rb_cLLaMAContext, "n_ctx", RUBY_METHOD_FUNC(_llama_context_n_ctx), 0);
1364
- rb_define_method(rb_cLLaMAContext, "n_ctx_train", RUBY_METHOD_FUNC(_llama_context_n_ctx_train), 0);
1365
- rb_define_method(rb_cLLaMAContext, "n_embd", RUBY_METHOD_FUNC(_llama_context_n_embd), 0);
1366
1654
  rb_define_method(rb_cLLaMAContext, "timings", RUBY_METHOD_FUNC(_llama_context_get_timings), 0);
1367
1655
  rb_define_method(rb_cLLaMAContext, "print_timings", RUBY_METHOD_FUNC(_llama_context_print_timings), 0);
1368
1656
  rb_define_method(rb_cLLaMAContext, "reset_timings", RUBY_METHOD_FUNC(_llama_context_reset_timings), 0);
1369
1657
  rb_define_method(rb_cLLaMAContext, "kv_cache_token_count", RUBY_METHOD_FUNC(_llama_context_kv_cache_token_count), 0);
1658
+ rb_define_method(rb_cLLaMAContext, "kv_cache_tokens_rm", RUBY_METHOD_FUNC(_llama_context_kv_cache_tokens_rm), 2);
1659
+ rb_define_method(rb_cLLaMAContext, "kv_cache_seq_rm", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_rm), 3);
1660
+ rb_define_method(rb_cLLaMAContext, "kv_cache_seq_cp", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_cp), 4);
1661
+ rb_define_method(rb_cLLaMAContext, "kv_cache_seq_keep", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_keep), 1);
1662
+ rb_define_method(rb_cLLaMAContext, "kv_cache_seq_shift", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_shift), 4);
1370
1663
  rb_define_method(rb_cLLaMAContext, "set_rng_seed", RUBY_METHOD_FUNC(_llama_context_set_rng_seed), 1);
1371
1664
  rb_define_method(rb_cLLaMAContext, "load_session_file", RUBY_METHOD_FUNC(_llama_context_load_session_file), -1);
1372
1665
  rb_define_method(rb_cLLaMAContext, "save_session_file", RUBY_METHOD_FUNC(_llama_context_save_session_file), -1);
@@ -1378,6 +1671,7 @@ public:
1378
1671
  rb_define_method(rb_cLLaMAContext, "sample_top_p", RUBY_METHOD_FUNC(_llama_context_sample_top_p), -1);
1379
1672
  rb_define_method(rb_cLLaMAContext, "sample_tail_free", RUBY_METHOD_FUNC(_llama_context_sample_tail_free), -1);
1380
1673
  rb_define_method(rb_cLLaMAContext, "sample_typical", RUBY_METHOD_FUNC(_llama_context_sample_typical), -1);
1674
+ rb_define_method(rb_cLLaMAContext, "sample_temp", RUBY_METHOD_FUNC(_llama_context_sample_temp), -1);
1381
1675
  rb_define_method(rb_cLLaMAContext, "sample_temperature", RUBY_METHOD_FUNC(_llama_context_sample_temperature), -1);
1382
1676
  rb_define_method(rb_cLLaMAContext, "sample_token_mirostat", RUBY_METHOD_FUNC(_llama_context_sample_token_mirostat), -1);
1383
1677
  rb_define_method(rb_cLLaMAContext, "sample_token_mirostat_v2", RUBY_METHOD_FUNC(_llama_context_sample_token_mirostat_v2), -1);
@@ -1392,24 +1686,27 @@ private:
1392
1686
 
1393
1687
  static VALUE _llama_context_initialize(int argc, VALUE* argv, VALUE self) {
1394
1688
  VALUE kw_args = Qnil;
1395
- ID kw_table[1] = { rb_intern("model") };
1396
- VALUE kw_values[1] = { Qundef };
1689
+ ID kw_table[2] = { rb_intern("model"), rb_intern("params") };
1690
+ VALUE kw_values[2] = { Qundef, Qundef };
1397
1691
  rb_scan_args(argc, argv, ":", &kw_args);
1398
- rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
1692
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
1399
1693
 
1400
1694
  VALUE model = kw_values[0];
1401
1695
  if (!rb_obj_is_kind_of(model, rb_cLLaMAModel)) {
1402
1696
  rb_raise(rb_eArgError, "model must be a Model");
1403
1697
  return Qnil;
1404
1698
  }
1699
+ VALUE params = kw_values[1];
1700
+ if (!rb_obj_is_kind_of(params, rb_cLLaMAContextParams)) {
1701
+ rb_raise(rb_eArgError, "params must be a ContextParams");
1702
+ return Qnil;
1703
+ }
1405
1704
 
1406
1705
  LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
1407
1706
  if (model_ptr->model == NULL) {
1408
1707
  rb_raise(rb_eRuntimeError, "Model is empty");
1409
1708
  return Qnil;
1410
1709
  }
1411
-
1412
- VALUE params = rb_iv_get(model, "@params");
1413
1710
  LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
1414
1711
  LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
1415
1712
 
@@ -1421,6 +1718,7 @@ private:
1421
1718
  }
1422
1719
 
1423
1720
  rb_iv_set(self, "@model", model);
1721
+ rb_iv_set(self, "@params", params);
1424
1722
  rb_iv_set(self, "@has_evaluated", Qfalse);
1425
1723
 
1426
1724
  return Qnil;
@@ -1428,8 +1726,8 @@ private:
1428
1726
 
1429
1727
  static VALUE _llama_context_eval(int argc, VALUE* argv, VALUE self) {
1430
1728
  VALUE kw_args = Qnil;
1431
- ID kw_table[4] = { rb_intern("tokens"), rb_intern("n_past"), rb_intern("n_tokens"), rb_intern("n_threads") };
1432
- VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
1729
+ ID kw_table[3] = { rb_intern("tokens"), rb_intern("n_past"), rb_intern("n_tokens") };
1730
+ VALUE kw_values[3] = { Qundef, Qundef, Qundef };
1433
1731
  rb_scan_args(argc, argv, ":", &kw_args);
1434
1732
  rb_get_kwargs(kw_args, kw_table, 2, 2, kw_values);
1435
1733
 
@@ -1445,10 +1743,6 @@ private:
1445
1743
  rb_raise(rb_eArgError, "n_tokens must be an integer");
1446
1744
  return Qnil;
1447
1745
  }
1448
- if (kw_values[3] != Qundef && !RB_INTEGER_TYPE_P(kw_values[3])) {
1449
- rb_raise(rb_eArgError, "n_threads must be an integer");
1450
- return Qnil;
1451
- }
1452
1746
 
1453
1747
  const size_t tokens_len = RARRAY_LEN(kw_values[0]);
1454
1748
  std::vector<llama_token> embd(tokens_len);
@@ -1463,14 +1757,13 @@ private:
1463
1757
 
1464
1758
  const int n_tokens = kw_values[2] == Qundef ? (int)tokens_len : NUM2INT(kw_values[2]);
1465
1759
  const int n_past = NUM2INT(kw_values[1]);
1466
- const int n_threads = kw_values[3] == Qundef ? 1 : NUM2INT(kw_values[3]);
1467
1760
 
1468
1761
  LLaMAContextWrapper* ptr = get_llama_context(self);
1469
1762
  if (ptr->ctx == NULL) {
1470
1763
  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1471
1764
  return Qnil;
1472
1765
  }
1473
- if (llama_eval(ptr->ctx, embd.data(), n_tokens, n_past, n_threads) != 0) {
1766
+ if (llama_eval(ptr->ctx, embd.data(), n_tokens, n_past) != 0) {
1474
1767
  rb_raise(rb_eRuntimeError, "Failed to evaluate");
1475
1768
  return Qnil;
1476
1769
  }
@@ -1483,8 +1776,8 @@ private:
1483
1776
 
1484
1777
  static VALUE _llama_context_eval_embd(int argc, VALUE* argv, VALUE self) {
1485
1778
  VALUE kw_args = Qnil;
1486
- ID kw_table[4] = { rb_intern("embd"), rb_intern("n_past"), rb_intern("n_tokens"), rb_intern("n_threads") };
1487
- VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
1779
+ ID kw_table[3] = { rb_intern("embd"), rb_intern("n_past"), rb_intern("n_tokens") };
1780
+ VALUE kw_values[3] = { Qundef, Qundef, Qundef };
1488
1781
  rb_scan_args(argc, argv, ":", &kw_args);
1489
1782
  rb_get_kwargs(kw_args, kw_table, 2, 2, kw_values);
1490
1783
 
@@ -1500,10 +1793,6 @@ private:
1500
1793
  rb_raise(rb_eArgError, "n_tokens must be an integer");
1501
1794
  return Qnil;
1502
1795
  }
1503
- if (kw_values[3] != Qundef && !RB_INTEGER_TYPE_P(kw_values[3])) {
1504
- rb_raise(rb_eArgError, "n_threads must be an integer");
1505
- return Qnil;
1506
- }
1507
1796
 
1508
1797
  const size_t tokens_len = RARRAY_LEN(kw_values[0]);
1509
1798
  std::vector<float> embd(tokens_len);
@@ -1518,14 +1807,13 @@ private:
1518
1807
 
1519
1808
  const int n_tokens = kw_values[2] == Qundef ? (int)tokens_len : NUM2INT(kw_values[2]);
1520
1809
  const int n_past = NUM2INT(kw_values[1]);
1521
- const int n_threads = kw_values[3] == Qundef ? 1 : NUM2INT(kw_values[3]);
1522
1810
 
1523
1811
  LLaMAContextWrapper* ptr = get_llama_context(self);
1524
1812
  if (ptr->ctx == NULL) {
1525
1813
  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1526
1814
  return Qnil;
1527
1815
  }
1528
- if (llama_eval_embd(ptr->ctx, embd.data(), n_tokens, n_past, n_threads) != 0) {
1816
+ if (llama_eval_embd(ptr->ctx, embd.data(), n_tokens, n_past) != 0) {
1529
1817
  rb_raise(rb_eRuntimeError, "Failed to evaluate");
1530
1818
  return Qnil;
1531
1819
  }
@@ -1536,91 +1824,22 @@ private:
1536
1824
  return Qnil;
1537
1825
  }
1538
1826
 
1539
- static VALUE _llama_context_eval_export(VALUE self, VALUE fname_) {
1827
+ static VALUE _llama_context_decode(VALUE self, VALUE batch) {
1540
1828
  LLaMAContextWrapper* ptr = get_llama_context(self);
1541
1829
  if (ptr->ctx == NULL) {
1542
1830
  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1543
1831
  return Qnil;
1544
1832
  }
1545
- if (!RB_TYPE_P(fname_, T_STRING)) {
1546
- rb_raise(rb_eArgError, "fname must be a string");
1833
+ if (!rb_obj_is_kind_of(batch, rb_cLLaMABatch)) {
1834
+ rb_raise(rb_eArgError, "batch must be a Batch");
1547
1835
  return Qnil;
1548
1836
  }
1549
- const char* fname = StringValueCStr(fname_);
1550
- if (llama_eval_export(ptr->ctx, fname) != 0) {
1551
- return Qfalse;
1552
- }
1553
- RB_GC_GUARD(fname_);
1554
- return Qtrue;
1555
- }
1556
-
1557
- static VALUE _llama_context_tokenize(int argc, VALUE* argv, VALUE self) {
1558
- VALUE kw_args = Qnil;
1559
- ID kw_table[3] = { rb_intern("text"), rb_intern("n_max_tokens"), rb_intern("add_bos") };
1560
- VALUE kw_values[3] = { Qundef, Qundef, Qundef };
1561
- rb_scan_args(argc, argv, ":", &kw_args);
1562
- rb_get_kwargs(kw_args, kw_table, 1, 2, kw_values);
1563
-
1564
- if (!RB_TYPE_P(kw_values[0], T_STRING)) {
1565
- rb_raise(rb_eArgError, "text must be a String");
1837
+ LLaMABatchWrapper* batch_ptr = RbLLaMABatch::get_llama_batch(batch);
1838
+ if (llama_decode(ptr->ctx, batch_ptr->batch) < 0) {
1839
+ rb_raise(rb_eRuntimeError, "Failed to decode");
1566
1840
  return Qnil;
1567
1841
  }
1568
- if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
1569
- rb_raise(rb_eArgError, "n_max_tokens must be an integer");
1570
- return Qnil;
1571
- }
1572
- if (kw_values[2] != Qundef && (kw_values[2] != Qtrue && kw_values[2] != Qfalse)) {
1573
- rb_raise(rb_eArgError, "add_bos must be a boolean");
1574
- return Qnil;
1575
- }
1576
-
1577
- VALUE text_ = kw_values[0];
1578
- std::string text = StringValueCStr(text_);
1579
- const bool add_bos = kw_values[2] == Qtrue ? true : false;
1580
- const int n_max_tokens = kw_values[1] != Qundef ? NUM2INT(kw_values[1]) : text.size() + (add_bos ? 1 : 0);
1581
-
1582
- std::vector<llama_token> tokens(n_max_tokens);
1583
- LLaMAContextWrapper* ptr = get_llama_context(self);
1584
- if (ptr->ctx == NULL) {
1585
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1586
- return Qnil;
1587
- }
1588
- const int n = llama_tokenize(ptr->ctx, text.c_str(), text.size(), tokens.data(), n_max_tokens, add_bos);
1589
- if (n < 0) {
1590
- rb_raise(rb_eRuntimeError, "Failed to tokenize");
1591
- return Qnil;
1592
- }
1593
-
1594
- VALUE output = rb_ary_new();
1595
- for (int i = 0; i < n; i++) {
1596
- rb_ary_push(output, INT2NUM(tokens[i]));
1597
- }
1598
-
1599
- RB_GC_GUARD(text_);
1600
- return output;
1601
- }
1602
-
1603
- static VALUE _llama_context_token_to_piece(VALUE self, VALUE token_) {
1604
- LLaMAContextWrapper* ptr = get_llama_context(self);
1605
- if (ptr->ctx == NULL) {
1606
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1607
- return Qnil;
1608
- }
1609
- const llama_token token = NUM2INT(token_);
1610
- std::vector<char> result(8, 0);
1611
- const int n_tokens = llama_token_to_piece(ptr->ctx, token, result.data(), result.size());
1612
- if (n_tokens < 0) {
1613
- result.resize(-n_tokens);
1614
- const int check = llama_token_to_piece(ptr->ctx, token, result.data(), result.size());
1615
- if (check != -n_tokens) {
1616
- rb_raise(rb_eRuntimeError, "failed to convert");
1617
- return Qnil;
1618
- }
1619
- } else {
1620
- result.resize(n_tokens);
1621
- }
1622
- std::string ret(result.data(), result.size());
1623
- return rb_str_new_cstr(ret.c_str());
1842
+ return Qnil;
1624
1843
  }
1625
1844
 
1626
1845
  static VALUE _llama_context_logits(VALUE self) {
@@ -1635,10 +1854,11 @@ private:
1635
1854
  }
1636
1855
 
1637
1856
  VALUE model = rb_iv_get(self, "@model");
1638
- VALUE params = rb_iv_get(model, "@params");
1857
+ LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
1858
+ VALUE params = rb_iv_get(self, "@params");
1639
1859
  LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
1640
1860
  const int n_tokens = prms_ptr->params.logits_all ? NUM2INT(rb_iv_get(self, "@n_tokens")) : 1;
1641
- const int n_vocab = llama_n_vocab(ptr->ctx);
1861
+ const int n_vocab = llama_n_vocab(model_ptr->model);
1642
1862
  const float* logits = llama_get_logits(ptr->ctx);
1643
1863
  VALUE output = rb_ary_new();
1644
1864
  for (int i = 0; i < n_tokens * n_vocab; i++) {
@@ -1655,7 +1875,8 @@ private:
1655
1875
  return Qnil;
1656
1876
  }
1657
1877
  VALUE model = rb_iv_get(self, "@model");
1658
- VALUE params = rb_iv_get(model, "@params");
1878
+ LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
1879
+ VALUE params = rb_iv_get(self, "@params");
1659
1880
  LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
1660
1881
  if (!prms_ptr->params.embedding) {
1661
1882
  rb_raise(rb_eRuntimeError, "embedding parameter is false");
@@ -1666,7 +1887,7 @@ private:
1666
1887
  return Qnil;
1667
1888
  }
1668
1889
 
1669
- const int n_embd = llama_n_embd(ptr->ctx);
1890
+ const int n_embd = llama_n_embd(model_ptr->model);
1670
1891
  const float* embd = llama_get_embeddings(ptr->ctx);
1671
1892
  VALUE output = rb_ary_new();
1672
1893
  for (int i = 0; i < n_embd; i++) {
@@ -1736,81 +1957,104 @@ private:
1736
1957
  return INT2NUM(llama_token_nl(ptr->ctx));
1737
1958
  }
1738
1959
 
1739
- static VALUE _llama_context_n_vocab(VALUE self) {
1960
+ static VALUE _llama_context_n_ctx(VALUE self) {
1740
1961
  LLaMAContextWrapper* ptr = get_llama_context(self);
1741
1962
  if (ptr->ctx == NULL) {
1742
1963
  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1743
1964
  return Qnil;
1744
1965
  }
1745
- return INT2NUM(llama_n_vocab(ptr->ctx));
1966
+ return INT2NUM(llama_n_ctx(ptr->ctx));
1746
1967
  }
1747
1968
 
1748
- static VALUE _llama_context_n_ctx(VALUE self) {
1969
+ static VALUE _llama_context_get_timings(VALUE self) {
1749
1970
  LLaMAContextWrapper* ptr = get_llama_context(self);
1750
1971
  if (ptr->ctx == NULL) {
1751
1972
  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1752
1973
  return Qnil;
1753
1974
  }
1754
- return INT2NUM(llama_n_ctx(ptr->ctx));
1975
+ VALUE tm_obj = rb_funcall(rb_cLLaMATimings, rb_intern("new"), 0);
1976
+ LLaMATimingsWrapper* tm_ptr = RbLLaMATimings::get_llama_timings(tm_obj);
1977
+ tm_ptr->timings = llama_get_timings(ptr->ctx);
1978
+ return tm_obj;
1755
1979
  }
1756
1980
 
1757
- static VALUE _llama_context_n_ctx_train(VALUE self) {
1981
+ static VALUE _llama_context_print_timings(VALUE self) {
1758
1982
  LLaMAContextWrapper* ptr = get_llama_context(self);
1759
1983
  if (ptr->ctx == NULL) {
1760
1984
  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1761
1985
  return Qnil;
1762
1986
  }
1763
- return INT2NUM(llama_n_ctx_train(ptr->ctx));
1987
+ llama_print_timings(ptr->ctx);
1988
+ return Qnil;
1764
1989
  }
1765
1990
 
1766
- static VALUE _llama_context_n_embd(VALUE self) {
1991
+ static VALUE _llama_context_reset_timings(VALUE self) {
1767
1992
  LLaMAContextWrapper* ptr = get_llama_context(self);
1768
1993
  if (ptr->ctx == NULL) {
1769
1994
  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1770
1995
  return Qnil;
1771
1996
  }
1772
- return INT2NUM(llama_n_embd(ptr->ctx));
1997
+ llama_reset_timings(ptr->ctx);
1998
+ return Qnil;
1773
1999
  }
1774
2000
 
1775
- static VALUE _llama_context_get_timings(VALUE self) {
2001
+ static VALUE _llama_context_kv_cache_token_count(VALUE self) {
1776
2002
  LLaMAContextWrapper* ptr = get_llama_context(self);
1777
2003
  if (ptr->ctx == NULL) {
1778
2004
  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1779
2005
  return Qnil;
1780
2006
  }
1781
- VALUE tm_obj = rb_funcall(rb_cLLaMATimings, rb_intern("new"), 0);
1782
- LLaMATimingsWrapper* tm_ptr = RbLLaMATimings::get_llama_timings(tm_obj);
1783
- tm_ptr->timings = llama_get_timings(ptr->ctx);
1784
- return tm_obj;
2007
+ return INT2NUM(llama_get_kv_cache_token_count(ptr->ctx));
1785
2008
  }
1786
2009
 
1787
- static VALUE _llama_context_print_timings(VALUE self) {
2010
+ static VALUE _llama_context_kv_cache_tokens_rm(VALUE self, VALUE c0, VALUE c1) {
1788
2011
  LLaMAContextWrapper* ptr = get_llama_context(self);
1789
2012
  if (ptr->ctx == NULL) {
1790
2013
  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1791
2014
  return Qnil;
1792
2015
  }
1793
- llama_print_timings(ptr->ctx);
2016
+ llama_kv_cache_tokens_rm(ptr->ctx, NUM2INT(c0), NUM2INT(c1));
1794
2017
  return Qnil;
1795
2018
  }
1796
2019
 
1797
- static VALUE _llama_context_reset_timings(VALUE self) {
2020
+ static VALUE _llama_context_kv_cache_seq_rm(VALUE self, VALUE seq_id, VALUE p0, VALUE p1) {
1798
2021
  LLaMAContextWrapper* ptr = get_llama_context(self);
1799
2022
  if (ptr->ctx == NULL) {
1800
2023
  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1801
2024
  return Qnil;
1802
2025
  }
1803
- llama_reset_timings(ptr->ctx);
2026
+ llama_kv_cache_seq_rm(ptr->ctx, NUM2INT(seq_id), NUM2INT(p0), NUM2INT(p1));
1804
2027
  return Qnil;
1805
2028
  }
1806
2029
 
1807
- static VALUE _llama_context_kv_cache_token_count(VALUE self) {
2030
+ static VALUE _llama_context_kv_cache_seq_cp(VALUE self, VALUE seq_id_src, VALUE seq_id_dst, VALUE p0, VALUE p1) {
1808
2031
  LLaMAContextWrapper* ptr = get_llama_context(self);
1809
2032
  if (ptr->ctx == NULL) {
1810
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2033
+ rb_raise(rb_eArgError, "LLaMA context is not initialized");
1811
2034
  return Qnil;
1812
2035
  }
1813
- return INT2NUM(llama_get_kv_cache_token_count(ptr->ctx));
2036
+ llama_kv_cache_seq_cp(ptr->ctx, NUM2INT(seq_id_src), NUM2INT(seq_id_dst), NUM2INT(p0), NUM2INT(p1));
2037
+ return Qnil;
2038
+ }
2039
+
2040
+ static VALUE _llama_context_kv_cache_seq_keep(VALUE self, VALUE seq_id) {
2041
+ LLaMAContextWrapper* ptr = get_llama_context(self);
2042
+ if (ptr->ctx == NULL) {
2043
+ rb_raise(rb_eArgError, "LLaMA context is not initialized");
2044
+ return Qnil;
2045
+ }
2046
+ llama_kv_cache_seq_keep(ptr->ctx, NUM2INT(seq_id));
2047
+ return Qnil;
2048
+ }
2049
+
2050
+ static VALUE _llama_context_kv_cache_seq_shift(VALUE self, VALUE seq_id, VALUE p0, VALUE p1, VALUE delta) {
2051
+ LLaMAContextWrapper* ptr = get_llama_context(self);
2052
+ if (ptr->ctx == NULL) {
2053
+ rb_raise(rb_eArgError, "LLaMA context is not initialized");
2054
+ return Qnil;
2055
+ }
2056
+ llama_kv_cache_seq_shift(ptr->ctx, NUM2INT(seq_id), NUM2INT(p0), NUM2INT(p1), NUM2INT(delta));
2057
+ return Qnil;
1814
2058
  }
1815
2059
 
1816
2060
  static VALUE _llama_context_set_rng_seed(VALUE self, VALUE seed_) {
@@ -1851,7 +2095,7 @@ private:
1851
2095
  }
1852
2096
 
1853
2097
  VALUE model = rb_iv_get(self, "@model");
1854
- VALUE params = rb_iv_get(model, "@params");
2098
+ VALUE params = rb_iv_get(self, "@params");
1855
2099
  LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
1856
2100
  const int n_ctx = prms_ptr->params.n_ctx;
1857
2101
 
@@ -2235,6 +2479,40 @@ private:
2235
2479
  return Qnil;
2236
2480
  }
2237
2481
 
2482
+ static VALUE _llama_context_sample_temp(int argc, VALUE* argv, VALUE self) {
2483
+ VALUE kw_args = Qnil;
2484
+ ID kw_table[1] = { rb_intern("temp") };
2485
+ VALUE kw_values[1] = { Qundef };
2486
+ VALUE candidates = Qnil;
2487
+ rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
2488
+ rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
2489
+
2490
+ if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
2491
+ rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
2492
+ return Qnil;
2493
+ }
2494
+ if (!RB_FLOAT_TYPE_P(kw_values[0])) {
2495
+ rb_raise(rb_eArgError, "temp must be a float");
2496
+ return Qnil;
2497
+ }
2498
+
2499
+ LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
2500
+ if (ctx_ptr->ctx == NULL) {
2501
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2502
+ return Qnil;
2503
+ }
2504
+ LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
2505
+ if (cnd_ptr->array.data == nullptr) {
2506
+ rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
2507
+ return Qnil;
2508
+ }
2509
+ const float temp = NUM2DBL(kw_values[0]);
2510
+
2511
+ llama_sample_temp(ctx_ptr->ctx, &(cnd_ptr->array), temp);
2512
+
2513
+ return Qnil;
2514
+ }
2515
+
2238
2516
  static VALUE _llama_context_sample_temperature(int argc, VALUE* argv, VALUE self) {
2239
2517
  VALUE kw_args = Qnil;
2240
2518
  ID kw_table[1] = { rb_intern("temperature") };
@@ -2560,6 +2838,7 @@ extern "C" void Init_llama_cpp(void) {
2560
2838
  RbLLaMATokenData::define_class(rb_mLLaMACpp);
2561
2839
  RbLLaMATokenDataArray::define_class(rb_mLLaMACpp);
2562
2840
  RbLLaMAModel::define_class(rb_mLLaMACpp);
2841
+ RbLLaMAModelParams::define_class(rb_mLLaMACpp);
2563
2842
  RbLLaMATimings::define_class(rb_mLLaMACpp);
2564
2843
  RbLLaMAContext::define_class(rb_mLLaMACpp);
2565
2844
  RbLLaMAContextParams::define_class(rb_mLLaMACpp);
@@ -2578,10 +2857,6 @@ extern "C" void Init_llama_cpp(void) {
2578
2857
 
2579
2858
  rb_define_const(rb_mLLaMACpp, "LLAMA_MAX_DEVICES", INT2NUM(LLAMA_MAX_DEVICES));
2580
2859
 
2581
- rb_define_const(rb_mLLaMACpp, "LLAMA_LOG_LEVEL_ERROR", INT2NUM(LLAMA_LOG_LEVEL_ERROR));
2582
- rb_define_const(rb_mLLaMACpp, "LLAMA_LOG_LEVEL_WARN", INT2NUM(LLAMA_LOG_LEVEL_WARN));
2583
- rb_define_const(rb_mLLaMACpp, "LLAMA_LOG_LEVEL_INFO", INT2NUM(LLAMA_LOG_LEVEL_INFO));
2584
-
2585
2860
  rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_TYPE_SPM", INT2NUM(LLAMA_VOCAB_TYPE_SPM));
2586
2861
  rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_TYPE_BPE", INT2NUM(LLAMA_VOCAB_TYPE_BPE));
2587
2862