llama_cpp 0.8.0 → 0.9.1

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 8045208b5f7801979212a4f6ed395217e78f06bcfbc2d0362aaaa04c529745cd
4
- data.tar.gz: 4011dfe279d8d4041c6c79dc5a6bad199777f83b5f0559f11ccd2f68c957e462
3
+ metadata.gz: dae7507ce41f18e3fd0fb2d7445275a387a3914068aa9eef922f260de699970a
4
+ data.tar.gz: d66cc2629aeca3285bc10988f8c410fb8cf5b7f1fe6f835b5dc60e9dcab4be9d
5
5
  SHA512:
6
- metadata.gz: d15e74da491773961006eca8ca6c6d80b30ffc995c56a9140961be0002eb09134f1a029c4e8ee192497fb7256fe36cf1c3ed928967ce57ece4c7a0904392c8fe
7
- data.tar.gz: a863596304ddb9ac5e4be2b2b65bebc7d3913705b8a0f516debfee0ca213f9dca69707edda8d70cfafb15500fcb6e70cffb6d5d1119302d24e05059c50f0da77
6
+ metadata.gz: 3e3e92aa38413877620947ec7996494cd720a3c211fcdf1973ce0d7a9a7e8803e293e2ce2f601b11e35858c5b4ef6b00d716069e322ea8d6b4c93412990fd746
7
+ data.tar.gz: 20a1e9e0e5812da9b00787afbf0f3aa0b762c8168f54ce3b7f2f25ff5b61cca5b2e7ab5faa065fbc3e266468d1c5747b8e0779fc7e073cc66240d1f3085e71c7
data/CHANGELOG.md CHANGED
@@ -1,3 +1,22 @@
1
+ ## [[0.9.1](https://github.com/yoshoku/llama_cpp.rb/compare/v0.9.0...v0.9.1)] - 2023-11-03
2
+
3
+ - Bump bundled llama.cpp from b1429 to b1472
4
+ - Rename `kv_cahe_tokens_rm` method to `kv_cahce_clear` in Context.
5
+ - Add `sample_min_p method` to Context.
6
+ - Add `rope_scaling_type`, `rope_freq_base`, `rope_freq_scale`, `yarn_ext_factor`, `yarn_attn_factor`, `yarn_beta_fast`, `yarn_beta_slow`, and `yarn_orig_ctx` to ContextParams.
7
+ - Add `pure` to ModelQuantizeParams.
8
+ - Add contstants for RoPE scaling type.
9
+
10
+ ## [[0.9.0](https://github.com/yoshoku/llama_cpp.rb/compare/v0.8.0...v0.9.0)] - 2023-10-28
11
+
12
+ - Fix missing object file for ggml-backend when building with metal and cublas options.
13
+
14
+ **Breaking Changes**
15
+ - Bump bundled llama.cpp from b1405 to b1429
16
+ - Move following methods from Context to Model:
17
+ - text, score, type, token_bos, token_eos, token_nl, token_prefix, token_middle, token_suffix, and token_eos.
18
+ - Add `sample_repetition_penalties` method, which integrates sample_frequency_and_presence_penalties and sample_repetition_penalty methods.
19
+
1
20
  ## [[0.8.0](https://github.com/yoshoku/llama_cpp.rb/compare/v0.7.1...v0.8.0)] - 2023-10-21
2
21
 
3
22
  **Breaking Changes**
data/examples/chat.rb CHANGED
@@ -83,10 +83,12 @@ class Chat < Thor # rubocop:disable Metrics/ClassLength, Style/Documentation
83
83
  candidates = LLaMACpp::TokenDataArray.new(base_candidates)
84
84
 
85
85
  last_n_repeat = [last_n_tokens.size, options[:repeat_last_n], n_ctx].min
