llama_cpp 0.17.9 → 0.18.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,3761 +0,0 @@
1
- #include "llama_cpp.h"
2
-
3
- VALUE rb_mLLaMACpp;
4
- VALUE rb_cLLaMABatch;
5
- VALUE rb_cLLaMAModel;
6
- VALUE rb_cLLaMAModelKVOverride;
7
- VALUE rb_cLLaMAModelParams;
8
- VALUE rb_cLLaMATimings;
9
- VALUE rb_cLLaMAContext;
10
- VALUE rb_cLLaMAContextParams;
11
- VALUE rb_cLLaMAModelQuantizeParams;
12
- VALUE rb_cLLaMATokenData;
13
- VALUE rb_cLLaMATokenDataArray;
14
- VALUE rb_cLLaMAGrammarElement;
15
- VALUE rb_cLLaMAGrammar;
16
-
17
- class LLaMABatchWrapper {
18
- public:
19
- llama_batch batch;
20
-
21
- LLaMABatchWrapper() {}
22
-
23
- ~LLaMABatchWrapper() {
24
- llama_batch_free(batch);
25
- }
26
- };
27
-
28
- class RbLLaMABatch {
29
- public:
30
- static VALUE llama_batch_alloc(VALUE self) {
31
- LLaMABatchWrapper* ptr = (LLaMABatchWrapper*)ruby_xmalloc(sizeof(LLaMABatchWrapper));
32
- new (ptr) LLaMABatchWrapper();
33
- return TypedData_Wrap_Struct(self, &llama_batch_type, ptr);
34
- }
35
-
36
- static void llama_batch_free(void* ptr) {
37
- ((LLaMABatchWrapper*)ptr)->~LLaMABatchWrapper();
38
- ruby_xfree(ptr);
39
- }
40
-
41
- static size_t llama_batch_size(const void* ptr) {
42
- return sizeof(*((LLaMABatchWrapper*)ptr));
43
- }
44
-
45
- static LLaMABatchWrapper* get_llama_batch(VALUE self) {
46
- LLaMABatchWrapper* ptr;
47
- TypedData_Get_Struct(self, LLaMABatchWrapper, &llama_batch_type, ptr);
48
- return ptr;
49
- }
50
-
51
- static void define_class(VALUE outer) {
52
- rb_cLLaMABatch = rb_define_class_under(outer, "Batch", rb_cObject);
53
- rb_define_alloc_func(rb_cLLaMABatch, llama_batch_alloc);
54
- rb_define_singleton_method(rb_cLLaMABatch, "get_one", RUBY_METHOD_FUNC(_llama_batch_get_one), -1);
55
- rb_define_method(rb_cLLaMABatch, "initialize", RUBY_METHOD_FUNC(_llama_batch_initialize), -1);
56
- rb_define_method(rb_cLLaMABatch, "n_tokens=", RUBY_METHOD_FUNC(_llama_batch_set_n_tokens), 1);
57
- rb_define_method(rb_cLLaMABatch, "n_tokens", RUBY_METHOD_FUNC(_llama_batch_get_n_tokens), 0);
58
- rb_define_method(rb_cLLaMABatch, "all_pos_zero=", RUBY_METHOD_FUNC(_llama_batch_set_all_pos_zero), 1);
59
- rb_define_method(rb_cLLaMABatch, "all_pos_zero", RUBY_METHOD_FUNC(_llama_batch_get_all_pos_zero), 0);
60
- rb_define_method(rb_cLLaMABatch, "all_pos_one=", RUBY_METHOD_FUNC(_llama_batch_set_all_pos_one), 1);
61
- rb_define_method(rb_cLLaMABatch, "all_pos_one", RUBY_METHOD_FUNC(_llama_batch_get_all_pos_one), 0);
62
- rb_define_method(rb_cLLaMABatch, "all_seq_id=", RUBY_METHOD_FUNC(_llama_batch_set_all_seq_id), 1);
63
- rb_define_method(rb_cLLaMABatch, "all_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_all_seq_id), 0);
64
- rb_define_method(rb_cLLaMABatch, "set_token", RUBY_METHOD_FUNC(_llama_batch_set_token), 2);
65
- rb_define_method(rb_cLLaMABatch, "get_token", RUBY_METHOD_FUNC(_llama_batch_get_token), 1);
66
- rb_define_method(rb_cLLaMABatch, "set_pos", RUBY_METHOD_FUNC(_llama_batch_set_pos), 2);
67
- rb_define_method(rb_cLLaMABatch, "get_pos", RUBY_METHOD_FUNC(_llama_batch_get_pos), 1);
68
- rb_define_method(rb_cLLaMABatch, "set_n_seq_id", RUBY_METHOD_FUNC(_llama_batch_set_n_seq_id), 2);
69
- rb_define_method(rb_cLLaMABatch, "get_n_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_n_seq_id), 1);
70
- rb_define_method(rb_cLLaMABatch, "set_seq_id", RUBY_METHOD_FUNC(_llama_batch_set_seq_id), 3);
71
- rb_define_method(rb_cLLaMABatch, "get_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_seq_id), 2);
72
- rb_define_method(rb_cLLaMABatch, "set_logits", RUBY_METHOD_FUNC(_llama_batch_set_logits), 2);
73
- rb_define_method(rb_cLLaMABatch, "get_logits", RUBY_METHOD_FUNC(_llama_batch_get_logits), 1);
74
- }
75
-
76
- private:
77
- static const rb_data_type_t llama_batch_type;
78
-
79
- static VALUE _llama_batch_get_one(int argc, VALUE* argv, VALUE klass) {
80
- VALUE kw_args = Qnil;
81
- ID kw_table[4] = { rb_intern("tokens"), rb_intern("n_tokens"), rb_intern("pos_zero"), rb_intern("seq_id") };
82
- VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
83
- rb_scan_args(argc, argv, ":", &kw_args);
84
- rb_get_kwargs(kw_args, kw_table, 4, 0, kw_values);
85
-
86
- if (!RB_TYPE_P(kw_values[0], T_ARRAY)) {
87
- rb_raise(rb_eArgError, "tokens must be an array");
88
- return Qnil;
89
- }
90
- if (!RB_INTEGER_TYPE_P(kw_values[1])) {
91
- rb_raise(rb_eArgError, "n_tokens must be an integer");
92
- return Qnil;
93
- }
94
- if (!RB_INTEGER_TYPE_P(kw_values[2])) {
95
- rb_raise(rb_eArgError, "pos_zero must be an integer");
96
- return Qnil;
97
- }
98
- if (!RB_INTEGER_TYPE_P(kw_values[3])) {
99
- rb_raise(rb_eArgError, "seq_id must be an integer");
100
- return Qnil;
101
- }
102
-
103
- const size_t sz_array = RARRAY_LEN(kw_values[0]);
104
- const int32_t n_tokens = NUM2INT(kw_values[1]);
105
- const llama_pos pos_zero = NUM2INT(kw_values[2]);
106
- const llama_seq_id seq_id = NUM2INT(kw_values[3]);
107
-
108
- LLaMABatchWrapper* ptr = (LLaMABatchWrapper*)ruby_xmalloc(sizeof(LLaMABatchWrapper));
109
- new (ptr) LLaMABatchWrapper();
110
- ptr->batch = llama_batch_get_one(nullptr, n_tokens, pos_zero, seq_id);
111
-
112
- ptr->batch.token = (llama_token*)malloc(sizeof(llama_token) * sz_array);
113
- for (size_t i = 0; i < sz_array; i++) {
114
- VALUE el = rb_ary_entry(kw_values[0], i);
115
- ptr->batch.token[i] = NUM2INT(el);
116
- }
117
-
118
- return TypedData_Wrap_Struct(klass, &llama_batch_type, ptr);
119
- }
120
-
121
- static VALUE _llama_batch_initialize(int argc, VALUE* argv, VALUE self) {
122
- VALUE kw_args = Qnil;
123
- ID kw_table[3] = { rb_intern("max_n_token"), rb_intern("n_embd"), rb_intern("max_n_seq") };
124
- VALUE kw_values[3] = { Qundef, Qundef, Qundef };
125
- rb_scan_args(argc, argv, ":", &kw_args);
126
- rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
127
-
128
- if (!RB_INTEGER_TYPE_P(kw_values[0])) {
129
- rb_raise(rb_eArgError, "max_n_token must be an integer");
130
- return Qnil;
131
- }
132
- if (!RB_INTEGER_TYPE_P(kw_values[1])) {
133
- rb_raise(rb_eArgError, "n_embd must be an integer");
134
- return Qnil;
135
- }
136
- if (!RB_INTEGER_TYPE_P(kw_values[2])) {
137
- rb_raise(rb_eArgError, "max_n_seq must be an integer");
138
- return Qnil;
139
- }
140
-
141
- const int32_t max_n_token = NUM2INT(kw_values[0]);
142
- const int32_t n_embd = NUM2INT(kw_values[1]);
143
- const int32_t max_n_seq = NUM2INT(kw_values[2]);
144
-
145
- LLaMABatchWrapper* ptr = get_llama_batch(self);
146
- ptr->batch = llama_batch_init(max_n_token, n_embd, max_n_seq);
147
-
148
- return Qnil;
149
- }
150
-
151
- // n_tokens
152
- static VALUE _llama_batch_set_n_tokens(VALUE self, VALUE n_tokens) {
153
- LLaMABatchWrapper* ptr = get_llama_batch(self);
154
- ptr->batch.n_tokens = NUM2INT(n_tokens);
155
- return INT2NUM(ptr->batch.n_tokens);
156
- }
157
-
158
- static VALUE _llama_batch_get_n_tokens(VALUE self) {
159
- LLaMABatchWrapper* ptr = get_llama_batch(self);
160
- return INT2NUM(ptr->batch.n_tokens);
161
- }
162
-
163
- // all_pos_0
164
- static VALUE _llama_batch_set_all_pos_zero(VALUE self, VALUE all_pos_0) {
165
- LLaMABatchWrapper* ptr = get_llama_batch(self);
166
- ptr->batch.all_pos_0 = NUM2INT(all_pos_0);
167
- return INT2NUM(ptr->batch.all_pos_0);
168
- }
169
-
170
- static VALUE _llama_batch_get_all_pos_zero(VALUE self) {
171
- LLaMABatchWrapper* ptr = get_llama_batch(self);
172
- return INT2NUM(ptr->batch.all_pos_0);
173
- }
174
-
175
- // all_pos_1
176
- static VALUE _llama_batch_set_all_pos_one(VALUE self, VALUE all_pos_1) {
177
- LLaMABatchWrapper* ptr = get_llama_batch(self);
178
- ptr->batch.all_pos_1 = NUM2INT(all_pos_1);
179
- return INT2NUM(ptr->batch.all_pos_1);
180
- }
181
-
182
- static VALUE _llama_batch_get_all_pos_one(VALUE self) {
183
- LLaMABatchWrapper* ptr = get_llama_batch(self);
184
- return INT2NUM(ptr->batch.all_pos_1);
185
- }
186
-
187
- // all_seq_id
188
- static VALUE _llama_batch_set_all_seq_id(VALUE self, VALUE all_seq_id) {
189
- LLaMABatchWrapper* ptr = get_llama_batch(self);
190
- ptr->batch.all_seq_id = NUM2INT(all_seq_id);
191
- return INT2NUM(ptr->batch.all_seq_id);
192
- }
193
-
194
- static VALUE _llama_batch_get_all_seq_id(VALUE self) {
195
- LLaMABatchWrapper* ptr = get_llama_batch(self);
196
- return INT2NUM(ptr->batch.all_seq_id);
197
- }
198
-
199
- // token
200
- static VALUE _llama_batch_set_token(VALUE self, VALUE idx, VALUE value) {
201
- LLaMABatchWrapper* ptr = get_llama_batch(self);
202
- const int32_t id = NUM2INT(idx);
203
- if (id < 0) {
204
- rb_raise(rb_eArgError, "id must be greater or equal to 0");
205
- return Qnil;
206
- }
207
- ptr->batch.token[id] = NUM2INT(value);
208
- return INT2NUM(ptr->batch.token[id]);
209
- }
210
-
211
- static VALUE _llama_batch_get_token(VALUE self, VALUE idx) {
212
- LLaMABatchWrapper* ptr = get_llama_batch(self);
213
- const int32_t id = NUM2INT(idx);
214
- if (id < 0) {
215
- rb_raise(rb_eArgError, "id must be greater or equal to 0");
216
- return Qnil;
217
- }
218
- return INT2NUM(ptr->batch.token[id]);
219
- }
220
-
221
- // pos
222
- static VALUE _llama_batch_set_pos(VALUE self, VALUE idx, VALUE value) {
223
- LLaMABatchWrapper* ptr = get_llama_batch(self);
224
- const int32_t id = NUM2INT(idx);
225
- if (id < 0) {
226
- rb_raise(rb_eArgError, "id must be greater or equal to 0");
227
- return Qnil;
228
- }
229
- ptr->batch.pos[id] = NUM2INT(value);
230
- return INT2NUM(ptr->batch.pos[id]);
231
- }
232
-
233
- static VALUE _llama_batch_get_pos(VALUE self, VALUE idx) {
234
- LLaMABatchWrapper* ptr = get_llama_batch(self);
235
- const int32_t id = NUM2INT(idx);
236
- if (id < 0) {
237
- rb_raise(rb_eArgError, "id must be greater or equal to 0");
238
- return Qnil;
239
- }
240
- return INT2NUM(ptr->batch.pos[id]);
241
- }
242
-
243
- // n_seq_id
244
- static VALUE _llama_batch_set_n_seq_id(VALUE self, VALUE idx, VALUE value) {
245
- LLaMABatchWrapper* ptr = get_llama_batch(self);
246
- const int32_t id = NUM2INT(idx);
247
- if (id < 0) {
248
- rb_raise(rb_eArgError, "id must be greater or equal to 0");
249
- return Qnil;
250
- }
251
- ptr->batch.n_seq_id[id] = NUM2INT(value);
252
- return INT2NUM(ptr->batch.n_seq_id[id]);
253
- }
254
-
255
- static VALUE _llama_batch_get_n_seq_id(VALUE self, VALUE idx) {
256
- LLaMABatchWrapper* ptr = get_llama_batch(self);
257
- const int32_t id = NUM2INT(idx);
258
- if (id < 0) {
259
- rb_raise(rb_eArgError, "id must be greater or equal to 0");
260
- return Qnil;
261
- }
262
- return INT2NUM(ptr->batch.n_seq_id[id]);
263
- }
264
-
265
- // seq_id
266
- static VALUE _llama_batch_set_seq_id(VALUE self, VALUE i_, VALUE j_, VALUE value) {
267
- LLaMABatchWrapper* ptr = get_llama_batch(self);
268
- const int32_t i = NUM2INT(i_);
269
- if (i < 0) {
270
- rb_raise(rb_eArgError, "i must be greater or equal to 0");
271
- return Qnil;
272
- }
273
- const int32_t j = NUM2INT(j_);
274
- if (j < 0) {
275
- rb_raise(rb_eArgError, "j must be greater or equal to 0");
276
- return Qnil;
277
- }
278
- ptr->batch.seq_id[i][j] = NUM2INT(value);
279
- return INT2NUM(ptr->batch.seq_id[i][j]);
280
- }
281
-
282
- static VALUE _llama_batch_get_seq_id(VALUE self, VALUE i_, VALUE j_) {
283
- LLaMABatchWrapper* ptr = get_llama_batch(self);
284
- const int32_t i = NUM2INT(i_);
285
- if (i < 0) {
286
- rb_raise(rb_eArgError, "i must be greater or equal to 0");
287
- return Qnil;
288
- }
289
- const int32_t j = NUM2INT(j_);
290
- if (j < 0) {
291
- rb_raise(rb_eArgError, "j must be greater or equal to 0");
292
- return Qnil;
293
- }
294
- return INT2NUM(ptr->batch.seq_id[i][j]);
295
- }
296
-
297
- // logits
298
- static VALUE _llama_batch_set_logits(VALUE self, VALUE idx, VALUE value) {
299
- LLaMABatchWrapper* ptr = get_llama_batch(self);
300
- const int32_t id = NUM2INT(idx);
301
- if (id < 0) {
302
- rb_raise(rb_eArgError, "id must be greater or equal to 0");
303
- return Qnil;
304
- }
305
- ptr->batch.logits[id] = RTEST(value) ? true : false;
306
- return ptr->batch.logits[id] ? Qtrue : Qfalse;
307
- }
308
-
309
- static VALUE _llama_batch_get_logits(VALUE self, VALUE idx) {
310
- LLaMABatchWrapper* ptr = get_llama_batch(self);
311
- const int32_t id = NUM2INT(idx);
312
- if (id < 0) {
313
- rb_raise(rb_eArgError, "id must be greater or equal to 0");
314
- return Qnil;
315
- }
316
- return ptr->batch.logits[id] ? Qtrue : Qfalse;
317
- }
318
- };
319
-
320
- const rb_data_type_t RbLLaMABatch::llama_batch_type = {
321
- "RbLLaMABatch",
322
- { NULL,
323
- RbLLaMABatch::llama_batch_free,
324
- RbLLaMABatch::llama_batch_size },
325
- NULL,
326
- NULL,
327
- RUBY_TYPED_FREE_IMMEDIATELY
328
-
329
- };
330
-
331
- class LLaMATokenDataWrapper {
332
- public:
333
- llama_token_data data;
334
-
335
- LLaMATokenDataWrapper() {
336
- data.id = 0;
337
- data.logit = 0.0;
338
- data.p = 0.0;
339
- }
340
-
341
- ~LLaMATokenDataWrapper() {}
342
- };
343
-
344
- class RbLLaMATokenData {
345
- public:
346
- static VALUE llama_token_data_alloc(VALUE self) {
347
- LLaMATokenDataWrapper* ptr = (LLaMATokenDataWrapper*)ruby_xmalloc(sizeof(LLaMATokenDataWrapper));
348
- new (ptr) LLaMATokenDataWrapper();
349
- return TypedData_Wrap_Struct(self, &llama_token_data_type, ptr);
350
- }
351
-
352
- static void llama_token_data_free(void* ptr) {
353
- ((LLaMATokenDataWrapper*)ptr)->~LLaMATokenDataWrapper();
354
- ruby_xfree(ptr);
355
- }
356
-
357
- static size_t llama_token_data_size(const void* ptr) {
358
- return sizeof(*((LLaMATokenDataWrapper*)ptr));
359
- }
360
-
361
- static LLaMATokenDataWrapper* get_llama_token_data(VALUE self) {
362
- LLaMATokenDataWrapper* ptr;
363
- TypedData_Get_Struct(self, LLaMATokenDataWrapper, &llama_token_data_type, ptr);
364
- return ptr;
365
- }
366
-
367
- static void define_class(VALUE outer) {
368
- rb_cLLaMATokenData = rb_define_class_under(outer, "TokenData", rb_cObject);
369
- rb_define_alloc_func(rb_cLLaMATokenData, llama_token_data_alloc);
370
- rb_define_method(rb_cLLaMATokenData, "initialize", RUBY_METHOD_FUNC(_llama_token_data_init), -1);
371
- rb_define_method(rb_cLLaMATokenData, "id=", RUBY_METHOD_FUNC(_llama_token_data_set_id), 1);
372
- rb_define_method(rb_cLLaMATokenData, "id", RUBY_METHOD_FUNC(_llama_token_data_get_id), 0);
373
- rb_define_method(rb_cLLaMATokenData, "logit=", RUBY_METHOD_FUNC(_llama_token_data_set_logit), 1);
374
- rb_define_method(rb_cLLaMATokenData, "logit", RUBY_METHOD_FUNC(_llama_token_data_get_logit), 0);
375
- rb_define_method(rb_cLLaMATokenData, "p=", RUBY_METHOD_FUNC(_llama_token_data_set_p), 1);
376
- rb_define_method(rb_cLLaMATokenData, "p", RUBY_METHOD_FUNC(_llama_token_data_get_p), 0);
377
- }
378
-
379
- private:
380
- static const rb_data_type_t llama_token_data_type;
381
-
382
- static VALUE _llama_token_data_init(int argc, VALUE* argv, VALUE self) {
383
- VALUE kw_args = Qnil;
384
- ID kw_table[3] = { rb_intern("id"), rb_intern("logit"), rb_intern("p") };
385
- VALUE kw_values[3] = { Qundef, Qundef, Qundef };
386
- rb_scan_args(argc, argv, ":", &kw_args);
387
- rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
388
-
389
- if (!RB_INTEGER_TYPE_P(kw_values[0])) {
390
- rb_raise(rb_eArgError, "id must be an integer");
391
- return Qnil;
392
- }
393
- if (!RB_FLOAT_TYPE_P(kw_values[1])) {
394
- rb_raise(rb_eArgError, "logit must be a float");
395
- return Qnil;
396
- }
397
- if (!RB_FLOAT_TYPE_P(kw_values[2])) {
398
- rb_raise(rb_eArgError, "p must be a float");
399
- return Qnil;
400
- }
401
-
402
- LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
403
- new (ptr) LLaMATokenDataWrapper();
404
-
405
- ptr->data.id = NUM2INT(kw_values[0]);
406
- ptr->data.logit = NUM2DBL(kw_values[1]);
407
- ptr->data.p = NUM2DBL(kw_values[2]);
408
-
409
- return self;
410
- }
411
-
412
- // id
413
- static VALUE _llama_token_data_set_id(VALUE self, VALUE id) {
414
- LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
415
- ptr->data.id = NUM2INT(id);
416
- return INT2NUM(ptr->data.id);
417
- }
418
-
419
- static VALUE _llama_token_data_get_id(VALUE self) {
420
- LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
421
- return INT2NUM(ptr->data.id);
422
- }
423
-
424
- // logit
425
- static VALUE _llama_token_data_set_logit(VALUE self, VALUE logit) {
426
- LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
427
- ptr->data.logit = NUM2DBL(logit);
428
- return DBL2NUM(ptr->data.logit);
429
- }
430
-
431
- static VALUE _llama_token_data_get_logit(VALUE self) {
432
- LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
433
- return DBL2NUM(ptr->data.logit);
434
- }
435
-
436
- // p
437
- static VALUE _llama_token_data_set_p(VALUE self, VALUE p) {
438
- LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
439
- ptr->data.p = NUM2DBL(p);
440
- return DBL2NUM(ptr->data.p);
441
- }
442
-
443
- static VALUE _llama_token_data_get_p(VALUE self) {
444
- LLaMATokenDataWrapper* ptr = get_llama_token_data(self);
445
- return DBL2NUM(ptr->data.p);
446
- }
447
- };
448
-
449
- const rb_data_type_t RbLLaMATokenData::llama_token_data_type = {
450
- "RbLLaMATokenData",
451
- { NULL,
452
- RbLLaMATokenData::llama_token_data_free,
453
- RbLLaMATokenData::llama_token_data_size },
454
- NULL,
455
- NULL,
456
- RUBY_TYPED_FREE_IMMEDIATELY
457
- };
458
-
459
- class LLaMATokenDataArrayWrapper {
460
- public:
461
- llama_token_data_array array;
462
-
463
- LLaMATokenDataArrayWrapper() {
464
- array.data = nullptr;
465
- array.size = 0;
466
- array.sorted = false;
467
- }
468
-
469
- ~LLaMATokenDataArrayWrapper() {
470
- if (array.data) {
471
- ruby_xfree(array.data);
472
- array.data = nullptr;
473
- }
474
- }
475
- };
476
-
477
- class RbLLaMATokenDataArray {
478
- public:
479
- static VALUE llama_token_data_array_alloc(VALUE self) {
480
- LLaMATokenDataArrayWrapper* ptr = (LLaMATokenDataArrayWrapper*)ruby_xmalloc(sizeof(LLaMATokenDataArrayWrapper));
481
- new (ptr) LLaMATokenDataArrayWrapper();
482
- return TypedData_Wrap_Struct(self, &llama_token_data_array_type, ptr);
483
- }
484
-
485
- static void llama_token_data_array_free(void* ptr) {
486
- ((LLaMATokenDataArrayWrapper*)ptr)->~LLaMATokenDataArrayWrapper();
487
- ruby_xfree(ptr);
488
- }
489
-
490
- static size_t llama_token_data_array_size(const void* ptr) {
491
- return sizeof(*((LLaMATokenDataArrayWrapper*)ptr));
492
- }
493
-
494
- static LLaMATokenDataArrayWrapper* get_llama_token_data_array(VALUE self) {
495
- LLaMATokenDataArrayWrapper* ptr;
496
- TypedData_Get_Struct(self, LLaMATokenDataArrayWrapper, &llama_token_data_array_type, ptr);
497
- return ptr;
498
- }
499
-
500
- static void define_class(VALUE outer) {
501
- rb_cLLaMATokenDataArray = rb_define_class_under(outer, "TokenDataArray", rb_cObject);
502
- rb_define_alloc_func(rb_cLLaMATokenDataArray, llama_token_data_array_alloc);
503
- rb_define_method(rb_cLLaMATokenDataArray, "initialize", RUBY_METHOD_FUNC(_llama_token_data_array_init), -1);
504
- rb_define_method(rb_cLLaMATokenDataArray, "size", RUBY_METHOD_FUNC(_llama_token_data_array_get_size), 0);
505
- rb_define_method(rb_cLLaMATokenDataArray, "sorted", RUBY_METHOD_FUNC(_llama_token_data_array_get_sorted), 0);
506
- }
507
-
508
- private:
509
- static const rb_data_type_t llama_token_data_array_type;
510
-
511
- static VALUE _llama_token_data_array_init(int argc, VALUE* argv, VALUE self) {
512
- VALUE kw_args = Qnil;
513
- ID kw_table[1] = { rb_intern("sorted") };
514
- VALUE kw_values[1] = { Qundef };
515
- VALUE arr = Qnil;
516
- rb_scan_args(argc, argv, "1:", &arr, &kw_args);
517
- rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
518
-
519
- if (!RB_TYPE_P(arr, T_ARRAY)) {
520
- rb_raise(rb_eArgError, "1st argument must be an array");
521
- return Qnil;
522
- }
523
- size_t sz_array = RARRAY_LEN(arr);
524
- if (sz_array == 0) {
525
- rb_raise(rb_eArgError, "array must not be empty");
526
- return Qnil;
527
- }
528
- if (kw_values[0] != Qundef && !RB_TYPE_P(kw_values[0], T_TRUE) && !RB_TYPE_P(kw_values[0], T_FALSE)) {
529
- rb_raise(rb_eArgError, "sorted must be a boolean");
530
- return Qnil;
531
- }
532
-
533
- LLaMATokenDataArrayWrapper* ptr = get_llama_token_data_array(self);
534
- new (ptr) LLaMATokenDataArrayWrapper();
535
-
536
- ptr->array.data = (llama_token_data*)ruby_xmalloc(sizeof(llama_token_data) * sz_array);
537
- for (size_t i = 0; i < sz_array; ++i) {
538
- VALUE el = rb_ary_entry(arr, i);
539
- if (!rb_obj_is_kind_of(el, rb_cLLaMATokenData)) {
540
- rb_raise(rb_eArgError, "array element must be a TokenData");
541
- xfree(ptr->array.data);
542
- ptr->array.data = nullptr;
543
- return Qnil;
544
- }
545
- llama_token_data token_data = RbLLaMATokenData::get_llama_token_data(el)->data;
546
- ptr->array.data[i].id = token_data.id;
547
- ptr->array.data[i].logit = token_data.logit;
548
- ptr->array.data[i].p = token_data.p;
549
- }
550
-
551
- ptr->array.size = sz_array;
552
- ptr->array.sorted = kw_values[0] == Qtrue;
553
-
554
- return self;
555
- }
556
-
557
- static VALUE _llama_token_data_array_get_size(VALUE self) {
558
- LLaMATokenDataArrayWrapper* ptr = get_llama_token_data_array(self);
559
- return SIZET2NUM(ptr->array.size);
560
- }
561
-
562
- static VALUE _llama_token_data_array_get_sorted(VALUE self) {
563
- LLaMATokenDataArrayWrapper* ptr = get_llama_token_data_array(self);
564
- return ptr->array.sorted ? Qtrue : Qfalse;
565
- }
566
- };
567
-
568
- const rb_data_type_t RbLLaMATokenDataArray::llama_token_data_array_type = {
569
- "RbLLaMATokenDataArray",
570
- { NULL,
571
- RbLLaMATokenDataArray::llama_token_data_array_free,
572
- RbLLaMATokenDataArray::llama_token_data_array_size },
573
- NULL,
574
- NULL,
575
- RUBY_TYPED_FREE_IMMEDIATELY
576
- };
577
-
578
- class LLaMATimingsWrapper {
579
- public:
580
- struct llama_timings timings;
581
-
582
- LLaMATimingsWrapper() {}
583
-
584
- ~LLaMATimingsWrapper() {}
585
- };
586
-
587
- class RbLLaMATimings {
588
- public:
589
- static VALUE llama_timings_alloc(VALUE self) {
590
- LLaMATimingsWrapper* ptr = (LLaMATimingsWrapper*)ruby_xmalloc(sizeof(LLaMATimingsWrapper));
591
- new (ptr) LLaMATimingsWrapper();
592
- return TypedData_Wrap_Struct(self, &llama_timings_type, ptr);
593
- }
594
-
595
- static void llama_timings_free(void* ptr) {
596
- ((LLaMATimingsWrapper*)ptr)->~LLaMATimingsWrapper();
597
- ruby_xfree(ptr);
598
- }
599
-
600
- static size_t llama_timings_size(const void* ptr) {
601
- return sizeof(*((LLaMATimingsWrapper*)ptr));
602
- }
603
-
604
- static LLaMATimingsWrapper* get_llama_timings(VALUE self) {
605
- LLaMATimingsWrapper* ptr;
606
- TypedData_Get_Struct(self, LLaMATimingsWrapper, &llama_timings_type, ptr);
607
- return ptr;
608
- }
609
-
610
- static void define_class(VALUE outer) {
611
- rb_cLLaMATimings = rb_define_class_under(outer, "Timings", rb_cObject);
612
- rb_define_alloc_func(rb_cLLaMATimings, llama_timings_alloc);
613
- rb_define_method(rb_cLLaMATimings, "t_start_ms", RUBY_METHOD_FUNC(_llama_timings_get_t_start_ms), 0);
614
- rb_define_method(rb_cLLaMATimings, "t_end_ms", RUBY_METHOD_FUNC(_llama_timings_get_t_end_ms), 0);
615
- rb_define_method(rb_cLLaMATimings, "t_load_ms", RUBY_METHOD_FUNC(_llama_timings_get_t_load_ms), 0);
616
- rb_define_method(rb_cLLaMATimings, "t_sample_ms", RUBY_METHOD_FUNC(_llama_timings_get_t_sample_ms), 0);
617
- rb_define_method(rb_cLLaMATimings, "t_p_eval_ms", RUBY_METHOD_FUNC(_llama_timings_get_t_p_eval_ms), 0);
618
- rb_define_method(rb_cLLaMATimings, "t_eval_ms", RUBY_METHOD_FUNC(_llama_timings_get_t_eval_ms), 0);
619
- rb_define_method(rb_cLLaMATimings, "n_sample", RUBY_METHOD_FUNC(_llama_timings_get_n_sample), 0);
620
- rb_define_method(rb_cLLaMATimings, "n_p_eval", RUBY_METHOD_FUNC(_llama_timings_get_n_p_eval), 0);
621
- rb_define_method(rb_cLLaMATimings, "n_eval", RUBY_METHOD_FUNC(_llama_timings_get_n_eval), 0);
622
- }
623
-
624
- private:
625
- static const rb_data_type_t llama_timings_type;
626
-
627
- static VALUE _llama_timings_get_t_start_ms(VALUE self) {
628
- LLaMATimingsWrapper* ptr = get_llama_timings(self);
629
- return DBL2NUM(ptr->timings.t_start_ms);
630
- }
631
-
632
- static VALUE _llama_timings_get_t_end_ms(VALUE self) {
633
- LLaMATimingsWrapper* ptr = get_llama_timings(self);
634
- return DBL2NUM(ptr->timings.t_end_ms);
635
- }
636
-
637
- static VALUE _llama_timings_get_t_load_ms(VALUE self) {
638
- LLaMATimingsWrapper* ptr = get_llama_timings(self);
639
- return DBL2NUM(ptr->timings.t_load_ms);
640
- }
641
-
642
- static VALUE _llama_timings_get_t_sample_ms(VALUE self) {
643
- LLaMATimingsWrapper* ptr = get_llama_timings(self);
644
- return DBL2NUM(ptr->timings.t_sample_ms);
645
- }
646
-
647
- static VALUE _llama_timings_get_t_p_eval_ms(VALUE self) {
648
- LLaMATimingsWrapper* ptr = get_llama_timings(self);
649
- return DBL2NUM(ptr->timings.t_p_eval_ms);
650
- }
651
-
652
- static VALUE _llama_timings_get_t_eval_ms(VALUE self) {
653
- LLaMATimingsWrapper* ptr = get_llama_timings(self);
654
- return DBL2NUM(ptr->timings.t_eval_ms);
655
- }
656
-
657
- static VALUE _llama_timings_get_n_sample(VALUE self) {
658
- LLaMATimingsWrapper* ptr = get_llama_timings(self);
659
- return INT2NUM(ptr->timings.n_sample);
660
- }
661
-
662
- static VALUE _llama_timings_get_n_p_eval(VALUE self) {
663
- LLaMATimingsWrapper* ptr = get_llama_timings(self);
664
- return INT2NUM(ptr->timings.n_p_eval);
665
- }
666
-
667
- static VALUE _llama_timings_get_n_eval(VALUE self) {
668
- LLaMATimingsWrapper* ptr = get_llama_timings(self);
669
- return INT2NUM(ptr->timings.n_eval);
670
- }
671
- };
672
-
673
- const rb_data_type_t RbLLaMATimings::llama_timings_type = {
674
- "RbLLaMATimings",
675
- { NULL,
676
- RbLLaMATimings::llama_timings_free,
677
- RbLLaMATimings::llama_timings_size },
678
- NULL,
679
- NULL,
680
- RUBY_TYPED_FREE_IMMEDIATELY
681
- };
682
-
683
- class RbLLaMAModelKVOverride {
684
- public:
685
- static VALUE llama_model_kv_override_alloc(VALUE self) {
686
- llama_model_kv_override* ptr = (llama_model_kv_override*)ruby_xmalloc(sizeof(llama_model_kv_override));
687
- new (ptr) llama_model_kv_override();
688
- return TypedData_Wrap_Struct(self, &llama_model_kv_override_type, ptr);
689
- }
690
-
691
- static void llama_model_kv_override_free(void* ptr) {
692
- ((llama_model_kv_override*)ptr)->~llama_model_kv_override();
693
- ruby_xfree(ptr);
694
- }
695
-
696
- static size_t llama_model_kv_override_size(const void* ptr) {
697
- return sizeof(*((llama_model_kv_override*)ptr));
698
- }
699
-
700
- static llama_model_kv_override* get_llama_model_kv_override(VALUE self) {
701
- llama_model_kv_override* ptr;
702
- TypedData_Get_Struct(self, llama_model_kv_override, &llama_model_kv_override_type, ptr);
703
- return ptr;
704
- }
705
-
706
- static void define_class(VALUE outer) {
707
- rb_cLLaMAModelKVOverride = rb_define_class_under(outer, "ModelKVOverride", rb_cObject);
708
- rb_define_alloc_func(rb_cLLaMAModelKVOverride, llama_model_kv_override_alloc);
709
- rb_define_method(rb_cLLaMAModelKVOverride, "key", RUBY_METHOD_FUNC(_llama_model_kv_override_get_key), 0);
710
- rb_define_method(rb_cLLaMAModelKVOverride, "tag", RUBY_METHOD_FUNC(_llama_model_kv_override_get_tag), 0);
711
- rb_define_method(rb_cLLaMAModelKVOverride, "val_i64", RUBY_METHOD_FUNC(_llama_model_kv_override_get_val_i64), 0);
712
- rb_define_method(rb_cLLaMAModelKVOverride, "val_f64", RUBY_METHOD_FUNC(_llama_model_kv_override_get_val_f64), 0);
713
- rb_define_method(rb_cLLaMAModelKVOverride, "val_bool", RUBY_METHOD_FUNC(_llama_model_kv_override_get_val_bool), 0);
714
- rb_define_method(rb_cLLaMAModelKVOverride, "val_str", RUBY_METHOD_FUNC(_llama_model_kv_override_get_val_str), 0);
715
- }
716
-
717
- static const rb_data_type_t llama_model_kv_override_type;
718
-
719
- private:
720
- static VALUE _llama_model_kv_override_get_key(VALUE self) {
721
- llama_model_kv_override* ptr = get_llama_model_kv_override(self);
722
- return rb_utf8_str_new_cstr(ptr->key);
723
- }
724
-
725
- static VALUE _llama_model_kv_override_get_tag(VALUE self) {
726
- llama_model_kv_override* ptr = get_llama_model_kv_override(self);
727
- return INT2NUM(ptr->tag);
728
- }
729
-
730
- static VALUE _llama_model_kv_override_get_val_i64(VALUE self) {
731
- llama_model_kv_override* ptr = get_llama_model_kv_override(self);
732
- return INT2NUM(ptr->val_i64);
733
- }
734
-
735
- static VALUE _llama_model_kv_override_get_val_f64(VALUE self) {
736
- llama_model_kv_override* ptr = get_llama_model_kv_override(self);
737
- return DBL2NUM(ptr->val_f64);
738
- }
739
-
740
- static VALUE _llama_model_kv_override_get_val_bool(VALUE self) {
741
- llama_model_kv_override* ptr = get_llama_model_kv_override(self);
742
- return ptr->val_bool ? Qtrue : Qfalse;
743
- }
744
-
745
- static VALUE _llama_model_kv_override_get_val_str(VALUE self) {
746
- llama_model_kv_override* ptr = get_llama_model_kv_override(self);
747
- return rb_utf8_str_new_cstr(ptr->val_str);
748
- }
749
- };
750
-
751
- const rb_data_type_t RbLLaMAModelKVOverride::llama_model_kv_override_type = {
752
- "RbLLaMAModelKVOverride",
753
- { NULL,
754
- RbLLaMAModelKVOverride::llama_model_kv_override_free,
755
- RbLLaMAModelKVOverride::llama_model_kv_override_size },
756
- NULL,
757
- NULL,
758
- RUBY_TYPED_FREE_IMMEDIATELY
759
- };
760
-
761
- class LLaMAModelParamsWrapper {
762
- public:
763
- struct llama_model_params params;
764
-
765
- LLaMAModelParamsWrapper() : params(llama_model_default_params()) {}
766
-
767
- ~LLaMAModelParamsWrapper() {}
768
- };
769
-
770
- class RbLLaMAModelParams {
771
- public:
772
- static VALUE llama_model_params_alloc(VALUE self) {
773
- LLaMAModelParamsWrapper* ptr = (LLaMAModelParamsWrapper*)ruby_xmalloc(sizeof(LLaMAModelParamsWrapper));
774
- new (ptr) LLaMAModelParamsWrapper();
775
- return TypedData_Wrap_Struct(self, &llama_model_params_type, ptr);
776
- }
777
-
778
- static void llama_model_params_free(void* ptr) {
779
- ((LLaMAModelParamsWrapper*)ptr)->~LLaMAModelParamsWrapper();
780
- ruby_xfree(ptr);
781
- }
782
-
783
- static size_t llama_model_params_size(const void* ptr) {
784
- return sizeof(*((LLaMAModelParamsWrapper*)ptr));
785
- }
786
-
787
- static LLaMAModelParamsWrapper* get_llama_model_params(VALUE self) {
788
- LLaMAModelParamsWrapper* ptr;
789
- TypedData_Get_Struct(self, LLaMAModelParamsWrapper, &llama_model_params_type, ptr);
790
- return ptr;
791
- }
792
-
793
- static void define_class(VALUE outer) {
794
- rb_cLLaMAModelParams = rb_define_class_under(outer, "ModelParams", rb_cObject);
795
- rb_define_alloc_func(rb_cLLaMAModelParams, llama_model_params_alloc);
796
- rb_define_method(rb_cLLaMAModelParams, "n_gpu_layers=", RUBY_METHOD_FUNC(_llama_model_params_set_n_gpu_layers), 1);
797
- rb_define_method(rb_cLLaMAModelParams, "n_gpu_layers", RUBY_METHOD_FUNC(_llama_model_params_get_n_gpu_layers), 0);
798
- rb_define_method(rb_cLLaMAModelParams, "split_mode=", RUBY_METHOD_FUNC(_llama_model_params_set_split_mode), 1);
799
- rb_define_method(rb_cLLaMAModelParams, "split_mode", RUBY_METHOD_FUNC(_llama_model_params_get_split_mode), 0);
800
- rb_define_method(rb_cLLaMAModelParams, "main_gpu=", RUBY_METHOD_FUNC(_llama_model_params_set_main_gpu), 1);
801
- rb_define_method(rb_cLLaMAModelParams, "main_gpu", RUBY_METHOD_FUNC(_llama_model_params_get_main_gpu), 0);
802
- rb_define_method(rb_cLLaMAModelParams, "tensor_split", RUBY_METHOD_FUNC(_llama_model_params_get_tensor_split), 0);
803
- rb_define_method(rb_cLLaMAModelParams, "vocab_only=", RUBY_METHOD_FUNC(_llama_model_params_set_vocab_only), 1);
804
- rb_define_method(rb_cLLaMAModelParams, "vocab_only", RUBY_METHOD_FUNC(_llama_model_params_get_vocab_only), 0);
805
- rb_define_method(rb_cLLaMAModelParams, "use_mmap=", RUBY_METHOD_FUNC(_llama_model_params_set_use_mmap), 1);
806
- rb_define_method(rb_cLLaMAModelParams, "use_mmap", RUBY_METHOD_FUNC(_llama_model_params_get_use_mmap), 0);
807
- rb_define_method(rb_cLLaMAModelParams, "use_mlock=", RUBY_METHOD_FUNC(_llama_model_params_set_use_mlock), 1);
808
- rb_define_method(rb_cLLaMAModelParams, "use_mlock", RUBY_METHOD_FUNC(_llama_model_params_get_use_mlock), 0);
809
- rb_define_method(rb_cLLaMAModelParams, "check_tensors=", RUBY_METHOD_FUNC(_llama_model_params_set_check_tensors), 1);
810
- rb_define_method(rb_cLLaMAModelParams, "check_tensors", RUBY_METHOD_FUNC(_llama_model_params_get_check_tensors), 0);
811
- }
812
-
813
- private:
814
- static const rb_data_type_t llama_model_params_type;
815
-
816
- // n_gpu_layers
817
- static VALUE _llama_model_params_set_n_gpu_layers(VALUE self, VALUE n_gpu_layers) {
818
- LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
819
- ptr->params.n_gpu_layers = NUM2INT(n_gpu_layers);
820
- return INT2NUM(ptr->params.n_gpu_layers);
821
- }
822
-
823
- static VALUE _llama_model_params_get_n_gpu_layers(VALUE self) {
824
- LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
825
- return INT2NUM(ptr->params.n_gpu_layers);
826
- }
827
-
828
- // split_mode
829
- static VALUE _llama_model_params_set_split_mode(VALUE self, VALUE split_mode) {
830
- LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
831
- ptr->params.split_mode = static_cast<enum llama_split_mode>(NUM2INT(split_mode));
832
- return INT2NUM(ptr->params.split_mode);
833
- }
834
-
835
- static VALUE _llama_model_params_get_split_mode(VALUE self) {
836
- LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
837
- return INT2NUM(ptr->params.split_mode);
838
- }
839
-
840
- // main_gpu
841
- static VALUE _llama_model_params_set_main_gpu(VALUE self, VALUE main_gpu) {
842
- LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
843
- ptr->params.main_gpu = NUM2INT(main_gpu);
844
- return INT2NUM(ptr->params.main_gpu);
845
- }
846
-
847
- static VALUE _llama_model_params_get_main_gpu(VALUE self) {
848
- LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
849
- return INT2NUM(ptr->params.main_gpu);
850
- }
851
-
852
- // tensor_split
853
- static VALUE _llama_model_params_get_tensor_split(VALUE self) {
854
- if (llama_max_devices() < 1) {
855
- return rb_ary_new();
856
- }
857
- VALUE ret = rb_ary_new2(llama_max_devices());
858
- LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
859
- if (ptr->params.tensor_split == nullptr) {
860
- return rb_ary_new();
861
- }
862
- for (size_t i = 0; i < llama_max_devices(); i++) {
863
- rb_ary_store(ret, i, DBL2NUM(ptr->params.tensor_split[i]));
864
- }
865
- return ret;
866
- }
867
-
868
- // vocab_only
869
- static VALUE _llama_model_params_set_vocab_only(VALUE self, VALUE vocab_only) {
870
- LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
871
- ptr->params.vocab_only = RTEST(vocab_only) ? true : false;
872
- return ptr->params.vocab_only ? Qtrue : Qfalse;
873
- }
874
-
875
- static VALUE _llama_model_params_get_vocab_only(VALUE self) {
876
- LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
877
- return ptr->params.vocab_only ? Qtrue : Qfalse;
878
- }
879
-
880
- // use_mmap
881
- static VALUE _llama_model_params_set_use_mmap(VALUE self, VALUE use_mmap) {
882
- LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
883
- ptr->params.use_mmap = RTEST(use_mmap) ? true : false;
884
- return ptr->params.use_mmap ? Qtrue : Qfalse;
885
- }
886
-
887
- static VALUE _llama_model_params_get_use_mmap(VALUE self) {
888
- LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
889
- return ptr->params.use_mmap ? Qtrue : Qfalse;
890
- }
891
-
892
- // use_mlock
893
- static VALUE _llama_model_params_set_use_mlock(VALUE self, VALUE use_mlock) {
894
- LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
895
- ptr->params.use_mlock = RTEST(use_mlock) ? true : false;
896
- return ptr->params.use_mlock ? Qtrue : Qfalse;
897
- }
898
-
899
- static VALUE _llama_model_params_get_use_mlock(VALUE self) {
900
- LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
901
- return ptr->params.use_mlock ? Qtrue : Qfalse;
902
- }
903
-
904
- // check_tensors
905
- static VALUE _llama_model_params_set_check_tensors(VALUE self, VALUE check_tensors) {
906
- LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
907
- ptr->params.check_tensors = RTEST(check_tensors) ? true : false;
908
- return ptr->params.check_tensors ? Qtrue : Qfalse;
909
- }
910
-
911
- static VALUE _llama_model_params_get_check_tensors(VALUE self) {
912
- LLaMAModelParamsWrapper* ptr = get_llama_model_params(self);
913
- return ptr->params.check_tensors ? Qtrue : Qfalse;
914
- }
915
- };
916
-
917
- const rb_data_type_t RbLLaMAModelParams::llama_model_params_type = {
918
- "RbLLaMAModelParams",
919
- { NULL,
920
- RbLLaMAModelParams::llama_model_params_free,
921
- RbLLaMAModelParams::llama_model_params_size },
922
- NULL,
923
- NULL,
924
- RUBY_TYPED_FREE_IMMEDIATELY
925
- };
926
-
927
- class LLaMAContextParamsWrapper {
928
- public:
929
- struct llama_context_params params;
930
-
931
- LLaMAContextParamsWrapper() : params(llama_context_default_params()) {}
932
-
933
- ~LLaMAContextParamsWrapper() {}
934
- };
935
-
936
- class RbLLaMAContextParams {
937
- public:
938
- static VALUE llama_context_params_alloc(VALUE self) {
939
- LLaMAContextParamsWrapper* ptr = (LLaMAContextParamsWrapper*)ruby_xmalloc(sizeof(LLaMAContextParamsWrapper));
940
- new (ptr) LLaMAContextParamsWrapper();
941
- return TypedData_Wrap_Struct(self, &llama_context_params_type, ptr);
942
- }
943
-
944
- static void llama_context_params_free(void* ptr) {
945
- ((LLaMAContextParamsWrapper*)ptr)->~LLaMAContextParamsWrapper();
946
- ruby_xfree(ptr);
947
- }
948
-
949
- static size_t llama_context_params_size(const void* ptr) {
950
- return sizeof(*((LLaMAContextParamsWrapper*)ptr));
951
- }
952
-
953
- static LLaMAContextParamsWrapper* get_llama_context_params(VALUE self) {
954
- LLaMAContextParamsWrapper* ptr;
955
- TypedData_Get_Struct(self, LLaMAContextParamsWrapper, &llama_context_params_type, ptr);
956
- return ptr;
957
- }
958
-
959
- static void define_class(VALUE outer) {
960
- rb_cLLaMAContextParams = rb_define_class_under(outer, "ContextParams", rb_cObject);
961
- rb_define_alloc_func(rb_cLLaMAContextParams, llama_context_params_alloc);
962
- // rb_define_method(rb_cLLaMAContextParams, "initialize", RUBY_METHOD_FUNC(_llama_context_params_init), 0);
963
- rb_define_method(rb_cLLaMAContextParams, "seed=", RUBY_METHOD_FUNC(_llama_context_params_set_seed), 1);
964
- rb_define_method(rb_cLLaMAContextParams, "seed", RUBY_METHOD_FUNC(_llama_context_params_get_seed), 0);
965
- rb_define_method(rb_cLLaMAContextParams, "n_ctx=", RUBY_METHOD_FUNC(_llama_context_params_set_n_ctx), 1);
966
- rb_define_method(rb_cLLaMAContextParams, "n_ctx", RUBY_METHOD_FUNC(_llama_context_params_get_n_ctx), 0);
967
- rb_define_method(rb_cLLaMAContextParams, "n_batch=", RUBY_METHOD_FUNC(_llama_context_params_set_n_batch), 1);
968
- rb_define_method(rb_cLLaMAContextParams, "n_batch", RUBY_METHOD_FUNC(_llama_context_params_get_n_batch), 0);
969
- rb_define_method(rb_cLLaMAContextParams, "n_ubatch=", RUBY_METHOD_FUNC(_llama_context_params_set_n_ubatch), 1);
970
- rb_define_method(rb_cLLaMAContextParams, "n_ubatch", RUBY_METHOD_FUNC(_llama_context_params_get_n_ubatch), 0);
971
- rb_define_method(rb_cLLaMAContextParams, "n_seq_max=", RUBY_METHOD_FUNC(_llama_context_params_set_n_seq_max), 1);
972
- rb_define_method(rb_cLLaMAContextParams, "n_seq_max", RUBY_METHOD_FUNC(_llama_context_params_get_n_seq_max), 0);
973
- rb_define_method(rb_cLLaMAContextParams, "n_threads=", RUBY_METHOD_FUNC(_llama_context_params_set_n_threads), 1);
974
- rb_define_method(rb_cLLaMAContextParams, "n_threads", RUBY_METHOD_FUNC(_llama_context_params_get_n_threads), 0);
975
- rb_define_method(rb_cLLaMAContextParams, "n_threads_batch=", RUBY_METHOD_FUNC(_llama_context_params_set_n_threads_batch), 1);
976
- rb_define_method(rb_cLLaMAContextParams, "n_threads_batch", RUBY_METHOD_FUNC(_llama_context_params_get_n_threads_batch), 0);
977
- rb_define_method(rb_cLLaMAContextParams, "rope_scaling_type=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_scaling_type), 1);
978
- rb_define_method(rb_cLLaMAContextParams, "rope_scaling_type", RUBY_METHOD_FUNC(_llama_context_params_get_rope_scaling_type), 0);
979
- rb_define_method(rb_cLLaMAContextParams, "pooling_type=", RUBY_METHOD_FUNC(_llama_context_params_set_pooling_type), 1);
980
- rb_define_method(rb_cLLaMAContextParams, "pooling_type", RUBY_METHOD_FUNC(_llama_context_params_get_pooling_type), 0);
981
- rb_define_method(rb_cLLaMAContextParams, "attention_type=", RUBY_METHOD_FUNC(_llama_context_params_set_attention_type), 1);
982
- rb_define_method(rb_cLLaMAContextParams, "attention_type", RUBY_METHOD_FUNC(_llama_context_params_get_attention_type), 0);
983
- rb_define_method(rb_cLLaMAContextParams, "rope_freq_base=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_freq_base), 1);
984
- rb_define_method(rb_cLLaMAContextParams, "rope_freq_base", RUBY_METHOD_FUNC(_llama_context_params_get_rope_freq_base), 0);
985
- rb_define_method(rb_cLLaMAContextParams, "rope_freq_scale=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_freq_scale), 1);
986
- rb_define_method(rb_cLLaMAContextParams, "rope_freq_scale", RUBY_METHOD_FUNC(_llama_context_params_get_rope_freq_scale), 0);
987
- rb_define_method(rb_cLLaMAContextParams, "yarn_ext_factor=", RUBY_METHOD_FUNC(_llama_context_params_set_yarn_ext_factor), 1);
988
- rb_define_method(rb_cLLaMAContextParams, "yarn_ext_factor", RUBY_METHOD_FUNC(_llama_context_params_get_yarn_ext_factor), 0);
989
- rb_define_method(rb_cLLaMAContextParams, "yarn_attn_factor=", RUBY_METHOD_FUNC(_llama_context_params_set_yarn_attn_factor), 1);
990
- rb_define_method(rb_cLLaMAContextParams, "yarn_attn_factor", RUBY_METHOD_FUNC(_llama_context_params_get_yarn_attn_factor), 0);
991
- rb_define_method(rb_cLLaMAContextParams, "yarn_beta_fast=", RUBY_METHOD_FUNC(_llama_context_params_set_yarn_beta_fast), 1);
992
- rb_define_method(rb_cLLaMAContextParams, "yarn_beta_fast", RUBY_METHOD_FUNC(_llama_context_params_get_yarn_beta_fast), 0);
993
- rb_define_method(rb_cLLaMAContextParams, "yarn_beta_slow=", RUBY_METHOD_FUNC(_llama_context_params_set_yarn_beta_slow), 1);
994
- rb_define_method(rb_cLLaMAContextParams, "yarn_beta_slow", RUBY_METHOD_FUNC(_llama_context_params_get_yarn_beta_slow), 0);
995
- rb_define_method(rb_cLLaMAContextParams, "yarn_orig_ctx=", RUBY_METHOD_FUNC(_llama_context_params_set_yarn_orig_ctx), 1);
996
- rb_define_method(rb_cLLaMAContextParams, "yarn_orig_ctx", RUBY_METHOD_FUNC(_llama_context_params_get_yarn_orig_ctx), 0);
997
- rb_define_method(rb_cLLaMAContextParams, "defrag_thold=", RUBY_METHOD_FUNC(_llama_context_params_set_defrag_thold), 1);
998
- rb_define_method(rb_cLLaMAContextParams, "defrag_thold", RUBY_METHOD_FUNC(_llama_context_params_get_defrag_thold), 0);
999
- rb_define_method(rb_cLLaMAContextParams, "type_k=", RUBY_METHOD_FUNC(_llama_context_params_set_type_k), 1);
1000
- rb_define_method(rb_cLLaMAContextParams, "type_k", RUBY_METHOD_FUNC(_llama_context_params_get_type_k), 0);
1001
- rb_define_method(rb_cLLaMAContextParams, "type_v=", RUBY_METHOD_FUNC(_llama_context_params_set_type_v), 1);
1002
- rb_define_method(rb_cLLaMAContextParams, "type_v", RUBY_METHOD_FUNC(_llama_context_params_get_type_v), 0);
1003
- rb_define_method(rb_cLLaMAContextParams, "logits_all=", RUBY_METHOD_FUNC(_llama_context_params_set_logits_all), 1);
1004
- rb_define_method(rb_cLLaMAContextParams, "logits_all", RUBY_METHOD_FUNC(_llama_context_params_get_logits_all), 0);
1005
- rb_define_method(rb_cLLaMAContextParams, "embeddings=", RUBY_METHOD_FUNC(_llama_context_params_set_embeddings), 1);
1006
- rb_define_method(rb_cLLaMAContextParams, "embeddings", RUBY_METHOD_FUNC(_llama_context_params_get_embeddings), 0);
1007
- rb_define_method(rb_cLLaMAContextParams, "offload_kqv=", RUBY_METHOD_FUNC(_llama_context_params_set_offload_kqv), 1);
1008
- rb_define_method(rb_cLLaMAContextParams, "offload_kqv", RUBY_METHOD_FUNC(_llama_context_params_get_offload_kqv), 0);
1009
- rb_define_method(rb_cLLaMAContextParams, "flash_attn=", RUBY_METHOD_FUNC(_llama_context_params_set_flash_attn), 1);
1010
- rb_define_method(rb_cLLaMAContextParams, "flash_attn", RUBY_METHOD_FUNC(_llama_context_params_get_flash_attn), 0);
1011
- }
1012
-
1013
- private:
1014
- static const rb_data_type_t llama_context_params_type;
1015
-
1016
- // static VALUE _llama_context_params_init(VALUE self, VALUE seed) {
1017
- // LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1018
- // new (ptr) LLaMAContextParamsWrapper();
1019
- // return self;
1020
- // }
1021
-
1022
- // seed
1023
- static VALUE _llama_context_params_set_seed(VALUE self, VALUE seed) {
1024
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1025
- if (NUM2INT(seed) < 0) {
1026
- rb_raise(rb_eArgError, "seed must be positive");
1027
- return Qnil;
1028
- }
1029
- ptr->params.seed = NUM2INT(seed);
1030
- return INT2NUM(ptr->params.seed);
1031
- }
1032
-
1033
- static VALUE _llama_context_params_get_seed(VALUE self) {
1034
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1035
- return INT2NUM(ptr->params.seed);
1036
- }
1037
-
1038
- // n_ctx
1039
- static VALUE _llama_context_params_set_n_ctx(VALUE self, VALUE n_ctx) {
1040
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1041
- ptr->params.n_ctx = NUM2INT(n_ctx);
1042
- return INT2NUM(ptr->params.n_ctx);
1043
- }
1044
-
1045
- static VALUE _llama_context_params_get_n_ctx(VALUE self) {
1046
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1047
- return INT2NUM(ptr->params.n_ctx);
1048
- }
1049
-
1050
- // n_batch
1051
- static VALUE _llama_context_params_set_n_batch(VALUE self, VALUE n_batch) {
1052
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1053
- ptr->params.n_batch = NUM2INT(n_batch);
1054
- return INT2NUM(ptr->params.n_batch);
1055
- }
1056
-
1057
- static VALUE _llama_context_params_get_n_batch(VALUE self) {
1058
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1059
- return INT2NUM(ptr->params.n_batch);
1060
- }
1061
-
1062
- // n_ubatch
1063
- static VALUE _llama_context_params_set_n_ubatch(VALUE self, VALUE n_ubatch) {
1064
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1065
- ptr->params.n_ubatch = NUM2INT(n_ubatch);
1066
- return INT2NUM(ptr->params.n_ubatch);
1067
- }
1068
-
1069
- static VALUE _llama_context_params_get_n_ubatch(VALUE self) {
1070
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1071
- return INT2NUM(ptr->params.n_ubatch);
1072
- }
1073
-
1074
- // n_seq_max
1075
- static VALUE _llama_context_params_set_n_seq_max(VALUE self, VALUE n_seq_max) {
1076
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1077
- ptr->params.n_seq_max = NUM2INT(n_seq_max);
1078
- return INT2NUM(ptr->params.n_seq_max);
1079
- }
1080
-
1081
- static VALUE _llama_context_params_get_n_seq_max(VALUE self) {
1082
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1083
- return INT2NUM(ptr->params.n_seq_max);
1084
- }
1085
-
1086
- // n_threads
1087
- static VALUE _llama_context_params_set_n_threads(VALUE self, VALUE n_threads) {
1088
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1089
- ptr->params.n_threads = NUM2INT(n_threads);
1090
- return INT2NUM(ptr->params.n_threads);
1091
- }
1092
-
1093
- static VALUE _llama_context_params_get_n_threads(VALUE self) {
1094
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1095
- return INT2NUM(ptr->params.n_threads);
1096
- }
1097
-
1098
- // n_threads_batch
1099
- static VALUE _llama_context_params_set_n_threads_batch(VALUE self, VALUE n_threads_batch) {
1100
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1101
- ptr->params.n_threads_batch = NUM2INT(n_threads_batch);
1102
- return INT2NUM(ptr->params.n_threads_batch);
1103
- }
1104
-
1105
- static VALUE _llama_context_params_get_n_threads_batch(VALUE self) {
1106
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1107
- return INT2NUM(ptr->params.n_threads_batch);
1108
- }
1109
-
1110
- // rope_scaling_type
1111
- static VALUE _llama_context_params_set_rope_scaling_type(VALUE self, VALUE scaling_type) {
1112
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1113
- ptr->params.rope_scaling_type = static_cast<enum llama_rope_scaling_type>(NUM2INT(scaling_type));
1114
- return INT2NUM(ptr->params.rope_scaling_type);
1115
- }
1116
-
1117
- static VALUE _llama_context_params_get_rope_scaling_type(VALUE self) {
1118
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1119
- return INT2NUM(ptr->params.rope_scaling_type);
1120
- }
1121
-
1122
- // pooling_type
1123
- static VALUE _llama_context_params_set_pooling_type(VALUE self, VALUE scaling_type) {
1124
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1125
- ptr->params.pooling_type = static_cast<enum llama_pooling_type>(NUM2INT(scaling_type));
1126
- return INT2NUM(ptr->params.pooling_type);
1127
- }
1128
-
1129
- static VALUE _llama_context_params_get_pooling_type(VALUE self) {
1130
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1131
- return INT2NUM(ptr->params.pooling_type);
1132
- }
1133
-
1134
- // attention_type
1135
- static VALUE _llama_context_params_set_attention_type(VALUE self, VALUE scaling_type) {
1136
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1137
- ptr->params.attention_type = static_cast<enum llama_attention_type>(NUM2INT(scaling_type));
1138
- return INT2NUM(ptr->params.attention_type);
1139
- }
1140
-
1141
- static VALUE _llama_context_params_get_attention_type(VALUE self) {
1142
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1143
- return INT2NUM(ptr->params.attention_type);
1144
- }
1145
-
1146
- // rope_freq_base
1147
- static VALUE _llama_context_params_set_rope_freq_base(VALUE self, VALUE rope_freq_base) {
1148
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1149
- ptr->params.rope_freq_base = NUM2DBL(rope_freq_base);
1150
- return DBL2NUM(ptr->params.rope_freq_base);
1151
- }
1152
-
1153
- static VALUE _llama_context_params_get_rope_freq_base(VALUE self) {
1154
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1155
- return DBL2NUM(ptr->params.rope_freq_base);
1156
- }
1157
-
1158
- // rope_freq_scale
1159
- static VALUE _llama_context_params_set_rope_freq_scale(VALUE self, VALUE rope_freq_scale) {
1160
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1161
- ptr->params.rope_freq_scale = NUM2DBL(rope_freq_scale);
1162
- return DBL2NUM(ptr->params.rope_freq_scale);
1163
- }
1164
-
1165
- static VALUE _llama_context_params_get_rope_freq_scale(VALUE self) {
1166
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1167
- return DBL2NUM(ptr->params.rope_freq_scale);
1168
- }
1169
-
1170
- // yarn_ext_factor
1171
- static VALUE _llama_context_params_set_yarn_ext_factor(VALUE self, VALUE yarn_ext_factor) {
1172
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1173
- ptr->params.yarn_ext_factor = NUM2DBL(yarn_ext_factor);
1174
- return DBL2NUM(ptr->params.yarn_ext_factor);
1175
- }
1176
-
1177
- static VALUE _llama_context_params_get_yarn_ext_factor(VALUE self) {
1178
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1179
- return DBL2NUM(ptr->params.yarn_ext_factor);
1180
- }
1181
-
1182
- // yarn_attn_factor
1183
- static VALUE _llama_context_params_set_yarn_attn_factor(VALUE self, VALUE yarn_attn_factor) {
1184
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1185
- ptr->params.yarn_attn_factor = NUM2DBL(yarn_attn_factor);
1186
- return DBL2NUM(ptr->params.yarn_attn_factor);
1187
- }
1188
-
1189
- static VALUE _llama_context_params_get_yarn_attn_factor(VALUE self) {
1190
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1191
- return DBL2NUM(ptr->params.yarn_attn_factor);
1192
- }
1193
-
1194
- // yarn_beta_fast
1195
- static VALUE _llama_context_params_set_yarn_beta_fast(VALUE self, VALUE yarn_beta_fast) {
1196
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1197
- ptr->params.yarn_beta_fast = NUM2DBL(yarn_beta_fast);
1198
- return DBL2NUM(ptr->params.yarn_beta_fast);
1199
- }
1200
-
1201
- static VALUE _llama_context_params_get_yarn_beta_fast(VALUE self) {
1202
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1203
- return DBL2NUM(ptr->params.yarn_beta_fast);
1204
- }
1205
-
1206
- // yarn_beta_slow
1207
- static VALUE _llama_context_params_set_yarn_beta_slow(VALUE self, VALUE yarn_beta_slow) {
1208
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1209
- ptr->params.yarn_beta_slow = NUM2DBL(yarn_beta_slow);
1210
- return DBL2NUM(ptr->params.yarn_beta_slow);
1211
- }
1212
-
1213
- static VALUE _llama_context_params_get_yarn_beta_slow(VALUE self) {
1214
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1215
- return DBL2NUM(ptr->params.yarn_beta_slow);
1216
- }
1217
-
1218
- // yarn_orig_ctx
1219
- static VALUE _llama_context_params_set_yarn_orig_ctx(VALUE self, VALUE yarn_orig_ctx) {
1220
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1221
- ptr->params.yarn_orig_ctx = NUM2UINT(yarn_orig_ctx);
1222
- return UINT2NUM(ptr->params.yarn_orig_ctx);
1223
- }
1224
-
1225
- // defrag_thold
1226
- static VALUE _llama_context_params_set_defrag_thold(VALUE self, VALUE defrag_thold) {
1227
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1228
- ptr->params.defrag_thold = NUM2DBL(defrag_thold);
1229
- return DBL2NUM(ptr->params.defrag_thold);
1230
- }
1231
-
1232
- static VALUE _llama_context_params_get_defrag_thold(VALUE self) {
1233
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1234
- return DBL2NUM(ptr->params.defrag_thold);
1235
- }
1236
-
1237
- static VALUE _llama_context_params_get_yarn_orig_ctx(VALUE self) {
1238
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1239
- return UINT2NUM(ptr->params.yarn_orig_ctx);
1240
- }
1241
-
1242
- // type_k
1243
- static VALUE _llama_context_params_set_type_k(VALUE self, VALUE type_k) {
1244
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1245
- ptr->params.type_k = static_cast<enum ggml_type>(NUM2INT(type_k));
1246
- return INT2NUM(ptr->params.type_k);
1247
- }
1248
-
1249
- static VALUE _llama_context_params_get_type_k(VALUE self) {
1250
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1251
- return INT2NUM(ptr->params.type_k);
1252
- }
1253
-
1254
- // type_v
1255
- static VALUE _llama_context_params_set_type_v(VALUE self, VALUE type_v) {
1256
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1257
- ptr->params.type_v = static_cast<enum ggml_type>(NUM2INT(type_v));
1258
- return INT2NUM(ptr->params.type_v);
1259
- }
1260
-
1261
- static VALUE _llama_context_params_get_type_v(VALUE self) {
1262
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1263
- return INT2NUM(ptr->params.type_v);
1264
- }
1265
-
1266
- // logits_all
1267
- static VALUE _llama_context_params_set_logits_all(VALUE self, VALUE logits_all) {
1268
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1269
- ptr->params.logits_all = RTEST(logits_all) ? true : false;
1270
- return ptr->params.logits_all ? Qtrue : Qfalse;
1271
- }
1272
-
1273
- static VALUE _llama_context_params_get_logits_all(VALUE self) {
1274
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1275
- return ptr->params.logits_all ? Qtrue : Qfalse;
1276
- }
1277
-
1278
- // embeddings
1279
- static VALUE _llama_context_params_set_embeddings(VALUE self, VALUE embeddings) {
1280
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1281
- ptr->params.embeddings = RTEST(embeddings) ? true : false;
1282
- return ptr->params.embeddings ? Qtrue : Qfalse;
1283
- }
1284
-
1285
- static VALUE _llama_context_params_get_embeddings(VALUE self) {
1286
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1287
- return ptr->params.embeddings ? Qtrue : Qfalse;
1288
- }
1289
-
1290
- // offload_kqv
1291
- static VALUE _llama_context_params_set_offload_kqv(VALUE self, VALUE offload_kqv) {
1292
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1293
- ptr->params.offload_kqv = RTEST(offload_kqv) ? true : false;
1294
- return ptr->params.offload_kqv ? Qtrue : Qfalse;
1295
- }
1296
-
1297
- static VALUE _llama_context_params_get_offload_kqv(VALUE self) {
1298
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1299
- return ptr->params.offload_kqv ? Qtrue : Qfalse;
1300
- }
1301
-
1302
- // flash_attn
1303
- static VALUE _llama_context_params_set_flash_attn(VALUE self, VALUE flash_attn) {
1304
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1305
- ptr->params.flash_attn = RTEST(flash_attn) ? true : false;
1306
- return ptr->params.flash_attn ? Qtrue : Qfalse;
1307
- }
1308
-
1309
- static VALUE _llama_context_params_get_flash_attn(VALUE self) {
1310
- LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
1311
- return ptr->params.flash_attn ? Qtrue : Qfalse;
1312
- }
1313
- };
1314
-
1315
- const rb_data_type_t RbLLaMAContextParams::llama_context_params_type = {
1316
- "RbLLaMAContextParams",
1317
- { NULL,
1318
- RbLLaMAContextParams::llama_context_params_free,
1319
- RbLLaMAContextParams::llama_context_params_size },
1320
- NULL,
1321
- NULL,
1322
- RUBY_TYPED_FREE_IMMEDIATELY
1323
- };
1324
-
1325
- class LLaMAModelQuantizeParamsWrapper {
1326
- public:
1327
- llama_model_quantize_params params;
1328
-
1329
- LLaMAModelQuantizeParamsWrapper() : params(llama_model_quantize_default_params()) {}
1330
-
1331
- ~LLaMAModelQuantizeParamsWrapper() {}
1332
- };
1333
-
1334
- class RbLLaMAModelQuantizeParams {
1335
- public:
1336
- static VALUE llama_model_quantize_params_alloc(VALUE self) {
1337
- LLaMAModelQuantizeParamsWrapper* ptr = (LLaMAModelQuantizeParamsWrapper*)ruby_xmalloc(sizeof(LLaMAModelQuantizeParamsWrapper));
1338
- new (ptr) LLaMAModelQuantizeParamsWrapper();
1339
- return TypedData_Wrap_Struct(self, &llama_model_quantize_params_type, ptr);
1340
- }
1341
-
1342
- static void llama_model_quantize_params_free(void* ptr) {
1343
- ((LLaMAModelQuantizeParamsWrapper*)ptr)->~LLaMAModelQuantizeParamsWrapper();
1344
- ruby_xfree(ptr);
1345
- }
1346
-
1347
- static size_t llama_model_quantize_params_size(const void* ptr) {
1348
- return sizeof(*((LLaMAModelQuantizeParamsWrapper*)ptr));
1349
- }
1350
-
1351
- static LLaMAModelQuantizeParamsWrapper* get_llama_model_quantize_params(VALUE self) {
1352
- LLaMAModelQuantizeParamsWrapper* ptr;
1353
- TypedData_Get_Struct(self, LLaMAModelQuantizeParamsWrapper, &llama_model_quantize_params_type, ptr);
1354
- return ptr;
1355
- }
1356
-
1357
- static void define_class(VALUE outer) {
1358
- rb_cLLaMAModelQuantizeParams = rb_define_class_under(outer, "ModelQuantizeParams", rb_cObject);
1359
- rb_define_alloc_func(rb_cLLaMAModelQuantizeParams, llama_model_quantize_params_alloc);
1360
- rb_define_method(rb_cLLaMAModelQuantizeParams, "n_thread=", RUBY_METHOD_FUNC(_llama_model_quantize_params_set_n_thread), 1);
1361
- rb_define_method(rb_cLLaMAModelQuantizeParams, "n_thread", RUBY_METHOD_FUNC(_llama_model_quantize_params_get_n_thread), 0);
1362
- rb_define_method(rb_cLLaMAModelQuantizeParams, "ftype=", RUBY_METHOD_FUNC(_llama_model_quantize_params_set_ftype), 1);
1363
- rb_define_method(rb_cLLaMAModelQuantizeParams, "ftype", RUBY_METHOD_FUNC(_llama_model_quantize_params_get_ftype), 0);
1364
- rb_define_method(rb_cLLaMAModelQuantizeParams, "allow_requantize=", RUBY_METHOD_FUNC(_llama_model_quantize_params_set_allow_requantize), 1);
1365
- rb_define_method(rb_cLLaMAModelQuantizeParams, "allow_requantize", RUBY_METHOD_FUNC(_llama_model_quantize_params_get_allow_requantize), 0);
1366
- rb_define_method(rb_cLLaMAModelQuantizeParams, "quantize_output_tensor=", RUBY_METHOD_FUNC(_llama_model_quantize_params_set_quantize_output_tensor), 1);
1367
- rb_define_method(rb_cLLaMAModelQuantizeParams, "quantize_output_tensor", RUBY_METHOD_FUNC(_llama_model_quantize_params_get_quantize_output_tensor), 0);
1368
- rb_define_method(rb_cLLaMAModelQuantizeParams, "only_copy=", RUBY_METHOD_FUNC(_llama_model_quantize_params_set_only_copy), 1);
1369
- rb_define_method(rb_cLLaMAModelQuantizeParams, "only_copy", RUBY_METHOD_FUNC(_llama_model_quantize_params_get_only_copy), 0);
1370
- rb_define_method(rb_cLLaMAModelQuantizeParams, "pure=", RUBY_METHOD_FUNC(_llama_model_quantize_params_set_pure), 1);
1371
- rb_define_method(rb_cLLaMAModelQuantizeParams, "pure", RUBY_METHOD_FUNC(_llama_model_quantize_params_get_pure), 0);
1372
- rb_define_method(rb_cLLaMAModelQuantizeParams, "keep_split=", RUBY_METHOD_FUNC(_llama_model_quantize_params_set_keep_split), 1);
1373
- rb_define_method(rb_cLLaMAModelQuantizeParams, "keep_split", RUBY_METHOD_FUNC(_llama_model_quantize_params_get_keep_split), 0);
1374
- }
1375
-
1376
- private:
1377
- static const rb_data_type_t llama_model_quantize_params_type;
1378
-
1379
- // n_thread
1380
- static VALUE _llama_model_quantize_params_set_n_thread(VALUE self, VALUE n_thread) {
1381
- LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
1382
- ptr->params.nthread = NUM2INT(n_thread);
1383
- return INT2NUM(ptr->params.nthread);
1384
- }
1385
-
1386
- static VALUE _llama_model_quantize_params_get_n_thread(VALUE self) {
1387
- LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
1388
- return INT2NUM(ptr->params.nthread);
1389
- }
1390
-
1391
- // ftype
1392
- static VALUE _llama_model_quantize_params_set_ftype(VALUE self, VALUE ftype) {
1393
- LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
1394
- ptr->params.ftype = static_cast<enum llama_ftype>(NUM2INT(ftype));
1395
- return INT2NUM(ptr->params.ftype);
1396
- }
1397
-
1398
- static VALUE _llama_model_quantize_params_get_ftype(VALUE self) {
1399
- LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
1400
- return INT2NUM(ptr->params.ftype);
1401
- }
1402
-
1403
- // allow_requantize
1404
- static VALUE _llama_model_quantize_params_set_allow_requantize(VALUE self, VALUE allow_requantize) {
1405
- LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
1406
- if (NIL_P(allow_requantize) || allow_requantize == Qfalse) {
1407
- ptr->params.allow_requantize = false;
1408
- } else {
1409
- ptr->params.allow_requantize = true;
1410
- }
1411
- return ptr->params.allow_requantize ? Qtrue : Qfalse;
1412
- }
1413
-
1414
- static VALUE _llama_model_quantize_params_get_allow_requantize(VALUE self) {
1415
- LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
1416
- return ptr->params.allow_requantize ? Qtrue : Qfalse;
1417
- }
1418
-
1419
- // quantize_output_tensor
1420
- static VALUE _llama_model_quantize_params_set_quantize_output_tensor(VALUE self, VALUE quantize_output_tensor) {
1421
- LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
1422
- if (NIL_P(quantize_output_tensor) || quantize_output_tensor == Qfalse) {
1423
- ptr->params.quantize_output_tensor = false;
1424
- } else {
1425
- ptr->params.quantize_output_tensor = true;
1426
- }
1427
- return ptr->params.quantize_output_tensor ? Qtrue : Qfalse;
1428
- }
1429
-
1430
- static VALUE _llama_model_quantize_params_get_quantize_output_tensor(VALUE self) {
1431
- LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
1432
- return ptr->params.quantize_output_tensor ? Qtrue : Qfalse;
1433
- }
1434
-
1435
- // only_copy
1436
- static VALUE _llama_model_quantize_params_set_only_copy(VALUE self, VALUE only_copy) {
1437
- LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
1438
- ptr->params.only_copy = RTEST(only_copy) ? true : false;
1439
- return ptr->params.only_copy ? Qtrue : Qfalse;
1440
- }
1441
-
1442
- static VALUE _llama_model_quantize_params_get_only_copy(VALUE self) {
1443
- LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
1444
- return ptr->params.only_copy ? Qtrue : Qfalse;
1445
- }
1446
-
1447
- // pure
1448
- static VALUE _llama_model_quantize_params_set_pure(VALUE self, VALUE pure) {
1449
- LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
1450
- ptr->params.pure = RTEST(pure) ? true : false;
1451
- return ptr->params.pure ? Qtrue : Qfalse;
1452
- }
1453
-
1454
- static VALUE _llama_model_quantize_params_get_pure(VALUE self) {
1455
- LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
1456
- return ptr->params.pure ? Qtrue : Qfalse;
1457
- }
1458
-
1459
- // keep_split
1460
- static VALUE _llama_model_quantize_params_set_keep_split(VALUE self, VALUE keep_split) {
1461
- LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
1462
- ptr->params.keep_split = RTEST(keep_split) ? true : false;
1463
- return ptr->params.keep_split ? Qtrue : Qfalse;
1464
- }
1465
-
1466
- static VALUE _llama_model_quantize_params_get_keep_split(VALUE self) {
1467
- LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
1468
- return ptr->params.keep_split ? Qtrue : Qfalse;
1469
- }
1470
- };
1471
-
1472
- const rb_data_type_t RbLLaMAModelQuantizeParams::llama_model_quantize_params_type = {
1473
- "RbLLaMAModelQuantizeParams",
1474
- { NULL,
1475
- RbLLaMAModelQuantizeParams::llama_model_quantize_params_free,
1476
- RbLLaMAModelQuantizeParams::llama_model_quantize_params_size },
1477
- NULL,
1478
- NULL,
1479
- RUBY_TYPED_FREE_IMMEDIATELY
1480
- };
1481
-
1482
- class LLaMAModelWrapper {
1483
- public:
1484
- struct llama_model* model;
1485
-
1486
- LLaMAModelWrapper() : model(NULL) {}
1487
-
1488
- ~LLaMAModelWrapper() {
1489
- if (model != NULL) {
1490
- llama_free_model(model);
1491
- }
1492
- }
1493
- };
1494
-
1495
- class RbLLaMAModel {
1496
- public:
1497
- static VALUE llama_model_alloc(VALUE self) {
1498
- LLaMAModelWrapper* ptr = (LLaMAModelWrapper*)ruby_xmalloc(sizeof(LLaMAModelWrapper));
1499
- new (ptr) LLaMAModelWrapper();
1500
- return TypedData_Wrap_Struct(self, &llama_model_type, ptr);
1501
- }
1502
-
1503
- static void llama_model_free(void* ptr) {
1504
- ((LLaMAModelWrapper*)ptr)->~LLaMAModelWrapper();
1505
- ruby_xfree(ptr);
1506
- }
1507
-
1508
- static size_t llama_model_size(const void* ptr) {
1509
- return sizeof(*((LLaMAModelWrapper*)ptr));
1510
- }
1511
-
1512
- static LLaMAModelWrapper* get_llama_model(VALUE self) {
1513
- LLaMAModelWrapper* ptr;
1514
- TypedData_Get_Struct(self, LLaMAModelWrapper, &llama_model_type, ptr);
1515
- return ptr;
1516
- }
1517
-
1518
- static void define_class(VALUE outer) {
1519
- rb_cLLaMAModel = rb_define_class_under(outer, "Model", rb_cObject);
1520
- rb_define_alloc_func(rb_cLLaMAModel, llama_model_alloc);
1521
- rb_define_attr(rb_cLLaMAModel, "params", 1, 0);
1522
- rb_define_method(rb_cLLaMAModel, "initialize", RUBY_METHOD_FUNC(_llama_model_initialize), -1);
1523
- rb_define_method(rb_cLLaMAModel, "empty?", RUBY_METHOD_FUNC(_llama_model_empty), 0);
1524
- rb_define_method(rb_cLLaMAModel, "free", RUBY_METHOD_FUNC(_llama_model_free), 0);
1525
- rb_define_method(rb_cLLaMAModel, "load", RUBY_METHOD_FUNC(_llama_model_load), -1);
1526
- rb_define_method(rb_cLLaMAModel, "vocab_type", RUBY_METHOD_FUNC(_llama_model_get_model_vocab_type), 0);
1527
- rb_define_method(rb_cLLaMAModel, "rope_type", RUBY_METHOD_FUNC(_llama_model_get_model_rope_type), 0);
1528
- rb_define_method(rb_cLLaMAModel, "n_vocab", RUBY_METHOD_FUNC(_llama_model_get_model_n_vocab), 0);
1529
- rb_define_method(rb_cLLaMAModel, "n_ctx_train", RUBY_METHOD_FUNC(_llama_model_get_model_n_ctx_train), 0);
1530
- rb_define_method(rb_cLLaMAModel, "n_embd", RUBY_METHOD_FUNC(_llama_model_get_model_n_embd), 0);
1531
- rb_define_method(rb_cLLaMAModel, "n_layer", RUBY_METHOD_FUNC(_llama_model_get_model_n_layer), 0);
1532
- rb_define_method(rb_cLLaMAModel, "rope_freq_scale_train", RUBY_METHOD_FUNC(_llama_model_rope_freq_scale_train), 0);
1533
- rb_define_method(rb_cLLaMAModel, "token_to_piece", RUBY_METHOD_FUNC(_llama_model_token_to_piece), -1);
1534
- rb_define_method(rb_cLLaMAModel, "tokenize", RUBY_METHOD_FUNC(_llama_model_tokenize), -1);
1535
- rb_define_method(rb_cLLaMAModel, "desc", RUBY_METHOD_FUNC(_llama_model_get_model_desc), 0);
1536
- rb_define_method(rb_cLLaMAModel, "size", RUBY_METHOD_FUNC(_llama_model_get_model_size), 0);
1537
- rb_define_method(rb_cLLaMAModel, "n_params", RUBY_METHOD_FUNC(_llama_model_get_model_n_params), 0);
1538
- rb_define_method(rb_cLLaMAModel, "text", RUBY_METHOD_FUNC(_llama_model_get_text), 1);
1539
- rb_define_method(rb_cLLaMAModel, "score", RUBY_METHOD_FUNC(_llama_model_get_score), 1);
1540
- rb_define_method(rb_cLLaMAModel, "token_attr", RUBY_METHOD_FUNC(_llama_model_get_token_attr), 1);
1541
- rb_define_method(rb_cLLaMAModel, "token_bos", RUBY_METHOD_FUNC(_llama_model_token_bos), 0);
1542
- rb_define_method(rb_cLLaMAModel, "token_eos", RUBY_METHOD_FUNC(_llama_model_token_eos), 0);
1543
- rb_define_method(rb_cLLaMAModel, "token_cls", RUBY_METHOD_FUNC(_llama_model_token_cls), 0);
1544
- rb_define_method(rb_cLLaMAModel, "token_sep", RUBY_METHOD_FUNC(_llama_model_token_sep), 0);
1545
- rb_define_method(rb_cLLaMAModel, "token_nl", RUBY_METHOD_FUNC(_llama_model_token_nl), 0);
1546
- rb_define_method(rb_cLLaMAModel, "token_pad", RUBY_METHOD_FUNC(_llama_model_token_pad), 0);
1547
- rb_define_method(rb_cLLaMAModel, "add_bos_token?", RUBY_METHOD_FUNC(_llama_model_add_bos_token), 0);
1548
- rb_define_method(rb_cLLaMAModel, "add_eos_token?", RUBY_METHOD_FUNC(_llama_model_add_eos_token), 0);
1549
- rb_define_method(rb_cLLaMAModel, "token_prefix", RUBY_METHOD_FUNC(_llama_model_token_prefix), 0);
1550
- rb_define_method(rb_cLLaMAModel, "token_middle", RUBY_METHOD_FUNC(_llama_model_token_middle), 0);
1551
- rb_define_method(rb_cLLaMAModel, "token_suffix", RUBY_METHOD_FUNC(_llama_model_token_suffix), 0);
1552
- rb_define_method(rb_cLLaMAModel, "token_eot", RUBY_METHOD_FUNC(_llama_model_token_eot), 0);
1553
- rb_define_method(rb_cLLaMAModel, "token_is_eog?", RUBY_METHOD_FUNC(_llama_model_token_is_eog), 1);
1554
- rb_define_method(rb_cLLaMAModel, "token_is_control?", RUBY_METHOD_FUNC(_llama_model_token_is_control), 1);
1555
- rb_define_method(rb_cLLaMAModel, "has_encoder?", RUBY_METHOD_FUNC(_llama_model_has_encoder), 0);
1556
- rb_define_method(rb_cLLaMAModel, "has_decoder?", RUBY_METHOD_FUNC(_llama_model_has_decoder), 0);
1557
- rb_define_method(rb_cLLaMAModel, "decoder_start_token", RUBY_METHOD_FUNC(_llama_model_decoder_start_token), 0);
1558
- rb_define_method(rb_cLLaMAModel, "is_recurrent?", RUBY_METHOD_FUNC(_llama_model_is_recurrent), 0);
1559
- rb_define_method(rb_cLLaMAModel, "detokenize", RUBY_METHOD_FUNC(_llama_model_detokenize), -1);
1560
- }
1561
-
1562
- private:
1563
- static const rb_data_type_t llama_model_type;
1564
-
1565
- static VALUE _llama_model_initialize(int argc, VALUE* argv, VALUE self) {
1566
- VALUE kw_args = Qnil;
1567
- ID kw_table[2] = { rb_intern("model_path"), rb_intern("params") };
1568
- VALUE kw_values[2] = { Qundef, Qundef };
1569
- rb_scan_args(argc, argv, ":", &kw_args);
1570
- rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
1571
-
1572
- if (!RB_TYPE_P(kw_values[0], T_STRING)) {
1573
- rb_raise(rb_eArgError, "model_path must be a string");
1574
- return Qnil;
1575
- }
1576
- if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAModelParams)) {
1577
- rb_raise(rb_eArgError, "params must be a ModelParams");
1578
- return Qnil;
1579
- }
1580
-
1581
- VALUE filename = kw_values[0];
1582
- LLaMAModelParamsWrapper* prms_ptr = RbLLaMAModelParams::get_llama_model_params(kw_values[1]);
1583
- LLaMAModelWrapper* model_ptr = get_llama_model(self);
1584
-
1585
- try {
1586
- model_ptr->model = llama_load_model_from_file(StringValueCStr(filename), prms_ptr->params);
1587
- } catch (const std::runtime_error& e) {
1588
- rb_raise(rb_eRuntimeError, "%s", e.what());
1589
- return Qnil;
1590
- }
1591
-
1592
- if (model_ptr->model == NULL) {
1593
- rb_raise(rb_eRuntimeError, "Failed to initialize LLaMA model");
1594
- return Qnil;
1595
- }
1596
-
1597
- rb_iv_set(self, "@params", kw_values[1]);
1598
-
1599
- RB_GC_GUARD(filename);
1600
- return Qnil;
1601
- }
1602
-
1603
- static VALUE _llama_model_empty(VALUE self) {
1604
- LLaMAModelWrapper* ptr = get_llama_model(self);
1605
- if (ptr->model != NULL) {
1606
- return Qfalse;
1607
- }
1608
- return Qtrue;
1609
- }
1610
-
1611
- static VALUE _llama_model_free(VALUE self) {
1612
- LLaMAModelWrapper* ptr = get_llama_model(self);
1613
- if (ptr->model != NULL) {
1614
- llama_free_model(ptr->model);
1615
- ptr->model = NULL;
1616
- rb_iv_set(self, "@params", Qnil);
1617
- }
1618
- return Qnil;
1619
- }
1620
-
1621
- static VALUE _llama_model_load(int argc, VALUE* argv, VALUE self) {
1622
- VALUE kw_args = Qnil;
1623
- ID kw_table[2] = { rb_intern("model_path"), rb_intern("params") };
1624
- VALUE kw_values[2] = { Qundef, Qundef };
1625
- rb_scan_args(argc, argv, ":", &kw_args);
1626
- rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
1627
-
1628
- if (!RB_TYPE_P(kw_values[0], T_STRING)) {
1629
- rb_raise(rb_eArgError, "model_path must be a string");
1630
- return Qnil;
1631
- }
1632
- if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAModelParams)) {
1633
- rb_raise(rb_eArgError, "params must be a LLaMAModelParams");
1634
- return Qnil;
1635
- }
1636
-
1637
- LLaMAModelWrapper* model_ptr = get_llama_model(self);
1638
- if (model_ptr->model != NULL) {
1639
- rb_raise(rb_eRuntimeError, "LLaMA model is already loaded");
1640
- return Qnil;
1641
- }
1642
-
1643
- VALUE filename = kw_values[0];
1644
- LLaMAModelParamsWrapper* prms_ptr = RbLLaMAModelParams::get_llama_model_params(kw_values[1]);
1645
-
1646
- try {
1647
- model_ptr->model = llama_load_model_from_file(StringValueCStr(filename), prms_ptr->params);
1648
- } catch (const std::runtime_error& e) {
1649
- rb_raise(rb_eRuntimeError, "%s", e.what());
1650
- return Qnil;
1651
- }
1652
-
1653
- if (model_ptr->model == NULL) {
1654
- rb_raise(rb_eRuntimeError, "Failed to initialize LLaMA model");
1655
- return Qnil;
1656
- }
1657
-
1658
- rb_iv_set(self, "@params", kw_values[1]);
1659
-
1660
- RB_GC_GUARD(filename);
1661
- return Qnil;
1662
- }
1663
-
1664
- static VALUE _llama_model_get_model_vocab_type(VALUE self) {
1665
- LLaMAModelWrapper* ptr = get_llama_model(self);
1666
- return INT2NUM(llama_vocab_type(ptr->model));
1667
- }
1668
-
1669
- static VALUE _llama_model_get_model_rope_type(VALUE self) {
1670
- LLaMAModelWrapper* ptr = get_llama_model(self);
1671
- return INT2NUM(llama_rope_type(ptr->model));
1672
- }
1673
-
1674
- static VALUE _llama_model_get_model_n_vocab(VALUE self) {
1675
- LLaMAModelWrapper* ptr = get_llama_model(self);
1676
- return INT2NUM(llama_n_vocab(ptr->model));
1677
- }
1678
-
1679
- static VALUE _llama_model_get_model_n_ctx_train(VALUE self) {
1680
- LLaMAModelWrapper* ptr = get_llama_model(self);
1681
- return INT2NUM(llama_n_ctx_train(ptr->model));
1682
- }
1683
-
1684
- static VALUE _llama_model_get_model_n_embd(VALUE self) {
1685
- LLaMAModelWrapper* ptr = get_llama_model(self);
1686
- return INT2NUM(llama_n_embd(ptr->model));
1687
- }
1688
-
1689
- static VALUE _llama_model_get_model_n_layer(VALUE self) {
1690
- LLaMAModelWrapper* ptr = get_llama_model(self);
1691
- return INT2NUM(llama_n_layer(ptr->model));
1692
- }
1693
-
1694
- static VALUE _llama_model_rope_freq_scale_train(VALUE self) {
1695
- LLaMAModelWrapper* ptr = get_llama_model(self);
1696
- return DBL2NUM(llama_rope_freq_scale_train(ptr->model));
1697
- }
1698
-
1699
- static VALUE _llama_model_token_to_piece(int argc, VALUE* argv, VALUE self) {
1700
- VALUE kw_args = Qnil;
1701
- ID kw_table[2] = { rb_intern("lstrip"), rb_intern("special") };
1702
- VALUE kw_values[2] = { Qundef, Qundef };
1703
- VALUE token_ = Qnil;
1704
- rb_scan_args(argc, argv, "1:", &token_, &kw_args);
1705
- rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
1706
-
1707
- if (!RB_INTEGER_TYPE_P(token_)) {
1708
- rb_raise(rb_eArgError, "token must be an integer");
1709
- return Qnil;
1710
- }
1711
- if (kw_values[0] != Qundef && !RB_INTEGER_TYPE_P(kw_values[0])) {
1712
- rb_raise(rb_eArgError, "lstrip must be an integer");
1713
- return Qnil;
1714
- }
1715
-
1716
- const llama_token token = NUM2INT(token_);
1717
- const int32_t lstrip = kw_values[0] != Qundef ? NUM2INT(kw_values[0]) : 0;
1718
- const bool special = kw_values[1] != Qundef ? RTEST(kw_values[1]) : false;
1719
-
1720
- LLaMAModelWrapper* ptr = get_llama_model(self);
1721
- std::vector<char> result(8, 0);
1722
- const int n_tokens = llama_token_to_piece(ptr->model, token, result.data(), result.size(), lstrip, special);
1723
- if (n_tokens < 0) {
1724
- result.resize(-n_tokens);
1725
- const int check = llama_token_to_piece(ptr->model, token, result.data(), result.size(), lstrip, special);
1726
- if (check != -n_tokens) {
1727
- rb_raise(rb_eRuntimeError, "failed to convert");
1728
- return Qnil;
1729
- }
1730
- } else {
1731
- result.resize(n_tokens);
1732
- }
1733
- std::string ret(result.data(), result.size());
1734
- return rb_utf8_str_new_cstr(ret.c_str());
1735
- }
1736
-
1737
- static VALUE _llama_model_tokenize(int argc, VALUE* argv, VALUE self) {
1738
- VALUE kw_args = Qnil;
1739
- ID kw_table[4] = { rb_intern("text"), rb_intern("n_max_tokens"), rb_intern("add_bos"), rb_intern("special") };
1740
- VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
1741
- rb_scan_args(argc, argv, ":", &kw_args);
1742
- rb_get_kwargs(kw_args, kw_table, 1, 3, kw_values);
1743
-
1744
- if (!RB_TYPE_P(kw_values[0], T_STRING)) {
1745
- rb_raise(rb_eArgError, "text must be a String");
1746
- return Qnil;
1747
- }
1748
- if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
1749
- rb_raise(rb_eArgError, "n_max_tokens must be an integer");
1750
- return Qnil;
1751
- }
1752
- if (kw_values[2] != Qundef && (kw_values[2] != Qtrue && kw_values[2] != Qfalse)) {
1753
- rb_raise(rb_eArgError, "add_bos must be a boolean");
1754
- return Qnil;
1755
- }
1756
- if (kw_values[3] != Qundef && (kw_values[3] != Qtrue && kw_values[3] != Qfalse)) {
1757
- rb_raise(rb_eArgError, "special must be a boolean");
1758
- return Qnil;
1759
- }
1760
-
1761
- VALUE text_ = kw_values[0];
1762
- std::string text = StringValueCStr(text_);
1763
- const bool add_bos = kw_values[2] == Qtrue ? true : false;
1764
- const bool special = kw_values[3] == Qtrue ? true : false;
1765
- const int n_max_tokens = kw_values[1] != Qundef ? NUM2INT(kw_values[1]) : text.size() + (add_bos ? 1 : 0);
1766
-
1767
- llama_token* tokens = ALLOCA_N(llama_token, n_max_tokens);
1768
- LLaMAModelWrapper* ptr = get_llama_model(self);
1769
- const int n_tokens = llama_tokenize(ptr->model, text.c_str(), text.size(), tokens, n_max_tokens, add_bos, special);
1770
-
1771
- if (n_tokens < 0) {
1772
- rb_raise(rb_eRuntimeError, "failed to tokenize. The numebr of tokens (%d) is greater than n_max_tokens.", -n_tokens);
1773
- return Qnil;
1774
- }
1775
-
1776
- VALUE ret = rb_ary_new2(n_tokens);
1777
- for (int i = 0; i < n_tokens; i++) {
1778
- rb_ary_store(ret, i, INT2NUM(tokens[i]));
1779
- }
1780
-
1781
- RB_GC_GUARD(text_);
1782
- return ret;
1783
- }
1784
-
1785
- static VALUE _llama_model_get_model_desc(VALUE self) {
1786
- LLaMAModelWrapper* ptr = get_llama_model(self);
1787
- char buf[128];
1788
- llama_model_desc(ptr->model, buf, sizeof(buf));
1789
- return rb_utf8_str_new_cstr(buf);
1790
- }
1791
-
1792
- static VALUE _llama_model_get_model_size(VALUE self) {
1793
- LLaMAModelWrapper* ptr = get_llama_model(self);
1794
- return UINT2NUM(llama_model_size(ptr->model));
1795
- }
1796
-
1797
- static VALUE _llama_model_get_model_n_params(VALUE self) {
1798
- LLaMAModelWrapper* ptr = get_llama_model(self);
1799
- return UINT2NUM(llama_model_n_params(ptr->model));
1800
- }
1801
-
1802
- static VALUE _llama_model_get_text(VALUE self, VALUE token_) {
1803
- LLaMAModelWrapper* ptr = get_llama_model(self);
1804
- const llama_token token = NUM2INT(token_);
1805
- const char* text = llama_token_get_text(ptr->model, token);
1806
- return rb_utf8_str_new_cstr(text);
1807
- }
1808
-
1809
- static VALUE _llama_model_get_score(VALUE self, VALUE token_) {
1810
- LLaMAModelWrapper* ptr = get_llama_model(self);
1811
- const llama_token token = NUM2INT(token_);
1812
- const float score = llama_token_get_score(ptr->model, token);
1813
- return DBL2NUM(score);
1814
- }
1815
-
1816
- static VALUE _llama_model_get_token_attr(VALUE self, VALUE token_) {
1817
- LLaMAModelWrapper* ptr = get_llama_model(self);
1818
- const llama_token token = NUM2INT(token_);
1819
- const llama_token_attr type = llama_token_get_attr(ptr->model, token);
1820
- return INT2NUM(type);
1821
- }
1822
-
1823
- static VALUE _llama_model_token_bos(VALUE self) {
1824
- LLaMAModelWrapper* ptr = get_llama_model(self);
1825
- return INT2NUM(llama_token_bos(ptr->model));
1826
- }
1827
-
1828
- static VALUE _llama_model_token_eos(VALUE self) {
1829
- LLaMAModelWrapper* ptr = get_llama_model(self);
1830
- return INT2NUM(llama_token_eos(ptr->model));
1831
- }
1832
-
1833
- static VALUE _llama_model_token_cls(VALUE self) {
1834
- LLaMAModelWrapper* ptr = get_llama_model(self);
1835
- return INT2NUM(llama_token_cls(ptr->model));
1836
- }
1837
-
1838
- static VALUE _llama_model_token_sep(VALUE self) {
1839
- LLaMAModelWrapper* ptr = get_llama_model(self);
1840
- return INT2NUM(llama_token_sep(ptr->model));
1841
- }
1842
-
1843
- static VALUE _llama_model_token_nl(VALUE self) {
1844
- LLaMAModelWrapper* ptr = get_llama_model(self);
1845
- return INT2NUM(llama_token_nl(ptr->model));
1846
- }
1847
-
1848
- static VALUE _llama_model_token_pad(VALUE self) {
1849
- LLaMAModelWrapper* ptr = get_llama_model(self);
1850
- return INT2NUM(llama_token_pad(ptr->model));
1851
- }
1852
-
1853
- static VALUE _llama_model_add_bos_token(VALUE self) {
1854
- LLaMAModelWrapper* ptr = get_llama_model(self);
1855
- return llama_add_bos_token(ptr->model) ? Qtrue : Qfalse;
1856
- }
1857
-
1858
- static VALUE _llama_model_add_eos_token(VALUE self) {
1859
- LLaMAModelWrapper* ptr = get_llama_model(self);
1860
- return llama_add_eos_token(ptr->model) ? Qtrue : Qfalse;
1861
- }
1862
-
1863
- static VALUE _llama_model_token_prefix(VALUE self) {
1864
- LLaMAModelWrapper* ptr = get_llama_model(self);
1865
- return INT2NUM(llama_token_prefix(ptr->model));
1866
- }
1867
-
1868
- static VALUE _llama_model_token_middle(VALUE self) {
1869
- LLaMAModelWrapper* ptr = get_llama_model(self);
1870
- return INT2NUM(llama_token_middle(ptr->model));
1871
- }
1872
-
1873
- static VALUE _llama_model_token_suffix(VALUE self) {
1874
- LLaMAModelWrapper* ptr = get_llama_model(self);
1875
- return INT2NUM(llama_token_suffix(ptr->model));
1876
- }
1877
-
1878
- static VALUE _llama_model_token_eot(VALUE self) {
1879
- LLaMAModelWrapper* ptr = get_llama_model(self);
1880
- return INT2NUM(llama_token_eot(ptr->model));
1881
- }
1882
-
1883
- static VALUE _llama_model_token_is_eog(VALUE self, VALUE token_) {
1884
- if (!RB_INTEGER_TYPE_P(token_)) {
1885
- rb_raise(rb_eArgError, "token must be an integer");
1886
- return Qnil;
1887
- }
1888
- const llama_token token = NUM2INT(token_);
1889
- LLaMAModelWrapper* ptr = get_llama_model(self);
1890
- return llama_token_is_eog(ptr->model, token) ? Qtrue : Qfalse;
1891
- }
1892
-
1893
- static VALUE _llama_model_token_is_control(VALUE self, VALUE token_) {
1894
- if (!RB_INTEGER_TYPE_P(token_)) {
1895
- rb_raise(rb_eArgError, "token must be an integer");
1896
- return Qnil;
1897
- }
1898
- const llama_token token = NUM2INT(token_);
1899
- LLaMAModelWrapper* ptr = get_llama_model(self);
1900
- return llama_token_is_control(ptr->model, token) ? Qtrue : Qfalse;
1901
- }
1902
-
1903
- static VALUE _llama_model_has_encoder(VALUE self) {
1904
- LLaMAModelWrapper* ptr = get_llama_model(self);
1905
- return llama_model_has_encoder(ptr->model) ? Qtrue : Qfalse;
1906
- }
1907
-
1908
- static VALUE _llama_model_has_decoder(VALUE self) {
1909
- LLaMAModelWrapper* ptr = get_llama_model(self);
1910
- return llama_model_has_decoder(ptr->model) ? Qtrue : Qfalse;
1911
- }
1912
-
1913
- static VALUE _llama_model_decoder_start_token(VALUE self) {
1914
- LLaMAModelWrapper* ptr = get_llama_model(self);
1915
- return INT2NUM(llama_model_decoder_start_token(ptr->model));
1916
- }
1917
-
1918
- static VALUE _llama_model_is_recurrent(VALUE self) {
1919
- LLaMAModelWrapper* ptr = get_llama_model(self);
1920
- return llama_model_is_recurrent(ptr->model) ? Qtrue : Qfalse;
1921
- }
1922
-
1923
- static VALUE _llama_model_detokenize(int argc, VALUE* argv, VALUE self) {
1924
- VALUE kw_args = Qnil;
1925
- ID kw_table[2] = { rb_intern("remove_special"), rb_intern("unparse_special") };
1926
- VALUE kw_values[2] = { Qundef, Qundef };
1927
- VALUE tokens_ = Qnil;
1928
- rb_scan_args(argc, argv, "1:", &tokens_, &kw_args);
1929
- rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
1930
-
1931
- if (!RB_TYPE_P(tokens_, T_ARRAY)) {
1932
- rb_raise(rb_eArgError, "tokens must be an array");
1933
- return Qnil;
1934
- }
1935
-
1936
- const int32_t n_tokens = RARRAY_LEN(tokens_);
1937
- llama_token* tokens = ALLOCA_N(llama_token, n_tokens);
1938
- for (int32_t i = 0; i < n_tokens; i++) {
1939
- tokens[i] = NUM2INT(rb_ary_entry(tokens_, i));
1940
- }
1941
-
1942
- std::string text;
1943
- text.resize(std::max(text.capacity(), static_cast<unsigned long>(n_tokens)));
1944
- const int32_t text_len_max = text.size();
1945
-
1946
- bool remove_special = kw_values[0] != Qundef ? RTEST(kw_values[0]) : false;
1947
- bool unparse_special = kw_values[1] != Qundef ? RTEST(kw_values[1]) : false;
1948
-
1949
- LLaMAModelWrapper* ptr = get_llama_model(self);
1950
- std::string result;
1951
- int32_t n_chars = llama_detokenize(ptr->model, tokens, n_tokens, &text[0], text_len_max, remove_special, unparse_special);
1952
- if (n_chars < 0) {
1953
- text.resize(-n_chars);
1954
- n_chars = llama_detokenize(ptr->model, tokens, n_tokens, &text[0], text_len_max, remove_special, unparse_special);
1955
- if (n_chars <= text.size()) {
1956
- rb_raise(rb_eRuntimeError, "Failed to detokenize");
1957
- return Qnil;
1958
- }
1959
- }
1960
-
1961
- text.resize(n_chars);
1962
- return rb_utf8_str_new_cstr(text.c_str());
1963
- }
1964
- };
1965
-
1966
- const rb_data_type_t RbLLaMAModel::llama_model_type = {
1967
- "RbLLaMAModel",
1968
- { NULL,
1969
- RbLLaMAModel::llama_model_free,
1970
- RbLLaMAModel::llama_model_size },
1971
- NULL,
1972
- NULL,
1973
- RUBY_TYPED_FREE_IMMEDIATELY
1974
- };
1975
-
1976
- class LLaMAGrammarElementWrapper {
1977
- public:
1978
- llama_grammar_element element;
1979
-
1980
- LLaMAGrammarElementWrapper() {
1981
- element.type = LLAMA_GRETYPE_END;
1982
- element.value = 0;
1983
- }
1984
-
1985
- ~LLaMAGrammarElementWrapper() {}
1986
- };
1987
-
1988
- class RbLLaMAGrammarElement {
1989
- public:
1990
- static VALUE llama_grammar_element_alloc(VALUE self) {
1991
- LLaMAGrammarElementWrapper* ptr = (LLaMAGrammarElementWrapper*)ruby_xmalloc(sizeof(LLaMAGrammarElementWrapper));
1992
- new (ptr) LLaMAGrammarElementWrapper();
1993
- return TypedData_Wrap_Struct(self, &llama_grammar_element_type, ptr);
1994
- }
1995
-
1996
- static void llama_grammar_element_free(void* ptr) {
1997
- ((LLaMAGrammarElementWrapper*)ptr)->~LLaMAGrammarElementWrapper();
1998
- ruby_xfree(ptr);
1999
- }
2000
-
2001
- static size_t llama_grammar_element_size(const void* ptr) {
2002
- return sizeof(*((LLaMAGrammarElementWrapper*)ptr));
2003
- }
2004
-
2005
- static LLaMAGrammarElementWrapper* get_llama_grammar_element(VALUE self) {
2006
- LLaMAGrammarElementWrapper* ptr;
2007
- TypedData_Get_Struct(self, LLaMAGrammarElementWrapper, &llama_grammar_element_type, ptr);
2008
- return ptr;
2009
- }
2010
-
2011
- static void define_class(VALUE outer) {
2012
- rb_cLLaMAGrammarElement = rb_define_class_under(outer, "GrammarElement", rb_cObject);
2013
- rb_define_alloc_func(rb_cLLaMAGrammarElement, llama_grammar_element_alloc);
2014
- rb_define_method(rb_cLLaMAGrammarElement, "initialize", RUBY_METHOD_FUNC(_llama_grammar_element_init), -1);
2015
- rb_define_method(rb_cLLaMAGrammarElement, "type=", RUBY_METHOD_FUNC(_llama_grammar_element_set_type), 1);
2016
- rb_define_method(rb_cLLaMAGrammarElement, "type", RUBY_METHOD_FUNC(_llama_grammar_element_get_type), 0);
2017
- rb_define_method(rb_cLLaMAGrammarElement, "value=", RUBY_METHOD_FUNC(_llama_grammar_element_set_value), 1);
2018
- rb_define_method(rb_cLLaMAGrammarElement, "value", RUBY_METHOD_FUNC(_llama_grammar_element_get_value), 0);
2019
- }
2020
-
2021
- private:
2022
- static const rb_data_type_t llama_grammar_element_type;
2023
-
2024
- static VALUE _llama_grammar_element_init(int argc, VALUE* argv, VALUE self) {
2025
- VALUE kw_args = Qnil;
2026
- ID kw_table[2] = { rb_intern("type"), rb_intern("value") };
2027
- VALUE kw_values[2] = { Qundef, Qundef };
2028
- VALUE arr = Qnil;
2029
- rb_scan_args(argc, argv, ":", &arr, &kw_args);
2030
- rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
2031
-
2032
- if (kw_values[0] != Qundef && !RB_INTEGER_TYPE_P(kw_values[0])) {
2033
- rb_raise(rb_eArgError, "type must be an integer");
2034
- return Qnil;
2035
- }
2036
- if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
2037
- rb_raise(rb_eArgError, "value must be an integer");
2038
- return Qnil;
2039
- }
2040
-
2041
- LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
2042
- new (ptr) LLaMAGrammarElementWrapper();
2043
-
2044
- if (kw_values[0] != Qundef) {
2045
- ptr->element.type = (enum llama_gretype)NUM2INT(kw_values[0]);
2046
- }
2047
- if (kw_values[1] != Qundef) {
2048
- ptr->element.value = NUM2INT(kw_values[1]);
2049
- }
2050
-
2051
- return self;
2052
- }
2053
-
2054
- // type
2055
- static VALUE _llama_grammar_element_set_type(VALUE self, VALUE type) {
2056
- LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
2057
- ptr->element.type = (enum llama_gretype)NUM2INT(type);
2058
- return INT2NUM(ptr->element.type);
2059
- }
2060
-
2061
- static VALUE _llama_grammar_element_get_type(VALUE self) {
2062
- LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
2063
- return INT2NUM(ptr->element.type);
2064
- }
2065
-
2066
- // value
2067
- static VALUE _llama_grammar_element_set_value(VALUE self, VALUE type) {
2068
- LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
2069
- ptr->element.value = NUM2INT(type);
2070
- return INT2NUM(ptr->element.value);
2071
- }
2072
-
2073
- static VALUE _llama_grammar_element_get_value(VALUE self) {
2074
- LLaMAGrammarElementWrapper* ptr = get_llama_grammar_element(self);
2075
- return INT2NUM(ptr->element.value);
2076
- }
2077
- };
2078
-
2079
- const rb_data_type_t RbLLaMAGrammarElement::llama_grammar_element_type = {
2080
- "RbLLaMAGrammarElement",
2081
- { NULL,
2082
- RbLLaMAGrammarElement::llama_grammar_element_free,
2083
- RbLLaMAGrammarElement::llama_grammar_element_size },
2084
- NULL,
2085
- NULL,
2086
- RUBY_TYPED_FREE_IMMEDIATELY
2087
- };
2088
-
2089
- class LLaMAGrammarWrapper {
2090
- public:
2091
- struct llama_grammar* grammar;
2092
-
2093
- LLaMAGrammarWrapper() : grammar(nullptr) {}
2094
-
2095
- ~LLaMAGrammarWrapper() {
2096
- if (grammar) {
2097
- llama_grammar_free(grammar);
2098
- }
2099
- }
2100
- };
2101
-
2102
- class RbLLaMAGrammar {
2103
- public:
2104
- static VALUE llama_grammar_alloc(VALUE self) {
2105
- LLaMAGrammarWrapper* ptr = (LLaMAGrammarWrapper*)ruby_xmalloc(sizeof(LLaMAGrammarWrapper));
2106
- new (ptr) LLaMAGrammarWrapper();
2107
- return TypedData_Wrap_Struct(self, &llama_grammar_type, ptr);
2108
- }
2109
-
2110
- static void llama_grammar_free(void* ptr) {
2111
- ((LLaMAGrammarWrapper*)ptr)->~LLaMAGrammarWrapper();
2112
- ruby_xfree(ptr);
2113
- }
2114
-
2115
- static size_t llama_grammar_size(const void* ptr) {
2116
- return sizeof(*((LLaMAGrammarWrapper*)ptr));
2117
- }
2118
-
2119
- static LLaMAGrammarWrapper* get_llama_grammar(VALUE self) {
2120
- LLaMAGrammarWrapper* ptr;
2121
- TypedData_Get_Struct(self, LLaMAGrammarWrapper, &llama_grammar_type, ptr);
2122
- return ptr;
2123
- }
2124
-
2125
- static void define_class(VALUE outer) {
2126
- rb_cLLaMAGrammar = rb_define_class_under(outer, "Grammar", rb_cObject);
2127
- rb_define_alloc_func(rb_cLLaMAGrammar, llama_grammar_alloc);
2128
- rb_define_method(rb_cLLaMAGrammar, "initialize", RUBY_METHOD_FUNC(_llama_grammar_init), -1);
2129
- }
2130
-
2131
- private:
2132
- static const rb_data_type_t llama_grammar_type;
2133
-
2134
- static VALUE _llama_grammar_init(int argc, VALUE* argv, VALUE self) {
2135
- VALUE kw_args = Qnil;
2136
- ID kw_table[2] = { rb_intern("rules"), rb_intern("start_rule_index") };
2137
- VALUE kw_values[2] = { Qundef, Qundef };
2138
- rb_scan_args(argc, argv, ":", &kw_args);
2139
- rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
2140
-
2141
- if (!RB_TYPE_P(kw_values[0], T_ARRAY)) {
2142
- rb_raise(rb_eArgError, "rules must be an array");
2143
- return Qnil;
2144
- }
2145
- if (!RB_INTEGER_TYPE_P(kw_values[1])) {
2146
- rb_raise(rb_eArgError, "start_rule_index must be an integer");
2147
- return Qnil;
2148
- }
2149
-
2150
- const int n_rules = RARRAY_LEN(kw_values[0]);
2151
- llama_grammar_element** rules = ALLOCA_N(llama_grammar_element*, n_rules);
2152
- for (int i = 0; i < n_rules; ++i) {
2153
- VALUE rule = rb_ary_entry(kw_values[0], i);
2154
- if (!RB_TYPE_P(rule, T_ARRAY)) {
2155
- rb_raise(rb_eArgError, "element of rules must be an array");
2156
- return Qnil;
2157
- }
2158
- const int n_elements = RARRAY_LEN(rule);
2159
- llama_grammar_element* elements = ALLOCA_N(llama_grammar_element, n_elements);
2160
- for (int j = 0; j < n_elements; ++j) {
2161
- VALUE element = rb_ary_entry(rule, j);
2162
- if (!rb_obj_is_kind_of(element, rb_cLLaMAGrammarElement)) {
2163
- rb_raise(rb_eArgError, "element of rule must be an instance of GrammarElement");
2164
- return Qnil;
2165
- }
2166
- LLaMAGrammarElementWrapper* ptr = RbLLaMAGrammarElement::get_llama_grammar_element(element);
2167
- elements[j] = ptr->element;
2168
- }
2169
- rules[i] = elements;
2170
- }
2171
-
2172
- const size_t start_rule_index = NUM2SIZET(kw_values[1]);
2173
-
2174
- LLaMAGrammarWrapper* ptr = get_llama_grammar(self);
2175
- new (ptr) LLaMAGrammarWrapper();
2176
- ptr->grammar = llama_grammar_init((const llama_grammar_element**)rules, n_rules, start_rule_index);
2177
-
2178
- return self;
2179
- }
2180
- };
2181
-
2182
- const rb_data_type_t RbLLaMAGrammar::llama_grammar_type = {
2183
- "RbLLaMAGrammar",
2184
- { NULL,
2185
- RbLLaMAGrammar::llama_grammar_free,
2186
- RbLLaMAGrammar::llama_grammar_size },
2187
- NULL,
2188
- NULL,
2189
- RUBY_TYPED_FREE_IMMEDIATELY
2190
- };
2191
-
2192
- class LLaMAContextWrapper {
2193
- public:
2194
- struct llama_context* ctx;
2195
-
2196
- LLaMAContextWrapper() : ctx(NULL) {}
2197
-
2198
- ~LLaMAContextWrapper() {
2199
- if (ctx != NULL) {
2200
- llama_free(ctx);
2201
- }
2202
- }
2203
- };
2204
-
2205
- class RbLLaMAContext {
2206
- public:
2207
- static VALUE llama_context_alloc(VALUE self) {
2208
- LLaMAContextWrapper* ptr = (LLaMAContextWrapper*)ruby_xmalloc(sizeof(LLaMAContextWrapper));
2209
- new (ptr) LLaMAContextWrapper();
2210
- return TypedData_Wrap_Struct(self, &llama_context_type, ptr);
2211
- }
2212
-
2213
- static void llama_context_free(void* ptr) {
2214
- ((LLaMAContextWrapper*)ptr)->~LLaMAContextWrapper();
2215
- ruby_xfree(ptr);
2216
- }
2217
-
2218
- static size_t llama_context_size(const void* ptr) {
2219
- return sizeof(*((LLaMAContextWrapper*)ptr));
2220
- }
2221
-
2222
- static LLaMAContextWrapper* get_llama_context(VALUE self) {
2223
- LLaMAContextWrapper* ptr;
2224
- TypedData_Get_Struct(self, LLaMAContextWrapper, &llama_context_type, ptr);
2225
- return ptr;
2226
- }
2227
-
2228
- static void define_class(VALUE outer) {
2229
- rb_cLLaMAContext = rb_define_class_under(outer, "Context", rb_cObject);
2230
- rb_define_alloc_func(rb_cLLaMAContext, llama_context_alloc);
2231
- rb_define_attr(rb_cLLaMAContext, "model", 1, 0);
2232
- rb_define_method(rb_cLLaMAContext, "initialize", RUBY_METHOD_FUNC(_llama_context_initialize), -1);
2233
- rb_define_method(rb_cLLaMAContext, "encode", RUBY_METHOD_FUNC(_llama_context_encode), 1);
2234
- rb_define_method(rb_cLLaMAContext, "decode", RUBY_METHOD_FUNC(_llama_context_decode), 1);
2235
- rb_define_method(rb_cLLaMAContext, "logits", RUBY_METHOD_FUNC(_llama_context_logits), 0);
2236
- rb_define_method(rb_cLLaMAContext, "embeddings", RUBY_METHOD_FUNC(_llama_context_embeddings), 0);
2237
- rb_define_method(rb_cLLaMAContext, "embeddings_ith", RUBY_METHOD_FUNC(_llama_context_embeddings_ith), 1);
2238
- rb_define_method(rb_cLLaMAContext, "embeddings_seq", RUBY_METHOD_FUNC(_llama_context_embeddings_seq), 1);
2239
- rb_define_method(rb_cLLaMAContext, "set_embeddings", RUBY_METHOD_FUNC(_llama_context_set_embeddings), 1);
2240
- rb_define_method(rb_cLLaMAContext, "set_n_threads", RUBY_METHOD_FUNC(_llama_context_set_n_threads), -1);
2241
- rb_define_method(rb_cLLaMAContext, "n_ctx", RUBY_METHOD_FUNC(_llama_context_n_ctx), 0);
2242
- rb_define_method(rb_cLLaMAContext, "n_batch", RUBY_METHOD_FUNC(_llama_context_n_batch), 0);
2243
- rb_define_method(rb_cLLaMAContext, "n_ubatch", RUBY_METHOD_FUNC(_llama_context_n_ubatch), 0);
2244
- rb_define_method(rb_cLLaMAContext, "n_seq_max", RUBY_METHOD_FUNC(_llama_context_n_seq_max), 0);
2245
- rb_define_method(rb_cLLaMAContext, "n_threads", RUBY_METHOD_FUNC(_llama_context_n_threads), 0);
2246
- rb_define_method(rb_cLLaMAContext, "n_threads_batch", RUBY_METHOD_FUNC(_llama_context_n_threads_batch), 0);
2247
- rb_define_method(rb_cLLaMAContext, "timings", RUBY_METHOD_FUNC(_llama_context_get_timings), 0);
2248
- rb_define_method(rb_cLLaMAContext, "print_timings", RUBY_METHOD_FUNC(_llama_context_print_timings), 0);
2249
- rb_define_method(rb_cLLaMAContext, "reset_timings", RUBY_METHOD_FUNC(_llama_context_reset_timings), 0);
2250
- rb_define_method(rb_cLLaMAContext, "kv_cache_token_count", RUBY_METHOD_FUNC(_llama_context_kv_cache_token_count), 0);
2251
- rb_define_method(rb_cLLaMAContext, "kv_cache_clear", RUBY_METHOD_FUNC(_llama_context_kv_cache_clear), 0);
2252
- rb_define_method(rb_cLLaMAContext, "kv_cache_seq_rm", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_rm), 3);
2253
- rb_define_method(rb_cLLaMAContext, "kv_cache_seq_cp", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_cp), 4);
2254
- rb_define_method(rb_cLLaMAContext, "kv_cache_seq_keep", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_keep), 1);
2255
- rb_define_method(rb_cLLaMAContext, "kv_cache_seq_add", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_add), 4);
2256
- rb_define_method(rb_cLLaMAContext, "kv_cache_seq_div", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_div), 4);
2257
- rb_define_method(rb_cLLaMAContext, "kv_cache_seq_pos_max", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_pos_max), 1);
2258
- rb_define_method(rb_cLLaMAContext, "kv_cache_kv_cache_defrag", RUBY_METHOD_FUNC(_llama_context_kv_cache_defrag), 0);
2259
- rb_define_method(rb_cLLaMAContext, "kv_cache_kv_cache_update", RUBY_METHOD_FUNC(_llama_context_kv_cache_update), 0);
2260
- rb_define_method(rb_cLLaMAContext, "set_rng_seed", RUBY_METHOD_FUNC(_llama_context_set_rng_seed), 1);
2261
- rb_define_method(rb_cLLaMAContext, "set_causal_attn", RUBY_METHOD_FUNC(_llama_context_set_causal_attn), 1);
2262
- rb_define_method(rb_cLLaMAContext, "synchronize", RUBY_METHOD_FUNC(_llama_context_synchronize), 0);
2263
- rb_define_method(rb_cLLaMAContext, "load_session_file", RUBY_METHOD_FUNC(_llama_context_load_session_file), -1);
2264
- rb_define_method(rb_cLLaMAContext, "save_session_file", RUBY_METHOD_FUNC(_llama_context_save_session_file), -1);
2265
- rb_define_method(rb_cLLaMAContext, "sample_repetition_penalties", RUBY_METHOD_FUNC(_llama_context_sample_repetition_penalties), -1);
2266
- rb_define_method(rb_cLLaMAContext, "sample_apply_guidance", RUBY_METHOD_FUNC(_llama_context_sample_apply_guidance), -1);
2267
- rb_define_method(rb_cLLaMAContext, "sample_softmax", RUBY_METHOD_FUNC(_llama_context_sample_softmax), 1);
2268
- rb_define_method(rb_cLLaMAContext, "sample_top_k", RUBY_METHOD_FUNC(_llama_context_sample_top_k), -1);
2269
- rb_define_method(rb_cLLaMAContext, "sample_top_p", RUBY_METHOD_FUNC(_llama_context_sample_top_p), -1);
2270
- rb_define_method(rb_cLLaMAContext, "sample_min_p", RUBY_METHOD_FUNC(_llama_context_sample_min_p), -1);
2271
- rb_define_method(rb_cLLaMAContext, "sample_tail_free", RUBY_METHOD_FUNC(_llama_context_sample_tail_free), -1);
2272
- rb_define_method(rb_cLLaMAContext, "sample_typical", RUBY_METHOD_FUNC(_llama_context_sample_typical), -1);
2273
- rb_define_method(rb_cLLaMAContext, "sample_temp", RUBY_METHOD_FUNC(_llama_context_sample_temp), -1);
2274
- rb_define_method(rb_cLLaMAContext, "sample_entropy", RUBY_METHOD_FUNC(_llama_context_sample_entropy), -1);
2275
- rb_define_method(rb_cLLaMAContext, "sample_token_mirostat", RUBY_METHOD_FUNC(_llama_context_sample_token_mirostat), -1);
2276
- rb_define_method(rb_cLLaMAContext, "sample_token_mirostat_v2", RUBY_METHOD_FUNC(_llama_context_sample_token_mirostat_v2), -1);
2277
- rb_define_method(rb_cLLaMAContext, "sample_token_greedy", RUBY_METHOD_FUNC(_llama_context_sample_token_greedy), 1);
2278
- rb_define_method(rb_cLLaMAContext, "sample_token", RUBY_METHOD_FUNC(_llama_context_sample_token), 1);
2279
- rb_define_method(rb_cLLaMAContext, "sample_grammar", RUBY_METHOD_FUNC(_llama_context_sample_grammar), -1);
2280
- rb_define_method(rb_cLLaMAContext, "grammar_accept_token", RUBY_METHOD_FUNC(_llama_context_grammar_accept_token), -1);
2281
- rb_define_method(rb_cLLaMAContext, "apply_control_vector", RUBY_METHOD_FUNC(_llama_context_apply_control_vector), -1);
2282
- rb_define_method(rb_cLLaMAContext, "pooling_type", RUBY_METHOD_FUNC(_llama_context_pooling_type), 0);
2283
- }
2284
-
2285
- private:
2286
- static const rb_data_type_t llama_context_type;
2287
-
2288
- static VALUE _llama_context_initialize(int argc, VALUE* argv, VALUE self) {
2289
- VALUE kw_args = Qnil;
2290
- ID kw_table[2] = { rb_intern("model"), rb_intern("params") };
2291
- VALUE kw_values[2] = { Qundef, Qundef };
2292
- rb_scan_args(argc, argv, ":", &kw_args);
2293
- rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
2294
-
2295
- VALUE model = kw_values[0];
2296
- if (!rb_obj_is_kind_of(model, rb_cLLaMAModel)) {
2297
- rb_raise(rb_eArgError, "model must be a Model");
2298
- return Qnil;
2299
- }
2300
- VALUE params = kw_values[1];
2301
- if (!rb_obj_is_kind_of(params, rb_cLLaMAContextParams)) {
2302
- rb_raise(rb_eArgError, "params must be a ContextParams");
2303
- return Qnil;
2304
- }
2305
-
2306
- LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
2307
- if (model_ptr->model == NULL) {
2308
- rb_raise(rb_eRuntimeError, "Model is empty");
2309
- return Qnil;
2310
- }
2311
- LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
2312
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
2313
-
2314
- ctx_ptr->ctx = llama_new_context_with_model(model_ptr->model, prms_ptr->params);
2315
-
2316
- if (ctx_ptr->ctx == NULL) {
2317
- rb_raise(rb_eRuntimeError, "Failed to initialize LLaMA context");
2318
- return Qnil;
2319
- }
2320
-
2321
- rb_iv_set(self, "@model", model);
2322
- rb_iv_set(self, "@params", params);
2323
- rb_iv_set(self, "@has_evaluated", Qfalse);
2324
-
2325
- return Qnil;
2326
- }
2327
-
2328
- static VALUE _llama_context_encode(VALUE self, VALUE batch) {
2329
- LLaMAContextWrapper* ptr = get_llama_context(self);
2330
- if (ptr->ctx == NULL) {
2331
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2332
- return Qnil;
2333
- }
2334
- if (!rb_obj_is_kind_of(batch, rb_cLLaMABatch)) {
2335
- rb_raise(rb_eArgError, "batch must be a Batch");
2336
- return Qnil;
2337
- }
2338
- LLaMABatchWrapper* batch_ptr = RbLLaMABatch::get_llama_batch(batch);
2339
- if (llama_encode(ptr->ctx, batch_ptr->batch) < 0) {
2340
- rb_raise(rb_eRuntimeError, "Failed to encode");
2341
- return Qnil;
2342
- }
2343
- return Qnil;
2344
- }
2345
-
2346
- static VALUE _llama_context_decode(VALUE self, VALUE batch) {
2347
- LLaMAContextWrapper* ptr = get_llama_context(self);
2348
- if (ptr->ctx == NULL) {
2349
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2350
- return Qnil;
2351
- }
2352
- if (!rb_obj_is_kind_of(batch, rb_cLLaMABatch)) {
2353
- rb_raise(rb_eArgError, "batch must be a Batch");
2354
- return Qnil;
2355
- }
2356
- LLaMABatchWrapper* batch_ptr = RbLLaMABatch::get_llama_batch(batch);
2357
- if (llama_decode(ptr->ctx, batch_ptr->batch) < 0) {
2358
- rb_raise(rb_eRuntimeError, "Failed to decode");
2359
- return Qnil;
2360
- }
2361
- rb_iv_set(self, "@n_tokens", INT2NUM(batch_ptr->batch.n_tokens));
2362
- rb_iv_set(self, "@has_evaluated", Qtrue);
2363
- return Qnil;
2364
- }
2365
-
2366
- static VALUE _llama_context_logits(VALUE self) {
2367
- LLaMAContextWrapper* ptr = get_llama_context(self);
2368
- if (ptr->ctx == NULL) {
2369
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2370
- return Qnil;
2371
- }
2372
- if (rb_iv_get(self, "@has_evaluated") != Qtrue) {
2373
- rb_raise(rb_eRuntimeError, "LLaMA context has not been evaluated");
2374
- return Qnil;
2375
- }
2376
-
2377
- VALUE model = rb_iv_get(self, "@model");
2378
- LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
2379
- VALUE params = rb_iv_get(self, "@params");
2380
- LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
2381
- const int n_tokens = prms_ptr->params.logits_all ? NUM2INT(rb_iv_get(self, "@n_tokens")) : 1;
2382
- const int n_vocab = llama_n_vocab(model_ptr->model);
2383
- const float* logits = llama_get_logits(ptr->ctx);
2384
- VALUE output = rb_ary_new();
2385
- for (int i = 0; i < n_tokens * n_vocab; i++) {
2386
- rb_ary_push(output, DBL2NUM((double)(logits[i])));
2387
- }
2388
-
2389
- return output;
2390
- }
2391
-
2392
- static VALUE _llama_context_embeddings(VALUE self) {
2393
- LLaMAContextWrapper* ptr = get_llama_context(self);
2394
- if (ptr->ctx == NULL) {
2395
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2396
- return Qnil;
2397
- }
2398
- VALUE model = rb_iv_get(self, "@model");
2399
- LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
2400
- VALUE params = rb_iv_get(self, "@params");
2401
- LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
2402
- if (!prms_ptr->params.embeddings) {
2403
- rb_raise(rb_eRuntimeError, "embedding parameter is false");
2404
- return Qnil;
2405
- }
2406
- if (rb_iv_get(self, "@has_evaluated") != Qtrue) {
2407
- rb_raise(rb_eRuntimeError, "LLaMA context has not been evaluated");
2408
- return Qnil;
2409
- }
2410
-
2411
- const int n_tokens = NUM2INT(rb_iv_get(self, "@n_tokens"));
2412
- const int n_embd = llama_n_embd(model_ptr->model);
2413
- const float* embd = llama_get_embeddings(ptr->ctx);
2414
- VALUE output = rb_ary_new();
2415
- for (int i = 0; i < n_tokens * n_embd; i++) {
2416
- rb_ary_push(output, DBL2NUM((double)(embd[i])));
2417
- }
2418
-
2419
- return output;
2420
- }
2421
-
2422
- static VALUE _llama_context_embeddings_ith(VALUE self, VALUE ith) {
2423
- if (!RB_INTEGER_TYPE_P(ith)) {
2424
- rb_raise(rb_eArgError, "ith must be an integer");
2425
- return Qnil;
2426
- }
2427
- LLaMAContextWrapper* ptr = get_llama_context(self);
2428
- if (ptr->ctx == NULL) {
2429
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2430
- return Qnil;
2431
- }
2432
- VALUE params = rb_iv_get(self, "@params");
2433
- LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
2434
- if (!prms_ptr->params.embeddings) {
2435
- rb_raise(rb_eRuntimeError, "embedding parameter is false");
2436
- return Qnil;
2437
- }
2438
-
2439
- VALUE model = rb_iv_get(self, "@model");
2440
- LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
2441
- const int n_embd = llama_n_embd(model_ptr->model);
2442
-
2443
- VALUE output = rb_ary_new();
2444
- const float* embd = llama_get_embeddings_ith(ptr->ctx, NUM2INT(ith));
2445
- for (int i = 0; i < n_embd; i++) {
2446
- rb_ary_push(output, DBL2NUM((double)(embd[i])));
2447
- }
2448
-
2449
- return output;
2450
- }
2451
-
2452
- static VALUE _llama_context_embeddings_seq(VALUE self, VALUE seq_id) {
2453
- if (!RB_INTEGER_TYPE_P(seq_id)) {
2454
- rb_raise(rb_eArgError, "seq_id must be an integer");
2455
- return Qnil;
2456
- }
2457
- LLaMAContextWrapper* ptr = get_llama_context(self);
2458
- if (ptr->ctx == NULL) {
2459
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2460
- return Qnil;
2461
- }
2462
- VALUE params = rb_iv_get(self, "@params");
2463
- LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
2464
- if (!prms_ptr->params.embeddings) {
2465
- rb_raise(rb_eRuntimeError, "embedding parameter is false");
2466
- return Qnil;
2467
- }
2468
-
2469
- VALUE model = rb_iv_get(self, "@model");
2470
- LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
2471
- const int n_embd = llama_n_embd(model_ptr->model);
2472
-
2473
- VALUE output = rb_ary_new();
2474
- const float* embd = llama_get_embeddings_seq(ptr->ctx, NUM2INT(seq_id));
2475
- for (int i = 0; i < n_embd; i++) {
2476
- rb_ary_push(output, DBL2NUM((double)(embd[i])));
2477
- }
2478
-
2479
- return output;
2480
- }
2481
-
2482
- static VALUE _llama_context_set_embeddings(VALUE self, VALUE embs) {
2483
- LLaMAContextWrapper* ptr = get_llama_context(self);
2484
- if (ptr->ctx == NULL) {
2485
- rb_raise(rb_eArgError, "LLaMA context is not initialized");
2486
- return Qnil;
2487
- }
2488
- llama_set_embeddings(ptr->ctx, RTEST(embs) ? true : false);
2489
- return Qnil;
2490
- }
2491
-
2492
- static VALUE _llama_context_set_n_threads(int argc, VALUE* argv, VALUE self) {
2493
- VALUE kw_args = Qnil;
2494
- ID kw_table[2] = { rb_intern("n_threads"), rb_intern("n_threads_batch") };
2495
- VALUE kw_values[2] = { Qundef, Qundef };
2496
- rb_scan_args(argc, argv, ":", &kw_args);
2497
- rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
2498
-
2499
- VALUE n_threads = kw_values[0];
2500
- if (!RB_INTEGER_TYPE_P(n_threads)) {
2501
- rb_raise(rb_eArgError, "n_threads must be an integer");
2502
- return Qnil;
2503
- }
2504
- VALUE n_threads_batch = kw_values[1];
2505
- if (!RB_INTEGER_TYPE_P(n_threads_batch)) {
2506
- rb_raise(rb_eArgError, "n_threads_batch must be an integer");
2507
- return Qnil;
2508
- }
2509
-
2510
- LLaMAContextWrapper* ptr = get_llama_context(self);
2511
- if (ptr->ctx == NULL) {
2512
- rb_raise(rb_eArgError, "LLaMA context is not initialized");
2513
- return Qnil;
2514
- }
2515
- llama_set_n_threads(ptr->ctx, NUM2UINT(n_threads), NUM2UINT(n_threads_batch));
2516
- return Qnil;
2517
- }
2518
-
2519
- static VALUE _llama_context_n_ctx(VALUE self) {
2520
- LLaMAContextWrapper* ptr = get_llama_context(self);
2521
- if (ptr->ctx == NULL) {
2522
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2523
- return Qnil;
2524
- }
2525
- return UINT2NUM(llama_n_ctx(ptr->ctx));
2526
- }
2527
-
2528
- static VALUE _llama_context_n_batch(VALUE self) {
2529
- LLaMAContextWrapper* ptr = get_llama_context(self);
2530
- if (ptr->ctx == NULL) {
2531
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2532
- return Qnil;
2533
- }
2534
- return UINT2NUM(llama_n_batch(ptr->ctx));
2535
- }
2536
-
2537
- static VALUE _llama_context_n_ubatch(VALUE self) {
2538
- LLaMAContextWrapper* ptr = get_llama_context(self);
2539
- if (ptr->ctx == NULL) {
2540
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2541
- return Qnil;
2542
- }
2543
- return UINT2NUM(llama_n_ubatch(ptr->ctx));
2544
- }
2545
-
2546
- static VALUE _llama_context_n_seq_max(VALUE self) {
2547
- LLaMAContextWrapper* ptr = get_llama_context(self);
2548
- if (ptr->ctx == NULL) {
2549
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2550
- return Qnil;
2551
- }
2552
- return UINT2NUM(llama_n_seq_max(ptr->ctx));
2553
- }
2554
-
2555
- static VALUE _llama_context_n_threads(VALUE self) {
2556
- LLaMAContextWrapper* ptr = get_llama_context(self);
2557
- if (ptr->ctx == NULL) {
2558
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2559
- return Qnil;
2560
- }
2561
- return UINT2NUM(llama_n_threads(ptr->ctx));
2562
- }
2563
-
2564
- static VALUE _llama_context_n_threads_batch(VALUE self) {
2565
- LLaMAContextWrapper* ptr = get_llama_context(self);
2566
- if (ptr->ctx == NULL) {
2567
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2568
- return Qnil;
2569
- }
2570
- return UINT2NUM(llama_n_threads_batch(ptr->ctx));
2571
- }
2572
-
2573
- static VALUE _llama_context_get_timings(VALUE self) {
2574
- LLaMAContextWrapper* ptr = get_llama_context(self);
2575
- if (ptr->ctx == NULL) {
2576
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2577
- return Qnil;
2578
- }
2579
- VALUE tm_obj = rb_funcall(rb_cLLaMATimings, rb_intern("new"), 0);
2580
- LLaMATimingsWrapper* tm_ptr = RbLLaMATimings::get_llama_timings(tm_obj);
2581
- tm_ptr->timings = llama_get_timings(ptr->ctx);
2582
- return tm_obj;
2583
- }
2584
-
2585
- static VALUE _llama_context_print_timings(VALUE self) {
2586
- LLaMAContextWrapper* ptr = get_llama_context(self);
2587
- if (ptr->ctx == NULL) {
2588
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2589
- return Qnil;
2590
- }
2591
- llama_print_timings(ptr->ctx);
2592
- return Qnil;
2593
- }
2594
-
2595
- static VALUE _llama_context_reset_timings(VALUE self) {
2596
- LLaMAContextWrapper* ptr = get_llama_context(self);
2597
- if (ptr->ctx == NULL) {
2598
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2599
- return Qnil;
2600
- }
2601
- llama_reset_timings(ptr->ctx);
2602
- return Qnil;
2603
- }
2604
-
2605
- static VALUE _llama_context_kv_cache_token_count(VALUE self) {
2606
- LLaMAContextWrapper* ptr = get_llama_context(self);
2607
- if (ptr->ctx == NULL) {
2608
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2609
- return Qnil;
2610
- }
2611
- return INT2NUM(llama_get_kv_cache_token_count(ptr->ctx));
2612
- }
2613
-
2614
- static VALUE _llama_context_kv_cache_clear(VALUE self) {
2615
- LLaMAContextWrapper* ptr = get_llama_context(self);
2616
- if (ptr->ctx == NULL) {
2617
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2618
- return Qnil;
2619
- }
2620
- llama_kv_cache_clear(ptr->ctx);
2621
- return Qnil;
2622
- }
2623
-
2624
- static VALUE _llama_context_kv_cache_seq_rm(VALUE self, VALUE seq_id, VALUE p0, VALUE p1) {
2625
- LLaMAContextWrapper* ptr = get_llama_context(self);
2626
- if (ptr->ctx == NULL) {
2627
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2628
- return Qnil;
2629
- }
2630
- llama_kv_cache_seq_rm(ptr->ctx, NUM2INT(seq_id), NUM2INT(p0), NUM2INT(p1));
2631
- return Qnil;
2632
- }
2633
-
2634
- static VALUE _llama_context_kv_cache_seq_cp(VALUE self, VALUE seq_id_src, VALUE seq_id_dst, VALUE p0, VALUE p1) {
2635
- LLaMAContextWrapper* ptr = get_llama_context(self);
2636
- if (ptr->ctx == NULL) {
2637
- rb_raise(rb_eArgError, "LLaMA context is not initialized");
2638
- return Qnil;
2639
- }
2640
- llama_kv_cache_seq_cp(ptr->ctx, NUM2INT(seq_id_src), NUM2INT(seq_id_dst), NUM2INT(p0), NUM2INT(p1));
2641
- return Qnil;
2642
- }
2643
-
2644
- static VALUE _llama_context_kv_cache_seq_keep(VALUE self, VALUE seq_id) {
2645
- LLaMAContextWrapper* ptr = get_llama_context(self);
2646
- if (ptr->ctx == NULL) {
2647
- rb_raise(rb_eArgError, "LLaMA context is not initialized");
2648
- return Qnil;
2649
- }
2650
- llama_kv_cache_seq_keep(ptr->ctx, NUM2INT(seq_id));
2651
- return Qnil;
2652
- }
2653
-
2654
- static VALUE _llama_context_kv_cache_seq_add(VALUE self, VALUE seq_id, VALUE p0, VALUE p1, VALUE delta) {
2655
- LLaMAContextWrapper* ptr = get_llama_context(self);
2656
- if (ptr->ctx == NULL) {
2657
- rb_raise(rb_eArgError, "LLaMA context is not initialized");
2658
- return Qnil;
2659
- }
2660
- llama_kv_cache_seq_add(ptr->ctx, NUM2INT(seq_id), NUM2INT(p0), NUM2INT(p1), NUM2INT(delta));
2661
- return Qnil;
2662
- }
2663
-
2664
- static VALUE _llama_context_kv_cache_seq_div(VALUE self, VALUE seq_id, VALUE p0, VALUE p1, VALUE d) {
2665
- LLaMAContextWrapper* ptr = get_llama_context(self);
2666
- if (ptr->ctx == NULL) {
2667
- rb_raise(rb_eArgError, "LLaMA context is not initialized");
2668
- return Qnil;
2669
- }
2670
- llama_kv_cache_seq_div(ptr->ctx, NUM2INT(seq_id), NUM2INT(p0), NUM2INT(p1), NUM2INT(d));
2671
- return Qnil;
2672
- }
2673
-
2674
- static VALUE _llama_context_kv_cache_seq_pos_max(VALUE self, VALUE seq_id) {
2675
- LLaMAContextWrapper* ptr = get_llama_context(self);
2676
- if (ptr->ctx == NULL) {
2677
- rb_raise(rb_eArgError, "LLaMA context is not initialized");
2678
- return Qnil;
2679
- }
2680
- return INT2NUM(llama_kv_cache_seq_pos_max(ptr->ctx, NUM2INT(seq_id)));
2681
- }
2682
-
2683
- static VALUE _llama_context_kv_cache_defrag(VALUE self) {
2684
- LLaMAContextWrapper* ptr = get_llama_context(self);
2685
- if (ptr->ctx == NULL) {
2686
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2687
- return Qnil;
2688
- }
2689
- llama_kv_cache_defrag(ptr->ctx);
2690
- return Qnil;
2691
- }
2692
-
2693
- static VALUE _llama_context_kv_cache_update(VALUE self) {
2694
- LLaMAContextWrapper* ptr = get_llama_context(self);
2695
- if (ptr->ctx == NULL) {
2696
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2697
- return Qnil;
2698
- }
2699
- llama_kv_cache_update(ptr->ctx);
2700
- return Qnil;
2701
- }
2702
-
2703
- static VALUE _llama_context_set_rng_seed(VALUE self, VALUE seed_) {
2704
- LLaMAContextWrapper* ptr = get_llama_context(self);
2705
- if (ptr->ctx == NULL) {
2706
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2707
- return Qnil;
2708
- }
2709
- if (NUM2INT(seed_) < 0) {
2710
- rb_raise(rb_eArgError, "seed must be a non-negative integer");
2711
- return Qnil;
2712
- }
2713
- const uint32_t seed = NUM2INT(seed_);
2714
- llama_set_rng_seed(ptr->ctx, seed);
2715
- return Qnil;
2716
- }
2717
-
2718
- static VALUE _llama_context_set_causal_attn(VALUE self, VALUE causal_attn) {
2719
- LLaMAContextWrapper* ptr = get_llama_context(self);
2720
- if (ptr->ctx == NULL) {
2721
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2722
- return Qnil;
2723
- }
2724
- llama_set_causal_attn(ptr->ctx, RTEST(causal_attn) ? true : false);
2725
- return Qnil;
2726
- }
2727
-
2728
- static VALUE _llama_context_synchronize(VALUE self) {
2729
- LLaMAContextWrapper* ptr = get_llama_context(self);
2730
- if (ptr->ctx == NULL) {
2731
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2732
- return Qnil;
2733
- }
2734
- llama_synchronize(ptr->ctx);
2735
- return Qnil;
2736
- }
2737
-
2738
- static VALUE _llama_context_load_session_file(int argc, VALUE* argv, VALUE self) {
2739
- VALUE kw_args = Qnil;
2740
- ID kw_table[1] = { rb_intern("session_path") };
2741
- VALUE kw_values[1] = { Qundef };
2742
- VALUE candidates = Qnil;
2743
- VALUE last_n_tokens = Qnil;
2744
- rb_scan_args(argc, argv, ":", &kw_args);
2745
- rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
2746
-
2747
- if (!RB_TYPE_P(kw_values[0], T_STRING)) {
2748
- rb_raise(rb_eArgError, "session_path must be a String");
2749
- return Qnil;
2750
- }
2751
-
2752
- VALUE filename = kw_values[0];
2753
-
2754
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
2755
- if (ctx_ptr->ctx == NULL) {
2756
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2757
- return Qnil;
2758
- }
2759
-
2760
- VALUE model = rb_iv_get(self, "@model");
2761
- VALUE params = rb_iv_get(self, "@params");
2762
- LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
2763
- const int n_ctx = prms_ptr->params.n_ctx;
2764
-
2765
- std::vector<llama_token> session_tokens(n_ctx);
2766
- size_t n_token_count_out = 0;
2767
-
2768
- try {
2769
- bool res = llama_load_session_file(ctx_ptr->ctx, StringValueCStr(filename), session_tokens.data(), session_tokens.capacity(), &n_token_count_out);
2770
- if (!res) {
2771
- rb_raise(rb_eRuntimeError, "Failed to load session file");
2772
- return Qnil;
2773
- }
2774
- session_tokens.resize(n_token_count_out);
2775
- } catch (const std::runtime_error& e) {
2776
- rb_raise(rb_eRuntimeError, "%s", e.what());
2777
- return Qnil;
2778
- }
2779
-
2780
- VALUE ary_session_tokens = rb_ary_new2(n_token_count_out);
2781
- for (size_t i = 0; i < n_token_count_out; i++) {
2782
- rb_ary_store(ary_session_tokens, i, INT2NUM(session_tokens[i]));
2783
- }
2784
-
2785
- RB_GC_GUARD(filename);
2786
- return ary_session_tokens;
2787
- }
2788
-
2789
- static VALUE _llama_context_save_session_file(int argc, VALUE* argv, VALUE self) {
2790
- VALUE kw_args = Qnil;
2791
- ID kw_table[2] = { rb_intern("session_path"), rb_intern("session_tokens") };
2792
- VALUE kw_values[2] = { Qundef, Qundef };
2793
- VALUE candidates = Qnil;
2794
- VALUE last_n_tokens = Qnil;
2795
- rb_scan_args(argc, argv, ":", &kw_args);
2796
- rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
2797
-
2798
- if (!RB_TYPE_P(kw_values[0], T_STRING)) {
2799
- rb_raise(rb_eArgError, "session_path must be a String");
2800
- return Qnil;
2801
- }
2802
- if (!RB_TYPE_P(kw_values[1], T_ARRAY)) {
2803
- rb_raise(rb_eArgError, "session_tokens must be an Array");
2804
- return Qnil;
2805
- }
2806
-
2807
- VALUE filename = kw_values[0];
2808
- const size_t sz_session_tokens = RARRAY_LEN(kw_values[1]);
2809
- std::vector<llama_token> session_tokens(sz_session_tokens);
2810
- for (size_t i = 0; i < sz_session_tokens; i++) {
2811
- session_tokens[i] = NUM2INT(rb_ary_entry(kw_values[1], i));
2812
- }
2813
-
2814
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
2815
- if (ctx_ptr->ctx == NULL) {
2816
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2817
- return Qnil;
2818
- }
2819
-
2820
- bool res = llama_save_session_file(ctx_ptr->ctx, StringValueCStr(filename), session_tokens.data(), sz_session_tokens);
2821
-
2822
- if (!res) {
2823
- rb_raise(rb_eRuntimeError, "Failed to save session file");
2824
- return Qnil;
2825
- }
2826
-
2827
- RB_GC_GUARD(filename);
2828
- return Qnil;
2829
- }
2830
-
2831
- static VALUE _llama_context_sample_repetition_penalties(int argc, VALUE* argv, VALUE self) {
2832
- VALUE kw_args = Qnil;
2833
- ID kw_table[3] = { rb_intern("penalty_repeat"), rb_intern("penalty_freq"), rb_intern("penalty_present") };
2834
- VALUE kw_values[3] = { Qundef, Qundef, Qundef };
2835
- VALUE candidates = Qnil;
2836
- VALUE last_n_tokens = Qnil;
2837
- rb_scan_args(argc, argv, "2:", &candidates, &last_n_tokens, &kw_args);
2838
- rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
2839
-
2840
- if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
2841
- rb_raise(rb_eArgError, "candidates must be a TokenDataArray");
2842
- return Qnil;
2843
- }
2844
- if (!RB_TYPE_P(last_n_tokens, T_ARRAY)) {
2845
- rb_raise(rb_eArgError, "last_n_tokens must be an Array");
2846
- return Qnil;
2847
- }
2848
- if (!RB_FLOAT_TYPE_P(kw_values[0])) {
2849
- rb_raise(rb_eArgError, "penalty_repeat must be a float");
2850
- return Qnil;
2851
- }
2852
- if (!RB_FLOAT_TYPE_P(kw_values[1])) {
2853
- rb_raise(rb_eArgError, "penalty_freq must be a float");
2854
- return Qnil;
2855
- }
2856
- if (!RB_FLOAT_TYPE_P(kw_values[2])) {
2857
- rb_raise(rb_eArgError, "penalty_present must be a float");
2858
- return Qnil;
2859
- }
2860
-
2861
- const size_t last_tokens_size = RARRAY_LEN(last_n_tokens);
2862
- std::vector<llama_token> last_n_tokens_data(last_tokens_size);
2863
- for (size_t i = 0; i < last_tokens_size; i++) {
2864
- last_n_tokens_data[i] = NUM2INT(rb_ary_entry(last_n_tokens, i));
2865
- }
2866
-
2867
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
2868
- if (ctx_ptr->ctx == NULL) {
2869
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2870
- return Qnil;
2871
- }
2872
- LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
2873
- if (cnd_ptr->array.data == nullptr) {
2874
- rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
2875
- return Qnil;
2876
- }
2877
- const float penalty_repeat = NUM2DBL(kw_values[0]);
2878
- const float penalty_freq = NUM2DBL(kw_values[1]);
2879
- const float penalty_present = NUM2DBL(kw_values[2]);
2880
-
2881
- llama_sample_repetition_penalties(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size,
2882
- penalty_repeat, penalty_freq, penalty_present);
2883
-
2884
- return Qnil;
2885
- }
2886
-
2887
- static VALUE _llama_context_sample_apply_guidance(int argc, VALUE* argv, VALUE self) {
2888
- VALUE kw_args = Qnil;
2889
- ID kw_table[3] = { rb_intern("logits"), rb_intern("logits_guidance"), rb_intern("scale") };
2890
- VALUE kw_values[3] = { Qundef, Qundef, Qundef };
2891
- rb_scan_args(argc, argv, ":", &kw_args);
2892
- rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
2893
-
2894
- if (!RB_TYPE_P(kw_values[0], T_ARRAY)) {
2895
- rb_raise(rb_eArgError, "logits must be an Array");
2896
- return Qnil;
2897
- }
2898
- if (!RB_TYPE_P(kw_values[1], T_ARRAY)) {
2899
- rb_raise(rb_eArgError, "logits_guidance must be an Array");
2900
- return Qnil;
2901
- }
2902
- if (!RB_FLOAT_TYPE_P(kw_values[2])) {
2903
- rb_raise(rb_eArgError, "scale must be a float");
2904
- return Qnil;
2905
- }
2906
-
2907
- const size_t sz_logits = RARRAY_LEN(kw_values[0]);
2908
- std::vector<float> logits(sz_logits);
2909
- for (size_t i = 0; i < sz_logits; i++) {
2910
- logits[i] = NUM2DBL(rb_ary_entry(kw_values[0], i));
2911
- }
2912
-
2913
- const size_t sz_logits_guidance = RARRAY_LEN(kw_values[1]);
2914
- std::vector<float> logits_guidance(sz_logits_guidance);
2915
- for (size_t i = 0; i < sz_logits_guidance; i++) {
2916
- logits_guidance[i] = NUM2DBL(rb_ary_entry(kw_values[1], i));
2917
- }
2918
-
2919
- const float scale = NUM2DBL(kw_values[2]);
2920
-
2921
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
2922
- if (ctx_ptr->ctx == NULL) {
2923
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2924
- return Qnil;
2925
- }
2926
-
2927
- llama_sample_apply_guidance(ctx_ptr->ctx, logits.data(), logits_guidance.data(), scale);
2928
-
2929
- return Qnil;
2930
- }
2931
-
2932
- static VALUE _llama_context_sample_softmax(VALUE self, VALUE candidates) {
2933
- if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
2934
- rb_raise(rb_eArgError, "argument must be a TokenDataArray");
2935
- return Qnil;
2936
- }
2937
-
2938
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
2939
- if (ctx_ptr->ctx == NULL) {
2940
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2941
- return Qnil;
2942
- }
2943
- LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
2944
- if (cnd_ptr->array.data == nullptr) {
2945
- rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
2946
- return Qnil;
2947
- }
2948
-
2949
- llama_sample_softmax(ctx_ptr->ctx, &(cnd_ptr->array));
2950
-
2951
- return Qnil;
2952
- }
2953
-
2954
- static VALUE _llama_context_sample_top_k(int argc, VALUE* argv, VALUE self) {
2955
- VALUE kw_args = Qnil;
2956
- ID kw_table[2] = { rb_intern("k"), rb_intern("min_keep") };
2957
- VALUE kw_values[2] = { Qundef, Qundef };
2958
- VALUE candidates = Qnil;
2959
- rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
2960
- rb_get_kwargs(kw_args, kw_table, 1, 1, kw_values);
2961
-
2962
- if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
2963
- rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
2964
- return Qnil;
2965
- }
2966
- if (!RB_INTEGER_TYPE_P(kw_values[0])) {
2967
- rb_raise(rb_eArgError, "k must be an integer");
2968
- return Qnil;
2969
- }
2970
- if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
2971
- rb_raise(rb_eArgError, "min_keep must be an integer");
2972
- return Qnil;
2973
- }
2974
-
2975
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
2976
- if (ctx_ptr->ctx == NULL) {
2977
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2978
- return Qnil;
2979
- }
2980
- LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
2981
- if (cnd_ptr->array.data == nullptr) {
2982
- rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
2983
- return Qnil;
2984
- }
2985
- const int k = NUM2DBL(kw_values[0]);
2986
- const size_t min_keep = kw_values[1] != Qundef ? NUM2SIZET(kw_values[1]) : 1;
2987
-
2988
- llama_sample_top_k(ctx_ptr->ctx, &(cnd_ptr->array), k, min_keep);
2989
-
2990
- return Qnil;
2991
- }
2992
-
2993
- static VALUE _llama_context_sample_top_p(int argc, VALUE* argv, VALUE self) {
2994
- VALUE kw_args = Qnil;
2995
- ID kw_table[2] = { rb_intern("prob"), rb_intern("min_keep") };
2996
- VALUE kw_values[2] = { Qundef, Qundef };
2997
- VALUE candidates = Qnil;
2998
- rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
2999
- rb_get_kwargs(kw_args, kw_table, 1, 1, kw_values);
3000
-
3001
- if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
3002
- rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
3003
- return Qnil;
3004
- }
3005
- if (!RB_FLOAT_TYPE_P(kw_values[0])) {
3006
- rb_raise(rb_eArgError, "prob must be a float");
3007
- return Qnil;
3008
- }
3009
- if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
3010
- rb_raise(rb_eArgError, "min_keep must be an integer");
3011
- return Qnil;
3012
- }
3013
-
3014
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
3015
- if (ctx_ptr->ctx == NULL) {
3016
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
3017
- return Qnil;
3018
- }
3019
- LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
3020
- if (cnd_ptr->array.data == nullptr) {
3021
- rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
3022
- return Qnil;
3023
- }
3024
- const float prob = NUM2DBL(kw_values[0]);
3025
- const size_t min_keep = kw_values[1] != Qundef ? NUM2SIZET(kw_values[1]) : 1;
3026
-
3027
- llama_sample_top_p(ctx_ptr->ctx, &(cnd_ptr->array), prob, min_keep);
3028
-
3029
- return Qnil;
3030
- }
3031
-
3032
- static VALUE _llama_context_sample_min_p(int argc, VALUE* argv, VALUE self) {
3033
- VALUE kw_args = Qnil;
3034
- ID kw_table[2] = { rb_intern("prob"), rb_intern("min_keep") };
3035
- VALUE kw_values[2] = { Qundef, Qundef };
3036
- VALUE candidates = Qnil;
3037
- rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
3038
- rb_get_kwargs(kw_args, kw_table, 1, 1, kw_values);
3039
-
3040
- if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
3041
- rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
3042
- return Qnil;
3043
- }
3044
- if (!RB_FLOAT_TYPE_P(kw_values[0])) {
3045
- rb_raise(rb_eArgError, "prob must be a float");
3046
- return Qnil;
3047
- }
3048
- if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
3049
- rb_raise(rb_eArgError, "min_keep must be an integer");
3050
- return Qnil;
3051
- }
3052
-
3053
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
3054
- if (ctx_ptr->ctx == NULL) {
3055
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
3056
- return Qnil;
3057
- }
3058
- LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
3059
- if (cnd_ptr->array.data == nullptr) {
3060
- rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
3061
- return Qnil;
3062
- }
3063
- const float prob = NUM2DBL(kw_values[0]);
3064
- const size_t min_keep = kw_values[1] != Qundef ? NUM2SIZET(kw_values[1]) : 1;
3065
-
3066
- llama_sample_min_p(ctx_ptr->ctx, &(cnd_ptr->array), prob, min_keep);
3067
-
3068
- return Qnil;
3069
- }
3070
-
3071
- static VALUE _llama_context_sample_tail_free(int argc, VALUE* argv, VALUE self) {
3072
- VALUE kw_args = Qnil;
3073
- ID kw_table[2] = { rb_intern("z"), rb_intern("min_keep") };
3074
- VALUE kw_values[2] = { Qundef, Qundef };
3075
- VALUE candidates = Qnil;
3076
- rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
3077
- rb_get_kwargs(kw_args, kw_table, 1, 1, kw_values);
3078
-
3079
- if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
3080
- rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
3081
- return Qnil;
3082
- }
3083
- if (!RB_FLOAT_TYPE_P(kw_values[0])) {
3084
- rb_raise(rb_eArgError, "prob must be a float");
3085
- return Qnil;
3086
- }
3087
- if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
3088
- rb_raise(rb_eArgError, "min_keep must be an integer");
3089
- return Qnil;
3090
- }
3091
-
3092
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
3093
- if (ctx_ptr->ctx == NULL) {
3094
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
3095
- return Qnil;
3096
- }
3097
- LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
3098
- if (cnd_ptr->array.data == nullptr) {
3099
- rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
3100
- return Qnil;
3101
- }
3102
- const float z = NUM2DBL(kw_values[0]);
3103
- const size_t min_keep = kw_values[1] != Qundef ? NUM2SIZET(kw_values[1]) : 1;
3104
-
3105
- llama_sample_tail_free(ctx_ptr->ctx, &(cnd_ptr->array), z, min_keep);
3106
-
3107
- return Qnil;
3108
- }
3109
-
3110
- static VALUE _llama_context_sample_typical(int argc, VALUE* argv, VALUE self) {
3111
- VALUE kw_args = Qnil;
3112
- ID kw_table[2] = { rb_intern("prob"), rb_intern("min_keep") };
3113
- VALUE kw_values[2] = { Qundef, Qundef };
3114
- VALUE candidates = Qnil;
3115
- rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
3116
- rb_get_kwargs(kw_args, kw_table, 1, 1, kw_values);
3117
-
3118
- if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
3119
- rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
3120
- return Qnil;
3121
- }
3122
- if (!RB_FLOAT_TYPE_P(kw_values[0])) {
3123
- rb_raise(rb_eArgError, "prob must be a float");
3124
- return Qnil;
3125
- }
3126
- if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
3127
- rb_raise(rb_eArgError, "min_keep must be an integer");
3128
- return Qnil;
3129
- }
3130
-
3131
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
3132
- if (ctx_ptr->ctx == NULL) {
3133
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
3134
- return Qnil;
3135
- }
3136
- LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
3137
- if (cnd_ptr->array.data == nullptr) {
3138
- rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
3139
- return Qnil;
3140
- }
3141
- const float prob = NUM2DBL(kw_values[0]);
3142
- const size_t min_keep = kw_values[1] != Qundef ? NUM2SIZET(kw_values[1]) : 1;
3143
-
3144
- llama_sample_typical(ctx_ptr->ctx, &(cnd_ptr->array), prob, min_keep);
3145
-
3146
- return Qnil;
3147
- }
3148
-
3149
- static VALUE _llama_context_sample_temp(int argc, VALUE* argv, VALUE self) {
3150
- VALUE kw_args = Qnil;
3151
- ID kw_table[1] = { rb_intern("temp") };
3152
- VALUE kw_values[1] = { Qundef };
3153
- VALUE candidates = Qnil;
3154
- rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
3155
- rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
3156
-
3157
- if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
3158
- rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
3159
- return Qnil;
3160
- }
3161
- if (!RB_FLOAT_TYPE_P(kw_values[0])) {
3162
- rb_raise(rb_eArgError, "temp must be a float");
3163
- return Qnil;
3164
- }
3165
-
3166
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
3167
- if (ctx_ptr->ctx == NULL) {
3168
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
3169
- return Qnil;
3170
- }
3171
- LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
3172
- if (cnd_ptr->array.data == nullptr) {
3173
- rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
3174
- return Qnil;
3175
- }
3176
- const float temp = NUM2DBL(kw_values[0]);
3177
-
3178
- llama_sample_temp(ctx_ptr->ctx, &(cnd_ptr->array), temp);
3179
-
3180
- return Qnil;
3181
- }
3182
-
3183
- static VALUE _llama_context_sample_entropy(int argc, VALUE* argv, VALUE self) {
3184
- VALUE kw_args = Qnil;
3185
- ID kw_table[3] = { rb_intern("min_temp"), rb_intern("max_temp"), rb_intern("exponent_val") };
3186
- VALUE kw_values[3] = { Qundef, Qundef, Qundef };
3187
- VALUE candidates = Qnil;
3188
- rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
3189
- rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
3190
-
3191
- if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
3192
- rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
3193
- return Qnil;
3194
- }
3195
- if (!RB_FLOAT_TYPE_P(kw_values[0])) {
3196
- rb_raise(rb_eArgError, "min_temp must be a float");
3197
- return Qnil;
3198
- }
3199
- if (!RB_FLOAT_TYPE_P(kw_values[1])) {
3200
- rb_raise(rb_eArgError, "max_temp must be a float");
3201
- return Qnil;
3202
- }
3203
- if (!RB_FLOAT_TYPE_P(kw_values[2])) {
3204
- rb_raise(rb_eArgError, "exponent_val must be a float");
3205
- return Qnil;
3206
- }
3207
-
3208
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
3209
- if (ctx_ptr->ctx == NULL) {
3210
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
3211
- return Qnil;
3212
- }
3213
- LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
3214
- if (cnd_ptr->array.data == nullptr) {
3215
- rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
3216
- return Qnil;
3217
- }
3218
- const float min_temp = NUM2DBL(kw_values[0]);
3219
- const float max_temp = NUM2DBL(kw_values[1]);
3220
- const float exponent_val = NUM2DBL(kw_values[2]);
3221
-
3222
- llama_sample_entropy(ctx_ptr->ctx, &(cnd_ptr->array), min_temp, max_temp, exponent_val);
3223
-
3224
- return Qnil;
3225
- }
3226
-
3227
- static VALUE _llama_context_sample_token_mirostat(int argc, VALUE* argv, VALUE self) {
3228
- VALUE kw_args = Qnil;
3229
- ID kw_table[4] = { rb_intern("tau"), rb_intern("eta"), rb_intern("m"), rb_intern("mu") };
3230
- VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
3231
- VALUE candidates = Qnil;
3232
- rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
3233
- rb_get_kwargs(kw_args, kw_table, 4, 0, kw_values);
3234
-
3235
- if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
3236
- rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
3237
- return Qnil;
3238
- }
3239
- if (!RB_FLOAT_TYPE_P(kw_values[0])) {
3240
- rb_raise(rb_eArgError, "tau must be a float");
3241
- return Qnil;
3242
- }
3243
- if (!RB_FLOAT_TYPE_P(kw_values[1])) {
3244
- rb_raise(rb_eArgError, "eta must be a float");
3245
- return Qnil;
3246
- }
3247
- if (!RB_INTEGER_TYPE_P(kw_values[2])) {
3248
- rb_raise(rb_eArgError, "m must be an integer");
3249
- return Qnil;
3250
- }
3251
- if (!RB_FLOAT_TYPE_P(kw_values[3])) {
3252
- rb_raise(rb_eArgError, "mu must be a float");
3253
- return Qnil;
3254
- }
3255
-
3256
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
3257
- if (ctx_ptr->ctx == NULL) {
3258
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
3259
- return Qnil;
3260
- }
3261
- LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
3262
- if (cnd_ptr->array.data == nullptr) {
3263
- rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
3264
- return Qnil;
3265
- }
3266
- const float tau = NUM2DBL(kw_values[0]);
3267
- const float eta = NUM2DBL(kw_values[1]);
3268
- const int m = NUM2INT(kw_values[2]);
3269
- float mu = NUM2DBL(kw_values[3]);
3270
-
3271
- llama_token id = llama_sample_token_mirostat(ctx_ptr->ctx, &(cnd_ptr->array), tau, eta, m, &mu);
3272
-
3273
- VALUE ret = rb_ary_new2(2);
3274
- rb_ary_store(ret, 0, INT2NUM(id));
3275
- rb_ary_store(ret, 1, DBL2NUM(mu));
3276
- return ret;
3277
- }
3278
-
3279
- static VALUE _llama_context_sample_token_mirostat_v2(int argc, VALUE* argv, VALUE self) {
3280
- VALUE kw_args = Qnil;
3281
- ID kw_table[3] = { rb_intern("tau"), rb_intern("eta"), rb_intern("mu") };
3282
- VALUE kw_values[3] = { Qundef, Qundef, Qundef };
3283
- VALUE candidates = Qnil;
3284
- rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
3285
- rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
3286
-
3287
- if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
3288
- rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
3289
- return Qnil;
3290
- }
3291
- if (!RB_FLOAT_TYPE_P(kw_values[0])) {
3292
- rb_raise(rb_eArgError, "tau must be a float");
3293
- return Qnil;
3294
- }
3295
- if (!RB_FLOAT_TYPE_P(kw_values[1])) {
3296
- rb_raise(rb_eArgError, "eta must be a float");
3297
- return Qnil;
3298
- }
3299
- if (!RB_FLOAT_TYPE_P(kw_values[2])) {
3300
- rb_raise(rb_eArgError, "mu must be a float");
3301
- return Qnil;
3302
- }
3303
-
3304
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
3305
- if (ctx_ptr->ctx == NULL) {
3306
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
3307
- return Qnil;
3308
- }
3309
- LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
3310
- if (cnd_ptr->array.data == nullptr) {
3311
- rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
3312
- return Qnil;
3313
- }
3314
- const float tau = NUM2DBL(kw_values[0]);
3315
- const float eta = NUM2DBL(kw_values[1]);
3316
- float mu = NUM2DBL(kw_values[2]);
3317
-
3318
- llama_token id = llama_sample_token_mirostat_v2(ctx_ptr->ctx, &(cnd_ptr->array), tau, eta, &mu);
3319
-
3320
- VALUE ret = rb_ary_new2(2);
3321
- rb_ary_store(ret, 0, INT2NUM(id));
3322
- rb_ary_store(ret, 1, DBL2NUM(mu));
3323
- return ret;
3324
- }
3325
-
3326
- static VALUE _llama_context_sample_token_greedy(VALUE self, VALUE candidates) {
3327
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
3328
- if (ctx_ptr->ctx == NULL) {
3329
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
3330
- return Qnil;
3331
- }
3332
- if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
3333
- rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
3334
- return Qnil;
3335
- }
3336
- LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
3337
- if (cnd_ptr->array.data == nullptr) {
3338
- rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
3339
- return Qnil;
3340
- }
3341
- llama_token id = llama_sample_token_greedy(ctx_ptr->ctx, &(cnd_ptr->array));
3342
- return INT2NUM(id);
3343
- }
3344
-
3345
- static VALUE _llama_context_sample_token(VALUE self, VALUE candidates) {
3346
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
3347
- if (ctx_ptr->ctx == NULL) {
3348
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
3349
- return Qnil;
3350
- }
3351
- if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
3352
- rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
3353
- return Qnil;
3354
- }
3355
- LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
3356
- if (cnd_ptr->array.data == nullptr) {
3357
- rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
3358
- return Qnil;
3359
- }
3360
- llama_token id = llama_sample_token(ctx_ptr->ctx, &(cnd_ptr->array));
3361
- return INT2NUM(id);
3362
- }
3363
-
3364
- static VALUE _llama_context_sample_grammar(int argc, VALUE* argv, VALUE self) {
3365
- VALUE kw_args = Qnil;
3366
- ID kw_table[1] = { rb_intern("grammar") };
3367
- VALUE kw_values[1] = { Qundef };
3368
- VALUE candidates = Qnil;
3369
- rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
3370
- rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
3371
-
3372
- if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
3373
- rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
3374
- return Qnil;
3375
- }
3376
- if (!rb_obj_is_kind_of(kw_values[0], rb_cLLaMAGrammar)) {
3377
- rb_raise(rb_eArgError, "grammar must be a Grammar");
3378
- return Qnil;
3379
- }
3380
-
3381
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
3382
- if (ctx_ptr->ctx == NULL) {
3383
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
3384
- return Qnil;
3385
- }
3386
- LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
3387
- if (cnd_ptr->array.data == nullptr) {
3388
- rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
3389
- return Qnil;
3390
- }
3391
- LLaMAGrammarWrapper* grm_ptr = RbLLaMAGrammar::get_llama_grammar(kw_values[0]);
3392
-
3393
- llama_grammar_sample(grm_ptr->grammar, ctx_ptr->ctx, &(cnd_ptr->array));
3394
-
3395
- return Qnil;
3396
- }
3397
-
3398
- static VALUE _llama_context_grammar_accept_token(int argc, VALUE* argv, VALUE self) {
3399
- VALUE kw_args = Qnil;
3400
- ID kw_table[2] = { rb_intern("grammar"), rb_intern("token") };
3401
- VALUE kw_values[2] = { Qundef, Qundef };
3402
- rb_scan_args(argc, argv, ":", &kw_args);
3403
- rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
3404
-
3405
- if (!rb_obj_is_kind_of(kw_values[0], rb_cLLaMAGrammar)) {
3406
- rb_raise(rb_eArgError, "grammar must be a Grammar");
3407
- return Qnil;
3408
- }
3409
- if (!RB_INTEGER_TYPE_P(kw_values[1])) {
3410
- rb_raise(rb_eArgError, "token must be an Integer");
3411
- return Qnil;
3412
- }
3413
-
3414
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
3415
- if (ctx_ptr->ctx == NULL) {
3416
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
3417
- return Qnil;
3418
- }
3419
- LLaMAGrammarWrapper* grm_ptr = RbLLaMAGrammar::get_llama_grammar(kw_values[0]);
3420
- llama_token token = NUM2INT(kw_values[1]);
3421
-
3422
- llama_grammar_accept_token(grm_ptr->grammar, ctx_ptr->ctx, token);
3423
-
3424
- return Qnil;
3425
- }
3426
-
3427
- static VALUE _llama_context_apply_control_vector(int argc, VALUE* argv, VALUE self) {
3428
- VALUE kw_args = Qnil;
3429
- ID kw_table[4] = { rb_intern("data"), rb_intern("n_embd"), rb_intern("il_start"), rb_intern("il_end") };
3430
- VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
3431
- rb_scan_args(argc, argv, ":", &kw_args);
3432
- rb_get_kwargs(kw_args, kw_table, 4, 0, kw_values);
3433
-
3434
- if (!RB_TYPE_P(kw_values[0], T_ARRAY) && !NIL_P(kw_values[0])) {
3435
- rb_raise(rb_eArgError, "data must be an Array or nil");
3436
- return Qnil;
3437
- }
3438
- if (!RB_INTEGER_TYPE_P(kw_values[1])) {
3439
- rb_raise(rb_eArgError, "n_embd must be an Integer");
3440
- return Qnil;
3441
- }
3442
- if (!RB_INTEGER_TYPE_P(kw_values[2])) {
3443
- rb_raise(rb_eArgError, "il_start must be an Integer");
3444
- return Qnil;
3445
- }
3446
- if (!RB_INTEGER_TYPE_P(kw_values[3])) {
3447
- rb_raise(rb_eArgError, "il_end must be an Integer");
3448
- return Qnil;
3449
- }
3450
-
3451
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
3452
- if (ctx_ptr->ctx == NULL) {
3453
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
3454
- return Qnil;
3455
- }
3456
-
3457
- std::vector<float> data(RARRAY_LEN(kw_values[0]));
3458
- for (size_t i = 0; i < data.size(); i++) {
3459
- data[i] = NUM2DBL(rb_ary_entry(kw_values[0], i));
3460
- }
3461
- const int32_t n_embd = NUM2INT(kw_values[1]);
3462
- const int32_t il_start = NUM2INT(kw_values[2]);
3463
- const int32_t il_end = NUM2INT(kw_values[3]);
3464
-
3465
- int32_t err = 0;
3466
- if (NIL_P(kw_values[0])) {
3467
- err = llama_control_vector_apply(ctx_ptr->ctx, NULL, 0, n_embd, il_start, il_end);
3468
- } else {
3469
- err = llama_control_vector_apply(ctx_ptr->ctx, data.data(), data.size(), n_embd, il_start, il_end);
3470
- }
3471
-
3472
- if (err) {
3473
- rb_raise(rb_eRuntimeError, "Failed to apply control vector");
3474
- return Qnil;
3475
- }
3476
-
3477
- return Qnil;
3478
- }
3479
-
3480
- static VALUE _llama_context_pooling_type(VALUE self) {
3481
- LLaMAContextWrapper* ptr = get_llama_context(self);
3482
- if (ptr->ctx == NULL) {
3483
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
3484
- return Qnil;
3485
- }
3486
- return INT2NUM(static_cast<int>(llama_pooling_type(ptr->ctx)));
3487
- }
3488
- };
3489
-
3490
- const rb_data_type_t RbLLaMAContext::llama_context_type = {
3491
- "RbLLaMAContext",
3492
- { NULL,
3493
- RbLLaMAContext::llama_context_free,
3494
- RbLLaMAContext::llama_context_size },
3495
- NULL,
3496
- NULL,
3497
- RUBY_TYPED_FREE_IMMEDIATELY
3498
- };
3499
-
3500
- // module functions
3501
-
3502
- static VALUE rb_llama_llama_backend_init(VALUE self) {
3503
- llama_backend_init();
3504
-
3505
- return Qnil;
3506
- }
3507
-
3508
- static VALUE rb_llama_llama_backend_free(VALUE self) {
3509
- llama_backend_free();
3510
-
3511
- return Qnil;
3512
- }
3513
-
3514
- static VALUE rb_llama_llama_numa_init(VALUE self, VALUE strategy) {
3515
- if (!RB_INTEGER_TYPE_P(strategy)) {
3516
- rb_raise(rb_eArgError, "strategy must be an integer");
3517
- return Qnil;
3518
- }
3519
-
3520
- llama_numa_init(static_cast<enum ggml_numa_strategy>(NUM2INT(strategy)));
3521
-
3522
- return Qnil;
3523
- }
3524
-
3525
- static VALUE rb_llama_model_quantize(int argc, VALUE* argv, VALUE self) {
3526
- VALUE kw_args = Qnil;
3527
- ID kw_table[3] = { rb_intern("input_path"), rb_intern("output_path"), rb_intern("params") };
3528
- VALUE kw_values[3] = { Qundef, Qundef, Qundef };
3529
- rb_scan_args(argc, argv, ":", &kw_args);
3530
- rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
3531
-
3532
- if (!RB_TYPE_P(kw_values[0], T_STRING)) {
3533
- rb_raise(rb_eArgError, "input_path must be a string");
3534
- return Qnil;
3535
- }
3536
- if (!RB_TYPE_P(kw_values[1], T_STRING)) {
3537
- rb_raise(rb_eArgError, "output_path must be a string");
3538
- return Qnil;
3539
- }
3540
- if (!rb_obj_is_kind_of(kw_values[2], rb_cLLaMAModelQuantizeParams)) {
3541
- rb_raise(rb_eArgError, "params must be a ModelQuantizeParams");
3542
- return Qnil;
3543
- }
3544
-
3545
- const char* input_path = StringValueCStr(kw_values[0]);
3546
- const char* output_path = StringValueCStr(kw_values[1]);
3547
- LLaMAModelQuantizeParamsWrapper* wrapper = RbLLaMAModelQuantizeParams::get_llama_model_quantize_params(kw_values[2]);
3548
-
3549
- if (llama_model_quantize(input_path, output_path, &(wrapper->params)) != 0) {
3550
- rb_raise(rb_eRuntimeError, "Failed to quantize model");
3551
- return Qnil;
3552
- }
3553
-
3554
- return Qnil;
3555
- }
3556
-
3557
- static VALUE rb_llama_print_system_info(VALUE self) {
3558
- const char* result = llama_print_system_info();
3559
- return rb_utf8_str_new_cstr(result);
3560
- }
3561
-
3562
- static VALUE rb_llama_time_us(VALUE self) {
3563
- return LONG2NUM(llama_time_us());
3564
- }
3565
-
3566
- static VALUE rb_llama_max_devices(VALUE self) {
3567
- return SIZET2NUM(llama_max_devices());
3568
- }
3569
-
3570
- static VALUE rb_llama_supports_mmap(VALUE self) {
3571
- return llama_supports_mmap() ? Qtrue : Qfalse;
3572
- }
3573
-
3574
- static VALUE rb_llama_supports_mlock(VALUE self) {
3575
- return llama_supports_mlock() ? Qtrue : Qfalse;
3576
- }
3577
-
3578
- static VALUE rb_llama_supports_gpu_offload(VALUE self) {
3579
- return llama_supports_gpu_offload() ? Qtrue : Qfalse;
3580
- }
3581
-
3582
- extern "C" void Init_llama_cpp(void) {
3583
- rb_mLLaMACpp = rb_define_module("LLaMACpp");
3584
-
3585
- RbLLaMABatch::define_class(rb_mLLaMACpp);
3586
- RbLLaMATokenData::define_class(rb_mLLaMACpp);
3587
- RbLLaMATokenDataArray::define_class(rb_mLLaMACpp);
3588
- RbLLaMAModel::define_class(rb_mLLaMACpp);
3589
- RbLLaMAModelKVOverride::define_class(rb_mLLaMACpp);
3590
- RbLLaMAModelParams::define_class(rb_mLLaMACpp);
3591
- RbLLaMATimings::define_class(rb_mLLaMACpp);
3592
- RbLLaMAContext::define_class(rb_mLLaMACpp);
3593
- RbLLaMAContextParams::define_class(rb_mLLaMACpp);
3594
- RbLLaMAModelQuantizeParams::define_class(rb_mLLaMACpp);
3595
- RbLLaMAGrammarElement::define_class(rb_mLLaMACpp);
3596
- RbLLaMAGrammar::define_class(rb_mLLaMACpp);
3597
-
3598
- rb_define_module_function(rb_mLLaMACpp, "backend_init", rb_llama_llama_backend_init, 0);
3599
- rb_define_module_function(rb_mLLaMACpp, "backend_free", rb_llama_llama_backend_free, 0);
3600
- rb_define_module_function(rb_mLLaMACpp, "numa_init", rb_llama_llama_numa_init, 1);
3601
- rb_define_module_function(rb_mLLaMACpp, "model_quantize", rb_llama_model_quantize, -1);
3602
- rb_define_module_function(rb_mLLaMACpp, "print_system_info", rb_llama_print_system_info, 0);
3603
- rb_define_module_function(rb_mLLaMACpp, "time_us", rb_llama_time_us, 0);
3604
- rb_define_module_function(rb_mLLaMACpp, "max_devices", rb_llama_max_devices, 0);
3605
- rb_define_module_function(rb_mLLaMACpp, "supports_mmap?", rb_llama_supports_mmap, 0);
3606
- rb_define_module_function(rb_mLLaMACpp, "supports_mlock?", rb_llama_supports_mlock, 0);
3607
- rb_define_module_function(rb_mLLaMACpp, "supports_gpu_offload?", rb_llama_supports_gpu_offload, 0);
3608
-
3609
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_TYPE_NONE", INT2NUM(LLAMA_VOCAB_TYPE_NONE));
3610
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_TYPE_SPM", INT2NUM(LLAMA_VOCAB_TYPE_SPM));
3611
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_TYPE_BPE", INT2NUM(LLAMA_VOCAB_TYPE_BPE));
3612
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_TYPE_WPM", INT2NUM(LLAMA_VOCAB_TYPE_WPM));
3613
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_TYPE_UGM", INT2NUM(LLAMA_VOCAB_TYPE_UGM));
3614
-
3615
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_DEFAULT", INT2NUM(LLAMA_VOCAB_PRE_TYPE_DEFAULT));
3616
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_LLAMA3", INT2NUM(LLAMA_VOCAB_PRE_TYPE_LLAMA3));
3617
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM", INT2NUM(LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM));
3618
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER", INT2NUM(LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER));
3619
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_FALCON", INT2NUM(LLAMA_VOCAB_PRE_TYPE_FALCON));
3620
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_MPT", INT2NUM(LLAMA_VOCAB_PRE_TYPE_MPT));
3621
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_STARCODER", INT2NUM(LLAMA_VOCAB_PRE_TYPE_STARCODER));
3622
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_GPT2", INT2NUM(LLAMA_VOCAB_PRE_TYPE_GPT2));
3623
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_REFACT", INT2NUM(LLAMA_VOCAB_PRE_TYPE_REFACT));
3624
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_COMMAND_R", INT2NUM(LLAMA_VOCAB_PRE_TYPE_COMMAND_R));
3625
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_STABLELM2", INT2NUM(LLAMA_VOCAB_PRE_TYPE_STABLELM2));
3626
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_QWEN2", INT2NUM(LLAMA_VOCAB_PRE_TYPE_QWEN2));
3627
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_OLMO", INT2NUM(LLAMA_VOCAB_PRE_TYPE_OLMO));
3628
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_DBRX", INT2NUM(LLAMA_VOCAB_PRE_TYPE_DBRX));
3629
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_SMAUG", INT2NUM(LLAMA_VOCAB_PRE_TYPE_SMAUG));
3630
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_PORO", INT2NUM(LLAMA_VOCAB_PRE_TYPE_PORO));
3631
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_CHATGLM3", INT2NUM(LLAMA_VOCAB_PRE_TYPE_CHATGLM3));
3632
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_CHATGLM4", INT2NUM(LLAMA_VOCAB_PRE_TYPE_CHATGLM4));
3633
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_VIKING", INT2NUM(LLAMA_VOCAB_PRE_TYPE_VIKING));
3634
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_JAIS", INT2NUM(LLAMA_VOCAB_PRE_TYPE_JAIS));
3635
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_TEKKEN", INT2NUM(LLAMA_VOCAB_PRE_TYPE_TEKKEN));
3636
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_SMOLLM", INT2NUM(LLAMA_VOCAB_PRE_TYPE_SMOLLM));
3637
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_CODESHELL", INT2NUM(LLAMA_VOCAB_PRE_TYPE_CODESHELL));
3638
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_BLOOM", INT2NUM(LLAMA_VOCAB_PRE_TYPE_BLOOM));
3639
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH", INT2NUM(LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH));
3640
- rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_EXAONE", INT2NUM(LLAMA_VOCAB_PRE_TYPE_EXAONE));
3641
-
3642
- rb_define_const(rb_mLLaMACpp, "LLAMA_TOKEN_TYPE_UNDEFINED", INT2NUM(LLAMA_TOKEN_TYPE_UNDEFINED));
3643
- rb_define_const(rb_mLLaMACpp, "LLAMA_TOKEN_TYPE_NORMAL", INT2NUM(LLAMA_TOKEN_TYPE_NORMAL));
3644
- rb_define_const(rb_mLLaMACpp, "LLAMA_TOKEN_TYPE_UNKNOWN", INT2NUM(LLAMA_TOKEN_TYPE_UNKNOWN));
3645
- rb_define_const(rb_mLLaMACpp, "LLAMA_TOKEN_TYPE_CONTROL", INT2NUM(LLAMA_TOKEN_TYPE_CONTROL));
3646
- rb_define_const(rb_mLLaMACpp, "LLAMA_TOKEN_TYPE_USER_DEFINED", INT2NUM(LLAMA_TOKEN_TYPE_USER_DEFINED));
3647
- rb_define_const(rb_mLLaMACpp, "LLAMA_TOKEN_TYPE_UNUSED", INT2NUM(LLAMA_TOKEN_TYPE_UNUSED));
3648
- rb_define_const(rb_mLLaMACpp, "LLAMA_TOKEN_TYPE_BYTE", INT2NUM(LLAMA_TOKEN_TYPE_BYTE));
3649
-
3650
- rb_define_const(rb_mLLaMACpp, "LLAMA_TOKEN_ATTR_UNDEFINED", INT2NUM(LLAMA_TOKEN_ATTR_UNDEFINED));
3651
- rb_define_const(rb_mLLaMACpp, "LLAMA_TOKEN_ATTR_UNKNOWN", INT2NUM(LLAMA_TOKEN_ATTR_UNKNOWN));
3652
- rb_define_const(rb_mLLaMACpp, "LLAMA_TOKEN_ATTR_UNUSED", INT2NUM(LLAMA_TOKEN_ATTR_UNUSED));
3653
- rb_define_const(rb_mLLaMACpp, "LLAMA_TOKEN_ATTR_NORMAL", INT2NUM(LLAMA_TOKEN_ATTR_NORMAL));
3654
- rb_define_const(rb_mLLaMACpp, "LLAMA_TOKEN_ATTR_CONTROL", INT2NUM(LLAMA_TOKEN_ATTR_CONTROL));
3655
- rb_define_const(rb_mLLaMACpp, "LLAMA_TOKEN_ATTR_USER_DEFINED", INT2NUM(LLAMA_TOKEN_ATTR_USER_DEFINED));
3656
- rb_define_const(rb_mLLaMACpp, "LLAMA_TOKEN_ATTR_BYTE", INT2NUM(LLAMA_TOKEN_ATTR_BYTE));
3657
- rb_define_const(rb_mLLaMACpp, "LLAMA_TOKEN_ATTR_NORMALIZED", INT2NUM(LLAMA_TOKEN_ATTR_NORMALIZED));
3658
- rb_define_const(rb_mLLaMACpp, "LLAMA_TOKEN_ATTR_LSTRIP", INT2NUM(LLAMA_TOKEN_ATTR_LSTRIP));
3659
- rb_define_const(rb_mLLaMACpp, "LLAMA_TOKEN_ATTR_RSTRIP", INT2NUM(LLAMA_TOKEN_ATTR_RSTRIP));
3660
- rb_define_const(rb_mLLaMACpp, "LLAMA_TOKEN_ATTR_SINGLE_WORD", INT2NUM(LLAMA_TOKEN_ATTR_SINGLE_WORD));
3661
-
3662
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_ALL_F32", INT2NUM(LLAMA_FTYPE_ALL_F32));
3663
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_F16", INT2NUM(LLAMA_FTYPE_MOSTLY_F16));
3664
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q4_0", INT2NUM(LLAMA_FTYPE_MOSTLY_Q4_0));
3665
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q4_1", INT2NUM(LLAMA_FTYPE_MOSTLY_Q4_1));
3666
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q8_0", INT2NUM(LLAMA_FTYPE_MOSTLY_Q8_0));
3667
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q5_0", INT2NUM(LLAMA_FTYPE_MOSTLY_Q5_0));
3668
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q5_1", INT2NUM(LLAMA_FTYPE_MOSTLY_Q5_1));
3669
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q2_K", INT2NUM(LLAMA_FTYPE_MOSTLY_Q2_K));
3670
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q3_K_S", INT2NUM(LLAMA_FTYPE_MOSTLY_Q3_K_S));
3671
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q3_K_M", INT2NUM(LLAMA_FTYPE_MOSTLY_Q3_K_M));
3672
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q3_K_L", INT2NUM(LLAMA_FTYPE_MOSTLY_Q3_K_L));
3673
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q4_K_S", INT2NUM(LLAMA_FTYPE_MOSTLY_Q4_K_S));
3674
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q4_K_M", INT2NUM(LLAMA_FTYPE_MOSTLY_Q4_K_M));
3675
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q5_K_S", INT2NUM(LLAMA_FTYPE_MOSTLY_Q5_K_S));
3676
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q5_K_M", INT2NUM(LLAMA_FTYPE_MOSTLY_Q5_K_M));
3677
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q6_K", INT2NUM(LLAMA_FTYPE_MOSTLY_Q6_K));
3678
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_IQ2_XXS", INT2NUM(LLAMA_FTYPE_MOSTLY_IQ2_XXS));
3679
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_IQ2_XS", INT2NUM(LLAMA_FTYPE_MOSTLY_IQ2_XS));
3680
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q2_K_S", INT2NUM(LLAMA_FTYPE_MOSTLY_Q2_K_S));
3681
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_IQ3_XS", INT2NUM(LLAMA_FTYPE_MOSTLY_IQ3_XS));
3682
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_IQ3_XXS", INT2NUM(LLAMA_FTYPE_MOSTLY_IQ3_XXS));
3683
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_IQ1_S", INT2NUM(LLAMA_FTYPE_MOSTLY_IQ1_S));
3684
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_IQ4_NL", INT2NUM(LLAMA_FTYPE_MOSTLY_IQ4_NL));
3685
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_IQ3_S", INT2NUM(LLAMA_FTYPE_MOSTLY_IQ3_S));
3686
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_IQ3_M", INT2NUM(LLAMA_FTYPE_MOSTLY_IQ3_M));
3687
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_IQ4_XS", INT2NUM(LLAMA_FTYPE_MOSTLY_IQ4_XS));
3688
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_IQ1_M", INT2NUM(LLAMA_FTYPE_MOSTLY_IQ1_M));
3689
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_BF16", INT2NUM(LLAMA_FTYPE_MOSTLY_BF16));
3690
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q4_0_4_4", INT2NUM(LLAMA_FTYPE_MOSTLY_Q4_0_4_4));
3691
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q4_0_4_8", INT2NUM(LLAMA_FTYPE_MOSTLY_Q4_0_4_8));
3692
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q4_0_8_8", INT2NUM(LLAMA_FTYPE_MOSTLY_Q4_0_8_8));
3693
-
3694
- rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_GUESSED", INT2NUM(LLAMA_FTYPE_GUESSED));
3695
-
3696
- rb_define_const(rb_mLLaMACpp, "LLAMA_KV_OVERRIDE_TYPE_INT", INT2NUM(LLAMA_KV_OVERRIDE_TYPE_INT));
3697
- rb_define_const(rb_mLLaMACpp, "LLAMA_KV_OVERRIDE_TYPE_FLOAT", INT2NUM(LLAMA_KV_OVERRIDE_TYPE_FLOAT));
3698
- rb_define_const(rb_mLLaMACpp, "LLAMA_KV_OVERRIDE_TYPE_BOOL", INT2NUM(LLAMA_KV_OVERRIDE_TYPE_BOOL));
3699
- rb_define_const(rb_mLLaMACpp, "LLAMA_KV_OVERRIDE_TYPE_STR", INT2NUM(LLAMA_KV_OVERRIDE_TYPE_STR));
3700
-
3701
- rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_END", INT2NUM(LLAMA_GRETYPE_END));
3702
- rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_ALT", INT2NUM(LLAMA_GRETYPE_ALT));
3703
- rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_RULE_REF", INT2NUM(LLAMA_GRETYPE_RULE_REF));
3704
- rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR", INT2NUM(LLAMA_GRETYPE_CHAR));
3705
- rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR_NOT", INT2NUM(LLAMA_GRETYPE_CHAR_NOT));
3706
- rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR_RNG_UPPER", INT2NUM(LLAMA_GRETYPE_CHAR_RNG_UPPER));
3707
- rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR_ALT", INT2NUM(LLAMA_GRETYPE_CHAR_ALT));
3708
- rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR_ANY", INT2NUM(LLAMA_GRETYPE_CHAR_ANY));
3709
-
3710
- rb_define_const(rb_mLLaMACpp, "LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED", INT2NUM(LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED));
3711
- rb_define_const(rb_mLLaMACpp, "LLAMA_ROPE_SCALING_TYPE_NONE", INT2NUM(LLAMA_ROPE_SCALING_TYPE_NONE));
3712
- rb_define_const(rb_mLLaMACpp, "LLAMA_ROPE_SCALING_TYPE_LINEAR", INT2NUM(LLAMA_ROPE_SCALING_TYPE_LINEAR));
3713
- rb_define_const(rb_mLLaMACpp, "LLAMA_ROPE_SCALING_TYPE_YARN", INT2NUM(LLAMA_ROPE_SCALING_TYPE_YARN));
3714
- rb_define_const(rb_mLLaMACpp, "LLAMA_ROPE_SCALING_TYPE_MAX_VALUE", INT2NUM(LLAMA_ROPE_SCALING_TYPE_MAX_VALUE));
3715
-
3716
- rb_define_const(rb_mLLaMACpp, "LLAMA_POOLING_TYPE_UNSPECIFIED", INT2NUM(LLAMA_POOLING_TYPE_UNSPECIFIED));
3717
- rb_define_const(rb_mLLaMACpp, "LLAMA_POOLING_TYPE_NONE", INT2NUM(LLAMA_POOLING_TYPE_NONE));
3718
- rb_define_const(rb_mLLaMACpp, "LLAMA_POOLING_TYPE_MEAN", INT2NUM(LLAMA_POOLING_TYPE_MEAN));
3719
- rb_define_const(rb_mLLaMACpp, "LLAMA_POOLING_TYPE_CLS", INT2NUM(LLAMA_POOLING_TYPE_CLS));
3720
- rb_define_const(rb_mLLaMACpp, "LLAMA_POOLING_TYPE_LAST", INT2NUM(LLAMA_POOLING_TYPE_LAST));
3721
-
3722
- rb_define_const(rb_mLLaMACpp, "LLAMA_ATTENTION_TYPE_UNSPECIFIED", INT2NUM(LLAMA_ATTENTION_TYPE_UNSPECIFIED));
3723
- rb_define_const(rb_mLLaMACpp, "LLAMA_ATTENTION_TYPE_CAUSAL", INT2NUM(LLAMA_ATTENTION_TYPE_CAUSAL));
3724
- rb_define_const(rb_mLLaMACpp, "LLAMA_ATTENTION_TYPE_NON_CAUSAL", INT2NUM(LLAMA_ATTENTION_TYPE_NON_CAUSAL));
3725
-
3726
- rb_define_const(rb_mLLaMACpp, "LLAMA_SPLIT_MODE_NONE", INT2NUM(LLAMA_SPLIT_MODE_NONE));
3727
- rb_define_const(rb_mLLaMACpp, "LLAMA_SPLIT_MODE_LAYER", INT2NUM(LLAMA_SPLIT_MODE_LAYER));
3728
- rb_define_const(rb_mLLaMACpp, "LLAMA_SPLIT_MODE_ROW", INT2NUM(LLAMA_SPLIT_MODE_ROW));
3729
-
3730
- std::stringstream ss_magic;
3731
- ss_magic << std::showbase << std::hex << LLAMA_FILE_MAGIC_GGLA;
3732
- rb_define_const(rb_mLLaMACpp, "LLAMA_FILE_MAGIC_GGLA", rb_str_new2(ss_magic.str().c_str()));
3733
-
3734
- ss_magic.str("");
3735
- ss_magic.clear(std::stringstream::goodbit);
3736
- ss_magic << std::showbase << std::hex << LLAMA_FILE_MAGIC_GGSN;
3737
- rb_define_const(rb_mLLaMACpp, "LLAMA_FILE_MAGIC_GGSN", rb_str_new2(ss_magic.str().c_str()));
3738
-
3739
- ss_magic.str("");
3740
- ss_magic.clear(std::stringstream::goodbit);
3741
- ss_magic << std::showbase << std::hex << LLAMA_FILE_MAGIC_GGSQ;
3742
- rb_define_const(rb_mLLaMACpp, "LLAMA_FILE_MAGIC_GGSQ", rb_str_new2(ss_magic.str().c_str()));
3743
-
3744
- ss_magic.str("");
3745
- ss_magic.clear(std::stringstream::goodbit);
3746
- ss_magic << std::showbase << std::hex << LLAMA_SESSION_MAGIC;
3747
- rb_define_const(rb_mLLaMACpp, "LLAMA_SESSION_MAGIC", rb_str_new2(ss_magic.str().c_str()));
3748
-
3749
- ss_magic.str("");
3750
- ss_magic.clear(std::stringstream::goodbit);
3751
- ss_magic << std::showbase << std::hex << LLAMA_STATE_SEQ_MAGIC;
3752
- rb_define_const(rb_mLLaMACpp, "LLAMA_STATE_SEQ_MAGIC", rb_str_new2(ss_magic.str().c_str()));
3753
-
3754
- ss_magic.str("");
3755
- ss_magic.clear(std::stringstream::goodbit);
3756
- ss_magic << std::showbase << std::hex << LLAMA_DEFAULT_SEED;
3757
- rb_define_const(rb_mLLaMACpp, "LLAMA_DEFAULT_SEED", rb_str_new2(ss_magic.str().c_str()));
3758
-
3759
- rb_define_const(rb_mLLaMACpp, "LLAMA_SESSION_VERSION", rb_str_new2(std::to_string(LLAMA_SESSION_VERSION).c_str()));
3760
- rb_define_const(rb_mLLaMACpp, "LLAMA_STATE_SEQ_VERSION", rb_str_new2(std::to_string(LLAMA_STATE_SEQ_VERSION).c_str()));
3761
- }