llama_cpp 0.0.7 → 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -4,6 +4,255 @@
4
4
  VALUE rb_mLLaMACpp;
5
5
  VALUE rb_cLLaMAContext;
6
6
  VALUE rb_cLLaMAContextParams;
7
+ VALUE rb_cLLaMATokenData;
8
+ VALUE rb_cLLaMATokenDataArray;
9
+
10
+ class LLaMATokenDataWrapper {
11
+ public:
12
+ llama_token_data data;
13
+
14
+ LLaMATokenDataWrapper() {
15
+ data.id = 0;
16
+ data.logit = 0.0;
17
+ data.p = 0.0;
18
+ };
19
+
20
+ ~LLaMATokenDataWrapper(){};
21
+ };
22
+
23
+ class RbLLaMATokenData {
24
+ public:
25
+ static VALUE llama_token_data_alloc(VALUE self) {
26
+ LLaMATokenDataWrapper* ptr = (LLaMATokenDataWrapper*)ruby_xmalloc(sizeof(LLaMATokenDataWrapper));
27
+ new (ptr) LLaMATokenDataWrapper();
28
+ return TypedData_Wrap_Struct(self, &llama_token_data_type, ptr);
29
+ };
30
+
31
+ static void llama_token_data_free(void* ptr) {
32
+ ((LLaMATokenDataWrapper*)ptr)->~LLaMATokenDataWrapper();
33
+ ruby_xfree(ptr);
34
+ };
35
+
36
+ static size_t llama_token_data_size(const void* ptr) {
37
+ return sizeof(*((LLaMATokenDataWrapper*)ptr));
38
+ };
39
+
40
+ static LLaMATokenDataWrapper* get_llama_token_data(VALUE self) {
41
+ LLaMATokenDataWrapper* ptr;
42
+ TypedData_Get_Struct(self, LLaMATokenDataWrapper, &llama_token_data_type, ptr);
43
+ return ptr;
44
+ };
45
+
46
+ static void define_class(VALUE outer) {
47
+ rb_cLLaMATokenData = rb_define_class_under(outer, "TokenData", rb_cObject);
48
+ rb_define_alloc_func(rb_cLLaMATokenData, llama_token_data_alloc);
49
+ rb_define_method(rb_cLLaMATokenData, "initialize", RUBY_METHOD_FUNC(_llama_token_data_init), -1);
50
+ rb_define_method(rb_cLLaMATokenData, "id=", RUBY_METHOD_FUNC(_llama_token_data_set_id), 1);
51
+ rb_define_method(rb_cLLaMATokenData, "id", RUBY_METHOD_FUNC(_llama_token_data_get_id), 0);
52
+ rb_define_method(rb_cLLaMATokenData, "logit=", RUBY_METHOD_FUNC(_llama_token_data_set_logit), 1);
53
+ rb_define_method(rb_cLLaMATokenData, "logit", RUBY_METHOD_FUNC(_llama_token_data_get_logit), 0);
54
+ rb_define_method(rb_cLLaMATokenData, "p=", RUBY_METHOD_FUNC(_llama_token_data_set_p), 1);
55
+ rb_define_method(rb_cLLaMATokenData, "p", RUBY_METHOD_FUNC(_llama_token_data_get_p), 0);
56
+ }
57
+
58
+ private:
59
+ static const rb_data_type_t llama_token_data_type;
60
+
61
+ static VALUE _llama_token_data_init(int argc, VALUE* argv, VALUE self) {
62
+ VALUE kw_args = Qnil;
63
+ ID kw_table[3] = { rb_intern("id"), rb_intern("logit"), rb_intern("p") };
64
+ VALUE kw_values[3] = { Qundef, Qundef, Qundef };
65
+ rb_scan_args(argc, argv, ":", &kw_args);
66
+ rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
67
+
68
+ if (!RB_INTEGER_TYPE_P(kw_values[0])) {
69
+ rb_raise(rb_eArgError, "id must be an integer");
70
+ return Qnil;
71
+ }
72
+ if (!RB_FLOAT_TYPE_P(kw_values[1])) {
73
+ rb_raise(rb_eArgError, "logit must be a float");
74
+ return Qnil;
75
+ }
76
+ if (!RB_FLOAT_TYPE_P(kw_values[2])) {
77
+ rb_raise(rb_eArgError, "p must be a float");
78
+ return Qnil;
79
+ }
80
+
81
+ LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
82
+ new (ptr) LLaMATokenDataWrapper();
83
+
84
+ ptr->data.id = NUM2INT(kw_values[0]);
85
+ ptr->data.logit = NUM2DBL(kw_values[1]);
86
+ ptr->data.p = NUM2DBL(kw_values[2]);
87
+
88
+ return self;
89
+ }
90
+
91
+ // id
92
+ static VALUE _llama_token_data_set_id(VALUE self, VALUE id) {
93
+ LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
94
+ ptr->data.id = NUM2INT(id);
95
+ return INT2NUM(ptr->data.id);
96
+ };
97
+
98
+ static VALUE _llama_token_data_get_id(VALUE self) {
99
+ LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
100
+ return INT2NUM(ptr->data.id);
101
+ };
102
+
103
+ // logit
104
+ static VALUE _llama_token_data_set_logit(VALUE self, VALUE logit) {
105
+ LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
106
+ ptr->data.logit = NUM2DBL(logit);
107
+ return DBL2NUM(ptr->data.logit);
108
+ };
109
+
110
+ static VALUE _llama_token_data_get_logit(VALUE self) {
111
+ LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
112
+ return DBL2NUM(ptr->data.logit);
113
+ };
114
+
115
+ // p
116
+ static VALUE _llama_token_data_set_p(VALUE self, VALUE p) {
117
+ LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
118
+ ptr->data.p = NUM2DBL(p);
119
+ return DBL2NUM(ptr->data.p);
120
+ };
121
+
122
+ static VALUE _llama_token_data_get_p(VALUE self) {
123
+ LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
124
+ return DBL2NUM(ptr->data.p);
125
+ };
126
+ };
127
+
128
+ const rb_data_type_t RbLLaMATokenData::llama_token_data_type = {
129
+ "RbLLaMATokenData",
130
+ { NULL,
131
+ RbLLaMATokenData::llama_token_data_free,
132
+ RbLLaMATokenData::llama_token_data_size },
133
+ NULL,
134
+ NULL,
135
+ RUBY_TYPED_FREE_IMMEDIATELY
136
+ };
137
+
138
+ class LLaMATokenDataArrayWrapper {
139
+ public:
140
+ llama_token_data_array array;
141
+
142
+ LLaMATokenDataArrayWrapper() {
143
+ array.data = nullptr;
144
+ array.size = 0;
145
+ array.sorted = false;
146
+ };
147
+
148
+ ~LLaMATokenDataArrayWrapper() {
149
+ if (array.data) {
150
+ ruby_xfree(array.data);
151
+ array.data = nullptr;
152
+ }
153
+ };
154
+ };
155
+
156
+ class RbLLaMATokenDataArray {
157
+ public:
158
+ static VALUE llama_token_data_array_alloc(VALUE self) {
159
+ LLaMATokenDataArrayWrapper* ptr = (LLaMATokenDataArrayWrapper*)ruby_xmalloc(sizeof(LLaMATokenDataArrayWrapper));
160
+ new (ptr) LLaMATokenDataArrayWrapper();
161
+ return TypedData_Wrap_Struct(self, &llama_token_data_array_type, ptr);
162
+ };
163
+
164
+ static void llama_token_data_array_free(void* ptr) {
165
+ ((LLaMATokenDataArrayWrapper*)ptr)->~LLaMATokenDataArrayWrapper();
166
+ ruby_xfree(ptr);
167
+ };
168
+
169
+ static size_t llama_token_data_array_size(const void* ptr) {
170
+ return sizeof(*((LLaMATokenDataArrayWrapper*)ptr));
171
+ };
172
+
173
+ static LLaMATokenDataArrayWrapper* get_llama_token_data_array(VALUE self) {
174
+ LLaMATokenDataArrayWrapper* ptr;
175
+ TypedData_Get_Struct(self, LLaMATokenDataArrayWrapper, &llama_token_data_array_type, ptr);
176
+ return ptr;
177
+ };
178
+
179
+ static void define_class(VALUE outer) {
180
+ rb_cLLaMATokenDataArray = rb_define_class_under(outer, "TokenDataArray", rb_cObject);
181
+ rb_define_alloc_func(rb_cLLaMATokenDataArray, llama_token_data_array_alloc);
182
+ rb_define_method(rb_cLLaMATokenDataArray, "initialize", RUBY_METHOD_FUNC(_llama_token_data_array_init), -1);
183
+ rb_define_method(rb_cLLaMATokenDataArray, "size", RUBY_METHOD_FUNC(_llama_token_data_array_get_size), 0);
184
+ rb_define_method(rb_cLLaMATokenDataArray, "sorted", RUBY_METHOD_FUNC(_llama_token_data_array_get_sorted), 0);
185
+ };
186
+
187
+ private:
188
+ static const rb_data_type_t llama_token_data_array_type;
189
+
190
+ static VALUE _llama_token_data_array_init(int argc, VALUE* argv, VALUE self) {
191
+ VALUE kw_args = Qnil;
192
+ ID kw_table[1] = { rb_intern("sorted") };
193
+ VALUE kw_values[1] = { Qundef };
194
+ VALUE arr = Qnil;
195
+ rb_scan_args(argc, argv, "1:", &arr, &kw_args);
196
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
197
+
198
+ if (!RB_TYPE_P(arr, T_ARRAY)) {
199
+ rb_raise(rb_eArgError, "1st argument must be an array");
200
+ return Qnil;
201
+ }
202
+ size_t sz_array = RARRAY_LEN(arr);
203
+ if (sz_array == 0) {
204
+ rb_raise(rb_eArgError, "array must not be empty");
205
+ return Qnil;
206
+ }
207
+ if (kw_values[0] != Qundef && !RB_TYPE_P(kw_values[0], T_TRUE) && !RB_TYPE_P(kw_values[0], T_FALSE)) {
208
+ rb_raise(rb_eArgError, "sorted must be a boolean");
209
+ return Qnil;
210
+ }
211
+
212
+ LLaMATokenDataArrayWrapper* ptr = get_llama_token_data_array(self);
213
+ new (ptr) LLaMATokenDataArrayWrapper();
214
+
215
+ ptr->array.data = (llama_token_data*)ruby_xmalloc(sizeof(llama_token_data) * sz_array);
216
+ for (size_t i = 0; i < sz_array; ++i) {
217
+ VALUE el = rb_ary_entry(arr, i);
218
+ if (!rb_obj_is_kind_of(el, rb_cLLaMATokenData)) {
219
+ rb_raise(rb_eArgError, "array element must be a TokenData");
220
+ xfree(ptr->array.data);
221
+ ptr->array.data = nullptr;
222
+ return Qnil;
223
+ }
224
+ llama_token_data token_data = RbLLaMATokenData::get_llama_token_data(el)->data;
225
+ ptr->array.data[i].id = token_data.id;
226
+ ptr->array.data[i].logit = token_data.logit;
227
+ ptr->array.data[i].p = token_data.p;
228
+ }
229
+
230
+ ptr->array.size = sz_array;
231
+ ptr->array.sorted = kw_values[0] == Qtrue;
232
+
233
+ return self;
234
+ };
235
+
236
+ static VALUE _llama_token_data_array_get_size(VALUE self) {
237
+ LLaMATokenDataArrayWrapper* ptr = get_llama_token_data_array(self);
238
+ return SIZET2NUM(ptr->array.size);
239
+ };
240
+
241
+ static VALUE _llama_token_data_array_get_sorted(VALUE self) {
242
+ LLaMATokenDataArrayWrapper* ptr = get_llama_token_data_array(self);
243
+ return ptr->array.sorted ? Qtrue : Qfalse;
244
+ };
245
+ };
246
+
247
+ const rb_data_type_t RbLLaMATokenDataArray::llama_token_data_array_type = {
248
+ "RbLLaMATokenDataArray",
249
+ { NULL,
250
+ RbLLaMATokenDataArray::llama_token_data_array_free,
251
+ RbLLaMATokenDataArray::llama_token_data_array_size },
252
+ NULL,
253
+ NULL,
254
+ RUBY_TYPED_FREE_IMMEDIATELY
255
+ };
7
256
 
