llama_cpp 0.7.1 → 0.9.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +18 -0
- data/examples/chat.rb +8 -6
- data/ext/llama_cpp/extconf.rb +2 -2
- data/ext/llama_cpp/llama_cpp.cpp +122 -183
- data/ext/llama_cpp/src/ggml-cuda.cu +188 -20
- data/ext/llama_cpp/src/ggml-metal.m +57 -8
- data/ext/llama_cpp/src/ggml-metal.metal +171 -2
- data/ext/llama_cpp/src/ggml-opencl.cpp +188 -222
- data/ext/llama_cpp/src/ggml.c +375 -93
- data/ext/llama_cpp/src/ggml.h +11 -9
- data/ext/llama_cpp/src/k_quants.c +12 -20
- data/ext/llama_cpp/src/llama.cpp +459 -153
- data/ext/llama_cpp/src/llama.h +34 -33
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +4 -4
- data/sig/llama_cpp.rbs +15 -16
- metadata +3 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 683f2d81aff9e82234925ba08cd5b46b56a2283ff8397a6c06ce50d34a95dbfc
|
4
|
+
data.tar.gz: d3005cab273b8d85f47f4cb4314fbab3a540d366a42829e5ec8d2c29576ae09e
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 559f1ba1253a704c38480336decd315c65b4d80e6895ad1dc0faa3b5b81570a1faeaadcb6ec7ee3145f0fff758ab5e38e6cb8163382ce9b693d893deebe9a8f9
|
7
|
+
data.tar.gz: cb3d96b8c3f79cd20d4169a175270e8768c04bcaa24e51cb2c4d7872db88bc6e3349e6b1e93a130b89d21daab8be6e57b5305412059ea722084c7cb7d4a01e93
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,21 @@
|
|
1
|
+
## [[0.9.0](https://github.com/yoshoku/llama_cpp.rb/compare/v0.8.0...v0.9.0)] - 2023-10-28
|
2
|
+
|
3
|
+
- Fix missing object file for ggml-backend when building with metal and cublas options.
|
4
|
+
|
5
|
+
**Breaking Changes**
|
6
|
+
- Bump bundled llama.cpp from b1405 to b1429
|
7
|
+
- Move following methods from Context to Model:
|
8
|
+
- text, score, type, token_bos, token_eos, token_nl, token_prefix, token_middle, token_suffix, and token_eos.
|
9
|
+
- Add `sample_repetition_penalties` method, which integrates sample_frequency_and_presence_penalties and sample_repetition_penalty methods.
|
10
|
+
|
11
|
+
## [[0.8.0](https://github.com/yoshoku/llama_cpp.rb/compare/v0.7.1...v0.8.0)] - 2023-10-21
|
12
|
+
|
13
|
+
**Breaking Changes**
|
14
|
+
- Bump bundled llama.cpp from b1380 to b1405
|
15
|
+
- Add column index argument to `set_seq_id` and `get_seq_id` methods in Batch.
|
16
|
+
- Add `special` keyword argument to `tokenize` method in Model.
|
17
|
+
- Add `n_seq_max` keyword argument to `initialize` method in Batch.
|
18
|
+
|
1
19
|
## [[0.7.1](https://github.com/yoshoku/llama_cpp.rb/compare/v0.7.0...v0.7.1)] - 2023-10-14
|
2
20
|
|
3
21
|
- Bump bundled llama.cpp from b1334 to b1380.
|
data/examples/chat.rb
CHANGED
@@ -83,10 +83,12 @@ class Chat < Thor # rubocop:disable Metrics/ClassLength, Style/Documentation
|
|
83
83
|
candidates = LLaMACpp::TokenDataArray.new(base_candidates)
|
84
84
|
|
85
85
|
last_n_repeat = [last_n_tokens.size, options[:repeat_last_n], n_ctx].min
|
86
|
-
context.
|
87
|
-
|
88
|
-
|
89
|
-
|
86
|
+
context.sample_repetition_penalties(
|
87
|
+
candidates,
|
88
|
+
last_n_tokens[-last_n_repeat..],
|
89
|
+
penalty_repeat: options[:repeat_penalty],
|
90
|
+
penalty_freq: options[:frequency_penalty],
|
91
|
+
penalty_present: options[:presence_penalty]
|
90
92
|
)
|
91
93
|
|
92
94
|
context.sample_top_k(candidates, k: options[:top_k])
|
@@ -99,8 +101,8 @@ class Chat < Thor # rubocop:disable Metrics/ClassLength, Style/Documentation
|
|
99
101
|
last_n_tokens.shift
|
100
102
|
last_n_tokens.push(id)
|
101
103
|
|
102
|
-
if id == context.token_eos
|
103
|
-
id = context.token_nl
|
104
|
+
if id == context.model.token_eos
|
105
|
+
id = context.model.token_nl
|
104
106
|
unless antiprompt.empty?
|
105
107
|
first_antiprompt = context.model.tokenize(text: antiprompt, add_bos: false)
|
106
108
|
embd_input.concat(first_antiprompt)
|
data/ext/llama_cpp/extconf.rb
CHANGED
@@ -53,7 +53,7 @@ if with_config('metal')
|
|
53
53
|
$CFLAGS << ' -DGGML_USE_METAL'
|
54
54
|
$CXXFLAGS << ' -DGGML_USE_METAL'
|
55
55
|
$LDFLAGS << ' -framework Foundation -framework Metal -framework MetalKit'
|
56
|
-
$objs = %w[ggml.o ggml-alloc.o ggml-metal.o llama.o llama_cpp.o]
|
56
|
+
$objs = %w[ggml.o ggml-backend.o ggml-alloc.o ggml-metal.o llama.o llama_cpp.o]
|
57
57
|
$objs << 'k_quants.o' unless with_config('no_k_quants')
|
58
58
|
end
|
59
59
|
|
@@ -61,7 +61,7 @@ if with_config('cublas')
|
|
61
61
|
$CFLAGS << ' -DGGML_USE_CUBLAS -I/usr/local/cuda/include'
|
62
62
|
$CXXFLAGS << ' -DGGML_USE_CUBLAS -I/usr/local/cuda/include'
|
63
63
|
$LDFLAGS << ' -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64'
|
64
|
-
$objs = %w[ggml.o ggml-alloc.o ggml-cuda.o llama.o llama_cpp.o]
|
64
|
+
$objs = %w[ggml.o ggml-backend.o ggml-alloc.o ggml-cuda.o llama.o llama_cpp.o]
|
65
65
|
$objs << 'k_quants.o' unless with_config('no_k_quants')
|
66
66
|
end
|
67
67
|
|
data/ext/llama_cpp/llama_cpp.cpp
CHANGED
@@ -63,8 +63,8 @@ public:
|
|
63
63
|
rb_define_method(rb_cLLaMABatch, "get_token", RUBY_METHOD_FUNC(_llama_batch_get_token), 1);
|
64
64
|
rb_define_method(rb_cLLaMABatch, "set_pos", RUBY_METHOD_FUNC(_llama_batch_set_pos), 2);
|
65
65
|
rb_define_method(rb_cLLaMABatch, "get_pos", RUBY_METHOD_FUNC(_llama_batch_get_pos), 1);
|
66
|
-
rb_define_method(rb_cLLaMABatch, "set_seq_id", RUBY_METHOD_FUNC(_llama_batch_set_seq_id),
|
67
|
-
rb_define_method(rb_cLLaMABatch, "get_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_seq_id),
|
66
|
+
rb_define_method(rb_cLLaMABatch, "set_seq_id", RUBY_METHOD_FUNC(_llama_batch_set_seq_id), 3);
|
67
|
+
rb_define_method(rb_cLLaMABatch, "get_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_seq_id), 2);
|
68
68
|
rb_define_method(rb_cLLaMABatch, "set_logits", RUBY_METHOD_FUNC(_llama_batch_set_logits), 2);
|
69
69
|
rb_define_method(rb_cLLaMABatch, "get_logits", RUBY_METHOD_FUNC(_llama_batch_get_logits), 1);
|
70
70
|
}
|
@@ -74,10 +74,10 @@ private:
|
|
74
74
|
|
75
75
|
static VALUE _llama_batch_initialize(int argc, VALUE* argv, VALUE self) {
|
76
76
|
VALUE kw_args = Qnil;
|
77
|
-
ID kw_table[
|
78
|
-
VALUE kw_values[
|
77
|
+
ID kw_table[3] = { rb_intern("n_tokens"), rb_intern("embd"), rb_intern("n_seq_max") };
|
78
|
+
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
79
79
|
rb_scan_args(argc, argv, ":", &kw_args);
|
80
|
-
rb_get_kwargs(kw_args, kw_table,
|
80
|
+
rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
|
81
81
|
|
82
82
|
if (!RB_INTEGER_TYPE_P(kw_values[0])) {
|
83
83
|
rb_raise(rb_eArgError, "n_tokens must be an integer");
|
@@ -87,12 +87,17 @@ private:
|
|
87
87
|
rb_raise(rb_eArgError, "embd must be an integer");
|
88
88
|
return Qnil;
|
89
89
|
}
|
90
|
+
if (!RB_INTEGER_TYPE_P(kw_values[2])) {
|
91
|
+
rb_raise(rb_eArgError, "n_seq_max must be an integer");
|
92
|
+
return Qnil;
|
93
|
+
}
|
90
94
|
|
91
95
|
const int32_t n_tokens = NUM2INT(kw_values[0]);
|
92
96
|
const int32_t embd = NUM2INT(kw_values[1]);
|
97
|
+
const int32_t n_seq_max = NUM2INT(kw_values[2]);
|
93
98
|
|
94
99
|
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
95
|
-
ptr->batch = llama_batch_init(n_tokens, embd);
|
100
|
+
ptr->batch = llama_batch_init(n_tokens, embd, n_seq_max);
|
96
101
|
|
97
102
|
return Qnil;
|
98
103
|
}
|
@@ -190,25 +195,35 @@ private:
|
|
190
195
|
}
|
191
196
|
|
192
197
|
// seq_id
|
193
|
-
static VALUE _llama_batch_set_seq_id(VALUE self, VALUE
|
198
|
+
static VALUE _llama_batch_set_seq_id(VALUE self, VALUE i_, VALUE j_, VALUE value) {
|
194
199
|
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
195
|
-
const int32_t
|
196
|
-
if (
|
197
|
-
rb_raise(rb_eArgError, "
|
200
|
+
const int32_t i = NUM2INT(i_);
|
201
|
+
if (i < 0 || i >= ptr->batch.n_tokens) {
|
202
|
+
rb_raise(rb_eArgError, "i must be in [0, n_tokens)");
|
203
|
+
return Qnil;
|
204
|
+
}
|
205
|
+
const int32_t j = NUM2INT(j_);
|
206
|
+
if (j < 0 || j >= ptr->batch.n_seq_id[i]) {
|
207
|
+
rb_raise(rb_eArgError, "j must be in [0, n_seq_id[i])");
|
198
208
|
return Qnil;
|
199
209
|
}
|
200
|
-
ptr->batch.seq_id[
|
201
|
-
return INT2NUM(ptr->batch.seq_id[
|
210
|
+
ptr->batch.seq_id[i][j] = NUM2INT(value);
|
211
|
+
return INT2NUM(ptr->batch.seq_id[i][j]);
|
202
212
|
}
|
203
213
|
|
204
|
-
static VALUE _llama_batch_get_seq_id(VALUE self, VALUE
|
214
|
+
static VALUE _llama_batch_get_seq_id(VALUE self, VALUE i_, VALUE j_) {
|
205
215
|
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
206
|
-
const int32_t
|
207
|
-
if (
|
208
|
-
rb_raise(rb_eArgError, "
|
216
|
+
const int32_t i = NUM2INT(i_);
|
217
|
+
if (i < 0 || i >= ptr->batch.n_tokens) {
|
218
|
+
rb_raise(rb_eArgError, "i must be in [0, n_tokens)");
|
209
219
|
return Qnil;
|
210
220
|
}
|
211
|
-
|
221
|
+
const int32_t j = NUM2INT(j_);
|
222
|
+
if (j < 0 || j >= ptr->batch.n_seq_id[i]) {
|
223
|
+
rb_raise(rb_eArgError, "j must be in [0, n_seq_id[i])");
|
224
|
+
return Qnil;
|
225
|
+
}
|
226
|
+
return INT2NUM(ptr->batch.seq_id[i][j]);
|
212
227
|
}
|
213
228
|
|
214
229
|
// logits
|
@@ -1133,6 +1148,16 @@ public:
|
|
1133
1148
|
rb_define_method(rb_cLLaMAModel, "desc", RUBY_METHOD_FUNC(_llama_model_get_model_desc), 0);
|
1134
1149
|
rb_define_method(rb_cLLaMAModel, "size", RUBY_METHOD_FUNC(_llama_model_get_model_size), 0);
|
1135
1150
|
rb_define_method(rb_cLLaMAModel, "n_params", RUBY_METHOD_FUNC(_llama_model_get_model_n_params), 0);
|
1151
|
+
rb_define_method(rb_cLLaMAModel, "text", RUBY_METHOD_FUNC(_llama_model_get_text), 1);
|
1152
|
+
rb_define_method(rb_cLLaMAModel, "score", RUBY_METHOD_FUNC(_llama_model_get_score), 1);
|
1153
|
+
rb_define_method(rb_cLLaMAModel, "type", RUBY_METHOD_FUNC(_llama_model_get_type), 1);
|
1154
|
+
rb_define_method(rb_cLLaMAModel, "token_bos", RUBY_METHOD_FUNC(_llama_model_token_bos), 0);
|
1155
|
+
rb_define_method(rb_cLLaMAModel, "token_eos", RUBY_METHOD_FUNC(_llama_model_token_eos), 0);
|
1156
|
+
rb_define_method(rb_cLLaMAModel, "token_nl", RUBY_METHOD_FUNC(_llama_model_token_nl), 0);
|
1157
|
+
rb_define_method(rb_cLLaMAModel, "token_prefix", RUBY_METHOD_FUNC(_llama_model_token_prefix), 0);
|
1158
|
+
rb_define_method(rb_cLLaMAModel, "token_middle", RUBY_METHOD_FUNC(_llama_model_token_middle), 0);
|
1159
|
+
rb_define_method(rb_cLLaMAModel, "token_suffix", RUBY_METHOD_FUNC(_llama_model_token_suffix), 0);
|
1160
|
+
rb_define_method(rb_cLLaMAModel, "token_eot", RUBY_METHOD_FUNC(_llama_model_token_eot), 0);
|
1136
1161
|
}
|
1137
1162
|
|
1138
1163
|
private:
|
@@ -1319,10 +1344,10 @@ private:
|
|
1319
1344
|
|
1320
1345
|
static VALUE _llama_model_tokenize(int argc, VALUE* argv, VALUE self) {
|
1321
1346
|
VALUE kw_args = Qnil;
|
1322
|
-
ID kw_table[
|
1323
|
-
VALUE kw_values[
|
1347
|
+
ID kw_table[4] = { rb_intern("text"), rb_intern("n_max_tokens"), rb_intern("add_bos"), rb_intern("special") };
|
1348
|
+
VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
|
1324
1349
|
rb_scan_args(argc, argv, ":", &kw_args);
|
1325
|
-
rb_get_kwargs(kw_args, kw_table, 1,
|
1350
|
+
rb_get_kwargs(kw_args, kw_table, 1, 3, kw_values);
|
1326
1351
|
|
1327
1352
|
if (!RB_TYPE_P(kw_values[0], T_STRING)) {
|
1328
1353
|
rb_raise(rb_eArgError, "text must be a String");
|
@@ -1336,15 +1361,20 @@ private:
|
|
1336
1361
|
rb_raise(rb_eArgError, "add_bos must be a boolean");
|
1337
1362
|
return Qnil;
|
1338
1363
|
}
|
1364
|
+
if (kw_values[3] != Qundef && (kw_values[3] != Qtrue && kw_values[3] != Qfalse)) {
|
1365
|
+
rb_raise(rb_eArgError, "special must be a boolean");
|
1366
|
+
return Qnil;
|
1367
|
+
}
|
1339
1368
|
|
1340
1369
|
VALUE text_ = kw_values[0];
|
1341
1370
|
std::string text = StringValueCStr(text_);
|
1342
1371
|
const bool add_bos = kw_values[2] == Qtrue ? true : false;
|
1372
|
+
const bool special = kw_values[3] == Qtrue ? true : false;
|
1343
1373
|
const int n_max_tokens = kw_values[1] != Qundef ? NUM2INT(kw_values[1]) : text.size() + (add_bos ? 1 : 0);
|
1344
1374
|
|
1345
1375
|
llama_token* tokens = ALLOCA_N(llama_token, n_max_tokens);
|
1346
1376
|
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1347
|
-
const int n_tokens = llama_tokenize(ptr->model, text.c_str(), text.size(), tokens, n_max_tokens, add_bos);
|
1377
|
+
const int n_tokens = llama_tokenize(ptr->model, text.c_str(), text.size(), tokens, n_max_tokens, add_bos, special);
|
1348
1378
|
|
1349
1379
|
if (n_tokens < 0) {
|
1350
1380
|
rb_raise(rb_eRuntimeError, "failed to tokenize. The numebr of tokens (%d) is greater than n_max_tokens.", -n_tokens);
|
@@ -1376,6 +1406,62 @@ private:
|
|
1376
1406
|
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1377
1407
|
return UINT2NUM(llama_model_n_params(ptr->model));
|
1378
1408
|
}
|
1409
|
+
|
1410
|
+
static VALUE _llama_model_get_text(VALUE self, VALUE token_) {
|
1411
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1412
|
+
const llama_token token = NUM2INT(token_);
|
1413
|
+
const char* text = llama_token_get_text(ptr->model, token);
|
1414
|
+
return rb_utf8_str_new_cstr(text);
|
1415
|
+
}
|
1416
|
+
|
1417
|
+
static VALUE _llama_model_get_score(VALUE self, VALUE token_) {
|
1418
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1419
|
+
const llama_token token = NUM2INT(token_);
|
1420
|
+
const float score = llama_token_get_score(ptr->model, token);
|
1421
|
+
return DBL2NUM(score);
|
1422
|
+
}
|
1423
|
+
|
1424
|
+
static VALUE _llama_model_get_type(VALUE self, VALUE token_) {
|
1425
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1426
|
+
const llama_token token = NUM2INT(token_);
|
1427
|
+
const int type = llama_token_get_type(ptr->model, token);
|
1428
|
+
return INT2NUM(type);
|
1429
|
+
}
|
1430
|
+
|
1431
|
+
static VALUE _llama_model_token_bos(VALUE self) {
|
1432
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1433
|
+
return INT2NUM(llama_token_bos(ptr->model));
|
1434
|
+
}
|
1435
|
+
|
1436
|
+
static VALUE _llama_model_token_eos(VALUE self) {
|
1437
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1438
|
+
return INT2NUM(llama_token_eos(ptr->model));
|
1439
|
+
}
|
1440
|
+
|
1441
|
+
static VALUE _llama_model_token_nl(VALUE self) {
|
1442
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1443
|
+
return INT2NUM(llama_token_nl(ptr->model));
|
1444
|
+
}
|
1445
|
+
|
1446
|
+
static VALUE _llama_model_token_prefix(VALUE self) {
|
1447
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1448
|
+
return INT2NUM(llama_token_prefix(ptr->model));
|
1449
|
+
}
|
1450
|
+
|
1451
|
+
static VALUE _llama_model_token_middle(VALUE self) {
|
1452
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1453
|
+
return INT2NUM(llama_token_middle(ptr->model));
|
1454
|
+
}
|
1455
|
+
|
1456
|
+
static VALUE _llama_model_token_suffix(VALUE self) {
|
1457
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1458
|
+
return INT2NUM(llama_token_suffix(ptr->model));
|
1459
|
+
}
|
1460
|
+
|
1461
|
+
static VALUE _llama_model_token_eot(VALUE self) {
|
1462
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1463
|
+
return INT2NUM(llama_token_eot(ptr->model));
|
1464
|
+
}
|
1379
1465
|
};
|
1380
1466
|
|
1381
1467
|
const rb_data_type_t RbLLaMAModel::llama_model_type = {
|
@@ -1650,16 +1736,6 @@ public:
|
|
1650
1736
|
rb_define_method(rb_cLLaMAContext, "decode", RUBY_METHOD_FUNC(_llama_context_decode), 1);
|
1651
1737
|
rb_define_method(rb_cLLaMAContext, "logits", RUBY_METHOD_FUNC(_llama_context_logits), 0);
|
1652
1738
|
rb_define_method(rb_cLLaMAContext, "embeddings", RUBY_METHOD_FUNC(_llama_context_embeddings), 0);
|
1653
|
-
rb_define_method(rb_cLLaMAContext, "text", RUBY_METHOD_FUNC(_llama_context_text), 1);
|
1654
|
-
rb_define_method(rb_cLLaMAContext, "score", RUBY_METHOD_FUNC(_llama_context_score), 1);
|
1655
|
-
rb_define_method(rb_cLLaMAContext, "type", RUBY_METHOD_FUNC(_llama_context_type), 1);
|
1656
|
-
rb_define_method(rb_cLLaMAContext, "token_bos", RUBY_METHOD_FUNC(_llama_context_token_bos), 0);
|
1657
|
-
rb_define_method(rb_cLLaMAContext, "token_eos", RUBY_METHOD_FUNC(_llama_context_token_eos), 0);
|
1658
|
-
rb_define_method(rb_cLLaMAContext, "token_nl", RUBY_METHOD_FUNC(_llama_context_token_nl), 0);
|
1659
|
-
rb_define_method(rb_cLLaMAContext, "token_prefix", RUBY_METHOD_FUNC(_llama_context_token_prefix), 0);
|
1660
|
-
rb_define_method(rb_cLLaMAContext, "token_middle", RUBY_METHOD_FUNC(_llama_context_token_middle), 0);
|
1661
|
-
rb_define_method(rb_cLLaMAContext, "token_suffix", RUBY_METHOD_FUNC(_llama_context_token_suffix), 0);
|
1662
|
-
rb_define_method(rb_cLLaMAContext, "token_eot", RUBY_METHOD_FUNC(_llama_context_token_eot), 0);
|
1663
1739
|
rb_define_method(rb_cLLaMAContext, "n_ctx", RUBY_METHOD_FUNC(_llama_context_n_ctx), 0);
|
1664
1740
|
rb_define_method(rb_cLLaMAContext, "timings", RUBY_METHOD_FUNC(_llama_context_get_timings), 0);
|
1665
1741
|
rb_define_method(rb_cLLaMAContext, "print_timings", RUBY_METHOD_FUNC(_llama_context_print_timings), 0);
|
@@ -1673,8 +1749,7 @@ public:
|
|
1673
1749
|
rb_define_method(rb_cLLaMAContext, "set_rng_seed", RUBY_METHOD_FUNC(_llama_context_set_rng_seed), 1);
|
1674
1750
|
rb_define_method(rb_cLLaMAContext, "load_session_file", RUBY_METHOD_FUNC(_llama_context_load_session_file), -1);
|
1675
1751
|
rb_define_method(rb_cLLaMAContext, "save_session_file", RUBY_METHOD_FUNC(_llama_context_save_session_file), -1);
|
1676
|
-
rb_define_method(rb_cLLaMAContext, "
|
1677
|
-
rb_define_method(rb_cLLaMAContext, "sample_frequency_and_presence_penalties", RUBY_METHOD_FUNC(_llama_context_sample_frequency_and_presence_penalties), -1);
|
1752
|
+
rb_define_method(rb_cLLaMAContext, "sample_repetition_penalties", RUBY_METHOD_FUNC(_llama_context_sample_repetition_penalties), -1);
|
1678
1753
|
rb_define_method(rb_cLLaMAContext, "sample_classifier_free_guidance", RUBY_METHOD_FUNC(_llama_context_sample_classifier_free_guidance), -1);
|
1679
1754
|
rb_define_method(rb_cLLaMAContext, "sample_softmax", RUBY_METHOD_FUNC(_llama_context_sample_softmax), 1);
|
1680
1755
|
rb_define_method(rb_cLLaMAContext, "sample_top_k", RUBY_METHOD_FUNC(_llama_context_sample_top_k), -1);
|
@@ -1907,102 +1982,6 @@ private:
|
|
1907
1982
|
return output;
|
1908
1983
|
}
|
1909
1984
|
|
1910
|
-
static VALUE _llama_context_text(VALUE self, VALUE token_) {
|
1911
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1912
|
-
if (ptr->ctx == NULL) {
|
1913
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1914
|
-
return Qnil;
|
1915
|
-
}
|
1916
|
-
const llama_token token = NUM2INT(token_);
|
1917
|
-
const char* text = llama_token_get_text(ptr->ctx, token);
|
1918
|
-
return rb_utf8_str_new_cstr(text);
|
1919
|
-
}
|
1920
|
-
|
1921
|
-
static VALUE _llama_context_score(VALUE self, VALUE token_) {
|
1922
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1923
|
-
if (ptr->ctx == NULL) {
|
1924
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1925
|
-
return Qnil;
|
1926
|
-
}
|
1927
|
-
const llama_token token = NUM2INT(token_);
|
1928
|
-
const float score = llama_token_get_score(ptr->ctx, token);
|
1929
|
-
return DBL2NUM(score);
|
1930
|
-
}
|
1931
|
-
|
1932
|
-
static VALUE _llama_context_type(VALUE self, VALUE token_) {
|
1933
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1934
|
-
if (ptr->ctx == NULL) {
|
1935
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1936
|
-
return Qnil;
|
1937
|
-
}
|
1938
|
-
const llama_token token = NUM2INT(token_);
|
1939
|
-
const int type = llama_token_get_type(ptr->ctx, token);
|
1940
|
-
return INT2NUM(type);
|
1941
|
-
}
|
1942
|
-
|
1943
|
-
static VALUE _llama_context_token_bos(VALUE self) {
|
1944
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1945
|
-
if (ptr->ctx == NULL) {
|
1946
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1947
|
-
return Qnil;
|
1948
|
-
}
|
1949
|
-
return INT2NUM(llama_token_bos(ptr->ctx));
|
1950
|
-
}
|
1951
|
-
|
1952
|
-
static VALUE _llama_context_token_eos(VALUE self) {
|
1953
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1954
|
-
if (ptr->ctx == NULL) {
|
1955
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1956
|
-
return Qnil;
|
1957
|
-
}
|
1958
|
-
return INT2NUM(llama_token_eos(ptr->ctx));
|
1959
|
-
}
|
1960
|
-
|
1961
|
-
static VALUE _llama_context_token_nl(VALUE self) {
|
1962
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1963
|
-
if (ptr->ctx == NULL) {
|
1964
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1965
|
-
return Qnil;
|
1966
|
-
}
|
1967
|
-
return INT2NUM(llama_token_nl(ptr->ctx));
|
1968
|
-
}
|
1969
|
-
|
1970
|
-
static VALUE _llama_context_token_prefix(VALUE self) {
|
1971
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1972
|
-
if (ptr->ctx == NULL) {
|
1973
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1974
|
-
return Qnil;
|
1975
|
-
}
|
1976
|
-
return INT2NUM(llama_token_prefix(ptr->ctx));
|
1977
|
-
}
|
1978
|
-
|
1979
|
-
static VALUE _llama_context_token_middle(VALUE self) {
|
1980
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1981
|
-
if (ptr->ctx == NULL) {
|
1982
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1983
|
-
return Qnil;
|
1984
|
-
}
|
1985
|
-
return INT2NUM(llama_token_middle(ptr->ctx));
|
1986
|
-
}
|
1987
|
-
|
1988
|
-
static VALUE _llama_context_token_suffix(VALUE self) {
|
1989
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1990
|
-
if (ptr->ctx == NULL) {
|
1991
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1992
|
-
return Qnil;
|
1993
|
-
}
|
1994
|
-
return INT2NUM(llama_token_suffix(ptr->ctx));
|
1995
|
-
}
|
1996
|
-
|
1997
|
-
static VALUE _llama_context_token_eot(VALUE self) {
|
1998
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1999
|
-
if (ptr->ctx == NULL) {
|
2000
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
2001
|
-
return Qnil;
|
2002
|
-
}
|
2003
|
-
return INT2NUM(llama_token_eot(ptr->ctx));
|
2004
|
-
}
|
2005
|
-
|
2006
1985
|
static VALUE _llama_context_n_ctx(VALUE self) {
|
2007
1986
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
2008
1987
|
if (ptr->ctx == NULL) {
|
@@ -2211,14 +2190,14 @@ private:
|
|
2211
2190
|
return Qnil;
|
2212
2191
|
}
|
2213
2192
|
|
2214
|
-
static VALUE
|
2193
|
+
static VALUE _llama_context_sample_repetition_penalties(int argc, VALUE* argv, VALUE self) {
|
2215
2194
|
VALUE kw_args = Qnil;
|
2216
|
-
ID kw_table[
|
2217
|
-
VALUE kw_values[
|
2195
|
+
ID kw_table[3] = { rb_intern("penalty_repeat"), rb_intern("penalty_freq"), rb_intern("penalty_present") };
|
2196
|
+
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
2218
2197
|
VALUE candidates = Qnil;
|
2219
2198
|
VALUE last_n_tokens = Qnil;
|
2220
2199
|
rb_scan_args(argc, argv, "2:", &candidates, &last_n_tokens, &kw_args);
|
2221
|
-
rb_get_kwargs(kw_args, kw_table,
|
2200
|
+
rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
|
2222
2201
|
|
2223
2202
|
if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
|
2224
2203
|
rb_raise(rb_eArgError, "candidates must be a TokenDataArray");
|
@@ -2229,56 +2208,15 @@ private:
|
|
2229
2208
|
return Qnil;
|
2230
2209
|
}
|
2231
2210
|
if (!RB_FLOAT_TYPE_P(kw_values[0])) {
|
2232
|
-
rb_raise(rb_eArgError, "
|
2211
|
+
rb_raise(rb_eArgError, "penalty_repeat must be a float");
|
2233
2212
|
return Qnil;
|
2234
2213
|
}
|
2235
|
-
|
2236
|
-
|
2237
|
-
std::vector<llama_token> last_n_tokens_data(last_tokens_size);
|
2238
|
-
for (size_t i = 0; i < last_tokens_size; i++) {
|
2239
|
-
last_n_tokens_data[i] = NUM2INT(rb_ary_entry(last_n_tokens, i));
|
2240
|
-
}
|
2241
|
-
|
2242
|
-
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
2243
|
-
if (ctx_ptr->ctx == NULL) {
|
2244
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
2245
|
-
return Qnil;
|
2246
|
-
}
|
2247
|
-
LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
|
2248
|
-
if (cnd_ptr->array.data == nullptr) {
|
2249
|
-
rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
|
2250
|
-
return Qnil;
|
2251
|
-
}
|
2252
|
-
const float penalty = NUM2DBL(kw_values[0]);
|
2253
|
-
|
2254
|
-
llama_sample_repetition_penalty(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size, penalty);
|
2255
|
-
|
2256
|
-
return Qnil;
|
2257
|
-
}
|
2258
|
-
|
2259
|
-
static VALUE _llama_context_sample_frequency_and_presence_penalties(int argc, VALUE* argv, VALUE self) {
|
2260
|
-
VALUE kw_args = Qnil;
|
2261
|
-
ID kw_table[2] = { rb_intern("frequency"), rb_intern("presence") };
|
2262
|
-
VALUE kw_values[2] = { Qundef, Qundef };
|
2263
|
-
VALUE candidates = Qnil;
|
2264
|
-
VALUE last_n_tokens = Qnil;
|
2265
|
-
rb_scan_args(argc, argv, "2:", &candidates, &last_n_tokens, &kw_args);
|
2266
|
-
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
2267
|
-
|
2268
|
-
if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
|
2269
|
-
rb_raise(rb_eArgError, "candidates must be a TokenDataArray");
|
2270
|
-
return Qnil;
|
2271
|
-
}
|
2272
|
-
if (!RB_TYPE_P(last_n_tokens, T_ARRAY)) {
|
2273
|
-
rb_raise(rb_eArgError, "last_n_tokens must be an Array");
|
2274
|
-
return Qnil;
|
2275
|
-
}
|
2276
|
-
if (!RB_FLOAT_TYPE_P(kw_values[0])) {
|
2277
|
-
rb_raise(rb_eArgError, "frequency must be a float");
|
2214
|
+
if (!RB_FLOAT_TYPE_P(kw_values[1])) {
|
2215
|
+
rb_raise(rb_eArgError, "penalty_freq must be a float");
|
2278
2216
|
return Qnil;
|
2279
2217
|
}
|
2280
|
-
if (!RB_FLOAT_TYPE_P(kw_values[
|
2281
|
-
rb_raise(rb_eArgError, "
|
2218
|
+
if (!RB_FLOAT_TYPE_P(kw_values[2])) {
|
2219
|
+
rb_raise(rb_eArgError, "penalty_present must be a float");
|
2282
2220
|
return Qnil;
|
2283
2221
|
}
|
2284
2222
|
|
@@ -2298,11 +2236,12 @@ private:
|
|
2298
2236
|
rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
|
2299
2237
|
return Qnil;
|
2300
2238
|
}
|
2239
|
+
const float penalty_repeat = NUM2DBL(kw_values[0]);
|
2240
|
+
const float penalty_freq = NUM2DBL(kw_values[1]);
|
2241
|
+
const float penalty_present = NUM2DBL(kw_values[2]);
|
2301
2242
|
|
2302
|
-
|
2303
|
-
|
2304
|
-
|
2305
|
-
llama_sample_frequency_and_presence_penalties(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size, alpha_frequency, alpha_presence);
|
2243
|
+
llama_sample_repetition_penalties(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size,
|
2244
|
+
penalty_repeat, penalty_freq, penalty_present);
|
2306
2245
|
|
2307
2246
|
return Qnil;
|
2308
2247
|
}
|