llama_cpp 0.7.1 → 0.9.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +18 -0
- data/examples/chat.rb +8 -6
- data/ext/llama_cpp/extconf.rb +2 -2
- data/ext/llama_cpp/llama_cpp.cpp +122 -183
- data/ext/llama_cpp/src/ggml-cuda.cu +188 -20
- data/ext/llama_cpp/src/ggml-metal.m +57 -8
- data/ext/llama_cpp/src/ggml-metal.metal +171 -2
- data/ext/llama_cpp/src/ggml-opencl.cpp +188 -222
- data/ext/llama_cpp/src/ggml.c +375 -93
- data/ext/llama_cpp/src/ggml.h +11 -9
- data/ext/llama_cpp/src/k_quants.c +12 -20
- data/ext/llama_cpp/src/llama.cpp +459 -153
- data/ext/llama_cpp/src/llama.h +34 -33
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +4 -4
- data/sig/llama_cpp.rbs +15 -16
- metadata +3 -3
    
        checksums.yaml
    CHANGED
    
    | @@ -1,7 +1,7 @@ | |
| 1 1 | 
             
            ---
         | 
| 2 2 | 
             
            SHA256:
         | 
| 3 | 
            -
              metadata.gz:  | 
| 4 | 
            -
              data.tar.gz:  | 
| 3 | 
            +
              metadata.gz: 683f2d81aff9e82234925ba08cd5b46b56a2283ff8397a6c06ce50d34a95dbfc
         | 
| 4 | 
            +
              data.tar.gz: d3005cab273b8d85f47f4cb4314fbab3a540d366a42829e5ec8d2c29576ae09e
         | 
| 5 5 | 
             
            SHA512:
         | 
| 6 | 
            -
              metadata.gz:  | 
| 7 | 
            -
              data.tar.gz:  | 
| 6 | 
            +
              metadata.gz: 559f1ba1253a704c38480336decd315c65b4d80e6895ad1dc0faa3b5b81570a1faeaadcb6ec7ee3145f0fff758ab5e38e6cb8163382ce9b693d893deebe9a8f9
         | 
| 7 | 
            +
              data.tar.gz: cb3d96b8c3f79cd20d4169a175270e8768c04bcaa24e51cb2c4d7872db88bc6e3349e6b1e93a130b89d21daab8be6e57b5305412059ea722084c7cb7d4a01e93
         | 
    
        data/CHANGELOG.md
    CHANGED
    
    | @@ -1,3 +1,21 @@ | |
| 1 | 
            +
            ## [[0.9.0](https://github.com/yoshoku/llama_cpp.rb/compare/v0.8.0...v0.9.0)] - 2023-10-28
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            - Fix missing object file for ggml-backend when building with metal and cublas options.
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            **Breaking Changes**
         | 
| 6 | 
            +
            - Bump bundled llama.cpp from b1405 to b1429
         | 
| 7 | 
            +
              - Move following methods from Context to Model:
         | 
| 8 | 
            +
                - text, score, type, token_bos, token_eos, token_nl, token_prefix, token_middle, token_suffix, and token_eos.
         | 
| 9 | 
            +
              - Add `sample_repetition_penalties` method, which integrates sample_frequency_and_presence_penalties and sample_repetition_penalty methods.
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            ## [[0.8.0](https://github.com/yoshoku/llama_cpp.rb/compare/v0.7.1...v0.8.0)] - 2023-10-21
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            **Breaking Changes**
         | 
| 14 | 
            +
            - Bump bundled llama.cpp from b1380 to b1405
         | 
| 15 | 
            +
              - Add column index argument to `set_seq_id` and `get_seq_id` methods in Batch.
         | 
| 16 | 
            +
              - Add `special` keyword argument to `tokenize` method in Model.
         | 
| 17 | 
            +
              - Add `n_seq_max` keyword argument to `initialize` method in Batch.
         | 
| 18 | 
            +
             | 
| 1 19 | 
             
            ## [[0.7.1](https://github.com/yoshoku/llama_cpp.rb/compare/v0.7.0...v0.7.1)] - 2023-10-14
         | 
| 2 20 |  | 
| 3 21 | 
             
            - Bump bundled llama.cpp from b1334 to b1380.
         | 
    
        data/examples/chat.rb
    CHANGED
    
    | @@ -83,10 +83,12 @@ class Chat < Thor # rubocop:disable Metrics/ClassLength, Style/Documentation | |
| 83 83 | 
             
                    candidates = LLaMACpp::TokenDataArray.new(base_candidates)
         | 
| 84 84 |  | 
| 85 85 | 
             
                    last_n_repeat = [last_n_tokens.size, options[:repeat_last_n], n_ctx].min
         | 
| 86 | 
            -
                    context. | 
| 87 | 
            -
             | 
| 88 | 
            -
                       | 
| 89 | 
            -
                       | 
| 86 | 
            +
                    context.sample_repetition_penalties(
         | 
| 87 | 
            +
                      candidates,
         | 
| 88 | 
            +
                      last_n_tokens[-last_n_repeat..],
         | 
| 89 | 
            +
                      penalty_repeat: options[:repeat_penalty],
         | 
| 90 | 
            +
                      penalty_freq: options[:frequency_penalty],
         | 
| 91 | 
            +
                      penalty_present: options[:presence_penalty]
         | 
| 90 92 | 
             
                    )
         | 
| 91 93 |  | 
| 92 94 | 
             
                    context.sample_top_k(candidates, k: options[:top_k])
         | 
| @@ -99,8 +101,8 @@ class Chat < Thor # rubocop:disable Metrics/ClassLength, Style/Documentation | |
| 99 101 | 
             
                    last_n_tokens.shift
         | 
| 100 102 | 
             
                    last_n_tokens.push(id)
         | 
| 101 103 |  | 
| 102 | 
            -
                    if id == context.token_eos
         | 