8
257
  class LLaMAContextParamsWrapper {
9
258
  public:
@@ -234,7 +483,6 @@ public:
234
483
  rb_define_method(rb_cLLaMAContext, "logits", RUBY_METHOD_FUNC(_llama_context_logits), 0);
235
484
  rb_define_method(rb_cLLaMAContext, "embeddings", RUBY_METHOD_FUNC(_llama_context_embeddings), 0);
236
485
  rb_define_method(rb_cLLaMAContext, "token_to_str", RUBY_METHOD_FUNC(_llama_context_token_to_str), 1);
237
- rb_define_method(rb_cLLaMAContext, "sample_top_p_top_k", RUBY_METHOD_FUNC(_llama_context_sample_top_p_top_k), -1);
238
486
  rb_define_method(rb_cLLaMAContext, "n_vocab", RUBY_METHOD_FUNC(_llama_context_n_vocab), 0);
239
487
  rb_define_method(rb_cLLaMAContext, "n_ctx", RUBY_METHOD_FUNC(_llama_context_n_ctx), 0);
240
488
  rb_define_method(rb_cLLaMAContext, "n_embd", RUBY_METHOD_FUNC(_llama_context_n_embd), 0);
@@ -244,6 +492,20 @@ public:
244
492
  rb_define_method(rb_cLLaMAContext, "free", RUBY_METHOD_FUNC(_llama_context_free), 0);
245
493
  rb_define_method(rb_cLLaMAContext, "load", RUBY_METHOD_FUNC(_llama_context_load), -1);
246
494
  rb_define_method(rb_cLLaMAContext, "apply_lora_from_file", RUBY_METHOD_FUNC(_llama_context_apply_lora_from_file), -1);
495
+ rb_define_method(rb_cLLaMAContext, "kv_cache_token_count", RUBY_METHOD_FUNC(_llama_context_kv_cache_token_count), 0);
496
+ rb_define_method(rb_cLLaMAContext, "set_rng_seed", RUBY_METHOD_FUNC(_llama_context_set_rng_seed), 1);
497
+ rb_define_method(rb_cLLaMAContext, "sample_repetition_penalty", RUBY_METHOD_FUNC(_llama_context_sample_repetition_penalty), -1);
498
+ rb_define_method(rb_cLLaMAContext, "sample_frequency_and_presence_penalties", RUBY_METHOD_FUNC(_llama_context_sample_frequency_and_presence_penalties), -1);
499
+ rb_define_method(rb_cLLaMAContext, "sample_softmax", RUBY_METHOD_FUNC(_llama_context_sample_softmax), 1);
500
+ rb_define_method(rb_cLLaMAContext, "sample_top_k", RUBY_METHOD_FUNC(_llama_context_sample_top_k), -1);
501
+ rb_define_method(rb_cLLaMAContext, "sample_top_p", RUBY_METHOD_FUNC(_llama_context_sample_top_p), -1);
502
+ rb_define_method(rb_cLLaMAContext, "sample_tail_free", RUBY_METHOD_FUNC(_llama_context_sample_tail_free), -1);
503
+ rb_define_method(rb_cLLaMAContext, "sample_typical", RUBY_METHOD_FUNC(_llama_context_sample_typical), -1);
504
+ rb_define_method(rb_cLLaMAContext, "sample_temperature", RUBY_METHOD_FUNC(_llama_context_sample_temperature), -1);
505
+ rb_define_method(rb_cLLaMAContext, "sample_token_mirostat", RUBY_METHOD_FUNC(_llama_context_sample_token_mirostat), -1);
506
+ rb_define_method(rb_cLLaMAContext, "sample_token_mirostat_v2", RUBY_METHOD_FUNC(_llama_context_sample_token_mirostat_v2), -1);
507
+ rb_define_method(rb_cLLaMAContext, "sample_token_greedy", RUBY_METHOD_FUNC(_llama_context_sample_token_greedy), 1);
508
+ rb_define_method(rb_cLLaMAContext, "sample_token", RUBY_METHOD_FUNC(_llama_context_sample_token), 1);
247
509
  };
