llama_cpp 0.2.2 → 0.3.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: e5e221d4831be790a990b121e6ac780d10b4cbfb85b2a9b4284d9c216f6e5604
4
- data.tar.gz: fba76ac1a70bfd7b02b8d123c57e4c8096a29ac7f658bb090cda91c6a54752d2
3
+ metadata.gz: 9e0152eb9e091932225356614b57fad416c2aa96a83316f8585c9ef2872e1504
4
+ data.tar.gz: 8ea2f00f11be7dd6524bfe69e3181fc63df7c841ed1e2d91b1b2bcafd99d0b66
5
5
  SHA512:
6
- metadata.gz: 994029383219077e134d170177954251c20ede6d1c83843ecd22c42eeae83584079d124b41702f55add7f3f237e9bdb14382fbd37dde2d0e74f8cffcfed1715b
7
- data.tar.gz: ca4e94b6ddf4e4e9ddabbb2b8309cf4b2b06a881df09fdf4ad96e27c4f1f620ca0024ac46f69d9b474849c074a5c9ba9b0440777a0b52a12413bc356457a02f3
6
+ metadata.gz: a85a4bdd2d1fd575eb406b9bebdf7f388db33dc42f7a2980ba9a7a6b346b539854d9df5515c9b6968727e76f035a23f59d4bc65bc5525df962dfbdf56d8b3b01
7
+ data.tar.gz: 33641d622102257dbc1358bde0871a03c595928f5d8cedee512e1df414e4aa93433eadfcd082d4db42046320c1ed7f806dfb3aafd7934a1becb33fe275f9435c
data/CHANGELOG.md CHANGED
@@ -1,3 +1,31 @@
1
+ ## [[0.3.0](https://github.com/yoshoku/llama_cpp.rb/compare/v0.2.2...v0.3.0)] - 2023-06-30
2
+
3
+ - Add no_k_quants and qkk_64 config options:
4
+ ```
5
+ $ gem install llama_cpp -- --with-no_k_quants
6
+ ```
7
+ ```
8
+ $ gem install llama_cpp -- --with-qkk_64
9
+ ```
10
+
11
+ **Breaking Changes**
12
+ - Remove `Client` class to concentrate on developing bindings.
13
+ - Bump bundled llama.cpp from master-7487137 to master-9d23589.
14
+ - llama_init_from_file and llama_apply_lora_from_file are deprecated.
15
+ - Add `Model` class for wrapping llama_model.
16
+ - Move the `apply_lora_from_file method`, `free`, `load`, and `empty?` methods to `Model` class from `Context` class.
17
+ - Change arguments of initialize method of Context. Its initialize method requires Model object instead of the model's file path.
18
+ ```ruby
19
+ requre 'llama_cpp'
20
+
21
+ params = LLaMACpp::ContextParams.new
22
+
23
+ model = LLaMACpp::Model.new(model_path: '/path/to/quantized-model.bin', params: params)
24
+ context = LLaMACpp::Context.new(model: model)
25
+
26
+ LLaMACpp.generate(context, 'Hello, world.')
27
+ ```
28
+
1
29
  ## [[0.2.2](https://github.com/yoshoku/llama_cpp.rb/compare/v0.2.1...v0.2.2)] - 2023-06-24
2
30
 
3
31
  - Bump bundled llama.cpp from master-a09f919 to master-7487137.
data/README.md CHANGED
@@ -20,21 +20,54 @@ If bundler is not being used to manage dependencies, install the gem by executin
20
20
 
21
21
  ## Usage
22
22
 