86
- context.sample_repetition_penalty(candidates, last_n_tokens[-last_n_repeat..], penalty: options[:repeat_penalty])
87
- context.sample_frequency_and_presence_penalties(
88
- candidates, last_n_tokens[-last_n_repeat..],
89
- frequency: options[:frequency_penalty], presence: options[:presence_penalty]
86
+ context.sample_repetition_penalties(
87
+ candidates,
88
+ last_n_tokens[-last_n_repeat..],
89
+ penalty_repeat: options[:repeat_penalty],
90
+ penalty_freq: options[:frequency_penalty],
91
+ penalty_present: options[:presence_penalty]
90
92
  )
91
93
 
92
94
  context.sample_top_k(candidates, k: options[:top_k])
@@ -99,8 +101,8 @@ class Chat < Thor # rubocop:disable Metrics/ClassLength, Style/Documentation
99
101
  last_n_tokens.shift
100
102
  last_n_tokens.push(id)
101
103
 
102
- if id == context.token_eos
103
- id = context.token_nl
104
+ if id == context.model.token_eos
105
+ id = context.model.token_nl
104
106
  unless antiprompt.empty?
105
107
  first_antiprompt = context.model.tokenize(text: antiprompt, add_bos: false)
106
108
  embd_input.concat(first_antiprompt)
@@ -5,7 +5,7 @@ require 'fileutils'
5
5
 
6
6
  abort 'libstdc++ is not found.' unless have_library('stdc++')
7
7
 
8
- $srcs = %w[ggml.c ggml-backend.c ggml-alloc.c llama.cpp llama_cpp.cpp]
8
+ $srcs = %w[ggml.c ggml-backend.c ggml-alloc.c ggml-quants.c llama.cpp llama_cpp.cpp]
9
9
  $srcs << 'ggml-opencl.cpp' if with_config('clblast')
10
10
  $srcs << 'ggml-mpi.c' if with_config('mpi')
11
11
  $CFLAGS << ' -w -DNDEBUG'
@@ -18,12 +18,6 @@ if RUBY_PLATFORM.match?(/darwin|linux|bsd/) && try_compile('#include <stdio.h>',
18
18
  $CXXFLAGS << ' -pthread'
19
19
  end
20
20
 
21
- unless with_config('no_k_quants')
22
- $CFLAGS << ' -DGGML_USE_K_QUANTS'
23
- $CXXFLAGS << ' -DGGML_USE_K_QUANTS'
24
- $srcs << 'k_quants.c'
25
- end
26
-
27
21
  if with_config('qkk_64')
28
22
  $CFLAGS << ' -DGGML_QKK_64'
29
23
  $CXXFLAGS << ' -DGGML_QKK_64'
@@ -53,16 +47,14 @@ if with_config('metal')
53
47
  $CFLAGS << ' -DGGML_USE_METAL'
54
48
  $CXXFLAGS << ' -DGGML_USE_METAL'
55
49
  $LDFLAGS << ' -framework Foundation -framework Metal -framework MetalKit'
56
- $objs = %w[ggml.o ggml-alloc.o ggml-metal.o llama.o llama_cpp.o]
57
- $objs << 'k_quants.o' unless with_config('no_k_quants')
50
+ $objs = %w[ggml.o ggml-backend.o ggml-alloc.o ggml-quants.o ggml-metal.o llama.o llama_cpp.o]
58
51
  end
59
52
 
60
53
  if with_config('cublas')
61
54
  $CFLAGS << ' -DGGML_USE_CUBLAS -I/usr/local/cuda/include'
62
55
  $CXXFLAGS << ' -DGGML_USE_CUBLAS -I/usr/local/cuda/include'
63
56
  $LDFLAGS << ' -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64'
64
- $objs = %w[ggml.o ggml-alloc.o ggml-cuda.o llama.o llama_cpp.o]
65
- $objs << 'k_quants.o' unless with_config('no_k_quants')
57
+ $objs = %w[ggml.o ggml-backend.o ggml-alloc.o ggml-quants.o ggml-cuda.o llama.o llama_cpp.o]
66
58
  end
67
59
 
68
60
  if with_config('clblast')
@@ -796,10 +796,22 @@ public:
796
796
  rb_define_method(rb_cLLaMAContextParams, "n_threads", RUBY_METHOD_FUNC(_llama_context_params_get_n_threads), 0);
797
797
  rb_define_method(rb_cLLaMAContextParams, "n_threads_batch=", RUBY_METHOD_FUNC(_llama_context_params_set_n_threads_batch), 1);
798
798
  rb_define_method(rb_cLLaMAContextParams, "n_threads_batch", RUBY_METHOD_FUNC(_llama_context_params_get_n_threads_batch), 0);
799
+ rb_define_method(rb_cLLaMAContextParams, "rope_scaling_type=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_scaling_type), 1);
800
+ rb_define_method(rb_cLLaMAContextParams, "rope_scaling_type", RUBY_METHOD_FUNC(_llama_context_params_get_rope_scaling_type), 0);
799
801
  rb_define_method(rb_cLLaMAContextParams, "rope_freq_base=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_freq_base), 1);
800
802
  rb_define_method(rb_cLLaMAContextParams, "rope_freq_base", RUBY_METHOD_FUNC(_llama_context_params_get_rope_freq_base), 0);
801
803
  rb_define_method(rb_cLLaMAContextParams, "rope_freq_scale=", RUBY_METHOD_FUNC(_llama_context_params_set_rope_freq_scale), 1);
802
804
  rb_define_method(rb_cLLaMAContextParams, "rope_freq_scale", RUBY_METHOD_FUNC(_llama_context_params_get_rope_freq_scale), 0);
805
+ rb_define_method(rb_cLLaMAContextParams, "yarn_ext_factor=", RUBY_METHOD_FUNC(_llama_context_params_set_yarn_ext_factor), 1);
806
+ rb_define_method(rb_cLLaMAContextParams, "yarn_ext_factor", RUBY_METHOD_FUNC(_llama_context_params_get_yarn_ext_factor), 0);
807
+ rb_define_method(rb_cLLaMAContextParams, "yarn_attn_factor=", RUBY_METHOD_FUNC(_llama_context_params_set_yarn_attn_factor), 1);
808
+ rb_define_method(rb_cLLaMAContextParams, "yarn_attn_factor", RUBY_METHOD_FUNC(_llama_context_params_get_yarn_attn_factor), 0);
809
+ rb_define_method(rb_cLLaMAContextParams, "yarn_beta_fast=", RUBY_METHOD_FUNC(_llama_context_params_set_yarn_beta_fast), 1);
810
+ rb_define_method(rb_cLLaMAContextParams, "yarn_beta_fast", RUBY_METHOD_FUNC(_llama_context_params_get_yarn_beta_fast), 0);
811
+ rb_define_method(rb_cLLaMAContextParams, "yarn_beta_slow=", RUBY_METHOD_FUNC(_llama_context_params_set_yarn_beta_slow), 1);
812
+ rb_define_method(rb_cLLaMAContextParams, "yarn_beta_slow", RUBY_METHOD_FUNC(_llama_context_params_get_yarn_beta_slow), 0);
813
+ rb_define_method(rb_cLLaMAContextParams, "yarn_orig_ctx=", RUBY_METHOD_FUNC(_llama_context_params_set_yarn_orig_ctx), 1);
814
+ rb_define_method(rb_cLLaMAContextParams, "yarn_orig_ctx", RUBY_METHOD_FUNC(_llama_context_params_get_yarn_orig_ctx), 0);
803
815
  rb_define_method(rb_cLLaMAContextParams, "mul_mat_q=", RUBY_METHOD_FUNC(_llama_context_params_set_mul_mat_q), 1);
804
816
  rb_define_method(rb_cLLaMAContextParams, "mul_mat_q", RUBY_METHOD_FUNC(_llama_context_params_get_mul_mat_q), 0);
805
817
  rb_define_method(rb_cLLaMAContextParams, "f16_kv=", RUBY_METHOD_FUNC(_llama_context_params_set_f16_kv), 1);
@@ -883,6 +895,18 @@ private:
883
895
  return INT2NUM(ptr->params.n_threads_batch);
884
896
  }
885
897
 
898
+ // rope_scaling_type
899
+ static VALUE _llama_context_params_set_rope_scaling_type(VALUE self, VALUE scaling_type) {
900
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
901
+ ptr->params.rope_scaling_type = NUM2INT(scaling_type);
902
+ return INT2NUM(ptr->params.rope_scaling_type);
903
+ }
904
+
905
+ static VALUE _llama_context_params_get_rope_scaling_type(VALUE self) {
906
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
907
+ return INT2NUM(ptr->params.rope_scaling_type);
908
+ }
909
+
886
910
  // rope_freq_base
887
911
  static VALUE _llama_context_params_set_rope_freq_base(VALUE self, VALUE rope_freq_base) {
888
912
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
@@ -907,6 +931,66 @@ private:
907
931
  return DBL2NUM(ptr->params.rope_freq_scale);
908
932
  }
909
933
 
934
+ // yarn_ext_factor
935
+ static VALUE _llama_context_params_set_yarn_ext_factor(VALUE self, VALUE yarn_ext_factor) {
936
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
937
+ ptr->params.yarn_ext_factor = NUM2DBL(yarn_ext_factor);
938
+ return DBL2NUM(ptr->params.yarn_ext_factor);
939
+ }
940
+
941
+ static VALUE _llama_context_params_get_yarn_ext_factor(VALUE self) {
942
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
943
+ return DBL2NUM(ptr->params.yarn_ext_factor);
944
+ }
945
+
946
+ // yarn_attn_factor
947
+ static VALUE _llama_context_params_set_yarn_attn_factor(VALUE self, VALUE yarn_attn_factor) {
948
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
949
+ ptr->params.yarn_attn_factor = NUM2DBL(yarn_attn_factor);
950
+ return DBL2NUM(ptr->params.yarn_attn_factor);
951
+ }
952
+
953
+ static VALUE _llama_context_params_get_yarn_attn_factor(VALUE self) {
954
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
955
+ return DBL2NUM(ptr->params.yarn_attn_factor);
956
+ }
957
+
958
+ // yarn_beta_fast
959
+ static VALUE _llama_context_params_set_yarn_beta_fast(VALUE self, VALUE yarn_beta_fast) {
960
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
961
+ ptr->params.yarn_beta_fast = NUM2DBL(yarn_beta_fast);
962
+ return DBL2NUM(ptr->params.yarn_beta_fast);
963
+ }
964
+
965
+ static VALUE _llama_context_params_get_yarn_beta_fast(VALUE self) {
966
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
967
+ return DBL2NUM(ptr->params.yarn_beta_fast);
968
+ }
969
+
970
+ // yarn_beta_slow
971
+ static VALUE _llama_context_params_set_yarn_beta_slow(VALUE self, VALUE yarn_beta_slow) {
972
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
973
+ ptr->params.yarn_beta_slow = NUM2DBL(yarn_beta_slow);
974
+ return DBL2NUM(ptr->params.yarn_beta_slow);
975
+ }
976
+
977
+ static VALUE _llama_context_params_get_yarn_beta_slow(VALUE self) {
978
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
979
+ return DBL2NUM(ptr->params.yarn_beta_slow);
980
+ }
981
+
982
+ // yarn_orig_ctx
983
+ static VALUE _llama_context_params_set_yarn_orig_ctx(VALUE self, VALUE yarn_orig_ctx) {
984
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
985
+ ptr->params.yarn_orig_ctx = NUM2UINT(yarn_orig_ctx);
986
+ return UINT2NUM(ptr->params.yarn_orig_ctx);
987
+ }
988
+
989
+ static VALUE _llama_context_params_get_yarn_orig_ctx(VALUE self) {
990
+ LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
991
+ return UINT2NUM(ptr->params.yarn_orig_ctx);
992
+ }
993
+
910
994
  // mul_mat_q
911
995
  static VALUE _llama_context_params_set_mul_mat_q(VALUE self, VALUE mul_mat_q) {
912
996
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
@@ -1011,6 +1095,8 @@ public:
1011
1095
  rb_define_method(rb_cLLaMAModelQuantizeParams, "quantize_output_tensor", RUBY_METHOD_FUNC(_llama_model_quantize_params_get_quantize_output_tensor), 0);
1012
1096
  rb_define_method(rb_cLLaMAModelQuantizeParams, "only_copy=", RUBY_METHOD_FUNC(_llama_model_quantize_params_set_only_copy), 1);
1013
1097
  rb_define_method(rb_cLLaMAModelQuantizeParams, "only_copy", RUBY_METHOD_FUNC(_llama_model_quantize_params_get_only_copy), 0);
1098
+ rb_define_method(rb_cLLaMAModelQuantizeParams, "pure=", RUBY_METHOD_FUNC(_llama_model_quantize_params_set_pure), 1);
1099
+ rb_define_method(rb_cLLaMAModelQuantizeParams, "pure", RUBY_METHOD_FUNC(_llama_model_quantize_params_get_pure), 0);
1014
1100
  }
1015
1101
 
1016
1102
  private:
@@ -1083,6 +1169,18 @@ private:
1083
1169
  LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
1084
1170
  return ptr->params.only_copy ? Qtrue : Qfalse;
1085
1171
  }
1172
+
1173
+ // pure
1174
+ static VALUE _llama_model_quantize_params_set_pure(VALUE self, VALUE pure) {
1175
+ LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
1176
+ ptr->params.pure = RTEST(pure) ? true : false;
1177
+ return ptr->params.pure ? Qtrue : Qfalse;
1178
+ }
1179
+
1180
+ static VALUE _llama_model_quantize_params_get_pure(VALUE self) {
1181
+ LLaMAModelQuantizeParamsWrapper* ptr = get_llama_model_quantize_params(self);
1182
+ return ptr->params.pure ? Qtrue : Qfalse;
1183
+ }
1086
1184
  };
1087
1185
 
1088
1186
  const rb_data_type_t RbLLaMAModelQuantizeParams::llama_model_quantize_params_type = {
@@ -1148,6 +1246,16 @@ public:
1148
1246
  rb_define_method(rb_cLLaMAModel, "desc", RUBY_METHOD_FUNC(_llama_model_get_model_desc), 0);
1149
1247
  rb_define_method(rb_cLLaMAModel, "size", RUBY_METHOD_FUNC(_llama_model_get_model_size), 0);
1150
1248
  rb_define_method(rb_cLLaMAModel, "n_params", RUBY_METHOD_FUNC(_llama_model_get_model_n_params), 0);
1249
+ rb_define_method(rb_cLLaMAModel, "text", RUBY_METHOD_FUNC(_llama_model_get_text), 1);
1250
+ rb_define_method(rb_cLLaMAModel, "score", RUBY_METHOD_FUNC(_llama_model_get_score), 1);
1251
+ rb_define_method(rb_cLLaMAModel, "type", RUBY_METHOD_FUNC(_llama_model_get_type), 1);
1252
+ rb_define_method(rb_cLLaMAModel, "token_bos", RUBY_METHOD_FUNC(_llama_model_token_bos), 0);
1253
+ rb_define_method(rb_cLLaMAModel, "token_eos", RUBY_METHOD_FUNC(_llama_model_token_eos), 0);
1254
+ rb_define_method(rb_cLLaMAModel, "token_nl", RUBY_METHOD_FUNC(_llama_model_token_nl), 0);
1255
+ rb_define_method(rb_cLLaMAModel, "token_prefix", RUBY_METHOD_FUNC(_llama_model_token_prefix), 0);
1256
+ rb_define_method(rb_cLLaMAModel, "token_middle", RUBY_METHOD_FUNC(_llama_model_token_middle), 0);
1257
+ rb_define_method(rb_cLLaMAModel, "token_suffix", RUBY_METHOD_FUNC(_llama_model_token_suffix), 0);
1258
+ rb_define_method(rb_cLLaMAModel, "token_eot", RUBY_METHOD_FUNC(_llama_model_token_eot), 0);
1151
1259
  }
1152
1260
 
1153
1261
  private:
@@ -1396,6 +1504,62 @@ private:
1396
1504
  LLaMAModelWrapper* ptr = get_llama_model(self);
1397
1505
  return UINT2NUM(llama_model_n_params(ptr->model));
1398
1506
  }