248
510
 
249
511
  private:
@@ -448,40 +710,6 @@ private:
448
710
  return output;
449
711
  };
450
712
 
451
- static VALUE _llama_context_sample_top_p_top_k(int argc, VALUE* argv, VALUE self) {
452
- VALUE last_n_tokens = Qnil;
453
- VALUE kw_args = Qnil;
454
- ID kw_table[4] = { rb_intern("top_k"), rb_intern("top_p"), rb_intern("temp"), rb_intern("penalty") };
455
- VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
456
- rb_scan_args(argc, argv, "1:", &last_n_tokens, &kw_args);
457
- rb_get_kwargs(kw_args, kw_table, 4, 0, kw_values);
458
-
459
- if (!RB_TYPE_P(last_n_tokens, T_ARRAY)) {
460
- rb_raise(rb_eArgError, "last_n_tokens must be an Array");
461
- return Qnil;
462
- }
463
-
464
- const int last_n_tokens_size = RARRAY_LEN(last_n_tokens);
465
- const int top_k = NUM2INT(kw_values[0]);
466
- const double top_p = NUM2DBL(kw_values[1]);
467
- const double temp = NUM2DBL(kw_values[2]);
468
- const double penalty = NUM2DBL(kw_values[3]);
469
-
470
- std::vector<llama_token> last_n_tokens_data(last_n_tokens_size);
471
- for (int i = 0; i < last_n_tokens_size; i++) {
472
- last_n_tokens_data[i] = NUM2INT(rb_ary_entry(last_n_tokens, i));
473
- }
474
-
475
- LLaMAContextWrapper* ptr = get_llama_context(self);
476
- if (ptr->ctx == NULL) {
477
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
478
- return Qnil;
479
- }
480
- llama_token token = llama_sample_top_p_top_k(ptr->ctx, last_n_tokens_data.data(), last_n_tokens_size, top_k, top_p, temp, penalty);
481
-
482
- return INT2NUM(token);
483
- }
484
-
485
713
  static VALUE _llama_context_n_vocab(VALUE self) {
486
714
  LLaMAContextWrapper* ptr = get_llama_context(self);
487
715
  if (ptr->ctx == NULL) {
@@ -621,6 +849,471 @@ private:
621
849
  }
622
850
  return Qnil;
623
851
  };
