llama_cpp 0.2.2 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: e5e221d4831be790a990b121e6ac780d10b4cbfb85b2a9b4284d9c216f6e5604
4
- data.tar.gz: fba76ac1a70bfd7b02b8d123c57e4c8096a29ac7f658bb090cda91c6a54752d2
3
+ metadata.gz: 9e0152eb9e091932225356614b57fad416c2aa96a83316f8585c9ef2872e1504
4
+ data.tar.gz: 8ea2f00f11be7dd6524bfe69e3181fc63df7c841ed1e2d91b1b2bcafd99d0b66
5
5
  SHA512:
6
- metadata.gz: 994029383219077e134d170177954251c20ede6d1c83843ecd22c42eeae83584079d124b41702f55add7f3f237e9bdb14382fbd37dde2d0e74f8cffcfed1715b
7
- data.tar.gz: ca4e94b6ddf4e4e9ddabbb2b8309cf4b2b06a881df09fdf4ad96e27c4f1f620ca0024ac46f69d9b474849c074a5c9ba9b0440777a0b52a12413bc356457a02f3
6
+ metadata.gz: a85a4bdd2d1fd575eb406b9bebdf7f388db33dc42f7a2980ba9a7a6b346b539854d9df5515c9b6968727e76f035a23f59d4bc65bc5525df962dfbdf56d8b3b01
7
+ data.tar.gz: 33641d622102257dbc1358bde0871a03c595928f5d8cedee512e1df414e4aa93433eadfcd082d4db42046320c1ed7f806dfb3aafd7934a1becb33fe275f9435c
data/CHANGELOG.md CHANGED
@@ -1,3 +1,31 @@
1
+ ## [[0.3.0](https://github.com/yoshoku/llama_cpp.rb/compare/v0.2.2...v0.3.0)] - 2023-06-30
2
+
3
+ - Add no_k_quants and qkk_64 config options:
4
+ ```
5
+ $ gem install llama_cpp -- --with-no_k_quants
6
+ ```
7
+ ```
8
+ $ gem install llama_cpp -- --with-qkk_64
9
+ ```
10
+
11
+ **Breaking Changes**
12
+ - Remove `Client` class to concentrate on developing bindings.
13
+ - Bump bundled llama.cpp from master-7487137 to master-9d23589.
14
+ - llama_init_from_file and llama_apply_lora_from_file are deprecated.
15
+ - Add `Model` class for wrapping llama_model.
16
+ - Move the `apply_lora_from_file method`, `free`, `load`, and `empty?` methods to `Model` class from `Context` class.
17
+ - Change arguments of initialize method of Context. Its initialize method requires Model object instead of the model's file path.
18
+ ```ruby
19
+ requre 'llama_cpp'
20
+
21
+ params = LLaMACpp::ContextParams.new
22
+
23
+ model = LLaMACpp::Model.new(model_path: '/path/to/quantized-model.bin', params: params)
24
+ context = LLaMACpp::Context.new(model: model)
25
+
26
+ LLaMACpp.generate(context, 'Hello, world.')
27
+ ```
28
+
1
29
  ## [[0.2.2](https://github.com/yoshoku/llama_cpp.rb/compare/v0.2.1...v0.2.2)] - 2023-06-24
2
30
 
3
31
  - Bump bundled llama.cpp from master-a09f919 to master-7487137.
data/README.md CHANGED
@@ -20,21 +20,54 @@ If bundler is not being used to manage dependencies, install the gem by executin
20
20
 
21
21
  ## Usage
22
22
 