1507
+
1508
+ static VALUE _llama_model_get_text(VALUE self, VALUE token_) {
1509
+ LLaMAModelWrapper* ptr = get_llama_model(self);
1510
+ const llama_token token = NUM2INT(token_);
1511
+ const char* text = llama_token_get_text(ptr->model, token);
1512
+ return rb_utf8_str_new_cstr(text);
1513
+ }
1514
+
1515
+ static VALUE _llama_model_get_score(VALUE self, VALUE token_) {
1516
+ LLaMAModelWrapper* ptr = get_llama_model(self);
1517
+ const llama_token token = NUM2INT(token_);
1518
+ const float score = llama_token_get_score(ptr->model, token);
1519
+ return DBL2NUM(score);
1520
+ }
1521
+
1522
+ static VALUE _llama_model_get_type(VALUE self, VALUE token_) {
1523
+ LLaMAModelWrapper* ptr = get_llama_model(self);
1524
+ const llama_token token = NUM2INT(token_);
1525
+ const int type = llama_token_get_type(ptr->model, token);
1526
+ return INT2NUM(type);
1527
+ }
1528
+
1529
+ static VALUE _llama_model_token_bos(VALUE self) {
1530
+ LLaMAModelWrapper* ptr = get_llama_model(self);
1531
+ return INT2NUM(llama_token_bos(ptr->model));
1532
+ }
1533
+
1534
+ static VALUE _llama_model_token_eos(VALUE self) {
1535
+ LLaMAModelWrapper* ptr = get_llama_model(self);
1536
+ return INT2NUM(llama_token_eos(ptr->model));
1537
+ }
1538
+
1539
+ static VALUE _llama_model_token_nl(VALUE self) {
1540
+ LLaMAModelWrapper* ptr = get_llama_model(self);
1541
+ return INT2NUM(llama_token_nl(ptr->model));
1542
+ }
1543
+
1544
+ static VALUE _llama_model_token_prefix(VALUE self) {
1545
+ LLaMAModelWrapper* ptr = get_llama_model(self);
1546
+ return INT2NUM(llama_token_prefix(ptr->model));
1547
+ }
1548
+
1549
+ static VALUE _llama_model_token_middle(VALUE self) {
1550
+ LLaMAModelWrapper* ptr = get_llama_model(self);
1551
+ return INT2NUM(llama_token_middle(ptr->model));
1552
+ }
1553
+
1554
+ static VALUE _llama_model_token_suffix(VALUE self) {
1555
+ LLaMAModelWrapper* ptr = get_llama_model(self);
1556
+ return INT2NUM(llama_token_suffix(ptr->model));
1557
+ }
1558
+
1559
+ static VALUE _llama_model_token_eot(VALUE self) {
1560
+ LLaMAModelWrapper* ptr = get_llama_model(self);
1561
+ return INT2NUM(llama_token_eot(ptr->model));
1562
+ }
1399
1563
  };
