llama_cpp 0.0.7 → 0.1.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -4,6 +4,255 @@
4
4
  VALUE rb_mLLaMACpp;
5
5
  VALUE rb_cLLaMAContext;
6
6
  VALUE rb_cLLaMAContextParams;
7
+ VALUE rb_cLLaMATokenData;
8
+ VALUE rb_cLLaMATokenDataArray;
9
+
10
+ class LLaMATokenDataWrapper {
11
+ public:
12
+ llama_token_data data;
13
+
14
+ LLaMATokenDataWrapper() {
15
+ data.id = 0;
16
+ data.logit = 0.0;
17
+ data.p = 0.0;
18
+ };
19
+
20
+ ~LLaMATokenDataWrapper(){};
21
+ };
22
+
23
+ class RbLLaMATokenData {
24
+ public:
25
+ static VALUE llama_token_data_alloc(VALUE self) {
26
+ LLaMATokenDataWrapper* ptr = (LLaMATokenDataWrapper*)ruby_xmalloc(sizeof(LLaMATokenDataWrapper));
27
+ new (ptr) LLaMATokenDataWrapper();
28
+ return TypedData_Wrap_Struct(self, &llama_token_data_type, ptr);
29
+ };
30
+
31
+ static void llama_token_data_free(void* ptr) {
32
+ ((LLaMATokenDataWrapper*)ptr)->~LLaMATokenDataWrapper();
33
+ ruby_xfree(ptr);
34
+ };
35
+
36
+ static size_t llama_token_data_size(const void* ptr) {
37
+ return sizeof(*((LLaMATokenDataWrapper*)ptr));
38
+ };
39
+
40
+ static LLaMATokenDataWrapper* get_llama_token_data(VALUE self) {
41
+ LLaMATokenDataWrapper* ptr;
42
+ TypedData_Get_Struct(self, LLaMATokenDataWrapper, &llama_token_data_type, ptr);
43
+ return ptr;
44
+ };
45
+
46
+ static void define_class(VALUE outer) {
47
+ rb_cLLaMATokenData = rb_define_class_under(outer, "TokenData", rb_cObject);
48
+ rb_define_alloc_func(rb_cLLaMATokenData, llama_token_data_alloc);
49
+ rb_define_method(rb_cLLaMATokenData, "initialize", RUBY_METHOD_FUNC(_llama_token_data_init), -1);
50
+ rb_define_method(rb_cLLaMATokenData, "id=", RUBY_METHOD_FUNC(_llama_token_data_set_id), 1);
51
+ rb_define_method(rb_cLLaMATokenData, "id", RUBY_METHOD_FUNC(_llama_token_data_get_id), 0);
52
+ rb_define_method(rb_cLLaMATokenData, "logit=", RUBY_METHOD_FUNC(_llama_token_data_set_logit), 1);
53
+ rb_define_method(rb_cLLaMATokenData, "logit", RUBY_METHOD_FUNC(_llama_token_data_get_logit), 0);
54
+ rb_define_method(rb_cLLaMATokenData, "p=", RUBY_METHOD_FUNC(_llama_token_data_set_p), 1);
55
+ rb_define_method(rb_cLLaMATokenData, "p", RUBY_METHOD_FUNC(_llama_token_data_get_p), 0);
56
+ }
57
+
58
+ private:
59
+ static const rb_data_type_t llama_token_data_type;
60
+
61
+ static VALUE _llama_token_data_init(int argc, VALUE* argv, VALUE self) {
62
+ VALUE kw_args = Qnil;
63
+ ID kw_table[3] = { rb_intern("id"), rb_intern("logit"), rb_intern("p") };
64
+ VALUE kw_values[3] = { Qundef, Qundef, Qundef };
65
+ rb_scan_args(argc, argv, ":", &kw_args);
66
+ rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
67
+
68
+ if (!RB_INTEGER_TYPE_P(kw_values[0])) {
69
+ rb_raise(rb_eArgError, "id must be an integer");
70
+ return Qnil;
71
+ }
72
+ if (!RB_FLOAT_TYPE_P(kw_values[1])) {
73
+ rb_raise(rb_eArgError, "logit must be a float");
74
+ return Qnil;
75
+ }
76
+ if (!RB_FLOAT_TYPE_P(kw_values[2])) {
77
+ rb_raise(rb_eArgError, "p must be a float");
78
+ return Qnil;
79
+ }
80
+
81
+ LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
82
+ new (ptr) LLaMATokenDataWrapper();
83
+
84
+ ptr->data.id = NUM2INT(kw_values[0]);
85
+ ptr->data.logit = NUM2DBL(kw_values[1]);
86
+ ptr->data.p = NUM2DBL(kw_values[2]);
87
+
88
+ return self;
89
+ }
90
+
91
+ // id
92
+ static VALUE _llama_token_data_set_id(VALUE self, VALUE id) {
93
+ LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
94
+ ptr->data.id = NUM2INT(id);
95
+ return INT2NUM(ptr->data.id);
96
+ };
97
+
98
+ static VALUE _llama_token_data_get_id(VALUE self) {
99
+ LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
100
+ return INT2NUM(ptr->data.id);
101
+ };
102
+
103
+ // logit
104
+ static VALUE _llama_token_data_set_logit(VALUE self, VALUE logit) {
105
+ LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
106
+ ptr->data.logit = NUM2DBL(logit);
107
+ return DBL2NUM(ptr->data.logit);
108
+ };
109
+
110
+ static VALUE _llama_token_data_get_logit(VALUE self) {
111
+ LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
112
+ return DBL2NUM(ptr->data.logit);
113
+ };
114
+
115
+ // p
116
+ static VALUE _llama_token_data_set_p(VALUE self, VALUE p) {
117
+ LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
118
+ ptr->data.p = NUM2DBL(p);
119
+ return DBL2NUM(ptr->data.p);
120
+ };
121
+
122
+ static VALUE _llama_token_data_get_p(VALUE self) {
123
+ LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
124
+ return DBL2NUM(ptr->data.p);
125
+ };
126
+ };
127
+
128
+ const rb_data_type_t RbLLaMATokenData::llama_token_data_type = {
129
+ "RbLLaMATokenData",
130
+ { NULL,
131
+ RbLLaMATokenData::llama_token_data_free,
132
+ RbLLaMATokenData::llama_token_data_size },
133
+ NULL,
134
+ NULL,
135
+ RUBY_TYPED_FREE_IMMEDIATELY
136
+ };
137
+
138
+ class LLaMATokenDataArrayWrapper {
139
+ public:
140
+ llama_token_data_array array;
141
+
142
+ LLaMATokenDataArrayWrapper() {
143
+ array.data = nullptr;
144
+ array.size = 0;
145
+ array.sorted = false;
146
+ };
147
+
148
+ ~LLaMATokenDataArrayWrapper() {
149
+ if (array.data) {
150
+ ruby_xfree(array.data);
151
+ array.data = nullptr;
152
+ }
153
+ };
154
+ };
155
+
156
+ class RbLLaMATokenDataArray {
157
+ public:
158
+ static VALUE llama_token_data_array_alloc(VALUE self) {
159
+ LLaMATokenDataArrayWrapper* ptr = (LLaMATokenDataArrayWrapper*)ruby_xmalloc(sizeof(LLaMATokenDataArrayWrapper));
160
+ new (ptr) LLaMATokenDataArrayWrapper();
161
+ return TypedData_Wrap_Struct(self, &llama_token_data_array_type, ptr);
162
+ };
163
+
164
+ static void llama_token_data_array_free(void* ptr) {
165
+ ((LLaMATokenDataArrayWrapper*)ptr)->~LLaMATokenDataArrayWrapper();
166
+ ruby_xfree(ptr);
167
+ };
168
+
169
+ static size_t llama_token_data_array_size(const void* ptr) {
170
+ return sizeof(*((LLaMATokenDataArrayWrapper*)ptr));
171
+ };
172
+
173
+ static LLaMATokenDataArrayWrapper* get_llama_token_data_array(VALUE self) {
174
+ LLaMATokenDataArrayWrapper* ptr;
175
+ TypedData_Get_Struct(self, LLaMATokenDataArrayWrapper, &llama_token_data_array_type, ptr);
176
+ return ptr;
177
+ };
178
+
179
+ static void define_class(VALUE outer) {
180
+ rb_cLLaMATokenDataArray = rb_define_class_under(outer, "TokenDataArray", rb_cObject);
181
+ rb_define_alloc_func(rb_cLLaMATokenDataArray, llama_token_data_array_alloc);
182
+ rb_define_method(rb_cLLaMATokenDataArray, "initialize", RUBY_METHOD_FUNC(_llama_token_data_array_init), -1);
183
+ rb_define_method(rb_cLLaMATokenDataArray, "size", RUBY_METHOD_FUNC(_llama_token_data_array_get_size), 0);
184
+ rb_define_method(rb_cLLaMATokenDataArray, "sorted", RUBY_METHOD_FUNC(_llama_token_data_array_get_sorted), 0);
185
+ };
186
+
187
+ private:
188
+ static const rb_data_type_t llama_token_data_array_type;
189
+
190
+ static VALUE _llama_token_data_array_init(int argc, VALUE* argv, VALUE self) {
191
+ VALUE kw_args = Qnil;
192
+ ID kw_table[1] = { rb_intern("sorted") };
193
+ VALUE kw_values[1] = { Qundef };
194
+ VALUE arr = Qnil;
195
+ rb_scan_args(argc, argv, "1:", &arr, &kw_args);
196
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
197
+
198
+ if (!RB_TYPE_P(arr, T_ARRAY)) {
199
+ rb_raise(rb_eArgError, "1st argument must be an array");
200
+ return Qnil;
201
+ }
202
+ size_t sz_array = RARRAY_LEN(arr);
203
+ if (sz_array == 0) {
204
+ rb_raise(rb_eArgError, "array must not be empty");
205
+ return Qnil;
206
+ }
207
+ if (kw_values[0] != Qundef && !RB_TYPE_P(kw_values[0], T_TRUE) && !RB_TYPE_P(kw_values[0], T_FALSE)) {
208
+ rb_raise(rb_eArgError, "sorted must be a boolean");
209
+ return Qnil;
210
+ }
211
+
212
+ LLaMATokenDataArrayWrapper* ptr = get_llama_token_data_array(self);
213
+ new (ptr) LLaMATokenDataArrayWrapper();
214
+
215
+ ptr->array.data = (llama_token_data*)ruby_xmalloc(sizeof(llama_token_data) * sz_array);
216
+ for (size_t i = 0; i < sz_array; ++i) {
217
+ VALUE el = rb_ary_entry(arr, i);
218
+ if (!rb_obj_is_kind_of(el, rb_cLLaMATokenData)) {
219
+ rb_raise(rb_eArgError, "array element must be a TokenData");
220
+ xfree(ptr->array.data);
221
+ ptr->array.data = nullptr;
222
+ return Qnil;
223
+ }
224
+ llama_token_data token_data = RbLLaMATokenData::get_llama_token_data(el)->data;
225
+ ptr->array.data[i].id = token_data.id;
226
+ ptr->array.data[i].logit = token_data.logit;
227
+ ptr->array.data[i].p = token_data.p;
228
+ }
229
+
230
+ ptr->array.size = sz_array;
231
+ ptr->array.sorted = kw_values[0] == Qtrue;
232
+
233
+ return self;
234
+ };
235
+
236
+ static VALUE _llama_token_data_array_get_size(VALUE self) {
237
+ LLaMATokenDataArrayWrapper* ptr = get_llama_token_data_array(self);
238
+ return SIZET2NUM(ptr->array.size);
239
+ };
240
+
241
+ static VALUE _llama_token_data_array_get_sorted(VALUE self) {
242
+ LLaMATokenDataArrayWrapper* ptr = get_llama_token_data_array(self);
243
+ return ptr->array.sorted ? Qtrue : Qfalse;
244
+ };
245
+ };
246
+
247
+ const rb_data_type_t RbLLaMATokenDataArray::llama_token_data_array_type = {
248
+ "RbLLaMATokenDataArray",
249
+ { NULL,
250
+ RbLLaMATokenDataArray::llama_token_data_array_free,
251
+ RbLLaMATokenDataArray::llama_token_data_array_size },
252
+ NULL,
253
+ NULL,
254
+ RUBY_TYPED_FREE_IMMEDIATELY
255
+ };
7
256
 