23
- Prepare the quantized model by refering to [the usage section on the llama.cpp README](https://github.com/ggerganov/llama.cpp#usage) or
24
- download the qunatized model, for example [ggml-vicuna-7b-4bit](https://github.com/ggerganov/llama.cpp/discussions/643#discussioncomment-5541351), from Hugging Face.
23
+ Prepare the quantized model by refering to [the usage section on the llama.cpp README](https://github.com/ggerganov/llama.cpp#usage).
24
+ For example, preparing the quatization model based on [open_llama_7b](https://huggingface.co/openlm-research/open_llama_7b) is as follows:
25
+
26
+ ```sh
27
+ $ cd ~/
28
+ $ brew install git-lfs
29
+ $ git lfs install
30
+ $ git clone https://github.com/ggerganov/llama.cpp.git
31
+ $ cd llama.cpp
32
+ $ python3 -m pip install -r requirements.txt
33
+ $ cd models
34
+ $ git clone https://huggingface.co/openlm-research/open_llama_7b
35
+ $ cd ../
36
+ $ python3 convert.py models/open_llama_7b
37
+ $ make
38
+ $ ./quantize ./models/open_llama_7b/ggml-model-f16.bin ./models/open_llama_7b/ggml-model-q4_0.bin q4_0
39
+ ```
40
+
41
+ An example of Ruby code that generates sentences with the quantization model is as follows:
25
42
 
26
43
  ```ruby
27
44
  require 'llama_cpp'
28
45
 
29
46
  params = LLaMACpp::ContextParams.new
30
- params.seed = 12
47
+ params.seed = 42
31
48
 
32
- context = LLaMACpp::Context.new(model_path: '/path/to/quantized-model.bin', params: params)
49
+ model = LLaMACpp::Model.new(model_path: '/home/user/llama.cpp/models/open_llama_7b/ggml-model-q4_0.bin', params: params)
50
+ context = LLaMACpp::Context.new(model: model)
33
51
 
34
- puts LLaMACpp.generate(context, 'Please tell me the largest city in Japan.', n_threads: 4)
35
- # => "There are two major cities in Japan, Tokyo and Osaka, which have about 30 million populations."
52
+ puts LLaMACpp.generate(context, 'Hello, World.', n_threads: 4)
36
53
  ```
37
54
 
55
+ ## Examples
56
+ There is a sample program in the [examples](https://github.com/yoshoku/llama_cpp.rb/tree/main/examples) directory that allow interactvie communication like ChatGPT.
57
+
58
+ ```sh
59
+ $ git clone https://github.com/yoshoku/llama_cpp.rb.git
60
+ $ cd examples
61
+ $ bundle install
62
+ $ ruby chat.rb --model /home/user/llama.cpp/models/open_llama_7b/ggml-model-q4_0.bin --seed 2023
63
+ ...
64
+ User: Who is the originator of the Ruby programming language?
65
+ Bob: The originator of the Ruby programming language is Mr. Yukihiro Matsumoto.
66
+ User:
67
+ ```
68
+
69
+ ![llama_cpp_chat_example](https://github.com/yoshoku/llama_cpp.rb/assets/5562409/374ae3d8-63a6-498f-ae6e-5552b464bdda)
70
+
38
71
  ## Contributing
39
72
 
40
73
  Bug reports and pull requests are welcome on GitHub at https://github.com/yoshoku/llama_cpp.rb.
data/examples/chat.rb CHANGED
@@ -35,7 +35,8 @@ class Chat < Thor # rubocop:disable Metrics/ClassLength, Style/Documentation
35
35
  params = LLaMACpp::ContextParams.new
36
36
  params.seed = options[:seed]
37
37
  params.n_gpu_layers = options[:n_gpu_layers]
38
- context = LLaMACpp::Context.new(model_path: options[:model], params: params)
38
+ model = LLaMACpp::Model.new(model_path: options[:model], params: params)
39
+ context = LLaMACpp::Context.new(model: model)
39
40
 
40
41
  antiprompt = options[:reverse_prompt] || 'User:'
41
42
  start_prompt = read_prompt(options[:file]) || default_prompt(antiprompt)
@@ -16,12 +16,13 @@ class Embedding < Thor # rubocop:disable Style/Documentation
16
16
  option :model, type: :string, aliases: '-m', desc: 'path to model file', required: true
17
17
  option :prompt, type: :string, aliases: '-p', desc: 'prompt to generate embedding', required: true
18
18
  option :n_gpu_layers, type: :numeric, desc: 'number of layers on GPU', default: 0
19
- def main # rubocop:disable Metrics/AbcSize
19
+ def main # rubocop:disable Metrics/AbcSize, Metrics/MethodLength
20
20
  params = LLaMACpp::ContextParams.new
21
21
  params.seed = options[:seed]
22
22
  params.n_gpu_layers = options[:n_gpu_layers]
23
23
  params.embedding = true
24
- context = LLaMACpp::Context.new(model_path: options[:model], params: params)
24
+ model = LLaMACpp::Model.new(model_path: options[:model], params: params)
25
+ context = LLaMACpp::Context.new(model: model)
25
26
 
26
27
  embd_input = context.tokenize(text: options[:prompt], add_bos: true)
27
28
 
@@ -17,6 +17,17 @@ if RUBY_PLATFORM.match?(/darwin|linux|bsd/) && try_compile('#include <stdio.h>',
17
17
  $CXXFLAGS << ' -pthread'
18
18
  end
19
19
 
20
+ unless with_config('no_k_quants')
21
+ $CFLAGS << ' -DGGML_USE_K_QUANTS'
22
+ $CXXFLAGS << ' -DGGML_USE_K_QUANTS'
23
+ $srcs << 'k_quants.c'
24
+ end
25
+
26
+ if with_config('qkk_64')
27
+ $CFLAGS << ' -DGGML_QKK_64'
28
+ $CXXFLAGS << ' -DGGML_QKK_64'
29
+ end
30
+
20
31
  if with_config('openblas')
21
32
  abort 'libopenblas is not found.' unless have_library('openblas')
22
33
  abort 'cblas.h is not found.' unless have_header('cblas.h')
@@ -42,6 +53,7 @@ if with_config('metal')
42
53
  $CXXFLAGS << ' -DGGML_USE_METAL'
43
54
  $LDFLAGS << ' -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders'
44
55
  $objs = %w[ggml.o llama.o llama_cpp.o ggml-metal.o]
56
+ $objs << 'k_quants.o' unless with_config('no_k_quants')
45
57
  end
46
58
 
47
59
  if with_config('cublas')
@@ -49,6 +61,7 @@ if with_config('cublas')
49
61
  $CXXFLAGS << ' -DGGML_USE_CUBLAS -I/usr/local/cuda/include'
50
62
  $LDFLAGS << ' -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64'
51
63
  $objs = %w[ggml-cuda.o ggml.o llama.o llama_cpp.o]
64
+ $objs << 'k_quants.o' unless with_config('no_k_quants')
52
65
  end
53
66
 
54
67
  if with_config('clblast')
@@ -2,6 +2,7 @@
2
2
  #include "llama_cpp.h"
3
3
 
4
4
  VALUE rb_mLLaMACpp;
5
+ VALUE rb_cLLaMAModel;
5
6
  VALUE rb_cLLaMAContext;
6
7
  VALUE rb_cLLaMAContextParams;
7
8
  VALUE rb_cLLaMAModelQuantizeParams;
@@ -610,6 +611,202 @@ const rb_data_type_t RbLLaMAModelQuantizeParams::llama_model_quantize_params_typ
610
611
  RUBY_TYPED_FREE_IMMEDIATELY
611
612
  };
612
613
 
614
+ class LLaMAModelWrapper {
615
+ public:
616
+ struct llama_model* model;
617
+
618
+ LLaMAModelWrapper() : model(NULL){};
619
+
620
+ ~LLaMAModelWrapper() {
621
+ if (model != NULL) {
622
+ llama_free_model(model);
623
+ }
624
+ };
625
+ };
626
+
627
+ class RbLLaMAModel {
628
+ public:
629
+ static VALUE llama_model_alloc(VALUE self) {
630
+ LLaMAModelWrapper* ptr = (LLaMAModelWrapper*)ruby_xmalloc(sizeof(LLaMAModelWrapper));
631
+ new (ptr) LLaMAModelWrapper();
632
+ return TypedData_Wrap_Struct(self, &llama_model_type, ptr);
633
+ }
634
+
635
+ static void llama_model_free(void* ptr) {
636
+ ((LLaMAModelWrapper*)ptr)->~LLaMAModelWrapper();
637
+ ruby_xfree(ptr);
638
+ }
639
+
640
+ static size_t llama_model_size(const void* ptr) {
641
+ return sizeof(*((LLaMAModelWrapper*)ptr));
642
+ }
643
+
644
+ static LLaMAModelWrapper* get_llama_model(VALUE self) {
645
+ LLaMAModelWrapper* ptr;
646
+ TypedData_Get_Struct(self, LLaMAModelWrapper, &llama_model_type, ptr);
647
+ return ptr;
648
+ }
649
+
650
+ static void define_class(VALUE outer) {
651
+ rb_cLLaMAModel = rb_define_class_under(outer, "Model", rb_cObject);
652
+ rb_define_alloc_func(rb_cLLaMAModel, llama_model_alloc);
653
+ rb_define_method(rb_cLLaMAModel, "initialize", RUBY_METHOD_FUNC(_llama_model_initialize), -1);
654
+ rb_define_method(rb_cLLaMAModel, "empty?", RUBY_METHOD_FUNC(_llama_model_empty), 0);
655
+ rb_define_method(rb_cLLaMAModel, "free", RUBY_METHOD_FUNC(_llama_model_free), 0);
656
+ rb_define_method(rb_cLLaMAModel, "load", RUBY_METHOD_FUNC(_llama_model_load), -1);
657
+ rb_define_method(rb_cLLaMAModel, "apply_lora_from_file", RUBY_METHOD_FUNC(_llama_model_apply_lora_from_file), -1);
658
+ }
659
+
660
+ private:
661
+ static const rb_data_type_t llama_model_type;
662
+
663
+ static VALUE _llama_model_initialize(int argc, VALUE* argv, VALUE self) {
664
+ VALUE kw_args = Qnil;
665
+ ID kw_table[2] = { rb_intern("model_path"), rb_intern("params") };
666
+ VALUE kw_values[2] = { Qundef, Qundef };
667
+ rb_scan_args(argc, argv, ":", &kw_args);
668
+ rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
669
+
670
+ if (kw_values[0] == Qundef && kw_values[1] == Qundef) {
671
+ rb_iv_set(self, "@params", Qnil);
672
+ return Qnil;
673
+ }
674
+
675
+ if (!RB_TYPE_P(kw_values[0], T_STRING)) {
676
+ rb_raise(rb_eArgError, "model_path must be a string");
677
+ return Qnil;
678
+ }
679
+ if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAContextParams)) {
680
+ rb_raise(rb_eArgError, "params must be a ContextParams");
681
+ return Qnil;
682
+ }
683
+
684
+ VALUE filename = kw_values[0];
685
+ LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(kw_values[1]);
686
+ LLaMAModelWrapper* model_ptr = get_llama_model(self);
687
+
688
+ try {
689
+ model_ptr->model = llama_load_model_from_file(StringValueCStr(filename), prms_ptr->params);
690
+ } catch (const std::runtime_error& e) {
691
+ rb_raise(rb_eRuntimeError, "%s", e.what());
692
+ return Qnil;
693
+ }
694
+
695
+ if (model_ptr->model == NULL) {
696
+ rb_raise(rb_eRuntimeError, "Failed to initialize LLaMA model");
697
+ return Qnil;
698
+ }
699
+
700
+ rb_iv_set(self, "@params", kw_values[1]);
701
+
702
+ RB_GC_GUARD(filename);
703
+ return Qnil;
704
+ }
705
+
706
+ static VALUE _llama_model_empty(VALUE self) {
707
+ LLaMAModelWrapper* ptr = get_llama_model(self);
708
+ if (ptr->model != NULL) {
709
+ return Qfalse;
710
+ }
711
+ return Qtrue;
712
+ }
713
+
714
+ static VALUE _llama_model_free(VALUE self) {
715
+ LLaMAModelWrapper* ptr = get_llama_model(self);
716
+ if (ptr->model != NULL) {
717
+ llama_free_model(ptr->model);
718
+ ptr->model = NULL;
719
+ rb_iv_set(self, "@params", Qnil);
720
+ }
721
+ return Qnil;
722
+ }
723
+
724
+ static VALUE _llama_model_load(int argc, VALUE* argv, VALUE self) {
725
+ VALUE kw_args = Qnil;
726
+ ID kw_table[2] = { rb_intern("model_path"), rb_intern("params") };
727
+ VALUE kw_values[2] = { Qundef, Qundef };
728
+ rb_scan_args(argc, argv, ":", &kw_args);
729
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
730
+
731
+ if (!RB_TYPE_P(kw_values[0], T_STRING)) {
732
+ rb_raise(rb_eArgError, "model_path must be a string");
733
+ return Qnil;
734
+ }
735
+ if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAContextParams)) {
736
+ rb_raise(rb_eArgError, "params must be a LLaMAContextParams");
737
+ return Qnil;
738
+ }
739
+
740
+ LLaMAModelWrapper* model_ptr = get_llama_model(self);
741
+ if (model_ptr->model != NULL) {
742
+ rb_raise(rb_eRuntimeError, "LLaMA model is already loaded");
743
+ return Qnil;
744
+ }
745
+
746
+ VALUE filename = kw_values[0];
747
+ LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(kw_values[1]);
748
+
749
+ try {
750
+ model_ptr->model = llama_load_model_from_file(StringValueCStr(filename), prms_ptr->params);
751
+ } catch (const std::runtime_error& e) {
752
+ rb_raise(rb_eRuntimeError, "%s", e.what());
753
+ return Qnil;
754
+ }
755
+
756
+ if (model_ptr->model == NULL) {
757
+ rb_raise(rb_eRuntimeError, "Failed to initialize LLaMA model");
758
+ return Qnil;
759
+ }
760
+
761
+ rb_iv_set(self, "@params", kw_values[1]);
762
+
763
+ RB_GC_GUARD(filename);
764
+ return Qnil;
765
+ }
766
+
767
+ static VALUE _llama_model_apply_lora_from_file(int argc, VALUE* argv, VALUE self) {
768
+ VALUE kw_args = Qnil;
769
+ ID kw_table[3] = { rb_intern("lora_path"), rb_intern("base_model_path"), rb_intern("n_threads") };
770
+ VALUE kw_values[3] = { Qundef, Qundef, Qundef };
771
+ rb_scan_args(argc, argv, ":", &kw_args);
772
+ rb_get_kwargs(kw_args, kw_table, 1, 2, kw_values);
773
+
774
+ if (!RB_TYPE_P(kw_values[0], T_STRING)) {
775
+ rb_raise(rb_eArgError, "lora_path must be a string");
776
+ return Qnil;
777
+ }
778
+ if (kw_values[1] != Qundef && !RB_TYPE_P(kw_values[1], T_STRING)) {
779
+ rb_raise(rb_eArgError, "base_model_path must be a string");
780
+ return Qnil;
781
+ }
782
+ if (kw_values[2] != Qundef && !RB_INTEGER_TYPE_P(kw_values[2])) {
783
+ rb_raise(rb_eArgError, "n_threads must be an integer");
784
+ return Qnil;
785
+ }
786
+
787
+ const char* lora_path = StringValueCStr(kw_values[0]);
788
+ const char* base_model_path = kw_values[1] == Qundef ? NULL : StringValueCStr(kw_values[1]);
789
+ const int n_threads = kw_values[2] == Qundef ? 1 : NUM2INT(kw_values[2]);
790
+
791
+ LLaMAModelWrapper* ptr = get_llama_model(self);
792
+ if (llama_model_apply_lora_from_file(ptr->model, lora_path, base_model_path, n_threads) != 0) {
793
+ rb_raise(rb_eRuntimeError, "Failed to apply LoRA");
794
+ return Qnil;
795
+ }
796
+ return Qnil;
797
+ };
798
+ };
799
+
800
+ const rb_data_type_t RbLLaMAModel::llama_model_type = {
801
+ "RbLLaMAModel",
802
+ { NULL,
803
+ RbLLaMAModel::llama_model_free,
804
+ RbLLaMAModel::llama_model_size },
805
+ NULL,
806
+ NULL,
807
+ RUBY_TYPED_FREE_IMMEDIATELY
808
+ };
809
+
613
810
  class LLaMAContextWrapper {
614
811
  public:
615
812
  struct llama_context* ctx;
@@ -662,10 +859,6 @@ public:
662
859
  rb_define_method(rb_cLLaMAContext, "n_embd", RUBY_METHOD_FUNC(_llama_context_n_embd), 0);
663
860
  rb_define_method(rb_cLLaMAContext, "print_timings", RUBY_METHOD_FUNC(_llama_context_print_timings), 0);
664
861
  rb_define_method(rb_cLLaMAContext, "reset_timings", RUBY_METHOD_FUNC(_llama_context_reset_timings), 0);
665
- rb_define_method(rb_cLLaMAContext, "empty?", RUBY_METHOD_FUNC(_llama_context_empty), 0);
666
- rb_define_method(rb_cLLaMAContext, "free", RUBY_METHOD_FUNC(_llama_context_free), 0);
667
- rb_define_method(rb_cLLaMAContext, "load", RUBY_METHOD_FUNC(_llama_context_load), -1);
668
- rb_define_method(rb_cLLaMAContext, "apply_lora_from_file", RUBY_METHOD_FUNC(_llama_context_apply_lora_from_file), -1);
669
862
  rb_define_method(rb_cLLaMAContext, "kv_cache_token_count", RUBY_METHOD_FUNC(_llama_context_kv_cache_token_count), 0);
670
863
  rb_define_method(rb_cLLaMAContext, "set_rng_seed", RUBY_METHOD_FUNC(_llama_context_set_rng_seed), 1);
671
864
  rb_define_method(rb_cLLaMAContext, "load_session_file", RUBY_METHOD_FUNC(_llama_context_load_session_file), -1);
@@ -689,46 +882,37 @@ private:
689
882
 
690
883
  static VALUE _llama_context_initialize(int argc, VALUE* argv, VALUE self) {
691
884
  VALUE kw_args = Qnil;
692
- ID kw_table[2] = { rb_intern("model_path"), rb_intern("params") };
693
- VALUE kw_values[2] = { Qundef, Qundef };
885
+ ID kw_table[1] = { rb_intern("model") };
886
+ VALUE kw_values[1] = { Qundef };
694
887
  rb_scan_args(argc, argv, ":", &kw_args);
695
- rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
888
+ rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
696
889
 
697
- if (kw_values[0] == Qundef && kw_values[1] == Qundef) {
698
- rb_iv_set(self, "@params", Qnil);
699
- rb_iv_set(self, "@has_evaluated", Qfalse);
890
+ VALUE model = kw_values[0];
891
+ if (!rb_obj_is_kind_of(model, rb_cLLaMAModel)) {
892
+ rb_raise(rb_eArgError, "model must be a Model");
700
893
  return Qnil;
701
894
  }
702
895
 
703
- if (!RB_TYPE_P(kw_values[0], T_STRING)) {
704
- rb_raise(rb_eArgError, "model_path must be a string");
705
- return Qnil;
706
- }
707
- if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAContextParams)) {
708
- rb_raise(rb_eArgError, "params must be a ContextParams");
896
+ LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
897
+ if (model_ptr->model == NULL) {
898
+ rb_raise(rb_eRuntimeError, "Model is empty");
709
899
  return Qnil;
710
900
  }
711
901
 
712
- VALUE filename = kw_values[0];
713
- LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(kw_values[1]);
902
+ VALUE params = rb_iv_get(model, "@params");
903
+ LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
714
904
  LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
715
905
 
716
- try {
717
- ctx_ptr->ctx = llama_init_from_file(StringValueCStr(filename), prms_ptr->params);
718
- } catch (const std::runtime_error& e) {
719
- rb_raise(rb_eRuntimeError, "%s", e.what());
720
- return Qnil;
721
- }
906
+ ctx_ptr->ctx = llama_new_context_with_model(model_ptr->model, prms_ptr->params);
722
907
 
723
908
  if (ctx_ptr->ctx == NULL) {
724
909
  rb_raise(rb_eRuntimeError, "Failed to initialize LLaMA context");
725
910
  return Qnil;
726
911
  }
727
912
 
728
- rb_iv_set(self, "@params", kw_values[1]);
913
+ rb_iv_set(self, "@model", model);
729
914
  rb_iv_set(self, "@has_evaluated", Qfalse);
730
915
 
731
- RB_GC_GUARD(filename);
732
916
  return Qnil;
733
917
  };
734
918
 
@@ -873,7 +1057,9 @@ private:
873
1057
  return Qnil;
874
1058
  }
875
1059
 
876
- LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(rb_iv_get(self, "@params"));
1060
+ VALUE model = rb_iv_get(self, "@model");
1061
+ VALUE params = rb_iv_get(model, "@params");
1062
+ LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
877
1063
  const int n_tokens = prms_ptr->params.logits_all ? NUM2INT(rb_iv_get(self, "@n_tokens")) : 1;
878
1064
  const int n_vocab = llama_n_vocab(ptr->ctx);
879
1065
  const float* logits = llama_get_logits(ptr->ctx);
@@ -891,7 +1077,9 @@ private:
891
1077
  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
892
1078
  return Qnil;
893
1079
  }
894
- LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(rb_iv_get(self, "@params"));
1080
+ VALUE model = rb_iv_get(self, "@model");
1081
+ VALUE params = rb_iv_get(model, "@params");
1082
+ LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
895
1083
  if (!prms_ptr->params.embedding) {
896
1084
  rb_raise(rb_eRuntimeError, "embedding parameter is false");
897
1085
  return Qnil;
@@ -995,106 +1183,6 @@ private:
995
1183
  return Qnil;
996
1184
  };
997
1185
 
998
- static VALUE _llama_context_empty(VALUE self) {
999
- LLaMAContextWrapper* ptr = get_llama_context(self);
1000
- if (ptr->ctx != NULL) {
1001
- return Qfalse;
1002
- }
1003
- return Qtrue;
1004
- }
1005
-
1006
- static VALUE _llama_context_free(VALUE self) {
1007
- LLaMAContextWrapper* ptr = get_llama_context(self);
1008
- if (ptr->ctx != NULL) {
1009
- llama_free(ptr->ctx);
1010
- ptr->ctx = NULL;
1011
- rb_iv_set(self, "@params", Qnil);
1012
- rb_iv_set(self, "@has_evaluated", Qfalse);
1013
- }
1014
- return Qnil;
1015
- }
1016
-
1017
- static VALUE _llama_context_load(int argc, VALUE* argv, VALUE self) {
1018
- VALUE kw_args = Qnil;
1019
- ID kw_table[2] = { rb_intern("model_path"), rb_intern("params") };
1020
- VALUE kw_values[2] = { Qundef, Qundef };
1021
- rb_scan_args(argc, argv, ":", &kw_args);
1022
- rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
1023
-
1024
- if (!RB_TYPE_P(kw_values[0], T_STRING)) {
1025
- rb_raise(rb_eArgError, "model_path must be a string");
1026
- return Qnil;
1027
- }
1028
- if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAContextParams)) {
1029
- rb_raise(rb_eArgError, "params must be a LLaMAContextParams");
1030
- return Qnil;
1031
- }
1032
-
1033
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
1034
- if (ctx_ptr->ctx != NULL) {
1035
- rb_raise(rb_eRuntimeError, "LLaMA context is already loaded");
1036
- return Qnil;
1037
- }
1038
-
1039
- VALUE filename = kw_values[0];
1040
- LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(kw_values[1]);
1041
-
1042
- try {
1043
- ctx_ptr->ctx = llama_init_from_file(StringValueCStr(filename), prms_ptr->params);
1044
- } catch (const std::runtime_error& e) {
1045
- rb_raise(rb_eRuntimeError, "%s", e.what());
1046
- return Qnil;
1047
- }
1048
-
1049
- if (ctx_ptr->ctx == NULL) {
1050
- rb_raise(rb_eRuntimeError, "Failed to initialize LLaMA context");
1051
- return Qnil;
1052
- }
1053
-
1054
- rb_iv_set(self, "@params", kw_values[1]);
1055
- rb_iv_set(self, "@has_evaluated", Qfalse);
1056
-
1057
- RB_GC_GUARD(filename);
1058
- return Qnil;
1059
- };
1060
-
1061
- static VALUE _llama_context_apply_lora_from_file(int argc, VALUE* argv, VALUE self) {
1062
- VALUE kw_args = Qnil;
1063
- ID kw_table[3] = { rb_intern("lora_path"), rb_intern("base_model_path"), rb_intern("n_threads") };
1064
- VALUE kw_values[3] = { Qundef, Qundef, Qundef };
1065
- rb_scan_args(argc, argv, ":", &kw_args);
1066
- rb_get_kwargs(kw_args, kw_table, 1, 2, kw_values);
1067
-
1068
- if (!RB_TYPE_P(kw_values[0], T_STRING)) {
1069
- rb_raise(rb_eArgError, "lora_path must be a string");
1070
- return Qnil;
1071
- }
1072
- if (kw_values[1] != Qundef && !RB_TYPE_P(kw_values[1], T_STRING)) {
1073
- rb_raise(rb_eArgError, "base_model_path must be a string");
1074
- return Qnil;
1075
- }
1076
- if (kw_values[2] != Qundef && !RB_INTEGER_TYPE_P(kw_values[2])) {
1077
- rb_raise(rb_eArgError, "n_threads must be an integer");
1078
- return Qnil;
1079
- }
1080
-
1081
- const char* lora_path = StringValueCStr(kw_values[0]);
1082
- const char* base_model_path = kw_values[1] == Qundef ? NULL : StringValueCStr(kw_values[1]);
1083
- const int n_threads = kw_values[2] == Qundef ? 1 : NUM2INT(kw_values[2]);
1084
-
1085
- LLaMAContextWrapper* ptr = get_llama_context(self);
1086
- if (ptr->ctx != NULL) {
1087
- rb_raise(rb_eRuntimeError, "LLaMA context is already loaded");
1088
- return Qnil;
1089
- }
1090
-
1091
- if (llama_apply_lora_from_file(ptr->ctx, lora_path, base_model_path, n_threads) != 0) {
1092
- rb_raise(rb_eRuntimeError, "Failed to apply LoRA");
1093
- return Qnil;
1094
- }
1095
- return Qnil;
1096
- };
1097
-
1098
1186
  static VALUE _llama_context_kv_cache_token_count(VALUE self) {
1099
1187
  LLaMAContextWrapper* ptr = get_llama_context(self);
1100
1188
  if (ptr->ctx == NULL) {
@@ -1137,7 +1225,9 @@ private:
1137
1225
  return Qnil;
1138
1226
  }
1139
1227
 
1140
- LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(rb_iv_get(self, "@params"));
1228
+ VALUE model = rb_iv_get(self, "@model");
1229
+ VALUE params = rb_iv_get(model, "@params");
1230
+ LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
1141
1231
  const int n_ctx = prms_ptr->params.n_ctx;
1142
1232
 
1143
1233
  std::vector<llama_token> session_tokens(n_ctx);
@@ -1664,8 +1754,16 @@ const rb_data_type_t RbLLaMAContext::llama_context_type = {
1664
1754
 
1665
1755
  // module functions
1666
1756
 
1667
- static VALUE rb_llama_llama_init_backend(VALUE self) {
1668
- llama_init_backend();
1757
+ static VALUE rb_llama_llama_init_backend(int argc, VALUE* argv, VALUE self) {
1758
+ VALUE kw_args = Qnil;
1759
+ ID kw_table[1] = { rb_intern("numa") };
1760
+ VALUE kw_values[1] = { Qundef };
1761
+ rb_scan_args(argc, argv, ":", &kw_args);
1762
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
1763
+
1764
+ const bool numa = kw_values[0] == Qundef ? false : (RTEST ? true : false);
1765
+ llama_init_backend(numa);
1766
+
1669
1767
  return Qnil;
1670
1768
  }
1671
1769
 
@@ -1731,10 +1829,11 @@ extern "C" void Init_llama_cpp(void) {
1731
1829
 
1732
1830
  RbLLaMATokenData::define_class(rb_mLLaMACpp);
1733
1831
  RbLLaMATokenDataArray::define_class(rb_mLLaMACpp);
1832
+ RbLLaMAModel::define_class(rb_mLLaMACpp);
1734
1833
  RbLLaMAContext::define_class(rb_mLLaMACpp);
1735
1834
  RbLLaMAContextParams::define_class(rb_mLLaMACpp);
1736
1835
 
1737
- rb_define_module_function(rb_mLLaMACpp, "init_backend", rb_llama_llama_init_backend, 0);
1836
+ rb_define_module_function(rb_mLLaMACpp, "init_backend", rb_llama_llama_init_backend, -1);
1738
1837
  rb_define_module_function(rb_mLLaMACpp, "model_quantize", rb_llama_model_quantize, -1);
1739
1838
  rb_define_module_function(rb_mLLaMACpp, "token_bos", rb_llama_token_bos, 0);
1740
1839
  rb_define_module_function(rb_mLLaMACpp, "token_eos", rb_llama_token_eos, 0);