1400
1564
 
1401
1565
  const rb_data_type_t RbLLaMAModel::llama_model_type = {
@@ -1670,22 +1834,12 @@ public:
1670
1834
  rb_define_method(rb_cLLaMAContext, "decode", RUBY_METHOD_FUNC(_llama_context_decode), 1);
1671
1835
  rb_define_method(rb_cLLaMAContext, "logits", RUBY_METHOD_FUNC(_llama_context_logits), 0);
1672
1836
  rb_define_method(rb_cLLaMAContext, "embeddings", RUBY_METHOD_FUNC(_llama_context_embeddings), 0);
1673
- rb_define_method(rb_cLLaMAContext, "text", RUBY_METHOD_FUNC(_llama_context_text), 1);
1674
- rb_define_method(rb_cLLaMAContext, "score", RUBY_METHOD_FUNC(_llama_context_score), 1);
1675
- rb_define_method(rb_cLLaMAContext, "type", RUBY_METHOD_FUNC(_llama_context_type), 1);
1676
- rb_define_method(rb_cLLaMAContext, "token_bos", RUBY_METHOD_FUNC(_llama_context_token_bos), 0);
1677
- rb_define_method(rb_cLLaMAContext, "token_eos", RUBY_METHOD_FUNC(_llama_context_token_eos), 0);
1678
- rb_define_method(rb_cLLaMAContext, "token_nl", RUBY_METHOD_FUNC(_llama_context_token_nl), 0);
1679
- rb_define_method(rb_cLLaMAContext, "token_prefix", RUBY_METHOD_FUNC(_llama_context_token_prefix), 0);
1680
- rb_define_method(rb_cLLaMAContext, "token_middle", RUBY_METHOD_FUNC(_llama_context_token_middle), 0);
1681
- rb_define_method(rb_cLLaMAContext, "token_suffix", RUBY_METHOD_FUNC(_llama_context_token_suffix), 0);
1682
- rb_define_method(rb_cLLaMAContext, "token_eot", RUBY_METHOD_FUNC(_llama_context_token_eot), 0);
1683
1837
  rb_define_method(rb_cLLaMAContext, "n_ctx", RUBY_METHOD_FUNC(_llama_context_n_ctx), 0);
1684
1838
  rb_define_method(rb_cLLaMAContext, "timings", RUBY_METHOD_FUNC(_llama_context_get_timings), 0);
1685
1839
  rb_define_method(rb_cLLaMAContext, "print_timings", RUBY_METHOD_FUNC(_llama_context_print_timings), 0);
1686
1840
  rb_define_method(rb_cLLaMAContext, "reset_timings", RUBY_METHOD_FUNC(_llama_context_reset_timings), 0);
1687
1841
  rb_define_method(rb_cLLaMAContext, "kv_cache_token_count", RUBY_METHOD_FUNC(_llama_context_kv_cache_token_count), 0);
1688
- rb_define_method(rb_cLLaMAContext, "kv_cache_tokens_rm", RUBY_METHOD_FUNC(_llama_context_kv_cache_tokens_rm), 2);
1842
+ rb_define_method(rb_cLLaMAContext, "kv_cache_clear", RUBY_METHOD_FUNC(_llama_context_kv_cache_clear), 0);
1689
1843
  rb_define_method(rb_cLLaMAContext, "kv_cache_seq_rm", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_rm), 3);