| 103 | 
            -
                      id = context.token_nl
         | 
| 104 | 
            +
                    if id == context.model.token_eos
         | 
| 105 | 
            +
                      id = context.model.token_nl
         | 
| 104 106 | 
             
                      unless antiprompt.empty?
         | 
| 105 107 | 
             
                        first_antiprompt = context.model.tokenize(text: antiprompt, add_bos: false)
         | 
| 106 108 | 
             
                        embd_input.concat(first_antiprompt)
         | 
    
        data/ext/llama_cpp/extconf.rb
    CHANGED
    
    | @@ -53,7 +53,7 @@ if with_config('metal') | |
| 53 53 | 
             
              $CFLAGS << ' -DGGML_USE_METAL'
         | 
| 54 54 | 
             
              $CXXFLAGS << ' -DGGML_USE_METAL'
         | 
| 55 55 | 
             
              $LDFLAGS << ' -framework Foundation -framework Metal -framework MetalKit'
         | 
| 56 | 
            -
              $objs = %w[ggml.o ggml-alloc.o ggml-metal.o llama.o llama_cpp.o]
         | 
| 56 | 
            +
              $objs = %w[ggml.o ggml-backend.o ggml-alloc.o ggml-metal.o llama.o llama_cpp.o]
         | 
| 57 57 | 
             
              $objs << 'k_quants.o' unless with_config('no_k_quants')
         | 
| 58 58 | 
             
            end
         | 
| 59 59 |  | 
| @@ -61,7 +61,7 @@ if with_config('cublas') | |
| 61 61 | 
             
              $CFLAGS << ' -DGGML_USE_CUBLAS -I/usr/local/cuda/include'
         | 
| 62 62 | 
             
              $CXXFLAGS << ' -DGGML_USE_CUBLAS -I/usr/local/cuda/include'
         | 
| 63 63 | 
             
              $LDFLAGS << ' -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64'
         | 
| 64 | 
            -
              $objs = %w[ggml.o ggml-alloc.o ggml-cuda.o llama.o llama_cpp.o]
         | 
| 64 | 
            +
              $objs = %w[ggml.o ggml-backend.o ggml-alloc.o ggml-cuda.o llama.o llama_cpp.o]
         | 
| 65 65 | 
             
              $objs << 'k_quants.o' unless with_config('no_k_quants')
         | 
| 66 66 | 
             
            end
         | 
| 67 67 |  | 
    
        data/ext/llama_cpp/llama_cpp.cpp
    CHANGED
    
    | @@ -63,8 +63,8 @@ public: | |
| 63 63 | 
             
                rb_define_method(rb_cLLaMABatch, "get_token", RUBY_METHOD_FUNC(_llama_batch_get_token), 1);
         | 
| 64 64 | 
             
                rb_define_method(rb_cLLaMABatch, "set_pos", RUBY_METHOD_FUNC(_llama_batch_set_pos), 2);
         | 
| 65 65 | 
             
                rb_define_method(rb_cLLaMABatch, "get_pos", RUBY_METHOD_FUNC(_llama_batch_get_pos), 1);
         | 
