llama_cpp 0.8.0 → 0.9.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +10 -0
- data/examples/chat.rb +8 -6
- data/ext/llama_cpp/extconf.rb +2 -2
- data/ext/llama_cpp/llama_cpp.cpp +81 -162
- data/ext/llama_cpp/src/ggml-cuda.cu +188 -20
- data/ext/llama_cpp/src/ggml-metal.m +13 -5
- data/ext/llama_cpp/src/ggml-metal.metal +9 -1
- data/ext/llama_cpp/src/ggml-opencl.cpp +161 -169
- data/ext/llama_cpp/src/ggml.c +362 -84
- data/ext/llama_cpp/src/ggml.h +8 -7
- data/ext/llama_cpp/src/llama.cpp +100 -95
- data/ext/llama_cpp/src/llama.h +16 -21
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +4 -4
- data/sig/llama_cpp.rbs +11 -12
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 683f2d81aff9e82234925ba08cd5b46b56a2283ff8397a6c06ce50d34a95dbfc
|
4
|
+
data.tar.gz: d3005cab273b8d85f47f4cb4314fbab3a540d366a42829e5ec8d2c29576ae09e
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 559f1ba1253a704c38480336decd315c65b4d80e6895ad1dc0faa3b5b81570a1faeaadcb6ec7ee3145f0fff758ab5e38e6cb8163382ce9b693d893deebe9a8f9
|
7
|
+
data.tar.gz: cb3d96b8c3f79cd20d4169a175270e8768c04bcaa24e51cb2c4d7872db88bc6e3349e6b1e93a130b89d21daab8be6e57b5305412059ea722084c7cb7d4a01e93
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,13 @@
|
|
1
|
+
## [[0.9.0](https://github.com/yoshoku/llama_cpp.rb/compare/v0.8.0...v0.9.0)] - 2023-10-28
|
2
|
+
|
3
|
+
- Fix missing object file for ggml-backend when building with metal and cublas options.
|
4
|
+
|
5
|
+
**Breaking Changes**
|
6
|
+
- Bump bundled llama.cpp from b1405 to b1429
|
7
|
+
- Move following methods from Context to Model:
|
8
|
+
- text, score, type, token_bos, token_eos, token_nl, token_prefix, token_middle, token_suffix, and token_eos.
|
9
|
+
- Add `sample_repetition_penalties` method, which integrates sample_frequency_and_presence_penalties and sample_repetition_penalty methods.
|
10
|
+
|
1
11
|
## [[0.8.0](https://github.com/yoshoku/llama_cpp.rb/compare/v0.7.1...v0.8.0)] - 2023-10-21
|
2
12
|
|
3
13
|
**Breaking Changes**
|
data/examples/chat.rb
CHANGED
@@ -83,10 +83,12 @@ class Chat < Thor # rubocop:disable Metrics/ClassLength, Style/Documentation
|
|
83
83
|
candidates = LLaMACpp::TokenDataArray.new(base_candidates)
|
84
84
|
|
85
85
|
last_n_repeat = [last_n_tokens.size, options[:repeat_last_n], n_ctx].min
|
86
|
-
context.
|
87
|
-
|
88
|
-
|
89
|
-
|
86
|
+
context.sample_repetition_penalties(
|
87
|
+
candidates,
|
88
|
+
last_n_tokens[-last_n_repeat..],
|
89
|
+
penalty_repeat: options[:repeat_penalty],
|
90
|
+
penalty_freq: options[:frequency_penalty],
|
91
|
+
penalty_present: options[:presence_penalty]
|
90
92
|
)
|
91
93
|
|
92
94
|
context.sample_top_k(candidates, k: options[:top_k])
|
@@ -99,8 +101,8 @@ class Chat < Thor # rubocop:disable Metrics/ClassLength, Style/Documentation
|
|
99
101
|
last_n_tokens.shift
|
100
102
|
last_n_tokens.push(id)
|
101
103
|
|
102
|
-
if id == context.token_eos
|
103
|
-
id = context.token_nl
|
104
|
+
if id == context.model.token_eos
|
105
|
+
id = context.model.token_nl
|
104
106
|
unless antiprompt.empty?
|
105
107
|
first_antiprompt = context.model.tokenize(text: antiprompt, add_bos: false)
|
106
108
|
embd_input.concat(first_antiprompt)
|
data/ext/llama_cpp/extconf.rb
CHANGED
@@ -53,7 +53,7 @@ if with_config('metal')
|
|
53
53
|
$CFLAGS << ' -DGGML_USE_METAL'
|
54
54
|
$CXXFLAGS << ' -DGGML_USE_METAL'
|
55
55
|
$LDFLAGS << ' -framework Foundation -framework Metal -framework MetalKit'
|
56
|
-
$objs = %w[ggml.o ggml-alloc.o ggml-metal.o llama.o llama_cpp.o]
|
56
|
+
$objs = %w[ggml.o ggml-backend.o ggml-alloc.o ggml-metal.o llama.o llama_cpp.o]
|
57
57
|
$objs << 'k_quants.o' unless with_config('no_k_quants')
|
58
58
|
end
|
59
59
|
|
@@ -61,7 +61,7 @@ if with_config('cublas')
|
|
61
61
|
$CFLAGS << ' -DGGML_USE_CUBLAS -I/usr/local/cuda/include'
|
62
62
|
$CXXFLAGS << ' -DGGML_USE_CUBLAS -I/usr/local/cuda/include'
|
63
63
|
$LDFLAGS << ' -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64'
|
64
|
-
$objs = %w[ggml.o ggml-alloc.o ggml-cuda.o llama.o llama_cpp.o]
|
64
|
+
$objs = %w[ggml.o ggml-backend.o ggml-alloc.o ggml-cuda.o llama.o llama_cpp.o]
|
65
65
|
$objs << 'k_quants.o' unless with_config('no_k_quants')
|
66
66
|
end
|
67
67
|
|
data/ext/llama_cpp/llama_cpp.cpp
CHANGED
@@ -1148,6 +1148,16 @@ public:
|
|
1148
1148
|
rb_define_method(rb_cLLaMAModel, "desc", RUBY_METHOD_FUNC(_llama_model_get_model_desc), 0);
|
1149
1149
|
rb_define_method(rb_cLLaMAModel, "size", RUBY_METHOD_FUNC(_llama_model_get_model_size), 0);
|
1150
1150
|
rb_define_method(rb_cLLaMAModel, "n_params", RUBY_METHOD_FUNC(_llama_model_get_model_n_params), 0);
|
1151
|
+
rb_define_method(rb_cLLaMAModel, "text", RUBY_METHOD_FUNC(_llama_model_get_text), 1);
|
1152
|
+
rb_define_method(rb_cLLaMAModel, "score", RUBY_METHOD_FUNC(_llama_model_get_score), 1);
|
1153
|
+
rb_define_method(rb_cLLaMAModel, "type", RUBY_METHOD_FUNC(_llama_model_get_type), 1);
|
1154
|
+
rb_define_method(rb_cLLaMAModel, "token_bos", RUBY_METHOD_FUNC(_llama_model_token_bos), 0);
|
1155
|
+
rb_define_method(rb_cLLaMAModel, "token_eos", RUBY_METHOD_FUNC(_llama_model_token_eos), 0);
|
1156
|
+
rb_define_method(rb_cLLaMAModel, "token_nl", RUBY_METHOD_FUNC(_llama_model_token_nl), 0);
|
1157
|
+
rb_define_method(rb_cLLaMAModel, "token_prefix", RUBY_METHOD_FUNC(_llama_model_token_prefix), 0);
|
1158
|
+
rb_define_method(rb_cLLaMAModel, "token_middle", RUBY_METHOD_FUNC(_llama_model_token_middle), 0);
|
1159
|
+
rb_define_method(rb_cLLaMAModel, "token_suffix", RUBY_METHOD_FUNC(_llama_model_token_suffix), 0);
|
1160
|
+
rb_define_method(rb_cLLaMAModel, "token_eot", RUBY_METHOD_FUNC(_llama_model_token_eot), 0);
|
1151
1161
|
}
|
1152
1162
|
|
1153
1163
|
private:
|
@@ -1396,6 +1406,62 @@ private:
|
|
1396
1406
|
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1397
1407
|
return UINT2NUM(llama_model_n_params(ptr->model));
|
1398
1408
|
}
|
1409
|
+
|
1410
|
+
static VALUE _llama_model_get_text(VALUE self, VALUE token_) {
|
1411
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1412
|
+
const llama_token token = NUM2INT(token_);
|
1413
|
+
const char* text = llama_token_get_text(ptr->model, token);
|
1414
|
+
return rb_utf8_str_new_cstr(text);
|
1415
|
+
}
|
1416
|
+
|
1417
|
+
static VALUE _llama_model_get_score(VALUE self, VALUE token_) {
|
1418
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1419
|
+
const llama_token token = NUM2INT(token_);
|
1420
|
+
const float score = llama_token_get_score(ptr->model, token);
|
1421
|
+
return DBL2NUM(score);
|
1422
|
+
}
|
1423
|
+
|
1424
|
+
static VALUE _llama_model_get_type(VALUE self, VALUE token_) {
|
1425
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1426
|
+
const llama_token token = NUM2INT(token_);
|
1427
|
+
const int type = llama_token_get_type(ptr->model, token);
|
1428
|
+
return INT2NUM(type);
|
1429
|
+
}
|
1430
|
+
|
1431
|
+
static VALUE _llama_model_token_bos(VALUE self) {
|
1432
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1433
|
+
return INT2NUM(llama_token_bos(ptr->model));
|
1434
|
+
}
|
1435
|
+
|
1436
|
+
static VALUE _llama_model_token_eos(VALUE self) {
|
1437
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1438
|
+
return INT2NUM(llama_token_eos(ptr->model));
|
1439
|
+
}
|
1440
|
+
|
1441
|
+
static VALUE _llama_model_token_nl(VALUE self) {
|
1442
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1443
|
+
return INT2NUM(llama_token_nl(ptr->model));
|
1444
|
+
}
|
1445
|
+
|
1446
|
+
static VALUE _llama_model_token_prefix(VALUE self) {
|
1447
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1448
|
+
return INT2NUM(llama_token_prefix(ptr->model));
|
1449
|
+
}
|
1450
|
+
|
1451
|
+
static VALUE _llama_model_token_middle(VALUE self) {
|
1452
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1453
|
+
return INT2NUM(llama_token_middle(ptr->model));
|
1454
|
+
}
|
1455
|
+
|
1456
|
+
static VALUE _llama_model_token_suffix(VALUE self) {
|
1457
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1458
|
+
return INT2NUM(llama_token_suffix(ptr->model));
|
1459
|
+
}
|
1460
|
+
|
1461
|
+
static VALUE _llama_model_token_eot(VALUE self) {
|
1462
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1463
|
+
return INT2NUM(llama_token_eot(ptr->model));
|
1464
|
+
}
|
1399
1465
|
};
|
1400
1466
|
|
1401
1467
|
const rb_data_type_t RbLLaMAModel::llama_model_type = {
|
@@ -1670,16 +1736,6 @@ public:
|
|
1670
1736
|
rb_define_method(rb_cLLaMAContext, "decode", RUBY_METHOD_FUNC(_llama_context_decode), 1);
|
1671
1737
|
rb_define_method(rb_cLLaMAContext, "logits", RUBY_METHOD_FUNC(_llama_context_logits), 0);
|
1672
1738
|
rb_define_method(rb_cLLaMAContext, "embeddings", RUBY_METHOD_FUNC(_llama_context_embeddings), 0);
|
1673
|
-
rb_define_method(rb_cLLaMAContext, "text", RUBY_METHOD_FUNC(_llama_context_text), 1);
|
1674
|
-
rb_define_method(rb_cLLaMAContext, "score", RUBY_METHOD_FUNC(_llama_context_score), 1);
|
1675
|
-
rb_define_method(rb_cLLaMAContext, "type", RUBY_METHOD_FUNC(_llama_context_type), 1);
|
1676
|
-
rb_define_method(rb_cLLaMAContext, "token_bos", RUBY_METHOD_FUNC(_llama_context_token_bos), 0);
|
1677
|
-
rb_define_method(rb_cLLaMAContext, "token_eos", RUBY_METHOD_FUNC(_llama_context_token_eos), 0);
|
1678
|
-
rb_define_method(rb_cLLaMAContext, "token_nl", RUBY_METHOD_FUNC(_llama_context_token_nl), 0);
|
1679
|
-
rb_define_method(rb_cLLaMAContext, "token_prefix", RUBY_METHOD_FUNC(_llama_context_token_prefix), 0);
|
1680
|
-
rb_define_method(rb_cLLaMAContext, "token_middle", RUBY_METHOD_FUNC(_llama_context_token_middle), 0);
|
1681
|
-
rb_define_method(rb_cLLaMAContext, "token_suffix", RUBY_METHOD_FUNC(_llama_context_token_suffix), 0);
|
1682
|
-
rb_define_method(rb_cLLaMAContext, "token_eot", RUBY_METHOD_FUNC(_llama_context_token_eot), 0);
|
1683
1739
|
rb_define_method(rb_cLLaMAContext, "n_ctx", RUBY_METHOD_FUNC(_llama_context_n_ctx), 0);
|
1684
1740
|
rb_define_method(rb_cLLaMAContext, "timings", RUBY_METHOD_FUNC(_llama_context_get_timings), 0);
|
1685
1741
|
rb_define_method(rb_cLLaMAContext, "print_timings", RUBY_METHOD_FUNC(_llama_context_print_timings), 0);
|
@@ -1693,8 +1749,7 @@ public:
|
|
1693
1749
|
rb_define_method(rb_cLLaMAContext, "set_rng_seed", RUBY_METHOD_FUNC(_llama_context_set_rng_seed), 1);
|
1694
1750
|
rb_define_method(rb_cLLaMAContext, "load_session_file", RUBY_METHOD_FUNC(_llama_context_load_session_file), -1);
|
1695
1751
|
rb_define_method(rb_cLLaMAContext, "save_session_file", RUBY_METHOD_FUNC(_llama_context_save_session_file), -1);
|
1696
|
-
rb_define_method(rb_cLLaMAContext, "
|
1697
|
-
rb_define_method(rb_cLLaMAContext, "sample_frequency_and_presence_penalties", RUBY_METHOD_FUNC(_llama_context_sample_frequency_and_presence_penalties), -1);
|
1752
|
+
rb_define_method(rb_cLLaMAContext, "sample_repetition_penalties", RUBY_METHOD_FUNC(_llama_context_sample_repetition_penalties), -1);
|
1698
1753
|
rb_define_method(rb_cLLaMAContext, "sample_classifier_free_guidance", RUBY_METHOD_FUNC(_llama_context_sample_classifier_free_guidance), -1);
|
1699
1754
|
rb_define_method(rb_cLLaMAContext, "sample_softmax", RUBY_METHOD_FUNC(_llama_context_sample_softmax), 1);
|
1700
1755
|
rb_define_method(rb_cLLaMAContext, "sample_top_k", RUBY_METHOD_FUNC(_llama_context_sample_top_k), -1);
|
@@ -1927,102 +1982,6 @@ private:
|
|
1927
1982
|
return output;
|
1928
1983
|
}
|
1929
1984
|
|
1930
|
-
static VALUE _llama_context_text(VALUE self, VALUE token_) {
|
1931
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1932
|
-
if (ptr->ctx == NULL) {
|
1933
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1934
|
-
return Qnil;
|
1935
|
-
}
|
1936
|
-
const llama_token token = NUM2INT(token_);
|
1937
|
-
const char* text = llama_token_get_text(ptr->ctx, token);
|
1938
|
-
return rb_utf8_str_new_cstr(text);
|
1939
|
-
}
|
1940
|
-
|
1941
|
-
static VALUE _llama_context_score(VALUE self, VALUE token_) {
|
1942
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1943
|
-
if (ptr->ctx == NULL) {
|
1944
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1945
|
-
return Qnil;
|
1946
|
-
}
|
1947
|
-
const llama_token token = NUM2INT(token_);
|
1948
|
-
const float score = llama_token_get_score(ptr->ctx, token);
|
1949
|
-
return DBL2NUM(score);
|
1950
|
-
}
|
1951
|
-
|
1952
|
-
static VALUE _llama_context_type(VALUE self, VALUE token_) {
|
1953
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1954
|
-
if (ptr->ctx == NULL) {
|
1955
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1956
|
-
return Qnil;
|
1957
|
-
}
|
1958
|
-
const llama_token token = NUM2INT(token_);
|
1959
|
-
const int type = llama_token_get_type(ptr->ctx, token);
|
1960
|
-
return INT2NUM(type);
|
1961
|
-
}
|
1962
|
-
|
1963
|
-
static VALUE _llama_context_token_bos(VALUE self) {
|
1964
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1965
|
-
if (ptr->ctx == NULL) {
|
1966
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1967
|
-
return Qnil;
|
1968
|
-
}
|
1969
|
-
return INT2NUM(llama_token_bos(ptr->ctx));
|
1970
|
-
}
|
1971
|
-
|
1972
|
-
static VALUE _llama_context_token_eos(VALUE self) {
|
1973
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1974
|
-
if (ptr->ctx == NULL) {
|
1975
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1976
|
-
return Qnil;
|
1977
|
-
}
|
1978
|
-
return INT2NUM(llama_token_eos(ptr->ctx));
|
1979
|
-
}
|
1980
|
-
|
1981
|
-
static VALUE _llama_context_token_nl(VALUE self) {
|
1982
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1983
|
-
if (ptr->ctx == NULL) {
|
1984
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1985
|
-
return Qnil;
|
1986
|
-
}
|
1987
|
-
return INT2NUM(llama_token_nl(ptr->ctx));
|
1988
|
-
}
|
1989
|
-
|
1990
|
-
static VALUE _llama_context_token_prefix(VALUE self) {
|
1991
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1992
|
-
if (ptr->ctx == NULL) {
|
1993
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1994
|
-
return Qnil;
|
1995
|
-
}
|
1996
|
-
return INT2NUM(llama_token_prefix(ptr->ctx));
|
1997
|
-
}
|
1998
|
-
|
1999
|
-
static VALUE _llama_context_token_middle(VALUE self) {
|
2000
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
2001
|
-
if (ptr->ctx == NULL) {
|
2002
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
2003
|
-
return Qnil;
|
2004
|
-
}
|
2005
|
-
return INT2NUM(llama_token_middle(ptr->ctx));
|
2006
|
-
}
|
2007
|
-
|
2008
|
-
static VALUE _llama_context_token_suffix(VALUE self) {
|
2009
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
2010
|
-
if (ptr->ctx == NULL) {
|
2011
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
2012
|
-
return Qnil;
|
2013
|
-
}
|
2014
|
-
return INT2NUM(llama_token_suffix(ptr->ctx));
|
2015
|
-
}
|
2016
|
-
|
2017
|
-
static VALUE _llama_context_token_eot(VALUE self) {
|
2018
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
2019
|
-
if (ptr->ctx == NULL) {
|
2020
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
2021
|
-
return Qnil;
|
2022
|
-
}
|
2023
|
-
return INT2NUM(llama_token_eot(ptr->ctx));
|
2024
|
-
}
|
2025
|
-
|
2026
1985
|
static VALUE _llama_context_n_ctx(VALUE self) {
|
2027
1986
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
2028
1987
|
if (ptr->ctx == NULL) {
|
@@ -2231,14 +2190,14 @@ private:
|
|
2231
2190
|
return Qnil;
|
2232
2191
|
}
|
2233
2192
|
|
2234
|
-
static VALUE
|
2193
|
+
static VALUE _llama_context_sample_repetition_penalties(int argc, VALUE* argv, VALUE self) {
|
2235
2194
|
VALUE kw_args = Qnil;
|
2236
|
-
ID kw_table[
|
2237
|
-
VALUE kw_values[
|
2195
|
+
ID kw_table[3] = { rb_intern("penalty_repeat"), rb_intern("penalty_freq"), rb_intern("penalty_present") };
|
2196
|
+
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
2238
2197
|
VALUE candidates = Qnil;
|
2239
2198
|
VALUE last_n_tokens = Qnil;
|
2240
2199
|
rb_scan_args(argc, argv, "2:", &candidates, &last_n_tokens, &kw_args);
|
2241
|
-
rb_get_kwargs(kw_args, kw_table,
|
2200
|
+
rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
|
2242
2201
|
|
2243
2202
|
if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
|
2244
2203
|
rb_raise(rb_eArgError, "candidates must be a TokenDataArray");
|
@@ -2249,56 +2208,15 @@ private:
|
|
2249
2208
|
return Qnil;
|
2250
2209
|
}
|
2251
2210
|
if (!RB_FLOAT_TYPE_P(kw_values[0])) {
|
2252
|
-
rb_raise(rb_eArgError, "
|
2211
|
+
rb_raise(rb_eArgError, "penalty_repeat must be a float");
|
2253
2212
|
return Qnil;
|
2254
2213
|
}
|
2255
|
-
|
2256
|
-
|
2257
|
-
std::vector<llama_token> last_n_tokens_data(last_tokens_size);
|
2258
|
-
for (size_t i = 0; i < last_tokens_size; i++) {
|
2259
|
-
last_n_tokens_data[i] = NUM2INT(rb_ary_entry(last_n_tokens, i));
|
2260
|
-
}
|
2261
|
-
|
2262
|
-
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
2263
|
-
if (ctx_ptr->ctx == NULL) {
|
2264
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
2265
|
-
return Qnil;
|
2266
|
-
}
|
2267
|
-
LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
|
2268
|
-
if (cnd_ptr->array.data == nullptr) {
|
2269
|
-
rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
|
2270
|
-
return Qnil;
|
2271
|
-
}
|
2272
|
-
const float penalty = NUM2DBL(kw_values[0]);
|
2273
|
-
|
2274
|
-
llama_sample_repetition_penalty(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size, penalty);
|
2275
|
-
|
2276
|
-
return Qnil;
|
2277
|
-
}
|
2278
|
-
|
2279
|
-
static VALUE _llama_context_sample_frequency_and_presence_penalties(int argc, VALUE* argv, VALUE self) {
|
2280
|
-
VALUE kw_args = Qnil;
|
2281
|
-
ID kw_table[2] = { rb_intern("frequency"), rb_intern("presence") };
|
2282
|
-
VALUE kw_values[2] = { Qundef, Qundef };
|
2283
|
-
VALUE candidates = Qnil;
|
2284
|
-
VALUE last_n_tokens = Qnil;
|
2285
|
-
rb_scan_args(argc, argv, "2:", &candidates, &last_n_tokens, &kw_args);
|
2286
|
-
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
2287
|
-
|
2288
|
-
if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
|
2289
|
-
rb_raise(rb_eArgError, "candidates must be a TokenDataArray");
|
2290
|
-
return Qnil;
|
2291
|
-
}
|
2292
|
-
if (!RB_TYPE_P(last_n_tokens, T_ARRAY)) {
|
2293
|
-
rb_raise(rb_eArgError, "last_n_tokens must be an Array");
|
2294
|
-
return Qnil;
|
2295
|
-
}
|
2296
|
-
if (!RB_FLOAT_TYPE_P(kw_values[0])) {
|
2297
|
-
rb_raise(rb_eArgError, "frequency must be a float");
|
2214
|
+
if (!RB_FLOAT_TYPE_P(kw_values[1])) {
|
2215
|
+
rb_raise(rb_eArgError, "penalty_freq must be a float");
|
2298
2216
|
return Qnil;
|
2299
2217
|
}
|
2300
|
-
if (!RB_FLOAT_TYPE_P(kw_values[
|
2301
|
-
rb_raise(rb_eArgError, "
|
2218
|
+
if (!RB_FLOAT_TYPE_P(kw_values[2])) {
|
2219
|
+
rb_raise(rb_eArgError, "penalty_present must be a float");
|
2302
2220
|
return Qnil;
|
2303
2221
|
}
|
2304
2222
|
|
@@ -2318,11 +2236,12 @@ private:
|
|
2318
2236
|
rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
|
2319
2237
|
return Qnil;
|
2320
2238
|
}
|
2239
|
+
const float penalty_repeat = NUM2DBL(kw_values[0]);
|
2240
|
+
const float penalty_freq = NUM2DBL(kw_values[1]);
|
2241
|
+
const float penalty_present = NUM2DBL(kw_values[2]);
|
2321
2242
|
|
2322
|
-
|
2323
|
-
|
2324
|
-
|
2325
|
-
llama_sample_frequency_and_presence_penalties(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size, alpha_frequency, alpha_presence);
|
2243
|
+
llama_sample_repetition_penalties(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size,
|
2244
|
+
penalty_repeat, penalty_freq, penalty_present);
|
2326
2245
|
|
2327
2246
|
return Qnil;
|
2328
2247
|
}
|