8
257
  class LLaMAContextParamsWrapper {
9
258
  public:
@@ -43,8 +292,6 @@ public:
43
292
  // rb_define_method(rb_cLLaMAContextParams, "initialize", RUBY_METHOD_FUNC(_llama_context_params_init), 0);
44
293
  rb_define_method(rb_cLLaMAContextParams, "n_ctx=", RUBY_METHOD_FUNC(_llama_context_params_set_n_ctx), 1);
45
294
  rb_define_method(rb_cLLaMAContextParams, "n_ctx", RUBY_METHOD_FUNC(_llama_context_params_get_n_ctx), 0);
46
- rb_define_method(rb_cLLaMAContextParams, "n_parts=", RUBY_METHOD_FUNC(_llama_context_params_set_n_parts), 1);
47
- rb_define_method(rb_cLLaMAContextParams, "n_parts", RUBY_METHOD_FUNC(_llama_context_params_get_n_parts), 0);
48
295
  rb_define_method(rb_cLLaMAContextParams, "seed=", RUBY_METHOD_FUNC(_llama_context_params_set_seed), 1);
49
296
  rb_define_method(rb_cLLaMAContextParams, "seed", RUBY_METHOD_FUNC(_llama_context_params_get_seed), 0);
50
297
  rb_define_method(rb_cLLaMAContextParams, "f16_kv=", RUBY_METHOD_FUNC(_llama_context_params_set_f16_kv), 1);
@@ -82,18 +329,6 @@ private:
82
329
  return INT2NUM(ptr->params.n_ctx);
83
330
  };
