llama_cpp 0.7.1 → 0.9.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 6688a7296f7a7e7ba4aa593b2d9b792beb1d569f7f2e0e872e1dbda64a336b57
4
- data.tar.gz: 3f683714c3b11b8f247d9ef40774b90e297c25f3bf2ab478e763bda9c983d73a
3
+ metadata.gz: 683f2d81aff9e82234925ba08cd5b46b56a2283ff8397a6c06ce50d34a95dbfc
4
+ data.tar.gz: d3005cab273b8d85f47f4cb4314fbab3a540d366a42829e5ec8d2c29576ae09e
5
5
  SHA512:
6
- metadata.gz: d7dc061516e688624f4090b956fd40999c9e2e5d2ae41fe8a1baac3caaf61ed9aef3ef31e8ca971e0a210a592cb3618f67533483e5808e2e9205e2ba9a7dfcf8
7
- data.tar.gz: aae1a4952d19aa186aa2ea97ce59af1dac7295f5430108aaf6545949218851b31c266472cf6111a62f7a5784c5f23fd3e3697f1181d5e659c217975890eed299
6
+ metadata.gz: 559f1ba1253a704c38480336decd315c65b4d80e6895ad1dc0faa3b5b81570a1faeaadcb6ec7ee3145f0fff758ab5e38e6cb8163382ce9b693d893deebe9a8f9
7
+ data.tar.gz: cb3d96b8c3f79cd20d4169a175270e8768c04bcaa24e51cb2c4d7872db88bc6e3349e6b1e93a130b89d21daab8be6e57b5305412059ea722084c7cb7d4a01e93
data/CHANGELOG.md CHANGED
@@ -1,3 +1,21 @@
1
+ ## [[0.9.0](https://github.com/yoshoku/llama_cpp.rb/compare/v0.8.0...v0.9.0)] - 2023-10-28
2
+
3
+ - Fix missing object file for ggml-backend when building with metal and cublas options.
4
+
5
+ **Breaking Changes**
6
+ - Bump bundled llama.cpp from b1405 to b1429
7
+ - Move following methods from Context to Model:
8
+ - text, score, type, token_bos, token_eos, token_nl, token_prefix, token_middle, token_suffix, and token_eos.
9
+ - Add `sample_repetition_penalties` method, which integrates sample_frequency_and_presence_penalties and sample_repetition_penalty methods.
10
+
11
+ ## [[0.8.0](https://github.com/yoshoku/llama_cpp.rb/compare/v0.7.1...v0.8.0)] - 2023-10-21
12
+
13
+ **Breaking Changes**
14
+ - Bump bundled llama.cpp from b1380 to b1405
15
+ - Add column index argument to `set_seq_id` and `get_seq_id` methods in Batch.
16
+ - Add `special` keyword argument to `tokenize` method in Model.
17
+ - Add `n_seq_max` keyword argument to `initialize` method in Batch.
18
+
1
19
  ## [[0.7.1](https://github.com/yoshoku/llama_cpp.rb/compare/v0.7.0...v0.7.1)] - 2023-10-14
2
20
 
3
21
  - Bump bundled llama.cpp from b1334 to b1380.
data/examples/chat.rb CHANGED
@@ -83,10 +83,12 @@ class Chat < Thor # rubocop:disable Metrics/ClassLength, Style/Documentation
83
83
  candidates = LLaMACpp::TokenDataArray.new(base_candidates)
84
84
 
85
85
  last_n_repeat = [last_n_tokens.size, options[:repeat_last_n], n_ctx].min
