llama_cpp 0.7.1 → 0.9.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 6688a7296f7a7e7ba4aa593b2d9b792beb1d569f7f2e0e872e1dbda64a336b57
4
- data.tar.gz: 3f683714c3b11b8f247d9ef40774b90e297c25f3bf2ab478e763bda9c983d73a
3
+ metadata.gz: 683f2d81aff9e82234925ba08cd5b46b56a2283ff8397a6c06ce50d34a95dbfc
4
+ data.tar.gz: d3005cab273b8d85f47f4cb4314fbab3a540d366a42829e5ec8d2c29576ae09e
5
5
  SHA512:
6
- metadata.gz: d7dc061516e688624f4090b956fd40999c9e2e5d2ae41fe8a1baac3caaf61ed9aef3ef31e8ca971e0a210a592cb3618f67533483e5808e2e9205e2ba9a7dfcf8
7
- data.tar.gz: aae1a4952d19aa186aa2ea97ce59af1dac7295f5430108aaf6545949218851b31c266472cf6111a62f7a5784c5f23fd3e3697f1181d5e659c217975890eed299
6
+ metadata.gz: 559f1ba1253a704c38480336decd315c65b4d80e6895ad1dc0faa3b5b81570a1faeaadcb6ec7ee3145f0fff758ab5e38e6cb8163382ce9b693d893deebe9a8f9
7
+ data.tar.gz: cb3d96b8c3f79cd20d4169a175270e8768c04bcaa24e51cb2c4d7872db88bc6e3349e6b1e93a130b89d21daab8be6e57b5305412059ea722084c7cb7d4a01e93
data/CHANGELOG.md CHANGED
@@ -1,3 +1,21 @@
1
+ ## [[0.9.0](https://github.com/yoshoku/llama_cpp.rb/compare/v0.8.0...v0.9.0)] - 2023-10-28
2
+
3
+ - Fix missing object file for ggml-backend when building with metal and cublas options.
4
+
5
+ **Breaking Changes**
6
+ - Bump bundled llama.cpp from b1405 to b1429
7
+ - Move following methods from Context to Model:
8
+ - text, score, type, token_bos, token_eos, token_nl, token_prefix, token_middle, token_suffix, and token_eos.
9
+ - Add `sample_repetition_penalties` method, which integrates sample_frequency_and_presence_penalties and sample_repetition_penalty methods.
10
+
11
+ ## [[0.8.0](https://github.com/yoshoku/llama_cpp.rb/compare/v0.7.1...v0.8.0)] - 2023-10-21
12
+
13
+ **Breaking Changes**
14
+ - Bump bundled llama.cpp from b1380 to b1405
15
+ - Add column index argument to `set_seq_id` and `get_seq_id` methods in Batch.
16
+ - Add `special` keyword argument to `tokenize` method in Model.
17
+ - Add `n_seq_max` keyword argument to `initialize` method in Batch.
18
+
1
19
  ## [[0.7.1](https://github.com/yoshoku/llama_cpp.rb/compare/v0.7.0...v0.7.1)] - 2023-10-14
2
20
 
3
21
  - Bump bundled llama.cpp from b1334 to b1380.
data/examples/chat.rb CHANGED
@@ -83,10 +83,12 @@ class Chat < Thor # rubocop:disable Metrics/ClassLength, Style/Documentation
83
83
  candidates = LLaMACpp::TokenDataArray.new(base_candidates)
84
84
 
85
85
  last_n_repeat = [last_n_tokens.size, options[:repeat_last_n], n_ctx].min