84
331
 
85
- // n_parts
86
- static VALUE _llama_context_params_set_n_parts(VALUE self, VALUE n_parts) {
87
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
88
- ptr->params.n_parts = NUM2INT(n_parts);
89
- return INT2NUM(ptr->params.n_parts);
90
- };
91
-
92
- static VALUE _llama_context_params_get_n_parts(VALUE self) {
93
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
94
- return INT2NUM(ptr->params.n_parts);
95
- };
96
-
97
332
  // seed
98
333
  static VALUE _llama_context_params_set_seed(VALUE self, VALUE seed) {
99
334
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
@@ -234,7 +469,6 @@ public:
234
469
  rb_define_method(rb_cLLaMAContext, "logits", RUBY_METHOD_FUNC(_llama_context_logits), 0);
235
470
  rb_define_method(rb_cLLaMAContext, "embeddings", RUBY_METHOD_FUNC(_llama_context_embeddings), 0);
236
471
  rb_define_method(rb_cLLaMAContext, "token_to_str", RUBY_METHOD_FUNC(_llama_context_token_to_str), 1);
237
- rb_define_method(rb_cLLaMAContext, "sample_top_p_top_k", RUBY_METHOD_FUNC(_llama_context_sample_top_p_top_k), -1);
238
472
  rb_define_method(rb_cLLaMAContext, "n_vocab", RUBY_METHOD_FUNC(_llama_context_n_vocab), 0);
239
473
  rb_define_method(rb_cLLaMAContext, "n_ctx", RUBY_METHOD_FUNC(_llama_context_n_ctx), 0);
240
474
  rb_define_method(rb_cLLaMAContext, "n_embd", RUBY_METHOD_FUNC(_llama_context_n_embd), 0);
@@ -244,6 +478,22 @@ public:
244
478
  rb_define_method(rb_cLLaMAContext, "free", RUBY_METHOD_FUNC(_llama_context_free), 0);
245
479
  rb_define_method(rb_cLLaMAContext, "load", RUBY_METHOD_FUNC(_llama_context_load), -1);
246
480
  rb_define_method(rb_cLLaMAContext, "apply_lora_from_file", RUBY_METHOD_FUNC(_llama_context_apply_lora_from_file), -1);
481
+ rb_define_method(rb_cLLaMAContext, "kv_cache_token_count", RUBY_METHOD_FUNC(_llama_context_kv_cache_token_count), 0);
482
+ rb_define_method(rb_cLLaMAContext, "set_rng_seed", RUBY_METHOD_FUNC(_llama_context_set_rng_seed), 1);
483
+ rb_define_method(rb_cLLaMAContext, "load_session_file", RUBY_METHOD_FUNC(_llama_context_load_session_file), -1);
484
+ rb_define_method(rb_cLLaMAContext, "save_session_file", RUBY_METHOD_FUNC(_llama_context_save_session_file), -1);
485
+ rb_define_method(rb_cLLaMAContext, "sample_repetition_penalty", RUBY_METHOD_FUNC(_llama_context_sample_repetition_penalty), -1);
486
+ rb_define_method(rb_cLLaMAContext, "sample_frequency_and_presence_penalties", RUBY_METHOD_FUNC(_llama_context_sample_frequency_and_presence_penalties), -1);
487
+ rb_define_method(rb_cLLaMAContext, "sample_softmax", RUBY_METHOD_FUNC(_llama_context_sample_softmax), 1);
488
+ rb_define_method(rb_cLLaMAContext, "sample_top_k", RUBY_METHOD_FUNC(_llama_context_sample_top_k), -1);
489
+ rb_define_method(rb_cLLaMAContext, "sample_top_p", RUBY_METHOD_FUNC(_llama_context_sample_top_p), -1);
490
+ rb_define_method(rb_cLLaMAContext, "sample_tail_free", RUBY_METHOD_FUNC(_llama_context_sample_tail_free), -1);
491
+ rb_define_method(rb_cLLaMAContext, "sample_typical", RUBY_METHOD_FUNC(_llama_context_sample_typical), -1);
492
+ rb_define_method(rb_cLLaMAContext, "sample_temperature", RUBY_METHOD_FUNC(_llama_context_sample_temperature), -1);
493
+ rb_define_method(rb_cLLaMAContext, "sample_token_mirostat", RUBY_METHOD_FUNC(_llama_context_sample_token_mirostat), -1);
494
+ rb_define_method(rb_cLLaMAContext, "sample_token_mirostat_v2", RUBY_METHOD_FUNC(_llama_context_sample_token_mirostat_v2), -1);
495
+ rb_define_method(rb_cLLaMAContext, "sample_token_greedy", RUBY_METHOD_FUNC(_llama_context_sample_token_greedy), 1);
496
+ rb_define_method(rb_cLLaMAContext, "sample_token", RUBY_METHOD_FUNC(_llama_context_sample_token), 1);
247
497
  };
