llama_cpp 0.17.9 → 0.18.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,3761 +0,0 @@
1
- #include "llama_cpp.h"
2
-
3
- VALUE rb_mLLaMACpp;
4
- VALUE rb_cLLaMABatch;
5
- VALUE rb_cLLaMAModel;
6
- VALUE rb_cLLaMAModelKVOverride;
7
- VALUE rb_cLLaMAModelParams;
8
- VALUE rb_cLLaMATimings;
9
- VALUE rb_cLLaMAContext;
10
- VALUE rb_cLLaMAContextParams;
11
- VALUE rb_cLLaMAModelQuantizeParams;
12
- VALUE rb_cLLaMATokenData;
13
- VALUE rb_cLLaMATokenDataArray;
14
- VALUE rb_cLLaMAGrammarElement;
15
- VALUE rb_cLLaMAGrammar;
16
-
17
- class LLaMABatchWrapper {
18
- public:
19
- llama_batch batch;
20
-
21
- LLaMABatchWrapper() {}
22
-
23
- ~LLaMABatchWrapper() {
24
- llama_batch_free(batch);
25
- }
26
- };
27
-
28
- class RbLLaMABatch {
29
- public:
30
- static VALUE llama_batch_alloc(VALUE self) {
31
- LLaMABatchWrapper* ptr = (LLaMABatchWrapper*)ruby_xmalloc(sizeof(LLaMABatchWrapper));
32
- new (ptr) LLaMABatchWrapper();
33
- return TypedData_Wrap_Struct(self, &llama_batch_type, ptr);
34
- }
35
-
36
- static void llama_batch_free(void* ptr) {
37
- ((LLaMABatchWrapper*)ptr)->~LLaMABatchWrapper();
38
- ruby_xfree(ptr);
39
- }
40
-
41
- static size_t llama_batch_size(const void* ptr) {
42
- return sizeof(*((LLaMABatchWrapper*)ptr));
43
- }
44
-
45
- static LLaMABatchWrapper* get_llama_batch(VALUE self) {
46
- LLaMABatchWrapper* ptr;
47
- TypedData_Get_Struct(self, LLaMABatchWrapper, &llama_batch_type, ptr);
48
- return ptr;
49
- }
50
-
51
- static void define_class(VALUE outer) {
52
- rb_cLLaMABatch = rb_define_class_under(outer, "Batch", rb_cObject);
53
- rb_define_alloc_func(rb_cLLaMABatch, llama_batch_alloc);
54
- rb_define_singleton_method(rb_cLLaMABatch, "get_one", RUBY_METHOD_FUNC(_llama_batch_get_one), -1);
55
- rb_define_method(rb_cLLaMABatch, "initialize", RUBY_METHOD_FUNC(_llama_batch_initialize), -1);
56
- rb_define_method(rb_cLLaMABatch, "n_tokens=", RUBY_METHOD_FUNC(_llama_batch_set_n_tokens), 1);
57
- rb_define_method(rb_cLLaMABatch, "n_tokens", RUBY_METHOD_FUNC(_llama_batch_get_n_tokens), 0);
58
- rb_define_method(rb_cLLaMABatch, "all_pos_zero=", RUBY_METHOD_FUNC(_llama_batch_set_all_pos_zero), 1);
59
- rb_define_method(rb_cLLaMABatch, "all_pos_zero", RUBY_METHOD_FUNC(_llama_batch_get_all_pos_zero), 0);
60
- rb_define_method(rb_cLLaMABatch, "all_pos_one=", RUBY_METHOD_FUNC(_llama_batch_set_all_pos_one), 1);
61
- rb_define_method(rb_cLLaMABatch, "all_pos_one", RUBY_METHOD_FUNC(_llama_batch_get_all_pos_one), 0);
62
- rb_define_method(rb_cLLaMABatch, "all_seq_id=", RUBY_METHOD_FUNC(_llama_batch_set_all_seq_id), 1);
63
- rb_define_method(rb_cLLaMABatch, "all_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_all_seq_id), 0);
64
- rb_define_method(rb_cLLaMABatch, "set_token", RUBY_METHOD_FUNC(_llama_batch_set_token), 2);
65
- rb_define_method(rb_cLLaMABatch, "get_token", RUBY_METHOD_FUNC(_llama_batch_get_token), 1);
66
- rb_define_method(rb_cLLaMABatch, "set_pos", RUBY_METHOD_FUNC(_llama_batch_set_pos), 2);
67
- rb_define_method(rb_cLLaMABatch, "get_pos", RUBY_METHOD_FUNC(_llama_batch_get_pos), 1);
68
- rb_define_method(rb_cLLaMABatch, "set_n_seq_id", RUBY_METHOD_FUNC(_llama_batch_set_n_seq_id), 2);
69
- rb_define_method(rb_cLLaMABatch, "get_n_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_n_seq_id), 1);
70
- rb_define_method(rb_cLLaMABatch, "set_seq_id", RUBY_METHOD_FUNC(_llama_batch_set_seq_id), 3);
71
- rb_define_method(rb_cLLaMABatch, "get_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_seq_id), 2);
72
- rb_define_method(rb_cLLaMABatch, "set_logits", RUBY_METHOD_FUNC(_llama_batch_set_logits), 2);
73
- rb_define_method(rb_cLLaMABatch, "get_logits", RUBY_METHOD_FUNC(_llama_batch_get_logits), 1);
74
- }
75
-
76
- private:
77
- static const rb_data_type_t llama_batch_type;
78
-
79
- static VALUE _llama_batch_get_one(int argc, VALUE* argv, VALUE klass) {
80
- VALUE kw_args = Qnil;
81
- ID kw_table[4] = { rb_intern("tokens"), rb_intern("n_tokens"), rb_intern("pos_zero"), rb_intern("seq_id") };
82
- VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
83
- rb_scan_args(argc, argv, ":", &kw_args);
84
- rb_get_kwargs(kw_args, kw_table, 4, 0, kw_values);
85
-
86
- if (!RB_TYPE_P(kw_values[0], T_ARRAY)) {
87
- rb_raise(rb_eArgError, "tokens must be an array");
88
- return Qnil;
89
- }
90
- if (!RB_INTEGER_TYPE_P(kw_values[1])) {
91
- rb_raise(rb_eArgError, "n_tokens must be an integer");
92
- return Qnil;
93
- }
94
- if (!RB_INTEGER_TYPE_P(kw_values[2])) {
95
- rb_raise(rb_eArgError, "pos_zero must be an integer");
96
- return Qnil;
97
- }
98
- if (!RB_INTEGER_TYPE_P(kw_values[3])) {
99
- rb_raise(rb_eArgError, "seq_id must be an integer");
100
- return Qnil;
101
- }
102
-
103
- const size_t sz_array = RARRAY_LEN(kw_values[0]);
104
- const int32_t n_tokens = NUM2INT(kw_values[1]);
105
- const llama_pos pos_zero = NUM2INT(kw_values[2]);
106
- const llama_seq_id seq_id = NUM2INT(kw_values[3]);
107
-
108
- LLaMABatchWrapper* ptr = (LLaMABatchWrapper*)ruby_xmalloc(sizeof(LLaMABatchWrapper));
109
- new (ptr) LLaMABatchWrapper();
110
- ptr->batch = llama_batch_get_one(nullptr, n_tokens, pos_zero, seq_id);
111
-
112
- ptr->batch.token = (llama_token*)malloc(sizeof(llama_token) * sz_array);
113
- for (size_t i = 0; i < sz_array; i++) {
114
- VALUE el = rb_ary_entry(kw_values[0], i);
115
- ptr->batch.token[i] = NUM2INT(el);
116
- }
117
-
118
- return TypedData_Wrap_Struct(klass, &llama_batch_type, ptr);
119
- }
120
-
121
- static VALUE _llama_batch_initialize(int argc, VALUE* argv, VALUE self) {
122
- VALUE kw_args = Qnil;
123
- ID kw_table[3] = { rb_intern("max_n_token"), rb_intern("n_embd"), rb_intern("max_n_seq") };
124
- VALUE kw_values[3] = { Qundef, Qundef, Qundef };
125
- rb_scan_args(argc, argv, ":", &kw_args);
126
- rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
127
-
128
- if (!RB_INTEGER_TYPE_P(kw_values[0])) {
129
- rb_raise(rb_eArgError, "max_n_token must be an integer");
130
- return Qnil;
131
- }
132
- if (!RB_INTEGER_TYPE_P(kw_values[1])) {
133
- rb_raise(rb_eArgError, "n_embd must be an integer");
134
- return Qnil;
135
- }
136
- if (!RB_INTEGER_TYPE_P(kw_values[2])) {
137
- rb_raise(rb_eArgError, "max_n_seq must be an integer");
138
- return Qnil;
139
- }
140
-
141
- const int32_t max_n_token = NUM2INT(kw_values[0]);
142
- const int32_t n_embd = NUM2INT(kw_values[1]);
143
- const int32_t max_n_seq = NUM2INT(kw_values[2]);
144
-
145
- LLaMABatchWrapper* ptr = get_llama_batch(self);
146
- ptr->batch = llama_batch_init(max_n_token, n_embd, max_n_seq);
147
-
148
- return Qnil;
149
- }
150
-
151
- // n_tokens
152
- static VALUE _llama_batch_set_n_tokens(VALUE self, VALUE n_tokens) {
153
- LLaMABatchWrapper* ptr = get_llama_batch(self);
154
- ptr->batch.n_tokens = NUM2INT(n_tokens);
155
- return INT2NUM(ptr->batch.n_tokens);
156
- }
157
-
158
- static VALUE _llama_batch_get_n_tokens(VALUE self) {
159
- LLaMABatchWrapper* ptr = get_llama_batch(self);
160
- return INT2NUM(ptr->batch.n_tokens);
161
- }
162
-
163
- // all_pos_0
164
- static VALUE _llama_batch_set_all_pos_zero(VALUE self, VALUE all_pos_0) {
165
- LLaMABatchWrapper* ptr = get_llama_batch(self);
166
- ptr->batch.all_pos_0 = NUM2INT(all_pos_0);
167
- return INT2NUM(ptr->batch.all_pos_0);
168
- }
169
-
170
- static VALUE _llama_batch_get_all_pos_zero(VALUE self) {
171
- LLaMABatchWrapper* ptr = get_llama_batch(self);
172
- return INT2NUM(ptr->batch.all_pos_0);
173
- }
174
-
175
- // all_pos_1
176
- static VALUE _llama_batch_set_all_pos_one(VALUE self, VALUE all_pos_1) {
177
- LLaMABatchWrapper* ptr = get_llama_batch(self);
178
- ptr->batch.all_pos_1 = NUM2INT(all_pos_1);
179
- return INT2NUM(ptr->batch.all_pos_1);
180
- }
181
-
182
- static VALUE _llama_batch_get_all_pos_one(VALUE self) {
183
- LLaMABatchWrapper* ptr = get_llama_batch(self);
184
- return INT2NUM(ptr->batch.all_pos_1);
185
- }
186
-
187
- // all_seq_id
188
- static VALUE _llama_batch_set_all_seq_id(VALUE self, VALUE all_seq_id) {
189
- LLaMABatchWrapper* ptr = get_llama_batch(self);
190
- ptr->batch.all_seq_id = NUM2INT(all_seq_id);
191
- return INT2NUM(ptr->batch.all_seq_id);
192
- }
193
-
194
- static VALUE _llama_batch_get_all_seq_id(VALUE self) {
195
- LLaMABatchWrapper* ptr = get_llama_batch(self);
196
- return INT2NUM(ptr->batch.all_seq_id);
197
- }
198
-
199
- // token
200
- static VALUE _llama_batch_set_token(VALUE self, VALUE idx, VALUE value) {
201
- LLaMABatchWrapper* ptr = get_llama_batch(self);
202
- const int32_t id = NUM2INT(idx);
203
- if (id < 0) {
204
- rb_raise(rb_eArgError, "id must be greater or equal to 0");
205
- return Qnil;
206
- }
207
- ptr->batch.token[id] = NUM2INT(value);
208
- return INT2NUM(ptr->batch.token[id]);
209
- }
210
-
211
- static VALUE _llama_batch_get_token(VALUE self, VALUE idx) {
212
- LLaMABatchWrapper* ptr = get_llama_batch(self);
213
- const int32_t id = NUM2INT(idx);
214
- if (id < 0) {
215
- rb_raise(rb_eArgError, "id must be greater or equal to 0");
216
- return Qnil;
217
- }
218
- return INT2NUM(ptr->batch.token[id]);
219
- }
220
-
221
- // pos
222
- static VALUE _llama_batch_set_pos(VALUE self, VALUE idx, VALUE value) {
223
- LLaMABatchWrapper* ptr = get_llama_batch(self);
224
- const int32_t id = NUM2INT(idx);
225
- if (id < 0) {
226
- rb_raise(rb_eArgError, "id must be greater or equal to 0");
227
- return Qnil;
228
- }
229
- ptr->batch.pos[id] = NUM2INT(value);
230
- return INT2NUM(ptr->batch.pos[id]);
231
- }
232
-
233
- static VALUE _llama_batch_get_pos(VALUE self, VALUE idx) {
234
- LLaMABatchWrapper* ptr = get_llama_batch(self);
235
- const int32_t id = NUM2INT(idx);
236
- if (id < 0) {
237
- rb_raise(rb_eArgError, "id must be greater or equal to 0");
238
- return Qnil;
239
- }
240
- return INT2NUM(ptr->batch.pos[id]);
241
- }
242
-
243
- // n_seq_id
244
- static VALUE _llama_batch_set_n_seq_id(VALUE self, VALUE idx, VALUE value) {
245
- LLaMABatchWrapper* ptr = get_llama_batch(self);
246
- const int32_t id = NUM2INT(idx);
247
- if (id < 0) {
248
- rb_raise(rb_eArgError, "id must be greater or equal to 0");
249
- return Qnil;
250
- }
251
- ptr->batch.n_seq_id[id] = NUM2INT(value);
252
- return INT2NUM(ptr->batch.n_seq_id[id]);
253
- }
254
-
255
- static VALUE _llama_batch_get_n_seq_id(VALUE self, VALUE idx) {
256
- LLaMABatchWrapper* ptr = get_llama_batch(self);
257
- const int32_t id = NUM2INT(idx);
258
- if (id < 0) {
259
- rb_raise(rb_eArgError, "id must be greater or equal to 0");
260
- return Qnil;
261
- }
262
- return INT2NUM(ptr->batch.n_seq_id[id]);
263
- }
264
-
265
- // seq_id
266
- static VALUE _llama_batch_set_seq_id(VALUE self, VALUE i_, VALUE j_, VALUE value) {
267
- LLaMABatchWrapper* ptr = get_llama_batch(self);
268
- const int32_t i = NUM2INT(i_);
269
- if (i < 0) {
270
- rb_raise(rb_eArgError, "i must be greater or equal to 0");
271
- return Qnil;
272
- }
273
- const int32_t j = NUM2INT(j_);
274
- if (j < 0) {
275
- rb_raise(rb_eArgError, "j must be greater or equal to 0");
276
- return Qnil;
277
- }
278
- ptr->batch.seq_id[i][j] = NUM2INT(value);
279
- return INT2NUM(ptr->batch.seq_id[i][j]);
280
- }
281
-
282
- static VALUE _llama_batch_get_seq_id(VALUE self, VALUE i_, VALUE j_) {
283
- LLaMABatchWrapper* ptr = get_llama_batch(self);
284
- const int32_t i = NUM2INT(i_);
285
- if (i < 0) {
286
- rb_raise(rb_eArgError, "i must be greater or equal to 0");
287
- return Qnil;
288
- }
289
- const int32_t j = NUM2INT(j_);
290
- if (j < 0) {
291
- rb_raise(rb_eArgError, "j must be greater or equal to 0");
292
- return Qnil;
293
- }
294
- return INT2NUM(ptr->batch.seq_id[i][j]);
295
- }
296
-
297
- // logits
298
- static VALUE _llama_batch_set_logits(VALUE self, VALUE idx, VALUE value) {
299
- LLaMABatchWrapper* ptr = get_llama_batch(self);
300
- const int32_t id = NUM2INT(idx);
301
- if (id < 0) {
302
- rb_raise(rb_eArgError, "id must be greater or equal to 0");
303
- return Qnil;
304
- }
305
- ptr->batch.logits[id] = RTEST(value) ? true : false;
306
- return ptr->batch.logits[id] ? Qtrue : Qfalse;
307
- }
308
-
309
- static VALUE _llama_batch_get_logits(VALUE self, VALUE idx) {
310
- LLaMABatchWrapper* ptr = get_llama_batch(self);
311
- const int32_t id = NUM2INT(idx);
312
- if (id < 0) {
313
- rb_raise(rb_eArgError, "id must be greater or equal to 0");
314
- return Qnil;
315
- }
316
- return ptr->batch.logits[id] ? Qtrue : Qfalse;
317
- }
318
- };
319
-
320
- const rb_data_type_t RbLLaMABatch::llama_batch_type = {
321
- "RbLLaMABatch",
322
- { NULL,
323
- RbLLaMABatch::llama_batch_free,
324
- RbLLaMABatch::llama_batch_size },
325
- NULL,
326
- NULL,
327
- RUBY_TYPED_FREE_IMMEDIATELY
328
-
329
- };
330
-
331
- class LLaMATokenDataWrapper {
332
- public:
333
- llama_token_data data;
334
-
335
- LLaMATokenDataWrapper() {
336
- data.id = 0;
337
- data.logit = 0.0;
338
- data.p = 0.0;
339
- }
340
-
341
- ~LLaMATokenDataWrapper() {}
342
- };
343
-
344
- class RbLLaMATokenData {
345
- public:
346
- static VALUE llama_token_data_alloc(VALUE self) {
347
- LLaMATokenDataWrapper* ptr = (LLaMATokenDataWrapper*)ruby_xmalloc(sizeof(LLaMATokenDataWrapper));
348
- new (ptr) LLaMATokenDataWrapper();
349
- return TypedData_Wrap_Struct(self, &llama_token_data_type, ptr);
350
- }
351
-
352
- static void llama_token_data_free(void* ptr) {
353
- ((LLaMATokenDataWrapper*)ptr)->~LLaMATokenDataWrapper();
354
- ruby_xfree(ptr);
355
- }
356
-
357
- static size_t llama_token_data_size(const void* ptr) {
358
- return sizeof(*((LLaMATokenDataWrapper*)ptr));
359
- }
360
-
361
- static LLaMATokenDataWrapper* get_llama_token_data(VALUE self) {
362
- LLaMATokenDataWrapper* ptr;
363
- TypedData_Get_Struct(self, LLaMATokenDataWrapper, &llama_token_data_type, ptr);
364
- return ptr;
365
- }
366
-
367
- static void define_class(VALUE outer) {
368
- rb_cLLaMATokenData = rb_define_class_under(outer, "TokenData", rb_cObject);
369
- rb_define_alloc_func(rb_cLLaMATokenData, llama_token_data_alloc);
370
- rb_define_method(rb_cLLaMATokenData, "initialize", RUBY_METHOD_FUNC(_llama_token_data_init), -1);
371
- rb_define_method(rb_cLLaMATokenData, "id=", RUBY_METHOD_FUNC(_llama_token_data_set_id), 1);
372
- rb_define_method(rb_cLLaMATokenData, "id", RUBY_METHOD_FUNC(_llama_token_data_get_id), 0);
373
- rb_define_method(rb_cLLaMATokenData, "logit=", RUBY_METHOD_FUNC(_llama_token_data_set_logit), 1);
374
- rb_define_method(rb_cLLaMATokenData, "logit", RUBY_METHOD_FUNC(_llama_token_data_get_logit), 0);
375
- rb_define_method(rb_cLLaMATokenData, "p=", RUBY_METHOD_FUNC(_llama_token_data_set_p), 1);
376
- rb_define_method(rb_cLLaMATokenData, "p", RUBY_METHOD_FUNC(_llama_token_data_get_p), 0);
377
- }
378
-
379
- private:
380
- static const rb_data_type_t llama_token_data_type;
381
-
382
- static VALUE _llama_token_data_init(int argc, VALUE* argv, VALUE self) {
383
- VALUE kw_args = Qnil;
384
- ID kw_table[3] = { rb_intern("id"), rb_intern("logit"), rb_intern("p") };
385
- VALUE kw_values[3] = { Qundef, Qundef, Qundef };
386
- rb_scan_args(argc, argv, ":", &kw_args);
387
- rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
388
-
389
- if (!RB_INTEGER_TYPE_P(kw_values[0])) {
390
- rb_raise(rb_eArgError, "id must be an integer");
391
- return Qnil;
392
- }
393
- if (!RB_FLOAT_TYPE_P(kw_values[1])) {
394
- rb_raise(rb_eArgError, "logit must be a float");
395
- return Qnil;
396
- }
397
- if (!RB_FLOAT_TYPE_P(kw_values[2])) {
398
- rb_raise(rb_eArgError, "p must be a float");
399
- return Qnil;
400
- }
401
-
402
- LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
403
- new (ptr) LLaMATokenDataWrapper();
404
-
405
- ptr->data.id = NUM2INT(kw_values[0]);
406
- ptr->data.logit = NUM2DBL(kw_values[1]);
407
- ptr->data.p = NUM2DBL(kw_values[2]);
408
-
409
- return self;
410
- }
411
-
412
- // id
413
- static VALUE _llama_token_data_set_id(VALUE self, VALUE id) {
414
- LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
415
- ptr->data.id = NUM2INT(id);
416
- return INT2NUM(ptr->data.id);
417
- }
418
-
419
- static VALUE _llama_token_data_get_id(VALUE self) {
420
- LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
421
- return INT2NUM(ptr->data.id);
422
- }
423
-
424
- // logit
425
- static VALUE _llama_token_data_set_logit(VALUE self, VALUE logit) {
426
- LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
427
- ptr->data.logit = NUM2DBL(logit);
428
- return DBL2NUM(ptr->data.logit);
429
- }
430
-
431
- static VALUE _llama_token_data_get_logit(VALUE self) {
432
- LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
433
- return DBL2NUM(ptr->data.logit);
434
- }
435
-
436
- // p
437
- static VALUE _llama_token_data_set_p(VALUE self, VALUE p) {
438
- LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
439
- ptr->data.p = NUM2DBL(p);
440
- return DBL2NUM(ptr->data.p);
441
- }
442
-
443
- static VALUE _llama_token_data_get_p(VALUE self) {
444
- LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
445
- return DBL2NUM(ptr->data.p);
446
- }
447
- };
448
-
449
- const rb_data_type_t RbLLaMATokenData::llama_token_data_type = {
450
- "RbLLaMATokenData",
451
- { NULL,
452
- RbLLaMATokenData::llama_token_data_free,
453
- RbLLaMATokenData::llama_token_data_size },
454
- NULL,
455
- NULL,
456
- RUBY_TYPED_FREE_IMMEDIATELY
457
- };
458
-
459
- class LLaMATokenDataArrayWrapper {
460
- public:
461
- llama_token_data_array array;
462
-
463
- LLaMATokenDataArrayWrapper() {
464
- array.data = nullptr;
465
- array.size = 0;
466
- array.sorted = false;
467
- }
468
-
469
- ~LLaMATokenDataArrayWrapper() {
470
- if (array.data) {
471
- ruby_xfree(array.data);
472
- array.data = nullptr;
473
- }
474
- }
475
- };
476
-
477
- class RbLLaMATokenDataArray {
478
- public:
479
- static VALUE llama_token_data_array_alloc(VALUE self) {
480
- LLaMATokenDataArrayWrapper* ptr = (LLaMATokenDataArrayWrapper*)ruby_xmalloc(sizeof(LLaMATokenDataArrayWrapper));
481
- new (ptr) LLaMATokenDataArrayWrapper();
482
- return TypedData_Wrap_Struct(self, &llama_token_data_array_type, ptr);
483
- }
484
-
485
- static void llama_token_data_array_free(void* ptr) {
486
- ((LLaMATokenDataArrayWrapper*)ptr)->~LLaMATokenDataArrayWrapper();
487
- ruby_xfree(ptr);
488
- }
489
-
490
- static size_t llama_token_data_array_size(const void* ptr) {
491
- return sizeof(*((LLaMATokenDataArrayWrapper*)ptr));
492
- }
493
-
494
- static LLaMATokenDataArrayWrapper* get_llama_token_data_array(VALUE self) {
495
- LLaMATokenDataArrayWrapper* ptr;
496
- TypedData_Get_Struct(self, LLaMATokenDataArrayWrapper, &llama_token_data_array_type, ptr);
497
- return ptr;
498
- }
499
-
500
- static void define_class(VALUE outer) {
501
- rb_cLLaMATokenDataArray = rb_define_class_under(outer, "TokenDataArray", rb_cObject);
502
- rb_define_alloc_func(rb_cLLaMATokenDataArray, llama_token_data_array_alloc);
503
- rb_define_method(rb_cLLaMATokenDataArray, "initialize", RUBY_METHOD_FUNC(_llama_token_data_array_init), -1);
504
- rb_define_method(rb_cLLaMATokenDataArray, "size", RUBY_METHOD_FUNC(_llama_token_data_array_get_size), 0);
505
- rb_define_method(rb_cLLaMATokenDataArray, "sorted", RUBY_METHOD_FUNC(_llama_token_data_array_get_sorted), 0);
506
- }
507
-
508
- private:
509
- static const rb_data_type_t llama_token_data_array_type;
510
-
511
- static VALUE _llama_token_data_array_init(int argc, VALUE* argv, VALUE self) {
512
- VALUE kw_args = Qnil;
513
- ID kw_table[1] = { rb_intern("sorted") };
514
- VALUE kw_values[1] = { Qundef };
515
- VALUE arr = Qnil;
516
- rb_scan_args(argc, argv, "1:", &arr, &kw_args);
517
- rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
518
-
519
- if (!RB_TYPE_P(arr, T_ARRAY)) {
520
- rb_raise(rb_eArgError, "1st argument must be an array");
521
- return Qnil;
522
- }
523
- size_t sz_array = RARRAY_LEN(arr);
524
- if (sz_array == 0) {
525
- rb_raise(rb_eArgError, "array must not be empty");
526
- return Qnil;
527
- }
528
- if (kw_values[0] != Qundef && !RB_TYPE_P(kw_values[0], T_TRUE) && !RB_TYPE_P(kw_values[0], T_FALSE)) {
529
- rb_raise(rb_eArgError, "sorted must be a boolean");
530
- return Qnil;
531
- }
532
-
533
- LLaMATokenDataArrayWrapper* ptr = get_llama_token_data_array(self);
534
- new (ptr) LLaMATokenDataArrayWrapper();
535
-
536
- ptr->array.data = (llama_token_data*)ruby_xmalloc(sizeof(llama_token_data) * sz_array);
537
- for (size_t i = 0; i < sz_array; ++i) {
538
- VALUE el = rb_ary_entry(arr, i);
539
- if (!rb_obj_is_kind_of(el, rb_cLLaMATokenData)) {
540
- rb_raise(rb_eArgError, "array element must be a TokenData");
541
- xfree(ptr->array.data);
542
- ptr->array.data = nullptr;
543
- return Qnil;
544
- }
545
- llama_token_data token_data = RbLLaMATokenData::get_llama_token_data(el)->data;
546
- ptr->array.data[i].id = token_data.id;
547
- ptr->array.data[i].logit = token_data.logit;
548
- ptr->array.data[i].p = token_data.p;
549
- }
550
-
551
- ptr->array.size = sz_array;
552
- ptr->array.sorted = kw_values[0] == Qtrue;
553
-
554
- return self;
555
- }
556
-
557
- static VALUE _llama_token_data_array_get_size(VALUE self) {
558
- LLaMATokenDataArrayWrapper* ptr = get_llama_token_data_array(self);
559
- return SIZET2NUM(ptr->array.size);
560
- }
561
-
562
- static VALUE _llama_token_data_array_get_sorted(VALUE self) {
563
- LLaMATokenDataArrayWrapper* ptr = get_llama_token_data_array(self);
564
- return ptr->array.sorted ? Qtrue : Qfalse;
565
- }
566
- };
567
-
568
- const rb_data_type_t RbLLaMATokenDataArray::llama_token_data_array_type = {
569
- "RbLLaMATokenDataArray",
570
- { NULL,
571
- RbLLaMATokenDataArray::llama_token_data_array_free,
572
- RbLLaMATokenDataArray::llama_token_data_array_size },
573
- NULL,
574
- NULL,
575
- RUBY_TYPED_FREE_IMMEDIATELY
576
- };
577
-
578
- class LLaMATimingsWrapper {
579
- public:
580
- struct llama_timings timings;
581
-
582
- LLaMATimingsWrapper() {}
583
-
584
- ~LLaMATimingsWrapper() {}
585
- };
586
-
587
- class RbLLaMATimings {
588
- public:
589
- static VALUE llama_timings_alloc(VALUE self) {
590
- LLaMATimingsWrapper* ptr = (LLaMATimingsWrapper*)ruby_xmalloc(sizeof(LLaMATimingsWrapper));
591
- new (ptr) LLaMATimingsWrapper();
592
- return TypedData_Wrap_Struct(self, &llama_timings_type, ptr);
593
- }
594
-
595
- static void llama_timings_free(void* ptr) {
596
- ((LLaMATimingsWrapper*)ptr)->~LLaMATimingsWrapper();
597
- ruby_xfree(ptr);
598
- }
599
-
600
- static size_t llama_timings_size(const void* ptr) {
601
- return sizeof(*((LLaMATimingsWrapper*)ptr));
602
- }
603
-
604
- static LLaMATimingsWrapper* get_llama_timings(VALUE self) {
605
- LLaMATimingsWrapper* ptr;
606
- TypedData_Get_Struct(self, LLaMATimingsWrapper, &llama_timings_type, ptr);
607
- return ptr;
608
- }
609
-
610
- static void define_class(VALUE outer) {
611
- rb_cLLaMATimings = rb_define_class_under(outer, "Timings", rb_cObject);
612
- rb_define_alloc_func(rb_cLLaMATimings, llama_timings_alloc);
613
- rb_define_method(rb_cLLaMATimings, "t_start_ms", RUBY_METHOD_FUNC(_llama_timings_get_t_start_ms), 0);
614
- rb_define_method(rb_cLLaMATimings, "t_end_ms", RUBY_METHOD_FUNC(_llama_timings_get_t_end_ms), 0);
615
- rb_define_method(rb_cLLaMATimings, "t_load_ms", RUBY_METHOD_FUNC(_llama_timings_get_t_load_ms), 0);
616
- rb_define_method(rb_cLLaMATimings, "t_sample_ms", RUBY_METHOD_FUNC(_llama_timings_get_t_sample_ms), 0);
617
- rb_define_method(rb_cLLaMATimings, "t_p_eval_ms", RUBY_METHOD_FUNC(_llama_timings_get_t_p_eval_ms), 0);
618
- rb_define_method(rb_cLLaMATimings, "t_eval_ms", RUBY_METHOD_FUNC(_llama_timings_get_t_eval_ms), 0);
619
- rb_define_method(rb_cLLaMATimings, "n_sample", RUBY_METHOD_FUNC(_llama_timings_get_n_sample), 0);
620
- rb_define_method(rb_cLLaMATimings, "n_p_eval", RUBY_METHOD_FUNC(_llama_timings_get_n_p_eval), 0);
621
- rb_define_method(rb_cLLaMATimings, "n_eval", RUBY_METHOD_FUNC(_llama_timings_get_n_eval), 0);
622
- }
623
-
624
- private:
625
- static const rb_data_type_t llama_timings_type;
626
-
627
- static VALUE _llama_timings_get_t_start_ms(VALUE self) {
628
- LLaMATimingsWrapper* ptr = get_llama_timings(self);
629
- return DBL2NUM(ptr->timings.t_start_ms);
630
- }
631
-
632
- static VALUE _llama_timings_get_t_end_ms(VALUE self) {
633
- LLaMATimingsWrapper* ptr = get_llama_timings(self);
634
- return DBL2NUM(ptr->timings.t_end_ms);
635
- }
636
-
637
- static VALUE _llama_timings_get_t_load_ms(VALUE self) {
638
- LLaMATimingsWrapper* ptr = get_llama_timings(self);
639
- return DBL2NUM(ptr->timings.t_load_ms);
640
- }
641
-
642
- static VALUE _llama_timings_get_t_sample_ms(VALUE self) {
643
- LLaMATimingsWrapper* ptr = get_llama_timings(self);
644
- return DBL2NUM(ptr->timings.t_sample_ms);
645
- }
646
-
647
- static VALUE _llama_timings_get_t_p_eval_ms(VALUE self) {
648
- LLaMATimingsWrapper* ptr = get_llama_timings(self);
649
- return DBL2NUM(ptr->timings.t_p_eval_ms);
650
- }
651
-
652
- static VALUE _llama_timings_get_t_eval_ms(VALUE self) {
653
- LLaMATimingsWrapper* ptr = get_llama_timings(self);
654
- return DBL2NUM(ptr->timings.t_eval_ms);
655
- }
656
-
657
- static VALUE _llama_timings_get_n_sample(VALUE self) {
658
- LLaMATimingsWrapper* ptr = get_llama_timings(self);
659
- return INT2NUM(ptr->timings.n_sample);
660
- }
661
-
662
- static VALUE _llama_timings_get_n_p_eval(VALUE self) {
663
- LLaMATimingsWrapper* ptr = get_llama_timings(self);
664
- return INT2NUM(ptr->timings.n_p_eval);
665
- }
666
-
667
- static VALUE _llama_timings_get_n_eval(VALUE self) {
668
- LLaMATimingsWrapper* ptr = get_llama_timings(self);
669
- return INT2NUM(ptr->timings.n_eval);
670
- }
671
- };
672
-
673
- const rb_data_type_t RbLLaMATimings::llama_timings_type = {
674
- "RbLLaMATimings",
675
- { NULL,
676
- RbLLaMATimings::llama_timings_free,
677
- RbLLaMATimings::llama_timings_size },
678
- NULL,
679
- NULL,
680
- RUBY_TYPED_FREE_IMMEDIATELY
681
- };
682
-
683
- class RbLLaMAModelKVOverride {
684
- public:
685
- static VALUE llama_model_kv_override_alloc(VALUE self) {
686
- llama_model_kv_override* ptr = (llama_model_kv_override*)ruby_xmalloc(sizeof(llama_model_kv_override));
687
- new (ptr) llama_model_kv_override();
688
- return TypedData_Wrap_Struct(self, &llama_model_kv_override_type, ptr);
689
- }
690
-
691
- static void llama_model_kv_override_free(void* ptr) {
692
- ((llama_model_kv_override*)ptr)->~llama_model_kv_override();
693
- ruby_xfree(ptr);
694
- }
695
-
696
- static size_t llama_model_kv_override_size(const void* ptr) {
697
- return sizeof(*((llama_model_kv_override*)ptr));
698
- }
699
-
700
- static llama_model_kv_override* get_llama_model_kv_override(VALUE self) {
701
- llama_model_kv_override* ptr;
702
- TypedData_Get_Struct(self, llama_model_kv_override, &llama_model_kv_override_type, ptr);
703
- return ptr;
704
- }
705
-
706
- static void define_class(VALUE outer) {
707
- rb_cLLaMAModelKVOverride = rb_define_class_under(outer, "ModelKVOverride", rb_cObject);
708
- rb_define_alloc_func(rb_cLLaMAModelKVOverride, llama_model_kv_override_alloc);
709
- rb_define_method(rb_cLLaMAModelKVOverride, "key", RUBY_METHOD_FUNC(_llama_model_kv_override_get_key), 0);
710
- rb_define_method(rb_cLLaMAModelKVOverride, "tag", RUBY_METHOD_FUNC(_llama_model_kv_override_get_tag), 0);
711
- rb_define_method(rb_cLLaMAModelKVOverride, "val_i64", RUBY_METHOD_FUNC(_llama_model_kv_override_get_val_i64), 0);
712
- rb_define_method(rb_cLLaMAModelKVOverride, "val_f64", RUBY_METHOD_FUNC(_llama_model_kv_override_get_val_f64), 0);
713
- rb_define_method(rb_cLLaMAModelKVOverride, "val_bool", RUBY_METHOD_FUNC(_llama_model_kv_override_get_val_bool), 0);
714
- rb_define_method(rb_cLLaMAModelKVOverride, "val_str", RUBY_METHOD_FUNC(_llama_model_kv_override_get_val_str), 0);
715
- }
716
-
717
- static const rb_data_type_t llama_model_kv_override_type;
718
-
719
- private:
720
- static VALUE _llama_model_kv_override_get_key(VALUE self) {
721
- llama_model_kv_override* ptr = get_llama_model_kv_override(self);
722
- return rb_utf8_str_new_cstr(ptr->key);
723
- }
724
-
725
- static VALUE _llama_model_kv_override_get_tag(VALUE self) {
726
- llama_model_kv_override* ptr = get_llama_model_kv_override(self);
727
- return INT2NUM(ptr->tag);
728
- }
729
-
730
- static VALUE _llama_model_kv_override_get_val_i64(VALUE self) {
731
- llama_model_kv_override* ptr = get_llama_model_kv_override(self);
732
- return INT2NUM(ptr->val_i64);
733
- }
734
-
735
- static VALUE _llama_model_kv_override_get_val_f64(VALUE self) {
736
- llama_model_kv_override* ptr = get_llama_model_kv_override(self);
737
- return DBL2NUM(ptr->val_f64);
738
- }
739
-
740
- static VALUE _llama_model_kv_override_get_val_bool(VALUE self) {
741
- llama_model_kv_override* ptr = get_llama_model_kv_override(self);
742
- return ptr->val_bool ? Qtrue : Qfalse;
743
- }
744
-
745
- static VALUE _llama_model_kv_override_get_val_str(VALUE self) {
746
- llama_model_kv_override* ptr = get_llama_model_kv_override(self);
747
- return rb_utf8_str_new_cstr(ptr->val_str);
748
- }
749
- };
750
-
751
- const rb_data_type_t RbLLaMAModelKVOverride::llama_model_kv_override_type = {
752
- "RbLLaMAModelKVOverride",
753
- { NULL,
754
- RbLLaMAModelKVOverride::llama_model_kv_override_free,
755
- RbLLaMAModelKVOverride::llama_model_kv_override_size },
756
- NULL,
757
- NULL,
758
- RUBY_TYPED_FREE_IMMEDIATELY
759
- };
760
-
761
- class LLaMAModelParamsWrapper {
762
- public:
763
- struct llama_model_params params;
764
-
765
- LLaMAModelParamsWrapper() : params(llama_model_default_params()) {}
766
-
767
- ~LLaMAModelParamsWrapper() {}
768
- };
769
-
770
- class RbLLaMAModelParams {
771
- public:
772
- static VALUE llama_model_params_alloc(VALUE self) {
773
- LLaMAModelParamsWrapper* ptr = (LLaMAModelParamsWrapper*)ruby_xmalloc(sizeof(LLaMAModelParamsWrapper));
774
- new (ptr) LLaMAModelParamsWrapper();
775
- return TypedData_Wrap_Struct(self, &llama_model_params_type, ptr);
776
- }
777
-
778
- static void llama_model_params_free(void* ptr) {
779
- ((LLaMAModelParamsWrapper*)ptr)->~LLaMAModelParamsWrapper();
780
- ruby_xfree(ptr);
781
- }
782
-
783
- static size_t llama_model_params_size(const void* ptr) {
784
- return sizeof(*((LLaMAModelParamsWrapper*)ptr));
785
- }
786
-
787
- static LLaMAModelParamsWrapper* get_llama_model_params(VALUE self) {
788
- LLaMAModelParamsWrapper* ptr;
789
- TypedData_Get_Struct(self, LLaMAModelParamsWrapper, &llama_model_params_type, ptr);
790
- return ptr;
791
- }
792
-
793
- static void define_class(VALUE outer) {
794
- rb_cLLaMAModelParams = rb_define_class_under(outer, "ModelParams", rb_cObject);
795
- rb_define_alloc_func(rb_cLLaMAModelParams, llama_model_params_alloc);
796
- rb_define_method(rb_cLLaMAModelParams, "n_gpu_layers=", RUBY_METHOD_FUNC(_llama_model_params_set_n_gpu_layers), 1);
797
- rb_define_method(rb_cLLaMAModelParams, "n_gpu_layers", RUBY_METHOD_FUNC(_llama_model_params_get_n_gpu_layers), 0);
798
- rb_define_method(rb_cLLaMAModelParams, "split_mode=", RUBY_METHOD_FUNC(_llama_model_params_set_split_mode), 1);
799
- rb_define_method(rb_cLLaMAModelParams, "split_mode", RUBY_METHOD_FUNC(_llama_model_params_get_split_mode), 0);
800
- rb_define_method(rb_cLLaMAModelParams, "main_gpu=", RUBY_METHOD_FUNC(_llama_model_params_set_main_gpu), 1);
801
- rb_define_method(rb_cLLaMAModelParams, "main_gpu", RUBY_METHOD_FUNC(_llama_model_params_get_main_gpu), 0);
802
- rb_define_method(rb_cLLaMAModelParams, "tensor_split", RUBY_METHOD_FUNC(_llama_model_params_get_tensor_split), 0);
803
- rb_define_method(rb_cLLaMAModelParams, "vocab_only=", RUBY_METHOD_FUNC(_llama_model_params_set_vocab_only), 1);
804
- rb_define_method(rb_cLLaMAModelParams, "vocab_only", RUBY_METHOD_FUNC(_llama_model_params_get_vocab_only), 0);
805
- rb_define_method(rb_cLLaMAModelParams, "use_mmap=", RUBY_METHOD_FUNC(_llama_model_params_set_use_mmap), 1);
806
- rb_define_method(rb_cLLaMAModelParams, "use_mmap", RUBY_METHOD_FUNC(_llama_model_params_get_use_mmap), 0);
807
- rb_define_method(rb_cLLaMAModelParams, "use_mlock=", RUBY_METHOD_FUNC(_llama_model_params_set_use_mlock), 1);
808
- rb_define_method(rb_cLLaMAModelParams, "use_mlock", RUBY_METHOD_FUNC(_llama_model_params_get_use_mlock), 0);
809
- rb_define_method(rb_cLLaMAModelParams, "check_tensors=", RUBY_METHOD_FUNC(_llama_model_params_set_check_tensors), 1);
810
- rb_define_method(rb_cLLaMAModelParams, "check_tensors", RUBY_METHOD_FUNC(_llama_model_params_get_check_tensors), 0);
811
- }
812
-
813
- private:
814
- static const rb_data_type_t llama_model_params_type;
815
-
816
- // n_gpu_layers
817
- static VALUE _llama_model_params_set_n_gpu_layers(VALUE self, VALUE n_gpu_layers) {
818
- LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
819
- ptr->params.n_gpu_layers = NUM2INT(n_gpu_layers);
820
- return INT2NUM(ptr->params.n_gpu_layers);
821
- }
822
-
823
- static VALUE _llama_model_params_get_n_gpu_layers(VALUE self) {
824
- LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
825
- return INT2NUM(ptr->params.n_gpu_layers);
826
- }
827
-
828
- // split_mode
829
- static VALUE _llama_model_params_set_split_mode(VALUE self, VALUE split_mode) {
830
- LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
831
- ptr->params.split_mode = static_cast<enum llama_split_mode>(NUM2INT(split_mode));
832
- return INT2NUM(ptr->params.split_mode);
833
- }
834
-
835
- static VALUE _llama_model_params_get_split_mode(VALUE self) {
836
- LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
837
- return INT2NUM(ptr->params.split_mode);
838
- }
839
-
840
- // main_gpu
841
- static VALUE _llama_model_params_set_main_gpu(VALUE self, VALUE main_gpu) {
842
- LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
843
- ptr->params.main_gpu = NUM2INT(main_gpu);
844
- return INT2NUM(ptr->params.main_gpu);
845
- }
846
-
847
- static VALUE _llama_model_params_get_main_gpu(VALUE self) {
848
- LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
849
- return INT2NUM(ptr->params.main_gpu);
850
- }
851
-
852
- // tensor_split
853
- static VALUE _llama_model_params_get_tensor_split(VALUE self) {
854
- if (llama_max_devices() < 1) {
855
- return rb_ary_new();
856
- }
857
- VALUE ret = rb_ary_new2(llama_max_devices());
858
- LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
859
- if (ptr->params.tensor_split == nullptr) {
860
- return rb_ary_new();
861
- }
862
- for (size_t i = 0; i < llama_max_devices(); i++) {
863
- rb_ary_store(ret, i, DBL2NUM(ptr->params.tensor_split[i]));
864
- }
865
- return ret;
866
- }
867
-
868
- // vocab_only
869
- static VALUE _llama_model_params_set_vocab_only(VALUE self, VALUE vocab_only) {
870
- LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
871
- ptr->params.vocab_only = RTEST(vocab_only) ? true : false;
872
- return ptr->params.vocab_only ? Qtrue : Qfalse;
873
- }
874
-
875
- static VALUE _llama_model_params_get_vocab_only(VALUE self) {
876
- LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
877
- return ptr->params.vocab_only ? Qtrue : Qfalse;
878
- }
879
-
880
- // use_mmap
881
- static VALUE _llama_model_params_set_use_mmap(VALUE self, VALUE use_mmap) {
882
- LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
883
- ptr->params.use_mmap = RTEST(use_mmap) ? true : false;
884
- return ptr->params.use_mmap ? Qtrue : Qfalse;
885
- }
886
-
887
- static VALUE _llama_model_params_get_use_mmap(VALUE self) {
888
- LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
889
- return ptr->params.use_mmap ? Qtrue : Qfalse;
890
- }
891
-
892
- // use_mlock
893
- static VALUE _llama_model_params_set_use_mlock(VALUE self, VALUE use_mlock) {
894
- LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
895
- ptr->params.use_mlock = RTEST(use_mlock) ? true : false;
896
- return ptr->params.use_mlock ? Qtrue : Qfalse;
897
- }
898
-
899
- static VALUE _llama_model_params_get_use_mlock(VALUE self) {
900
- LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
901
- return ptr->params.use_mlock ? Qtrue : Qfalse;
902
- }
903
-
904
- // check_tensors
905
- static VALUE _llama_model_params_set_check_tensors(VALUE self, VALUE check_tensors) {
906
- LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
907
- ptr->params.check_tensors = RTEST(check_tensors) ? true : false;
908
- return ptr->params.check_tensors ? Qtrue : Qfalse;
909
- }
910
-
911
- static VALUE _llama_model_params_get_check_tensors(VALUE self) {
912
- LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
913
- return ptr->params.check_tensors ? Qtrue : Qfalse;
914
- }
915
- };
916
-
917
- const rb_data_type_t RbLLaMAModelParams::llama_model_params_type = {
918
- "RbLLaMAModelParams",
919
- { NULL,
920
- RbLLaMAModelParams::llama_model_params_free,
921
- RbLLaMAModelParams::llama_model_params_size },
922
- NULL,
923
- NULL,
924
- RUBY_TYPED_FREE_IMMEDIATELY
925
- };
926
-
927
- class LLaMAContextParamsWrapper {
928
- public:
929
- struct llama_context_params params;
930
-
931
- LLaMAContextParamsWrapper() : params(llama_context_default_params()) {}
932
-
933
- ~LLaMAContextParamsWrapper() {}
934
- };
935
-
936
- class RbLLaMAContextParams {
937
- public:
938
- static VALUE llama_context_params_alloc(VALUE self) {
939
- LLaMAContextParamsWrapper* ptr = (LLaMAContextParamsWrapper*)ruby_xmalloc(sizeof(LLaMAContextParamsWrapper));
940
- new (ptr) LLaMAContextParamsWrapper();
941
- return TypedData_Wrap_Struct(self, &llama_context_params_type, ptr);
942
- }
943
-
944
- static void llama_context_params_free(void* ptr) {
945
- ((LLaMAContextParamsWrapper*)ptr)->~LLaMAContextParamsWrapper();
946
- ruby_xfree(ptr);
947
- }
948
-
949
- static size_t llama_context_params_size(const void* ptr) {
950
- return sizeof(*((LLaMAContextParamsWrapper*)ptr));
951
- }
952
-
953
- static LLaMAContextParamsWrapper* get_llama_context_params(VALUE self) {
954
- LLaMAContextParamsWrapper* ptr;
955
- TypedData_Get_Struct(self, LLaMAContextParamsWrapper, &llama_context_params_type, ptr);
956
- return ptr;
957
- }
958
-
959
- static void define_class(VALUE outer) {
960
- rb_cLLaMAContextParams = rb_define_class_under(outer, "ContextParams", rb_cObject);
961
- rb_define_alloc_func(rb_cLLaMAContextParams, llama_context_params_alloc);
962
- // rb_define_method(rb_cLLaMAContextParams, "initialize", RUBY_METHOD_FUNC(_llama_context_params_init), 0);
963
- rb_define_method(rb_cLLaMAContextParams, "seed=", RUBY_METHOD_FUNC(_llama_context_params_set_seed), 1);
964
- rb_define_method(rb_cLLaMAContextParams, "seed", RUBY_METHOD_FUNC(_llama_context_params_get_seed), 0);
965
- rb_define_method(rb_cLLaMAContextParams, "n_ctx=", RUBY_METHOD_FUNC(_llama_context_params_set_n_ctx), 1);
966
- rb_define_method(rb_cLLaMAContextParams, "n_ctx", RUBY_METHOD_FUNC(_llama_context_params_get_n_ctx), 0);
967
- rb_define_method(rb_cLLaMAContextParams, "n_batch=", RUBY_METHOD_FUNC(_llama_context_params_set_n_batch), 1);
968
- rb_define_method(rb_cLLaMAContextParams, "n_batch", RUBY_METHOD_FUNC(_llama_context_params_get_n_batch), 0);
969
- rb_define_method(rb_cLLaMAContextParams, "n_ubatch=", RUBY_METHOD_FUNC(_llama_context_params_set_n_ubatch), 1);
970
- rb_define_method(rb_cLLaMAContextParams, "n_ubatch", RUBY_METHOD_FUNC(_llama_context_params_get_n_ubatch), 0);
971
- rb_define_method(rb_cLLaMAContextParams, "n_seq_max=", RUBY_METHOD_FUNC(_llama_context_params_set_n_seq_max), 1);
972
- rb_define_method(rb_cLLaMAContextParams, "n_seq_max", RUBY_METHOD_FUNC(_llama_context_params_get_n_seq_max), 0);
973
- rb_define_method(rb_cLLaMAContextParams, "n_threads=", RUBY_METHOD_FUNC(_llama_context_params_set_n_threads), 1);
974
- rb_define_method(rb_cLLaMAContextParams, "n_threads", RUBY_METHOD_FUNC(_llama_context_params_get_n_threads), 0);
975
- rb_define_method(rb_cLLaMAContextParams, "n_threads_batch=", RUBY_METHOD_FUNC(_llama_context_params_set_n_threads_batch), 1);
976
- rb_define_method(rb_cLLaMAContextParams, "n_threads_batch", RUBY_METHOD_FUNC(_llama_context_params_get_n_threads_batch), 0);
977
- rb_define_method(rb_cLLaMAContextParams, "rope_scaling_type=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_scaling_type), 1);
978
- rb_define_method(rb_cLLaMAContextParams, "rope_scaling_type", RUBY_METHOD_FUNC(_llama_context_params_get_rope_scaling_type), 0);
979
- rb_define_method(rb_cLLaMAContextParams, "pooling_type=", RUBY_METHOD_FUNC(_llama_context_params_set_pooling_type), 1);
980
- rb_define_method(rb_cLLaMAContextParams, "pooling_type", RUBY_METHOD_FUNC(_llama_context_params_get_pooling_type), 0);
981
- rb_define_method(rb_cLLaMAContextParams, "attention_type=", RUBY_METHOD_FUNC(_llama_context_params_set_attention_type), 1);
982
- rb_define_method(rb_cLLaMAContextParams, "attention_type", RUBY_METHOD_FUNC(_llama_context_params_get_attention_type), 0);
983
- rb_define_method(rb_cLLaMAContextParams, "rope_freq_base=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_freq_base), 1);
984
- rb_define_method(rb_cLLaMAContextParams, "rope_freq_base", RUBY_METHOD_FUNC(_llama_context_params_get_rope_freq_base), 0);
985
- rb_define_method(rb_cLLaMAContextParams, "rope_freq_scale=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_freq_scale), 1);
986
- rb_define_method(rb_cLLaMAContextParams, "rope_freq_scale", RUBY_METHOD_FUNC(_llama_context_params_get_rope_freq_scale), 0);
987
- rb_define_method(rb_cLLaMAContextParams, "yarn_ext_factor=", RUBY_METHOD_FUNC(_llama_context_params_set_yarn_ext_factor), 1);
988
- rb_define_method(rb_cLLaMAContextParams, "yarn_ext_factor", RUBY_METHOD_FUNC(_llama_context_params_get_yarn_ext_factor), 0);
989
- rb_define_method(rb_cLLaMAContextParams, "yarn_attn_factor=", RUBY_METHOD_FUNC(_llama_context_params_set_yarn_attn_factor), 1);
990
- rb_define_method(rb_cLLaMAContextParams, "yarn_attn_factor", RUBY_METHOD_FUNC(_llama_context_params_get_yarn_attn_factor), 0);
991
- rb_define_method(rb_cLLaMAContextParams, "yarn_beta_fast=", RUBY_METHOD_FUNC(_llama_context_params_set_yarn_beta_fast), 1);
992
- rb_define_method(rb_cLLaMAContextParams, "yarn_beta_fast", RUBY_METHOD_FUNC(_llama_context_params_get_yarn_beta_fast), 0);
993
- rb_define_method(rb_cLLaMAContextParams, "yarn_beta_slow=", RUBY_METHOD_FUNC(_llama_context_params_set_yarn_beta_slow), 1);
994
- rb_define_method(rb_cLLaMAContextParams, "yarn_beta_slow", RUBY_METHOD_FUNC(_llama_context_params_get_yarn_beta_slow), 0);
995
- rb_define_method(rb_cLLaMAContextParams, "yarn_orig_ctx=", RUBY_METHOD_FUNC(_llama_context_params_set_yarn_orig_ctx), 1);
996
- rb_define_method(rb_cLLaMAContextParams, "yarn_orig_ctx", RUBY_METHOD_FUNC(_llama_context_params_get_yarn_orig_ctx), 0);
997
- rb_define_method(rb_cLLaMAContextParams, "defrag_thold=", RUBY_METHOD_FUNC(_llama_context_params_set_defrag_thold), 1);
998
- rb_define_method(rb_cLLaMAContextParams, "defrag_thold", RUBY_METHOD_FUNC(_llama_context_params_get_defrag_thold), 0);
999
- rb_define_method(rb_cLLaMAContextParams, "type_k=", RUBY_METHOD_FUNC(_llama_context_params_set_type_k), 1);
1000
- rb_define_method(rb_cLLaMAContextParams, "type_k", RUBY_METHOD_FUNC(_llama_context_params_get_type_k), 0);
1001
- rb_define_method(rb_cLLaMAContextParams, "type_v=", RUBY_METHOD_FUNC(_llama_context_params_set_type_v), 1);
1002
- rb_define_method(rb_cLLaMAContextParams, "type_v", RUBY_METHOD_FUNC(_llama_context_params_get_type_v), 0);
1003
- rb_define_method(rb_cLLaMAContextParams, "logits_all=", RUBY_METHOD_FUNC(_llama_context_params_set_logits_all), 1);
1004
- rb_define_method(rb_cLLaMAContextParams, "logits_all", RUBY_METHOD_FUNC(_llama_context_params_get_logits_all), 0);
1005
- rb_define_method(rb_cLLaMAContextParams, "embeddings=", RUBY_METHOD_FUNC(_llama_context_params_set_embeddings), 1);
1006
- rb_define_method(rb_cLLaMAContextParams, "embeddings", RUBY_METHOD_FUNC(_llama_context_params_get_embeddings), 0);
1007
- rb_define_method(rb_cLLaMAContextParams, "offload_kqv=", RUBY_METHOD_FUNC(_llama_context_params_set_offload_kqv), 1);
1008
- rb_define_method(rb_cLLaMAContextParams, "offload_kqv", RUBY_METHOD_FUNC(_llama_context_params_get_offload_kqv), 0);
1009
- rb_define_method(rb_cLLaMAContextParams, "flash_attn=", RUBY_METHOD_FUNC(_llama_context_params_set_flash_attn), 1);
1010
- rb_define_method(rb_cLLaMAContextParams, "flash_attn", RUBY_METHOD_FUNC(_llama_context_params_get_flash_attn), 0);
1011
- }
1012
-
1013
- private:
1014
- static const rb_data_type_t llama_context_params_type;
1015
-
1016
- // static VALUE _llama_context_params_init(VALUE self, VALUE seed) {
1017
- // LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1018
- // new (ptr) LLaMAContextParamsWrapper();
1019
- // return self;
1020
- // }
1021
-
1022
- // seed
1023
- static VALUE _llama_context_params_set_seed(VALUE self, VALUE seed) {
1024
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1025
- if (NUM2INT(seed) < 0) {
1026
- rb_raise(rb_eArgError, "seed must be positive");
1027
- return Qnil;
1028
- }
1029
- ptr->params.seed = NUM2INT(seed);
1030
- return INT2NUM(ptr->params.seed);
1031
- }
1032
-
1033
- static VALUE _llama_context_params_get_seed(VALUE self) {
1034
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1035
- return INT2NUM(ptr->params.seed);
1036
- }
1037
-
1038
- // n_ctx
1039
- static VALUE _llama_context_params_set_n_ctx(VALUE self, VALUE n_ctx) {
1040
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1041
- ptr->params.n_ctx = NUM2INT(n_ctx);
1042
- return INT2NUM(ptr->params.n_ctx);
1043
- }
1044
-
1045
- static VALUE _llama_context_params_get_n_ctx(VALUE self) {
1046
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1047
- return INT2NUM(ptr->params.n_ctx);
1048
- }
1049
-
1050
- // n_batch
1051
- static VALUE _llama_context_params_set_n_batch(VALUE self, VALUE n_batch) {
1052
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1053
- ptr->params.n_batch = NUM2INT(n_batch);
1054
- return INT2NUM(ptr->params.n_batch);
1055
- }
1056
-
1057
- static VALUE _llama_context_params_get_n_batch(VALUE self) {
1058
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1059
- return INT2NUM(ptr->params.n_batch);
1060
- }
1061
-
1062
- // n_ubatch
1063
- static VALUE _llama_context_params_set_n_ubatch(VALUE self, VALUE n_ubatch) {
1064
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1065
- ptr->params.n_ubatch = NUM2INT(n_ubatch);
1066
- return INT2NUM(ptr->params.n_ubatch);
1067
- }
1068
-
1069
- static VALUE _llama_context_params_get_n_ubatch(VALUE self) {
1070
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1071
- return INT2NUM(ptr->params.n_ubatch);
1072
- }
1073
-
1074
- // n_seq_max
1075
- static VALUE _llama_context_params_set_n_seq_max(VALUE self, VALUE n_seq_max) {
1076
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1077
- ptr->params.n_seq_max = NUM2INT(n_seq_max);
1078
- return INT2NUM(ptr->params.n_seq_max);
1079
- }
1080
-
1081
- static VALUE _llama_context_params_get_n_seq_max(VALUE self) {
1082
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1083
- return INT2NUM(ptr->params.n_seq_max);
1084
- }
1085
-
1086
- // n_threads
1087
- static VALUE _llama_context_params_set_n_threads(VALUE self, VALUE n_threads) {
1088
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1089
- ptr->params.n_threads = NUM2INT(n_threads);
1090
- return INT2NUM(ptr->params.n_threads);
1091
- }
1092
-
1093
- static VALUE _llama_context_params_get_n_threads(VALUE self) {
1094
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1095
- return INT2NUM(ptr->params.n_threads);
1096
- }
1097
-
1098
- // n_threads_batch
1099
- static VALUE _llama_context_params_set_n_threads_batch(VALUE self, VALUE n_threads_batch) {
1100
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1101
- ptr->params.n_threads_batch = NUM2INT(n_threads_batch);
1102
- return INT2NUM(ptr->params.n_threads_batch);
1103
- }
1104
-
1105
- static VALUE _llama_context_params_get_n_threads_batch(VALUE self) {
1106
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1107
- return INT2NUM(ptr->params.n_threads_batch);
1108
- }
1109
-
1110
- // rope_scaling_type
1111
- static VALUE _llama_context_params_set_rope_scaling_type(VALUE self, VALUE scaling_type) {
1112
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1113
- ptr->params.rope_scaling_type = static_cast<enum llama_rope_scaling_type>(NUM2INT(scaling_type));
1114
- return INT2NUM(ptr->params.rope_scaling_type);
1115
- }
1116
-
1117
- static VALUE _llama_context_params_get_rope_scaling_type(VALUE self) {
1118
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1119
- return INT2NUM(ptr->params.rope_scaling_type);
1120
- }
1121
-
1122
- // pooling_type
1123
- static VALUE _llama_context_params_set_pooling_type(VALUE self, VALUE scaling_type) {
1124
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1125
- ptr->params.pooling_type = static_cast<enum llama_pooling_type>(NUM2INT(scaling_type));
1126
- return INT2NUM(ptr->params.pooling_type);
1127
- }
1128
-
1129
- static VALUE _llama_context_params_get_pooling_type(VALUE self) {
1130
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1131
- return INT2NUM(ptr->params.pooling_type);
1132
- }
1133
-
1134
- // attention_type
1135
- static VALUE _llama_context_params_set_attention_type(VALUE self, VALUE scaling_type) {
1136
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1137
- ptr->params.attention_type = static_cast<enum llama_attention_type>(NUM2INT(scaling_type));
1138
- return INT2NUM(ptr->params.attention_type);
1139
- }
1140
-
1141
- static VALUE _llama_context_params_get_attention_type(VALUE self) {
1142
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1143
- return INT2NUM(ptr->params.attention_type);
1144
- }
1145
-
1146
- // rope_freq_base
1147
- static VALUE _llama_context_params_set_rope_freq_base(VALUE self, VALUE rope_freq_base) {
1148
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1149
- ptr->params.rope_freq_base = NUM2DBL(rope_freq_base);
1150
- return DBL2NUM(ptr->params.rope_freq_base);
1151
- }
1152
-
1153
- static VALUE _llama_context_params_get_rope_freq_base(VALUE self) {
1154
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1155
- return DBL2NUM(ptr->params.rope_freq_base);
1156
- }
1157
-
1158
- // rope_freq_scale
1159
- static VALUE _llama_context_params_set_rope_freq_scale(VALUE self, VALUE rope_freq_scale) {
1160
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1161
- ptr->params.rope_freq_scale = NUM2DBL(rope_freq_scale);
1162
- return DBL2NUM(ptr->params.rope_freq_scale);
1163
- }
1164
-
1165
- static VALUE _llama_context_params_get_rope_freq_scale(VALUE self) {
1166
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1167
- return DBL2NUM(ptr->params.rope_freq_scale);
1168
- }
1169
-
1170
- // yarn_ext_factor
1171
- static VALUE _llama_context_params_set_yarn_ext_factor(VALUE self, VALUE yarn_ext_factor) {
1172
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1173
- ptr->params.yarn_ext_factor = NUM2DBL(yarn_ext_factor);
1174
- return DBL2NUM(ptr->params.yarn_ext_factor);
1175
- }
1176
-
1177
- static VALUE _llama_context_params_get_yarn_ext_factor(VALUE self) {
1178
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1179
- return DBL2NUM(ptr->params.yarn_ext_factor);
1180
- }
1181
-
1182
- // yarn_attn_factor
1183
- static VALUE _llama_context_params_set_yarn_attn_factor(VALUE self, VALUE yarn_attn_factor) {
1184
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1185
- ptr->params.yarn_attn_factor = NUM2DBL(yarn_attn_factor);
1186
- return DBL2NUM(ptr->params.yarn_attn_factor);
1187
- }
1188
-
1189
- static VALUE _llama_context_params_get_yarn_attn_factor(VALUE self) {
1190
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1191
- return DBL2NUM(ptr->params.yarn_attn_factor);
1192
- }
1193
-
1194
- // yarn_beta_fast
1195
- static VALUE _llama_context_params_set_yarn_beta_fast(VALUE self, VALUE yarn_beta_fast) {
1196
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1197
- ptr->params.yarn_beta_fast = NUM2DBL(yarn_beta_fast);
1198
- return DBL2NUM(ptr->params.yarn_beta_fast);
1199
- }
1200
-
1201
- static VALUE _llama_context_params_get_yarn_beta_fast(VALUE self) {
1202
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1203
- return DBL2NUM(ptr->params.yarn_beta_fast);
1204
- }
1205
-
1206
- // yarn_beta_slow
1207
- static VALUE _llama_context_params_set_yarn_beta_slow(VALUE self, VALUE yarn_beta_slow) {
1208
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1209
- ptr->params.yarn_beta_slow = NUM2DBL(yarn_beta_slow);
1210
- return DBL2NUM(ptr->params.yarn_beta_slow);
1211
- }
1212
-
1213
- static VALUE _llama_context_params_get_yarn_beta_slow(VALUE self) {
1214
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1215
- return DBL2NUM(ptr->params.yarn_beta_slow);
1216
- }
1217
-
1218
- // yarn_orig_ctx
1219
- static VALUE _llama_context_params_set_yarn_orig_ctx(VALUE self, VALUE yarn_orig_ctx) {
1220
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1221
- ptr->params.yarn_orig_ctx = NUM2UINT(yarn_orig_ctx);
1222
- return UINT2NUM(ptr->params.yarn_orig_ctx);
1223
- }
1224
-
1225
- // defrag_thold
1226
- static VALUE _llama_context_params_set_defrag_thold(VALUE self, VALUE defrag_thold) {
1227
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1228
- ptr->params.defrag_thold = NUM2DBL(defrag_thold);
1229
- return DBL2NUM(ptr->params.defrag_thold);
1230
- }
1231
-
1232
- static VALUE _llama_context_params_get_defrag_thold(VALUE self) {
1233
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1234
- return DBL2NUM(ptr->params.defrag_thold);
1235
- }
1236
-
1237
- static VALUE _llama_context_params_get_yarn_orig_ctx(VALUE self) {
1238
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1239
- return UINT2NUM(ptr->params.yarn_orig_ctx);
1240
- }
1241
-
1242
- // type_k
1243
- static VALUE _llama_context_params_set_type_k(VALUE self, VALUE type_k) {
1244
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1245
- ptr->params.type_k = static_cast<enum ggml_type>(NUM2INT(type_k));
1246
- return INT2NUM(ptr->params.type_k);
1247
- }
1248
-
1249
- static VALUE _llama_context_params_get_type_k(VALUE self) {
1250
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1251
- return INT2NUM(ptr->params.type_k);
1252
- }
1253
-
1254
- // type_v
1255
- static VALUE _llama_context_params_set_type_v(VALUE self, VALUE type_v) {
1256
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1257
- ptr->params.type_v = static_cast<enum ggml_type>(NUM2INT(type_v));
1258
- return INT2NUM(ptr->params.type_v);
1259
- }
1260
-
1261
- static VALUE _llama_context_params_get_type_v(VALUE self) {
1262
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1263
- return INT2NUM(ptr->params.type_v);
1264
- }
1265
-
1266
- // logits_all
1267
- static VALUE _llama_context_params_set_logits_all(VALUE self, VALUE logits_all) {
1268
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1269
- ptr->params.logits_all = RTEST(logits_all) ? true : false;
1270
- return ptr->params.logits_all ? Qtrue : Qfalse;
1271
- }
1272
-
1273
- static VALUE _llama_context_params_get_logits_all(VALUE self) {
1274
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1275
- return ptr->params.logits_all ? Qtrue : Qfalse;
1276
- }
1277
-
1278
- // embeddings
1279
- static VALUE _llama_context_params_set_embeddings(VALUE self, VALUE embeddings) {
1280
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1281
- ptr->params.embeddings = RTEST(embeddings) ? true : false;
1282
- return ptr->params.embeddings ? Qtrue : Qfalse;
1283
- }
1284
-
1285
- static VALUE _llama_context_params_get_embeddings(VALUE self) {
1286
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1287
- return ptr->params.embeddings ? Qtrue : Qfalse;
1288
- }
1289
-
1290
- // offload_kqv
1291
- static VALUE _llama_context_params_set_offload_kqv(VALUE self, VALUE offload_kqv) {
1292
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1293
- ptr->params.offload_kqv = RTEST(offload_kqv) ? true : false;
1294
- return ptr->params.offload_kqv ? Qtrue : Qfalse;
1295
- }
1296
-
1297
- static VALUE _llama_context_params_get_offload_kqv(VALUE self) {
1298
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1299
- return ptr->params.offload_kqv ? Qtrue : Qfalse;
1300
- }
1301
-
1302
- // flash_attn
1303
- static VALUE _llama_context_params_set_flash_attn(VALUE self, VALUE flash_attn) {
1304
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1305
- ptr->params.flash_attn = RTEST(flash_attn) ? true : false;
1306
- return ptr->params.flash_attn ? Qtrue : Qfalse;
1307
- }
1308
-
1309
- static VALUE _llama_context_params_get_flash_attn(VALUE self) {
1310
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1311
- return ptr->params.flash_attn ? Qtrue : Qfalse;
1312
- }
1313
- };
1314
-
1315
- const rb_data_type_t RbLLaMAContextParams::llama_context_params_type = {
1316
- "RbLLaMAContextParams",
1317
- { NULL,
1318
- RbLLaMAContextParams::llama_context_params_free,
1319
- RbLLaMAContextParams::llama_context_params_size },
1320
- NULL,
1321
- NULL,
1322
- RUBY_TYPED_FREE_IMMEDIATELY
1323
- };
1324
-
1325
- class LLaMAModelQuantizeParamsWrapper {
1326
- public:
1327
- llama_model_quantize_params params;
1328
-
1329
- LLaMAModelQuantizeParamsWrapper() : params(llama_model_quantize_default_params()) {}
1330
-
1331
- ~LLaMAModelQuantizeParamsWrapper() {}
1332
- };
1333
-
1334
- class RbLLaMAModelQuantizeParams {
1335
- public:
1336
- static VALUE llama_model_quantize_params_alloc(VALUE self) {
1337
- LLaMAModelQuantizeParamsWrapper* ptr = (LLaMAModelQuantizeParamsWrapper*)ruby_xmalloc(sizeof(LLaMAModelQuantizeParamsWrapper));
1338
- new (ptr) LLaMAModelQuantizeParamsWrapper();
1339
- return TypedData_Wrap_Struct(self, &llama_model_quantize_params_type, ptr);
1340
- }
1341
-
1342
- static void llama_model_quantize_params_free(void* ptr) {
1343
- ((LLaMAModelQuantizeParamsWrapper*)ptr)->~LLaMAModelQuantizeParamsWrapper();
1344
- ruby_xfree(ptr);
1345
- }
1346
-
1347
- static size_t llama_model_quantize_params_size(const void* ptr) {
1348
- return sizeof(*((LLaMAModelQuantizeParamsWrapper*)ptr));
1349
- }
1350
-
1351
- static LLaMAModelQuantizeParamsWrapper* get_llama_model_quantize_params(VALUE self) {
1352
- LLaMAModelQuantizeParamsWrapper* ptr;
1353
- TypedData_Get_Struct(self, LLaMAModelQuantizeParamsWrapper, &llama_model_quantize_params_type, ptr);
1354
- return ptr;
1355
- }
1356
-
1357
- static void define_class(VALUE outer) {
1358
- rb_cLLaMAModelQuantizeParams = rb_define_class_under(outer, "ModelQuantizeParams", rb_cObject);
1359
- rb_define_alloc_func(rb_cLLaMAModelQuantizeParams, llama_model_quantize_params_alloc);
1360
- rb_define_method(rb_cLLaMAModelQuantizeParams, "n_thread=", RUBY_METHOD_FUNC(_llama_model_quantize_params_set_n_thread), 1);
1361
- rb_define_method(rb_cLLaMAModelQuantizeParams, "n_thread", RUBY_METHOD_FUNC(_llama_model_quantize_params_get_n_thread), 0);
1362
- rb_define_method(rb_cLLaMAModelQuantizeParams, "ftype=", RUBY_METHOD_FUNC(_llama_model_quantize_params_set_ftype), 1);
1363
- rb_define_method(rb_cLLaMAModelQuantizeParams, "ftype", RUBY_METHOD_FUNC(_llama_model_quantize_params_get_ftype), 0);
1364
- rb_define_method(rb_cLLaMAModelQuantizeParams, "allow_requantize=", RUBY_METHOD_FUNC(_llama_model_quantize_params_set_allow_requantize), 1);
1365
- rb_define_method(rb_cLLaMAModelQuantizeParams, "allow_requantize", RUBY_METHOD_FUNC(_llama_model_quantize_params_get_allow_requantize), 0);
1366
- rb_define_method(rb_cLLaMAModelQuantizeParams, "quantize_output_tensor=", RUBY_METHOD_FUNC(_llama_model_quantize_params_set_quantize_output_tensor), 1);
1367
- rb_define_method(rb_cLLaMAModelQuantizeParams, "quantize_output_tensor", RUBY_METHOD_FUNC(_llama_model_quantize_params_get_quantize_output_tensor), 0);
1368
- rb_define_method(rb_cLLaMAModelQuantizeParams, "only_copy=", RUBY_METHOD_FUNC(_llama_model_quantize_params_set_only_copy), 1);
1369
- rb_define_method(rb_cLLaMAModelQuantizeParams, "only_copy", RUBY_METHOD_FUNC(_llama_model_quantize_params_get_only_copy), 0);
1370
- rb_define_method(rb_cLLaMAModelQuantizeParams, "pure=", RUBY_METHOD_FUNC(_llama_model_quantize_params_set_pure), 1);
1371
- rb_define_method(rb_cLLaMAModelQuantizeParams, "pure", RUBY_METHOD_FUNC(_llama_model_quantize_params_get_pure), 0);
1372
- rb_define_method(rb_cLLaMAModelQuantizeParams, "keep_split=", RUBY_METHOD_FUNC(_llama_model_quantize_params_set_keep_split), 1);
1373
- rb_define_method(rb_cLLaMAModelQuantizeParams, "keep_split", RUBY_METHOD_FUNC(_llama_model_quantize_params_get_keep_split), 0);
1374
- }
1375
-
1376
- private:
1377
- static const rb_data_type_t llama_model_quantize_params_type;
1378
-
1379
- // n_thread
1380
- static VALUE _llama_model_quantize_params_set_n_thread(VALUE self, VALUE n_thread) {
1381
- LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
1382
- ptr->params.nthread = NUM2INT(n_thread);
1383
- return INT2NUM(ptr->params.nthread);
1384
- }
1385
-
1386
- static VALUE _llama_model_quantize_params_get_n_thread(VALUE self) {
1387
- LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
1388
- return INT2NUM(ptr->params.nthread);
1389
- }
1390
-
1391
- // ftype
1392
- static VALUE _llama_model_quantize_params_set_ftype(VALUE self, VALUE ftype) {
1393
- LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
1394
- ptr->params.ftype = static_cast<enum llama_ftype>(NUM2INT(ftype));
1395
- return INT2NUM(ptr->params.ftype);
1396
- }
1397
-
1398
- static VALUE _llama_model_quantize_params_get_ftype(VALUE self) {
1399
- LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
1400
- return INT2NUM(ptr->params.ftype);
1401
- }
1402
-
1403
- // allow_requantize
1404
- static VALUE _llama_model_quantize_params_set_allow_requantize(VALUE self, VALUE allow_requantize) {
1405
- LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
1406
- if (NIL_P(allow_requantize) || allow_requantize == Qfalse) {
1407
- ptr->params.allow_requantize = false;
1408
- } else {
1409
- ptr->params.allow_requantize = true;
1410
- }
1411
- return ptr->params.allow_requantize ? Qtrue : Qfalse;
1412
- }
1413
-
1414
- static VALUE _llama_model_quantize_params_get_allow_requantize(VALUE self) {
1415
- LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
1416
- return ptr->params.allow_requantize ? Qtrue : Qfalse;
1417
- }
1418
-
1419
- // quantize_output_tensor
1420
- static VALUE _llama_model_quantize_params_set_quantize_output_tensor(VALUE self, VALUE quantize_output_tensor) {
1421
- LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
1422
- if (NIL_P(quantize_output_tensor) || quantize_output_tensor == Qfalse) {
1423
- ptr->params.quantize_output_tensor = false;
1424
- } else {
1425
- ptr->params.quantize_output_tensor = true;
1426
- }
1427
- return ptr->params.quantize_output_tensor ? Qtrue : Qfalse;
1428
- }
1429
-
1430
- static VALUE _llama_model_quantize_params_get_quantize_output_tensor(VALUE self) {
1431
- LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
1432
- return ptr->params.quantize_output_tensor ? Qtrue : Qfalse;
1433
- }
1434
-
1435
- // only_copy
1436
- static VALUE _llama_model_quantize_params_set_only_copy(VALUE self, VALUE only_copy) {
1437
- LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
1438
- ptr->params.only_copy = RTEST(only_copy) ? true : false;
1439
- return ptr->params.only_copy ? Qtrue : Qfalse;
1440
- }
1441
-
1442
- static VALUE _llama_model_quantize_params_get_only_copy(VALUE self) {
1443
- LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
1444
- return ptr->params.only_copy ? Qtrue : Qfalse;
1445
- }
1446
-
1447
- // pure
1448
- static VALUE _llama_model_quantize_params_set_pure(VALUE self, VALUE pure) {
1449
- LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
1450
- ptr->params.pure = RTEST(pure) ? true : false;
1451
- return ptr->params.pure ? Qtrue : Qfalse;
1452
- }
1453
-
1454
- static VALUE _llama_model_quantize_params_get_pure(VALUE self) {
1455
- LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
1456
- return ptr->params.pure ? Qtrue : Qfalse;
1457
- }
1458
-
1459
- // keep_split
1460
- static VALUE _llama_model_quantize_params_set_keep_split(VALUE self, VALUE keep_split) {
1461
- LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
1462
- ptr->params.keep_split = RTEST(keep_split) ? true : false;
1463
- return ptr->params.keep_split ? Qtrue : Qfalse;
1464
- }
1465
-
1466
- static VALUE _llama_model_quantize_params_get_keep_split(VALUE self) {
1467
- LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
1468
- return ptr->params.keep_split ? Qtrue : Qfalse;
1469
- }
1470
- };
1471
-
1472
- const rb_data_type_t RbLLaMAModelQuantizeParams::llama_model_quantize_params_type = {
1473
- "RbLLaMAModelQuantizeParams",
1474
- { NULL,
1475
- RbLLaMAModelQuantizeParams::llama_model_quantize_params_free,
1476
- RbLLaMAModelQuantizeParams::llama_model_quantize_params_size },
1477
- NULL,
1478
- NULL,
1479
- RUBY_TYPED_FREE_IMMEDIATELY
1480
- };
1481
-
1482
- class LLaMAModelWrapper {
1483
- public:
1484
- struct llama_model* model;
1485
-
1486
- LLaMAModelWrapper() : model(NULL) {}
1487
-
1488
- ~LLaMAModelWrapper() {
1489
- if (model != NULL) {
1490
- llama_free_model(model);
1491
- }
1492
- }
1493
- };
1494
-
1495
- class RbLLaMAModel {
1496
- public:
1497
- static VALUE llama_model_alloc(VALUE self) {
1498
- LLaMAModelWrapper* ptr = (LLaMAModelWrapper*)ruby_xmalloc(sizeof(LLaMAModelWrapper));
1499
- new (ptr) LLaMAModelWrapper();
1500
- return TypedData_Wrap_Struct(self, &llama_model_type, ptr);
1501
- }
1502
-
1503
- static void llama_model_free(void* ptr) {
1504
- ((LLaMAModelWrapper*)ptr)->~LLaMAModelWrapper();
1505
- ruby_xfree(ptr);
1506
- }
1507
-
1508
- static size_t llama_model_size(const void* ptr) {
1509
- return sizeof(*((LLaMAModelWrapper*)ptr));
1510
- }
1511
-
1512
- static LLaMAModelWrapper* get_llama_model(VALUE self) {
1513
- LLaMAModelWrapper* ptr;
1514
- TypedData_Get_Struct(self, LLaMAModelWrapper, &llama_model_type, ptr);
1515
- return ptr;
1516
- }
1517
-
1518
- static void define_class(VALUE outer) {
1519
- rb_cLLaMAModel = rb_define_class_under(outer, "Model", rb_cObject);
1520
- rb_define_alloc_func(rb_cLLaMAModel, llama_model_alloc);
1521
- rb_define_attr(rb_cLLaMAModel, "params", 1, 0);
1522
- rb_define_method(rb_cLLaMAModel, "initialize", RUBY_METHOD_FUNC(_llama_model_initialize), -1);
1523
- rb_define_method(rb_cLLaMAModel, "empty?", RUBY_METHOD_FUNC(_llama_model_empty), 0);
1524
- rb_define_method(rb_cLLaMAModel, "free", RUBY_METHOD_FUNC(_llama_model_free), 0);
1525
- rb_define_method(rb_cLLaMAModel, "load", RUBY_METHOD_FUNC(_llama_model_load), -1);
1526
- rb_define_method(rb_cLLaMAModel, "vocab_type", RUBY_METHOD_FUNC(_llama_model_get_model_vocab_type), 0);
1527
- rb_define_method(rb_cLLaMAModel, "rope_type", RUBY_METHOD_FUNC(_llama_model_get_model_rope_type), 0);
1528
- rb_define_method(rb_cLLaMAModel, "n_vocab", RUBY_METHOD_FUNC(_llama_model_get_model_n_vocab), 0);
1529
- rb_define_method(rb_cLLaMAModel, "n_ctx_train", RUBY_METHOD_FUNC(_llama_model_get_model_n_ctx_train), 0);
1530
- rb_define_method(rb_cLLaMAModel, "n_embd", RUBY_METHOD_FUNC(_llama_model_get_model_n_embd), 0);
1531
- rb_define_method(rb_cLLaMAModel, "n_layer", RUBY_METHOD_FUNC(_llama_model_get_model_n_layer), 0);
1532
- rb_define_method(rb_cLLaMAModel, "rope_freq_scale_train", RUBY_METHOD_FUNC(_llama_model_rope_freq_scale_train), 0);
1533
- rb_define_method(rb_cLLaMAModel, "token_to_piece", RUBY_METHOD_FUNC(_llama_model_token_to_piece), -1);
1534
- rb_define_method(rb_cLLaMAModel, "tokenize", RUBY_METHOD_FUNC(_llama_model_tokenize), -1);
1535
- rb_define_method(rb_cLLaMAModel, "desc", RUBY_METHOD_FUNC(_llama_model_get_model_desc), 0);
1536
- rb_define_method(rb_cLLaMAModel, "size", RUBY_METHOD_FUNC(_llama_model_get_model_size), 0);
1537
- rb_define_method(rb_cLLaMAModel, "n_params", RUBY_METHOD_FUNC(_llama_model_get_model_n_params), 0);
1538
- rb_define_method(rb_cLLaMAModel, "text", RUBY_METHOD_FUNC(_llama_model_get_text), 1);
1539
- rb_define_method(rb_cLLaMAModel, "score", RUBY_METHOD_FUNC(_llama_model_get_score), 1);
1540
- rb_define_method(rb_cLLaMAModel, "token_attr", RUBY_METHOD_FUNC(_llama_model_get_token_attr), 1);
1541
- rb_define_method(rb_cLLaMAModel, "token_bos", RUBY_METHOD_FUNC(_llama_model_token_bos), 0);
1542
- rb_define_method(rb_cLLaMAModel, "token_eos", RUBY_METHOD_FUNC(_llama_model_token_eos), 0);
1543
- rb_define_method(rb_cLLaMAModel, "token_cls", RUBY_METHOD_FUNC(_llama_model_token_cls), 0);
1544
- rb_define_method(rb_cLLaMAModel, "token_sep", RUBY_METHOD_FUNC(_llama_model_token_sep), 0);
1545
- rb_define_method(rb_cLLaMAModel, "token_nl", RUBY_METHOD_FUNC(_llama_model_token_nl), 0);
1546
- rb_define_method(rb_cLLaMAModel, "token_pad", RUBY_METHOD_FUNC(_llama_model_token_pad), 0);
1547
- rb_define_method(rb_cLLaMAModel, "add_bos_token?", RUBY_METHOD_FUNC(_llama_model_add_bos_token), 0);
1548
- rb_define_method(rb_cLLaMAModel, "add_eos_token?", RUBY_METHOD_FUNC(_llama_model_add_eos_token), 0);
1549
- rb_define_method(rb_cLLaMAModel, "token_prefix", RUBY_METHOD_FUNC(_llama_model_token_prefix), 0);
1550
- rb_define_method(rb_cLLaMAModel, "token_middle", RUBY_METHOD_FUNC(_llama_model_token_middle), 0);
1551
- rb_define_method(rb_cLLaMAModel, "token_suffix", RUBY_METHOD_FUNC(_llama_model_token_suffix), 0);
1552
- rb_define_method(rb_cLLaMAModel, "token_eot", RUBY_METHOD_FUNC(_llama_model_token_eot), 0);
1553
- rb_define_method(rb_cLLaMAModel, "token_is_eog?", RUBY_METHOD_FUNC(_llama_model_token_is_eog), 1);
1554
- rb_define_method(rb_cLLaMAModel, "token_is_control?", RUBY_METHOD_FUNC(_llama_model_token_is_control), 1);
1555
- rb_define_method(rb_cLLaMAModel, "has_encoder?", RUBY_METHOD_FUNC(_llama_model_has_encoder), 0);
1556
- rb_define_method(rb_cLLaMAModel, "has_decoder?", RUBY_METHOD_FUNC(_llama_model_has_decoder), 0);
1557
- rb_define_method(rb_cLLaMAModel, "decoder_start_token", RUBY_METHOD_FUNC(_llama_model_decoder_start_token), 0);
1558
- rb_define_method(rb_cLLaMAModel, "is_recurrent?", RUBY_METHOD_FUNC(_llama_model_is_recurrent), 0);
1559
- rb_define_method(rb_cLLaMAModel, "detokenize", RUBY_METHOD_FUNC(_llama_model_detokenize), -1);
1560
- }
1561
-
1562
- private:
1563
- static const rb_data_type_t llama_model_type;
1564
-
1565
- static VALUE _llama_model_initialize(int argc, VALUE* argv, VALUE self) {
1566
- VALUE kw_args = Qnil;
1567
- ID kw_table[2] = { rb_intern("model_path"), rb_intern("params") };
1568
- VALUE kw_values[2] = { Qundef, Qundef };
1569
- rb_scan_args(argc, argv, ":", &kw_args);
1570
- rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
1571
-
1572
- if (!RB_TYPE_P(kw_values[0], T_STRING)) {
1573
- rb_raise(rb_eArgError, "model_path must be a string");
1574
- return Qnil;
1575
- }
1576
- if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAModelParams)) {
1577
- rb_raise(rb_eArgError, "params must be a ModelParams");
1578
- return Qnil;
1579
- }
1580
-
1581
- VALUE filename = kw_values[0];
1582
- LLaMAModelParamsWrapper* prms_ptr = RbLLaMAModelParams::get_llama_model_params(kw_values[1]);
1583
- LLaMAModelWrapper* model_ptr = get_llama_model(self);
1584
-
1585
- try {
1586
- model_ptr->model = llama_load_model_from_file(StringValueCStr(filename), prms_ptr->params);
1587
- } catch (const std::runtime_error& e) {
1588
- rb_raise(rb_eRuntimeError, "%s", e.what());
1589
- return Qnil;
1590
- }
1591
-
1592
- if (model_ptr->model == NULL) {
1593
- rb_raise(rb_eRuntimeError, "Failed to initialize LLaMA model");
1594
- return Qnil;
1595
- }
1596
-
1597
- rb_iv_set(self, "@params", kw_values[1]);
1598
-
1599
- RB_GC_GUARD(filename);
1600
- return Qnil;
1601
- }
1602
-
1603
- static VALUE _llama_model_empty(VALUE self) {
1604
- LLaMAModelWrapper* ptr = get_llama_model(self);
1605
- if (ptr->model != NULL) {
1606
- return Qfalse;
1607
- }
1608
- return Qtrue;
1609
- }
1610
-
1611
- static VALUE _llama_model_free(VALUE self) {
1612
- LLaMAModelWrapper* ptr = get_llama_model(self);
1613
- if (ptr->model != NULL) {
1614
- llama_free_model(ptr->model);
1615
- ptr->model = NULL;
1616
- rb_iv_set(self, "@params", Qnil);
1617
- }
1618
- return Qnil;
1619
- }
1620
-
1621
- static VALUE _llama_model_load(int argc, VALUE* argv, VALUE self) {
1622
- VALUE kw_args = Qnil;
1623
- ID kw_table[2] = { rb_intern("model_path"), rb_intern("params") };
1624
- VALUE kw_values[2] = { Qundef, Qundef };
1625
- rb_scan_args(argc, argv, ":", &kw_args);
1626
- rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
1627
-
1628
- if (!RB_TYPE_P(kw_values[0], T_STRING)) {
1629
- rb_raise(rb_eArgError, "model_path must be a string");
1630
- return Qnil;
1631
- }
1632
- if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAModelParams)) {
1633
- rb_raise(rb_eArgError, "params must be a LLaMAModelParams");
1634
- return Qnil;
1635
- }
1636
-
1637
- LLaMAModelWrapper* model_ptr = get_llama_model(self);
1638
- if (model_ptr->model != NULL) {
1639
- rb_raise(rb_eRuntimeError, "LLaMA model is already loaded");
1640
- return Qnil;
1641
- }
1642
-
1643
- VALUE filename = kw_values[0];
1644
- LLaMAModelParamsWrapper* prms_ptr = RbLLaMAModelParams::get_llama_model_params(kw_values[1]);
1645
-
1646
- try {
1647
- model_ptr->model = llama_load_model_from_file(StringValueCStr(filename), prms_ptr->params);
1648
- } catch (const std::runtime_error& e) {
1649
- rb_raise(rb_eRuntimeError, "%s", e.what());
1650
- return Qnil;
1651
- }
1652
-
1653
- if (model_ptr->model == NULL) {
1654
- rb_raise(rb_eRuntimeError, "Failed to initialize LLaMA model");
1655
- return Qnil;
1656
- }
1657
-
1658
- rb_iv_set(self, "@params", kw_values[1]);
1659
-
1660
- RB_GC_GUARD(filename);
1661
- return Qnil;
1662
- }
1663
-
1664
- static VALUE _llama_model_get_model_vocab_type(VALUE self) {
1665
- LLaMAModelWrapper* ptr = get_llama_model(self);
1666
- return INT2NUM(llama_vocab_type(ptr->model));
1667
- }
1668
-
1669
- static VALUE _llama_model_get_model_rope_type(VALUE self) {
1670
- LLaMAModelWrapper* ptr = get_llama_model(self);
1671
- return INT2NUM(llama_rope_type(ptr->model));
1672
- }
1673
-
1674
- static VALUE _llama_model_get_model_n_vocab(VALUE self) {
1675
- LLaMAModelWrapper* ptr = get_llama_model(self);
1676
- return INT2NUM(llama_n_vocab(ptr->model));
1677
- }
1678
-
1679
- static VALUE _llama_model_get_model_n_ctx_train(VALUE self) {
1680
- LLaMAModelWrapper* ptr = get_llama_model(self);
1681
- return INT2NUM(llama_n_ctx_train(ptr->model));
1682
- }
1683
-
1684
- static VALUE _llama_model_get_model_n_embd(VALUE self) {
1685
- LLaMAModelWrapper* ptr = get_llama_model(self);
1686
- return INT2NUM(llama_n_embd(ptr->model));
1687
- }
1688
-
1689
- static VALUE _llama_model_get_model_n_layer(VALUE self) {
1690
- LLaMAModelWrapper* ptr = get_llama_model(self);
1691
- return INT2NUM(llama_n_layer(ptr->model));
1692
- }
1693
-
1694
- static VALUE _llama_model_rope_freq_scale_train(VALUE self) {
1695
- LLaMAModelWrapper* ptr = get_llama_model(self);
1696
- return DBL2NUM(llama_rope_freq_scale_train(ptr->model));
1697
- }
1698
-
1699
- static VALUE _llama_model_token_to_piece(int argc, VALUE* argv, VALUE self) {
1700
- VALUE kw_args = Qnil;
1701
- ID kw_table[2] = { rb_intern("lstrip"), rb_intern("special") };
1702
- VALUE kw_values[2] = { Qundef, Qundef };
1703
- VALUE token_ = Qnil;
1704
- rb_scan_args(argc, argv, "1:", &token_, &kw_args);
1705
- rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
1706
-
1707
- if (!RB_INTEGER_TYPE_P(token_)) {
1708
- rb_raise(rb_eArgError, "token must be an integer");
1709
- return Qnil;
1710
- }
1711
- if (kw_values[0] != Qundef && !RB_INTEGER_TYPE_P(kw_values[0])) {
1712
- rb_raise(rb_eArgError, "lstrip must be an integer");
1713
- return Qnil;
1714
- }
1715
-
1716
- const llama_token token = NUM2INT(token_);
1717
- const int32_t lstrip = kw_values[0] != Qundef ? NUM2INT(kw_values[0]) : 0;
1718
- const bool special = kw_values[1] != Qundef ? RTEST(kw_values[1]) : false;
1719
-
1720
- LLaMAModelWrapper* ptr = get_llama_model(self);
1721
- std::vector<char> result(8, 0);
1722
- const int n_tokens = llama_token_to_piece(ptr->model, token, result.data(), result.size(), lstrip, special);
1723
- if (n_tokens < 0) {
1724
- result.resize(-n_tokens);
1725
- const int check = llama_token_to_piece(ptr->model, token, result.data(), result.size(), lstrip, special);
1726
- if (check != -n_tokens) {
1727
- rb_raise(rb_eRuntimeError, "failed to convert");
1728
- return Qnil;
1729
- }
1730
- } else {
1731
- result.resize(n_tokens);
1732
- }
1733
- std::string ret(result.data(), result.size());
1734
- return rb_utf8_str_new_cstr(ret.c_str());
1735
- }
1736
-
1737
- static VALUE _llama_model_tokenize(int argc, VALUE* argv, VALUE self) {
1738
- VALUE kw_args = Qnil;
1739
- ID kw_table[4] = { rb_intern("text"), rb_intern("n_max_tokens"), rb_intern("add_bos"), rb_intern("special") };
1740
- VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
1741
- rb_scan_args(argc, argv, ":", &kw_args);
1742
- rb_get_kwargs(kw_args, kw_table, 1, 3, kw_values);
1743
-
1744
- if (!RB_TYPE_P(kw_values[0], T_STRING)) {
1745
- rb_raise(rb_eArgError, "text must be a String");
1746
- return Qnil;
1747
- }
1748
- if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
1749
- rb_raise(rb_eArgError, "n_max_tokens must be an integer");
1750
- return Qnil;
1751
- }
1752
- if (kw_values[2] != Qundef && (kw_values[2] != Qtrue && kw_values[2] != Qfalse)) {
1753
- rb_raise(rb_eArgError, "add_bos must be a boolean");
1754
- return Qnil;
1755
- }
1756
- if (kw_values[3] != Qundef && (kw_values[3] != Qtrue && kw_values[3] != Qfalse)) {
1757
- rb_raise(rb_eArgError, "special must be a boolean");
1758
- return Qnil;
1759
- }
1760
-
1761
- VALUE text_ = kw_values[0];
1762
- std::string text = StringValueCStr(text_);
1763
- const bool add_bos = kw_values[2] == Qtrue ? true : false;
1764
- const bool special = kw_values[3] == Qtrue ? true : false;
1765
- const int n_max_tokens = kw_values[1] != Qundef ? NUM2INT(kw_values[1]) : text.size() + (add_bos ? 1 : 0);
1766
-
1767
- llama_token* tokens = ALLOCA_N(llama_token, n_max_tokens);
1768
- LLaMAModelWrapper* ptr = get_llama_model(self);
1769
- const int n_tokens = llama_tokenize(ptr->model, text.c_str(), text.size(), tokens, n_max_tokens, add_bos, special);
1770
-
1771
- if (n_tokens < 0) {
1772
- rb_raise(rb_eRuntimeError, "failed to tokenize. The numebr of tokens (%d) is greater than n_max_tokens.", -n_tokens);
1773
- return Qnil;
1774
- }
1775
-
1776
- VALUE ret = rb_ary_new2(n_tokens);
1777
- for (int i = 0; i < n_tokens; i++) {
1778
- rb_ary_store(ret, i, INT2NUM(tokens[i]));
1779
- }
1780
-
1781
- RB_GC_GUARD(text_);
1782
- return ret;
1783
- }
1784
-
1785
- static VALUE _llama_model_get_model_desc(VALUE self) {
1786
- LLaMAModelWrapper* ptr = get_llama_model(self);
1787
- char buf[128];
1788
- llama_model_desc(ptr->model, buf, sizeof(buf));
1789
- return rb_utf8_str_new_cstr(buf);
1790
- }
1791
-
1792
- static VALUE _llama_model_get_model_size(VALUE self) {
1793
- LLaMAModelWrapper* ptr = get_llama_model(self);
1794
- return UINT2NUM(llama_model_size(ptr->model));
1795
- }
1796
-
1797
- static VALUE _llama_model_get_model_n_params(VALUE self) {
1798
- LLaMAModelWrapper* ptr = get_llama_model(self);
1799
- return UINT2NUM(llama_model_n_params(ptr->model));
1800
- }
1801
-
1802
- static VALUE _llama_model_get_text(VALUE self, VALUE token_) {
1803
- LLaMAModelWrapper* ptr = get_llama_model(self);
1804
- const llama_token token = NUM2INT(token_);
1805
- const char* text = llama_token_get_text(ptr->model, token);
1806
- return rb_utf8_str_new_cstr(text);
1807
- }
1808
-
1809
- static VALUE _llama_model_get_score(VALUE self, VALUE token_) {
1810
- LLaMAModelWrapper* ptr = get_llama_model(self);
1811
- const llama_token token = NUM2INT(token_);
1812
- const float score = llama_token_get_score(ptr->model, token);
1813
- return DBL2NUM(score);
1814
- }
1815
-
1816
- static VALUE _llama_model_get_token_attr(VALUE self, VALUE token_) {
1817
- LLaMAModelWrapper* ptr = get_llama_model(self);
1818
- const llama_token token = NUM2INT(token_);
1819
- const llama_token_attr type = llama_token_get_attr(ptr->model, token);
1820
- return INT2NUM(type);
1821
- }
1822
-
1823
- static VALUE _llama_model_token_bos(VALUE self) {
1824
- LLaMAModelWrapper* ptr = get_llama_model(self);
1825
- return INT2NUM(llama_token_bos(ptr->model));
1826
- }
1827
-
1828
- static VALUE _llama_model_token_eos(VALUE self) {
1829
- LLaMAModelWrapper* ptr = get_llama_model(self);
1830
- return INT2NUM(llama_token_eos(ptr->model));
1831
- }
1832
-
1833
- static VALUE _llama_model_token_cls(VALUE self) {
1834
- LLaMAModelWrapper* ptr = get_llama_model(self);
1835
- return INT2NUM(llama_token_cls(ptr->model));
1836
- }
1837
-
1838
- static VALUE _llama_model_token_sep(VALUE self) {
1839
- LLaMAModelWrapper* ptr = get_llama_model(self);
1840
- return INT2NUM(llama_token_sep(ptr->model));
1841
- }
1842
-
1843
- static VALUE _llama_model_token_nl(VALUE self) {
1844
- LLaMAModelWrapper* ptr = get_llama_model(self);
1845
- return INT2NUM(llama_token_nl(ptr->model));
1846
- }
1847
-
1848
- static VALUE _llama_model_token_pad(VALUE self) {
1849
- LLaMAModelWrapper* ptr = get_llama_model(self);
1850
- return INT2NUM(llama_token_pad(ptr->model));
1851
- }
1852
-
1853
- static VALUE _llama_model_add_bos_token(VALUE self) {
1854
- LLaMAModelWrapper* ptr = get_llama_model(self);
1855
- return llama_add_bos_token(ptr->model) ? Qtrue : Qfalse;
1856
- }
1857
-
1858
- static VALUE _llama_model_add_eos_token(VALUE self) {
1859
- LLaMAModelWrapper* ptr = get_llama_model(self);
1860
- return llama_add_eos_token(ptr->model) ? Qtrue : Qfalse;
1861
- }
1862
-
1863
- static VALUE _llama_model_token_prefix(VALUE self) {
1864
- LLaMAModelWrapper* ptr = get_llama_model(self);
1865
- return INT2NUM(llama_token_prefix(ptr->model));
1866
- }
1867
-
1868
- static VALUE _llama_model_token_middle(VALUE self) {
1869
- LLaMAModelWrapper* ptr = get_llama_model(self);
1870
- return INT2NUM(llama_token_middle(ptr->model));
1871
- }
1872
-
1873
- static VALUE _llama_model_token_suffix(VALUE self) {
1874
- LLaMAModelWrapper* ptr = get_llama_model(self);
1875
- return INT2NUM(llama_token_suffix(ptr->model));
1876
- }
1877
-
1878
- static VALUE _llama_model_token_eot(VALUE self) {
1879
- LLaMAModelWrapper* ptr = get_llama_model(self);
1880
- return INT2NUM(llama_token_eot(ptr->model));
1881
- }
1882
-
1883
- static VALUE _llama_model_token_is_eog(VALUE self, VALUE token_) {
1884
- if (!RB_INTEGER_TYPE_P(token_)) {
1885
- rb_raise(rb_eArgError, "token must be an integer");
1886
- return Qnil;
1887
- }
1888
- const llama_token token = NUM2INT(token_);
1889
- LLaMAModelWrapper* ptr = get_llama_model(self);
1890
- return llama_token_is_eog(ptr->model, token) ? Qtrue : Qfalse;
1891
- }
1892
-
1893
- static VALUE _llama_model_token_is_control(VALUE self, VALUE token_) {
1894
- if (!RB_INTEGER_TYPE_P(token_)) {
1895
- rb_raise(rb_eArgError, "token must be an integer");
1896
- return Qnil;
1897
- }
1898
- const llama_token token = NUM2INT(token_);
1899
- LLaMAModelWrapper* ptr = get_llama_model(self);
1900
- return llama_token_is_control(ptr->model, token) ? Qtrue : Qfalse;
1901
- }
1902
-
1903
- static VALUE _llama_model_has_encoder(VALUE self) {
1904
- LLaMAModelWrapper* ptr = get_llama_model(self);
1905
- return llama_model_has_encoder(ptr->model) ? Qtrue : Qfalse;
1906
- }
1907
-
1908
- static VALUE _llama_model_has_decoder(VALUE self) {
1909
- LLaMAModelWrapper* ptr = get_llama_model(self);
1910
- return llama_model_has_decoder(ptr->model) ? Qtrue : Qfalse;
1911
- }
1912
-
1913
- static VALUE _llama_model_decoder_start_token(VALUE self) {
1914
- LLaMAModelWrapper* ptr = get_llama_model(self);
1915
- return INT2NUM(llama_model_decoder_start_token(ptr->model));
1916
- }
1917
-
1918
- static VALUE _llama_model_is_recurrent(VALUE self) {
1919
- LLaMAModelWrapper* ptr = get_llama_model(self);
1920
- return llama_model_is_recurrent(ptr->model) ? Qtrue : Qfalse;
1921
- }
1922
-
1923
- static VALUE _llama_model_detokenize(int argc, VALUE* argv, VALUE self) {
1924
- VALUE kw_args = Qnil;
1925
- ID kw_table[2] = { rb_intern("remove_special"), rb_intern("unparse_special") };
1926
- VALUE kw_values[2] = { Qundef, Qundef };
1927
- VALUE tokens_ = Qnil;
1928
- rb_scan_args(argc, argv, "1:", &tokens_, &kw_args);
1929
- rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
1930
-
1931
- if (!RB_TYPE_P(tokens_, T_ARRAY)) {
1932
- rb_raise(rb_eArgError, "tokens must be an array");
1933
- return Qnil;
1934
- }
1935
-
1936
- const int32_t n_tokens = RARRAY_LEN(tokens_);
1937
- llama_token* tokens = ALLOCA_N(llama_token, n_tokens);
1938
- for (int32_t i = 0; i < n_tokens; i++) {
1939
- tokens[i] = NUM2INT(rb_ary_entry(tokens_, i));
1940
- }
1941
-
1942
- std::string text;
1943
- text.resize(std::max(text.capacity(), static_cast<unsigned long>(n_tokens)));
1944
- const int32_t text_len_max = text.size();
1945
-
1946
- bool remove_special = kw_values[0] != Qundef ? RTEST(kw_values[0]) : false;
1947
- bool unparse_special = kw_values[1] != Qundef ? RTEST(kw_values[1]) : false;
1948
-
1949
- LLaMAModelWrapper* ptr = get_llama_model(self);
1950
- std::string result;
1951
- int32_t n_chars = llama_detokenize(ptr->model, tokens, n_tokens, &text[0], text_len_max, remove_special, unparse_special);
1952
- if (n_chars < 0) {
1953
- text.resize(-n_chars);
1954
- n_chars = llama_detokenize(ptr->model, tokens, n_tokens, &text[0], text_len_max, remove_special, unparse_special);
1955
- if (n_chars <= text.size()) {
1956
- rb_raise(rb_eRuntimeError, "Failed to detokenize");
1957
- return Qnil;
1958
- }
1959
- }
1960
-
1961
- text.resize(n_chars);
1962
- return rb_utf8_str_new_cstr(text.c_str());
1963
- }
1964
- };
1965
-
1966
- const rb_data_type_t RbLLaMAModel::llama_model_type = {
1967
- "RbLLaMAModel",
1968
- { NULL,
1969
- RbLLaMAModel::llama_model_free,
1970
- RbLLaMAModel::llama_model_size },
1971
- NULL,
1972
- NULL,
1973
- RUBY_TYPED_FREE_IMMEDIATELY
1974
- };
1975
-
1976
- class LLaMAGrammarElementWrapper {
1977
- public:
1978
- llama_grammar_element element;
1979
-
1980
- LLaMAGrammarElementWrapper() {
1981
- element.type = LLAMA_GRETYPE_END;
1982
- element.value = 0;
1983
- }
1984
-
1985
- ~LLaMAGrammarElementWrapper() {}
1986
- };
1987
-
1988
- class RbLLaMAGrammarElement {
1989
- public:
1990
- static VALUE llama_grammar_element_alloc(VALUE self) {
1991
- LLaMAGrammarElementWrapper* ptr = (LLaMAGrammarElementWrapper*)ruby_xmalloc(sizeof(LLaMAGrammarElementWrapper));
1992
- new (ptr) LLaMAGrammarElementWrapper();
1993
- return TypedData_Wrap_Struct(self, &llama_grammar_element_type, ptr);
1994
- }
1995
-
1996
- static void llama_grammar_element_free(void* ptr) {
1997
- ((LLaMAGrammarElementWrapper*)ptr)->~LLaMAGrammarElementWrapper();
1998
- ruby_xfree(ptr);
1999
- }
2000
-
2001
- static size_t llama_grammar_element_size(const void* ptr) {
2002
- return sizeof(*((LLaMAGrammarElementWrapper*)ptr));
2003
- }
2004
-
2005
- static LLaMAGrammarElementWrapper* get_llama_grammar_element(VALUE self) {
2006
- LLaMAGrammarElementWrapper* ptr;
2007
- TypedData_Get_Struct(self, LLaMAGrammarElementWrapper, &llama_grammar_element_type, ptr);
2008
- return ptr;
2009
- }
2010
-
2011
- static void define_class(VALUE outer) {
2012
- rb_cLLaMAGrammarElement = rb_define_class_under(outer, "GrammarElement", rb_cObject);
2013
- rb_define_alloc_func(rb_cLLaMAGrammarElement, llama_grammar_element_alloc);
2014
- rb_define_method(rb_cLLaMAGrammarElement, "initialize", RUBY_METHOD_FUNC(_llama_grammar_element_init), -1);
2015
- rb_define_method(rb_cLLaMAGrammarElement, "type=", RUBY_METHOD_FUNC(_llama_grammar_element_set_type), 1);
2016
- rb_define_method(rb_cLLaMAGrammarElement, "type", RUBY_METHOD_FUNC(_llama_grammar_element_get_type), 0);
2017
- rb_define_method(rb_cLLaMAGrammarElement, "value=", RUBY_METHOD_FUNC(_llama_grammar_element_set_value), 1);
2018
- rb_define_method(rb_cLLaMAGrammarElement, "value", RUBY_METHOD_FUNC(_llama_grammar_element_get_value), 0);
2019
- }
2020
-
2021
- private:
2022
- static const rb_data_type_t llama_grammar_element_type;
2023
-
2024
- static VALUE _llama_grammar_element_init(int argc, VALUE* argv, VALUE self) {
2025
- VALUE kw_args = Qnil;
2026
- ID kw_table[2] = { rb_intern("type"), rb_intern("value") };
2027
- VALUE kw_values[2] = { Qundef, Qundef };
2028
- VALUE arr = Qnil;
2029
- rb_scan_args(argc, argv, ":", &arr, &kw_args);
2030
- rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
2031
-
2032
- if (kw_values[0] != Qundef && !RB_INTEGER_TYPE_P(kw_values[0])) {
2033
- rb_raise(rb_eArgError, "type must be an integer");
2034
- return Qnil;
2035
- }
2036
- if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
2037
- rb_raise(rb_eArgError, "value must be an integer");
2038
- return Qnil;
2039
- }
2040
-
2041
- LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
2042
- new (ptr) LLaMAGrammarElementWrapper();
2043
-
2044
- if (kw_values[0] != Qundef) {
2045
- ptr->element.type = (enum llama_gretype)NUM2INT(kw_values[0]);
2046
- }
2047
- if (kw_values[1] != Qundef) {
2048
- ptr->element.value = NUM2INT(kw_values[1]);
2049
- }
2050
-
2051
- return self;
2052
- }
2053
-
2054
- // type
2055
- static VALUE _llama_grammar_element_set_type(VALUE self, VALUE type) {
2056
- LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
2057
- ptr->element.type = (enum llama_gretype)NUM2INT(type);
2058
- return INT2NUM(ptr->element.type);
2059
- }
2060
-
2061
- static VALUE _llama_grammar_element_get_type(VALUE self) {
2062
- LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
2063
- return INT2NUM(ptr->element.type);
2064
- }
2065
-
2066
- // value
2067
- static VALUE _llama_grammar_element_set_value(VALUE self, VALUE type) {
2068
- LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
2069
- ptr->element.value = NUM2INT(type);
2070
- return INT2NUM(ptr->element.value);
2071
- }
2072
-
2073
- static VALUE _llama_grammar_element_get_value(VALUE self) {
2074
- LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
2075
- return INT2NUM(ptr->element.value);
2076
- }
2077
- };
2078
-
2079
- const rb_data_type_t RbLLaMAGrammarElement::llama_grammar_element_type = {
2080
- "RbLLaMAGrammarElement",
2081
- { NULL,
2082
- RbLLaMAGrammarElement::llama_grammar_element_free,
2083
- RbLLaMAGrammarElement::llama_grammar_element_size },
2084
- NULL,
2085
- NULL,
2086
- RUBY_TYPED_FREE_IMMEDIATELY
2087
- };
2088
-
2089
- class LLaMAGrammarWrapper {
2090
- public:
2091
- struct llama_grammar* grammar;
2092
-
2093
- LLaMAGrammarWrapper() : grammar(nullptr) {}
2094
-
2095
- ~LLaMAGrammarWrapper() {
2096
- if (grammar) {
2097
- llama_grammar_free(grammar);
2098
- }
2099
- }
2100
- };
2101
-
2102
- class RbLLaMAGrammar {
2103
- public:
2104
- static VALUE llama_grammar_alloc(VALUE self) {
2105
- LLaMAGrammarWrapper* ptr = (LLaMAGrammarWrapper*)ruby_xmalloc(sizeof(LLaMAGrammarWrapper));
2106
- new (ptr) LLaMAGrammarWrapper();
2107
- return TypedData_Wrap_Struct(self, &llama_grammar_type, ptr);
2108
- }
2109
-
2110
- static void llama_grammar_free(void* ptr) {
2111
- ((LLaMAGrammarWrapper*)ptr)->~LLaMAGrammarWrapper();
2112
- ruby_xfree(ptr);
2113
- }
2114
-
2115
- static size_t llama_grammar_size(const void* ptr) {
2116
- return sizeof(*((LLaMAGrammarWrapper*)ptr));
2117
- }
2118
-
2119
- static LLaMAGrammarWrapper* get_llama_grammar(VALUE self) {
2120
- LLaMAGrammarWrapper* ptr;
2121
- TypedData_Get_Struct(self, LLaMAGrammarWrapper, &llama_grammar_type, ptr);
2122
- return ptr;
2123
- }
2124
-
2125
- static void define_class(VALUE outer) {
2126
- rb_cLLaMAGrammar = rb_define_class_under(outer, "Grammar", rb_cObject);
2127
- rb_define_alloc_func(rb_cLLaMAGrammar, llama_grammar_alloc);
2128
- rb_define_method(rb_cLLaMAGrammar, "initialize", RUBY_METHOD_FUNC(_llama_grammar_init), -1);
2129
- }
2130
-
2131
- private:
2132
- static const rb_data_type_t llama_grammar_type;
2133
-
2134
- static VALUE _llama_grammar_init(int argc, VALUE* argv, VALUE self) {
2135
- VALUE kw_args = Qnil;
2136
- ID kw_table[2] = { rb_intern("rules"), rb_intern("start_rule_index") };
2137
- VALUE kw_values[2] = { Qundef, Qundef };
2138
- rb_scan_args(argc, argv, ":", &kw_args);
2139
- rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
2140
-
2141
- if (!RB_TYPE_P(kw_values[0], T_ARRAY)) {
2142
- rb_raise(rb_eArgError, "rules must be an array");
2143
- return Qnil;
2144
- }
2145
- if (!RB_INTEGER_TYPE_P(kw_values[1])) {
2146
- rb_raise(rb_eArgError, "start_rule_index must be an integer");
2147
- return Qnil;
2148
- }
2149
-
2150
- const int n_rules = RARRAY_LEN(kw_values[0]);
2151
- llama_grammar_element** rules = ALLOCA_N(llama_grammar_element*, n_rules);
2152
- for (int i = 0; i < n_rules; ++i) {
2153
- VALUE rule = rb_ary_entry(kw_values[0], i);
2154
- if (!RB_TYPE_P(rule, T_ARRAY)) {
2155
- rb_raise(rb_eArgError, "element of rules must be an array");
2156
- return Qnil;
2157
- }
2158
- const int n_elements = RARRAY_LEN(rule);
2159
- llama_grammar_element* elements = ALLOCA_N(llama_grammar_element, n_elements);
2160
- for (int j = 0; j < n_elements; ++j) {
2161
- VALUE element = rb_ary_entry(rule, j);
2162
- if (!rb_obj_is_kind_of(element, rb_cLLaMAGrammarElement)) {
2163
- rb_raise(rb_eArgError, "element of rule must be an instance of GrammarElement");
2164
- return Qnil;
2165
- }
2166
- LLaMAGrammarElementWrapper* ptr = RbLLaMAGrammarElement::get_llama_grammar_element(element);
2167
- elements[j] = ptr->element;
2168
- }
2169
- rules[i] = elements;
2170
- }
2171
-
2172
- const size_t start_rule_index = NUM2SIZET(kw_values[1]);
2173
-
2174
- LLaMAGrammarWrapper* ptr = get_llama_grammar(self);
2175
- new (ptr) LLaMAGrammarWrapper();
2176
- ptr->grammar = llama_grammar_init((const llama_grammar_element**)rules, n_rules, start_rule_index);
2177
-
2178
- return self;
2179
- }
2180
- };
2181
-
2182
- const rb_data_type_t RbLLaMAGrammar::llama_grammar_type = {
2183
- "RbLLaMAGrammar",
2184
- { NULL,
2185
- RbLLaMAGrammar::llama_grammar_free,
2186
- RbLLaMAGrammar::llama_grammar_size },
2187
- NULL,
2188
- NULL,
2189
- RUBY_TYPED_FREE_IMMEDIATELY
2190
- };
2191
-
2192
- class LLaMAContextWrapper {
2193
- public:
2194
- struct llama_context* ctx;
2195
-
2196
- LLaMAContextWrapper() : ctx(NULL) {}
2197
-
2198
- ~LLaMAContextWrapper() {
2199
- if (ctx != NULL) {
2200
- llama_free(ctx);
2201
- }
2202
- }
2203
- };
2204
-
2205
- class RbLLaMAContext {
2206
- public:
2207
- static VALUE llama_context_alloc(VALUE self) {
2208
- LLaMAContextWrapper* ptr = (LLaMAContextWrapper*)ruby_xmalloc(sizeof(LLaMAContextWrapper));
2209
- new (ptr) LLaMAContextWrapper();
2210
- return TypedData_Wrap_Struct(self, &llama_context_type, ptr);
2211
- }
2212
-
2213
- static void llama_context_free(void* ptr) {
2214
- ((LLaMAContextWrapper*)ptr)->~LLaMAContextWrapper();
2215
- ruby_xfree(ptr);
2216
- }
2217
-
2218
- static size_t llama_context_size(const void* ptr) {
2219
- return sizeof(*((LLaMAContextWrapper*)ptr));
2220
- }
2221
-
2222
- static LLaMAContextWrapper* get_llama_context(VALUE self) {
2223
- LLaMAContextWrapper* ptr;
2224
- TypedData_Get_Struct(self, LLaMAContextWrapper, &llama_context_type, ptr);
2225
- return ptr;
2226
- }
2227
-
2228
- static void define_class(VALUE outer) {
2229
- rb_cLLaMAContext = rb_define_class_under(outer, "Context", rb_cObject);
2230
- rb_define_alloc_func(rb_cLLaMAContext, llama_context_alloc);
2231
- rb_define_attr(rb_cLLaMAContext, "model", 1, 0);
2232
- rb_define_method(rb_cLLaMAContext, "initialize", RUBY_METHOD_FUNC(_llama_context_initialize), -1);
2233
- rb_define_method(rb_cLLaMAContext, "encode", RUBY_METHOD_FUNC(_llama_context_encode), 1);
2234
- rb_define_method(rb_cLLaMAContext, "decode", RUBY_METHOD_FUNC(_llama_context_decode), 1);
2235
- rb_define_method(rb_cLLaMAContext, "logits", RUBY_METHOD_FUNC(_llama_context_logits), 0);
2236
- rb_define_method(rb_cLLaMAContext, "embeddings", RUBY_METHOD_FUNC(_llama_context_embeddings), 0);
2237
- rb_define_method(rb_cLLaMAContext, "embeddings_ith", RUBY_METHOD_FUNC(_llama_context_embeddings_ith), 1);
2238
- rb_define_method(rb_cLLaMAContext, "embeddings_seq", RUBY_METHOD_FUNC(_llama_context_embeddings_seq), 1);
2239
- rb_define_method(rb_cLLaMAContext, "set_embeddings", RUBY_METHOD_FUNC(_llama_context_set_embeddings), 1);
2240
- rb_define_method(rb_cLLaMAContext, "set_n_threads", RUBY_METHOD_FUNC(_llama_context_set_n_threads), -1);
2241
- rb_define_method(rb_cLLaMAContext, "n_ctx", RUBY_METHOD_FUNC(_llama_context_n_ctx), 0);
2242
- rb_define_method(rb_cLLaMAContext, "n_batch", RUBY_METHOD_FUNC(_llama_context_n_batch), 0);
2243
- rb_define_method(rb_cLLaMAContext, "n_ubatch", RUBY_METHOD_FUNC(_llama_context_n_ubatch), 0);
2244
- rb_define_method(rb_cLLaMAContext, "n_seq_max", RUBY_METHOD_FUNC(_llama_context_n_seq_max), 0);
2245
- rb_define_method(rb_cLLaMAContext, "n_threads", RUBY_METHOD_FUNC(_llama_context_n_threads), 0);
2246
- rb_define_method(rb_cLLaMAContext, "n_threads_batch", RUBY_METHOD_FUNC(_llama_context_n_threads_batch), 0);
2247
- rb_define_method(rb_cLLaMAContext, "timings", RUBY_METHOD_FUNC(_llama_context_get_timings), 0);
2248
- rb_define_method(rb_cLLaMAContext, "print_timings", RUBY_METHOD_FUNC(_llama_context_print_timings), 0);
2249
- rb_define_method(rb_cLLaMAContext, "reset_timings", RUBY_METHOD_FUNC(_llama_context_reset_timings), 0);
2250
- rb_define_method(rb_cLLaMAContext, "kv_cache_token_count", RUBY_METHOD_FUNC(_llama_context_kv_cache_token_count), 0);
2251
- rb_define_method(rb_cLLaMAContext, "kv_cache_clear", RUBY_METHOD_FUNC(_llama_context_kv_cache_clear), 0);
2252
- rb_define_method(rb_cLLaMAContext, "kv_cache_seq_rm", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_rm), 3);
2253
- rb_define_method(rb_cLLaMAContext, "kv_cache_seq_cp", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_cp), 4);
2254
- rb_define_method(rb_cLLaMAContext, "kv_cache_seq_keep", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_keep), 1);
2255
- rb_define_method(rb_cLLaMAContext, "kv_cache_seq_add", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_add), 4);
2256
- rb_define_method(rb_cLLaMAContext, "kv_cache_seq_div", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_div), 4);
2257
- rb_define_method(rb_cLLaMAContext, "kv_cache_seq_pos_max", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_pos_max), 1);
2258
- rb_define_method(rb_cLLaMAContext, "kv_cache_kv_cache_defrag", RUBY_METHOD_FUNC(_llama_context_kv_cache_defrag), 0);
2259
- rb_define_method(rb_cLLaMAContext, "kv_cache_kv_cache_update", RUBY_METHOD_FUNC(_llama_context_kv_cache_update), 0);
2260
- rb_define_method(rb_cLLaMAContext, "set_rng_seed", RUBY_METHOD_FUNC(_llama_context_set_rng_seed), 1);
2261
- rb_define_method(rb_cLLaMAContext, "set_causal_attn", RUBY_METHOD_FUNC(_llama_context_set_causal_attn), 1);
2262
- rb_define_method(rb_cLLaMAContext, "synchronize", RUBY_METHOD_FUNC(_llama_context_synchronize), 0);
2263
- rb_define_method(rb_cLLaMAContext, "load_session_file", RUBY_METHOD_FUNC(_llama_context_load_session_file), -1);
2264
- rb_define_method(rb_cLLaMAContext, "save_session_file", RUBY_METHOD_FUNC(_llama_context_save_session_file), -1);
2265
- rb_define_method(rb_cLLaMAContext, "sample_repetition_penalties", RUBY_METHOD_FUNC(_llama_context_sample_repetition_penalties), -1);
2266
- rb_define_method(rb_cLLaMAContext, "sample_apply_guidance", RUBY_METHOD_FUNC(_llama_context_sample_apply_guidance), -1);
2267
- rb_define_method(rb_cLLaMAContext, "sample_softmax", RUBY_METHOD_FUNC(_llama_context_sample_softmax), 1);
2268
- rb_define_method(rb_cLLaMAContext, "sample_top_k", RUBY_METHOD_FUNC(_llama_context_sample_top_k), -1);
2269
- rb_define_method(rb_cLLaMAContext, "sample_top_p", RUBY_METHOD_FUNC(_llama_context_sample_top_p), -1);
2270
- rb_define_method(rb_cLLaMAContext, "sample_min_p", RUBY_METHOD_FUNC(_llama_context_sample_min_p), -1);
2271
- rb_define_method(rb_cLLaMAContext, "sample_tail_free", RUBY_METHOD_FUNC(_llama_context_sample_tail_free), -1);
2272
- rb_define_method(rb_cLLaMAContext, "sample_typical", RUBY_METHOD_FUNC(_llama_context_sample_typical), -1);
2273
- rb_define_method(rb_cLLaMAContext, "sample_temp", RUBY_METHOD_FUNC(_llama_context_sample_temp), -1);
2274
- rb_define_method(rb_cLLaMAContext, "sample_entropy", RUBY_METHOD_FUNC(_llama_context_sample_entropy), -1);
2275
- rb_define_method(rb_cLLaMAContext, "sample_token_mirostat", RUBY_METHOD_FUNC(_llama_context_sample_token_mirostat), -1);
2276
- rb_define_method(rb_cLLaMAContext, "sample_token_mirostat_v2", RUBY_METHOD_FUNC(_llama_context_sample_token_mirostat_v2), -1);
2277
- rb_define_method(rb_cLLaMAContext, "sample_token_greedy", RUBY_METHOD_FUNC(_llama_context_sample_token_greedy), 1);
2278
- rb_define_method(rb_cLLaMAContext, "sample_token", RUBY_METHOD_FUNC(_llama_context_sample_token), 1);
2279
- rb_define_method(rb_cLLaMAContext, "sample_grammar", RUBY_METHOD_FUNC(_llama_context_sample_grammar), -1);
2280
- rb_define_method(rb_cLLaMAContext, "grammar_accept_token", RUBY_METHOD_FUNC(_llama_context_grammar_accept_token), -1);
2281
- rb_define_method(rb_cLLaMAContext, "apply_control_vector", RUBY_METHOD_FUNC(_llama_context_apply_control_vector), -1);
2282
- rb_define_method(rb_cLLaMAContext, "pooling_type", RUBY_METHOD_FUNC(_llama_context_pooling_type), 0);
2283
- }
2284
-
2285
- private:
2286
- static const rb_data_type_t llama_context_type;
2287
-
2288
- static VALUE _llama_context_initialize(int argc, VALUE* argv, VALUE self) {
2289
- VALUE kw_args = Qnil;
2290
- ID kw_table[2] = { rb_intern("model"), rb_intern("params") };
2291
- VALUE kw_values[2] = { Qundef, Qundef };
2292
- rb_scan_args(argc, argv, ":", &kw_args);
2293
- rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
2294
-
2295
- VALUE model = kw_values[0];
2296
- if (!rb_obj_is_kind_of(model, rb_cLLaMAModel)) {
2297
- rb_raise(rb_eArgError, "model must be a Model");
2298
- return Qnil;
2299
- }
2300
- VALUE params = kw_values[1];
2301
- if (!rb_obj_is_kind_of(params, rb_cLLaMAContextParams)) {
2302
- rb_raise(rb_eArgError, "params must be a ContextParams");
2303
- return Qnil;
2304
- }
2305
-
2306
- LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
2307
- if (model_ptr->model == NULL) {
2308
- rb_raise(rb_eRuntimeError, "Model is empty");
2309
- return Qnil;
2310
- }
2311
- LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
2312
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
2313
-
2314
- ctx_ptr->ctx = llama_new_context_with_model(model_ptr->model, prms_ptr->params);
2315
-
2316
- if (ctx_ptr->ctx == NULL) {
2317
- rb_raise(rb_eRuntimeError, "Failed to initialize LLaMA context");
2318
- return Qnil;
2319
- }
2320
-
2321
- rb_iv_set(self, "@model", model);
2322
- rb_iv_set(self, "@params", params);
2323
- rb_iv_set(self, "@has_evaluated", Qfalse);
2324
-
2325
- return Qnil;
2326
- }
2327
-
2328
- static VALUE _llama_context_encode(VALUE self, VALUE batch) {
2329
- LLaMAContextWrapper* ptr = get_llama_context(self);
2330
- if (ptr->ctx == NULL) {
2331
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2332
- return Qnil;
2333
- }
2334
- if (!rb_obj_is_kind_of(batch, rb_cLLaMABatch)) {
2335
- rb_raise(rb_eArgError, "batch must be a Batch");
2336
- return Qnil;
2337
- }
2338
- LLaMABatchWrapper* batch_ptr = RbLLaMABatch::get_llama_batch(batch);
2339
- if (llama_encode(ptr->ctx, batch_ptr->batch) < 0) {
2340
- rb_raise(rb_eRuntimeError, "Failed to encode");
2341
- return Qnil;
2342
- }
2343
- return Qnil;
2344
- }
2345
-
2346
- static VALUE _llama_context_decode(VALUE self, VALUE batch) {
2347
- LLaMAContextWrapper* ptr = get_llama_context(self);
2348
- if (ptr->ctx == NULL) {
2349
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2350
- return Qnil;
2351
- }
2352
- if (!rb_obj_is_kind_of(batch, rb_cLLaMABatch)) {
2353
- rb_raise(rb_eArgError, "batch must be a Batch");
2354
- return Qnil;
2355
- }
2356
- LLaMABatchWrapper* batch_ptr = RbLLaMABatch::get_llama_batch(batch);
2357
- if (llama_decode(ptr->ctx, batch_ptr->batch) < 0) {
2358
- rb_raise(rb_eRuntimeError, "Failed to decode");
2359
- return Qnil;
2360
- }
2361
- rb_iv_set(self, "@n_tokens", INT2NUM(batch_ptr->batch.n_tokens));
2362
- rb_iv_set(self, "@has_evaluated", Qtrue);
2363
- return Qnil;
2364
- }
2365
-
2366
- static VALUE _llama_context_logits(VALUE self) {
2367
- LLaMAContextWrapper* ptr = get_llama_context(self);
2368
- if (ptr->ctx == NULL) {
2369
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2370
- return Qnil;
2371
- }
2372
- if (rb_iv_get(self, "@has_evaluated") != Qtrue) {
2373
- rb_raise(rb_eRuntimeError, "LLaMA context has not been evaluated");
2374
- return Qnil;
2375
- }
2376
-
2377
- VALUE model = rb_iv_get(self, "@model");
2378
- LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
2379
- VALUE params = rb_iv_get(self, "@params");
2380
- LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
2381
- const int n_tokens = prms_ptr->params.logits_all ? NUM2INT(rb_iv_get(self, "@n_tokens")) : 1;
2382
- const int n_vocab = llama_n_vocab(model_ptr->model);
2383
- const float* logits = llama_get_logits(ptr->ctx);
2384
- VALUE output = rb_ary_new();
2385
- for (int i = 0; i < n_tokens * n_vocab; i++) {
2386
- rb_ary_push(output, DBL2NUM((double)(logits[i])));
2387
- }
2388
-
2389
- return output;
2390
- }
2391
-
2392
- static VALUE _llama_context_embeddings(VALUE self) {
2393
- LLaMAContextWrapper* ptr = get_llama_context(self);
2394
- if (ptr->ctx == NULL) {
2395
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2396
- return Qnil;
2397
- }
2398
- VALUE model = rb_iv_get(self, "@model");
2399
- LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
2400
- VALUE params = rb_iv_get(self, "@params");
2401
- LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
2402
- if (!prms_ptr->params.embeddings) {
2403
- rb_raise(rb_eRuntimeError, "embedding parameter is false");
2404
- return Qnil;
2405
- }
2406
- if (rb_iv_get(self, "@has_evaluated") != Qtrue) {
2407
- rb_raise(rb_eRuntimeError, "LLaMA context has not been evaluated");
2408
- return Qnil;
2409
- }
2410
-
2411
- const int n_tokens = NUM2INT(rb_iv_get(self, "@n_tokens"));
2412
- const int n_embd = llama_n_embd(model_ptr->model);
2413
- const float* embd = llama_get_embeddings(ptr->ctx);
2414
- VALUE output = rb_ary_new();
2415
- for (int i = 0; i < n_tokens * n_embd; i++) {
2416
- rb_ary_push(output, DBL2NUM((double)(embd[i])));
2417
- }
2418
-
2419
- return output;
2420
- }
2421
-
2422
- static VALUE _llama_context_embeddings_ith(VALUE self, VALUE ith) {
2423
- if (!RB_INTEGER_TYPE_P(ith)) {
2424
- rb_raise(rb_eArgError, "ith must be an integer");
2425
- return Qnil;
2426
- }
2427
- LLaMAContextWrapper* ptr = get_llama_context(self);
2428
- if (ptr->ctx == NULL) {
2429
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2430
- return Qnil;
2431
- }
2432
- VALUE params = rb_iv_get(self, "@params");
2433
- LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
2434
- if (!prms_ptr->params.embeddings) {
2435
- rb_raise(rb_eRuntimeError, "embedding parameter is false");
2436
- return Qnil;
2437
- }
2438
-
2439
- VALUE model = rb_iv_get(self, "@model");
2440
- LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
2441
- const int n_embd = llama_n_embd(model_ptr->model);
2442
-
2443
- VALUE output = rb_ary_new();
2444
- const float* embd = llama_get_embeddings_ith(ptr->ctx, NUM2INT(ith));
2445
- for (int i = 0; i < n_embd; i++) {
2446
- rb_ary_push(output, DBL2NUM((double)(embd[i])));
2447
- }
2448
-
2449
- return output;
2450
- }
2451
-
2452
- static VALUE _llama_context_embeddings_seq(VALUE self, VALUE seq_id) {
2453
- if (!RB_INTEGER_TYPE_P(seq_id)) {
2454
- rb_raise(rb_eArgError, "seq_id must be an integer");
2455
- return Qnil;
2456
- }
2457
- LLaMAContextWrapper* ptr = get_llama_context(self);
2458
- if (ptr->ctx == NULL) {
2459
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2460
- return Qnil;
2461
- }
2462
- VALUE params = rb_iv_get(self, "@params");
2463
- LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
2464
- if (!prms_ptr->params.embeddings) {
2465
- rb_raise(rb_eRuntimeError, "embedding parameter is false");
2466
- return Qnil;
2467
- }
2468
-
2469
- VALUE model = rb_iv_get(self, "@model");
2470
- LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
2471
- const int n_embd = llama_n_embd(model_ptr->model);
2472
-
2473
- VALUE output = rb_ary_new();
2474
- const float* embd = llama_get_embeddings_seq(ptr->ctx, NUM2INT(seq_id));
2475
- for (int i = 0; i < n_embd; i++) {
2476
- rb_ary_push(output, DBL2NUM((double)(embd[i])));
2477
- }
2478
-
2479
- return output;
2480
- }
2481
-
2482
- static VALUE _llama_context_set_embeddings(VALUE self, VALUE embs) {
2483
- LLaMAContextWrapper* ptr = get_llama_context(self);
2484
- if (ptr->ctx == NULL) {
2485
- rb_raise(rb_eArgError, "LLaMA context is not initialized");
2486
- return Qnil;
2487
- }
2488
- llama_set_embeddings(ptr->ctx, RTEST(embs) ? true : false);
2489
- return Qnil;
2490
- }
2491
-
2492
- static VALUE _llama_context_set_n_threads(int argc, VALUE* argv, VALUE self) {
2493
- VALUE kw_args = Qnil;
2494
- ID kw_table[2] = { rb_intern("n_threads"), rb_intern("n_threads_batch") };
2495
- VALUE kw_values[2] = { Qundef, Qundef };
2496
- rb_scan_args(argc, argv, ":", &kw_args);
2497
- rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
2498
-
2499
- VALUE n_threads = kw_values[0];
2500
- if (!RB_INTEGER_TYPE_P(n_threads)) {
2501
- rb_raise(rb_eArgError, "n_threads must be an integer");
2502
- return Qnil;
2503
- }
2504
- VALUE n_threads_batch = kw_values[1];
2505
- if (!RB_INTEGER_TYPE_P(n_threads_batch)) {
2506
- rb_raise(rb_eArgError, "n_threads_batch must be an integer");
2507
- return Qnil;
2508
- }
2509
-
2510
- LLaMAContextWrapper* ptr = get_llama_context(self);
2511
- if (ptr->ctx == NULL) {
2512
- rb_raise(rb_eArgError, "LLaMA context is not initialized");
2513
- return Qnil;
2514
- }
2515
- llama_set_n_threads(ptr->ctx, NUM2UINT(n_threads), NUM2UINT(n_threads_batch));
2516
- return Qnil;
2517
- }
2518
-
2519
- static VALUE _llama_context_n_ctx(VALUE self) {
2520
- LLaMAContextWrapper* ptr = get_llama_context(self);
2521
- if (ptr->ctx == NULL) {
2522
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2523
- return Qnil;
2524
- }
2525
- return UINT2NUM(llama_n_ctx(ptr->ctx));
2526
- }
2527
-
2528
- static VALUE _llama_context_n_batch(VALUE self) {
2529
- LLaMAContextWrapper* ptr = get_llama_context(self);
2530
- if (ptr->ctx == NULL) {
2531
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2532
- return Qnil;
2533
- }
2534
- return UINT2NUM(llama_n_batch(ptr->ctx));
2535
- }
2536
-
2537
- static VALUE _llama_context_n_ubatch(VALUE self) {
2538
- LLaMAContextWrapper* ptr = get_llama_context(self);
2539
- if (ptr->ctx == NULL) {
2540
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2541
- return Qnil;
2542
- }
2543
- return UINT2NUM(llama_n_ubatch(ptr->ctx));
2544
- }
2545
-
2546
- static VALUE _llama_context_n_seq_max(VALUE self) {
2547
- LLaMAContextWrapper* ptr = get_llama_context(self);
2548
- if (ptr->ctx == NULL) {
2549
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2550
- return Qnil;
2551
- }
2552
- return UINT2NUM(llama_n_seq_max(ptr->ctx));
2553
- }
2554
-
2555
- static VALUE _llama_context_n_threads(VALUE self) {
2556
- LLaMAContextWrapper* ptr = get_llama_context(self);
2557
- if (ptr->ctx == NULL) {
2558
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2559
- return Qnil;
2560
- }
2561
- return UINT2NUM(llama_n_threads(ptr->ctx));
2562
- }
2563
-
2564
- static VALUE _llama_context_n_threads_batch(VALUE self) {
2565
- LLaMAContextWrapper* ptr = get_llama_context(self);
2566
- if (ptr->ctx == NULL) {
2567
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2568
- return Qnil;
2569
- }
2570
- return UINT2NUM(llama_n_threads_batch(ptr->ctx));
2571
- }
2572
-
2573
- static VALUE _llama_context_get_timings(VALUE self) {
2574
- LLaMAContextWrapper* ptr = get_llama_context(self);
2575
- if (ptr->ctx == NULL) {
2576
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2577
- return Qnil;
2578
- }
2579
- VALUE tm_obj = rb_funcall(rb_cLLaMATimings, rb_intern("new"), 0);
2580
- LLaMATimingsWrapper* tm_ptr = RbLLaMATimings::get_llama_timings(tm_obj);
2581
- tm_ptr->timings = llama_get_timings(ptr->ctx);
2582
- return tm_obj;
2583
- }
2584
-
2585
- static VALUE _llama_context_print_timings(VALUE self) {
2586
- LLaMAContextWrapper* ptr = get_llama_context(self);
2587
- if (ptr->ctx == NULL) {
2588
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2589
- return Qnil;
2590
- }
2591
- llama_print_timings(ptr->ctx);
2592
- return Qnil;
2593
- }
2594
-
2595
- static VALUE _llama_context_reset_timings(VALUE self) {
2596
- LLaMAContextWrapper* ptr = get_llama_context(self);
2597
- if (ptr->ctx == NULL) {
2598
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2599
- return Qnil;
2600
- }
2601
- llama_reset_timings(ptr->ctx);
2602
- return Qnil;
2603
- }
2604
-
2605
- static VALUE _llama_context_kv_cache_token_count(VALUE self) {
2606
- LLaMAContextWrapper* ptr = get_llama_context(self);
2607
- if (ptr->ctx == NULL) {
2608
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2609
- return Qnil;
2610
- }
2611
- return INT2NUM(llama_get_kv_cache_token_count(ptr->ctx));
2612
- }
2613
-
2614
- static VALUE _llama_context_kv_cache_clear(VALUE self) {
2615
- LLaMAContextWrapper* ptr = get_llama_context(self);
2616
- if (ptr->ctx == NULL) {
2617
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2618
- return Qnil;
2619
- }
2620
- llama_kv_cache_clear(ptr->ctx);
2621
- return Qnil;
2622
- }
2623
-
2624
- static VALUE _llama_context_kv_cache_seq_rm(VALUE self, VALUE seq_id, VALUE p0, VALUE p1) {
2625
- LLaMAContextWrapper* ptr = get_llama_context(self);
2626
- if (ptr->ctx == NULL) {
2627
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2628
- return Qnil;
2629
- }
2630
- llama_kv_cache_seq_rm(ptr->ctx, NUM2INT(seq_id), NUM2INT(p0), NUM2INT(p1));
2631
- return Qnil;
2632
- }
2633
-
2634
- static VALUE _llama_context_kv_cache_seq_cp(VALUE self, VALUE seq_id_src, VALUE seq_id_dst, VALUE p0, VALUE p1) {
2635
- LLaMAContextWrapper* ptr = get_llama_context(self);
2636
- if (ptr->ctx == NULL) {
2637
- rb_raise(rb_eArgError, "LLaMA context is not initialized");
2638
- return Qnil;
2639
- }
2640
- llama_kv_cache_seq_cp(ptr->ctx, NUM2INT(seq_id_src), NUM2INT(seq_id_dst), NUM2INT(p0), NUM2INT(p1));
2641
- return Qnil;
2642
- }
2643
-
2644
- static VALUE _llama_context_kv_cache_seq_keep(VALUE self, VALUE seq_id) {
2645
- LLaMAContextWrapper* ptr = get_llama_context(self);
2646
- if (ptr->ctx == NULL) {
2647
- rb_raise(rb_eArgError, "LLaMA context is not initialized");
2648
- return Qnil;
2649
- }
2650
- llama_kv_cache_seq_keep(ptr->ctx, NUM2INT(seq_id));
2651
- return Qnil;
2652
- }
2653
-
2654
- static VALUE _llama_context_kv_cache_seq_add(VALUE self, VALUE seq_id, VALUE p0, VALUE p1, VALUE delta) {
2655
- LLaMAContextWrapper* ptr = get_llama_context(self);
2656
- if (ptr->ctx == NULL) {
2657
- rb_raise(rb_eArgError, "LLaMA context is not initialized");
2658
- return Qnil;
2659
- }
2660
- llama_kv_cache_seq_add(ptr->ctx, NUM2INT(seq_id), NUM2INT(p0), NUM2INT(p1), NUM2INT(delta));
2661
- return Qnil;
2662
- }
2663
-
2664
- static VALUE _llama_context_kv_cache_seq_div(VALUE self, VALUE seq_id, VALUE p0, VALUE p1, VALUE d) {
2665
- LLaMAContextWrapper* ptr = get_llama_context(self);
2666
- if (ptr->ctx == NULL) {
2667
- rb_raise(rb_eArgError, "LLaMA context is not initialized");
2668
- return Qnil;
2669
- }
2670
- llama_kv_cache_seq_div(ptr->ctx, NUM2INT(seq_id), NUM2INT(p0), NUM2INT(p1), NUM2INT(d));
2671
- return Qnil;
2672
- }
2673
-
2674
- static VALUE _llama_context_kv_cache_seq_pos_max(VALUE self, VALUE seq_id) {
2675
- LLaMAContextWrapper* ptr = get_llama_context(self);
2676
- if (ptr->ctx == NULL) {
2677
- rb_raise(rb_eArgError, "LLaMA context is not initialized");
2678
- return Qnil;
2679
- }
2680
- return INT2NUM(llama_kv_cache_seq_pos_max(ptr->ctx, NUM2INT(seq_id)));
2681
- }
2682
-
2683
- static VALUE _llama_context_kv_cache_defrag(VALUE self) {
2684
- LLaMAContextWrapper* ptr = get_llama_context(self);
2685
- if (ptr->ctx == NULL) {
2686
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2687
- return Qnil;
2688
- }
2689
- llama_kv_cache_defrag(ptr->ctx);
2690
- return Qnil;
2691
- }
2692
-
2693
- static VALUE _llama_context_kv_cache_update(VALUE self) {
2694
- LLaMAContextWrapper* ptr = get_llama_context(self);
2695
- if (ptr->ctx == NULL) {
2696
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2697
- return Qnil;
2698
- }
2699
- llama_kv_cache_update(ptr->ctx);
2700
- return Qnil;
2701
- }
2702
-
2703
- static VALUE _llama_context_set_rng_seed(VALUE self, VALUE seed_) {
2704
- LLaMAContextWrapper* ptr = get_llama_context(self);
2705
- if (ptr->ctx == NULL) {
2706
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2707
- return Qnil;
2708
- }
2709
- if (NUM2INT(seed_) < 0) {
2710
- rb_raise(rb_eArgError, "seed must be a non-negative integer");
2711
- return Qnil;
2712
- }
2713
- const uint32_t seed = NUM2INT(seed_);
2714
- llama_set_rng_seed(ptr->ctx, seed);
2715
- return Qnil;
2716
- }
2717
-
2718
- static VALUE _llama_context_set_causal_attn(VALUE self, VALUE causal_attn) {
2719
- LLaMAContextWrapper* ptr = get_llama_context(self);
2720
- if (ptr->ctx == NULL) {
2721
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2722
- return Qnil;
2723
- }
2724
- llama_set_causal_attn(ptr->ctx, RTEST(causal_attn) ? true : false);
2725
- return Qnil;
2726
- }
2727
-
2728
- static VALUE _llama_context_synchronize(VALUE self) {
2729
- LLaMAContextWrapper* ptr = get_llama_context(self);
2730
- if (ptr->ctx == NULL) {
2731
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2732
- return Qnil;
2733
- }
2734
- llama_synchronize(ptr->ctx);
2735
- return Qnil;
2736
- }
2737
-
2738
- static VALUE _llama_context_load_session_file(int argc, VALUE* argv, VALUE self) {
2739
- VALUE kw_args = Qnil;
2740
- ID kw_table[1] = { rb_intern("session_path") };
2741
- VALUE kw_values[1] = { Qundef };
2742
- VALUE candidates = Qnil;
2743
- VALUE last_n_tokens = Qnil;
2744
- rb_scan_args(argc, argv, ":", &kw_args);
2745
- rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
2746
-
2747
- if (!RB_TYPE_P(kw_values[0], T_STRING)) {
2748
- rb_raise(rb_eArgError, "session_path must be a String");
2749
- return Qnil;
2750
- }
2751
-
2752
- VALUE filename = kw_values[0];
2753
-
2754
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
2755
- if (ctx_ptr->ctx == NULL) {
2756
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2757
- return Qnil;
2758
- }
2759
-
2760
- VALUE model = rb_iv_get(self, "@model");
2761
- VALUE params = rb_iv_get(self, "@params");
2762
- LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
2763
- const int n_ctx = prms_ptr->params.n_ctx;
2764
-
2765
- std::vector<llama_token> session_tokens(n_ctx);
2766
- size_t n_token_count_out = 0;
2767
-
2768
- try {
2769
- bool res = llama_load_session_file(ctx_ptr->ctx, StringValueCStr(filename), session_tokens.data(), session_tokens.capacity(), &n_token_count_out);
2770
- if (!res) {
2771
- rb_raise(rb_eRuntimeError, "Failed to load session file");
2772
- return Qnil;
2773
- }
2774
- session_tokens.resize(n_token_count_out);
2775
- } catch (const std::runtime_error& e) {
2776
- rb_raise(rb_eRuntimeError, "%s", e.what());
2777
- return Qnil;
2778
- }
2779
-
2780
- VALUE ary_session_tokens = rb_ary_new2(n_token_count_out);
2781
- for (size_t i = 0; i < n_token_count_out; i++) {
2782
- rb_ary_store(ary_session_tokens, i, INT2NUM(session_tokens[i]));
2783
- }
2784
-
2785
- RB_GC_GUARD(filename);
2786
- return ary_session_tokens;
2787
- }
2788
-
2789
- static VALUE _llama_context_save_session_file(int argc, VALUE* argv, VALUE self) {
2790
- VALUE kw_args = Qnil;
2791
- ID kw_table[2] = { rb_intern("session_path"), rb_intern("session_tokens") };
2792
- VALUE kw_values[2] = { Qundef, Qundef };
2793
- VALUE candidates = Qnil;
2794
- VALUE last_n_tokens = Qnil;
2795
- rb_scan_args(argc, argv, ":", &kw_args);
2796
- rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
2797
-
2798
- if (!RB_TYPE_P(kw_values[0], T_STRING)) {
2799
- rb_raise(rb_eArgError, "session_path must be a String");
2800
- return Qnil;
2801
- }
2802
- if (!RB_TYPE_P(kw_values[1], T_ARRAY)) {
2803
- rb_raise(rb_eArgError, "session_tokens must be an Array");
2804
- return Qnil;
2805
- }
2806
-
2807
- VALUE filename = kw_values[0];
2808
- const size_t sz_session_tokens = RARRAY_LEN(kw_values[1]);
2809
- std::vector<llama_token> session_tokens(sz_session_tokens);
2810
- for (size_t i = 0; i < sz_session_tokens; i++) {
2811
- session_tokens[i] = NUM2INT(rb_ary_entry(kw_values[1], i));
2812
- }
2813
-
2814
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
2815
- if (ctx_ptr->ctx == NULL) {
2816
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2817
- return Qnil;
2818
- }
2819
-
2820
- bool res = llama_save_session_file(ctx_ptr->ctx, StringValueCStr(filename), session_tokens.data(), sz_session_tokens);
2821
-
2822
- if (!res) {
2823
- rb_raise(rb_eRuntimeError, "Failed to save session file");
2824
- return Qnil;
2825
- }
2826
-
2827
- RB_GC_GUARD(filename);
2828
- return Qnil;
2829
- }
2830
-
2831
- static VALUE _llama_context_sample_repetition_penalties(int argc, VALUE* argv, VALUE self) {
2832
- VALUE kw_args = Qnil;
2833
- ID kw_table[3] = { rb_intern("penalty_repeat"), rb_intern("penalty_freq"), rb_intern("penalty_present") };
2834
- VALUE kw_values[3] = { Qundef, Qundef, Qundef };
2835
- VALUE candidates = Qnil;
2836
- VALUE last_n_tokens = Qnil;
2837
- rb_scan_args(argc, argv, "2:", &candidates, &last_n_tokens, &kw_args);
2838
- rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
2839
-
2840
- if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
2841
- rb_raise(rb_eArgError, "candidates must be a TokenDataArray");
2842
- return Qnil;
2843
- }
2844
- if (!RB_TYPE_P(last_n_tokens, T_ARRAY)) {
2845
- rb_raise(rb_eArgError, "last_n_tokens must be an Array");
2846
- return Qnil;
2847
- }
2848
- if (!RB_FLOAT_TYPE_P(kw_values[0])) {
2849
- rb_raise(rb_eArgError, "penalty_repeat must be a float");
2850
- return Qnil;
2851
- }
2852
- if (!RB_FLOAT_TYPE_P(kw_values[1])) {
2853
- rb_raise(rb_eArgError, "penalty_freq must be a float");
2854
- return Qnil;
2855
- }
2856
- if (!RB_FLOAT_TYPE_P(kw_values[2])) {
2857
- rb_raise(rb_eArgError, "penalty_present must be a float");
2858
- return Qnil;
2859
- }
2860
-
2861
- const size_t last_tokens_size = RARRAY_LEN(last_n_tokens);
2862
- std::vector<llama_token> last_n_tokens_data(last_tokens_size);
2863
- for (size_t i = 0; i < last_tokens_size; i++) {
2864
- last_n_tokens_data[i] = NUM2INT(rb_ary_entry(last_n_tokens, i));
2865
- }
2866
-
2867
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
2868
- if (ctx_ptr->ctx == NULL) {
2869
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2870
- return Qnil;
2871
- }
2872
- LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
2873
- if (cnd_ptr->array.data == nullptr) {
2874
- rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
2875
- return Qnil;
2876
- }
2877
- const float penalty_repeat = NUM2DBL(kw_values[0]);
2878
- const float penalty_freq = NUM2DBL(kw_values[1]);
2879
- const float penalty_present = NUM2DBL(kw_values[2]);
2880
-
2881
- llama_sample_repetition_penalties(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size,
2882
- penalty_repeat, penalty_freq, penalty_present);
2883
-
2884
- return Qnil;
2885
- }
2886
-
2887
- static VALUE _llama_context_sample_apply_guidance(int argc, VALUE* argv, VALUE self) {
2888
- VALUE kw_args = Qnil;
2889
- ID kw_table[3] = { rb_intern("logits"), rb_intern("logits_guidance"), rb_intern("scale") };
2890
- VALUE kw_values[3] = { Qundef, Qundef, Qundef };
2891
- rb_scan_args(argc, argv, ":", &kw_args);
2892
- rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
2893
-
2894
- if (!RB_TYPE_P(kw_values[0], T_ARRAY)) {
2895
- rb_raise(rb_eArgError, "logits must be an Array");
2896
- return Qnil;
2897
- }
2898
- if (!RB_TYPE_P(kw_values[1], T_ARRAY)) {
2899
- rb_raise(rb_eArgError, "logits_guidance must be an Array");
2900
- return Qnil;
2901
- }
2902
- if (!RB_FLOAT_TYPE_P(kw_values[2])) {
2903
- rb_raise(rb_eArgError, "scale must be a float");
2904
- return Qnil;
2905
- }
2906
-
2907
- const size_t sz_logits = RARRAY_LEN(kw_values[0]);
2908
- std::vector<float> logits(sz_logits);
2909
- for (size_t i = 0; i < sz_logits; i++) {
2910
- logits[i] = NUM2DBL(rb_ary_entry(kw_values[0], i));
2911
- }
2912
-
2913
- const size_t sz_logits_guidance = RARRAY_LEN(kw_values[1]);
2914
- std::vector<float> logits_guidance(sz_logits_guidance);
2915
- for (size_t i = 0; i < sz_logits_guidance; i++) {
2916
- logits_guidance[i] = NUM2DBL(rb_ary_entry(kw_values[1], i));
2917
- }
2918
-
2919
- const float scale = NUM2DBL(kw_values[2]);
2920
-
2921
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
2922
- if (ctx_ptr->ctx == NULL) {
2923
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2924
- return Qnil;
2925
- }
2926
-
2927
- llama_sample_apply_guidance(ctx_ptr->ctx, logits.data(), logits_guidance.data(), scale);
2928
-
2929
- return Qnil;
2930
- }
2931
-
2932
- static VALUE _llama_context_sample_softmax(VALUE self, VALUE candidates) {
2933
- if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
2934
- rb_raise(rb_eArgError, "argument must be a TokenDataArray");
2935
- return Qnil;
2936
- }
2937
-
2938
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
2939
- if (ctx_ptr->ctx == NULL) {
2940
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2941
- return Qnil;
2942
- }
2943
- LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
2944
- if (cnd_ptr->array.data == nullptr) {
2945
- rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
2946
- return Qnil;
2947
- }
2948
-
2949
- llama_sample_softmax(ctx_ptr->ctx, &(cnd_ptr->array));
2950
-
2951
- return Qnil;
2952
- }
2953
-
2954
- static VALUE _llama_context_sample_top_k(int argc, VALUE* argv, VALUE self) {
2955
- VALUE kw_args = Qnil;
2956
- ID kw_table[2] = { rb_intern("k"), rb_intern("min_keep") };
2957
- VALUE kw_values[2] = { Qundef, Qundef };
2958
- VALUE candidates = Qnil;
2959
- rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
2960
- rb_get_kwargs(kw_args, kw_table, 1, 1, kw_values);
2961
-
2962
- if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
2963
- rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
2964
- return Qnil;
2965
- }
2966
- if (!RB_INTEGER_TYPE_P(kw_values[0])) {
2967
- rb_raise(rb_eArgError, "k must be an integer");
2968
- return Qnil;
2969
- }
2970
- if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
2971
- rb_raise(rb_eArgError, "min_keep must be an integer");
2972
- return Qnil;
2973
- }
2974
-
2975
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
2976
- if (ctx_ptr->ctx == NULL) {
2977
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2978
- return Qnil;
2979
- }
2980
- LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
2981
- if (cnd_ptr->array.data == nullptr) {
2982
- rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
2983
- return Qnil;
2984
- }
2985
- const int k = NUM2DBL(kw_values[0]);
2986
- const size_t min_keep = kw_values[1] != Qundef ? NUM2SIZET(kw_values[1]) : 1;
2987
-
2988
- llama_sample_top_k(ctx_ptr->ctx, &(cnd_ptr->array), k, min_keep);
2989
-
2990
- return Qnil;
2991
- }
2992
-
2993
- static VALUE _llama_context_sample_top_p(int argc, VALUE* argv, VALUE self) {
2994
- VALUE kw_args = Qnil;
2995
- ID kw_table[2] = { rb_intern("prob"), rb_intern("min_keep") };
2996
- VALUE kw_values[2] = { Qundef, Qundef };
2997
- VALUE candidates = Qnil;
2998
- rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
2999
- rb_get_kwargs(kw_args, kw_table, 1, 1, kw_values);
3000
-
3001
- if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
3002
- rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
3003
- return Qnil;
3004
- }
3005
- if (!RB_FLOAT_TYPE_P(kw_values[0])) {
3006
- rb_raise(rb_eArgError, "prob must be a float");
3007
- return Qnil;
3008
- }
3009
- if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
3010
- rb_raise(rb_eArgError, "min_keep must be an integer");
3011
- return Qnil;
3012
- }
3013
-
3014
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
3015
- if (ctx_ptr->ctx == NULL) {
3016
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
3017
- return Qnil;
3018
- }
3019
- LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
3020
- if (cnd_ptr->array.data == nullptr) {
3021
- rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
3022
- return Qnil;
3023
- }
3024
- const float prob = NUM2DBL(kw_values[0]);
3025
- const size_t min_keep = kw_values[1] != Qundef ? NUM2SIZET(kw_values[1]) : 1;
3026
-
3027
- llama_sample_top_p(ctx_ptr->ctx, &(cnd_ptr->array), prob, min_keep);
3028
-
3029
- return Qnil;
3030
- }
3031
-
3032
- static VALUE _llama_context_sample_min_p(int argc, VALUE* argv, VALUE self) {
3033
- VALUE kw_args = Qnil;
3034
- ID kw_table[2] = { rb_intern("prob"), rb_intern("min_keep") };
3035
- VALUE kw_values[2] = { Qundef, Qundef };
3036
- VALUE candidates = Qnil;
3037
- rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
3038
- rb_get_kwargs(kw_args, kw_table, 1, 1, kw_values);
3039
-
3040
- if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
3041
- rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
3042
- return Qnil;
3043
- }
3044
- if (!RB_FLOAT_TYPE_P(kw_values[0])) {
3045
- rb_raise(rb_eArgError, "prob must be a float");
3046
- return Qnil;
3047
- }
3048
- if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
3049
- rb_raise(rb_eArgError, "min_keep must be an integer");
3050
- return Qnil;
3051
- }
3052
-
3053
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
3054
- if (ctx_ptr->ctx == NULL) {
3055
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
3056
- return Qnil;
3057
- }
3058
- LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
3059
- if (cnd_ptr->array.data == nullptr) {
3060
- rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
3061
- return Qnil;
3062
- }
3063
- const float prob = NUM2DBL(kw_values[0]);
3064
- const size_t min_keep = kw_values[1] != Qundef ? NUM2SIZET(kw_values[1]) : 1;
3065
-
3066
- llama_sample_min_p(ctx_ptr->ctx, &(cnd_ptr->array), prob, min_keep);
3067
-
3068
- return Qnil;
3069
- }
3070
-
3071
- static VALUE _llama_context_sample_tail_free(int argc, VALUE* argv, VALUE self) {
3072
- VALUE kw_args = Qnil;
3073
- ID kw_table[2] = { rb_intern("z"), rb_intern("min_keep") };
3074
- VALUE kw_values[2] = { Qundef, Qundef };
3075
- VALUE candidates = Qnil;
3076
- rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
3077
- rb_get_kwargs(kw_args, kw_table, 1, 1, kw_values);
3078
-
3079
- if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
3080
- rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
3081
- return Qnil;
3082
- }
3083
- if (!RB_FLOAT_TYPE_P(kw_values[0])) {
3084
- rb_raise(rb_eArgError, "prob must be a float");
3085
- return Qnil;
3086
- }
3087
- if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
3088
- rb_raise(rb_eArgError, "min_keep must be an integer");
3089
- return Qnil;
3090
- }
3091
-
3092
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
3093
- if (ctx_ptr->ctx == NULL) {
3094
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
3095
- return Qnil;
3096
- }
3097
- LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
3098
- if (cnd_ptr->array.data == nullptr) {
3099
- rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
3100
- return Qnil;
3101
- }
3102
- const float z = NUM2DBL(kw_values[0]);
3103
- const size_t min_keep = kw_values[1] != Qundef ? NUM2SIZET(kw_values[1]) : 1;
3104
-
3105
- llama_sample_tail_free(ctx_ptr->ctx, &(cnd_ptr->array), z, min_keep);
3106
-
3107
- return Qnil;
3108
- }
3109
-
3110
- static VALUE _llama_context_sample_typical(int argc, VALUE* argv, VALUE self) {
3111
- VALUE kw_args = Qnil;
3112
- ID kw_table[2] = { rb_intern("prob"), rb_intern("min_keep") };
3113
- VALUE kw_values[2] = { Qundef, Qundef };
3114
- VALUE candidates = Qnil;
3115
- rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
3116
- rb_get_kwargs(kw_args, kw_table, 1, 1, kw_values);
3117
-
3118
- if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
3119
- rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
3120
- return Qnil;
3121
- }
3122
- if (!RB_FLOAT_TYPE_P(kw_values[0])) {
3123
- rb_raise(rb_eArgError, "prob must be a float");
3124
- return Qnil;
3125
- }
3126
- if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
3127
- rb_raise(rb_eArgError, "min_keep must be an integer");
3128
- return Qnil;
3129
- }
3130
-
3131
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
3132
- if (ctx_ptr->ctx == NULL) {
3133
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
3134
- return Qnil;
3135
- }
3136
- LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
3137
- if (cnd_ptr->array.data == nullptr) {
3138
- rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
3139
- return Qnil;
3140
- }
3141
- const float prob = NUM2DBL(kw_values[0]);
3142
- const size_t min_keep = kw_values[1] != Qundef ? NUM2SIZET(kw_values[1]) : 1;
3143
-
3144
- llama_sample_typical(ctx_ptr->ctx, &(cnd_ptr->array), prob, min_keep);
3145
-
3146
- return Qnil;
3147
- }
3148
-
3149
- static VALUE _llama_context_sample_temp(int argc, VALUE* argv, VALUE self) {
3150
- VALUE kw_args = Qnil;
3151
- ID kw_table[1] = { rb_intern("temp") };
3152
- VALUE kw_values[1] = { Qundef };
3153
- VALUE candidates = Qnil;
3154
- rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
3155
- rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
3156
-
3157
- if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
3158
- rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
3159
- return Qnil;
3160
- }
3161
- if (!RB_FLOAT_TYPE_P(kw_values[0])) {
3162
- rb_raise(rb_eArgError, "temp must be a float");
3163
- return Qnil;
3164
- }
3165
-
3166
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
3167
- if (ctx_ptr->ctx == NULL) {
3168
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
3169
- return Qnil;
3170
- }
3171
- LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
3172
- if (cnd_ptr->array.data == nullptr) {
3173
- rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
3174
- return Qnil;
3175
- }
3176
- const float temp = NUM2DBL(kw_values[0]);
3177
-
3178
- llama_sample_temp(ctx_ptr->ctx, &(cnd_ptr->array), temp);
3179
-
3180
- return Qnil;
3181
- }
3182
-
3183
- static VALUE _llama_context_sample_entropy(int argc, VALUE* argv, VALUE self) {
3184
- VALUE kw_args = Qnil;
3185
- ID kw_table[3] = { rb_intern("min_temp"), rb_intern("max_temp"), rb_intern("exponent_val") };
3186
- VALUE kw_values[3] = { Qundef, Qundef, Qundef };
3187
- VALUE candidates = Qnil;
3188
- rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
3189
- rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
3190
-
3191
- if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
3192
- rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
3193
- return Qnil;
3194
- }
3195
- if (!RB_FLOAT_TYPE_P(kw_values[0])) {
3196
- rb_raise(rb_eArgError, "min_temp must be a float");
3197
- return Qnil;
3198
- }
3199
- if (!RB_FLOAT_TYPE_P(kw_values[1])) {
3200
- rb_raise(rb_eArgError, "max_temp must be a float");
3201
- return Qnil;
3202
- }
3203
- if (!RB_FLOAT_TYPE_P(kw_values[2])) {
3204
- rb_raise(rb_eArgError, "exponent_val must be a float");
3205
- return Qnil;
3206
- }
3207
-
3208
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
3209
- if (ctx_ptr->ctx == NULL) {
3210
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
3211
- return Qnil;
3212
- }
3213
- LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
3214
- if (cnd_ptr->array.data == nullptr) {
3215
- rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
3216
- return Qnil;
3217
- }
3218
- const float min_temp = NUM2DBL(kw_values[0]);
3219
- const float max_temp = NUM2DBL(kw_values[1]);
3220
- const float exponent_val = NUM2DBL(kw_values[2]);
3221
-
3222
- llama_sample_entropy(ctx_ptr->ctx, &(cnd_ptr->array), min_temp, max_temp, exponent_val);
3223
-
3224
- return Qnil;
3225
- }
3226
-
3227
- static VALUE _llama_context_sample_token_mirostat(int argc, VALUE* argv, VALUE self) {
3228
- VALUE kw_args = Qnil;
3229
- ID kw_table[4] = { rb_intern("tau"), rb_intern("eta"), rb_intern("m"), rb_intern("mu") };
3230
- VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
3231
- VALUE candidates = Qnil;
3232
- rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
3233
- rb_get_kwargs(kw_args, kw_table, 4, 0, kw_values);
3234
-
3235
- if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
3236
- rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
3237
- return Qnil;
3238
- }
3239
- if (!RB_FLOAT_TYPE_P(kw_values[0])) {
3240
- rb_raise(rb_eArgError, "tau must be a float");
3241
- return Qnil;
3242
- }
3243
- if (!RB_FLOAT_TYPE_P(kw_values[1])) {
3244
- rb_raise(rb_eArgError, "eta must be a float");
3245
- return Qnil;
3246
- }
3247
- if (!RB_INTEGER_TYPE_P(kw_values[2])) {
3248
- rb_raise(rb_eArgError, "m must be an integer");
3249
- return Qnil;
3250
- }
3251
- if (!RB_FLOAT_TYPE_P(kw_values[3])) {
3252
- rb_raise(rb_eArgError, "mu must be a float");
3253
- return Qnil;
3254
- }
3255
-
3256
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
3257
- if (ctx_ptr->ctx == NULL) {
3258
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
3259
- return Qnil;
3260
- }
3261
- LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
3262
- if (cnd_ptr->array.data == nullptr) {
3263
- rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
3264
- return Qnil;
3265
- }
3266
- const float tau = NUM2DBL(kw_values[0]);
3267
- const float eta = NUM2DBL(kw_values[1]);
3268
- const int m = NUM2INT(kw_values[2]);
3269
- float mu = NUM2DBL(kw_values[3]);
3270
-
3271
- llama_token id = llama_sample_token_mirostat(ctx_ptr->ctx, &(cnd_ptr->array), tau, eta, m, &mu);
3272
-
3273
- VALUE ret = rb_ary_new2(2);
3274
- rb_ary_store(ret, 0, INT2NUM(id));
3275
- rb_ary_store(ret, 1, DBL2NUM(mu));
3276
- return ret;
3277
- }
3278
-
3279
- static VALUE _llama_context_sample_token_mirostat_v2(int argc, VALUE* argv, VALUE self) {
3280
- VALUE kw_args = Qnil;
3281
- ID kw_table[3] = { rb_intern("tau"), rb_intern("eta"), rb_intern("mu") };
3282
- VALUE kw_values[3] = { Qundef, Qundef, Qundef };
3283
- VALUE candidates = Qnil;
3284
- rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
3285
- rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
3286
-
3287
- if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
3288
- rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
3289
- return Qnil;
3290
- }
3291
- if (!RB_FLOAT_TYPE_P(kw_values[0])) {
3292
- rb_raise(rb_eArgError, "tau must be a float");
3293
- return Qnil;
3294
- }
3295
- if (!RB_FLOAT_TYPE_P(kw_values[1])) {
3296
- rb_raise(rb_eArgError, "eta must be a float");
3297
- return Qnil;
3298
- }
3299
- if (!RB_FLOAT_TYPE_P(kw_values[2])) {
3300
- rb_raise(rb_eArgError, "mu must be a float");
3301
- return Qnil;
3302
- }
3303
-
3304
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
3305
- if (ctx_ptr->ctx == NULL) {
3306
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
3307
- return Qnil;
3308
- }
3309
- LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
3310
- if (cnd_ptr->array.data == nullptr) {
3311
- rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
3312
- return Qnil;
3313
- }
3314
- const float tau = NUM2DBL(kw_values[0]);
3315
- const float eta = NUM2DBL(kw_values[1]);
3316
- float mu = NUM2DBL(kw_values[2]);
3317
-
3318
- llama_token id = llama_sample_token_mirostat_v2(ctx_ptr->ctx, &(cnd_ptr->array), tau, eta, &mu);
3319
-
3320
- VALUE ret = rb_ary_new2(2);
3321
- rb_ary_store(ret, 0, INT2NUM(id));
3322
- rb_ary_store(ret, 1, DBL2NUM(mu));
3323
- return ret;
3324
- }
3325
-
3326
- static VALUE _llama_context_sample_token_greedy(VALUE self, VALUE candidates) {
3327
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
3328
- if (ctx_ptr->ctx == NULL) {
3329
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
3330
- return Qnil;
3331
- }
3332
- if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
3333
- rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
3334
- return Qnil;
3335
- }
3336
- LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
3337
- if (cnd_ptr->array.data == nullptr) {
3338
- rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
3339
- return Qnil;
3340
- }
3341
- llama_token id = llama_sample_token_greedy(ctx_ptr->ctx, &(cnd_ptr->array));
3342
- return INT2NUM(id);
3343
- }
3344
-
3345
- static VALUE _llama_context_sample_token(VALUE self, VALUE candidates) {
3346
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
3347
- if (ctx_ptr->ctx == NULL) {
3348
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
3349
- return Qnil;
3350
- }
3351
- if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
3352
- rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
3353
- return Qnil;
3354
- }
3355
- LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
3356
- if (cnd_ptr->array.data == nullptr) {
3357
- rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
3358
- return Qnil;
3359
- }
3360
- llama_token id = llama_sample_token(ctx_ptr->ctx, &(cnd_ptr->array));
3361
- return INT2NUM(id);
3362
- }
3363
-
3364
- static VALUE _llama_context_sample_grammar(int argc, VALUE* argv, VALUE self) {
3365
- VALUE kw_args = Qnil;
3366
- ID kw_table[1] = { rb_intern("grammar") };
3367
- VALUE kw_values[1] = { Qundef };
3368
- VALUE candidates = Qnil;
3369
- rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
3370
- rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
3371
-
3372
- if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
3373
- rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
3374
- return Qnil;
3375
- }
3376
- if (!rb_obj_is_kind_of(kw_values[0], rb_cLLaMAGrammar)) {
3377
- rb_raise(rb_eArgError, "grammar must be a Grammar");
3378
- return Qnil;
3379
- }
3380
-
3381
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
3382
- if (ctx_ptr->ctx == NULL) {
3383
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
3384
- return Qnil;
3385
- }
3386
- LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
3387
- if (cnd_ptr->array.data == nullptr) {
3388
- rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
3389
- return Qnil;
3390
- }
3391
- LLaMAGrammarWrapper* grm_ptr = RbLLaMAGrammar::get_llama_grammar(kw_values[0]);
3392
-
3393
- llama_grammar_sample(grm_ptr->grammar, ctx_ptr->ctx, &(cnd_ptr->array));
3394
-
3395
- return Qnil;
3396
- }
3397
-
3398
- static VALUE _llama_context_grammar_accept_token(int argc, VALUE* argv, VALUE self) {
3399
- VALUE kw_args = Qnil;
3400
- ID kw_table[2] = { rb_intern("grammar"), rb_intern("token") };
3401
- VALUE kw_values[2] = { Qundef, Qundef };
3402
- rb_scan_args(argc, argv, ":", &kw_args);
3403
- rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
3404
-
3405
- if (!rb_obj_is_kind_of(kw_values[0], rb_cLLaMAGrammar)) {
3406
- rb_raise(rb_eArgError, "grammar must be a Grammar");
3407
- return Qnil;
3408
- }
3409
- if (!RB_INTEGER_TYPE_P(kw_values[1])) {
3410
- rb_raise(rb_eArgError, "token must be an Integer");
3411
- return Qnil;
3412
- }
3413
-
3414
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
3415
- if (ctx_ptr->ctx == NULL) {
3416
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
3417
- return Qnil;
3418
- }
3419
- LLaMAGrammarWrapper* grm_ptr = RbLLaMAGrammar::get_llama_grammar(kw_values[0]);
3420
- llama_token token = NUM2INT(kw_values[1]);
3421
-
3422
- llama_grammar_accept_token(grm_ptr->grammar, ctx_ptr->ctx, token);
3423
-
3424
- return Qnil;
3425
- }
3426
-
3427
- static VALUE _llama_context_apply_control_vector(int argc, VALUE* argv, VALUE self) {
3428
- VALUE kw_args = Qnil;
3429
- ID kw_table[4] = { rb_intern("data"), rb_intern("n_embd"), rb_intern("il_start"), rb_intern("il_end") };
3430
- VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
3431
- rb_scan_args(argc, argv, ":", &kw_args);
3432
- rb_get_kwargs(kw_args, kw_table, 4, 0, kw_values);
3433
-
3434
- if (!RB_TYPE_P(kw_values[0], T_ARRAY) && !NIL_P(kw_values[0])) {
3435
- rb_raise(rb_eArgError, "data must be an Array or nil");
3436
- return Qnil;
3437
- }
3438
- if (!RB_INTEGER_TYPE_P(kw_values[1])) {
3439
- rb_raise(rb_eArgError, "n_embd must be an Integer");
3440
- return Qnil;
3441
- }
3442
- if (!RB_INTEGER_TYPE_P(kw_values[2])) {
3443
- rb_raise(rb_eArgError, "il_start must be an Integer");
3444
- return Qnil;
3445
- }
3446
- if (!RB_INTEGER_TYPE_P(kw_values[3])) {
3447
- rb_raise(rb_eArgError, "il_end must be an Integer");
3448
- return Qnil;
3449
- }
3450
-
3451
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
3452
- if (ctx_ptr->ctx == NULL) {
3453
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
3454
- return Qnil;
3455
- }
3456
-
3457
- std::vector<float> data(RARRAY_LEN(kw_values[0]));
3458
- for (size_t i = 0; i < data.size(); i++) {
3459
- data[i] = NUM2DBL(rb_ary_entry(kw_values[0], i));
3460
- }
3461
- const int32_t n_embd = NUM2INT(kw_values[1]);
3462
- const int32_t il_start = NUM2INT(kw_values[2]);
3463
- const int32_t il_end = NUM2INT(kw_values[3]);
3464
-
3465
- int32_t err = 0;
3466
- if (NIL_P(kw_values[0])) {
3467
- err = llama_control_vector_apply(ctx_ptr->ctx, NULL, 0, n_embd, il_start, il_end);
3468
- } else {
3469
- err = llama_control_vector_apply(ctx_ptr->ctx, data.data(), data.size(), n_embd, il_start, il_end);
3470
- }
3471
-
3472
- if (err) {
3473
- rb_raise(rb_eRuntimeError, "Failed to apply control vector");
3474
- return Qnil;
3475
- }
3476
-
3477
- return Qnil;
3478
- }
3479
-
3480
- static VALUE _llama_context_pooling_type(VALUE self) {
3481
- LLaMAContextWrapper* ptr = get_llama_context(self);
3482
- if (ptr->ctx == NULL) {
3483
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
3484
- return Qnil;
3485
- }
3486
- return INT2NUM(static_cast<int>(llama_pooling_type(ptr->ctx)));
3487
- }
3488
- };
3489
-
3490
- const rb_data_type_t RbLLaMAContext::llama_context_type = {
3491
- "RbLLaMAContext",
3492
- { NULL,
3493
- RbLLaMAContext::llama_context_free,
3494
- RbLLaMAContext::llama_context_size },
3495
- NULL,
3496
- NULL,
3497
- RUBY_TYPED_FREE_IMMEDIATELY
3498
- };
3499
-
3500
- // module functions
3501
-
3502
- static VALUE rb_llama_llama_backend_init(VALUE self) {
3503
- llama_backend_init();
3504
-
3505
- return Qnil;
3506
- }
3507
-
3508
- static VALUE rb_llama_llama_backend_free(VALUE self) {
3509
- llama_backend_free();
3510
-
3511
- return Qnil;
3512
- }
3513
-
3514
- static VALUE rb_llama_llama_numa_init(VALUE self, VALUE strategy) {
3515
- if (!RB_INTEGER_TYPE_P(strategy)) {
3516
- rb_raise(rb_eArgError, "strategy must be an integer");
3517
- return Qnil;
3518
- }
3519
-
3520
- llama_numa_init(static_cast<enum ggml_numa_strategy>(NUM2INT(strategy)));
3521
-
3522
- return Qnil;
3523
- }
3524
-
3525
- static VALUE rb_llama_model_quantize(int argc, VALUE* argv, VALUE self) {
3526
- VALUE kw_args = Qnil;
3527
- ID kw_table[3] = { rb_intern("input_path"), rb_intern("output_path"), rb_intern("params") };
3528
- VALUE kw_values[3] = { Qundef, Qundef, Qundef };
3529
- rb_scan_args(argc, argv, ":", &kw_args);
3530
- rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
3531
-
3532
- if (!RB_TYPE_P(kw_values[0], T_STRING)) {
3533
- rb_raise(rb_eArgError, "input_path must be a string");
3534
- return Qnil;
3535
- }
3536
- if (!RB_TYPE_P(kw_values[1], T_STRING)) {
3537
- rb_raise(rb_eArgError, "output_path must be a string");
3538
- return Qnil;
3539
- }
3540
- if (!rb_obj_is_kind_of(kw_values[2], rb_cLLaMAModelQuantizeParams)) {
3541
- rb_raise(rb_eArgError, "params must be a ModelQuantizeParams");
3542
- return Qnil;
3543
- }
3544
-
3545
- const char* input_path = StringValueCStr(kw_values[0]);
3546
- const char* output_path = StringValueCStr(kw_values[1]);
3547
- LLaMAModelQuantizeParamsWrapper* wrapper = RbLLaMAModelQuantizeParams::get_llama_model_quantize_params(kw_values[2]);
3548
-
3549
- if (llama_model_quantize(input_path, output_path, &(wrapper->params)) != 0) {
3550
- rb_raise(rb_eRuntimeError, "Failed to quantize model");
3551
- return Qnil;
3552
- }
3553
-
3554
- return Qnil;
3555
- }
3556
-
3557
- static VALUE rb_llama_print_system_info(VALUE self) {
3558
- const char* result = llama_print_system_info();
3559
- return rb_utf8_str_new_cstr(result);
3560
- }
3561
-
3562
- static VALUE rb_llama_time_us(VALUE self) {
3563
- return LONG2NUM(llama_time_us());
3564
- }
3565
-
3566
- static VALUE rb_llama_max_devices(VALUE self) {
3567
- return SIZET2NUM(llama_max_devices());
3568
- }
3569
-
3570
- static VALUE rb_llama_supports_mmap(VALUE self) {
3571
- return llama_supports_mmap() ? Qtrue : Qfalse;
3572
- }
3573
-
3574
- static VALUE rb_llama_supports_mlock(VALUE self) {
3575
- return llama_supports_mlock() ? Qtrue : Qfalse;
3576
- }
3577
-
3578
- static VALUE rb_llama_supports_gpu_offload(VALUE self) {
3579
- return llama_supports_gpu_offload() ? Qtrue : Qfalse;
3580
- }
3581
-
3582
- extern "C" void Init_llama_cpp(void) {
3583
- rb_mLLaMACpp = rb_define_module("LLaMACpp");
3584
-
3585
- RbLLaMABatch::define_class(rb_mLLaMACpp);
3586
- RbLLaMATokenData::define_class(rb_mLLaMACpp);
3587
- RbLLaMATokenDataArray::define_class(rb_mLLaMACpp);
3588
- RbLLaMAModel::define_class(rb_mLLaMACpp);
3589
- RbLLaMAModelKVOverride::define_class(rb_mLLaMACpp);
3590
- RbLLaMAModelParams::define_class(rb_mLLaMACpp);
3591
- RbLLaMATimings::define_class(rb_mLLaMACpp);
3592
- RbLLaMAContext::define_class(rb_mLLaMACpp);
3593
- RbLLaMAContextParams::define_class(rb_mLLaMACpp);
3594
- RbLLaMAModelQuantizeParams::define_class(rb_mLLaMACpp);
3595
- RbLLaMAGrammarElement::define_class(rb_mLLaMACpp);
3596
- RbLLaMAGrammar::define_class(rb_mLLaMACpp);
3597
-
3598
- rb_define_module_function(rb_mLLaMACpp, "backend_init", rb_llama_llama_backend_init, 0);
3599
- rb_define_module_function(rb_mLLaMACpp, "backend_free", rb_llama_llama_backend_free, 0);
3600
- rb_define_module_function(rb_mLLaMACpp, "numa_init", rb_llama_llama_numa_init, 1);
3601
- rb_define_module_function(rb_mLLaMACpp, "model_quantize", rb_llama_model_quantize, -1);
3602
- rb_define_module_function(rb_mLLaMACpp, "print_system_info", rb_llama_print_system_info, 0);
3603
- rb_define_module_function(rb_mLLaMACpp, "time_us", rb_llama_time_us, 0);
3604
- rb_define_module_function(rb_mLLaMACpp, "max_devices", rb_llama_max_devices, 0);
3605
- rb_define_module_function(rb_mLLaMACpp, "supports_mmap?", rb_llama_supports_mmap, 0);
3606
- rb_define_module_function(rb_mLLaMACpp, "supports_mlock?", rb_llama_supports_mlock, 0);
3607
- rb_define_module_function(rb_mLLaMACpp, "supports_gpu_offload?", rb_llama_supports_gpu_offload, 0);
3608
-
3609
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_TYPE_NONE", INT2NUM(LLAMA_VOCAB_TYPE_NONE));
3610
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_TYPE_SPM", INT2NUM(LLAMA_VOCAB_TYPE_SPM));
3611
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_TYPE_BPE", INT2NUM(LLAMA_VOCAB_TYPE_BPE));
3612
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_TYPE_WPM", INT2NUM(LLAMA_VOCAB_TYPE_WPM));
3613
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_TYPE_UGM", INT2NUM(LLAMA_VOCAB_TYPE_UGM));
3614
-
3615
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_DEFAULT", INT2NUM(LLAMA_VOCAB_PRE_TYPE_DEFAULT));
3616
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_LLAMA3", INT2NUM(LLAMA_VOCAB_PRE_TYPE_LLAMA3));
3617
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM", INT2NUM(LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM));
3618
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER", INT2NUM(LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER));
3619
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_FALCON", INT2NUM(LLAMA_VOCAB_PRE_TYPE_FALCON));
3620
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_MPT", INT2NUM(LLAMA_VOCAB_PRE_TYPE_MPT));
3621
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_STARCODER", INT2NUM(LLAMA_VOCAB_PRE_TYPE_STARCODER));
3622
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_GPT2", INT2NUM(LLAMA_VOCAB_PRE_TYPE_GPT2));
3623
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_REFACT", INT2NUM(LLAMA_VOCAB_PRE_TYPE_REFACT));
3624
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_COMMAND_R", INT2NUM(LLAMA_VOCAB_PRE_TYPE_COMMAND_R));
3625
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_STABLELM2", INT2NUM(LLAMA_VOCAB_PRE_TYPE_STABLELM2));
3626
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_QWEN2", INT2NUM(LLAMA_VOCAB_PRE_TYPE_QWEN2));
3627
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_OLMO", INT2NUM(LLAMA_VOCAB_PRE_TYPE_OLMO));
3628
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_DBRX", INT2NUM(LLAMA_VOCAB_PRE_TYPE_DBRX));
3629
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_SMAUG", INT2NUM(LLAMA_VOCAB_PRE_TYPE_SMAUG));
3630
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_PORO", INT2NUM(LLAMA_VOCAB_PRE_TYPE_PORO));
3631
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_CHATGLM3", INT2NUM(LLAMA_VOCAB_PRE_TYPE_CHATGLM3));
3632
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_CHATGLM4", INT2NUM(LLAMA_VOCAB_PRE_TYPE_CHATGLM4));
3633
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_VIKING", INT2NUM(LLAMA_VOCAB_PRE_TYPE_VIKING));
3634
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_JAIS", INT2NUM(LLAMA_VOCAB_PRE_TYPE_JAIS));
3635
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_TEKKEN", INT2NUM(LLAMA_VOCAB_PRE_TYPE_TEKKEN));
3636
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_SMOLLM", INT2NUM(LLAMA_VOCAB_PRE_TYPE_SMOLLM));
3637
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_CODESHELL", INT2NUM(LLAMA_VOCAB_PRE_TYPE_CODESHELL));
3638
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_BLOOM", INT2NUM(LLAMA_VOCAB_PRE_TYPE_BLOOM));
3639
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH", INT2NUM(LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH));
3640
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_EXAONE", INT2NUM(LLAMA_VOCAB_PRE_TYPE_EXAONE));
3641
-
3642
- rb_define_const(rb_mLLaMACpp, "LLAMA_TOKEN_TYPE_UNDEFINED", INT2NUM(LLAMA_TOKEN_TYPE_UNDEFINED));
3643
- rb_define_const(rb_mLLaMACpp, "LLAMA_TOKEN_TYPE_NORMAL", INT2NUM(LLAMA_TOKEN_TYPE_NORMAL));
3644
- rb_define_const(rb_mLLaMACpp, "LLAMA_TOKEN_TYPE_UNKNOWN", INT2NUM(LLAMA_TOKEN_TYPE_UNKNOWN));
3645
- rb_define_const(rb_mLLaMACpp, "LLAMA_TOKEN_TYPE_CONTROL", INT2NUM(LLAMA_TOKEN_TYPE_CONTROL));
3646
- rb_define_const(rb_mLLaMACpp, "LLAMA_TOKEN_TYPE_USER_DEFINED", INT2NUM(LLAMA_TOKEN_TYPE_USER_DEFINED));
3647
- rb_define_const(rb_mLLaMACpp, "LLAMA_TOKEN_TYPE_UNUSED", INT2NUM(LLAMA_TOKEN_TYPE_UNUSED));
3648
- rb_define_const(rb_mLLaMACpp, "LLAMA_TOKEN_TYPE_BYTE", INT2NUM(LLAMA_TOKEN_TYPE_BYTE));
3649
-
3650
- rb_define_const(rb_mLLaMACpp, "LLAMA_TOKEN_ATTR_UNDEFINED", INT2NUM(LLAMA_TOKEN_ATTR_UNDEFINED));
3651
- rb_define_const(rb_mLLaMACpp, "LLAMA_TOKEN_ATTR_UNKNOWN", INT2NUM(LLAMA_TOKEN_ATTR_UNKNOWN));
3652
- rb_define_const(rb_mLLaMACpp, "LLAMA_TOKEN_ATTR_UNUSED", INT2NUM(LLAMA_TOKEN_ATTR_UNUSED));
3653
- rb_define_const(rb_mLLaMACpp, "LLAMA_TOKEN_ATTR_NORMAL", INT2NUM(LLAMA_TOKEN_ATTR_NORMAL));
3654
- rb_define_const(rb_mLLaMACpp, "LLAMA_TOKEN_ATTR_CONTROL", INT2NUM(LLAMA_TOKEN_ATTR_CONTROL));
3655
- rb_define_const(rb_mLLaMACpp, "LLAMA_TOKEN_ATTR_USER_DEFINED", INT2NUM(LLAMA_TOKEN_ATTR_USER_DEFINED));
3656
- rb_define_const(rb_mLLaMACpp, "LLAMA_TOKEN_ATTR_BYTE", INT2NUM(LLAMA_TOKEN_ATTR_BYTE));
3657
- rb_define_const(rb_mLLaMACpp, "LLAMA_TOKEN_ATTR_NORMALIZED", INT2NUM(LLAMA_TOKEN_ATTR_NORMALIZED));
3658
- rb_define_const(rb_mLLaMACpp, "LLAMA_TOKEN_ATTR_LSTRIP", INT2NUM(LLAMA_TOKEN_ATTR_LSTRIP));
3659
- rb_define_const(rb_mLLaMACpp, "LLAMA_TOKEN_ATTR_RSTRIP", INT2NUM(LLAMA_TOKEN_ATTR_RSTRIP));
3660
- rb_define_const(rb_mLLaMACpp, "LLAMA_TOKEN_ATTR_SINGLE_WORD", INT2NUM(LLAMA_TOKEN_ATTR_SINGLE_WORD));
3661
-
3662
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_ALL_F32", INT2NUM(LLAMA_FTYPE_ALL_F32));
3663
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_F16", INT2NUM(LLAMA_FTYPE_MOSTLY_F16));
3664
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q4_0", INT2NUM(LLAMA_FTYPE_MOSTLY_Q4_0));
3665
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q4_1", INT2NUM(LLAMA_FTYPE_MOSTLY_Q4_1));
3666
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q8_0", INT2NUM(LLAMA_FTYPE_MOSTLY_Q8_0));
3667
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q5_0", INT2NUM(LLAMA_FTYPE_MOSTLY_Q5_0));
3668
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q5_1", INT2NUM(LLAMA_FTYPE_MOSTLY_Q5_1));
3669
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q2_K", INT2NUM(LLAMA_FTYPE_MOSTLY_Q2_K));
3670
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q3_K_S", INT2NUM(LLAMA_FTYPE_MOSTLY_Q3_K_S));
3671
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q3_K_M", INT2NUM(LLAMA_FTYPE_MOSTLY_Q3_K_M));
3672
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q3_K_L", INT2NUM(LLAMA_FTYPE_MOSTLY_Q3_K_L));
3673
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q4_K_S", INT2NUM(LLAMA_FTYPE_MOSTLY_Q4_K_S));
3674
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q4_K_M", INT2NUM(LLAMA_FTYPE_MOSTLY_Q4_K_M));
3675
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q5_K_S", INT2NUM(LLAMA_FTYPE_MOSTLY_Q5_K_S));
3676
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q5_K_M", INT2NUM(LLAMA_FTYPE_MOSTLY_Q5_K_M));
3677
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q6_K", INT2NUM(LLAMA_FTYPE_MOSTLY_Q6_K));
3678
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_IQ2_XXS", INT2NUM(LLAMA_FTYPE_MOSTLY_IQ2_XXS));
3679
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_IQ2_XS", INT2NUM(LLAMA_FTYPE_MOSTLY_IQ2_XS));
3680
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q2_K_S", INT2NUM(LLAMA_FTYPE_MOSTLY_Q2_K_S));
3681
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_IQ3_XS", INT2NUM(LLAMA_FTYPE_MOSTLY_IQ3_XS));
3682
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_IQ3_XXS", INT2NUM(LLAMA_FTYPE_MOSTLY_IQ3_XXS));
3683
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_IQ1_S", INT2NUM(LLAMA_FTYPE_MOSTLY_IQ1_S));
3684
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_IQ4_NL", INT2NUM(LLAMA_FTYPE_MOSTLY_IQ4_NL));
3685
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_IQ3_S", INT2NUM(LLAMA_FTYPE_MOSTLY_IQ3_S));
3686
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_IQ3_M", INT2NUM(LLAMA_FTYPE_MOSTLY_IQ3_M));
3687
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_IQ4_XS", INT2NUM(LLAMA_FTYPE_MOSTLY_IQ4_XS));
3688
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_IQ1_M", INT2NUM(LLAMA_FTYPE_MOSTLY_IQ1_M));
3689
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_BF16", INT2NUM(LLAMA_FTYPE_MOSTLY_BF16));
3690
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q4_0_4_4", INT2NUM(LLAMA_FTYPE_MOSTLY_Q4_0_4_4));
3691
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q4_0_4_8", INT2NUM(LLAMA_FTYPE_MOSTLY_Q4_0_4_8));
3692
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q4_0_8_8", INT2NUM(LLAMA_FTYPE_MOSTLY_Q4_0_8_8));
3693
-
3694
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_GUESSED", INT2NUM(LLAMA_FTYPE_GUESSED));
3695
-
3696
- rb_define_const(rb_mLLaMACpp, "LLAMA_KV_OVERRIDE_TYPE_INT", INT2NUM(LLAMA_KV_OVERRIDE_TYPE_INT));
3697
- rb_define_const(rb_mLLaMACpp, "LLAMA_KV_OVERRIDE_TYPE_FLOAT", INT2NUM(LLAMA_KV_OVERRIDE_TYPE_FLOAT));
3698
- rb_define_const(rb_mLLaMACpp, "LLAMA_KV_OVERRIDE_TYPE_BOOL", INT2NUM(LLAMA_KV_OVERRIDE_TYPE_BOOL));
3699
- rb_define_const(rb_mLLaMACpp, "LLAMA_KV_OVERRIDE_TYPE_STR", INT2NUM(LLAMA_KV_OVERRIDE_TYPE_STR));
3700
-
3701
- rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_END", INT2NUM(LLAMA_GRETYPE_END));
3702
- rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_ALT", INT2NUM(LLAMA_GRETYPE_ALT));
3703
- rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_RULE_REF", INT2NUM(LLAMA_GRETYPE_RULE_REF));
3704
- rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR", INT2NUM(LLAMA_GRETYPE_CHAR));
3705
- rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR_NOT", INT2NUM(LLAMA_GRETYPE_CHAR_NOT));
3706
- rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR_RNG_UPPER", INT2NUM(LLAMA_GRETYPE_CHAR_RNG_UPPER));
3707
- rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR_ALT", INT2NUM(LLAMA_GRETYPE_CHAR_ALT));
3708
- rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR_ANY", INT2NUM(LLAMA_GRETYPE_CHAR_ANY));
3709
-
3710
- rb_define_const(rb_mLLaMACpp, "LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED", INT2NUM(LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED));
3711
- rb_define_const(rb_mLLaMACpp, "LLAMA_ROPE_SCALING_TYPE_NONE", INT2NUM(LLAMA_ROPE_SCALING_TYPE_NONE));
3712
- rb_define_const(rb_mLLaMACpp, "LLAMA_ROPE_SCALING_TYPE_LINEAR", INT2NUM(LLAMA_ROPE_SCALING_TYPE_LINEAR));
3713
- rb_define_const(rb_mLLaMACpp, "LLAMA_ROPE_SCALING_TYPE_YARN", INT2NUM(LLAMA_ROPE_SCALING_TYPE_YARN));
3714
- rb_define_const(rb_mLLaMACpp, "LLAMA_ROPE_SCALING_TYPE_MAX_VALUE", INT2NUM(LLAMA_ROPE_SCALING_TYPE_MAX_VALUE));
3715
-
3716
- rb_define_const(rb_mLLaMACpp, "LLAMA_POOLING_TYPE_UNSPECIFIED", INT2NUM(LLAMA_POOLING_TYPE_UNSPECIFIED));
3717
- rb_define_const(rb_mLLaMACpp, "LLAMA_POOLING_TYPE_NONE", INT2NUM(LLAMA_POOLING_TYPE_NONE));
3718
- rb_define_const(rb_mLLaMACpp, "LLAMA_POOLING_TYPE_MEAN", INT2NUM(LLAMA_POOLING_TYPE_MEAN));
3719
- rb_define_const(rb_mLLaMACpp, "LLAMA_POOLING_TYPE_CLS", INT2NUM(LLAMA_POOLING_TYPE_CLS));
3720
- rb_define_const(rb_mLLaMACpp, "LLAMA_POOLING_TYPE_LAST", INT2NUM(LLAMA_POOLING_TYPE_LAST));
3721
-
3722
- rb_define_const(rb_mLLaMACpp, "LLAMA_ATTENTION_TYPE_UNSPECIFIED", INT2NUM(LLAMA_ATTENTION_TYPE_UNSPECIFIED));
3723
- rb_define_const(rb_mLLaMACpp, "LLAMA_ATTENTION_TYPE_CAUSAL", INT2NUM(LLAMA_ATTENTION_TYPE_CAUSAL));
3724
- rb_define_const(rb_mLLaMACpp, "LLAMA_ATTENTION_TYPE_NON_CAUSAL", INT2NUM(LLAMA_ATTENTION_TYPE_NON_CAUSAL));
3725
-
3726
- rb_define_const(rb_mLLaMACpp, "LLAMA_SPLIT_MODE_NONE", INT2NUM(LLAMA_SPLIT_MODE_NONE));
3727
- rb_define_const(rb_mLLaMACpp, "LLAMA_SPLIT_MODE_LAYER", INT2NUM(LLAMA_SPLIT_MODE_LAYER));
3728
- rb_define_const(rb_mLLaMACpp, "LLAMA_SPLIT_MODE_ROW", INT2NUM(LLAMA_SPLIT_MODE_ROW));
3729
-
3730
- std::stringstream ss_magic;
3731
- ss_magic << std::showbase << std::hex << LLAMA_FILE_MAGIC_GGLA;
3732
- rb_define_const(rb_mLLaMACpp, "LLAMA_FILE_MAGIC_GGLA", rb_str_new2(ss_magic.str().c_str()));
3733
-
3734
- ss_magic.str("");
3735
- ss_magic.clear(std::stringstream::goodbit);
3736
- ss_magic << std::showbase << std::hex << LLAMA_FILE_MAGIC_GGSN;
3737
- rb_define_const(rb_mLLaMACpp, "LLAMA_FILE_MAGIC_GGSN", rb_str_new2(ss_magic.str().c_str()));
3738
-
3739
- ss_magic.str("");
3740
- ss_magic.clear(std::stringstream::goodbit);
3741
- ss_magic << std::showbase << std::hex << LLAMA_FILE_MAGIC_GGSQ;
3742
- rb_define_const(rb_mLLaMACpp, "LLAMA_FILE_MAGIC_GGSQ", rb_str_new2(ss_magic.str().c_str()));
3743
-
3744
- ss_magic.str("");
3745
- ss_magic.clear(std::stringstream::goodbit);
3746
- ss_magic << std::showbase << std::hex << LLAMA_SESSION_MAGIC;
3747
- rb_define_const(rb_mLLaMACpp, "LLAMA_SESSION_MAGIC", rb_str_new2(ss_magic.str().c_str()));
3748
-
3749
- ss_magic.str("");
3750
- ss_magic.clear(std::stringstream::goodbit);
3751
- ss_magic << std::showbase << std::hex << LLAMA_STATE_SEQ_MAGIC;
3752
- rb_define_const(rb_mLLaMACpp, "LLAMA_STATE_SEQ_MAGIC", rb_str_new2(ss_magic.str().c_str()));
3753
-
3754
- ss_magic.str("");
3755
- ss_magic.clear(std::stringstream::goodbit);
3756
- ss_magic << std::showbase << std::hex << LLAMA_DEFAULT_SEED;
3757
- rb_define_const(rb_mLLaMACpp, "LLAMA_DEFAULT_SEED", rb_str_new2(ss_magic.str().c_str()));
3758
-
3759
- rb_define_const(rb_mLLaMACpp, "LLAMA_SESSION_VERSION", rb_str_new2(std::to_string(LLAMA_SESSION_VERSION).c_str()));
3760
- rb_define_const(rb_mLLaMACpp, "LLAMA_STATE_SEQ_VERSION", rb_str_new2(std::to_string(LLAMA_STATE_SEQ_VERSION).c_str()));
3761
- }