1690
1844
  rb_define_method(rb_cLLaMAContext, "kv_cache_seq_cp", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_cp), 4);
1691
1845
  rb_define_method(rb_cLLaMAContext, "kv_cache_seq_keep", RUBY_METHOD_FUNC(_llama_context_kv_cache_seq_keep), 1);
@@ -1693,12 +1847,12 @@ public:
1693
1847
  rb_define_method(rb_cLLaMAContext, "set_rng_seed", RUBY_METHOD_FUNC(_llama_context_set_rng_seed), 1);
1694
1848
  rb_define_method(rb_cLLaMAContext, "load_session_file", RUBY_METHOD_FUNC(_llama_context_load_session_file), -1);
1695
1849
  rb_define_method(rb_cLLaMAContext, "save_session_file", RUBY_METHOD_FUNC(_llama_context_save_session_file), -1);
1696
- rb_define_method(rb_cLLaMAContext, "sample_repetition_penalty", RUBY_METHOD_FUNC(_llama_context_sample_repetition_penalty), -1);
1697
- rb_define_method(rb_cLLaMAContext, "sample_frequency_and_presence_penalties", RUBY_METHOD_FUNC(_llama_context_sample_frequency_and_presence_penalties), -1);
1850
+ rb_define_method(rb_cLLaMAContext, "sample_repetition_penalties", RUBY_METHOD_FUNC(_llama_context_sample_repetition_penalties), -1);
1698
1851
  rb_define_method(rb_cLLaMAContext, "sample_classifier_free_guidance", RUBY_METHOD_FUNC(_llama_context_sample_classifier_free_guidance), -1);