248
498
 
249
499
  private:
@@ -448,40 +698,6 @@ private:
448
698
  return output;
449
699
  };
450
700
 
451
- static VALUE _llama_context_sample_top_p_top_k(int argc, VALUE* argv, VALUE self) {
452
- VALUE last_n_tokens = Qnil;
453
- VALUE kw_args = Qnil;
454
- ID kw_table[4] = { rb_intern("top_k"), rb_intern("top_p"), rb_intern("temp"), rb_intern("penalty") };
455
- VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
456
- rb_scan_args(argc, argv, "1:", &last_n_tokens, &kw_args);
457
- rb_get_kwargs(kw_args, kw_table, 4, 0, kw_values);
458
-
459
- if (!RB_TYPE_P(last_n_tokens, T_ARRAY)) {
460
- rb_raise(rb_eArgError, "last_n_tokens must be an Array");
461
- return Qnil;
462
- }
463
-
464
- const int last_n_tokens_size = RARRAY_LEN(last_n_tokens);
465
- const int top_k = NUM2INT(kw_values[0]);
466
- const double top_p = NUM2DBL(kw_values[1]);
467
- const double temp = NUM2DBL(kw_values[2]);
468
- const double penalty = NUM2DBL(kw_values[3]);
469
-
470
- std::vector<llama_token> last_n_tokens_data(last_n_tokens_size);
471
- for (int i = 0; i < last_n_tokens_size; i++) {
472
- last_n_tokens_data[i] = NUM2INT(rb_ary_entry(last_n_tokens, i));
473
- }
474
-
475
- LLaMAContextWrapper* ptr = get_llama_context(self);
476
- if (ptr->ctx == NULL) {
477
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
478
- return Qnil;
479
- }
480
- llama_token token = llama_sample_top_p_top_k(ptr->ctx, last_n_tokens_data.data(), last_n_tokens_size, top_k, top_p, temp, penalty);
481
-
482
- return INT2NUM(token);
483
- }
484
-
485
701
  static VALUE _llama_context_n_vocab(VALUE self) {
486
702
  LLaMAContextWrapper* ptr = get_llama_context(self);
487
703
  if (ptr->ctx == NULL) {
@@ -621,6 +837,562 @@ private:
621
837
  }
622
838
  return Qnil;
623
839
  };