23
- Prepare the quantized model by refering to [the usage section on the llama.cpp README](https://github.com/ggerganov/llama.cpp#usage) or
24
- download the qunatized model, for example [ggml-vicuna-7b-4bit](https://github.com/ggerganov/llama.cpp/discussions/643#discussioncomment-5541351), from Hugging Face.
23
+ Prepare the quantized model by refering to [the usage section on the llama.cpp README](https://github.com/ggerganov/llama.cpp#usage).
24
+ For example, preparing the quatization model based on [open_llama_7b](https://huggingface.co/openlm-research/open_llama_7b) is as follows:
25
+
26
+ ```sh
27
+ $ cd ~/
28
+ $ brew install git-lfs
29
+ $ git lfs install
30
+ $ git clone https://github.com/ggerganov/llama.cpp.git
31
+ $ cd llama.cpp
32
+ $ python3 -m pip install -r requirements.txt
33
+ $ cd models
34
+ $ git clone https://huggingface.co/openlm-research/open_llama_7b
35
+ $ cd ../
36
+ $ python3 convert.py models/open_llama_7b
37
+ $ make
38
+ $ ./quantize ./models/open_llama_7b/ggml-model-f16.bin ./models/open_llama_7b/ggml-model-q4_0.bin q4_0
39
+ ```
40
+
41
+ An example of Ruby code that generates sentences with the quantization model is as follows:
25
42
 
26
43
  ```ruby
27
44
  require 'llama_cpp'
28
45
 
29
46
  params = LLaMACpp::ContextParams.new
30
- params.seed = 12
47
+ params.seed = 42
31
48
 
32
- context = LLaMACpp::Context.new(model_path: '/path/to/quantized-model.bin', params: params)
49
+ model = LLaMACpp::Model.new(model_path: '/home/user/llama.cpp/models/open_llama_7b/ggml-model-q4_0.bin', params: params)
50
+ context = LLaMACpp::Context.new(model: model)
33
51
 
34
- puts LLaMACpp.generate(context, 'Please tell me the largest city in Japan.', n_threads: 4)
35
- # => "There are two major cities in Japan, Tokyo and Osaka, which have about 30 million populations."
52
+ puts LLaMACpp.generate(context, 'Hello, World.', n_threads: 4)
36
53
  ```
37
54
 
55
+ ## Examples
56
+ There is a sample program in the [examples](https://github.com/yoshoku/llama_cpp.rb/tree/main/examples) directory that allow interactvie communication like ChatGPT.
57
+
58
+ ```sh
59
+ $ git clone https://github.com/yoshoku/llama_cpp.rb.git
60
+ $ cd examples
61
+ $ bundle install
62
+ $ ruby chat.rb --model /home/user/llama.cpp/models/open_llama_7b/ggml-model-q4_0.bin --seed 2023
63
+ ...
64
+ User: Who is the originator of the Ruby programming language?
65
+ Bob: The originator of the Ruby programming language is Mr. Yukihiro Matsumoto.
66
+ User:
67
+ ```
68
+
69
+ ![llama_cpp_chat_example](https://github.com/yoshoku/llama_cpp.rb/assets/5562409/374ae3d8-63a6-498f-ae6e-5552b464bdda)
70
+
38
71
  ## Contributing
39
72
 
40
73
  Bug reports and pull requests are welcome on GitHub at https://github.com/yoshoku/llama_cpp.rb.
data/examples/chat.rb CHANGED
@@ -35,7 +35,8 @@ class Chat < Thor # rubocop:disable Metrics/ClassLength, Style/Documentation
35
35
  params = LLaMACpp::ContextParams.new
36
36
  params.seed = options[:seed]
37
37
  params.n_gpu_layers = options[:n_gpu_layers]
38
- context = LLaMACpp::Context.new(model_path: options[:model], params: params)
38
+ model = LLaMACpp::Model.new(model_path: options[:model], params: params)
39
+ context = LLaMACpp::Context.new(model: model)
39
40
 
40
41
  antiprompt = options[:reverse_prompt] || 'User:'
41
42
  start_prompt = read_prompt(options[:file]) || default_prompt(antiprompt)
@@ -16,12 +16,13 @@ class Embedding < Thor # rubocop:disable Style/Documentation
16
16
  option :model, type: :string, aliases: '-m', desc: 'path to model file', required: true
17
17
  option :prompt, type: :string, aliases: '-p', desc: 'prompt to generate embedding', required: true
18
18
  option :n_gpu_layers, type: :numeric, desc: 'number of layers on GPU', default: 0
19
- def main # rubocop:disable Metrics/AbcSize
19
+ def main # rubocop:disable Metrics/AbcSize, Metrics/MethodLength
20
20
  params = LLaMACpp::ContextParams.new
21
21
  params.seed = options[:seed]
22
22
  params.n_gpu_layers = options[:n_gpu_layers]
23
23
  params.embedding = true
24
- context = LLaMACpp::Context.new(model_path: options[:model], params: params)
24
+ model = LLaMACpp::Model.new(model_path: options[:model], params: params)
25
+ context = LLaMACpp::Context.new(model: model)
25
26
 
26
27
  embd_input = context.tokenize(text: options[:prompt], add_bos: true)
27
28
 
@@ -17,6 +17,17 @@ if RUBY_PLATFORM.match?(/darwin|linux|bsd/) && try_compile('#include <stdio.h>',
17
17
  $CXXFLAGS << ' -pthread'
18
18
  end
19
19
 
20
+ unless with_config('no_k_quants')
21
+ $CFLAGS << ' -DGGML_USE_K_QUANTS'
22
+ $CXXFLAGS << ' -DGGML_USE_K_QUANTS'
23
+ $srcs << 'k_quants.c'
24
+ end
25
+
26
+ if with_config('qkk_64')
27
+ $CFLAGS << ' -DGGML_QKK_64'
28
+ $CXXFLAGS << ' -DGGML_QKK_64'
29
+ end
30
+
20
31
  if with_config('openblas')
21
32
  abort 'libopenblas is not found.' unless have_library('openblas')
22
33
  abort 'cblas.h is not found.' unless have_header('cblas.h')
@@ -42,6 +53,7 @@ if with_config('metal')
42
53
  $CXXFLAGS << ' -DGGML_USE_METAL'
43
54
  $LDFLAGS << ' -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders'
44
55
  $objs = %w[ggml.o llama.o llama_cpp.o ggml-metal.o]
56
+ $objs << 'k_quants.o' unless with_config('no_k_quants')
45
57
  end
46
58
 
47
59
  if with_config('cublas')
@@ -49,6 +61,7 @@ if with_config('cublas')
49
61
  $CXXFLAGS << ' -DGGML_USE_CUBLAS -I/usr/local/cuda/include'
50
62
  $LDFLAGS << ' -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64'
51
63
  $objs = %w[ggml-cuda.o ggml.o llama.o llama_cpp.o]
64
+ $objs << 'k_quants.o' unless with_config('no_k_quants')
52
65
  end
53
66
 
54
67
  if with_config('clblast')
@@ -2,6 +2,7 @@
2
2
  #include "llama_cpp.h"
3
3
 
4
4
  VALUE rb_mLLaMACpp;
5
+ VALUE rb_cLLaMAModel;
5
6
  VALUE rb_cLLaMAContext;
6
7
  VALUE rb_cLLaMAContextParams;
7
8
  VALUE rb_cLLaMAModelQuantizeParams;
@@ -610,6 +611,202 @@ const rb_data_type_t RbLLaMAModelQuantizeParams::llama_model_quantize_params_typ
610
611
  RUBY_TYPED_FREE_IMMEDIATELY
611
612
  };
612
613
 
614
+ class LLaMAModelWrapper {
615
+ public:
616
+ struct llama_model* model;
617
+
618
+ LLaMAModelWrapper() : model(NULL){};
619
+
620
+ ~LLaMAModelWrapper() {
621
+ if (model != NULL) {
622
+ llama_free_model(model);
623
+ }
624
+ };
625
+ };
626
+
627
+ class RbLLaMAModel {
628
+ public:
629
+ static VALUE llama_model_alloc(VALUE self) {
630
+ LLaMAModelWrapper* ptr = (LLaMAModelWrapper*)ruby_xmalloc(sizeof(LLaMAModelWrapper));
631
+ new (ptr) LLaMAModelWrapper();
632
+ return TypedData_Wrap_Struct(self, &llama_model_type, ptr);
633
+ }
634
+
635
+ static void llama_model_free(void* ptr) {
636
+ ((LLaMAModelWrapper*)ptr)->~LLaMAModelWrapper();
637
+ ruby_xfree(ptr);
638
+ }
639
+
640
+ static size_t llama_model_size(const void* ptr) {
641
+ return sizeof(*((LLaMAModelWrapper*)ptr));
642
+ }
643
+
644
+ static LLaMAModelWrapper* get_llama_model(VALUE self) {
645
+ LLaMAModelWrapper* ptr;
646
+ TypedData_Get_Struct(self, LLaMAModelWrapper, &llama_model_type, ptr);
647
+ return ptr;
648
+ }
649
+
650
+ static void define_class(VALUE outer) {
651
+ rb_cLLaMAModel = rb_define_class_under(outer, "Model", rb_cObject);
652
+ rb_define_alloc_func(rb_cLLaMAModel, llama_model_alloc);
653
+ rb_define_method(rb_cLLaMAModel, "initialize", RUBY_METHOD_FUNC(_llama_model_initialize), -1);
654
+ rb_define_method(rb_cLLaMAModel, "empty?", RUBY_METHOD_FUNC(_llama_model_empty), 0);
655
+ rb_define_method(rb_cLLaMAModel, "free", RUBY_METHOD_FUNC(_llama_model_free), 0);
656
+ rb_define_method(rb_cLLaMAModel, "load", RUBY_METHOD_FUNC(_llama_model_load), -1);
657
+ rb_define_method(rb_cLLaMAModel, "apply_lora_from_file", RUBY_METHOD_FUNC(_llama_model_apply_lora_from_file), -1);
658
+ }
659
+
660
+ private:
661
+ static const rb_data_type_t llama_model_type;
662
+
663
+ static VALUE _llama_model_initialize(int argc, VALUE* argv, VALUE self) {
664
+ VALUE kw_args = Qnil;
665
+ ID kw_table[2] = { rb_intern("model_path"), rb_intern("params") };
666
+ VALUE kw_values[2] = { Qundef, Qundef };
667
+ rb_scan_args(argc, argv, ":", &kw_args);
668
+ rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
669
+
670
+ if (kw_values[0] == Qundef && kw_values[1] == Qundef) {
671
+ rb_iv_set(self, "@params", Qnil);
672
+ return Qnil;
673
+ }
674
+
675
+ if (!RB_TYPE_P(kw_values[0], T_STRING)) {
676
+ rb_raise(rb_eArgError, "model_path must be a string");
677
+ return Qnil;
678
+ }
679
+ if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAContextParams)) {
680
+ rb_raise(rb_eArgError, "params must be a ContextParams");
681
+ return Qnil;
682
+ }
683
+
684
+ VALUE filename = kw_values[0];
685
+ LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(kw_values[1]);
686
+ LLaMAModelWrapper* model_ptr = get_llama_model(self);
687
+
688
+ try {
689
+ model_ptr->model = llama_load_model_from_file(StringValueCStr(filename), prms_ptr->params);
690
+ } catch (const std::runtime_error& e) {
691
+ rb_raise(rb_eRuntimeError, "%s", e.what());
692
+ return Qnil;
693
+ }
694
+
695
+ if (model_ptr->model == NULL) {
696
+ rb_raise(rb_eRuntimeError, "Failed to initialize LLaMA model");
697
+ return Qnil;
698
+ }
699
+
700
+ rb_iv_set(self, "@params", kw_values[1]);
701
+
702
+ RB_GC_GUARD(filename);
703
+ return Qnil;
704
+ }
705
+
706
+ static VALUE _llama_model_empty(VALUE self) {
707
+ LLaMAModelWrapper* ptr = get_llama_model(self);
708
+ if (ptr->model != NULL) {
709
+ return Qfalse;
710
+ }
711
+ return Qtrue;
712
+ }
713
+
714
+ static VALUE _llama_model_free(VALUE self) {
715
+ LLaMAModelWrapper* ptr = get_llama_model(self);
716
+ if (ptr->model != NULL) {
717
+ llama_free_model(ptr->model);
718
+ ptr->model = NULL;
719
+ rb_iv_set(self, "@params", Qnil);
720
+ }
721
+ return Qnil;
722
+ }
723
+
724
+ static VALUE _llama_model_load(int argc, VALUE* argv, VALUE self) {
725
+ VALUE kw_args = Qnil;
726
+ ID kw_table[2] = { rb_intern("model_path"), rb_intern("params") };
727
+ VALUE kw_values[2] = { Qundef, Qundef };
728
+ rb_scan_args(argc, argv, ":", &kw_args);
729
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
730
+
731
+ if (!RB_TYPE_P(kw_values[0], T_STRING)) {
732
+ rb_raise(rb_eArgError, "model_path must be a string");
733
+ return Qnil;
734
+ }
735
+ if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAContextParams)) {
736
+ rb_raise(rb_eArgError, "params must be a LLaMAContextParams");
737
+ return Qnil;
738
+ }
739
+
740
+ LLaMAModelWrapper* model_ptr = get_llama_model(self);
741
+ if (model_ptr->model != NULL) {
742
+ rb_raise(rb_eRuntimeError, "LLaMA model is already loaded");
743
+ return Qnil;
744
+ }
745
+
746
+ VALUE filename = kw_values[0];
747
+ LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(kw_values[1]);
748
+
749
+ try {
750
+ model_ptr->model = llama_load_model_from_file(StringValueCStr(filename), prms_ptr->params);
751
+ } catch (const std::runtime_error& e) {
752
+ rb_raise(rb_eRuntimeError, "%s", e.what());
753
+ return Qnil;
754
+ }
755
+
756
+ if (model_ptr->model == NULL) {
757
+ rb_raise(rb_eRuntimeError, "Failed to initialize LLaMA model");
758
+ return Qnil;
759
+ }
760
+
761
+ rb_iv_set(self, "@params", kw_values[1]);
762
+
763
+ RB_GC_GUARD(filename);
764
+ return Qnil;
765
+ }
766
+
767
+ static VALUE _llama_model_apply_lora_from_file(int argc, VALUE* argv, VALUE self) {
768
+ VALUE kw_args = Qnil;
769
+ ID kw_table[3] = { rb_intern("lora_path"), rb_intern("base_model_path"), rb_intern("n_threads") };
770
+ VALUE kw_values[3] = { Qundef, Qundef, Qundef };
771
+ rb_scan_args(argc, argv, ":", &kw_args);
772
+ rb_get_kwargs(kw_args, kw_table, 1, 2, kw_values);
773
+
774
+ if (!RB_TYPE_P(kw_values[0], T_STRING)) {
775
+ rb_raise(rb_eArgError, "lora_path must be a string");
776
+ return Qnil;
777
+ }
778
+ if (kw_values[1] != Qundef && !RB_TYPE_P(kw_values[1], T_STRING)) {
779
+ rb_raise(rb_eArgError, "base_model_path must be a string");
780
+ return Qnil;
781
+ }
782
+ if (kw_values[2] != Qundef && !RB_INTEGER_TYPE_P(kw_values[2])) {
783
+ rb_raise(rb_eArgError, "n_threads must be an integer");
784
+ return Qnil;
785
+ }
786
+
787
+ const char* lora_path = StringValueCStr(kw_values[0]);
788
+ const char* base_model_path = kw_values[1] == Qundef ? NULL : StringValueCStr(kw_values[1]);
789
+ const int n_threads = kw_values[2] == Qundef ? 1 : NUM2INT(kw_values[2]);
790
+
791
+ LLaMAModelWrapper* ptr = get_llama_model(self);
792
+ if (llama_model_apply_lora_from_file(ptr->model, lora_path, base_model_path, n_threads) != 0) {
793
+ rb_raise(rb_eRuntimeError, "Failed to apply LoRA");
794
+ return Qnil;
795
+ }
796
+ return Qnil;
797
+ };
798
+ };
799
+
800
+ const rb_data_type_t RbLLaMAModel::llama_model_type = {
801
+ "RbLLaMAModel",
802
+ { NULL,
803
+ RbLLaMAModel::llama_model_free,
804
+ RbLLaMAModel::llama_model_size },
805
+ NULL,
806
+ NULL,
807
+ RUBY_TYPED_FREE_IMMEDIATELY
808
+ };
809
+
613
810
  class LLaMAContextWrapper {
614
811
  public:
615
812
  struct llama_context* ctx;
@@ -662,10 +859,6 @@ public:
662
859
  rb_define_method(rb_cLLaMAContext, "n_embd", RUBY_METHOD_FUNC(_llama_context_n_embd), 0);
663
860
  rb_define_method(rb_cLLaMAContext, "print_timings", RUBY_METHOD_FUNC(_llama_context_print_timings), 0);
664
861
  rb_define_method(rb_cLLaMAContext, "reset_timings", RUBY_METHOD_FUNC(_llama_context_reset_timings), 0);
665
- rb_define_method(rb_cLLaMAContext, "empty?", RUBY_METHOD_FUNC(_llama_context_empty), 0);
666
- rb_define_method(rb_cLLaMAContext, "free", RUBY_METHOD_FUNC(_llama_context_free), 0);
667
- rb_define_method(rb_cLLaMAContext, "load", RUBY_METHOD_FUNC(_llama_context_load), -1);
668
- rb_define_method(rb_cLLaMAContext, "apply_lora_from_file", RUBY_METHOD_FUNC(_llama_context_apply_lora_from_file), -1);
669
862
  rb_define_method(rb_cLLaMAContext, "kv_cache_token_count", RUBY_METHOD_FUNC(_llama_context_kv_cache_token_count), 0);
670
863
  rb_define_method(rb_cLLaMAContext, "set_rng_seed", RUBY_METHOD_FUNC(_llama_context_set_rng_seed), 1);
671
864
  rb_define_method(rb_cLLaMAContext, "load_session_file", RUBY_METHOD_FUNC(_llama_context_load_session_file), -1);
@@ -689,46 +882,37 @@ private:
689
882
 
690
883
  static VALUE _llama_context_initialize(int argc, VALUE* argv, VALUE self) {
691
884
  VALUE kw_args = Qnil;
692
- ID kw_table[2] = { rb_intern("model_path"), rb_intern("params") };
693
- VALUE kw_values[2] = { Qundef, Qundef };
885
+ ID kw_table[1] = { rb_intern("model") };
886
+ VALUE kw_values[1] = { Qundef };
694
887
  rb_scan_args(argc, argv, ":", &kw_args);
695
- rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
888
+ rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
696
889
 
697
- if (kw_values[0] == Qundef && kw_values[1] == Qundef) {
698
- rb_iv_set(self, "@params", Qnil);
699
- rb_iv_set(self, "@has_evaluated", Qfalse);
890
+ VALUE model = kw_values[0];
891
+ if (!rb_obj_is_kind_of(model, rb_cLLaMAModel)) {
892
+ rb_raise(rb_eArgError, "model must be a Model");
700
893
  return Qnil;
701
894
  }
702
895
 
703
- if (!RB_TYPE_P(kw_values[0], T_STRING)) {
704
- rb_raise(rb_eArgError, "model_path must be a string");
705
- return Qnil;
706
- }
707
- if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAContextParams)) {
708
- rb_raise(rb_eArgError, "params must be a ContextParams");
896
+ LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
897
+ if (model_ptr->model == NULL) {
898
+ rb_raise(rb_eRuntimeError, "Model is empty");
709
899
  return Qnil;
710
900
  }
711
901
 
712
- VALUE filename = kw_values[0];
713
- LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(kw_values[1]);
902
+ VALUE params = rb_iv_get(model, "@params");
903
+ LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
714
904
  LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
715
905
 
716
- try {
717
- ctx_ptr->ctx = llama_init_from_file(StringValueCStr(filename), prms_ptr->params);
718
- } catch (const std::runtime_error& e) {
719
- rb_raise(rb_eRuntimeError, "%s", e.what());
720
- return Qnil;
721
- }
906
+ ctx_ptr->ctx = llama_new_context_with_model(model_ptr->model, prms_ptr->params);
722
907
 
723
908
  if (ctx_ptr->ctx == NULL) {
724
909
  rb_raise(rb_eRuntimeError, "Failed to initialize LLaMA context");
725
910
  return Qnil;
726
911
  }
727
912
 
728
- rb_iv_set(self, "@params", kw_values[1]);
913
+ rb_iv_set(self, "@model", model);
729
914
  rb_iv_set(self, "@has_evaluated", Qfalse);
730
915
 
731
- RB_GC_GUARD(filename);
732
916
  return Qnil;
733
917
  };
734
918
 
@@ -873,7 +1057,9 @@ private:
873
1057
  return Qnil;
874
1058
  }
875
1059
 
876
- LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(rb_iv_get(self, "@params"));
1060
+ VALUE model = rb_iv_get(self, "@model");
1061
+ VALUE params = rb_iv_get(model, "@params");
1062
+ LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
877
1063
  const int n_tokens = prms_ptr->params.logits_all ? NUM2INT(rb_iv_get(self, "@n_tokens")) : 1;
878
1064
  const int n_vocab = llama_n_vocab(ptr->ctx);
879
1065
  const float* logits = llama_get_logits(ptr->ctx);
@@ -891,7 +1077,9 @@ private:
891
1077
  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
892
1078
  return Qnil;
893
1079
  }
894
- LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(rb_iv_get(self, "@params"));
1080
+ VALUE model = rb_iv_get(self, "@model");
1081
+ VALUE params = rb_iv_get(model, "@params");
1082
+ LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
895
1083
  if (!prms_ptr->params.embedding) {
896
1084
  rb_raise(rb_eRuntimeError, "embedding parameter is false");
897
1085
  return Qnil;
@@ -995,106 +1183,6 @@ private:
995
1183
  return Qnil;
996
1184
  };
997
1185
 
998
- static VALUE _llama_context_empty(VALUE self) {
999
- LLaMAContextWrapper* ptr = get_llama_context(self);
1000
- if (ptr->ctx != NULL) {
1001
- return Qfalse;
1002
- }
1003
- return Qtrue;
1004
- }
1005
-
1006
- static VALUE _llama_context_free(VALUE self) {
1007
- LLaMAContextWrapper* ptr = get_llama_context(self);
1008
- if (ptr->ctx != NULL) {
1009
- llama_free(ptr->ctx);
1010
- ptr->ctx = NULL;
1011
- rb_iv_set(self, "@params", Qnil);
1012
- rb_iv_set(self, "@has_evaluated", Qfalse);
1013
- }
1014
- return Qnil;
1015
- }
1016
-
1017
- static VALUE _llama_context_load(int argc, VALUE* argv, VALUE self) {
1018
- VALUE kw_args = Qnil;
1019
- ID kw_table[2] = { rb_intern("model_path"), rb_intern("params") };
1020
- VALUE kw_values[2] = { Qundef, Qundef };
1021
- rb_scan_args(argc, argv, ":", &kw_args);
1022
- rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
1023
-
1024
- if (!RB_TYPE_P(kw_values[0], T_STRING)) {
1025
- rb_raise(rb_eArgError, "model_path must be a string");
1026
- return Qnil;
1027
- }
1028
- if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAContextParams)) {
1029
- rb_raise(rb_eArgError, "params must be a LLaMAContextParams");
1030
- return Qnil;
1031
- }
1032
-
1033
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
1034
- if (ctx_ptr->ctx != NULL) {
1035
- rb_raise(rb_eRuntimeError, "LLaMA context is already loaded");
1036
- return Qnil;
1037
- }
1038
-
1039
- VALUE filename = kw_values[0];
1040
- LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(kw_values[1]);
1041
-
1042
- try {
1043
- ctx_ptr->ctx = llama_init_from_file(StringValueCStr(filename), prms_ptr->params);
1044
- } catch (const std::runtime_error& e) {
1045
- rb_raise(rb_eRuntimeError, "%s", e.what());
1046
- return Qnil;
1047
- }
1048
-
1049
- if (ctx_ptr->ctx == NULL) {
1050
- rb_raise(rb_eRuntimeError, "Failed to initialize LLaMA context");
1051
- return Qnil;
1052
- }
1053
-
1054
- rb_iv_set(self, "@params", kw_values[1]);
1055
- rb_iv_set(self, "@has_evaluated", Qfalse);
1056
-
1057
- RB_GC_GUARD(filename);
1058
- return Qnil;
1059
- };
1060
-
1061
- static VALUE _llama_context_apply_lora_from_file(int argc, VALUE* argv, VALUE self) {
1062
- VALUE kw_args = Qnil;
1063
- ID kw_table[3] = { rb_intern("lora_path"), rb_intern("base_model_path"), rb_intern("n_threads") };
1064
- VALUE kw_values[3] = { Qundef, Qundef, Qundef };
1065
- rb_scan_args(argc, argv, ":", &kw_args);
1066
- rb_get_kwargs(kw_args, kw_table, 1, 2, kw_values);
1067
-
1068
- if (!RB_TYPE_P(kw_values[0], T_STRING)) {
1069
- rb_raise(rb_eArgError, "lora_path must be a string");
1070
- return Qnil;
1071
- }
1072
- if (kw_values[1] != Qundef && !RB_TYPE_P(kw_values[1], T_STRING)) {
1073
- rb_raise(rb_eArgError, "base_model_path must be a string");
1074
- return Qnil;
1075
- }
1076
- if (kw_values[2] != Qundef && !RB_INTEGER_TYPE_P(kw_values[2])) {
1077
- rb_raise(rb_eArgError, "n_threads must be an integer");
1078
- return Qnil;
1079
- }
1080
-
1081
- const char* lora_path = StringValueCStr(kw_values[0]);
1082
- const char* base_model_path = kw_values[1] == Qundef ? NULL : StringValueCStr(kw_values[1]);
1083
- const int n_threads = kw_values[2] == Qundef ? 1 : NUM2INT(kw_values[2]);
1084
-
1085
- LLaMAContextWrapper* ptr = get_llama_context(self);
1086
- if (ptr->ctx != NULL) {
1087
- rb_raise(rb_eRuntimeError, "LLaMA context is already loaded");
1088
- return Qnil;
1089
- }
1090
-
1091
- if (llama_apply_lora_from_file(ptr->ctx, lora_path, base_model_path, n_threads) != 0) {
1092
- rb_raise(rb_eRuntimeError, "Failed to apply LoRA");
1093
- return Qnil;
1094
- }
1095
- return Qnil;
1096
- };
1097
-
1098
1186
  static VALUE _llama_context_kv_cache_token_count(VALUE self) {
1099
1187
  LLaMAContextWrapper* ptr = get_llama_context(self);
1100
1188
  if (ptr->ctx == NULL) {
@@ -1137,7 +1225,9 @@ private:
1137
1225
  return Qnil;
1138
1226
  }
1139
1227
 
1140
- LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(rb_iv_get(self, "@params"));
1228
+ VALUE model = rb_iv_get(self, "@model");
1229
+ VALUE params = rb_iv_get(model, "@params");
1230
+ LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
1141
1231
  const int n_ctx = prms_ptr->params.n_ctx;
1142
1232
 
1143
1233
  std::vector<llama_token> session_tokens(n_ctx);
@@ -1664,8 +1754,16 @@ const rb_data_type_t RbLLaMAContext::llama_context_type = {
1664
1754
 
1665
1755
  // module functions
1666
1756
 
1667
- static VALUE rb_llama_llama_init_backend(VALUE self) {
1668
- llama_init_backend();
1757
+ static VALUE rb_llama_llama_init_backend(int argc, VALUE* argv, VALUE self) {
1758
+ VALUE kw_args = Qnil;
1759
+ ID kw_table[1] = { rb_intern("numa") };
1760
+ VALUE kw_values[1] = { Qundef };
1761
+ rb_scan_args(argc, argv, ":", &kw_args);
1762
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
1763
+
1764
+ const bool numa = kw_values[0] == Qundef ? false : (RTEST ? true : false);
1765
+ llama_init_backend(numa);
1766
+
1669
1767
  return Qnil;
1670
1768
  }
1671
1769
 
@@ -1731,10 +1829,11 @@ extern "C" void Init_llama_cpp(void) {
1731
1829
 
1732
1830
  RbLLaMATokenData::define_class(rb_mLLaMACpp);
1733
1831
  RbLLaMATokenDataArray::define_class(rb_mLLaMACpp);
1832
+ RbLLaMAModel::define_class(rb_mLLaMACpp);
1734
1833
  RbLLaMAContext::define_class(rb_mLLaMACpp);
1735
1834
  RbLLaMAContextParams::define_class(rb_mLLaMACpp);
1736
1835
 
1737
- rb_define_module_function(rb_mLLaMACpp, "init_backend", rb_llama_llama_init_backend, 0);
1836
+ rb_define_module_function(rb_mLLaMACpp, "init_backend", rb_llama_llama_init_backend, -1);
1738
1837
  rb_define_module_function(rb_mLLaMACpp, "model_quantize", rb_llama_model_quantize, -1);
1739
1838
  rb_define_module_function(rb_mLLaMACpp, "token_bos", rb_llama_token_bos, 0);
1740
1839
  rb_define_module_function(rb_mLLaMACpp, "token_eos", rb_llama_token_eos, 0);