1699
1852
  rb_define_method(rb_cLLaMAContext, "sample_softmax", RUBY_METHOD_FUNC(_llama_context_sample_softmax), 1);
1700
1853
  rb_define_method(rb_cLLaMAContext, "sample_top_k", RUBY_METHOD_FUNC(_llama_context_sample_top_k), -1);
1701
1854
  rb_define_method(rb_cLLaMAContext, "sample_top_p", RUBY_METHOD_FUNC(_llama_context_sample_top_p), -1);
1855
+ rb_define_method(rb_cLLaMAContext, "sample_min_p", RUBY_METHOD_FUNC(_llama_context_sample_min_p), -1);
1702
1856
  rb_define_method(rb_cLLaMAContext, "sample_tail_free", RUBY_METHOD_FUNC(_llama_context_sample_tail_free), -1);
1703
1857
  rb_define_method(rb_cLLaMAContext, "sample_typical", RUBY_METHOD_FUNC(_llama_context_sample_typical), -1);
1704
1858
  rb_define_method(rb_cLLaMAContext, "sample_temp", RUBY_METHOD_FUNC(_llama_context_sample_temp), -1);
@@ -1927,102 +2081,6 @@ private:
1927
2081
  return output;
1928
2082
  }
1929
2083
 
1930
- static VALUE _llama_context_text(VALUE self, VALUE token_) {
1931
- LLaMAContextWrapper* ptr = get_llama_context(self);
1932
- if (ptr->ctx == NULL) {
1933
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1934
- return Qnil;
1935
- }
1936
- const llama_token token = NUM2INT(token_);
1937
- const char* text = llama_token_get_text(ptr->ctx, token);
1938
- return rb_utf8_str_new_cstr(text);
1939
- }
1940
-
1941
- static VALUE _llama_context_score(VALUE self, VALUE token_) {
1942
- LLaMAContextWrapper* ptr = get_llama_context(self);
1943
- if (ptr->ctx == NULL) {
1944
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1945
- return Qnil;
1946
- }
1947
- const llama_token token = NUM2INT(token_);
1948
- const float score = llama_token_get_score(ptr->ctx, token);
1949
- return DBL2NUM(score);
1950
- }
1951
-
1952
- static VALUE _llama_context_type(VALUE self, VALUE token_) {
1953
- LLaMAContextWrapper* ptr = get_llama_context(self);
1954
- if (ptr->ctx == NULL) {
1955
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1956
- return Qnil;
1957
- }
1958
- const llama_token token = NUM2INT(token_);
1959
- const int type = llama_token_get_type(ptr->ctx, token);
1960
- return INT2NUM(type);
1961
- }
1962
-
1963
- static VALUE _llama_context_token_bos(VALUE self) {
1964
- LLaMAContextWrapper* ptr = get_llama_context(self);
1965
- if (ptr->ctx == NULL) {
1966
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1967
- return Qnil;
1968
- }
1969
- return INT2NUM(llama_token_bos(ptr->ctx));
1970
- }
1971
-
1972
- static VALUE _llama_context_token_eos(VALUE self) {
1973
- LLaMAContextWrapper* ptr = get_llama_context(self);
1974
- if (ptr->ctx == NULL) {
1975
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1976
- return Qnil;
1977
- }
1978
- return INT2NUM(llama_token_eos(ptr->ctx));
1979
- }
1980
-
1981
- static VALUE _llama_context_token_nl(VALUE self) {
1982
- LLaMAContextWrapper* ptr = get_llama_context(self);
1983
- if (ptr->ctx == NULL) {
1984
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1985
- return Qnil;
1986
- }
1987
- return INT2NUM(llama_token_nl(ptr->ctx));
1988
- }
1989
-
1990
- static VALUE _llama_context_token_prefix(VALUE self) {
1991
- LLaMAContextWrapper* ptr = get_llama_context(self);
1992
- if (ptr->ctx == NULL) {
1993
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1994
- return Qnil;
1995
- }
1996
- return INT2NUM(llama_token_prefix(ptr->ctx));
1997
- }
1998
-
1999
- static VALUE _llama_context_token_middle(VALUE self) {
2000
- LLaMAContextWrapper* ptr = get_llama_context(self);
2001
- if (ptr->ctx == NULL) {
2002
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2003
- return Qnil;
2004
- }
2005
- return INT2NUM(llama_token_middle(ptr->ctx));
2006
- }
2007
-
2008
- static VALUE _llama_context_token_suffix(VALUE self) {
2009
- LLaMAContextWrapper* ptr = get_llama_context(self);
2010
- if (ptr->ctx == NULL) {
2011
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2012
- return Qnil;
2013
- }
2014
- return INT2NUM(llama_token_suffix(ptr->ctx));
2015
- }
2016
-
2017
- static VALUE _llama_context_token_eot(VALUE self) {
2018
- LLaMAContextWrapper* ptr = get_llama_context(self);
2019
- if (ptr->ctx == NULL) {
2020
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2021
- return Qnil;
2022
- }
2023
- return INT2NUM(llama_token_eot(ptr->ctx));
2024
- }
2025
-
2026
2084
  static VALUE _llama_context_n_ctx(VALUE self) {
2027
2085
  LLaMAContextWrapper* ptr = get_llama_context(self);
2028
2086
  if (ptr->ctx == NULL) {
@@ -2073,13 +2131,13 @@ private:
2073
2131
  return INT2NUM(llama_get_kv_cache_token_count(ptr->ctx));
2074
2132
  }
2075
2133
 
2076
- static VALUE _llama_context_kv_cache_tokens_rm(VALUE self, VALUE c0, VALUE c1) {
2134
+ static VALUE _llama_context_kv_cache_clear(VALUE self) {
2077
2135
  LLaMAContextWrapper* ptr = get_llama_context(self);
2078
2136
  if (ptr->ctx == NULL) {
2079
2137
  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2080
2138
  return Qnil;
2081
2139
  }
2082
- llama_kv_cache_tokens_rm(ptr->ctx, NUM2INT(c0), NUM2INT(c1));
2140
+ llama_kv_cache_clear(ptr->ctx);
2083
2141
  return Qnil;
2084
2142
  }
2085
2143
 
@@ -2231,14 +2289,14 @@ private:
2231
2289
  return Qnil;
2232
2290
  }
2233
2291
 
2234
- static VALUE _llama_context_sample_repetition_penalty(int argc, VALUE* argv, VALUE self) {
2292
+ static VALUE _llama_context_sample_repetition_penalties(int argc, VALUE* argv, VALUE self) {
2235
2293
  VALUE kw_args = Qnil;
2236
- ID kw_table[1] = { rb_intern("penalty") };
2237
- VALUE kw_values[1] = { Qundef };
2294
+ ID kw_table[3] = { rb_intern("penalty_repeat"), rb_intern("penalty_freq"), rb_intern("penalty_present") };
2295
+ VALUE kw_values[3] = { Qundef, Qundef, Qundef };
2238
2296
  VALUE candidates = Qnil;
2239
2297
  VALUE last_n_tokens = Qnil;
2240
2298
  rb_scan_args(argc, argv, "2:", &candidates, &last_n_tokens, &kw_args);
2241
- rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
2299
+ rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
2242
2300
 
2243
2301
  if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
2244
2302
  rb_raise(rb_eArgError, "candidates must be a TokenDataArray");
@@ -2249,56 +2307,15 @@ private:
2249
2307
  return Qnil;
2250
2308
  }
2251
2309
  if (!RB_FLOAT_TYPE_P(kw_values[0])) {
2252
- rb_raise(rb_eArgError, "penalty must be a float");
2253
- return Qnil;
2254
- }
2255
-
2256
- const size_t last_tokens_size = RARRAY_LEN(last_n_tokens);
2257
- std::vector<llama_token> last_n_tokens_data(last_tokens_size);
2258
- for (size_t i = 0; i < last_tokens_size; i++) {
2259
- last_n_tokens_data[i] = NUM2INT(rb_ary_entry(last_n_tokens, i));
2260
- }
2261
-
2262
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
2263
- if (ctx_ptr->ctx == NULL) {
2264
- rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2310
+ rb_raise(rb_eArgError, "penalty_repeat must be a float");
2265
2311
  return Qnil;
2266
2312
  }
2267
- LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
2268
- if (cnd_ptr->array.data == nullptr) {
2269
- rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
2270
- return Qnil;
2271
- }
2272
- const float penalty = NUM2DBL(kw_values[0]);
2273
-
2274
- llama_sample_repetition_penalty(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size, penalty);
2275
-
2276
- return Qnil;
2277
- }
2278
-
2279
- static VALUE _llama_context_sample_frequency_and_presence_penalties(int argc, VALUE* argv, VALUE self) {
2280
- VALUE kw_args = Qnil;
2281
- ID kw_table[2] = { rb_intern("frequency"), rb_intern("presence") };
2282
- VALUE kw_values[2] = { Qundef, Qundef };
2283
- VALUE candidates = Qnil;
2284
- VALUE last_n_tokens = Qnil;
2285
- rb_scan_args(argc, argv, "2:", &candidates, &last_n_tokens, &kw_args);
2286
- rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
2287
-
2288
- if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
2289
- rb_raise(rb_eArgError, "candidates must be a TokenDataArray");
2290
- return Qnil;
2291
- }
2292
- if (!RB_TYPE_P(last_n_tokens, T_ARRAY)) {
2293
- rb_raise(rb_eArgError, "last_n_tokens must be an Array");
2294
- return Qnil;
2295
- }
2296
- if (!RB_FLOAT_TYPE_P(kw_values[0])) {
2297
- rb_raise(rb_eArgError, "frequency must be a float");
2313
+ if (!RB_FLOAT_TYPE_P(kw_values[1])) {
2314
+ rb_raise(rb_eArgError, "penalty_freq must be a float");
2298
2315
  return Qnil;
2299
2316
  }
2300
- if (!RB_FLOAT_TYPE_P(kw_values[1])) {
2301
- rb_raise(rb_eArgError, "presence must be a float");
2317
+ if (!RB_FLOAT_TYPE_P(kw_values[2])) {
2318
+ rb_raise(rb_eArgError, "penalty_present must be a float");
2302
2319
  return Qnil;
2303
2320
  }
2304
2321
 
@@ -2318,11 +2335,12 @@ private:
2318
2335
  rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
2319
2336
  return Qnil;
2320
2337
  }
2338
+ const float penalty_repeat = NUM2DBL(kw_values[0]);
2339
+ const float penalty_freq = NUM2DBL(kw_values[1]);
2340
+ const float penalty_present = NUM2DBL(kw_values[2]);
2321
2341
 
2322
- const float alpha_frequency = NUM2DBL(kw_values[0]);
2323
- const float alpha_presence = NUM2DBL(kw_values[1]);
2324
-
2325
- llama_sample_frequency_and_presence_penalties(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size, alpha_frequency, alpha_presence);
2342
+ llama_sample_repetition_penalties(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size,
2343
+ penalty_repeat, penalty_freq, penalty_present);
2326
2344
 
2327
2345
  return Qnil;
2328
2346
  }
@@ -2467,6 +2485,45 @@ private:
2467
2485
  return Qnil;
2468
2486
  }