840
+
841
+ static VALUE _llama_context_kv_cache_token_count(VALUE self) {
842
+ LLaMAContextWrapper* ptr = get_llama_context(self);
843
+ if (ptr->ctx == NULL) {
844
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
845
+ return Qnil;
846
+ }
847
+ return INT2NUM(llama_get_kv_cache_token_count(ptr->ctx));
848
+ };
849
+
850
+ static VALUE _llama_context_set_rng_seed(VALUE self, VALUE seed_) {
851
+ LLaMAContextWrapper* ptr = get_llama_context(self);
852
+ if (ptr->ctx == NULL) {
853
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
854
+ return Qnil;
855
+ }
856
+ const int seed = NUM2INT(seed_);
857
+ llama_set_rng_seed(ptr->ctx, seed);
858
+ return Qnil;
859
+ };
860
+
861
+ static VALUE _llama_context_load_session_file(int argc, VALUE* argv, VALUE self) {
862
+ VALUE kw_args = Qnil;
863
+ ID kw_table[1] = { rb_intern("session_path") };
864
+ VALUE kw_values[1] = { Qundef };
865
+ VALUE candidates = Qnil;
866
+ VALUE last_n_tokens = Qnil;
867
+ rb_scan_args(argc, argv, ":", &kw_args);
868
+ rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
869
+
870
+ if (!RB_TYPE_P(kw_values[0], T_STRING)) {
871
+ rb_raise(rb_eArgError, "session_path must be a String");
872
+ return Qnil;
873
+ }
874
+
875
+ VALUE filename = kw_values[0];
876
+
877
+ LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
878
+ if (ctx_ptr->ctx == NULL) {
879
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
880
+ return Qnil;
881
+ }
882
+
883
+ LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(rb_iv_get(self, "@params"));
884
+ const int n_ctx = prms_ptr->params.n_ctx;
885
+
886
+ std::vector<llama_token> session_tokens(n_ctx);
887
+ size_t n_token_count_out = 0;
888
+
889
+ try {
890
+ bool res = llama_load_session_file(ctx_ptr->ctx, StringValueCStr(filename), session_tokens.data(), session_tokens.capacity(), &n_token_count_out);
891
+ if (!res) {
892
+ rb_raise(rb_eRuntimeError, "Failed to load session file");
893
+ return Qnil;
894
+ }
895
+ session_tokens.resize(n_token_count_out);
896
+ } catch (const std::runtime_error& e) {
897
+ rb_raise(rb_eRuntimeError, "%s", e.what());
898
+ return Qnil;
899
+ }
900
+
901
+ VALUE ary_session_tokens = rb_ary_new2(n_token_count_out);
902
+ for (size_t i = 0; i < n_token_count_out; i++) {
903
+ rb_ary_store(ary_session_tokens, i, INT2NUM(session_tokens[i]));
904
+ }
905
+
906
+ RB_GC_GUARD(filename);
907
+ return ary_session_tokens;
908
+ }
909
+
910
+ static VALUE _llama_context_save_session_file(int argc, VALUE* argv, VALUE self) {
911
+ VALUE kw_args = Qnil;
912
+ ID kw_table[2] = { rb_intern("session_path"), rb_intern("session_tokens") };
913
+ VALUE kw_values[2] = { Qundef, Qundef };
914
+ VALUE candidates = Qnil;
915
+ VALUE last_n_tokens = Qnil;
916
+ rb_scan_args(argc, argv, ":", &kw_args);
917
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
918
+
919
+ if (!RB_TYPE_P(kw_values[0], T_STRING)) {
920
+ rb_raise(rb_eArgError, "session_path must be a String");
921
+ return Qnil;
922
+ }
923
+ if (!RB_TYPE_P(kw_values[1], T_ARRAY)) {
924
+ rb_raise(rb_eArgError, "session_tokens must be an Array");
925
+ return Qnil;
926
+ }
927
+
928
+ VALUE filename = kw_values[0];
929
+ const size_t sz_session_tokens = RARRAY_LEN(kw_values[1]);
930
+ std::vector<llama_token> session_tokens(sz_session_tokens);
931
+ for (size_t i = 0; i < sz_session_tokens; i++) {
932
+ session_tokens[i] = NUM2INT(rb_ary_entry(kw_values[1], i));
933
+ }
934
+
935
+ LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
936
+ if (ctx_ptr->ctx == NULL) {
937
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
938
+ return Qnil;
939
+ }
940
+
941
+ bool res = llama_save_session_file(ctx_ptr->ctx, StringValueCStr(filename), session_tokens.data(), sz_session_tokens);
942
+
943
+ if (!res) {
944
+ rb_raise(rb_eRuntimeError, "Failed to save session file");
945
+ return Qnil;
946
+ }
947
+
948
+ RB_GC_GUARD(filename);
949
+ return Qnil;
950
+ }
951
+
952
+ static VALUE _llama_context_sample_repetition_penalty(int argc, VALUE* argv, VALUE self) {
953
+ VALUE kw_args = Qnil;
954
+ ID kw_table[1] = { rb_intern("penalty") };
955
+ VALUE kw_values[1] = { Qundef };
956
+ VALUE candidates = Qnil;
957
+ VALUE last_n_tokens = Qnil;
958
+ rb_scan_args(argc, argv, "2:", &candidates, &last_n_tokens, &kw_args);
959
+ rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
960
+
961
+ if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
962
+ rb_raise(rb_eArgError, "candidates must be a TokenDataArray");
963
+ return Qnil;
964
+ }
965
+ if (!RB_TYPE_P(last_n_tokens, T_ARRAY)) {
966
+ rb_raise(rb_eArgError, "last_n_tokens must be an Array");
967
+ return Qnil;
968
+ }
969
+ if (!RB_FLOAT_TYPE_P(kw_values[0])) {
970
+ rb_raise(rb_eArgError, "penalty must be a float");
971
+ return Qnil;
972
+ }
973
+
974
+ const size_t last_tokens_size = RARRAY_LEN(last_n_tokens);
975
+ std::vector<llama_token> last_n_tokens_data(last_tokens_size);
976
+ for (size_t i = 0; i < last_tokens_size; i++) {
977
+ last_n_tokens_data[i] = NUM2INT(rb_ary_entry(last_n_tokens, i));
978
+ }
979
+
980
+ LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
981
+ if (ctx_ptr->ctx == NULL) {
982
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
983
+ return Qnil;
984
+ }
985
+ LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
986
+ if (cnd_ptr->array.data == nullptr) {
987
+ rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
988
+ return Qnil;
989
+ }
990
+ const float penalty = NUM2DBL(kw_values[0]);
991
+
992
+ llama_sample_repetition_penalty(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size, penalty);
993
+
994
+ return Qnil;
995
+ };
996
+
997
+ static VALUE _llama_context_sample_frequency_and_presence_penalties(int argc, VALUE* argv, VALUE self) {
998
+ VALUE kw_args = Qnil;
999
+ ID kw_table[2] = { rb_intern("frequency"), rb_intern("presence") };
1000
+ VALUE kw_values[2] = { Qundef, Qundef };
1001
+ VALUE candidates = Qnil;
1002
+ VALUE last_n_tokens = Qnil;
1003
+ rb_scan_args(argc, argv, "2:", &candidates, &last_n_tokens, &kw_args);
1004
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
1005
+
1006
+ if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
1007
+ rb_raise(rb_eArgError, "candidates must be a TokenDataArray");
1008
+ return Qnil;
1009
+ }
1010
+ if (!RB_TYPE_P(last_n_tokens, T_ARRAY)) {
1011
+ rb_raise(rb_eArgError, "last_n_tokens must be an Array");
1012
+ return Qnil;
1013
+ }
1014
+ if (!RB_FLOAT_TYPE_P(kw_values[0])) {
1015
+ rb_raise(rb_eArgError, "frequency must be a float");
1016
+ return Qnil;
1017
+ }
1018
+ if (!RB_FLOAT_TYPE_P(kw_values[1])) {
1019
+ rb_raise(rb_eArgError, "presence must be a float");
1020
+ return Qnil;
1021
+ }
1022
+
1023
+ const size_t last_tokens_size = RARRAY_LEN(last_n_tokens);
1024
+ std::vector<llama_token> last_n_tokens_data(last_tokens_size);
1025
+ for (size_t i = 0; i < last_tokens_size; i++) {
1026
+ last_n_tokens_data[i] = NUM2INT(rb_ary_entry(last_n_tokens, i));
1027
+ }
1028
+
1029
+ LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
1030
+ if (ctx_ptr->ctx == NULL) {
1031
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1032
+ return Qnil;
1033
+ }
1034
+ LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
1035
+ if (cnd_ptr->array.data == nullptr) {
1036
+ rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
1037
+ return Qnil;
1038
+ }
1039
+
1040
+ const float alpha_frequency = NUM2DBL(kw_values[0]);
1041
+ const float alpha_presence = NUM2DBL(kw_values[1]);
1042
+
1043
+ llama_sample_frequency_and_presence_penalties(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size, alpha_frequency, alpha_presence);
1044
+
1045
+ return Qnil;
1046
+ };
1047
+
1048
+ static VALUE _llama_context_sample_softmax(VALUE self, VALUE candidates) {
1049
+ if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
1050
+ rb_raise(rb_eArgError, "argument must be a TokenDataArray");
1051
+ return Qnil;
1052
+ }
1053
+
1054
+ LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
1055
+ if (ctx_ptr->ctx == NULL) {
1056
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1057
+ return Qnil;
1058
+ }
1059
+ LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
1060
+ if (cnd_ptr->array.data == nullptr) {
1061
+ rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
1062
+ return Qnil;
1063
+ }
1064
+
1065
+ llama_sample_softmax(ctx_ptr->ctx, &(cnd_ptr->array));
1066
+
1067
+ return Qnil;
1068
+ };
1069
+
1070
+ static VALUE _llama_context_sample_top_k(int argc, VALUE* argv, VALUE self) {
1071
+ VALUE kw_args = Qnil;
1072
+ ID kw_table[2] = { rb_intern("k"), rb_intern("min_keep") };
1073
+ VALUE kw_values[2] = { Qundef, Qundef };
1074
+ VALUE candidates = Qnil;
1075
+ rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
1076
+ rb_get_kwargs(kw_args, kw_table, 1, 1, kw_values);
1077
+
1078
+ if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
1079
+ rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
1080
+ return Qnil;
1081
+ }
1082
+ if (!RB_INTEGER_TYPE_P(kw_values[0])) {
1083
+ rb_raise(rb_eArgError, "k must be an integer");
1084
+ return Qnil;
1085
+ }
1086
+ if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
1087
+ rb_raise(rb_eArgError, "min_keep must be an integer");
1088
+ return Qnil;
1089
+ }
1090
+
1091
+ LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
1092
+ if (ctx_ptr->ctx == NULL) {
1093
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1094
+ return Qnil;
1095
+ }
1096
+ LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
1097
+ if (cnd_ptr->array.data == nullptr) {
1098
+ rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
1099
+ return Qnil;
1100
+ }
1101
+ const int k = NUM2DBL(kw_values[0]);
1102
+ const size_t min_keep = kw_values[1] != Qundef ? NUM2SIZET(kw_values[1]) : 1;
1103
+
1104
+ llama_sample_top_k(ctx_ptr->ctx, &(cnd_ptr->array), k, min_keep);
1105
+
1106
+ return Qnil;
1107
+ };
1108
+
1109
+ static VALUE _llama_context_sample_top_p(int argc, VALUE* argv, VALUE self) {
1110
+ VALUE kw_args = Qnil;
1111
+ ID kw_table[2] = { rb_intern("prob"), rb_intern("min_keep") };
1112
+ VALUE kw_values[2] = { Qundef, Qundef };
1113
+ VALUE candidates = Qnil;
1114
+ rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
1115
+ rb_get_kwargs(kw_args, kw_table, 1, 1, kw_values);
1116
+
1117
+ if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
1118
+ rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
1119
+ return Qnil;
1120
+ }
1121
+ if (!RB_FLOAT_TYPE_P(kw_values[0])) {
1122
+ rb_raise(rb_eArgError, "prob must be a float");
1123
+ return Qnil;
1124
+ }
1125
+ if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
1126
+ rb_raise(rb_eArgError, "min_keep must be an integer");
1127
+ return Qnil;
1128
+ }
1129
+
1130
+ LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
1131
+ if (ctx_ptr->ctx == NULL) {
1132
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1133
+ return Qnil;
1134
+ }
1135
+ LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
1136
+ if (cnd_ptr->array.data == nullptr) {
1137
+ rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
1138
+ return Qnil;
1139
+ }
1140
+ const float prob = NUM2DBL(kw_values[0]);
1141
+ const size_t min_keep = kw_values[1] != Qundef ? NUM2SIZET(kw_values[1]) : 1;
1142
+
1143
+ llama_sample_top_p(ctx_ptr->ctx, &(cnd_ptr->array), prob, min_keep);
1144
+
1145
+ return Qnil;
1146
+ };
1147
+
1148
+ static VALUE _llama_context_sample_tail_free(int argc, VALUE* argv, VALUE self) {
1149
+ VALUE kw_args = Qnil;
1150
+ ID kw_table[2] = { rb_intern("z"), rb_intern("min_keep") };
1151
+ VALUE kw_values[2] = { Qundef, Qundef };
1152
+ VALUE candidates = Qnil;
1153
+ rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
1154
+ rb_get_kwargs(kw_args, kw_table, 1, 1, kw_values);
1155
+
1156
+ if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
1157
+ rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
1158
+ return Qnil;
1159
+ }
1160
+ if (!RB_FLOAT_TYPE_P(kw_values[0])) {
1161
+ rb_raise(rb_eArgError, "prob must be a float");
1162
+ return Qnil;
1163
+ }
1164
+ if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
1165
+ rb_raise(rb_eArgError, "min_keep must be an integer");
1166
+ return Qnil;
1167
+ }
1168
+
1169
+ LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
1170
+ if (ctx_ptr->ctx == NULL) {
1171
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1172
+ return Qnil;
1173
+ }
1174
+ LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
1175
+ if (cnd_ptr->array.data == nullptr) {
1176
+ rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
1177
+ return Qnil;
1178
+ }
1179
+ const float z = NUM2DBL(kw_values[0]);
1180
+ const size_t min_keep = kw_values[1] != Qundef ? NUM2SIZET(kw_values[1]) : 1;
1181
+
1182
+ llama_sample_tail_free(ctx_ptr->ctx, &(cnd_ptr->array), z, min_keep);
1183
+
1184
+ return Qnil;
1185
+ };
1186
+
1187
+ static VALUE _llama_context_sample_typical(int argc, VALUE* argv, VALUE self) {
1188
+ VALUE kw_args = Qnil;
1189
+ ID kw_table[2] = { rb_intern("prob"), rb_intern("min_keep") };
1190
+ VALUE kw_values[2] = { Qundef, Qundef };
1191
+ VALUE candidates = Qnil;
1192
+ rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
1193
+ rb_get_kwargs(kw_args, kw_table, 1, 1, kw_values);
1194
+
1195
+ if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
1196
+ rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
1197
+ return Qnil;
1198
+ }
1199
+ if (!RB_FLOAT_TYPE_P(kw_values[0])) {
1200
+ rb_raise(rb_eArgError, "prob must be a float");
1201
+ return Qnil;
1202
+ }
1203
+ if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
1204
+ rb_raise(rb_eArgError, "min_keep must be an integer");
1205
+ return Qnil;
1206
+ }
1207
+
1208
+ LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
1209
+ if (ctx_ptr->ctx == NULL) {
1210
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1211
+ return Qnil;
1212
+ }
1213
+ LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
1214
+ if (cnd_ptr->array.data == nullptr) {
1215
+ rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
1216
+ return Qnil;
1217
+ }
1218
+ const float prob = NUM2DBL(kw_values[0]);
1219
+ const size_t min_keep = kw_values[1] != Qundef ? NUM2SIZET(kw_values[1]) : 1;
1220
+
1221
+ llama_sample_typical(ctx_ptr->ctx, &(cnd_ptr->array), prob, min_keep);
1222
+
1223
+ return Qnil;
1224
+ };
1225
+
1226
+ static VALUE _llama_context_sample_temperature(int argc, VALUE* argv, VALUE self) {
1227
+ VALUE kw_args = Qnil;
1228
+ ID kw_table[1] = { rb_intern("temperature") };
1229
+ VALUE kw_values[1] = { Qundef };
1230
+ VALUE candidates = Qnil;
1231
+ rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
1232
+ rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
1233
+
1234
+ if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
1235
+ rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
1236
+ return Qnil;
1237
+ }
1238
+ if (!RB_FLOAT_TYPE_P(kw_values[0])) {
1239
+ rb_raise(rb_eArgError, "temperature must be a float");
1240
+ return Qnil;
1241
+ }
1242
+
1243
+ LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
1244
+ if (ctx_ptr->ctx == NULL) {
1245
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1246
+ return Qnil;
1247
+ }
1248
+ LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
1249
+ if (cnd_ptr->array.data == nullptr) {
1250
+ rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
1251
+ return Qnil;
1252
+ }
1253
+ const float temperature = NUM2DBL(kw_values[0]);
1254
+
1255
+ llama_sample_temperature(ctx_ptr->ctx, &(cnd_ptr->array), temperature);
1256
+
1257
+ return Qnil;
1258
+ };
1259
+
1260
+ static VALUE _llama_context_sample_token_mirostat(int argc, VALUE* argv, VALUE self) {
1261
+ VALUE kw_args = Qnil;
1262
+ ID kw_table[4] = { rb_intern("tau"), rb_intern("eta"), rb_intern("m"), rb_intern("mu") };
1263
+ VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
1264
+ VALUE candidates = Qnil;
1265
+ rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
1266
+ rb_get_kwargs(kw_args, kw_table, 4, 0, kw_values);
1267
+
1268
+ if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
1269
+ rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
1270
+ return Qnil;
1271
+ }
1272
+ if (!RB_FLOAT_TYPE_P(kw_values[0])) {
1273
+ rb_raise(rb_eArgError, "tau must be a float");
1274
+ return Qnil;
1275
+ }
1276
+ if (!RB_FLOAT_TYPE_P(kw_values[1])) {
1277
+ rb_raise(rb_eArgError, "eta must be a float");
1278
+ return Qnil;
1279
+ }
1280
+ if (!RB_INTEGER_TYPE_P(kw_values[2])) {
1281
+ rb_raise(rb_eArgError, "m must be an integer");
1282
+ return Qnil;
1283
+ }
1284
+ if (!RB_FLOAT_TYPE_P(kw_values[3])) {
1285
+ rb_raise(rb_eArgError, "mu must be a float");
1286
+ return Qnil;
1287
+ }
1288
+
1289
+ LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
1290
+ if (ctx_ptr->ctx == NULL) {
1291
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1292
+ return Qnil;
1293
+ }
1294
+ LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
1295
+ if (cnd_ptr->array.data == nullptr) {
1296
+ rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
1297
+ return Qnil;
1298
+ }
1299
+ const float tau = NUM2DBL(kw_values[0]);
1300
+ const float eta = NUM2DBL(kw_values[1]);
1301
+ const int m = NUM2INT(kw_values[2]);
1302
+ float mu = NUM2DBL(kw_values[3]);
1303
+
1304
+ llama_token id = llama_sample_token_mirostat(ctx_ptr->ctx, &(cnd_ptr->array), tau, eta, m, &mu);
1305
+
1306
+ VALUE ret = rb_ary_new2(2);
1307
+ rb_ary_store(ret, 0, INT2NUM(id));
1308
+ rb_ary_store(ret, 1, DBL2NUM(mu));
1309
+ return ret;
1310
+ };
1311
+
1312
+ static VALUE _llama_context_sample_token_mirostat_v2(int argc, VALUE* argv, VALUE self) {
1313
+ VALUE kw_args = Qnil;
1314
+ ID kw_table[3] = { rb_intern("tau"), rb_intern("eta"), rb_intern("mu") };
1315
+ VALUE kw_values[3] = { Qundef, Qundef, Qundef };
1316
+ VALUE candidates = Qnil;
1317
+ rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
1318
+ rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
1319
+
1320
+ if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
1321
+ rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
1322
+ return Qnil;
1323
+ }
1324
+ if (!RB_FLOAT_TYPE_P(kw_values[0])) {
1325
+ rb_raise(rb_eArgError, "tau must be a float");
1326
+ return Qnil;
1327
+ }
1328
+ if (!RB_FLOAT_TYPE_P(kw_values[1])) {
1329
+ rb_raise(rb_eArgError, "eta must be a float");
1330
+ return Qnil;
1331
+ }
1332
+ if (!RB_FLOAT_TYPE_P(kw_values[2])) {
1333
+ rb_raise(rb_eArgError, "mu must be a float");
1334
+ return Qnil;
1335
+ }
1336
+
1337
+ LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
1338
+ if (ctx_ptr->ctx == NULL) {
1339
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1340
+ return Qnil;
1341
+ }
1342
+ LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
1343
+ if (cnd_ptr->array.data == nullptr) {
1344
+ rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
1345
+ return Qnil;
1346
+ }
1347
+ const float tau = NUM2DBL(kw_values[0]);
1348
+ const float eta = NUM2DBL(kw_values[1]);
1349
+ float mu = NUM2DBL(kw_values[2]);
1350
+
1351
+ llama_token id = llama_sample_token_mirostat_v2(ctx_ptr->ctx, &(cnd_ptr->array), tau, eta, &mu);
1352
+
1353
+ VALUE ret = rb_ary_new2(2);
1354
+ rb_ary_store(ret, 0, INT2NUM(id));
1355
+ rb_ary_store(ret, 1, DBL2NUM(mu));
1356
+ return ret;
1357
+ };
1358
+
1359
+ static VALUE _llama_context_sample_token_greedy(VALUE self, VALUE candidates) {
1360
+ LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
1361
+ if (ctx_ptr->ctx == NULL) {
1362
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1363
+ return Qnil;
1364
+ }
1365
+ if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
1366
+ rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
1367
+ return Qnil;
1368
+ }
1369
+ LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
1370
+ if (cnd_ptr->array.data == nullptr) {
1371
+ rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
1372
+ return Qnil;
1373
+ }
1374
+ llama_token id = llama_sample_token_greedy(ctx_ptr->ctx, &(cnd_ptr->array));
1375
+ return INT2NUM(id);
1376
+ };
1377
+
1378
+ static VALUE _llama_context_sample_token(VALUE self, VALUE candidates) {
1379
+ LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
1380
+ if (ctx_ptr->ctx == NULL) {
1381
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1382
+ return Qnil;
1383
+ }
1384
+ if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
1385
+ rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
1386
+ return Qnil;
1387
+ }
1388
+ LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
1389
+ if (cnd_ptr->array.data == nullptr) {
1390
+ rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
1391
+ return Qnil;
1392
+ }
1393
+ llama_token id = llama_sample_token(ctx_ptr->ctx, &(cnd_ptr->array));
1394
+ return INT2NUM(id);
1395
+ };
624
1396
  };