86
- context.sample_repetition_penalty(candidates, last_n_tokens[-last_n_repeat..], penalty: options[:repeat_penalty])
87
- context.sample_frequency_and_presence_penalties(
88
- candidates, last_n_tokens[-last_n_repeat..],
89
- frequency: options[:frequency_penalty], presence: options[:presence_penalty]
86
+ context.sample_repetition_penalties(
87
+ candidates,
88
+ last_n_tokens[-last_n_repeat..],
89
+ penalty_repeat: options[:repeat_penalty],
90
+ penalty_freq: options[:frequency_penalty],
91
+ penalty_present: options[:presence_penalty]
90
92
  )
91
93
 
92
94
  context.sample_top_k(candidates, k: options[:top_k])
@@ -99,8 +101,8 @@ class Chat < Thor # rubocop:disable Metrics/ClassLength, Style/Documentation
99
101
  last_n_tokens.shift
100
102
  last_n_tokens.push(id)
101
103
 
102
- if id == context.token_eos
103
- id = context.token_nl
104
+ if id == context.model.token_eos
105
+ id = context.model.token_nl
104
106
  unless antiprompt.empty?
105
107
  first_antiprompt = context.model.tokenize(text: antiprompt, add_bos: false)
106
108
  embd_input.concat(first_antiprompt)
@@ -53,7 +53,7 @@ if with_config('metal')
53
53
  $CFLAGS << ' -DGGML_USE_METAL'
54
54
  $CXXFLAGS << ' -DGGML_USE_METAL'
55
55
  $LDFLAGS << ' -framework Foundation -framework Metal -framework MetalKit'
56
- $objs = %w[ggml.o ggml-alloc.o ggml-metal.o llama.o llama_cpp.o]
56
+ $objs = %w[ggml.o ggml-backend.o ggml-alloc.o ggml-metal.o llama.o llama_cpp.o]
57
57
  $objs << 'k_quants.o' unless with_config('no_k_quants')
58
58
  end
59
59
 
@@ -61,7 +61,7 @@ if with_config('cublas')
61
61
  $CFLAGS << ' -DGGML_USE_CUBLAS -I/usr/local/cuda/include'
62
62
  $CXXFLAGS << ' -DGGML_USE_CUBLAS -I/usr/local/cuda/include'
63
63
  $LDFLAGS << ' -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64'
64
- $objs = %w[ggml.o ggml-alloc.o ggml-cuda.o llama.o llama_cpp.o]
64
+ $objs = %w[ggml.o ggml-backend.o ggml-alloc.o ggml-cuda.o llama.o llama_cpp.o]
65
65
  $objs << 'k_quants.o' unless with_config('no_k_quants')
66
66
  end
67
67
 
@@ -63,8 +63,8 @@ public:
63
63
  rb_define_method(rb_cLLaMABatch, "get_token", RUBY_METHOD_FUNC(_llama_batch_get_token), 1);
64
64
  rb_define_method(rb_cLLaMABatch, "set_pos", RUBY_METHOD_FUNC(_llama_batch_set_pos), 2);
65
65
  rb_define_method(rb_cLLaMABatch, "get_pos", RUBY_METHOD_FUNC(_llama_batch_get_pos), 1);
66
- rb_define_method(rb_cLLaMABatch, "set_seq_id", RUBY_METHOD_FUNC(_llama_batch_set_seq_id), 2);
67
- rb_define_method(rb_cLLaMABatch, "get_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_seq_id), 1);
66
+ rb_define_method(rb_cLLaMABatch, "set_seq_id", RUBY_METHOD_FUNC(_llama_batch_set_seq_id), 3);
67
+ rb_define_method(rb_cLLaMABatch, "get_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_seq_id), 2);
68
68
  rb_define_method(rb_cLLaMABatch, "set_logits", RUBY_METHOD_FUNC(_llama_batch_set_logits), 2);
69
69
  rb_define_method(rb_cLLaMABatch, "get_logits", RUBY_METHOD_FUNC(_llama_batch_get_logits), 1);
70
70
  }
@@ -74,10 +74,10 @@ private:
74
74
 
75
75
  static VALUE _llama_batch_initialize(int argc, VALUE* argv, VALUE self) {
76
76
  VALUE kw_args = Qnil;
77
- ID kw_table[2] = { rb_intern("n_tokens"), rb_intern("embd") };
78
- VALUE kw_values[2] = { Qundef, Qundef };
77
+ ID kw_table[3] = { rb_intern("n_tokens"), rb_intern("embd"), rb_intern("n_seq_max") };
78
+ VALUE kw_values[3] = { Qundef, Qundef, Qundef };
79
79
  rb_scan_args(argc, argv, ":", &kw_args);
80
- rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
80
+ rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
81
81
 
82
82
  if (!RB_INTEGER_TYPE_P(kw_values[0])) {
83
83
  rb_raise(rb_eArgError, "n_tokens must be an integer");
@@ -87,12 +87,17 @@ private:
87
87
  rb_raise(rb_eArgError, "embd must be an integer");
88
88
  return Qnil;
89
89
  }
90
+ if (!RB_INTEGER_TYPE_P(kw_values[2])) {
91
+ rb_raise(rb_eArgError, "n_seq_max must be an integer");
92
+ return Qnil;
93
+ }
90
94
 
91
95
  const int32_t n_tokens = NUM2INT(kw_values[0]);
92
96
  const int32_t embd = NUM2INT(kw_values[1]);
97
+ const int32_t n_seq_max = NUM2INT(kw_values[2]);
93
98
 
94
99
  LLaMABatchWrapper* ptr = get_llama_batch(self);
95
- ptr->batch = llama_batch_init(n_tokens, embd);
100
+ ptr->batch = llama_batch_init(n_tokens, embd, n_seq_max);
96
101
 
97
102
  return Qnil;
98
103
  }
@@ -190,25 +195,35 @@ private:
190
195
  }
191
196
 
192
197
  // seq_id
193
- static VALUE _llama_batch_set_seq_id(VALUE self, VALUE idx, VALUE value) {
198
+ static VALUE _llama_batch_set_seq_id(VALUE self, VALUE i_, VALUE j_, VALUE value) {
194
199
  LLaMABatchWrapper* ptr = get_llama_batch(self);
195
- const int32_t id = NUM2INT(idx);
196
- if (id < 0 || id >= ptr->batch.n_tokens) {
197
- rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
200
+ const int32_t i = NUM2INT(i_);
201
+ if (i < 0 || i >= ptr->batch.n_tokens) {
202
+ rb_raise(rb_eArgError, "i must be in [0, n_tokens)");
203
+ return Qnil;
204
+ }
205
+ const int32_t j = NUM2INT(j_);
206
+ if (j < 0 || j >= ptr->batch.n_seq_id[i]) {
207
+ rb_raise(rb_eArgError, "j must be in [0, n_seq_id[i])");
198
208
  return Qnil;
199
209
  }
200
- ptr->batch.seq_id[id] = NUM2INT(value);
201
- return INT2NUM(ptr->batch.seq_id[id]);
210
+ ptr->batch.seq_id[i][j] = NUM2INT(value);
211
+ return INT2NUM(ptr->batch.seq_id[i][j]);
202
212
  }
203
213
 
204
- static VALUE _llama_batch_get_seq_id(VALUE self, VALUE idx) {
214
+ static VALUE _llama_batch_get_seq_id(VALUE self, VALUE i_, VALUE j_) {
205
215
  LLaMABatchWrapper* ptr = get_llama_batch(self);
206
- const int32_t id = NUM2INT(idx);
207
- if (id < 0 || id >= ptr->batch.n_tokens) {
208
- rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
216
+ const int32_t i = NUM2INT(i_);
217
+ if (i < 0 || i >= ptr->batch.n_tokens) {
218
+ rb_raise(rb_eArgError, "i must be in [0, n_tokens)");
209
219
  return Qnil;
210
220
  }
211
- return INT2NUM(ptr->batch.seq_id[id]);
221
+ const int32_t j = NUM2INT(j_);
222
+ if (j < 0 || j >= ptr->batch.n_seq_id[i]) {
223
+ rb_raise(rb_eArgError, "j must be in [0, n_seq_id[i])");
224
+ return Qnil;
225
+ }
226
+ return INT2NUM(ptr->batch.seq_id[i][j]);
212
227
  }
213
228
 
214
229
  // logits
@@ -1133,6 +1148,16 @@ public:
1133
1148
  rb_define_method(rb_cLLaMAModel, "desc", RUBY_METHOD_FUNC(_llama_model_get_model_desc), 0);
1134
1149
  rb_define_method(rb_cLLaMAModel, "size", RUBY_METHOD_FUNC(_llama_model_get_model_size), 0);
1135
1150
  rb_define_method(rb_cLLaMAModel, "n_params", RUBY_METHOD_FUNC(_llama_model_get_model_n_params), 0);
1151
+ rb_define_method(rb_cLLaMAModel, "text", RUBY_METHOD_FUNC(_llama_model_get_text), 1);
1152
+ rb_define_method(rb_cLLaMAModel, "score", RUBY_METHOD_FUNC(_llama_model_get_score), 1);
1153
+ rb_define_method(rb_cLLaMAModel, "type", RUBY_METHOD_FUNC(_llama_model_get_type), 1);
1154
+ rb_define_method(rb_cLLaMAModel, "token_bos", RUBY_METHOD_FUNC(_llama_model_token_bos), 0);
1155
+ rb_define_method(rb_cLLaMAModel, "token_eos", RUBY_METHOD_FUNC(_llama_model_token_eos), 0);
1156
+ rb_define_method(rb_cLLaMAModel, "token_nl", RUBY_METHOD_FUNC(_llama_model_token_nl), 0);
1157
+ rb_define_method(rb_cLLaMAModel, "token_prefix", RUBY_METHOD_FUNC(_llama_model_token_prefix), 0);
1158
+ rb_define_method(rb_cLLaMAModel, "token_middle", RUBY_METHOD_FUNC(_llama_model_token_middle), 0);
1159
+ rb_define_method(rb_cLLaMAModel, "token_suffix", RUBY_METHOD_FUNC(_llama_model_token_suffix), 0);
1160
+ rb_define_method(rb_cLLaMAModel, "token_eot", RUBY_METHOD_FUNC(_llama_model_token_eot), 0);
1136
1161
  }
1137
1162
 
1138
1163
  private:
@@ -1319,10 +1344,10 @@ private:
1319
1344
 
1320
1345
  static VALUE _llama_model_tokenize(int argc, VALUE* argv, VALUE self) {
1321
1346
  VALUE kw_args = Qnil;
1322
- ID kw_table[3] = { rb_intern("text"), rb_intern("n_max_tokens"), rb_intern("add_bos") };
1323
- VALUE kw_values[3] = { Qundef, Qundef, Qundef };
1347
+ ID kw_table[4] = { rb_intern("text"), rb_intern("n_max_tokens"), rb_intern("add_bos"), rb_intern("special") };
1348
+ VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
1324
1349
  rb_scan_args(argc, argv, ":", &kw_args);
1325
- rb_get_kwargs(kw_args, kw_table, 1, 2, kw_values);
1350
+ rb_get_kwargs(kw_args, kw_table, 1, 3, kw_values);
1326
1351
 
1327
1352
  if (!RB_TYPE_P(kw_values[0], T_STRING)) {
1328
1353
  rb_raise(rb_eArgError, "text must be a String");
@@ -1336,15 +1361,20 @@ private:
1336
1361
  rb_raise(rb_eArgError, "add_bos must be a boolean");
1337
1362
  return Qnil;
1338
1363
  }
1364
+ if (kw_values[3] != Qundef && (kw_values[3] != Qtrue && kw_values[3] != Qfalse)) {
1365
+ rb_raise(rb_eArgError, "special must be a boolean");
1366
+ return Qnil;
1367
+ }
1339
1368
 
1340
1369
  VALUE text_ = kw_values[0];
1341
1370
  std::string text = StringValueCStr(text_);
1342
1371
  const bool add_bos = kw_values[2] == Qtrue ? true : false;
1372
+ const bool special = kw_values[3] == Qtrue ? true : false;
1343
1373
  const int n_max_tokens = kw_values[1] != Qundef ? NUM2INT(kw_values[1]) : text.size() + (add_bos ? 1 : 0);
1344
1374
 
1345
1375
  llama_token* tokens = ALLOCA_N(llama_token, n_max_tokens);
1346
1376
  LLaMAModelWrapper* ptr = get_llama_model(self);
1347
- const int n_tokens = llama_tokenize(ptr->model, text.c_str(), text.size(), tokens, n_max_tokens, add_bos);
1377
+ const int n_tokens = llama_tokenize(ptr->model, text.c_str(), text.size(), tokens, n_max_tokens, add_bos, special);
1348
1378
 
1349
1379
  if (n_tokens < 0) {
1350
1380
  rb_raise(rb_eRuntimeError, "failed to tokenize. The numebr of tokens (%d) is greater than n_max_tokens.", -n_tokens);
@@ -1376,6 +1406,62 @@ private:
1376
1406
  LLaMAModelWrapper* ptr = get_llama_model(self);
1377
1407
  return UINT2NUM(llama_model_n_params(ptr->model));
1378
1408
  }
1409
+
1410
+ static VALUE _llama_model_get_text(VALUE self, VALUE token_) {
1411
+ LLaMAModelWrapper* ptr = get_llama_model(self);
1412
+ const llama_token token = NUM2INT(token_);
1413
+ const char* text = llama_token_get_text(ptr->model, token);
1414
+ return rb_utf8_str_new_cstr(text);
1415
+ }
1416
+
1417
+ static VALUE _llama_model_get_score(VALUE self, VALUE token_) {
1418
+ LLaMAModelWrapper* ptr = get_llama_model(self);
1419
+ const llama_token token = NUM2INT(token_);
1420
+ const float score = llama_token_get_score(ptr->model, token);
1421
+ return DBL2NUM(score);
1422
+ }
1423
+
1424
+ static VALUE _llama_model_get_type(VALUE self, VALUE token_) {
1425
+ LLaMAModelWrapper* ptr = get_llama_model(self);
1426
+ const llama_token token = NUM2INT(token_);
1427
+ const int type = llama_token_get_type(ptr->model, token);
1428
+ return INT2NUM(type);
1429
+ }
1430
+
1431
+ static VALUE _llama_model_token_bos(VALUE self) {
1432
+ LLaMAModelWrapper* ptr = get_llama_model(self);
1433
+ return INT2NUM(llama_token_bos(ptr->model));
1434
+ }
1435
+
1436
+ static VALUE _llama_model_token_eos(VALUE self) {
1437
+ LLaMAModelWrapper* ptr = get_llama_model(self);
1438
+ return INT2NUM(llama_token_eos(ptr->model));
1439
+ }
1440
+
1441
+ static VALUE _llama_model_token_nl(VALUE self) {
1442
+ LLaMAModelWrapper* ptr = get_llama_model(self);
1443
+ return INT2NUM(llama_token_nl(ptr->model));
1444
+ }
1445
+
1446
+ static VALUE _llama_model_token_prefix(VALUE self) {
1447
+ LLaMAModelWrapper* ptr = get_llama_model(self);
1448
+ return INT2NUM(llama_token_prefix(ptr->model));
1449
+ }
1450
+
1451
+ static VALUE _llama_model_token_middle(VALUE self) {
1452
+ LLaMAModelWrapper* ptr = get_llama_model(self);
1453
+ return INT2NUM(llama_token_middle(ptr->model));
1454
+ }
1455
+
1456
+ static VALUE _llama_model_token_suffix(VALUE self) {
1457
+ LLaMAModelWrapper* ptr = get_llama_model(self);
1458
+ return INT2NUM(llama_token_suffix(ptr->model));
1459
+ }
1460
+
1461
+ static VALUE _llama_model_token_eot(VALUE self) {
1462
+ LLaMAModelWrapper* ptr = get_llama_model(self);
1463
+ return INT2NUM(llama_token_eot(ptr->model));
1464
+ }
1379
1465
  };
1380
1466
 
1381
1467
  const rb_data_type_t RbLLaMAModel::llama_model_type = {
@@ -1650,16 +1736,6 @@ public:
1650
1736
  rb_define_method(rb_cLLaMAContext, "decode", RUBY_METHOD_FUNC(_llama_context_decode), 1);
1651
1737
  rb_define_method(rb_cLLaMAContext, "logits", RUBY_METHOD_FUNC(_llama_context_logits), 0);
1652
1738
  rb_define_method(rb_cLLaMAContext, "embeddings", RUBY_METHOD_FUNC(_llama_context_embeddings), 0);
1653
- rb_define_method(rb_cLLaMAContext, "text", RUBY_METHOD_FUNC(_llama_context_text), 1);
1654
- rb_define_method(rb_cLLaMAContext, "score", RUBY_METHOD_FUNC(_llama_context_score), 1);
1655
- rb_define_method(rb_cLLaMAContext, "type", RUBY_METHOD_FUNC(_llama_context_type), 1);
1656
- rb_define_method(rb_cLLaMAContext, "token_bos", RUBY_METHOD_FUNC(_llama_context_token_bos), 0);
1657
- rb_define_method(rb_cLLaMAContext, "token_eos", RUBY_METHOD_FUNC(_llama_context_token_eos), 0);
1658
- rb_define_method(rb_cLLaMAContext, "token_nl", RUBY_METHOD_FUNC(_llama_context_token_nl), 0);
1659
- rb_define_method(rb_cLLaMAContext, "token_prefix", RUBY_METHOD_FUNC(_llama_context_token_prefix), 0);
1660
- rb_define_method(rb_cLLaMAContext, "token_middle", RUBY_METHOD_FUNC(_llama_context_token_middle), 0);
1661
- rb_define_method(rb_cLLaMAContext, "token_suffix", RUBY_METHOD_FUNC(_llama_context_token_suffix), 0);
1662
- rb_define_method(rb_cLLaMAContext, "token_eot", RUBY_METHOD_FUNC(_llama_context_token_eot), 0);
1663
1739
  rb_define_method(rb_cLLaMAContext, "n_ctx", RUBY_METHOD_FUNC(_llama_context_n_ctx), 0);
1664
1740
  rb_define_method(rb_cLLaMAContext, "timings", RUBY_METHOD_FUNC(_llama_context_get_timings), 0);
1665
1741
  rb_define_method(rb_cLLaMAContext, "print_timings", RUBY_METHOD_FUNC(_llama_context_print_timings), 0);
@@ -1673,8 +1749,7 @@ public:
1673
1749
  rb_define_method(rb_cLLaMAContext, "set_rng_seed", RUBY_METHOD_FUNC(_llama_context_set_rng_seed), 1);
1674
1750
  rb_define_method(rb_cLLaMAContext, "load_session_file", RUBY_METHOD_FUNC(_llama_context_load_session_file), -1);
1675
1751
  rb_define_method(rb_cLLaMAContext, "save_session_file", RUBY_METHOD_FUNC(_llama_context_save_session_file), -1);
1676
- rb_define_method(rb_cLLaMAContext, "sample_repetition_penalty", RUBY_METHOD_FUNC(_llama_context_sample_repetition_penalty), -1);
1677
- rb_define_method(rb_cLLaMAContext, "sample_frequency_and_presence_penalties", RUBY_METHOD_FUNC(_llama_context_sample_frequency_and_presence_penalties), -1);
1752
+ rb_define_method(rb_cLLaMAContext, "sample_repetition_penalties", RUBY_METHOD_FUNC(_llama_context_sample_repetition_penalties), -1);
1678
1753
  rb_define_method(rb_cLLaMAContext, "sample_classifier_free_guidance", RUBY_METHOD_FUNC(_llama_context_sample_classifier_free_guidance), -1);
1679
1754
  rb_define_method(rb_cLLaMAContext, "sample_softmax", RUBY_METHOD_FUNC(_llama_context_sample_softmax), 1);
1680
1755
  rb_define_method(rb_cLLaMAContext, "sample_top_k", RUBY_METHOD_FUNC(_llama_context_sample_top_k), -1);
@@ -1907,102 +1982,6 @@ private:
1907
1982
  return output;
1908
1983
  }
1909
1984
 
1910
- static VALUE _llama_context_text(VALUE self, VALUE token_) {
1911
- LLaMAContextWrapper* ptr = get_llama_context(self);
1912
- if (ptr->ctx == NULL) {
1913
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1914
- return Qnil;
1915
- }
1916
- const llama_token token = NUM2INT(token_);
1917
- const char* text = llama_token_get_text(ptr->ctx, token);
1918
- return rb_utf8_str_new_cstr(text);
1919
- }
1920
-
1921
- static VALUE _llama_context_score(VALUE self, VALUE token_) {
1922
- LLaMAContextWrapper* ptr = get_llama_context(self);
1923
- if (ptr->ctx == NULL) {
1924
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1925
- return Qnil;
1926
- }
1927
- const llama_token token = NUM2INT(token_);
1928
- const float score = llama_token_get_score(ptr->ctx, token);
1929
- return DBL2NUM(score);
1930
- }
1931
-
1932
- static VALUE _llama_context_type(VALUE self, VALUE token_) {
1933
- LLaMAContextWrapper* ptr = get_llama_context(self);
1934
- if (ptr->ctx == NULL) {
1935
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1936
- return Qnil;
1937
- }
1938
- const llama_token token = NUM2INT(token_);
1939
- const int type = llama_token_get_type(ptr->ctx, token);
1940
- return INT2NUM(type);
1941
- }
1942
-
1943
- static VALUE _llama_context_token_bos(VALUE self) {
1944
- LLaMAContextWrapper* ptr = get_llama_context(self);
1945
- if (ptr->ctx == NULL) {
1946
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1947
- return Qnil;
1948
- }
1949
- return INT2NUM(llama_token_bos(ptr->ctx));
1950
- }
1951
-
1952
- static VALUE _llama_context_token_eos(VALUE self) {
1953
- LLaMAContextWrapper* ptr = get_llama_context(self);
1954
- if (ptr->ctx == NULL) {
1955
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1956
- return Qnil;
1957
- }
1958
- return INT2NUM(llama_token_eos(ptr->ctx));
1959
- }
1960
-
1961
- static VALUE _llama_context_token_nl(VALUE self) {
1962
- LLaMAContextWrapper* ptr = get_llama_context(self);
1963
- if (ptr->ctx == NULL) {
1964
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1965
- return Qnil;
1966
- }
1967
- return INT2NUM(llama_token_nl(ptr->ctx));
1968
- }
1969
-
1970
- static VALUE _llama_context_token_prefix(VALUE self) {
1971
- LLaMAContextWrapper* ptr = get_llama_context(self);
1972
- if (ptr->ctx == NULL) {
1973
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1974
- return Qnil;
1975
- }
1976
- return INT2NUM(llama_token_prefix(ptr->ctx));
1977
- }
1978
-
1979
- static VALUE _llama_context_token_middle(VALUE self) {
1980
- LLaMAContextWrapper* ptr = get_llama_context(self);
1981
- if (ptr->ctx == NULL) {
1982
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1983
- return Qnil;
1984
- }
1985
- return INT2NUM(llama_token_middle(ptr->ctx));
1986
- }
1987
-
1988
- static VALUE _llama_context_token_suffix(VALUE self) {
1989
- LLaMAContextWrapper* ptr = get_llama_context(self);
1990
- if (ptr->ctx == NULL) {
1991
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1992
- return Qnil;
1993
- }
1994
- return INT2NUM(llama_token_suffix(ptr->ctx));
1995
- }
1996
-
1997
- static VALUE _llama_context_token_eot(VALUE self) {
1998
- LLaMAContextWrapper* ptr = get_llama_context(self);
1999
- if (ptr->ctx == NULL) {
2000
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2001
- return Qnil;
2002
- }
2003
- return INT2NUM(llama_token_eot(ptr->ctx));
2004
- }
2005
-
2006
1985
  static VALUE _llama_context_n_ctx(VALUE self) {
2007
1986
  LLaMAContextWrapper* ptr = get_llama_context(self);
2008
1987
  if (ptr->ctx == NULL) {
@@ -2211,14 +2190,14 @@ private:
2211
2190
  return Qnil;
2212
2191
  }
2213
2192
 
2214
- static VALUE _llama_context_sample_repetition_penalty(int argc, VALUE* argv, VALUE self) {
2193
+ static VALUE _llama_context_sample_repetition_penalties(int argc, VALUE* argv, VALUE self) {
2215
2194
  VALUE kw_args = Qnil;
2216
- ID kw_table[1] = { rb_intern("penalty") };
2217
- VALUE kw_values[1] = { Qundef };
2195
+ ID kw_table[3] = { rb_intern("penalty_repeat"), rb_intern("penalty_freq"), rb_intern("penalty_present") };
2196
+ VALUE kw_values[3] = { Qundef, Qundef, Qundef };
2218
2197
  VALUE candidates = Qnil;
2219
2198
  VALUE last_n_tokens = Qnil;
2220
2199
  rb_scan_args(argc, argv, "2:", &candidates, &last_n_tokens, &kw_args);
2221
- rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
2200
+ rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
2222
2201
 
2223
2202
  if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
2224
2203
  rb_raise(rb_eArgError, "candidates must be a TokenDataArray");
@@ -2229,56 +2208,15 @@ private:
2229
2208
  return Qnil;
2230
2209
  }
2231
2210
  if (!RB_FLOAT_TYPE_P(kw_values[0])) {
2232
- rb_raise(rb_eArgError, "penalty must be a float");
2211
+ rb_raise(rb_eArgError, "penalty_repeat must be a float");
2233
2212
  return Qnil;
2234
2213
  }
2235
-
2236
- const size_t last_tokens_size = RARRAY_LEN(last_n_tokens);
2237
- std::vector<llama_token> last_n_tokens_data(last_tokens_size);
2238
- for (size_t i = 0; i < last_tokens_size; i++) {
2239
- last_n_tokens_data[i] = NUM2INT(rb_ary_entry(last_n_tokens, i));
2240
- }
2241
-
2242
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
2243
- if (ctx_ptr->ctx == NULL) {
2244
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2245
- return Qnil;
2246
- }
2247
- LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
2248
- if (cnd_ptr->array.data == nullptr) {
2249
- rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
2250
- return Qnil;
2251
- }
2252
- const float penalty = NUM2DBL(kw_values[0]);
2253
-
2254
- llama_sample_repetition_penalty(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size, penalty);
2255
-
2256
- return Qnil;
2257
- }
2258
-
2259
- static VALUE _llama_context_sample_frequency_and_presence_penalties(int argc, VALUE* argv, VALUE self) {
2260
- VALUE kw_args = Qnil;
2261
- ID kw_table[2] = { rb_intern("frequency"), rb_intern("presence") };
2262
- VALUE kw_values[2] = { Qundef, Qundef };
2263
- VALUE candidates = Qnil;
2264
- VALUE last_n_tokens = Qnil;
2265
- rb_scan_args(argc, argv, "2:", &candidates, &last_n_tokens, &kw_args);
2266
- rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
2267
-
2268
- if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
2269
- rb_raise(rb_eArgError, "candidates must be a TokenDataArray");
2270
- return Qnil;
2271
- }
2272
- if (!RB_TYPE_P(last_n_tokens, T_ARRAY)) {
2273
- rb_raise(rb_eArgError, "last_n_tokens must be an Array");
2274
- return Qnil;
2275
- }
2276
- if (!RB_FLOAT_TYPE_P(kw_values[0])) {
2277
- rb_raise(rb_eArgError, "frequency must be a float");
2214
+ if (!RB_FLOAT_TYPE_P(kw_values[1])) {
2215
+ rb_raise(rb_eArgError, "penalty_freq must be a float");
2278
2216
  return Qnil;
2279
2217
  }
2280
- if (!RB_FLOAT_TYPE_P(kw_values[1])) {
2281
- rb_raise(rb_eArgError, "presence must be a float");
2218
+ if (!RB_FLOAT_TYPE_P(kw_values[2])) {
2219
+ rb_raise(rb_eArgError, "penalty_present must be a float");
2282
2220
  return Qnil;
2283
2221
  }
2284
2222
 
@@ -2298,11 +2236,12 @@ private:
2298
2236
  rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
2299
2237
  return Qnil;
2300
2238
  }
2239
+ const float penalty_repeat = NUM2DBL(kw_values[0]);
2240
+ const float penalty_freq = NUM2DBL(kw_values[1]);
2241
+ const float penalty_present = NUM2DBL(kw_values[2]);
2301
2242
 
2302
- const float alpha_frequency = NUM2DBL(kw_values[0]);
2303
- const float alpha_presence = NUM2DBL(kw_values[1]);
2304
-
2305
- llama_sample_frequency_and_presence_penalties(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size, alpha_frequency, alpha_presence);
2243
+ llama_sample_repetition_penalties(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size,
2244
+ penalty_repeat, penalty_freq, penalty_present);
2306
2245
 
2307
2246
  return Qnil;
2308
2247
  }