86
- context.sample_repetition_penalty(candidates, last_n_tokens[-last_n_repeat..], penalty: options[:repeat_penalty])
87
- context.sample_frequency_and_presence_penalties(
88
- candidates, last_n_tokens[-last_n_repeat..],
89
- frequency: options[:frequency_penalty], presence: options[:presence_penalty]
86
+ context.sample_repetition_penalties(
87
+ candidates,
88
+ last_n_tokens[-last_n_repeat..],
89
+ penalty_repeat: options[:repeat_penalty],
90
+ penalty_freq: options[:frequency_penalty],
91
+ penalty_present: options[:presence_penalty]
90
92
  )
91
93
 
92
94
  context.sample_top_k(candidates, k: options[:top_k])
@@ -99,8 +101,8 @@ class Chat < Thor # rubocop:disable Metrics/ClassLength, Style/Documentation
99
101
  last_n_tokens.shift
100
102
  last_n_tokens.push(id)
101
103
 
102
- if id == context.token_eos
103
- id = context.token_nl
104
+ if id == context.model.token_eos
105
+ id = context.model.token_nl
104
106
  unless antiprompt.empty?
105
107
  first_antiprompt = context.model.tokenize(text: antiprompt, add_bos: false)
106
108
  embd_input.concat(first_antiprompt)
@@ -53,7 +53,7 @@ if with_config('metal')
53
53
  $CFLAGS << ' -DGGML_USE_METAL'
54
54
  $CXXFLAGS << ' -DGGML_USE_METAL'
55
55
  $LDFLAGS << ' -framework Foundation -framework Metal -framework MetalKit'
56
- $objs = %w[ggml.o ggml-alloc.o ggml-metal.o llama.o llama_cpp.o]
56
+ $objs = %w[ggml.o ggml-backend.o ggml-alloc.o ggml-metal.o llama.o llama_cpp.o]
57
57
  $objs << 'k_quants.o' unless with_config('no_k_quants')
58
58
  end
59
59
 
@@ -61,7 +61,7 @@ if with_config('cublas')
61
61
  $CFLAGS << ' -DGGML_USE_CUBLAS -I/usr/local/cuda/include'
62
62
  $CXXFLAGS << ' -DGGML_USE_CUBLAS -I/usr/local/cuda/include'
63
63
  $LDFLAGS << ' -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64'
64
- $objs = %w[ggml.o ggml-alloc.o ggml-cuda.o llama.o llama_cpp.o]
64
+ $objs = %w[ggml.o ggml-backend.o ggml-alloc.o ggml-cuda.o llama.o llama_cpp.o]
65
65
  $objs << 'k_quants.o' unless with_config('no_k_quants')
66
66
  end
67
67
 
@@ -63,8 +63,8 @@ public:
63
63
  rb_define_method(rb_cLLaMABatch, "get_token", RUBY_METHOD_FUNC(_llama_batch_get_token), 1);
64
64
  rb_define_method(rb_cLLaMABatch, "set_pos", RUBY_METHOD_FUNC(_llama_batch_set_pos), 2);
65
65
  rb_define_method(rb_cLLaMABatch, "get_pos", RUBY_METHOD_FUNC(_llama_batch_get_pos), 1);
66
- rb_define_method(rb_cLLaMABatch, "set_seq_id", RUBY_METHOD_FUNC(_llama_batch_set_seq_id), 2);
67
- rb_define_method(rb_cLLaMABatch, "get_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_seq_id), 1);
66
+ rb_define_method(rb_cLLaMABatch, "set_seq_id", RUBY_METHOD_FUNC(_llama_batch_set_seq_id), 3);
67
+ rb_define_method(rb_cLLaMABatch, "get_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_seq_id), 2);
68
68
  rb_define_method(rb_cLLaMABatch, "set_logits", RUBY_METHOD_FUNC(_llama_batch_set_logits), 2);
69
69
  rb_define_method(rb_cLLaMABatch, "get_logits", RUBY_METHOD_FUNC(_llama_batch_get_logits), 1);
70
70
  }
@@ -74,10 +74,10 @@ private:
74
74
 
75
75
  static VALUE _llama_batch_initialize(int argc, VALUE* argv, VALUE self) {
76
76
  VALUE kw_args = Qnil;
77
- ID kw_table[2] = { rb_intern("n_tokens"), rb_intern("embd") };
78
- VALUE kw_values[2] = { Qundef, Qundef };
77
+ ID kw_table[3] = { rb_intern("n_tokens"), rb_intern("embd"), rb_intern("n_seq_max") };
78
+ VALUE kw_values[3] = { Qundef, Qundef, Qundef };
79
79
  rb_scan_args(argc, argv, ":", &kw_args);
80
- rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
80
+ rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
81
81
 
82
82
  if (!RB_INTEGER_TYPE_P(kw_values[0])) {
83
83
  rb_raise(rb_eArgError, "n_tokens must be an integer");
@@ -87,12 +87,17 @@ private:
87
87
  rb_raise(rb_eArgError, "embd must be an integer");
88
88
  return Qnil;
89
89
  }
90
+ if (!RB_INTEGER_TYPE_P(kw_values[2])) {
91
+ rb_raise(rb_eArgError, "n_seq_max must be an integer");
92
+ return Qnil;
93
+ }
90
94
 
91
95
  const int32_t n_tokens = NUM2INT(kw_values[0]);
92
96
  const int32_t embd = NUM2INT(kw_values[1]);
97
+ const int32_t n_seq_max = NUM2INT(kw_values[2]);
93
98
 
94
99
  LLaMABatchWrapper* ptr = get_llama_batch(self);
95
- ptr->batch = llama_batch_init(n_tokens, embd);
100
+ ptr->batch = llama_batch_init(n_tokens, embd, n_seq_max);
96
101
 
97
102
  return Qnil;
98
103
  }
@@ -190,25 +195,35 @@ private:
190
195
  }
191
196
 
192
197
  // seq_id
193
- static VALUE _llama_batch_set_seq_id(VALUE self, VALUE idx, VALUE value) {
198
+ static VALUE _llama_batch_set_seq_id(VALUE self, VALUE i_, VALUE j_, VALUE value) {
194
199
  LLaMABatchWrapper* ptr = get_llama_batch(self);
195
- const int32_t id = NUM2INT(idx);
196
- if (id < 0 || id >= ptr->batch.n_tokens) {
197
- rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
200
+ const int32_t i = NUM2INT(i_);
201
+ if (i < 0 || i >= ptr->batch.n_tokens) {
202
+ rb_raise(rb_eArgError, "i must be in [0, n_tokens)");
203
+ return Qnil;
204
+ }
205
+ const int32_t j = NUM2INT(j_);
206
+ if (j < 0 || j >= ptr->batch.n_seq_id[i]) {
207
+ rb_raise(rb_eArgError, "j must be in [0, n_seq_id[i])");
198
208
  return Qnil;
199
209
  }
200
- ptr->batch.seq_id[id] = NUM2INT(value);
201
- return INT2NUM(ptr->batch.seq_id[id]);
210
+ ptr->batch.seq_id[i][j] = NUM2INT(value);
211
+ return INT2NUM(ptr->batch.seq_id[i][j]);
202
212
  }
203
213
 
204
- static VALUE _llama_batch_get_seq_id(VALUE self, VALUE idx) {
214
+ static VALUE _llama_batch_get_seq_id(VALUE self, VALUE i_, VALUE j_) {
205
215
  LLaMABatchWrapper* ptr = get_llama_batch(self);
206
- const int32_t id = NUM2INT(idx);
207
- if (id < 0 || id >= ptr->batch.n_tokens) {
208
- rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
216
+ const int32_t i = NUM2INT(i_);
217
+ if (i < 0 || i >= ptr->batch.n_tokens) {
218
+ rb_raise(rb_eArgError, "i must be in [0, n_tokens)");
209
219
  return Qnil;
210
220
  }
211
- return INT2NUM(ptr->batch.seq_id[id]);
221
+ const int32_t j = NUM2INT(j_);
222
+ if (j < 0 || j >= ptr->batch.n_seq_id[i]) {
223
+ rb_raise(rb_eArgError, "j must be in [0, n_seq_id[i])");
224
+ return Qnil;
225
+ }
226
+ return INT2NUM(ptr->batch.seq_id[i][j]);
212
227
  }
213
228
 
214
229
  // logits
@@ -1133,6 +1148,16 @@ public:
1133
1148
  rb_define_method(rb_cLLaMAModel, "desc", RUBY_METHOD_FUNC(_llama_model_get_model_desc), 0);
1134
1149
  rb_define_method(rb_cLLaMAModel, "size", RUBY_METHOD_FUNC(_llama_model_get_model_size), 0);
1135
1150
  rb_define_method(rb_cLLaMAModel, "n_params", RUBY_METHOD_FUNC(_llama_model_get_model_n_params), 0);
1151
+ rb_define_method(rb_cLLaMAModel, "text", RUBY_METHOD_FUNC(_llama_model_get_text), 1);
1152
+ rb_define_method(rb_cLLaMAModel, "score", RUBY_METHOD_FUNC(_llama_model_get_score), 1);
1153
+ rb_define_method(rb_cLLaMAModel, "type", RUBY_METHOD_FUNC(_llama_model_get_type), 1);
1154
+ rb_define_method(rb_cLLaMAModel, "token_bos", RUBY_METHOD_FUNC(_llama_model_token_bos), 0);
1155
+ rb_define_method(rb_cLLaMAModel, "token_eos", RUBY_METHOD_FUNC(_llama_model_token_eos), 0);
1156
+ rb_define_method(rb_cLLaMAModel, "token_nl", RUBY_METHOD_FUNC(_llama_model_token_nl), 0);
1157
+ rb_define_method(rb_cLLaMAModel, "token_prefix", RUBY_METHOD_FUNC(_llama_model_token_prefix), 0);
1158
+ rb_define_method(rb_cLLaMAModel, "token_middle", RUBY_METHOD_FUNC(_llama_model_token_middle), 0);
1159
+ rb_define_method(rb_cLLaMAModel, "token_suffix", RUBY_METHOD_FUNC(_llama_model_token_suffix), 0);
1160
+ rb_define_method(rb_cLLaMAModel, "token_eot", RUBY_METHOD_FUNC(_llama_model_token_eot), 0);
1136
1161
  }
1137
1162
 
1138
1163
  private:
@@ -1319,10 +1344,10 @@ private:
1319
1344
 
1320
1345
  static VALUE _llama_model_tokenize(int argc, VALUE* argv, VALUE self) {
1321
1346
  VALUE kw_args = Qnil;
1322
- ID kw_table[3] = { rb_intern("text"), rb_intern("n_max_tokens"), rb_intern("add_bos") };
1323
- VALUE kw_values[3] = { Qundef, Qundef, Qundef };
1347
+ ID kw_table[4] = { rb_intern("text"), rb_intern("n_max_tokens"), rb_intern("add_bos"), rb_intern("special") };
1348
+ VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
1324
1349
  rb_scan_args(argc, argv, ":", &kw_args);
1325
- rb_get_kwargs(kw_args, kw_table, 1, 2, kw_values);
1350
+ rb_get_kwargs(kw_args, kw_table, 1, 3, kw_values);
1326
1351
 
1327
1352
  if (!RB_TYPE_P(kw_values[0], T_STRING)) {
1328
1353
  rb_raise(rb_eArgError, "text must be a String");
@@ -1336,15 +1361,20 @@ private:
1336
1361
  rb_raise(rb_eArgError, "add_bos must be a boolean");
1337
1362
  return Qnil;
1338
1363
  }
1364
+ if (kw_values[3] != Qundef && (kw_values[3] != Qtrue && kw_values[3] != Qfalse)) {
1365
+ rb_raise(rb_eArgError, "special must be a boolean");
1366
+ return Qnil;
1367
+ }
1339
1368
 
1340
1369
  VALUE text_ = kw_values[0];
1341
1370
  std::string text = StringValueCStr(text_);
1342
1371
  const bool add_bos = kw_values[2] == Qtrue ? true : false;
1372
+ const bool special = kw_values[3] == Qtrue ? true : false;
1343
1373
  const int n_max_tokens = kw_values[1] != Qundef ? NUM2INT(kw_values[1]) : text.size() + (add_bos ? 1 : 0);
1344
1374
 
1345
1375
  llama_token* tokens = ALLOCA_N(llama_token, n_max_tokens);
1346
1376
  LLaMAModelWrapper* ptr = get_llama_model(self);
1347
- const int n_tokens = llama_tokenize(ptr->model, text.c_str(), text.size(), tokens, n_max_tokens, add_bos);
1377
+ const int n_tokens = llama_tokenize(ptr->model, text.c_str(), text.size(), tokens, n_max_tokens, add_bos, special);
1348
1378
 
1349
1379
  if (n_tokens < 0) {
1350
1380
  rb_raise(rb_eRuntimeError, "failed to tokenize. The numebr of tokens (%d) is greater than n_max_tokens.", -n_tokens);
@@ -1376,6 +1406,62 @@ private:
1376
1406
  LLaMAModelWrapper* ptr = get_llama_model(self);
1377
1407
  return UINT2NUM(llama_model_n_params(ptr->model));
1378
1408
  }
1409
+
1410
+ static VALUE _llama_model_get_text(VALUE self, VALUE token_) {
1411
+ LLaMAModelWrapper* ptr = get_llama_model(self);
1412
+ const llama_token token = NUM2INT(token_);
1413
+ const char* text = llama_token_get_text(ptr->model, token);
1414
+ return rb_utf8_str_new_cstr(text);
1415
+ }
1416
+
1417
+ static VALUE _llama_model_get_score(VALUE self, VALUE token_) {
1418
+ LLaMAModelWrapper* ptr = get_llama_model(self);
1419
+ const llama_token token = NUM2INT(token_);
1420
+ const float score = llama_token_get_score(ptr->model, token);
1421
+ return DBL2NUM(score);
1422
+ }
1423
+
1424
+ static VALUE _llama_model_get_type(VALUE self, VALUE token_) {
1425
+ LLaMAModelWrapper* ptr = get_llama_model(self);
1426
+ const llama_token token = NUM2INT(token_);
1427
+ const int type = llama_token_get_type(ptr->model, token);
1428
+ return INT2NUM(type);
1429
+ }
1430
+
1431
+ static VALUE _llama_model_token_bos(VALUE self) {
1432
+ LLaMAModelWrapper* ptr = get_llama_model(self);
1433
+ return INT2NUM(llama_token_bos(ptr->model));
1434
+ }
1435
+
1436
+ static VALUE _llama_model_token_eos(VALUE self) {
1437
+ LLaMAModelWrapper* ptr = get_llama_model(self);
1438
+ return INT2NUM(llama_token_eos(ptr->model));
1439
+ }
1440
+
1441
+ static VALUE _llama_model_token_nl(VALUE self) {
1442
+ LLaMAModelWrapper* ptr = get_llama_model(self);
1443
+ return INT2NUM(llama_token_nl(ptr->model));
1444
+ }
1445
+
1446
+ static VALUE _llama_model_token_prefix(VALUE self) {
1447
+ LLaMAModelWrapper* ptr = get_llama_model(self);
1448
+ return INT2NUM(llama_token_prefix(ptr->model));
1449
+ }
1450
+
1451
+ static VALUE _llama_model_token_middle(VALUE self) {
1452
+ LLaMAModelWrapper* ptr = get_llama_model(self);
1453
+ return INT2NUM(llama_token_middle(ptr->model));
1454
+ }
1455
+
1456
+ static VALUE _llama_model_token_suffix(VALUE self) {
1457
+ LLaMAModelWrapper* ptr = get_llama_model(self);
1458
+ return INT2NUM(llama_token_suffix(ptr->model));
1459
+ }
1460
+
1461
+ static VALUE _llama_model_token_eot(VALUE self) {
1462
+ LLaMAModelWrapper* ptr = get_llama_model(self);
1463
+ return INT2NUM(llama_token_eot(ptr->model));
1464
+ }
1379
1465
  };
1380
1466
 
1381
1467
  const rb_data_type_t RbLLaMAModel::llama_model_type = {
@@ -1650,16 +1736,6 @@ public:
1650
1736
  rb_define_method(rb_cLLaMAContext, "decode", RUBY_METHOD_FUNC(_llama_context_decode), 1);
1651
1737
  rb_define_method(rb_cLLaMAContext, "logits", RUBY_METHOD_FUNC(_llama_context_logits), 0);
1652
1738
  rb_define_method(rb_cLLaMAContext, "embeddings", RUBY_METHOD_FUNC(_llama_context_embeddings), 0);
1653
- rb_define_method(rb_cLLaMAContext, "text", RUBY_METHOD_FUNC(_llama_context_text), 1);
1654
- rb_define_method(rb_cLLaMAContext, "score", RUBY_METHOD_FUNC(_llama_context_score), 1);
1655
- rb_define_method(rb_cLLaMAContext, "type", RUBY_METHOD_FUNC(_llama_context_type), 1);
1656
- rb_define_method(rb_cLLaMAContext, "token_bos", RUBY_METHOD_FUNC(_llama_context_token_bos), 0);
1657
- rb_define_method(rb_cLLaMAContext, "token_eos", RUBY_METHOD_FUNC(_llama_context_token_eos), 0);
1658
- rb_define_method(rb_cLLaMAContext, "token_nl", RUBY_METHOD_FUNC(_llama_context_token_nl), 0);
1659
- rb_define_method(rb_cLLaMAContext, "token_prefix", RUBY_METHOD_FUNC(_llama_context_token_prefix), 0);
1660
- rb_define_method(rb_cLLaMAContext, "token_middle", RUBY_METHOD_FUNC(_llama_context_token_middle), 0);
1661
- rb_define_method(rb_cLLaMAContext, "token_suffix", RUBY_METHOD_FUNC(_llama_context_token_suffix), 0);
1662
- rb_define_method(rb_cLLaMAContext, "token_eot", RUBY_METHOD_FUNC(_llama_context_token_eot), 0);
1663
1739
  rb_define_method(rb_cLLaMAContext, "n_ctx", RUBY_METHOD_FUNC(_llama_context_n_ctx), 0);
1664
1740
  rb_define_method(rb_cLLaMAContext, "timings", RUBY_METHOD_FUNC(_llama_context_get_timings), 0);
1665
1741
  rb_define_method(rb_cLLaMAContext, "print_timings", RUBY_METHOD_FUNC(_llama_context_print_timings), 0);
@@ -1673,8 +1749,7 @@ public:
1673
1749
  rb_define_method(rb_cLLaMAContext, "set_rng_seed", RUBY_METHOD_FUNC(_llama_context_set_rng_seed), 1);
1674
1750
  rb_define_method(rb_cLLaMAContext, "load_session_file", RUBY_METHOD_FUNC(_llama_context_load_session_file), -1);
1675
1751
  rb_define_method(rb_cLLaMAContext, "save_session_file", RUBY_METHOD_FUNC(_llama_context_save_session_file), -1);
1676
- rb_define_method(rb_cLLaMAContext, "sample_repetition_penalty", RUBY_METHOD_FUNC(_llama_context_sample_repetition_penalty), -1);
1677
- rb_define_method(rb_cLLaMAContext, "sample_frequency_and_presence_penalties", RUBY_METHOD_FUNC(_llama_context_sample_frequency_and_presence_penalties), -1);
1752
+ rb_define_method(rb_cLLaMAContext, "sample_repetition_penalties", RUBY_METHOD_FUNC(_llama_context_sample_repetition_penalties), -1);
1678
1753
  rb_define_method(rb_cLLaMAContext, "sample_classifier_free_guidance", RUBY_METHOD_FUNC(_llama_context_sample_classifier_free_guidance), -1);
1679
1754
  rb_define_method(rb_cLLaMAContext, "sample_softmax", RUBY_METHOD_FUNC(_llama_context_sample_softmax), 1);
1680
1755
  rb_define_method(rb_cLLaMAContext, "sample_top_k", RUBY_METHOD_FUNC(_llama_context_sample_top_k), -1);
@@ -1907,102 +1982,6 @@ private:
1907
1982
  return output;
1908
1983
  }
1909
1984
 
1910
- static VALUE _llama_context_text(VALUE self, VALUE token_) {
1911
- LLaMAContextWrapper* ptr = get_llama_context(self);
1912
- if (ptr->ctx == NULL) {
1913
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1914
- return Qnil;
1915
- }
1916
- const llama_token token = NUM2INT(token_);
1917
- const char* text = llama_token_get_text(ptr->ctx, token);
1918
- return rb_utf8_str_new_cstr(text);
1919
- }
1920
-
1921
- static VALUE _llama_context_score(VALUE self, VALUE token_) {
1922
- LLaMAContextWrapper* ptr = get_llama_context(self);
1923
- if (ptr->ctx == NULL) {
1924
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1925
- return Qnil;
1926
- }
1927
- const llama_token token = NUM2INT(token_);
1928
- const float score = llama_token_get_score(ptr->ctx, token);
1929
- return DBL2NUM(score);
1930
- }
1931
-
1932
- static VALUE _llama_context_type(VALUE self, VALUE token_) {
1933
- LLaMAContextWrapper* ptr = get_llama_context(self);
1934
- if (ptr->ctx == NULL) {
1935
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1936
- return Qnil;
1937
- }
1938
- const llama_token token = NUM2INT(token_);
1939
- const int type = llama_token_get_type(ptr->ctx, token);
1940
- return INT2NUM(type);
1941
- }
1942
-
1943
- static VALUE _llama_context_token_bos(VALUE self) {
1944
- LLaMAContextWrapper* ptr = get_llama_context(self);
1945
- if (ptr->ctx == NULL) {
1946
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1947
- return Qnil;
1948
- }
1949
- return INT2NUM(llama_token_bos(ptr->ctx));
1950
- }
1951
-
1952
- static VALUE _llama_context_token_eos(VALUE self) {
1953
- LLaMAContextWrapper* ptr = get_llama_context(self);
1954
- if (ptr->ctx == NULL) {
1955
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1956
- return Qnil;
1957
- }
1958
- return INT2NUM(llama_token_eos(ptr->ctx));
1959
- }
1960
-
1961
- static VALUE _llama_context_token_nl(VALUE self) {
1962
- LLaMAContextWrapper* ptr = get_llama_context(self);
1963
- if (ptr->ctx == NULL) {
1964
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1965
- return Qnil;
1966
- }
1967
- return INT2NUM(llama_token_nl(ptr->ctx));
1968
- }
1969
-
1970
- static VALUE _llama_context_token_prefix(VALUE self) {
1971
- LLaMAContextWrapper* ptr = get_llama_context(self);
1972
- if (ptr->ctx == NULL) {
1973
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1974
- return Qnil;
1975
- }
1976
- return INT2NUM(llama_token_prefix(ptr->ctx));
1977
- }
1978
-
1979
- static VALUE _llama_context_token_middle(VALUE self) {
1980
- LLaMAContextWrapper* ptr = get_llama_context(self);
1981
- if (ptr->ctx == NULL) {
1982
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1983
- return Qnil;
1984
- }
1985
- return INT2NUM(llama_token_middle(ptr->ctx));
1986
- }
1987
-
1988
- static VALUE _llama_context_token_suffix(VALUE self) {
1989
- LLaMAContextWrapper* ptr = get_llama_context(self);
1990
- if (ptr->ctx == NULL) {
1991
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1992
- return Qnil;
1993
- }
1994
- return INT2NUM(llama_token_suffix(ptr->ctx));
1995
- }
1996
-
1997
- static VALUE _llama_context_token_eot(VALUE self) {
1998
- LLaMAContextWrapper* ptr = get_llama_context(self);
1999
- if (ptr->ctx == NULL) {
2000
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2001
- return Qnil;
2002
- }
2003
- return INT2NUM(llama_token_eot(ptr->ctx));
2004
- }
2005
-
2006
1985
  static VALUE _llama_context_n_ctx(VALUE self) {
2007
1986
  LLaMAContextWrapper* ptr = get_llama_context(self);
2008
1987
  if (ptr->ctx == NULL) {
@@ -2211,14 +2190,14 @@ private:
2211
2190
  return Qnil;
2212
2191
  }
2213
2192
 
2214
- static VALUE _llama_context_sample_repetition_penalty(int argc, VALUE* argv, VALUE self) {
2193
+ static VALUE _llama_context_sample_repetition_penalties(int argc, VALUE* argv, VALUE self) {
2215
2194
  VALUE kw_args = Qnil;
2216
- ID kw_table[1] = { rb_intern("penalty") };
2217
- VALUE kw_values[1] = { Qundef };
2195
+ ID kw_table[3] = { rb_intern("penalty_repeat"), rb_intern("penalty_freq"), rb_intern("penalty_present") };
2196
+ VALUE kw_values[3] = { Qundef, Qundef, Qundef };
2218
2197
  VALUE candidates = Qnil;
2219
2198
  VALUE last_n_tokens = Qnil;
2220
2199
  rb_scan_args(argc, argv, "2:", &candidates, &last_n_tokens, &kw_args);
2221
- rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
2200
+ rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
2222
2201
 
2223
2202
  if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
2224
2203
  rb_raise(rb_eArgError, "candidates must be a TokenDataArray");
@@ -2229,56 +2208,15 @@ private:
2229
2208
  return Qnil;
2230
2209
  }
2231
2210
  if (!RB_FLOAT_TYPE_P(kw_values[0])) {
2232
- rb_raise(rb_eArgError, "penalty must be a float");
2211
+ rb_raise(rb_eArgError, "penalty_repeat must be a float");
2233
2212
  return Qnil;
2234
2213
  }
2235
-
2236
- const size_t last_tokens_size = RARRAY_LEN(last_n_tokens);
2237
- std::vector<llama_token> last_n_tokens_data(last_tokens_size);
2238
- for (size_t i = 0; i < last_tokens_size; i++) {
2239
- last_n_tokens_data[i] = NUM2INT(rb_ary_entry(last_n_tokens, i));
2240
- }
2241
-
2242
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
2243
- if (ctx_ptr->ctx == NULL) {
2244
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2245
- return Qnil;
2246
- }
2247
- LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
2248
- if (cnd_ptr->array.data == nullptr) {
2249
- rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
2250
- return Qnil;
2251
- }
2252
- const float penalty = NUM2DBL(kw_values[0]);
2253
-
2254
- llama_sample_repetition_penalty(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size, penalty);
2255
-
2256
- return Qnil;
2257
- }
2258
-
2259
- static VALUE _llama_context_sample_frequency_and_presence_penalties(int argc, VALUE* argv, VALUE self) {
2260
- VALUE kw_args = Qnil;
2261
- ID kw_table[2] = { rb_intern("frequency"), rb_intern("presence") };
2262
- VALUE kw_values[2] = { Qundef, Qundef };
2263
- VALUE candidates = Qnil;
2264
- VALUE last_n_tokens = Qnil;
2265
- rb_scan_args(argc, argv, "2:", &candidates, &last_n_tokens, &kw_args);
2266
- rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
2267
-
2268
- if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
2269
- rb_raise(rb_eArgError, "candidates must be a TokenDataArray");
2270
- return Qnil;
2271
- }
2272
- if (!RB_TYPE_P(last_n_tokens, T_ARRAY)) {
2273
- rb_raise(rb_eArgError, "last_n_tokens must be an Array");
2274
- return Qnil;
2275
- }
2276
- if (!RB_FLOAT_TYPE_P(kw_values[0])) {
2277
- rb_raise(rb_eArgError, "frequency must be a float");
2214
+ if (!RB_FLOAT_TYPE_P(kw_values[1])) {
2215
+ rb_raise(rb_eArgError, "penalty_freq must be a float");
2278
2216
  return Qnil;
2279
2217
  }
2280
- if (!RB_FLOAT_TYPE_P(kw_values[1])) {
2281
- rb_raise(rb_eArgError, "presence must be a float");
2218
+ if (!RB_FLOAT_TYPE_P(kw_values[2])) {
2219
+ rb_raise(rb_eArgError, "penalty_present must be a float");
2282
2220
  return Qnil;
2283
2221
  }
2284
2222
 
@@ -2298,11 +2236,12 @@ private:
2298
2236
  rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
2299
2237
  return Qnil;
2300
2238
  }
2239
+ const float penalty_repeat = NUM2DBL(kw_values[0]);
2240
+ const float penalty_freq = NUM2DBL(kw_values[1]);
2241
+ const float penalty_present = NUM2DBL(kw_values[2]);
2301
2242
 
2302
- const float alpha_frequency = NUM2DBL(kw_values[0]);
2303
- const float alpha_presence = NUM2DBL(kw_values[1]);
2304
-
2305
- llama_sample_frequency_and_presence_penalties(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size, alpha_frequency, alpha_presence);
2243
+ llama_sample_repetition_penalties(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size,
2244
+ penalty_repeat, penalty_freq, penalty_present);
2306
2245
 
2307
2246
  return Qnil;
2308
2247
  }