852
+
853
+ static VALUE _llama_context_kv_cache_token_count(VALUE self) {
854
+ LLaMAContextWrapper* ptr = get_llama_context(self);
855
+ if (ptr->ctx == NULL) {
856
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
857
+ return Qnil;
858
+ }
859
+ return INT2NUM(llama_get_kv_cache_token_count(ptr->ctx));
860
+ };
861
+
862
+ static VALUE _llama_context_set_rng_seed(VALUE self, VALUE seed_) {
863
+ LLaMAContextWrapper* ptr = get_llama_context(self);
864
+ if (ptr->ctx == NULL) {
865
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
866
+ return Qnil;
867
+ }
868
+ const int seed = NUM2INT(seed_);
869
+ llama_set_rng_seed(ptr->ctx, seed);
870
+ return Qnil;
871
+ };
872
+
873
+ static VALUE _llama_context_sample_repetition_penalty(int argc, VALUE* argv, VALUE self) {
874
+ VALUE kw_args = Qnil;
875
+ ID kw_table[1] = { rb_intern("penalty") };
876
+ VALUE kw_values[1] = { Qundef };
877
+ VALUE candidates = Qnil;
878
+ VALUE last_n_tokens = Qnil;
879
+ rb_scan_args(argc, argv, "2:", &candidates, &last_n_tokens, &kw_args);
880
+ rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
881
+
882
+ if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
883
+ rb_raise(rb_eArgError, "candidates must be a TokenDataArray");
884
+ return Qnil;
885
+ }
886
+ if (!RB_TYPE_P(last_n_tokens, T_ARRAY)) {
887
+ rb_raise(rb_eArgError, "last_n_tokens must be an Array");
888
+ return Qnil;
889
+ }
890
+ if (!RB_FLOAT_TYPE_P(kw_values[0])) {
891
+ rb_raise(rb_eArgError, "penalty must be a float");
892
+ return Qnil;
893
+ }
894
+
895
+ const size_t last_tokens_size = RARRAY_LEN(last_n_tokens);
896
+ std::vector<llama_token> last_n_tokens_data(last_tokens_size);
897
+ for (size_t i = 0; i < last_tokens_size; i++) {
898
+ last_n_tokens_data[i] = NUM2INT(rb_ary_entry(last_n_tokens, i));
899
+ }
900
+
901
+ LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
902
+ if (ctx_ptr->ctx == NULL) {
903
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
904
+ return Qnil;
905
+ }
906
+ LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
907
+ if (cnd_ptr->array.data == nullptr) {
908
+ rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
909
+ return Qnil;
910
+ }
911
+ const float penalty = NUM2DBL(kw_values[0]);
912
+
913
+ llama_sample_repetition_penalty(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size, penalty);
914
+
915
+ return Qnil;
916
+ };
917
+
918
+ static VALUE _llama_context_sample_frequency_and_presence_penalties(int argc, VALUE* argv, VALUE self) {
919
+ VALUE kw_args = Qnil;
920
+ ID kw_table[2] = { rb_intern("frequency"), rb_intern("presence") };
921
+ VALUE kw_values[2] = { Qundef, Qundef };
922
+ VALUE candidates = Qnil;
923
+ VALUE last_n_tokens = Qnil;
924
+ rb_scan_args(argc, argv, "2:", &candidates, &last_n_tokens, &kw_args);
925
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
926
+
927
+ if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
928
+ rb_raise(rb_eArgError, "candidates must be a TokenDataArray");
929
+ return Qnil;
930
+ }
931
+ if (!RB_TYPE_P(last_n_tokens, T_ARRAY)) {
932
+ rb_raise(rb_eArgError, "last_n_tokens must be an Array");
933
+ return Qnil;
934
+ }
935
+ if (!RB_FLOAT_TYPE_P(kw_values[0])) {
936
+ rb_raise(rb_eArgError, "frequency must be a float");
937
+ return Qnil;
938
+ }
939
+ if (!RB_FLOAT_TYPE_P(kw_values[1])) {
940
+ rb_raise(rb_eArgError, "presence must be a float");
941
+ return Qnil;
942
+ }
943
+
944
+ const size_t last_tokens_size = RARRAY_LEN(last_n_tokens);
945
+ std::vector<llama_token> last_n_tokens_data(last_tokens_size);
946
+ for (size_t i = 0; i < last_tokens_size; i++) {
947
+ last_n_tokens_data[i] = NUM2INT(rb_ary_entry(last_n_tokens, i));
948
+ }
949
+
950
+ LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
951
+ if (ctx_ptr->ctx == NULL) {
952
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
953
+ return Qnil;
954
+ }
955
+ LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
956
+ if (cnd_ptr->array.data == nullptr) {
957
+ rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
958
+ return Qnil;
959
+ }
960
+
961
+ const float alpha_frequency = NUM2DBL(kw_values[0]);
962
+ const float alpha_presence = NUM2DBL(kw_values[1]);
963
+
964
+ llama_sample_frequency_and_presence_penalties(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size, alpha_frequency, alpha_presence);
965
+
966
+ return Qnil;
967
+ };
968
+
969
+ static VALUE _llama_context_sample_softmax(VALUE self, VALUE candidates) {
970
+ if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
971
+ rb_raise(rb_eArgError, "argument must be a TokenDataArray");
972
+ return Qnil;
973
+ }
974
+
975
+ LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
976
+ if (ctx_ptr->ctx == NULL) {
977
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
978
+ return Qnil;
979
+ }
980
+ LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
981
+ if (cnd_ptr->array.data == nullptr) {
982
+ rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
983
+ return Qnil;
984
+ }
985
+
986
+ llama_sample_softmax(ctx_ptr->ctx, &(cnd_ptr->array));
987
+
988
+ return Qnil;
989
+ };
990
+
991
+ static VALUE _llama_context_sample_top_k(int argc, VALUE* argv, VALUE self) {
992
+ VALUE kw_args = Qnil;
993
+ ID kw_table[2] = { rb_intern("k"), rb_intern("min_keep") };
994
+ VALUE kw_values[2] = { Qundef, Qundef };
995
+ VALUE candidates = Qnil;
996
+ rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
997
+ rb_get_kwargs(kw_args, kw_table, 1, 1, kw_values);
998
+
999
+ if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
1000
+ rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
1001
+ return Qnil;
1002
+ }
1003
+ if (!RB_INTEGER_TYPE_P(kw_values[0])) {
1004
+ rb_raise(rb_eArgError, "k must be an integer");
1005
+ return Qnil;
1006
+ }
1007
+ if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
1008
+ rb_raise(rb_eArgError, "min_keep must be an integer");
1009
+ return Qnil;
1010
+ }
1011
+
1012
+ LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
1013
+ if (ctx_ptr->ctx == NULL) {
1014
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1015
+ return Qnil;
1016
+ }
1017
+ LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
1018
+ if (cnd_ptr->array.data == nullptr) {
1019
+ rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
1020
+ return Qnil;
1021
+ }
1022
+ const int k = NUM2DBL(kw_values[0]);
1023
+ const size_t min_keep = kw_values[1] != Qundef ? NUM2SIZET(kw_values[1]) : 1;
1024
+
1025
+ llama_sample_top_k(ctx_ptr->ctx, &(cnd_ptr->array), k, min_keep);
1026
+
1027
+ return Qnil;
1028
+ };
1029
+
1030
+ static VALUE _llama_context_sample_top_p(int argc, VALUE* argv, VALUE self) {
1031
+ VALUE kw_args = Qnil;
1032
+ ID kw_table[2] = { rb_intern("prob"), rb_intern("min_keep") };
1033
+ VALUE kw_values[2] = { Qundef, Qundef };
1034
+ VALUE candidates = Qnil;
1035
+ rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
1036
+ rb_get_kwargs(kw_args, kw_table, 1, 1, kw_values);
1037
+
1038
+ if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
1039
+ rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
1040
+ return Qnil;
1041
+ }
1042
+ if (!RB_FLOAT_TYPE_P(kw_values[0])) {
1043
+ rb_raise(rb_eArgError, "prob must be a float");
1044
+ return Qnil;
1045
+ }
1046
+ if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
1047
+ rb_raise(rb_eArgError, "min_keep must be an integer");
1048
+ return Qnil;
1049
+ }
1050
+
1051
+ LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
1052
+ if (ctx_ptr->ctx == NULL) {
1053
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1054
+ return Qnil;
1055
+ }
1056
+ LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
1057
+ if (cnd_ptr->array.data == nullptr) {
1058
+ rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
1059
+ return Qnil;
1060
+ }
1061
+ const float prob = NUM2DBL(kw_values[0]);
1062
+ const size_t min_keep = kw_values[1] != Qundef ? NUM2SIZET(kw_values[1]) : 1;
1063
+
1064
+ llama_sample_top_p(ctx_ptr->ctx, &(cnd_ptr->array), prob, min_keep);
1065
+
1066
+ return Qnil;
1067
+ };
1068
+
1069
+ static VALUE _llama_context_sample_tail_free(int argc, VALUE* argv, VALUE self) {
1070
+ VALUE kw_args = Qnil;
1071
+ ID kw_table[2] = { rb_intern("z"), rb_intern("min_keep") };
1072
+ VALUE kw_values[2] = { Qundef, Qundef };
1073
+ VALUE candidates = Qnil;
1074
+ rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
1075
+ rb_get_kwargs(kw_args, kw_table, 1, 1, kw_values);
1076
+
1077
+ if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
1078
+ rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
1079
+ return Qnil;
1080
+ }
1081
+ if (!RB_FLOAT_TYPE_P(kw_values[0])) {
1082
+ rb_raise(rb_eArgError, "prob must be a float");
1083
+ return Qnil;
1084
+ }
1085
+ if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
1086
+ rb_raise(rb_eArgError, "min_keep must be an integer");
1087
+ return Qnil;
1088
+ }
1089
+
1090
+ LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
1091
+ if (ctx_ptr->ctx == NULL) {
1092
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1093
+ return Qnil;
1094
+ }
1095
+ LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
1096
+ if (cnd_ptr->array.data == nullptr) {
1097
+ rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
1098
+ return Qnil;
1099
+ }
1100
+ const float z = NUM2DBL(kw_values[0]);
1101
+ const size_t min_keep = kw_values[1] != Qundef ? NUM2SIZET(kw_values[1]) : 1;
1102
+
1103
+ llama_sample_tail_free(ctx_ptr->ctx, &(cnd_ptr->array), z, min_keep);
1104
+
1105
+ return Qnil;
1106
+ };
1107
+
1108
+ static VALUE _llama_context_sample_typical(int argc, VALUE* argv, VALUE self) {
1109
+ VALUE kw_args = Qnil;
1110
+ ID kw_table[2] = { rb_intern("prob"), rb_intern("min_keep") };
1111
+ VALUE kw_values[2] = { Qundef, Qundef };
1112
+ VALUE candidates = Qnil;
1113
+ rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
1114
+ rb_get_kwargs(kw_args, kw_table, 1, 1, kw_values);
1115
+
1116
+ if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
1117
+ rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
1118
+ return Qnil;
1119
+ }
1120
+ if (!RB_FLOAT_TYPE_P(kw_values[0])) {
1121
+ rb_raise(rb_eArgError, "prob must be a float");
1122
+ return Qnil;
1123
+ }
1124
+ if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
1125
+ rb_raise(rb_eArgError, "min_keep must be an integer");
1126
+ return Qnil;
1127
+ }
1128
+
1129
+ LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
1130
+ if (ctx_ptr->ctx == NULL) {
1131
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1132
+ return Qnil;
1133
+ }
1134
+ LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
1135
+ if (cnd_ptr->array.data == nullptr) {
1136
+ rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
1137
+ return Qnil;
1138
+ }
1139
+ const float prob = NUM2DBL(kw_values[0]);
1140
+ const size_t min_keep = kw_values[1] != Qundef ? NUM2SIZET(kw_values[1]) : 1;
1141
+
1142
+ llama_sample_typical(ctx_ptr->ctx, &(cnd_ptr->array), prob, min_keep);
1143
+
1144
+ return Qnil;
1145
+ };
1146
+
1147
+ static VALUE _llama_context_sample_temperature(int argc, VALUE* argv, VALUE self) {
1148
+ VALUE kw_args = Qnil;
1149
+ ID kw_table[1] = { rb_intern("temperature") };
1150
+ VALUE kw_values[1] = { Qundef };
1151
+ VALUE candidates = Qnil;
1152
+ rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
1153
+ rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
1154
+
1155
+ if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
1156
+ rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
1157
+ return Qnil;
1158
+ }
1159
+ if (!RB_FLOAT_TYPE_P(kw_values[0])) {
1160
+ rb_raise(rb_eArgError, "temperature must be a float");
1161
+ return Qnil;
1162
+ }
1163
+
1164
+ LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
1165
+ if (ctx_ptr->ctx == NULL) {
1166
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1167
+ return Qnil;
1168
+ }
1169
+ LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
1170
+ if (cnd_ptr->array.data == nullptr) {
1171
+ rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
1172
+ return Qnil;
1173
+ }
1174
+ const float temperature = NUM2DBL(kw_values[0]);
1175
+
1176
+ llama_sample_temperature(ctx_ptr->ctx, &(cnd_ptr->array), temperature);
1177
+
1178
+ return Qnil;
1179
+ };
1180
+
1181
+ static VALUE _llama_context_sample_token_mirostat(int argc, VALUE* argv, VALUE self) {
1182
+ VALUE kw_args = Qnil;
1183
+ ID kw_table[4] = { rb_intern("tau"), rb_intern("eta"), rb_intern("m"), rb_intern("mu") };
1184
+ VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
1185
+ VALUE candidates = Qnil;
1186
+ rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
1187
+ rb_get_kwargs(kw_args, kw_table, 4, 0, kw_values);
1188
+
1189
+ if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
1190
+ rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
1191
+ return Qnil;
1192
+ }
1193
+ if (!RB_FLOAT_TYPE_P(kw_values[0])) {
1194
+ rb_raise(rb_eArgError, "tau must be a float");
1195
+ return Qnil;
1196
+ }
1197
+ if (!RB_FLOAT_TYPE_P(kw_values[1])) {
1198
+ rb_raise(rb_eArgError, "eta must be a float");
1199
+ return Qnil;
1200
+ }
1201
+ if (!RB_INTEGER_TYPE_P(kw_values[2])) {
1202
+ rb_raise(rb_eArgError, "m must be an integer");
1203
+ return Qnil;
1204
+ }
1205
+ if (!RB_FLOAT_TYPE_P(kw_values[3])) {
1206
+ rb_raise(rb_eArgError, "mu must be a float");
1207
+ return Qnil;
1208
+ }
1209
+
1210
+ LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
1211
+ if (ctx_ptr->ctx == NULL) {
1212
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1213
+ return Qnil;
1214
+ }
1215
+ LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
1216
+ if (cnd_ptr->array.data == nullptr) {
1217
+ rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
1218
+ return Qnil;
1219
+ }
1220
+ const float tau = NUM2DBL(kw_values[0]);
1221
+ const float eta = NUM2DBL(kw_values[1]);
1222
+ const int m = NUM2INT(kw_values[2]);
1223
+ float mu = NUM2DBL(kw_values[3]);
1224
+
1225
+ llama_token id = llama_sample_token_mirostat(ctx_ptr->ctx, &(cnd_ptr->array), tau, eta, m, &mu);
1226
+
1227
+ VALUE ret = rb_ary_new2(2);
1228
+ rb_ary_store(ret, 0, INT2NUM(id));
1229
+ rb_ary_store(ret, 1, DBL2NUM(mu));
1230
+ return ret;
1231
+ };
1232
+
1233
+ static VALUE _llama_context_sample_token_mirostat_v2(int argc, VALUE* argv, VALUE self) {
1234
+ VALUE kw_args = Qnil;
1235
+ ID kw_table[3] = { rb_intern("tau"), rb_intern("eta"), rb_intern("mu") };
1236
+ VALUE kw_values[3] = { Qundef, Qundef, Qundef };
1237
+ VALUE candidates = Qnil;
1238
+ rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
1239
+ rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
1240
+
1241
+ if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
1242
+ rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
1243
+ return Qnil;
1244
+ }
1245
+ if (!RB_FLOAT_TYPE_P(kw_values[0])) {
1246
+ rb_raise(rb_eArgError, "tau must be a float");
1247
+ return Qnil;
1248
+ }
1249
+ if (!RB_FLOAT_TYPE_P(kw_values[1])) {
1250
+ rb_raise(rb_eArgError, "eta must be a float");
1251
+ return Qnil;
1252
+ }
1253
+ if (!RB_FLOAT_TYPE_P(kw_values[2])) {
1254
+ rb_raise(rb_eArgError, "mu must be a float");
1255
+ return Qnil;
1256
+ }
1257
+
1258
+ LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
1259
+ if (ctx_ptr->ctx == NULL) {
1260
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1261
+ return Qnil;
1262
+ }
1263
+ LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
1264
+ if (cnd_ptr->array.data == nullptr) {
1265
+ rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
1266
+ return Qnil;
1267
+ }
1268
+ const float tau = NUM2DBL(kw_values[0]);
1269
+ const float eta = NUM2DBL(kw_values[1]);
1270
+ float mu = NUM2DBL(kw_values[2]);
1271
+
1272
+ llama_token id = llama_sample_token_mirostat_v2(ctx_ptr->ctx, &(cnd_ptr->array), tau, eta, &mu);
1273
+
1274
+ VALUE ret = rb_ary_new2(2);
1275
+ rb_ary_store(ret, 0, INT2NUM(id));
1276
+ rb_ary_store(ret, 1, DBL2NUM(mu));
1277
+ return ret;
1278
+ };
1279
+
1280
+ static VALUE _llama_context_sample_token_greedy(VALUE self, VALUE candidates) {
1281
+ LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
1282
+ if (ctx_ptr->ctx == NULL) {
1283
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1284
+ return Qnil;
1285
+ }
1286
+ if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
1287
+ rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
1288
+ return Qnil;
1289
+ }
1290
+ LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
1291
+ if (cnd_ptr->array.data == nullptr) {
1292
+ rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
1293
+ return Qnil;
1294
+ }
1295
+ llama_token id = llama_sample_token_greedy(ctx_ptr->ctx, &(cnd_ptr->array));
1296
+ return INT2NUM(id);
1297
+ };
1298
+
1299
+ static VALUE _llama_context_sample_token(VALUE self, VALUE candidates) {
1300
+ LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
1301
+ if (ctx_ptr->ctx == NULL) {
1302
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1303
+ return Qnil;
1304
+ }
1305
+ if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
1306
+ rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
1307
+ return Qnil;
1308
+ }
1309
+ LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
1310
+ if (cnd_ptr->array.data == nullptr) {
1311
+ rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
1312
+ return Qnil;
1313
+ }
1314
+ llama_token id = llama_sample_token(ctx_ptr->ctx, &(cnd_ptr->array));
1315
+ return INT2NUM(id);
1316
+ };
624
1317
  };
