llama_cpp 0.7.1 → 0.9.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +18 -0
- data/examples/chat.rb +8 -6
- data/ext/llama_cpp/extconf.rb +2 -2
- data/ext/llama_cpp/llama_cpp.cpp +122 -183
- data/ext/llama_cpp/src/ggml-cuda.cu +188 -20
- data/ext/llama_cpp/src/ggml-metal.m +57 -8
- data/ext/llama_cpp/src/ggml-metal.metal +171 -2
- data/ext/llama_cpp/src/ggml-opencl.cpp +188 -222
- data/ext/llama_cpp/src/ggml.c +375 -93
- data/ext/llama_cpp/src/ggml.h +11 -9
- data/ext/llama_cpp/src/k_quants.c +12 -20
- data/ext/llama_cpp/src/llama.cpp +459 -153
- data/ext/llama_cpp/src/llama.h +34 -33
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +4 -4
- data/sig/llama_cpp.rbs +15 -16
- metadata +3 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 683f2d81aff9e82234925ba08cd5b46b56a2283ff8397a6c06ce50d34a95dbfc
|
4
|
+
data.tar.gz: d3005cab273b8d85f47f4cb4314fbab3a540d366a42829e5ec8d2c29576ae09e
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 559f1ba1253a704c38480336decd315c65b4d80e6895ad1dc0faa3b5b81570a1faeaadcb6ec7ee3145f0fff758ab5e38e6cb8163382ce9b693d893deebe9a8f9
|
7
|
+
data.tar.gz: cb3d96b8c3f79cd20d4169a175270e8768c04bcaa24e51cb2c4d7872db88bc6e3349e6b1e93a130b89d21daab8be6e57b5305412059ea722084c7cb7d4a01e93
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,21 @@
|
|
1
|
+
## [[0.9.0](https://github.com/yoshoku/llama_cpp.rb/compare/v0.8.0...v0.9.0)] - 2023-10-28
|
2
|
+
|
3
|
+
- Fix missing object file for ggml-backend when building with metal and cublas options.
|
4
|
+
|
5
|
+
**Breaking Changes**
|
6
|
+
- Bump bundled llama.cpp from b1405 to b1429
|
7
|
+
- Move following methods from Context to Model:
|
8
|
+
- text, score, type, token_bos, token_eos, token_nl, token_prefix, token_middle, token_suffix, and token_eos.
|
9
|
+
- Add `sample_repetition_penalties` method, which integrates sample_frequency_and_presence_penalties and sample_repetition_penalty methods.
|
10
|
+
|
11
|
+
## [[0.8.0](https://github.com/yoshoku/llama_cpp.rb/compare/v0.7.1...v0.8.0)] - 2023-10-21
|
12
|
+
|
13
|
+
**Breaking Changes**
|
14
|
+
- Bump bundled llama.cpp from b1380 to b1405
|
15
|
+
- Add column index argument to `set_seq_id` and `get_seq_id` methods in Batch.
|
16
|
+
- Add `special` keyword argument to `tokenize` method in Model.
|
17
|
+
- Add `n_seq_max` keyword argument to `initialize` method in Batch.
|
18
|
+
|
1
19
|
## [[0.7.1](https://github.com/yoshoku/llama_cpp.rb/compare/v0.7.0...v0.7.1)] - 2023-10-14
|
2
20
|
|
3
21
|
- Bump bundled llama.cpp from b1334 to b1380.
|
data/examples/chat.rb
CHANGED
@@ -83,10 +83,12 @@ class Chat < Thor # rubocop:disable Metrics/ClassLength, Style/Documentation
|
|
83
83
|
candidates = LLaMACpp::TokenDataArray.new(base_candidates)
|
84
84
|
|
85
85
|
last_n_repeat = [last_n_tokens.size, options[:repeat_last_n], n_ctx].min
|
86
|
-
context.
|
87
|
-
|
88
|
-
|
89
|
-
|
86
|
+
context.sample_repetition_penalties(
|
87
|
+
candidates,
|
88
|
+
last_n_tokens[-last_n_repeat..],
|
89
|
+
penalty_repeat: options[:repeat_penalty],
|
90
|
+
penalty_freq: options[:frequency_penalty],
|
91
|
+
penalty_present: options[:presence_penalty]
|
90
92
|
)
|
91
93
|
|
92
94
|
context.sample_top_k(candidates, k: options[:top_k])
|
@@ -99,8 +101,8 @@ class Chat < Thor # rubocop:disable Metrics/ClassLength, Style/Documentation
|
|
99
101
|
last_n_tokens.shift
|
100
102
|
last_n_tokens.push(id)
|
101
103
|
|
102
|
-
if id == context.token_eos
|
103
|
-
id = context.token_nl
|
104
|
+
if id == context.model.token_eos
|
105
|
+
id = context.model.token_nl
|
104
106
|
unless antiprompt.empty?
|
105
107
|
first_antiprompt = context.model.tokenize(text: antiprompt, add_bos: false)
|
106
108
|
embd_input.concat(first_antiprompt)
|
data/ext/llama_cpp/extconf.rb
CHANGED
@@ -53,7 +53,7 @@ if with_config('metal')
|
|
53
53
|
$CFLAGS << ' -DGGML_USE_METAL'
|
54
54
|
$CXXFLAGS << ' -DGGML_USE_METAL'
|
55
55
|
$LDFLAGS << ' -framework Foundation -framework Metal -framework MetalKit'
|
56
|
-
$objs = %w[ggml.o ggml-alloc.o ggml-metal.o llama.o llama_cpp.o]
|
56
|
+
$objs = %w[ggml.o ggml-backend.o ggml-alloc.o ggml-metal.o llama.o llama_cpp.o]
|
57
57
|
$objs << 'k_quants.o' unless with_config('no_k_quants')
|
58
58
|
end
|
59
59
|
|
@@ -61,7 +61,7 @@ if with_config('cublas')
|
|
61
61
|
$CFLAGS << ' -DGGML_USE_CUBLAS -I/usr/local/cuda/include'
|
62
62
|
$CXXFLAGS << ' -DGGML_USE_CUBLAS -I/usr/local/cuda/include'
|
63
63
|
$LDFLAGS << ' -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64'
|
64
|
-
$objs = %w[ggml.o ggml-alloc.o ggml-cuda.o llama.o llama_cpp.o]
|
64
|
+
$objs = %w[ggml.o ggml-backend.o ggml-alloc.o ggml-cuda.o llama.o llama_cpp.o]
|
65
65
|
$objs << 'k_quants.o' unless with_config('no_k_quants')
|
66
66
|
end
|
67
67
|
|
data/ext/llama_cpp/llama_cpp.cpp
CHANGED
@@ -63,8 +63,8 @@ public:
|
|
63
63
|
rb_define_method(rb_cLLaMABatch, "get_token", RUBY_METHOD_FUNC(_llama_batch_get_token), 1);
|
64
64
|
rb_define_method(rb_cLLaMABatch, "set_pos", RUBY_METHOD_FUNC(_llama_batch_set_pos), 2);
|
65
65
|
rb_define_method(rb_cLLaMABatch, "get_pos", RUBY_METHOD_FUNC(_llama_batch_get_pos), 1);
|
66
|
-
rb_define_method(rb_cLLaMABatch, "set_seq_id", RUBY_METHOD_FUNC(_llama_batch_set_seq_id),
|
67
|
-
rb_define_method(rb_cLLaMABatch, "get_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_seq_id),
|
66
|
+
rb_define_method(rb_cLLaMABatch, "set_seq_id", RUBY_METHOD_FUNC(_llama_batch_set_seq_id), 3);
|
67
|
+
rb_define_method(rb_cLLaMABatch, "get_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_seq_id), 2);
|
68
68
|
rb_define_method(rb_cLLaMABatch, "set_logits", RUBY_METHOD_FUNC(_llama_batch_set_logits), 2);
|
69
69
|
rb_define_method(rb_cLLaMABatch, "get_logits", RUBY_METHOD_FUNC(_llama_batch_get_logits), 1);
|
70
70
|
}
|
@@ -74,10 +74,10 @@ private:
|
|
74
74
|
|
75
75
|
static VALUE _llama_batch_initialize(int argc, VALUE* argv, VALUE self) {
|
76
76
|
VALUE kw_args = Qnil;
|
77
|
-
ID kw_table[
|
78
|
-
VALUE kw_values[
|
77
|
+
ID kw_table[3] = { rb_intern("n_tokens"), rb_intern("embd"), rb_intern("n_seq_max") };
|
78
|
+
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
79
79
|
rb_scan_args(argc, argv, ":", &kw_args);
|
80
|
-
rb_get_kwargs(kw_args, kw_table,
|
80
|
+
rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
|
81
81
|
|
82
82
|
if (!RB_INTEGER_TYPE_P(kw_values[0])) {
|
83
83
|
rb_raise(rb_eArgError, "n_tokens must be an integer");
|
@@ -87,12 +87,17 @@ private:
|
|
87
87
|
rb_raise(rb_eArgError, "embd must be an integer");
|
88
88
|
return Qnil;
|
89
89
|
}
|
90
|
+
if (!RB_INTEGER_TYPE_P(kw_values[2])) {
|
91
|
+
rb_raise(rb_eArgError, "n_seq_max must be an integer");
|
92
|
+
return Qnil;
|
93
|
+
}
|
90
94
|
|
91
95
|
const int32_t n_tokens = NUM2INT(kw_values[0]);
|
92
96
|
const int32_t embd = NUM2INT(kw_values[1]);
|
97
|
+
const int32_t n_seq_max = NUM2INT(kw_values[2]);
|
93
98
|
|
94
99
|
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
95
|
-
ptr->batch = llama_batch_init(n_tokens, embd);
|
100
|
+
ptr->batch = llama_batch_init(n_tokens, embd, n_seq_max);
|
96
101
|
|
97
102
|
return Qnil;
|
98
103
|
}
|
@@ -190,25 +195,35 @@ private:
|
|
190
195
|
}
|
191
196
|
|
192
197
|
// seq_id
|
193
|
-
static VALUE _llama_batch_set_seq_id(VALUE self, VALUE
|
198
|
+
static VALUE _llama_batch_set_seq_id(VALUE self, VALUE i_, VALUE j_, VALUE value) {
|
194
199
|
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
195
|
-
const int32_t
|
196
|
-
if (
|
197
|
-
rb_raise(rb_eArgError, "
|
200
|
+
const int32_t i = NUM2INT(i_);
|
201
|
+
if (i < 0 || i >= ptr->batch.n_tokens) {
|
202
|
+
rb_raise(rb_eArgError, "i must be in [0, n_tokens)");
|
203
|
+
return Qnil;
|
204
|
+
}
|
205
|
+
const int32_t j = NUM2INT(j_);
|
206
|
+
if (j < 0 || j >= ptr->batch.n_seq_id[i]) {
|
207
|
+
rb_raise(rb_eArgError, "j must be in [0, n_seq_id[i])");
|
198
208
|
return Qnil;
|
199
209
|
}
|
200
|
-
ptr->batch.seq_id[
|
201
|
-
return INT2NUM(ptr->batch.seq_id[
|
210
|
+
ptr->batch.seq_id[i][j] = NUM2INT(value);
|
211
|
+
return INT2NUM(ptr->batch.seq_id[i][j]);
|
202
212
|
}
|
203
213
|
|
204
|
-
static VALUE _llama_batch_get_seq_id(VALUE self, VALUE
|
214
|
+
static VALUE _llama_batch_get_seq_id(VALUE self, VALUE i_, VALUE j_) {
|
205
215
|
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
206
|
-
const int32_t
|
207
|
-
if (
|
208
|
-
rb_raise(rb_eArgError, "
|
216
|
+
const int32_t i = NUM2INT(i_);
|
217
|
+
if (i < 0 || i >= ptr->batch.n_tokens) {
|
218
|
+
rb_raise(rb_eArgError, "i must be in [0, n_tokens)");
|
209
219
|
return Qnil;
|
210
220
|
}
|
211
|
-
|
221
|
+
const int32_t j = NUM2INT(j_);
|
222
|
+
if (j < 0 || j >= ptr->batch.n_seq_id[i]) {
|
223
|
+
rb_raise(rb_eArgError, "j must be in [0, n_seq_id[i])");
|
224
|
+
return Qnil;
|
225
|
+
}
|
226
|
+
return INT2NUM(ptr->batch.seq_id[i][j]);
|
212
227
|
}
|
213
228
|
|
214
229
|
// logits
|
@@ -1133,6 +1148,16 @@ public:
|
|
1133
1148
|
rb_define_method(rb_cLLaMAModel, "desc", RUBY_METHOD_FUNC(_llama_model_get_model_desc), 0);
|
1134
1149
|
rb_define_method(rb_cLLaMAModel, "size", RUBY_METHOD_FUNC(_llama_model_get_model_size), 0);
|
1135
1150
|
rb_define_method(rb_cLLaMAModel, "n_params", RUBY_METHOD_FUNC(_llama_model_get_model_n_params), 0);
|
1151
|
+
rb_define_method(rb_cLLaMAModel, "text", RUBY_METHOD_FUNC(_llama_model_get_text), 1);
|
1152
|
+
rb_define_method(rb_cLLaMAModel, "score", RUBY_METHOD_FUNC(_llama_model_get_score), 1);
|
1153
|
+
rb_define_method(rb_cLLaMAModel, "type", RUBY_METHOD_FUNC(_llama_model_get_type), 1);
|
1154
|
+
rb_define_method(rb_cLLaMAModel, "token_bos", RUBY_METHOD_FUNC(_llama_model_token_bos), 0);
|
1155
|
+
rb_define_method(rb_cLLaMAModel, "token_eos", RUBY_METHOD_FUNC(_llama_model_token_eos), 0);
|
1156
|
+
rb_define_method(rb_cLLaMAModel, "token_nl", RUBY_METHOD_FUNC(_llama_model_token_nl), 0);
|
1157
|
+
rb_define_method(rb_cLLaMAModel, "token_prefix", RUBY_METHOD_FUNC(_llama_model_token_prefix), 0);
|
1158
|
+
rb_define_method(rb_cLLaMAModel, "token_middle", RUBY_METHOD_FUNC(_llama_model_token_middle), 0);
|
1159
|
+
rb_define_method(rb_cLLaMAModel, "token_suffix", RUBY_METHOD_FUNC(_llama_model_token_suffix), 0);
|
1160
|
+
rb_define_method(rb_cLLaMAModel, "token_eot", RUBY_METHOD_FUNC(_llama_model_token_eot), 0);
|
1136
1161
|
}
|
1137
1162
|
|
1138
1163
|
private:
|
@@ -1319,10 +1344,10 @@ private:
|
|
1319
1344
|
|
1320
1345
|
static VALUE _llama_model_tokenize(int argc, VALUE* argv, VALUE self) {
|
1321
1346
|
VALUE kw_args = Qnil;
|
1322
|
-
ID kw_table[
|
1323
|
-
VALUE kw_values[
|
1347
|
+
ID kw_table[4] = { rb_intern("text"), rb_intern("n_max_tokens"), rb_intern("add_bos"), rb_intern("special") };
|
1348
|
+
VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
|
1324
1349
|
rb_scan_args(argc, argv, ":", &kw_args);
|
1325
|
-
rb_get_kwargs(kw_args, kw_table, 1,
|
1350
|
+
rb_get_kwargs(kw_args, kw_table, 1, 3, kw_values);
|
1326
1351
|
|
1327
1352
|
if (!RB_TYPE_P(kw_values[0], T_STRING)) {
|
1328
1353
|
rb_raise(rb_eArgError, "text must be a String");
|
@@ -1336,15 +1361,20 @@ private:
|
|
1336
1361
|
rb_raise(rb_eArgError, "add_bos must be a boolean");
|
1337
1362
|
return Qnil;
|
1338
1363
|
}
|
1364
|
+
if (kw_values[3] != Qundef && (kw_values[3] != Qtrue && kw_values[3] != Qfalse)) {
|
1365
|
+
rb_raise(rb_eArgError, "special must be a boolean");
|
1366
|
+
return Qnil;
|
1367
|
+
}
|
1339
1368
|
|
1340
1369
|
VALUE text_ = kw_values[0];
|
1341
1370
|
std::string text = StringValueCStr(text_);
|
1342
1371
|
const bool add_bos = kw_values[2] == Qtrue ? true : false;
|
1372
|
+
const bool special = kw_values[3] == Qtrue ? true : false;
|
1343
1373
|
const int n_max_tokens = kw_values[1] != Qundef ? NUM2INT(kw_values[1]) : text.size() + (add_bos ? 1 : 0);
|
1344
1374
|
|
1345
1375
|
llama_token* tokens = ALLOCA_N(llama_token, n_max_tokens);
|
1346
1376
|
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1347
|
-
const int n_tokens = llama_tokenize(ptr->model, text.c_str(), text.size(), tokens, n_max_tokens, add_bos);
|
1377
|
+
const int n_tokens = llama_tokenize(ptr->model, text.c_str(), text.size(), tokens, n_max_tokens, add_bos, special);
|
1348
1378
|
|
1349
1379
|
if (n_tokens < 0) {
|
1350
1380
|
rb_raise(rb_eRuntimeError, "failed to tokenize. The numebr of tokens (%d) is greater than n_max_tokens.", -n_tokens);
|
@@ -1376,6 +1406,62 @@ private:
|
|
1376
1406
|
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1377
1407
|
return UINT2NUM(llama_model_n_params(ptr->model));
|
1378
1408
|
}
|
1409
|
+
|
1410
|
+
static VALUE _llama_model_get_text(VALUE self, VALUE token_) {
|
1411
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1412
|
+
const llama_token token = NUM2INT(token_);
|
1413
|
+
const char* text = llama_token_get_text(ptr->model, token);
|
1414
|
+
return rb_utf8_str_new_cstr(text);
|
1415
|
+
}
|
1416
|
+
|
1417
|
+
static VALUE _llama_model_get_score(VALUE self, VALUE token_) {
|
1418
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1419
|
+
const llama_token token = NUM2INT(token_);
|
1420
|
+
const float score = llama_token_get_score(ptr->model, token);
|
1421
|
+
return DBL2NUM(score);
|
1422
|
+
}
|
1423
|
+
|
1424
|
+
static VALUE _llama_model_get_type(VALUE self, VALUE token_) {
|
1425
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1426
|
+
const llama_token token = NUM2INT(token_);
|
1427
|
+
const int type = llama_token_get_type(ptr->model, token);
|
1428
|
+
return INT2NUM(type);
|
1429
|
+
}
|
1430
|
+
|
1431
|
+
static VALUE _llama_model_token_bos(VALUE self) {
|
1432
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1433
|
+
return INT2NUM(llama_token_bos(ptr->model));
|
1434
|
+
}
|
1435
|
+
|
1436
|
+
static VALUE _llama_model_token_eos(VALUE self) {
|
1437
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1438
|
+
return INT2NUM(llama_token_eos(ptr->model));
|
1439
|
+
}
|
1440
|
+
|
1441
|
+
static VALUE _llama_model_token_nl(VALUE self) {
|
1442
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1443
|
+
return INT2NUM(llama_token_nl(ptr->model));
|
1444
|
+
}
|
1445
|
+
|
1446
|
+
static VALUE _llama_model_token_prefix(VALUE self) {
|
1447
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1448
|
+
return INT2NUM(llama_token_prefix(ptr->model));
|
1449
|
+
}
|
1450
|
+
|
1451
|
+
static VALUE _llama_model_token_middle(VALUE self) {
|
1452
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1453
|
+
return INT2NUM(llama_token_middle(ptr->model));
|
1454
|
+
}
|
1455
|
+
|
1456
|
+
static VALUE _llama_model_token_suffix(VALUE self) {
|
1457
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1458
|
+
return INT2NUM(llama_token_suffix(ptr->model));
|
1459
|
+
}
|
1460
|
+
|
1461
|
+
static VALUE _llama_model_token_eot(VALUE self) {
|
1462
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1463
|
+
return INT2NUM(llama_token_eot(ptr->model));
|
1464
|
+
}
|
1379
1465
|
};
|
1380
1466
|
|
1381
1467
|
const rb_data_type_t RbLLaMAModel::llama_model_type = {
|
@@ -1650,16 +1736,6 @@ public:
|
|
1650
1736
|
rb_define_method(rb_cLLaMAContext, "decode", RUBY_METHOD_FUNC(_llama_context_decode), 1);
|
1651
1737
|
rb_define_method(rb_cLLaMAContext, "logits", RUBY_METHOD_FUNC(_llama_context_logits), 0);
|
1652
1738
|
rb_define_method(rb_cLLaMAContext, "embeddings", RUBY_METHOD_FUNC(_llama_context_embeddings), 0);
|
1653
|
-
rb_define_method(rb_cLLaMAContext, "text", RUBY_METHOD_FUNC(_llama_context_text), 1);
|
1654
|
-
rb_define_method(rb_cLLaMAContext, "score", RUBY_METHOD_FUNC(_llama_context_score), 1);
|
1655
|
-
rb_define_method(rb_cLLaMAContext, "type", RUBY_METHOD_FUNC(_llama_context_type), 1);
|
1656
|
-
rb_define_method(rb_cLLaMAContext, "token_bos", RUBY_METHOD_FUNC(_llama_context_token_bos), 0);
|
1657
|
-
rb_define_method(rb_cLLaMAContext, "token_eos", RUBY_METHOD_FUNC(_llama_context_token_eos), 0);
|
1658
|
-
rb_define_method(rb_cLLaMAContext, "token_nl", RUBY_METHOD_FUNC(_llama_context_token_nl), 0);
|
1659
|
-
rb_define_method(rb_cLLaMAContext, "token_prefix", RUBY_METHOD_FUNC(_llama_context_token_prefix), 0);
|
1660
|
-
rb_define_method(rb_cLLaMAContext, "token_middle", RUBY_METHOD_FUNC(_llama_context_token_middle), 0);
|
1661
|
-
rb_define_method(rb_cLLaMAContext, "token_suffix", RUBY_METHOD_FUNC(_llama_context_token_suffix), 0);
|
1662
|
-
rb_define_method(rb_cLLaMAContext, "token_eot", RUBY_METHOD_FUNC(_llama_context_token_eot), 0);
|
1663
1739
|
rb_define_method(rb_cLLaMAContext, "n_ctx", RUBY_METHOD_FUNC(_llama_context_n_ctx), 0);
|
1664
1740
|
rb_define_method(rb_cLLaMAContext, "timings", RUBY_METHOD_FUNC(_llama_context_get_timings), 0);
|
1665
1741
|
rb_define_method(rb_cLLaMAContext, "print_timings", RUBY_METHOD_FUNC(_llama_context_print_timings), 0);
|
@@ -1673,8 +1749,7 @@ public:
|
|
1673
1749
|
rb_define_method(rb_cLLaMAContext, "set_rng_seed", RUBY_METHOD_FUNC(_llama_context_set_rng_seed), 1);
|
1674
1750
|
rb_define_method(rb_cLLaMAContext, "load_session_file", RUBY_METHOD_FUNC(_llama_context_load_session_file), -1);
|
1675
1751
|
rb_define_method(rb_cLLaMAContext, "save_session_file", RUBY_METHOD_FUNC(_llama_context_save_session_file), -1);
|
1676
|
-
rb_define_method(rb_cLLaMAContext, "
|
1677
|
-
rb_define_method(rb_cLLaMAContext, "sample_frequency_and_presence_penalties", RUBY_METHOD_FUNC(_llama_context_sample_frequency_and_presence_penalties), -1);
|
1752
|
+
rb_define_method(rb_cLLaMAContext, "sample_repetition_penalties", RUBY_METHOD_FUNC(_llama_context_sample_repetition_penalties), -1);
|
1678
1753
|
rb_define_method(rb_cLLaMAContext, "sample_classifier_free_guidance", RUBY_METHOD_FUNC(_llama_context_sample_classifier_free_guidance), -1);
|
1679
1754
|
rb_define_method(rb_cLLaMAContext, "sample_softmax", RUBY_METHOD_FUNC(_llama_context_sample_softmax), 1);
|
1680
1755
|
rb_define_method(rb_cLLaMAContext, "sample_top_k", RUBY_METHOD_FUNC(_llama_context_sample_top_k), -1);
|
@@ -1907,102 +1982,6 @@ private:
|
|
1907
1982
|
return output;
|
1908
1983
|
}
|
1909
1984
|
|
1910
|
-
static VALUE _llama_context_text(VALUE self, VALUE token_) {
|
1911
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1912
|
-
if (ptr->ctx == NULL) {
|
1913
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1914
|
-
return Qnil;
|
1915
|
-
}
|
1916
|
-
const llama_token token = NUM2INT(token_);
|
1917
|
-
const char* text = llama_token_get_text(ptr->ctx, token);
|
1918
|
-
return rb_utf8_str_new_cstr(text);
|
1919
|
-
}
|
1920
|
-
|
1921
|
-
static VALUE _llama_context_score(VALUE self, VALUE token_) {
|
1922
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1923
|
-
if (ptr->ctx == NULL) {
|
1924
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1925
|
-
return Qnil;
|
1926
|
-
}
|
1927
|
-
const llama_token token = NUM2INT(token_);
|
1928
|
-
const float score = llama_token_get_score(ptr->ctx, token);
|
1929
|
-
return DBL2NUM(score);
|
1930
|
-
}
|
1931
|
-
|
1932
|
-
static VALUE _llama_context_type(VALUE self, VALUE token_) {
|
1933
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1934
|
-
if (ptr->ctx == NULL) {
|
1935
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1936
|
-
return Qnil;
|
1937
|
-
}
|
1938
|
-
const llama_token token = NUM2INT(token_);
|
1939
|
-
const int type = llama_token_get_type(ptr->ctx, token);
|
1940
|
-
return INT2NUM(type);
|
1941
|
-
}
|
1942
|
-
|
1943
|
-
static VALUE _llama_context_token_bos(VALUE self) {
|
1944
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1945
|
-
if (ptr->ctx == NULL) {
|
1946
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1947
|
-
return Qnil;
|
1948
|
-
}
|
1949
|
-
return INT2NUM(llama_token_bos(ptr->ctx));
|
1950
|
-
}
|
1951
|
-
|
1952
|
-
static VALUE _llama_context_token_eos(VALUE self) {
|
1953
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1954
|
-
if (ptr->ctx == NULL) {
|
1955
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1956
|
-
return Qnil;
|
1957
|
-
}
|
1958
|
-
return INT2NUM(llama_token_eos(ptr->ctx));
|
1959
|
-
}
|
1960
|
-
|
1961
|
-
static VALUE _llama_context_token_nl(VALUE self) {
|
1962
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1963
|
-
if (ptr->ctx == NULL) {
|
1964
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1965
|
-
return Qnil;
|
1966
|
-
}
|
1967
|
-
return INT2NUM(llama_token_nl(ptr->ctx));
|
1968
|
-
}
|
1969
|
-
|
1970
|
-
static VALUE _llama_context_token_prefix(VALUE self) {
|
1971
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1972
|
-
if (ptr->ctx == NULL) {
|
1973
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1974
|
-
return Qnil;
|
1975
|
-
}
|
1976
|
-
return INT2NUM(llama_token_prefix(ptr->ctx));
|
1977
|
-
}
|
1978
|
-
|
1979
|
-
static VALUE _llama_context_token_middle(VALUE self) {
|
1980
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1981
|
-
if (ptr->ctx == NULL) {
|
1982
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1983
|
-
return Qnil;
|
1984
|
-
}
|
1985
|
-
return INT2NUM(llama_token_middle(ptr->ctx));
|
1986
|
-
}
|
1987
|
-
|
1988
|
-
static VALUE _llama_context_token_suffix(VALUE self) {
|
1989
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1990
|
-
if (ptr->ctx == NULL) {
|
1991
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1992
|
-
return Qnil;
|
1993
|
-
}
|
1994
|
-
return INT2NUM(llama_token_suffix(ptr->ctx));
|
1995
|
-
}
|
1996
|
-
|
1997
|
-
static VALUE _llama_context_token_eot(VALUE self) {
|
1998
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1999
|
-
if (ptr->ctx == NULL) {
|
2000
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
2001
|
-
return Qnil;
|
2002
|
-
}
|
2003
|
-
return INT2NUM(llama_token_eot(ptr->ctx));
|
2004
|
-
}
|
2005
|
-
|
2006
1985
|
static VALUE _llama_context_n_ctx(VALUE self) {
|
2007
1986
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
2008
1987
|
if (ptr->ctx == NULL) {
|
@@ -2211,14 +2190,14 @@ private:
|
|
2211
2190
|
return Qnil;
|
2212
2191
|
}
|
2213
2192
|
|
2214
|
-
static VALUE
|
2193
|
+
static VALUE _llama_context_sample_repetition_penalties(int argc, VALUE* argv, VALUE self) {
|
2215
2194
|
VALUE kw_args = Qnil;
|
2216
|
-
ID kw_table[
|
2217
|
-
VALUE kw_values[
|
2195
|
+
ID kw_table[3] = { rb_intern("penalty_repeat"), rb_intern("penalty_freq"), rb_intern("penalty_present") };
|
2196
|
+
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
2218
2197
|
VALUE candidates = Qnil;
|
2219
2198
|
VALUE last_n_tokens = Qnil;
|
2220
2199
|
rb_scan_args(argc, argv, "2:", &candidates, &last_n_tokens, &kw_args);
|
2221
|
-
rb_get_kwargs(kw_args, kw_table,
|
2200
|
+
rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
|
2222
2201
|
|
2223
2202
|
if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
|
2224
2203
|
rb_raise(rb_eArgError, "candidates must be a TokenDataArray");
|
@@ -2229,56 +2208,15 @@ private:
|
|
2229
2208
|
return Qnil;
|
2230
2209
|
}
|
2231
2210
|
if (!RB_FLOAT_TYPE_P(kw_values[0])) {
|
2232
|
-
rb_raise(rb_eArgError, "
|
2211
|
+
rb_raise(rb_eArgError, "penalty_repeat must be a float");
|
2233
2212
|
return Qnil;
|
2234
2213
|
}
|
2235
|
-
|
2236
|
-
|
2237
|
-
std::vector<llama_token> last_n_tokens_data(last_tokens_size);
|
2238
|
-
for (size_t i = 0; i < last_tokens_size; i++) {
|
2239
|
-
last_n_tokens_data[i] = NUM2INT(rb_ary_entry(last_n_tokens, i));
|
2240
|
-
}
|
2241
|
-
|
2242
|
-
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
2243
|
-
if (ctx_ptr->ctx == NULL) {
|
2244
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
2245
|
-
return Qnil;
|
2246
|
-
}
|
2247
|
-
LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
|
2248
|
-
if (cnd_ptr->array.data == nullptr) {
|
2249
|
-
rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
|
2250
|
-
return Qnil;
|
2251
|
-
}
|
2252
|
-
const float penalty = NUM2DBL(kw_values[0]);
|
2253
|
-
|
2254
|
-
llama_sample_repetition_penalty(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size, penalty);
|
2255
|
-
|
2256
|
-
return Qnil;
|
2257
|
-
}
|
2258
|
-
|
2259
|
-
static VALUE _llama_context_sample_frequency_and_presence_penalties(int argc, VALUE* argv, VALUE self) {
|
2260
|
-
VALUE kw_args = Qnil;
|
2261
|
-
ID kw_table[2] = { rb_intern("frequency"), rb_intern("presence") };
|
2262
|
-
VALUE kw_values[2] = { Qundef, Qundef };
|
2263
|
-
VALUE candidates = Qnil;
|
2264
|
-
VALUE last_n_tokens = Qnil;
|
2265
|
-
rb_scan_args(argc, argv, "2:", &candidates, &last_n_tokens, &kw_args);
|
2266
|
-
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
2267
|
-
|
2268
|
-
if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
|
2269
|
-
rb_raise(rb_eArgError, "candidates must be a TokenDataArray");
|
2270
|
-
return Qnil;
|
2271
|
-
}
|
2272
|
-
if (!RB_TYPE_P(last_n_tokens, T_ARRAY)) {
|
2273
|
-
rb_raise(rb_eArgError, "last_n_tokens must be an Array");
|
2274
|
-
return Qnil;
|
2275
|
-
}
|
2276
|
-
if (!RB_FLOAT_TYPE_P(kw_values[0])) {
|
2277
|
-
rb_raise(rb_eArgError, "frequency must be a float");
|
2214
|
+
if (!RB_FLOAT_TYPE_P(kw_values[1])) {
|
2215
|
+
rb_raise(rb_eArgError, "penalty_freq must be a float");
|
2278
2216
|
return Qnil;
|
2279
2217
|
}
|
2280
|
-
if (!RB_FLOAT_TYPE_P(kw_values[
|
2281
|
-
rb_raise(rb_eArgError, "
|
2218
|
+
if (!RB_FLOAT_TYPE_P(kw_values[2])) {
|
2219
|
+
rb_raise(rb_eArgError, "penalty_present must be a float");
|
2282
2220
|
return Qnil;
|
2283
2221
|
}
|
2284
2222
|
|
@@ -2298,11 +2236,12 @@ private:
|
|
2298
2236
|
rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
|
2299
2237
|
return Qnil;
|
2300
2238
|
}
|
2239
|
+
const float penalty_repeat = NUM2DBL(kw_values[0]);
|
2240
|
+
const float penalty_freq = NUM2DBL(kw_values[1]);
|
2241
|
+
const float penalty_present = NUM2DBL(kw_values[2]);
|
2301
2242
|
|
2302
|
-
|
2303
|
-
|
2304
|
-
|
2305
|
-
llama_sample_frequency_and_presence_penalties(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size, alpha_frequency, alpha_presence);
|
2243
|
+
llama_sample_repetition_penalties(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size,
|
2244
|
+
penalty_repeat, penalty_freq, penalty_present);
|
2306
2245
|
|
2307
2246
|
return Qnil;
|
2308
2247
|
}
|