2469
2487
 
2488
+ static VALUE _llama_context_sample_min_p(int argc, VALUE* argv, VALUE self) {
2489
+ VALUE kw_args = Qnil;
2490
+ ID kw_table[2] = { rb_intern("prob"), rb_intern("min_keep") };
2491
+ VALUE kw_values[2] = { Qundef, Qundef };
2492
+ VALUE candidates = Qnil;
2493
+ rb_scan_args(argc, argv, "1:", &candidates, &kw_args);
2494
+ rb_get_kwargs(kw_args, kw_table, 1, 1, kw_values);
2495
+
2496
+ if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
2497
+ rb_raise(rb_eArgError, "1st argument must be a TokenDataArray");
2498
+ return Qnil;
2499
+ }
2500
+ if (!RB_FLOAT_TYPE_P(kw_values[0])) {
2501
+ rb_raise(rb_eArgError, "prob must be a float");
2502
+ return Qnil;
2503
+ }
2504
+ if (kw_values[1] != Qundef && !RB_INTEGER_TYPE_P(kw_values[1])) {
2505
+ rb_raise(rb_eArgError, "min_keep must be an integer");
2506
+ return Qnil;
2507
+ }
2508
+
2509
+ LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
2510
+ if (ctx_ptr->ctx == NULL) {
2511
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
2512
+ return Qnil;
2513
+ }
2514
+ LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
2515
+ if (cnd_ptr->array.data == nullptr) {
2516
+ rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
2517
+ return Qnil;
2518
+ }
2519
+ const float prob = NUM2DBL(kw_values[0]);
2520
+ const size_t min_keep = kw_values[1] != Qundef ? NUM2SIZET(kw_values[1]) : 1;
2521
+
2522
+ llama_sample_min_p(ctx_ptr->ctx, &(cnd_ptr->array), prob, min_keep);
2523
+
2524
+ return Qnil;
2525
+ }
2526
+
2470
2527
  static VALUE _llama_context_sample_tail_free(int argc, VALUE* argv, VALUE self) {
2471
2528
  VALUE kw_args = Qnil;
2472
2529
  ID kw_table[2] = { rb_intern("z"), rb_intern("min_keep") };
@@ -2962,6 +3019,12 @@ extern "C" void Init_llama_cpp(void) {
2962
3019
  rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR_RNG_UPPER", INT2NUM(LLAMA_GRETYPE_CHAR_RNG_UPPER));
2963
3020
  rb_define_const(rb_mLLaMACpp, "LLAMA_GRETYPE_CHAR_ALT", INT2NUM(LLAMA_GRETYPE_CHAR_ALT));
2964
3021
 
3022
+ rb_define_const(rb_mLLaMACpp, "LLAMA_ROPE_SCALING_UNSPECIFIED", INT2NUM(LLAMA_ROPE_SCALING_UNSPECIFIED));
3023
+ rb_define_const(rb_mLLaMACpp, "LLAMA_ROPE_SCALING_NONE", INT2NUM(LLAMA_ROPE_SCALING_NONE));
3024
+ rb_define_const(rb_mLLaMACpp, "LLAMA_ROPE_SCALING_LINEAR", INT2NUM(LLAMA_ROPE_SCALING_LINEAR));
3025
+ rb_define_const(rb_mLLaMACpp, "LLAMA_ROPE_SCALING_YARN", INT2NUM(LLAMA_ROPE_SCALING_YARN));
3026
+ rb_define_const(rb_mLLaMACpp, "LLAMA_ROPE_SCALING_MAX_VALUE", INT2NUM(LLAMA_ROPE_SCALING_MAX_VALUE));
3027
+
2965
3028
  std::stringstream ss_magic;
2966
3029
  ss_magic << std::showbase << std::hex << LLAMA_FILE_MAGIC_GGSN;
2967
3030
  rb_define_const(rb_mLLaMACpp, "LLAMA_FILE_MAGIC_GGSN", rb_str_new2(ss_magic.str().c_str()));