| 66 | 
            -
                rb_define_method(rb_cLLaMABatch, "set_seq_id", RUBY_METHOD_FUNC(_llama_batch_set_seq_id),  | 
| 67 | 
            -
                rb_define_method(rb_cLLaMABatch, "get_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_seq_id),  | 
| 66 | 
            +
                rb_define_method(rb_cLLaMABatch, "set_seq_id", RUBY_METHOD_FUNC(_llama_batch_set_seq_id), 3);
         | 
| 67 | 
            +
                rb_define_method(rb_cLLaMABatch, "get_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_seq_id), 2);
         | 
| 68 68 | 
             
                rb_define_method(rb_cLLaMABatch, "set_logits", RUBY_METHOD_FUNC(_llama_batch_set_logits), 2);
         | 
| 69 69 | 
             
                rb_define_method(rb_cLLaMABatch, "get_logits", RUBY_METHOD_FUNC(_llama_batch_get_logits), 1);
         | 
| 70 70 | 
             
              }
         | 
| @@ -74,10 +74,10 @@ private: | |
| 74 74 |  | 
| 75 75 | 
             
              static VALUE _llama_batch_initialize(int argc, VALUE* argv, VALUE self) {
         | 
| 76 76 | 
             
                VALUE kw_args = Qnil;
         | 
| 77 | 
            -
                ID kw_table[ | 
| 78 | 
            -
                VALUE kw_values[ | 
| 77 | 
            +
                ID kw_table[3] = { rb_intern("n_tokens"), rb_intern("embd"), rb_intern("n_seq_max") };
         | 
| 78 | 
            +
                VALUE kw_values[3] = { Qundef, Qundef, Qundef };
         | 
| 79 79 | 
             
                rb_scan_args(argc, argv, ":", &kw_args);
         | 
| 80 | 
            -
                rb_get_kwargs(kw_args, kw_table,  | 
| 80 | 
            +
                rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
         | 
| 81 81 |  | 
| 82 82 | 
             
                if (!RB_INTEGER_TYPE_P(kw_values[0])) {
         | 
| 83 83 | 
             
                  rb_raise(rb_eArgError, "n_tokens must be an integer");
         | 
| @@ -87,12 +87,17 @@ private: | |
| 87 87 | 
             
                  rb_raise(rb_eArgError, "embd must be an integer");
         | 
| 88 88 | 
             
                  return Qnil;
         | 
| 89 89 | 
             
                }
         | 
| 90 | 
            +
                if (!RB_INTEGER_TYPE_P(kw_values[2])) {
         | 
| 91 | 
            +
                  rb_raise(rb_eArgError, "n_seq_max must be an integer");
         | 
| 92 | 
            +
                  return Qnil;
         | 
| 93 | 
            +
                }
         | 
| 90 94 |  | 
| 91 95 | 
             
                const int32_t n_tokens = NUM2INT(kw_values[0]);
         | 
| 92 96 | 
             
                const int32_t embd = NUM2INT(kw_values[1]);
         | 
| 97 | 
            +
                const int32_t n_seq_max = NUM2INT(kw_values[2]);
         | 
| 93 98 |  | 
| 94 99 | 
             
                LLaMABatchWrapper* ptr = get_llama_batch(self);
         | 
| 95 | 
            -
                ptr->batch = llama_batch_init(n_tokens, embd);
         | 
| 100 | 
            +
                ptr->batch = llama_batch_init(n_tokens, embd, n_seq_max);
         | 
| 96 101 |  | 
| 97 102 | 
             
                return Qnil;
         | 
| 98 103 | 
             
              }
         | 
| @@ -190,25 +195,35 @@ private: | |
| 190 195 | 
             
              }
         | 
| 191 196 |  | 
| 192 197 | 
             
              // seq_id
         | 
| 193 | 
            -
              static VALUE _llama_batch_set_seq_id(VALUE self, VALUE  | 
| 198 | 
            +
              static VALUE _llama_batch_set_seq_id(VALUE self, VALUE i_, VALUE j_, VALUE value) {
         | 
| 194 199 | 
             
                LLaMABatchWrapper* ptr = get_llama_batch(self);
         | 
| 195 | 
            -
                const int32_t  | 
| 196 | 
            -
                if ( | 
| 197 | 
            -
                  rb_raise(rb_eArgError, " | 
| 200 | 
            +
                const int32_t i = NUM2INT(i_);
         | 
| 201 | 
            +
                if (i < 0 || i >= ptr->batch.n_tokens) {
         | 
| 202 | 
            +
                  rb_raise(rb_eArgError, "i must be in [0, n_tokens)");
         | 
| 203 | 
            +
                  return Qnil;
         | 
| 204 | 
            +
                }
         | 
| 205 | 
            +
                const int32_t j = NUM2INT(j_);
         | 
| 206 | 
            +
                if (j < 0 || j >= ptr->batch.n_seq_id[i]) {
         | 
| 207 | 
            +
                  rb_raise(rb_eArgError, "j must be in [0, n_seq_id[i])");
         | 
| 198 208 | 
             
                  return Qnil;
         | 
| 199 209 | 
             
                }
         | 
| 200 | 
            -
                ptr->batch.seq_id[ | 
| 201 | 
            -
                return INT2NUM(ptr->batch.seq_id[ | 
| 210 | 
            +
                ptr->batch.seq_id[i][j] = NUM2INT(value);
         | 
| 211 | 
            +
                return INT2NUM(ptr->batch.seq_id[i][j]);
         | 
| 202 212 | 
             
              }
         | 
| 203 213 |  | 
| 204 | 
            -
              static VALUE _llama_batch_get_seq_id(VALUE self, VALUE  | 
| 214 | 
            +
              static VALUE _llama_batch_get_seq_id(VALUE self, VALUE i_, VALUE j_) {
         | 
| 205 215 | 
             
                LLaMABatchWrapper* ptr = get_llama_batch(self);
         | 
| 206 | 
            -
                const int32_t  | 
| 207 | 
            -
                if ( | 
| 208 | 
            -
                  rb_raise(rb_eArgError, " | 
| 216 | 
            +
                const int32_t i = NUM2INT(i_);
         | 
| 217 | 
            +
                if (i < 0 || i >= ptr->batch.n_tokens) {
         | 
| 218 | 
            +
                  rb_raise(rb_eArgError, "i must be in [0, n_tokens)");
         | 
| 209 219 | 
             
                  return Qnil;
         | 
| 210 220 | 
             
                }
         | 
| 211 | 
            -
                 | 
| 221 | 
            +
                const int32_t j = NUM2INT(j_);
         | 
| 222 | 
            +
                if (j < 0 || j >= ptr->batch.n_seq_id[i]) {
         | 
| 223 | 
            +
                  rb_raise(rb_eArgError, "j must be in [0, n_seq_id[i])");
         | 
| 224 | 
            +
                  return Qnil;
         | 
| 225 | 
            +
                }
         | 
| 226 | 
            +
                return INT2NUM(ptr->batch.seq_id[i][j]);
         | 
| 212 227 | 
             
              }
         | 
| 213 228 |  | 
| 214 229 | 
             
              // logits
         | 
| @@ -1133,6 +1148,16 @@ public: | |
| 1133 1148 | 
             
                rb_define_method(rb_cLLaMAModel, "desc", RUBY_METHOD_FUNC(_llama_model_get_model_desc), 0);
         | 
| 1134 1149 | 
             
                rb_define_method(rb_cLLaMAModel, "size", RUBY_METHOD_FUNC(_llama_model_get_model_size), 0);
         | 
| 1135 1150 | 
             
                rb_define_method(rb_cLLaMAModel, "n_params", RUBY_METHOD_FUNC(_llama_model_get_model_n_params), 0);
         | 
| 1151 | 
            +
                rb_define_method(rb_cLLaMAModel, "text", RUBY_METHOD_FUNC(_llama_model_get_text), 1);
         | 
| 1152 | 
            +
                rb_define_method(rb_cLLaMAModel, "score", RUBY_METHOD_FUNC(_llama_model_get_score), 1);
         | 
| 1153 | 
            +
                rb_define_method(rb_cLLaMAModel, "type", RUBY_METHOD_FUNC(_llama_model_get_type), 1);
         | 
| 1154 | 
            +
                rb_define_method(rb_cLLaMAModel, "token_bos", RUBY_METHOD_FUNC(_llama_model_token_bos), 0);
         | 
| 1155 | 
            +
                rb_define_method(rb_cLLaMAModel, "token_eos", RUBY_METHOD_FUNC(_llama_model_token_eos), 0);
         | 
| 1156 | 
            +
                rb_define_method(rb_cLLaMAModel, "token_nl", RUBY_METHOD_FUNC(_llama_model_token_nl), 0);
         | 
| 1157 | 
            +
                rb_define_method(rb_cLLaMAModel, "token_prefix", RUBY_METHOD_FUNC(_llama_model_token_prefix), 0);
         | 
| 1158 | 
            +
                rb_define_method(rb_cLLaMAModel, "token_middle", RUBY_METHOD_FUNC(_llama_model_token_middle), 0);
         | 
| 1159 | 
            +
                rb_define_method(rb_cLLaMAModel, "token_suffix", RUBY_METHOD_FUNC(_llama_model_token_suffix), 0);
         | 
| 1160 | 
            +
                rb_define_method(rb_cLLaMAModel, "token_eot", RUBY_METHOD_FUNC(_llama_model_token_eot), 0);
         | 
| 1136 1161 | 
             
              }
         | 
| 1137 1162 |  | 
| 1138 1163 | 
             
            private:
         | 
| @@ -1319,10 +1344,10 @@ private: | |
| 1319 1344 |  | 
| 1320 1345 | 
             
              static VALUE _llama_model_tokenize(int argc, VALUE* argv, VALUE self) {
         | 
| 1321 1346 | 
             
                VALUE kw_args = Qnil;
         | 
| 1322 | 
            -
                ID kw_table[ | 
| 1323 | 
            -
                VALUE kw_values[ | 
| 1347 | 
            +
                ID kw_table[4] = { rb_intern("text"), rb_intern("n_max_tokens"), rb_intern("add_bos"), rb_intern("special") };
         | 
| 1348 | 
            +
                VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
         | 
| 1324 1349 | 
             
                rb_scan_args(argc, argv, ":", &kw_args);
         | 
| 1325 | 
            -
                rb_get_kwargs(kw_args, kw_table, 1,  | 
| 1350 | 
            +
                rb_get_kwargs(kw_args, kw_table, 1, 3, kw_values);
         | 
| 1326 1351 |  | 
| 1327 1352 | 
             
                if (!RB_TYPE_P(kw_values[0], T_STRING)) {
         | 
| 1328 1353 | 
             
                  rb_raise(rb_eArgError, "text must be a String");
         | 
| @@ -1336,15 +1361,20 @@ private: | |
| 1336 1361 | 
             
                  rb_raise(rb_eArgError, "add_bos must be a boolean");
         | 
| 1337 1362 | 
             
                  return Qnil;
         | 
| 1338 1363 | 
             
                }
         | 
| 1364 | 
            +
                if (kw_values[3] != Qundef && (kw_values[3] != Qtrue && kw_values[3] != Qfalse)) {
         | 
| 1365 | 
            +
                  rb_raise(rb_eArgError, "special must be a boolean");
         | 
| 1366 | 
            +
                  return Qnil;
         | 
| 1367 | 
            +
                }
         | 
| 1339 1368 |  | 
| 1340 1369 | 
             
                VALUE text_ = kw_values[0];
         | 
| 1341 1370 | 
             
                std::string text = StringValueCStr(text_);
         | 
| 1342 1371 | 
             
                const bool add_bos = kw_values[2] == Qtrue ? true : false;
         | 
| 1372 | 
            +
                const bool special = kw_values[3] == Qtrue ? true : false;
         | 
| 1343 1373 | 
             
                const int n_max_tokens = kw_values[1] != Qundef ? NUM2INT(kw_values[1]) : text.size() + (add_bos ? 1 : 0);
         | 
| 1344 1374 |  | 
| 1345 1375 | 
             
                llama_token* tokens = ALLOCA_N(llama_token, n_max_tokens);
         | 
| 1346 1376 | 
             
                LLaMAModelWrapper* ptr = get_llama_model(self);
         | 
| 1347 | 
            -
                const int n_tokens = llama_tokenize(ptr->model, text.c_str(), text.size(), tokens, n_max_tokens, add_bos);
         | 
| 1377 | 
            +
                const int n_tokens = llama_tokenize(ptr->model, text.c_str(), text.size(), tokens, n_max_tokens, add_bos, special);
         | 
| 1348 1378 |  | 
| 1349 1379 | 
             
                if (n_tokens < 0) {
         | 
| 1350 1380 | 
             
                  rb_raise(rb_eRuntimeError, "failed to tokenize. The numebr of tokens (%d) is greater than n_max_tokens.", -n_tokens);
         | 
| @@ -1376,6 +1406,62 @@ private: | |
| 1376 1406 | 
             
                LLaMAModelWrapper* ptr = get_llama_model(self);
         | 
| 1377 1407 | 
             
                return UINT2NUM(llama_model_n_params(ptr->model));
         | 
| 1378 1408 | 
             
              }
         | 
| 1409 | 
            +
             | 
| 1410 | 
            +
              static VALUE _llama_model_get_text(VALUE self, VALUE token_) {
         | 
| 1411 | 
            +
                LLaMAModelWrapper* ptr = get_llama_model(self);
         | 
| 1412 | 
            +
                const llama_token token = NUM2INT(token_);
         | 
| 1413 | 
            +
                const char* text = llama_token_get_text(ptr->model, token);
         | 
| 1414 | 
            +
                return rb_utf8_str_new_cstr(text);
         | 
| 1415 | 
            +
              }
         | 
| 1416 | 
            +
             | 
| 1417 | 
            +
              static VALUE _llama_model_get_score(VALUE self, VALUE token_) {
         | 
| 1418 | 
            +
                LLaMAModelWrapper* ptr = get_llama_model(self);
         | 
| 1419 | 
            +
                const llama_token token = NUM2INT(token_);
         | 
| 1420 | 
            +
                const float score = llama_token_get_score(ptr->model, token);
         | 
| 1421 | 
            +
                return DBL2NUM(score);
         | 
| 1422 | 
            +
              }
         | 
| 1423 | 
            +
             | 
| 1424 | 
            +
              static VALUE _llama_model_get_type(VALUE self, VALUE token_) {
         | 
| 1425 | 
            +
                LLaMAModelWrapper* ptr = get_llama_model(self);
         | 
| 1426 | 
            +
                const llama_token token = NUM2INT(token_);
         | 
| 1427 | 
            +
                const int type = llama_token_get_type(ptr->model, token);
         | 
| 1428 | 
            +
                return INT2NUM(type);
         | 
| 1429 | 
            +
              }
         | 
| 1430 | 
            +
             | 
| 1431 | 
            +
              static VALUE _llama_model_token_bos(VALUE self) {
         | 
| 1432 | 
            +
                LLaMAModelWrapper* ptr = get_llama_model(self);
         | 
| 1433 | 
            +
                return INT2NUM(llama_token_bos(ptr->model));
         | 
| 1434 | 
            +
              }
         | 
| 1435 | 
            +
             | 
| 1436 | 
            +
              static VALUE _llama_model_token_eos(VALUE self) {
         | 
| 1437 | 
            +
                LLaMAModelWrapper* ptr = get_llama_model(self);
         | 
| 1438 | 
            +
                return INT2NUM(llama_token_eos(ptr->model));
         | 
| 1439 | 
            +
              }
         | 
| 1440 | 
            +
             | 
| 1441 | 
            +
              static VALUE _llama_model_token_nl(VALUE self) {
         | 
| 1442 | 
            +
                LLaMAModelWrapper* ptr = get_llama_model(self);
         | 
| 1443 | 
            +
                return INT2NUM(llama_token_nl(ptr->model));
         | 
| 1444 | 
            +
              }
         | 
| 1445 | 
            +
             | 
| 1446 | 
            +
              static VALUE _llama_model_token_prefix(VALUE self) {
         | 
| 1447 | 
            +
                LLaMAModelWrapper* ptr = get_llama_model(self);
         | 
| 1448 | 
            +
                return INT2NUM(llama_token_prefix(ptr->model));
         | 
| 1449 | 
            +
              }
         | 
| 1450 | 
            +
             | 
| 1451 | 
            +
              static VALUE _llama_model_token_middle(VALUE self) {
         | 
| 1452 | 
            +
                LLaMAModelWrapper* ptr = get_llama_model(self);
         | 
| 1453 | 
            +
                return INT2NUM(llama_token_middle(ptr->model));
         | 
| 1454 | 
            +
              }
         | 
| 1455 | 
            +
             | 
| 1456 | 
            +
              static VALUE _llama_model_token_suffix(VALUE self) {
         | 
| 1457 | 
            +
                LLaMAModelWrapper* ptr = get_llama_model(self);
         | 
| 1458 | 
            +
                return INT2NUM(llama_token_suffix(ptr->model));
         | 
| 1459 | 
            +
              }
         | 
| 1460 | 
            +
             | 
| 1461 | 
            +
              static VALUE _llama_model_token_eot(VALUE self) {
         | 
| 1462 | 
            +
                LLaMAModelWrapper* ptr = get_llama_model(self);
         | 
| 1463 | 
            +
                return INT2NUM(llama_token_eot(ptr->model));
         | 
| 1464 | 
            +
              }
         | 
| 1379 1465 | 
             
            };
         | 
| 1380 1466 |  | 
| 1381 1467 | 
             
            const rb_data_type_t RbLLaMAModel::llama_model_type = {
         | 
| @@ -1650,16 +1736,6 @@ public: | |
| 1650 1736 | 
             
                rb_define_method(rb_cLLaMAContext, "decode", RUBY_METHOD_FUNC(_llama_context_decode), 1);
         | 
| 1651 1737 | 
             
                rb_define_method(rb_cLLaMAContext, "logits", RUBY_METHOD_FUNC(_llama_context_logits), 0);
         | 
| 1652 1738 | 
             
                rb_define_method(rb_cLLaMAContext, "embeddings", RUBY_METHOD_FUNC(_llama_context_embeddings), 0);
         | 
| 1653 | 
            -
                rb_define_method(rb_cLLaMAContext, "text", RUBY_METHOD_FUNC(_llama_context_text), 1);
         | 
| 1654 | 
            -
                rb_define_method(rb_cLLaMAContext, "score", RUBY_METHOD_FUNC(_llama_context_score), 1);
         | 
| 1655 | 
            -
                rb_define_method(rb_cLLaMAContext, "type", RUBY_METHOD_FUNC(_llama_context_type), 1);
         | 
| 1656 | 
            -
                rb_define_method(rb_cLLaMAContext, "token_bos", RUBY_METHOD_FUNC(_llama_context_token_bos), 0);
         | 
| 1657 | 
            -
                rb_define_method(rb_cLLaMAContext, "token_eos", RUBY_METHOD_FUNC(_llama_context_token_eos), 0);
         | 
| 1658 | 
            -
                rb_define_method(rb_cLLaMAContext, "token_nl", RUBY_METHOD_FUNC(_llama_context_token_nl), 0);
         | 
| 1659 | 
            -
                rb_define_method(rb_cLLaMAContext, "token_prefix", RUBY_METHOD_FUNC(_llama_context_token_prefix), 0);
         | 
| 1660 | 
            -
                rb_define_method(rb_cLLaMAContext, "token_middle", RUBY_METHOD_FUNC(_llama_context_token_middle), 0);
         | 
| 1661 | 
            -
                rb_define_method(rb_cLLaMAContext, "token_suffix", RUBY_METHOD_FUNC(_llama_context_token_suffix), 0);
         | 
| 1662 | 
            -
                rb_define_method(rb_cLLaMAContext, "token_eot", RUBY_METHOD_FUNC(_llama_context_token_eot), 0);
         | 
| 1663 1739 | 
             
                rb_define_method(rb_cLLaMAContext, "n_ctx", RUBY_METHOD_FUNC(_llama_context_n_ctx), 0);
         | 
| 1664 1740 | 
             
                rb_define_method(rb_cLLaMAContext, "timings", RUBY_METHOD_FUNC(_llama_context_get_timings), 0);
         | 
| 1665 1741 | 
             
                rb_define_method(rb_cLLaMAContext, "print_timings", RUBY_METHOD_FUNC(_llama_context_print_timings), 0);
         | 
| @@ -1673,8 +1749,7 @@ public: | |
| 1673 1749 | 
             
                rb_define_method(rb_cLLaMAContext, "set_rng_seed", RUBY_METHOD_FUNC(_llama_context_set_rng_seed), 1);
         | 
| 1674 1750 | 
             
                rb_define_method(rb_cLLaMAContext, "load_session_file", RUBY_METHOD_FUNC(_llama_context_load_session_file), -1);
         | 
| 1675 1751 | 
             
                rb_define_method(rb_cLLaMAContext, "save_session_file", RUBY_METHOD_FUNC(_llama_context_save_session_file), -1);
         | 
| 1676 | 
            -
                rb_define_method(rb_cLLaMAContext, " | 
| 1677 | 
            -
                rb_define_method(rb_cLLaMAContext, "sample_frequency_and_presence_penalties", RUBY_METHOD_FUNC(_llama_context_sample_frequency_and_presence_penalties), -1);
         | 
| 1752 | 
            +
                rb_define_method(rb_cLLaMAContext, "sample_repetition_penalties", RUBY_METHOD_FUNC(_llama_context_sample_repetition_penalties), -1);
         | 
| 1678 1753 | 
             
                rb_define_method(rb_cLLaMAContext, "sample_classifier_free_guidance", RUBY_METHOD_FUNC(_llama_context_sample_classifier_free_guidance), -1);
         | 
| 1679 1754 | 
             
                rb_define_method(rb_cLLaMAContext, "sample_softmax", RUBY_METHOD_FUNC(_llama_context_sample_softmax), 1);
         | 
| 1680 1755 | 
             
                rb_define_method(rb_cLLaMAContext, "sample_top_k", RUBY_METHOD_FUNC(_llama_context_sample_top_k), -1);
         | 
| @@ -1907,102 +1982,6 @@ private: | |
| 1907 1982 | 
             
                return output;
         | 
| 1908 1983 | 
             
              }
         | 
| 1909 1984 |  | 
| 1910 | 
            -
              static VALUE _llama_context_text(VALUE self, VALUE token_) {
         | 
| 1911 | 
            -
                LLaMAContextWrapper* ptr = get_llama_context(self);
         | 
| 1912 | 
            -
                if (ptr->ctx == NULL) {
         | 
| 1913 | 
            -
                  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
         | 
| 1914 | 
            -
                  return Qnil;
         | 
| 1915 | 
            -
                }
         | 
| 1916 | 
            -
                const llama_token token = NUM2INT(token_);
         | 
| 1917 | 
            -
                const char* text = llama_token_get_text(ptr->ctx, token);
         | 
| 1918 | 
            -
                return rb_utf8_str_new_cstr(text);
         | 
| 1919 | 
            -
              }
         | 
| 1920 | 
            -
             | 
| 1921 | 
            -
              static VALUE _llama_context_score(VALUE self, VALUE token_) {
         | 
| 1922 | 
            -
                LLaMAContextWrapper* ptr = get_llama_context(self);
         | 
| 1923 | 
            -
                if (ptr->ctx == NULL) {
         | 
| 1924 | 
            -
                  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
         | 
| 1925 | 
            -
                  return Qnil;
         | 
| 1926 | 
            -
                }
         | 
| 1927 | 
            -
                const llama_token token = NUM2INT(token_);
         | 
| 1928 | 
            -
                const float score = llama_token_get_score(ptr->ctx, token);
         | 
| 1929 | 
            -
                return DBL2NUM(score);
         | 
| 1930 | 
            -
              }
         | 
| 1931 | 
            -
             | 
| 1932 | 
            -
              static VALUE _llama_context_type(VALUE self, VALUE token_) {
         | 
| 1933 | 
            -
                LLaMAContextWrapper* ptr = get_llama_context(self);
         | 
| 1934 | 
            -
                if (ptr->ctx == NULL) {
         | 
| 1935 | 
            -
                  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
         | 
| 1936 | 
            -
                  return Qnil;
         | 
| 1937 | 
            -
                }
         | 
| 1938 | 
            -
                const llama_token token = NUM2INT(token_);
         | 
| 1939 | 
            -
                const int type = llama_token_get_type(ptr->ctx, token);
         | 
| 1940 | 
            -
                return INT2NUM(type);
         | 
| 1941 | 
            -
              }
         | 
| 1942 | 
            -
             | 
| 1943 | 
            -
              static VALUE _llama_context_token_bos(VALUE self) {
         | 
| 1944 | 
            -
                LLaMAContextWrapper* ptr = get_llama_context(self);
         | 
| 1945 | 
            -
                if (ptr->ctx == NULL) {
         | 
| 1946 | 
            -
                  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
         | 
| 1947 | 
            -
                  return Qnil;
         | 
| 1948 | 
            -
                }
         | 
| 1949 | 
            -
                return INT2NUM(llama_token_bos(ptr->ctx));
         | 
| 1950 | 
            -
              }
         | 
| 1951 | 
            -
             | 
| 1952 | 
            -
              static VALUE _llama_context_token_eos(VALUE self) {
         | 
| 1953 | 
            -
                LLaMAContextWrapper* ptr = get_llama_context(self);
         | 
| 1954 | 
            -
                if (ptr->ctx == NULL) {
         | 
| 1955 | 
            -
                  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
         | 
| 1956 | 
            -
                  return Qnil;
         | 
| 1957 | 
            -
                }
         | 
| 1958 | 
            -
                return INT2NUM(llama_token_eos(ptr->ctx));
         | 
| 1959 | 
            -
              }
         | 
| 1960 | 
            -
             | 
| 1961 | 
            -
              static VALUE _llama_context_token_nl(VALUE self) {
         | 
| 1962 | 
            -
                LLaMAContextWrapper* ptr = get_llama_context(self);
         | 
| 1963 | 
            -
                if (ptr->ctx == NULL) {
         | 
| 1964 | 
            -
                  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
         | 
| 1965 | 
            -
                  return Qnil;
         | 
| 1966 | 
            -
                }
         | 
| 1967 | 
            -
                return INT2NUM(llama_token_nl(ptr->ctx));
         | 
| 1968 | 
            -
              }
         | 
| 1969 | 
            -
             | 
| 1970 | 
            -
              static VALUE _llama_context_token_prefix(VALUE self) {
         | 
| 1971 | 
            -
                LLaMAContextWrapper* ptr = get_llama_context(self);
         | 
| 1972 | 
            -
                if (ptr->ctx == NULL) {
         | 
| 1973 | 
            -
                  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
         | 
| 1974 | 
            -
                  return Qnil;
         | 
| 1975 | 
            -
                }
         | 
| 1976 | 
            -
                return INT2NUM(llama_token_prefix(ptr->ctx));
         | 
| 1977 | 
            -
              }
         | 
| 1978 | 
            -
             | 
| 1979 | 
            -
              static VALUE _llama_context_token_middle(VALUE self) {
         | 
| 1980 | 
            -
                LLaMAContextWrapper* ptr = get_llama_context(self);
         | 
| 1981 | 
            -
                if (ptr->ctx == NULL) {
         | 
| 1982 | 
            -
                  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
         | 
| 1983 | 
            -
                  return Qnil;
         | 
| 1984 | 
            -
                }
         | 
| 1985 | 
            -
                return INT2NUM(llama_token_middle(ptr->ctx));
         | 
| 1986 | 
            -
              }
         | 
| 1987 | 
            -
             | 
| 1988 | 
            -
              static VALUE _llama_context_token_suffix(VALUE self) {
         | 
| 1989 | 
            -
                LLaMAContextWrapper* ptr = get_llama_context(self);
         | 
| 1990 | 
            -
                if (ptr->ctx == NULL) {
         | 
| 1991 | 
            -
                  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
         | 
| 1992 | 
            -
                  return Qnil;
         | 
| 1993 | 
            -
                }
         | 
| 1994 | 
            -
                return INT2NUM(llama_token_suffix(ptr->ctx));
         | 
| 1995 | 
            -
              }
         | 
| 1996 | 
            -
             | 
| 1997 | 
            -
              static VALUE _llama_context_token_eot(VALUE self) {
         | 
| 1998 | 
            -
                LLaMAContextWrapper* ptr = get_llama_context(self);
         | 
| 1999 | 
            -
                if (ptr->ctx == NULL) {
         | 
| 2000 | 
            -
                  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
         | 
| 2001 | 
            -
                  return Qnil;
         | 
| 2002 | 
            -
                }
         | 
| 2003 | 
            -
                return INT2NUM(llama_token_eot(ptr->ctx));
         | 
| 2004 | 
            -
              }
         | 
| 2005 | 
            -
             | 
| 2006 1985 | 
             
              static VALUE _llama_context_n_ctx(VALUE self) {
         | 
| 2007 1986 | 
             
                LLaMAContextWrapper* ptr = get_llama_context(self);
         | 
| 2008 1987 | 
             
                if (ptr->ctx == NULL) {
         | 
| @@ -2211,14 +2190,14 @@ private: | |
| 2211 2190 | 
             
                return Qnil;
         | 
| 2212 2191 | 
             
              }
         | 
| 2213 2192 |  | 
| 2214 | 
            -
              static VALUE  | 
| 2193 | 
            +
              static VALUE _llama_context_sample_repetition_penalties(int argc, VALUE* argv, VALUE self) {
         | 
| 2215 2194 | 
             
                VALUE kw_args = Qnil;
         | 
| 2216 | 
            -
                ID kw_table[ | 
| 2217 | 
            -
                VALUE kw_values[ | 
| 2195 | 
            +
                ID kw_table[3] = { rb_intern("penalty_repeat"), rb_intern("penalty_freq"), rb_intern("penalty_present") };
         | 
| 2196 | 
            +
                VALUE kw_values[3] = { Qundef, Qundef, Qundef };
         | 
| 2218 2197 | 
             
                VALUE candidates = Qnil;
         | 
| 2219 2198 | 
             
                VALUE last_n_tokens = Qnil;
         | 
| 2220 2199 | 
             
                rb_scan_args(argc, argv, "2:", &candidates, &last_n_tokens, &kw_args);
         | 
| 2221 | 
            -
                rb_get_kwargs(kw_args, kw_table,  | 
| 2200 | 
            +
                rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
         | 
| 2222 2201 |  | 
| 2223 2202 | 
             
                if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
         | 
| 2224 2203 | 
             
                  rb_raise(rb_eArgError, "candidates must be a TokenDataArray");
         | 
| @@ -2229,56 +2208,15 @@ private: | |
| 2229 2208 | 
             
                  return Qnil;
         | 
| 2230 2209 | 
             
                }
         | 
| 2231 2210 | 
             
                if (!RB_FLOAT_TYPE_P(kw_values[0])) {
         | 
| 2232 | 
            -
                  rb_raise(rb_eArgError, " | 
| 2211 | 
            +
                  rb_raise(rb_eArgError, "penalty_repeat must be a float");
         | 
| 2233 2212 | 
             
                  return Qnil;
         | 
| 2234 2213 | 
             
                }
         | 
| 2235 | 
            -
             | 
| 2236 | 
            -
             | 
| 2237 | 
            -
                std::vector<llama_token> last_n_tokens_data(last_tokens_size);
         | 
| 2238 | 
            -
                for (size_t i = 0; i < last_tokens_size; i++) {
         | 
| 2239 | 
            -
                  last_n_tokens_data[i] = NUM2INT(rb_ary_entry(last_n_tokens, i));
         | 
| 2240 | 
            -
                }
         | 
| 2241 | 
            -
             | 
| 2242 | 
            -
                LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
         | 
| 2243 | 
            -
                if (ctx_ptr->ctx == NULL) {
         | 
| 2244 | 
            -
                  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
         | 
| 2245 | 
            -
                  return Qnil;
         | 
| 2246 | 
            -
                }
         | 
| 2247 | 
            -
                LLaMATokenDataArrayWrapper* cnd_ptr = RbLLaMATokenDataArray::get_llama_token_data_array(candidates);
         | 
| 2248 | 
            -
                if (cnd_ptr->array.data == nullptr) {
         | 
| 2249 | 
            -
                  rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
         | 
| 2250 | 
            -
                  return Qnil;
         | 
| 2251 | 
            -
                }
         | 
| 2252 | 
            -
                const float penalty = NUM2DBL(kw_values[0]);
         | 
| 2253 | 
            -
             | 
| 2254 | 
            -
                llama_sample_repetition_penalty(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size, penalty);
         | 
| 2255 | 
            -
             | 
| 2256 | 
            -
                return Qnil;
         | 
| 2257 | 
            -
              }
         | 
| 2258 | 
            -
             | 
| 2259 | 
            -
              static VALUE _llama_context_sample_frequency_and_presence_penalties(int argc, VALUE* argv, VALUE self) {
         | 
| 2260 | 
            -
                VALUE kw_args = Qnil;
         | 
| 2261 | 
            -
                ID kw_table[2] = { rb_intern("frequency"), rb_intern("presence") };
         | 
| 2262 | 
            -
                VALUE kw_values[2] = { Qundef, Qundef };
         | 
| 2263 | 
            -
                VALUE candidates = Qnil;
         | 
| 2264 | 
            -
                VALUE last_n_tokens = Qnil;
         | 
| 2265 | 
            -
                rb_scan_args(argc, argv, "2:", &candidates, &last_n_tokens, &kw_args);
         | 
| 2266 | 
            -
                rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
         | 
| 2267 | 
            -
             | 
| 2268 | 
            -
                if (!rb_obj_is_kind_of(candidates, rb_cLLaMATokenDataArray)) {
         | 
| 2269 | 
            -
                  rb_raise(rb_eArgError, "candidates must be a TokenDataArray");
         | 
| 2270 | 
            -
                  return Qnil;
         | 
| 2271 | 
            -
                }
         | 
| 2272 | 
            -
                if (!RB_TYPE_P(last_n_tokens, T_ARRAY)) {
         | 
| 2273 | 
            -
                  rb_raise(rb_eArgError, "last_n_tokens must be an Array");
         | 
| 2274 | 
            -
                  return Qnil;
         | 
| 2275 | 
            -
                }
         | 
| 2276 | 
            -
                if (!RB_FLOAT_TYPE_P(kw_values[0])) {
         | 
| 2277 | 
            -
                  rb_raise(rb_eArgError, "frequency must be a float");
         | 
| 2214 | 
            +
                if (!RB_FLOAT_TYPE_P(kw_values[1])) {
         | 
| 2215 | 
            +
                  rb_raise(rb_eArgError, "penalty_freq must be a float");
         | 
| 2278 2216 | 
             
                  return Qnil;
         | 
| 2279 2217 | 
             
                }
         | 
| 2280 | 
            -
                if (!RB_FLOAT_TYPE_P(kw_values[ | 
| 2281 | 
            -
                  rb_raise(rb_eArgError, " | 
| 2218 | 
            +
                if (!RB_FLOAT_TYPE_P(kw_values[2])) {
         | 
| 2219 | 
            +
                  rb_raise(rb_eArgError, "penalty_present must be a float");
         | 
| 2282 2220 | 
             
                  return Qnil;
         | 
| 2283 2221 | 
             
                }
         | 
| 2284 2222 |  | 
| @@ -2298,11 +2236,12 @@ private: | |
| 2298 2236 | 
             
                  rb_raise(rb_eRuntimeError, "TokenDataArray is empty");
         | 
| 2299 2237 | 
             
                  return Qnil;
         | 
| 2300 2238 | 
             
                }
         | 
| 2239 | 
            +
                const float penalty_repeat = NUM2DBL(kw_values[0]);
         | 
| 2240 | 
            +
                const float penalty_freq = NUM2DBL(kw_values[1]);
         | 
| 2241 | 
            +
                const float penalty_present = NUM2DBL(kw_values[2]);
         | 
| 2301 2242 |  | 
| 2302 | 
            -
                 | 
| 2303 | 
            -
             | 
| 2304 | 
            -
             | 
| 2305 | 
            -
                llama_sample_frequency_and_presence_penalties(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size, alpha_frequency, alpha_presence);
         | 
| 2243 | 
            +
                llama_sample_repetition_penalties(ctx_ptr->ctx, &(cnd_ptr->array), last_n_tokens_data.data(), last_tokens_size,
         | 
| 2244 | 
            +
                    penalty_repeat, penalty_freq, penalty_present);
         | 
| 2306 2245 |  | 
| 2307 2246 | 
             
                return Qnil;
         | 
| 2308 2247 | 
             
              }
         |