625
1318
 
626
1319
  const rb_data_type_t RbLLaMAContext::llama_context_type = {
@@ -680,6 +1373,10 @@ static VALUE rb_llama_token_eos(VALUE self) {
680
1373
  return INT2NUM(llama_token_eos());
681
1374
  }
682
1375
 
1376
+ static VALUE rb_llama_token_nl(VALUE self) {
1377
+ return INT2NUM(llama_token_nl());
1378
+ }
1379
+
683
1380
  static VALUE rb_llama_print_system_info(VALUE self) {
684
1381
  const char* result = llama_print_system_info();
685
1382
  return rb_utf8_str_new_cstr(result);
@@ -695,12 +1392,16 @@ static VALUE rb_llama_mlock_supported(VALUE self) {
695
1392
 
696
1393
  extern "C" void Init_llama_cpp(void) {
697
1394
  rb_mLLaMACpp = rb_define_module("LLaMACpp");
1395
+
1396
+ RbLLaMATokenData::define_class(rb_mLLaMACpp);
1397
+ RbLLaMATokenDataArray::define_class(rb_mLLaMACpp);
698
1398
  RbLLaMAContext::define_class(rb_mLLaMACpp);
699
1399
  RbLLaMAContextParams::define_class(rb_mLLaMACpp);
700
1400
 
701
1401
  rb_define_module_function(rb_mLLaMACpp, "model_quantize", rb_llama_model_quantize, -1);
702
1402
  rb_define_module_function(rb_mLLaMACpp, "token_bos", rb_llama_token_bos, 0);
703
1403
  rb_define_module_function(rb_mLLaMACpp, "token_eos", rb_llama_token_eos, 0);
1404
+ rb_define_module_function(rb_mLLaMACpp, "token_nl", rb_llama_token_nl, 0);
704
1405
  rb_define_module_function(rb_mLLaMACpp, "print_system_info", rb_llama_print_system_info, 0);
705
1406
  rb_define_module_function(rb_mLLaMACpp, "mmap_supported?", rb_llama_mmap_supported, 0);
706
1407
  rb_define_module_function(rb_mLLaMACpp, "mlock_supported?", rb_llama_mlock_supported, 0);
@@ -711,7 +1412,6 @@ extern "C" void Init_llama_cpp(void) {
711
1412
  rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q4_1", INT2NUM(LLAMA_FTYPE_MOSTLY_Q4_1));
712
1413
  rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16", INT2NUM(LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16));
713
1414
  rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q4_2", INT2NUM(LLAMA_FTYPE_MOSTLY_Q4_2));
714
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q4_3", INT2NUM(LLAMA_FTYPE_MOSTLY_Q4_3));
715
1415
  rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q8_0", INT2NUM(LLAMA_FTYPE_MOSTLY_Q8_0));
716
1416
  rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q5_0", INT2NUM(LLAMA_FTYPE_MOSTLY_Q5_0));
717
1417
  rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q5_1", INT2NUM(LLAMA_FTYPE_MOSTLY_Q5_1));