llama_cpp 0.8.0 → 0.9.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +10 -0
- data/examples/chat.rb +8 -6
- data/ext/llama_cpp/extconf.rb +2 -2
- data/ext/llama_cpp/llama_cpp.cpp +81 -162
- data/ext/llama_cpp/src/ggml-cuda.cu +188 -20
- data/ext/llama_cpp/src/ggml-metal.m +13 -5
- data/ext/llama_cpp/src/ggml-metal.metal +9 -1
- data/ext/llama_cpp/src/ggml-opencl.cpp +161 -169
- data/ext/llama_cpp/src/ggml.c +362 -84
- data/ext/llama_cpp/src/ggml.h +8 -7
- data/ext/llama_cpp/src/llama.cpp +100 -95
- data/ext/llama_cpp/src/llama.h +16 -21
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +4 -4
- data/sig/llama_cpp.rbs +11 -12
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 683f2d81aff9e82234925ba08cd5b46b56a2283ff8397a6c06ce50d34a95dbfc
|
4
|
+
data.tar.gz: d3005cab273b8d85f47f4cb4314fbab3a540d366a42829e5ec8d2c29576ae09e
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 559f1ba1253a704c38480336decd315c65b4d80e6895ad1dc0faa3b5b81570a1faeaadcb6ec7ee3145f0fff758ab5e38e6cb8163382ce9b693d893deebe9a8f9
|
7
|
+
data.tar.gz: cb3d96b8c3f79cd20d4169a175270e8768c04bcaa24e51cb2c4d7872db88bc6e3349e6b1e93a130b89d21daab8be6e57b5305412059ea722084c7cb7d4a01e93
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,13 @@
|
|
1
|
+
## [[0.9.0](https://github.com/yoshoku/llama_cpp.rb/compare/v0.8.0...v0.9.0)] - 2023-10-28
|
2
|
+
|
3
|
+
- Fix missing object file for ggml-backend when building with metal and cublas options.
|
4
|
+
|
5
|
+
**Breaking Changes**
|
6
|
+
- Bump bundled llama.cpp from b1405 to b1429
|
7
|
+
- Move following methods from Context to Model:
|
8
|
+
- text, score, type, token_bos, token_eos, token_nl, token_prefix, token_middle, token_suffix, and token_eos.
|
9
|
+
- Add `sample_repetition_penalties` method, which integrates sample_frequency_and_presence_penalties and sample_repetition_penalty methods.
|
10
|
+
|
1
11
|
## [[0.8.0](https://github.com/yoshoku/llama_cpp.rb/compare/v0.7.1...v0.8.0)] - 2023-10-21
|
2
12
|
|
3
13
|
**Breaking Changes**
|
data/examples/chat.rb
CHANGED
@@ -83,10 +83,12 @@ class Chat < Thor # rubocop:disable Metrics/ClassLength, Style/Documentation
|
|
83
83
|
candidates = LLaMACpp::TokenDataArray.new(base_candidates)
|
84
84
|
|
85
85
|
last_n_repeat = [last_n_tokens.size, options[:repeat_last_n], n_ctx].min
|
86
|
-
context.
|
87
|
-
|
88
|
-
|
89
|
-
|
86
|
+
context.sample_repetition_penalties(
|
87
|
+
candidates,
|
88
|
+
last_n_tokens[-last_n_repeat..],
|
89
|
+
penalty_repeat: options[:repeat_penalty],
|
90
|
+
penalty_freq: options[:frequency_penalty],
|
91
|
+
penalty_present: options[:presence_penalty]
|
90
92
|
)
|
91
93
|
|
92
94
|
context.sample_top_k(candidates, k: options[:top_k])
|
@@ -99,8 +101,8 @@ class Chat < Thor # rubocop:disable Metrics/ClassLength, Style/Documentation
|
|
99
101
|
last_n_tokens.shift
|
100
102
|
last_n_tokens.push(id)
|
101
103
|
|
102
|
-
if id == context.token_eos
|
103
|
-
id = context.token_nl
|
104
|
+
if id == context.model.token_eos
|
105
|
+
id = context.model.token_nl
|
104
106
|
unless antiprompt.empty?
|
105
107
|
first_antiprompt = context.model.tokenize(text: antiprompt, add_bos: false)
|
106
108
|
embd_input.concat(first_antiprompt)
|
data/ext/llama_cpp/extconf.rb
CHANGED
@@ -53,7 +53,7 @@ if with_config('metal')
|
|
53
53
|
$CFLAGS << ' -DGGML_USE_METAL'
|
54
54
|
$CXXFLAGS << ' -DGGML_USE_METAL'
|
55
55
|
$LDFLAGS << ' -framework Foundation -framework Metal -framework MetalKit'
|
56
|
-
$objs = %w[ggml.o ggml-alloc.o ggml-metal.o llama.o llama_cpp.o]
|
56
|
+
$objs = %w[ggml.o ggml-backend.o ggml-alloc.o ggml-metal.o llama.o llama_cpp.o]
|
57
57
|
$objs << 'k_quants.o' unless with_config('no_k_quants')
|
58
58
|
end
|
59
59
|
|
@@ -61,7 +61,7 @@ if with_config('cublas')
|
|
61
61
|
$CFLAGS << ' -DGGML_USE_CUBLAS -I/usr/local/cuda/include'
|
62
62
|
$CXXFLAGS << ' -DGGML_USE_CUBLAS -I/usr/local/cuda/include'
|
63
63
|
$LDFLAGS << ' -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64'
|
64
|
-
$objs = %w[ggml.o ggml-alloc.o ggml-cuda.o llama.o llama_cpp.o]
|
64
|
+
$objs = %w[ggml.o ggml-backend.o ggml-alloc.o ggml-cuda.o llama.o llama_cpp.o]
|
65
65
|
$objs << 'k_quants.o' unless with_config('no_k_quants')
|
66
66
|
end
|
67
67
|
|
data/ext/llama_cpp/llama_cpp.cpp
CHANGED
@@ -1148,6 +1148,16 @@ public:
|
|
1148
1148
|
rb_define_method(rb_cLLaMAModel, "desc", RUBY_METHOD_FUNC(_llama_model_get_model_desc), 0);
|
1149
1149
|
rb_define_method(rb_cLLaMAModel, "size", RUBY_METHOD_FUNC(_llama_model_get_model_size), 0);
|
1150
1150
|
rb_define_method(rb_cLLaMAModel, "n_params", RUBY_METHOD_FUNC(_llama_model_get_model_n_params), 0);
|
1151
|
+
rb_define_method(rb_cLLaMAModel, "text", RUBY_METHOD_FUNC(_llama_model_get_text), 1);
|
1152
|
+
rb_define_method(rb_cLLaMAModel, "score", RUBY_METHOD_FUNC(_llama_model_get_score), 1);
|
1153
|
+
rb_define_method(rb_cLLaMAModel, "type", RUBY_METHOD_FUNC(_llama_model_get_type), 1);
|
1154
|
+
rb_define_method(rb_cLLaMAModel, "token_bos", RUBY_METHOD_FUNC(_llama_model_token_bos), 0);
|
1155
|
+
rb_define_method(rb_cLLaMAModel, "token_eos", RUBY_METHOD_FUNC(_llama_model_token_eos), 0);
|
1156
|
+
rb_define_method(rb_cLLaMAModel, "token_nl", RUBY_METHOD_FUNC(_llama_model_token_nl), 0);
|
1157
|
+
rb_define_method(rb_cLLaMAModel, "token_prefix", RUBY_METHOD_FUNC(_llama_model_token_prefix), 0);
|
1158
|
+
rb_define_method(rb_cLLaMAModel, "token_middle", RUBY_METHOD_FUNC(_llama_model_token_middle), 0);
|
1159
|
+
rb_define_method(rb_cLLaMAModel, "token_suffix", RUBY_METHOD_FUNC(_llama_model_token_suffix), 0);
|
1160
|
+
rb_define_method(rb_cLLaMAModel, "token_eot", RUBY_METHOD_FUNC(_llama_model_token_eot), 0);
|
1151
1161
|
}
|
1152
1162
|
|
1153
1163
|
private:
|
@@ -1396,6 +1406,62 @@ private:
|
|
1396
1406
|
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1397
1407
|
return UINT2NUM(llama_model_n_params(ptr->model));
|
1398
1408
|
}
|
1409
|
+
|
1410
|
+
static VALUE _llama_model_get_text(VALUE self, VALUE token_) {
|
1411
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1412
|
+
const llama_token token = NUM2INT(token_);
|
1413
|
+
const char* text = llama_token_get_text(ptr->model, token);
|
1414
|
+
return rb_utf8_str_new_cstr(text);
|
1415
|
+
}
|
1416
|
+
|
1417
|
+
static VALUE _llama_model_get_score(VALUE self, VALUE token_) {
|
1418
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1419
|
+
const llama_token token = NUM2INT(token_);
|
1420
|
+
const float score = llama_token_get_score(ptr->model, token);
|
1421
|
+
return DBL2NUM(score);
|
1422
|
+
}
|
1423
|
+
|
1424
|
+
static VALUE _llama_model_get_type(VALUE self, VALUE token_) {
|
1425
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1426
|
+
const llama_token token = NUM2INT(token_);
|
1427
|
+
const int type = llama_token_get_type(ptr->model, token);
|
1428
|
+
return INT2NUM(type);
|
1429
|
+
}
|
1430
|
+
|
1431
|
+
static VALUE _llama_model_token_bos(VALUE self) {
|
1432
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1433
|
+
return INT2NUM(llama_token_bos(ptr->model));
|
1434
|
+
}
|
1435
|
+
|
1436
|
+
static VALUE _llama_model_token_eos(VALUE self) {
|
1437
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1438
|
+
return INT2NUM(llama_token_eos(ptr->model));
|
1439
|
+
}
|
1440
|
+
|
1441
|
+
static VALUE _llama_model_token_nl(VALUE self) {
|
1442
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1443
|
+
return INT2NUM(llama_token_nl(ptr->model));
|
1444
|
+
}
|
1445
|
+
|
1446
|
+
static VALUE _llama_model_token_prefix(VALUE self) {
|
1447
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1448
|
+
return INT2NUM(llama_token_prefix(ptr->model));
|
1449
|
+
}
|
1450
|
+
|
1451
|
+
static VALUE _llama_model_token_middle(VALUE self) {
|
1452
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1453
|
+
return INT2NUM(llama_token_middle(ptr->model));
|
1454
|
+
}
|
1455
|
+
|
1456
|
+
static VALUE _llama_model_token_suffix(VALUE self) {
|
1457
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1458
|
+
return INT2NUM(llama_token_suffix(ptr->model));
|
1459
|
+
}
|
1460
|
+
|
1461
|
+
static VALUE _llama_model_token_eot(VALUE self) {
|
1462
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1463
|
+
return INT2NUM(llama_token_eot(ptr->model));
|
1464
|
+
}
|
1399
1465
|
};
|
1400
1466
|
|
1401
1467
|
const rb_data_type_t RbLLaMAModel::llama_model_type = {
|
@@ -1670,16 +1736,6 @@ public:
|
|
1670
1736
|
rb_define_method(rb_cLLaMAContext, "decode", RUBY_METHOD_FUNC(_llama_context_decode), 1);
|
1671
1737
|
rb_define_method(rb_cLLaMAContext, "logits", RUBY_METHOD_FUNC(_llama_context_logits), 0);
|
1672
1738
|
rb_define_method(rb_cLLaMAContext, "embeddings", RUBY_METHOD_FUNC(_llama_context_embeddings), 0);
|
1673
|
-
rb_define_method(rb_cLLaMAContext, "text", RUBY_METHOD_FUNC(_llama_context_text), 1);
|
1674
|
-
rb_define_method(rb_cLLaMAContext, "score", RUBY_METHOD_FUNC(_llama_context_score), 1);
|
1675
|
-
rb_define_method(rb_cLLaMAContext, "type", RUBY_METHOD_FUNC(_llama_context_type), 1);
|
1676
|
-
rb_define_method(rb_cLLaMAContext, "token_bos", RUBY_METHOD_FUNC(_llama_context_token_bos), 0);
|
1677
|
-
rb_define_method(rb_cLLaMAContext, "token_eos", RUBY_METHOD_FUNC(_llama_context_token_eos), 0);
|
1678
|
-
rb_define_method(rb_cLLaMAContext, "token_nl", RUBY_METHOD_FUNC(_llama_context_token_nl), 0);
|
1679
|
-
rb_define_method(rb_cLLaMAContext, "token_prefix", RUBY_METHOD_FUNC(_llama_context_token_prefix), 0);
|
1680
|
-
rb_define_method(rb_cLLaMAContext, "token_middle", RUBY_METHOD_FUNC(_llama_context_token_middle), 0);
|
1681
|
-
rb_define_method(rb_cLLaMAContext, "token_suffix", RUBY_METHOD_FUNC(_llama_context_token_suffix), 0);
|
1682
|
-
rb_define_method(rb_cLLaMAContext, "token_eot", RUBY_METHOD_FUNC(_llama_context_token_eot), 0);
|
1683
1739
|
rb_define_method(rb_cLLaMAContext, "n_ctx", RUBY_METHOD_FUNC(_llama_context_n_ctx), 0);
|
1684
1740
|
rb_define_method(rb_cLLaMAContext, "timings", RUBY_METHOD_FUNC(_llama_context_get_timings), 0);
|
1685
1741
|
rb_define_method(rb_cLLaMAContext, "print_timings", RUBY_METHOD_FUNC(_llama_context_print_timings), 0);
|
@@ -1693,8 +1749,7 @@ public:
|
|
1693
1749
|
rb_define_method(rb_cLLaMAContext, "set_rng_seed", RUBY_METHOD_FUNC(_llama_context_set_rng_seed), 1);
|
1694
1750
|
rb_define_method(rb_cLLaMAContext, "load_session_file", RUBY_METHOD_FUNC(_llama_context_load_session_file), -1);
|
1695
1751
|
rb_define_method(rb_cLLaMAContext, "save_session_file", RUBY_METHOD_FUNC(_llama_context_save_session_file), -1);
|
1696
|
-
rb_define_method(rb_cLLaMAContext, "
|
1697
|
-
rb_define_method(rb_cLLaMAContext, "sample_frequency_and_presence_penalties", RUBY_METHOD_FUNC(_llama_context_sample_frequency_and_presence_penalties), -1);
|
1752
|
+
rb_define_method(rb_cLLaMAContext, "sample_repetition_penalties", RUBY_METHOD_FUNC(_llama_context_sample_repetition_penalties), -1);
|
1698
1753
|
rb_define_method(rb_cLLaMAContext, "sample_classifier_free_guidance", RUBY_METHOD_FUNC(_llama_context_sample_classifier_free_guidance), -1);
|
1699
1754
|
rb_define_method(rb_cLLaMAContext, "sample_softmax", RUBY_METHOD_FUNC(_llama_context_sample_softmax), 1);
|
1700
1755
|
rb_define_method(rb_cLLaMAContext, "sample_top_k", RUBY_METHOD_FUNC(_llama_context_sample_top_k), -1);
|
@@ -1927,102 +1982,6 @@ private:
|
|
1927
1982
|
return output;
|
1928
1983
|
}
|
1929
1984
|
|
1930
|
-
static VALUE _llama_context_text(VALUE self, VALUE token_) {
|
1931
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1932
|
-
if (ptr->ctx == NULL) {
|
1933
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1934
|
-
return Qnil;
|
1935
|
-
}
|
1936
|
-
const llama_token token = NUM2INT(token_);
|
1937
|
-
const char* text = llama_token_get_text(ptr->ctx, token);
|
1938
|
-
return rb_utf8_str_new_cstr(text);
|
1939
|
-
}
|
1940
|
-
|
1941
|
-
static VALUE _llama_context_score(VALUE self, VALUE token_) {
|
1942
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1943
|
-
if (ptr->ctx == NULL) {
|
1944
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1945
|
-
return Qnil;
|
1946
|
-
}
|
1947
|
-
const llama_token token = NUM2INT(token_);
|
1948
|
-
const float score = llama_token_get_score(ptr->ctx, token);
|
1949
|
-
return DBL2NUM(score);
|
1950
|
-
}
|
1951
|
-
|
1952
|
-
static VALUE _llama_context_type(VALUE self, VALUE token_) {
|
1953
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1954
|
-
if (ptr->ctx == NULL) {
|
1955
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1956
|
-
return Qnil;
|
1957
|
-
}
|
1958
|
-
const llama_token token = NUM2INT(token_);
|
1959
|
-
const int type = llama_token_get_type(ptr->ctx, token);
|
1960
|
-
return INT2NUM(type);
|
1961
|
-
}
|
1962
|
-
|
1963
|
-
static VALUE _llama_context_token_bos(VALUE self) {
|
1964
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1965
|
-
if (ptr->ctx == NULL) {
|
1966
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1967
|
-
return Qnil;
|
1968
|
-
}
|
1969
|
-
return INT2NUM(llama_token_bos(ptr->ctx));
|
1970
|
-
}
|
1971
|
-
|
1972
|
-
static VALUE _llama_context_token_eos(VALUE self) {
|
1973
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1974
|
-
if (ptr->ctx == NULL) {
|
1975
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1976
|
-
return Qnil;
|
1977
|
-
}
|
1978
|
-
return INT2NUM(llama_token_eos(ptr->ctx));
|
1979
|
-
}
|
1980
|
-
|
1981
|
-
static VALUE _llama_context_token_nl(VALUE self) {
|
1982
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1983
|
-
if (ptr->ctx == NULL) {
|
1984
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1985
|
-
return Qnil;
|
1986
|
-
}
|
1987
|
-
return INT2NUM(llama_token_nl(ptr->ctx));
|
1988
|
-
}
|
1989
|
-
|
1990
|
-
static VALUE _llama_context_token_prefix(VALUE self) {
|
1991
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1992
|
-
if (ptr->ctx == NULL) {
|
1993
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1994
|
-
return Qnil;
|
1995
|
-
}
|
1996
|
-
return INT2NUM(llama_token_prefix(ptr->ctx));
|
1997
|
-
}
|
1998
|
-
|
1999
|
-
static VALUE _llama_context_token_middle(VALUE self) {
|
2000
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
2001
|
-
if (ptr->ctx == NULL) {
|
2002
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
2003
|
-
return Qnil;
|
2004
|
-
}
|
2005
|
-
return INT2NUM(llama_token_middle(ptr->ctx));
|
2006
|
-
}
|
2007
|
-
|
2008
|
-
static VALUE _llama_context_token_suffix(VALUE self) {
|
2009
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
2010
|
-
if (ptr->ctx == NULL) {
|
2011
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
2012
|
-
return Qnil;
|
2013
|
-
}
|
2014
|
-
return INT2NUM(llama_token_suffix(ptr->ctx));
|
2015
|
-
}
|
2016
|
-
|
2017
|
-
static VALUE _llama_context_token_eot(VALUE self) {
|
2018
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
2019
|
-
if (ptr->ctx == NULL) {
|
2020
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
2021
|
-
return Qnil;
|
2022
|
-
}
|
2023
|
-
return INT2NUM(llama_token_eot(ptr->ctx));
|
2024
|
-
}
|
2025
|
-
|
2026
1985
|
static VALUE _llama_context_n_ctx(VALUE self) {
|
2027
1986
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
2028
1987
|
if (ptr->ctx == NULL) {
|
@@ -2231,14 +2190,14 @@ private:
|
|
2231
2190
|
return Qnil;
|
2232
2191
|
}
|
2233
2192
|
|
2234
|
-
static VALUE
|
2193
|
+
static VALUE _llama_context_sample_repetition_penalties(int argc, VALUE* argv, VALUE self) {
|
2235
2194
|
VALUE kw_args = Qnil;
|
2236
|
-
ID kw_table[
|
2237
|
-
VALUE kw_values[
|
2195
|
+
ID kw_table[3] = { rb_intern("penalty_repeat"), rb_intern("penalty_freq"), rb_intern("penalty_present") };
|
2196
|
+
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
2238
2197
|
VALUE candidates = Qnil;
|
2239
2198
|
VALUE last_n_tokens = Qnil;
|
2240
2199
|
rb_scan_args(argc, argv, "2:", &candidates, &last_n_tokens, &kw_args);
|
2241
|
-
rb_get_kwargs(kw_args, kw_table,
|
2200
|
+
rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
|
2242
2201
|
|
2243
2202
|
if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
|
2244
2203
|
rb_raise(rb_eArgError, "candidates must be a TokenDataArray");
|
@@ -2249,56 +2208,15 @@ private:
|
|
2249
2208
|
return Qnil;
|
2250
2209
|
}
|
2251
2210
|
if (!RB_FLOAT_TYPE_P(kw_values[0])) {
|
2252
|
-
rb_raise(rb_eArgError, "
|
2211
|
+
rb_raise(rb_eArgError, "penalty_repeat must be a float");
|
2253
2212
|
return Qnil;
|
2254
2213
|
}
|
2255
|
-
|
2256
|
-
|
2257
|
-
std::vector<llama_token> last_n_tokens_data(last_tokens_size);
|
2258
|
-
for (size_t i = 0; i < last_tokens_size; i++) {
|
2259
|
-
last_n_tokens_data[i] = NUM2INT(rb_ary_entry(last_n_tokens, i));
|
2260
|
-
}
|
2261
|
-
|
2262
|
-
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
2263
|
-
if (ctx_ptr->ctx == NULL) {
|
2264
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
2265
|
-
return Qnil;
|
2266
|
-
}
|
2267
|
-
LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
|
2268
|
-
if (cnd_ptr->array.data == nullptr) {
|
2269
|
-
rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
|
2270
|
-
return Qnil;
|
2271
|
-
}
|
2272
|
-
const float penalty = NUM2DBL(kw_values[0]);
|
2273
|
-
|
2274
|
-
llama_sample_repetition_penalty(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size, penalty);
|
2275
|
-
|
2276
|
-
return Qnil;
|
2277
|
-
}
|
2278
|
-
|
2279
|
-
static VALUE _llama_context_sample_frequency_and_presence_penalties(int argc, VALUE* argv, VALUE self) {
|
2280
|
-
VALUE kw_args = Qnil;
|
2281
|
-
ID kw_table[2] = { rb_intern("frequency"), rb_intern("presence") };
|
2282
|
-
VALUE kw_values[2] = { Qundef, Qundef };
|
2283
|
-
VALUE candidates = Qnil;
|
2284
|
-
VALUE last_n_tokens = Qnil;
|
2285
|
-
rb_scan_args(argc, argv, "2:", &candidates, &last_n_tokens, &kw_args);
|
2286
|
-
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
2287
|
-
|
2288
|
-
if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
|
2289
|
-
rb_raise(rb_eArgError, "candidates must be a TokenDataArray");
|
2290
|
-
return Qnil;
|
2291
|
-
}
|
2292
|
-
if (!RB_TYPE_P(last_n_tokens, T_ARRAY)) {
|
2293
|
-
rb_raise(rb_eArgError, "last_n_tokens must be an Array");
|
2294
|
-
return Qnil;
|
2295
|
-
}
|
2296
|
-
if (!RB_FLOAT_TYPE_P(kw_values[0])) {
|
2297
|
-
rb_raise(rb_eArgError, "frequency must be a float");
|
2214
|
+
if (!RB_FLOAT_TYPE_P(kw_values[1])) {
|
2215
|
+
rb_raise(rb_eArgError, "penalty_freq must be a float");
|
2298
2216
|
return Qnil;
|
2299
2217
|
}
|
2300
|
-
if (!RB_FLOAT_TYPE_P(kw_values[
|
2301
|
-
rb_raise(rb_eArgError, "
|
2218
|
+
if (!RB_FLOAT_TYPE_P(kw_values[2])) {
|
2219
|
+
rb_raise(rb_eArgError, "penalty_present must be a float");
|
2302
2220
|
return Qnil;
|
2303
2221
|
}
|
2304
2222
|
|
@@ -2318,11 +2236,12 @@ private:
|
|
2318
2236
|
rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
|
2319
2237
|
return Qnil;
|
2320
2238
|
}
|
2239
|
+
const float penalty_repeat = NUM2DBL(kw_values[0]);
|
2240
|
+
const float penalty_freq = NUM2DBL(kw_values[1]);
|
2241
|
+
const float penalty_present = NUM2DBL(kw_values[2]);
|
2321
2242
|
|
2322
|
-
|
2323
|
-
|
2324
|
-
|
2325
|
-
llama_sample_frequency_and_presence_penalties(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size, alpha_frequency, alpha_presence);
|
2243
|
+
llama_sample_repetition_penalties(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size,
|
2244
|
+
penalty_repeat, penalty_freq, penalty_present);
|
2326
2245
|
|
2327
2246
|
return Qnil;
|
2328
2247
|
}
|