625
1397
 
626
1398
  const rb_data_type_t RbLLaMAContext::llama_context_type = {
@@ -680,6 +1452,10 @@ static VALUE rb_llama_token_eos(VALUE self) {
680
1452
  return INT2NUM(llama_token_eos());
681
1453
  }
682
1454
 
1455
+ static VALUE rb_llama_token_nl(VALUE self) {
1456
+ return INT2NUM(llama_token_nl());
1457
+ }
1458
+
683
1459
  static VALUE rb_llama_print_system_info(VALUE self) {
684
1460
  const char* result = llama_print_system_info();
685
1461
  return rb_utf8_str_new_cstr(result);
@@ -695,12 +1471,16 @@ static VALUE rb_llama_mlock_supported(VALUE self) {
695
1471
 
696
1472
  extern "C" void Init_llama_cpp(void) {
697
1473
  rb_mLLaMACpp = rb_define_module("LLaMACpp");
1474
+
1475
+ RbLLaMATokenData::define_class(rb_mLLaMACpp);
1476
+ RbLLaMATokenDataArray::define_class(rb_mLLaMACpp);
698
1477
  RbLLaMAContext::define_class(rb_mLLaMACpp);
699
1478
  RbLLaMAContextParams::define_class(rb_mLLaMACpp);
700
1479
 
701
1480
  rb_define_module_function(rb_mLLaMACpp, "model_quantize", rb_llama_model_quantize, -1);
702
1481
  rb_define_module_function(rb_mLLaMACpp, "token_bos", rb_llama_token_bos, 0);
703
1482
  rb_define_module_function(rb_mLLaMACpp, "token_eos", rb_llama_token_eos, 0);
1483
+ rb_define_module_function(rb_mLLaMACpp, "token_nl", rb_llama_token_nl, 0);
704
1484
  rb_define_module_function(rb_mLLaMACpp, "print_system_info", rb_llama_print_system_info, 0);
705
1485
  rb_define_module_function(rb_mLLaMACpp, "mmap_supported?", rb_llama_mmap_supported, 0);
706
1486
  rb_define_module_function(rb_mLLaMACpp, "mlock_supported?", rb_llama_mlock_supported, 0);
@@ -710,8 +1490,6 @@ extern "C" void Init_llama_cpp(void) {
710
1490
  rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q4_0", INT2NUM(LLAMA_FTYPE_MOSTLY_Q4_0));
711
1491
  rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q4_1", INT2NUM(LLAMA_FTYPE_MOSTLY_Q4_1));
712
1492
  rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16", INT2NUM(LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16));
713
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q4_2", INT2NUM(LLAMA_FTYPE_MOSTLY_Q4_2));
714
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q4_3", INT2NUM(LLAMA_FTYPE_MOSTLY_Q4_3));
715
1493
  rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q8_0", INT2NUM(LLAMA_FTYPE_MOSTLY_Q8_0));
716
1494
  rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q5_0", INT2NUM(LLAMA_FTYPE_MOSTLY_Q5_0));
717
1495
  rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q5_1", INT2NUM(LLAMA_FTYPE_MOSTLY_Q5_1));