llama_cpp 0.2.2 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,6 +2,7 @@
2
2
  #include "llama_cpp.h"
3
3
 
4
4
  VALUE rb_mLLaMACpp;
5
+ VALUE rb_cLLaMAModel;
5
6
  VALUE rb_cLLaMAContext;
6
7
  VALUE rb_cLLaMAContextParams;
7
8
  VALUE rb_cLLaMAModelQuantizeParams;
@@ -403,6 +404,10 @@ private:
403
404
  // seed
404
405
  static VALUE _llama_context_params_set_seed(VALUE self, VALUE seed) {
405
406
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
407
+ if (NUM2INT(seed) < 0) {
408
+ rb_raise(rb_eArgError, "seed must be positive");
409
+ return Qnil;
410
+ }
406
411
  ptr->params.seed = NUM2INT(seed);
407
412
  return INT2NUM(ptr->params.seed);
408
413
  };
@@ -610,6 +615,206 @@ const rb_data_type_t RbLLaMAModelQuantizeParams::llama_model_quantize_params_typ
610
615
  RUBY_TYPED_FREE_IMMEDIATELY
611
616
  };
612
617
 
618
+ class LLaMAModelWrapper {
619
+ public:
620
+ struct llama_model* model;
621
+
622
+ LLaMAModelWrapper() : model(NULL){};
623
+
624
+ ~LLaMAModelWrapper() {
625
+ if (model != NULL) {
626
+ llama_free_model(model);
627
+ }
628
+ };
629
+ };
630
+
631
+ class RbLLaMAModel {
632
+ public:
633
+ static VALUE llama_model_alloc(VALUE self) {
634
+ LLaMAModelWrapper* ptr = (LLaMAModelWrapper*)ruby_xmalloc(sizeof(LLaMAModelWrapper));
635
+ new (ptr) LLaMAModelWrapper();
636
+ return TypedData_Wrap_Struct(self, &llama_model_type, ptr);
637
+ }
638
+
639
+ static void llama_model_free(void* ptr) {
640
+ ((LLaMAModelWrapper*)ptr)->~LLaMAModelWrapper();
641
+ ruby_xfree(ptr);
642
+ }
643
+
644
+ static size_t llama_model_size(const void* ptr) {
645
+ return sizeof(*((LLaMAModelWrapper*)ptr));
646
+ }
647
+
648
+ static LLaMAModelWrapper* get_llama_model(VALUE self) {
649
+ LLaMAModelWrapper* ptr;
650
+ TypedData_Get_Struct(self, LLaMAModelWrapper, &llama_model_type, ptr);
651
+ return ptr;
652
+ }
653
+
654
+ static void define_class(VALUE outer) {
655
+ rb_cLLaMAModel = rb_define_class_under(outer, "Model", rb_cObject);
656
+ rb_define_alloc_func(rb_cLLaMAModel, llama_model_alloc);
657
+ rb_define_method(rb_cLLaMAModel, "initialize", RUBY_METHOD_FUNC(_llama_model_initialize), -1);
658
+ rb_define_method(rb_cLLaMAModel, "empty?", RUBY_METHOD_FUNC(_llama_model_empty), 0);
659
+ rb_define_method(rb_cLLaMAModel, "free", RUBY_METHOD_FUNC(_llama_model_free), 0);
660
+ rb_define_method(rb_cLLaMAModel, "load", RUBY_METHOD_FUNC(_llama_model_load), -1);
661
+ rb_define_method(rb_cLLaMAModel, "apply_lora_from_file", RUBY_METHOD_FUNC(_llama_model_apply_lora_from_file), -1);
662
+ }
663
+
664
+ private:
665
+ static const rb_data_type_t llama_model_type;
666
+
667
+ static VALUE _llama_model_initialize(int argc, VALUE* argv, VALUE self) {
668
+ VALUE kw_args = Qnil;
669
+ ID kw_table[2] = { rb_intern("model_path"), rb_intern("params") };
670
+ VALUE kw_values[2] = { Qundef, Qundef };
671
+ rb_scan_args(argc, argv, ":", &kw_args);
672
+ rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
673
+
674
+ if (kw_values[0] == Qundef && kw_values[1] == Qundef) {
675
+ rb_iv_set(self, "@params", Qnil);
676
+ return Qnil;
677
+ }
678
+
679
+ if (!RB_TYPE_P(kw_values[0], T_STRING)) {
680
+ rb_raise(rb_eArgError, "model_path must be a string");
681
+ return Qnil;
682
+ }
683
+ if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAContextParams)) {
684
+ rb_raise(rb_eArgError, "params must be a ContextParams");
685
+ return Qnil;
686
+ }
687
+
688
+ VALUE filename = kw_values[0];
689
+ LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(kw_values[1]);
690
+ LLaMAModelWrapper* model_ptr = get_llama_model(self);
691
+
692
+ if (prms_ptr->params.seed == LLAMA_DEFAULT_SEED) {
693
+ prms_ptr->params.seed = time(NULL);
694
+ }
695
+
696
+ try {
697
+ model_ptr->model = llama_load_model_from_file(StringValueCStr(filename), prms_ptr->params);
698
+ } catch (const std::runtime_error& e) {
699
+ rb_raise(rb_eRuntimeError, "%s", e.what());
700
+ return Qnil;
701
+ }
702
+
703
+ if (model_ptr->model == NULL) {
704
+ rb_raise(rb_eRuntimeError, "Failed to initialize LLaMA model");
705
+ return Qnil;
706
+ }
707
+
708
+ rb_iv_set(self, "@params", kw_values[1]);
709
+
710
+ RB_GC_GUARD(filename);
711
+ return Qnil;
712
+ }
713
+
714
+ static VALUE _llama_model_empty(VALUE self) {
715
+ LLaMAModelWrapper* ptr = get_llama_model(self);
716
+ if (ptr->model != NULL) {
717
+ return Qfalse;
718
+ }
719
+ return Qtrue;
720
+ }
721
+
722
+ static VALUE _llama_model_free(VALUE self) {
723
+ LLaMAModelWrapper* ptr = get_llama_model(self);
724
+ if (ptr->model != NULL) {
725
+ llama_free_model(ptr->model);
726
+ ptr->model = NULL;
727
+ rb_iv_set(self, "@params", Qnil);
728
+ }
729
+ return Qnil;
730
+ }
731
+
732
+ static VALUE _llama_model_load(int argc, VALUE* argv, VALUE self) {
733
+ VALUE kw_args = Qnil;
734
+ ID kw_table[2] = { rb_intern("model_path"), rb_intern("params") };
735
+ VALUE kw_values[2] = { Qundef, Qundef };
736
+ rb_scan_args(argc, argv, ":", &kw_args);
737
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
738
+
739
+ if (!RB_TYPE_P(kw_values[0], T_STRING)) {
740
+ rb_raise(rb_eArgError, "model_path must be a string");
741
+ return Qnil;
742
+ }
743
+ if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAContextParams)) {
744
+ rb_raise(rb_eArgError, "params must be a LLaMAContextParams");
745
+ return Qnil;
746
+ }
747
+
748
+ LLaMAModelWrapper* model_ptr = get_llama_model(self);
749
+ if (model_ptr->model != NULL) {
750
+ rb_raise(rb_eRuntimeError, "LLaMA model is already loaded");
751
+ return Qnil;
752
+ }
753
+
754
+ VALUE filename = kw_values[0];
755
+ LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(kw_values[1]);
756
+
757
+ try {
758
+ model_ptr->model = llama_load_model_from_file(StringValueCStr(filename), prms_ptr->params);
759
+ } catch (const std::runtime_error& e) {
760
+ rb_raise(rb_eRuntimeError, "%s", e.what());
761
+ return Qnil;
762
+ }
763
+
764
+ if (model_ptr->model == NULL) {
765
+ rb_raise(rb_eRuntimeError, "Failed to initialize LLaMA model");
766
+ return Qnil;
767
+ }
768
+
769
+ rb_iv_set(self, "@params", kw_values[1]);
770
+
771
+ RB_GC_GUARD(filename);
772
+ return Qnil;
773
+ }
774
+
775
+ static VALUE _llama_model_apply_lora_from_file(int argc, VALUE* argv, VALUE self) {
776
+ VALUE kw_args = Qnil;
777
+ ID kw_table[3] = { rb_intern("lora_path"), rb_intern("base_model_path"), rb_intern("n_threads") };
778
+ VALUE kw_values[3] = { Qundef, Qundef, Qundef };
779
+ rb_scan_args(argc, argv, ":", &kw_args);
780
+ rb_get_kwargs(kw_args, kw_table, 1, 2, kw_values);
781
+
782
+ if (!RB_TYPE_P(kw_values[0], T_STRING)) {
783
+ rb_raise(rb_eArgError, "lora_path must be a string");
784
+ return Qnil;
785
+ }
786
+ if (kw_values[1] != Qundef && !RB_TYPE_P(kw_values[1], T_STRING)) {
787
+ rb_raise(rb_eArgError, "base_model_path must be a string");
788
+ return Qnil;
789
+ }
790
+ if (kw_values[2] != Qundef && !RB_INTEGER_TYPE_P(kw_values[2])) {
791
+ rb_raise(rb_eArgError, "n_threads must be an integer");
792
+ return Qnil;
793
+ }
794
+
795
+ const char* lora_path = StringValueCStr(kw_values[0]);
796
+ const char* base_model_path = kw_values[1] == Qundef ? NULL : StringValueCStr(kw_values[1]);
797
+ const int n_threads = kw_values[2] == Qundef ? 1 : NUM2INT(kw_values[2]);
798
+
799
+ LLaMAModelWrapper* ptr = get_llama_model(self);
800
+ if (llama_model_apply_lora_from_file(ptr->model, lora_path, base_model_path, n_threads) != 0) {
801
+ rb_raise(rb_eRuntimeError, "Failed to apply LoRA");
802
+ return Qnil;
803
+ }
804
+ return Qnil;
805
+ };
806
+ };
807
+
808
+ const rb_data_type_t RbLLaMAModel::llama_model_type = {
809
+ "RbLLaMAModel",
810
+ { NULL,
811
+ RbLLaMAModel::llama_model_free,
812
+ RbLLaMAModel::llama_model_size },
813
+ NULL,
814
+ NULL,
815
+ RUBY_TYPED_FREE_IMMEDIATELY
816
+ };
817
+
613
818
  class LLaMAContextWrapper {
614
819
  public:
615
820
  struct llama_context* ctx;
@@ -651,6 +856,7 @@ public:
651
856
  rb_define_alloc_func(rb_cLLaMAContext, llama_context_alloc);
652
857
  rb_define_method(rb_cLLaMAContext, "initialize", RUBY_METHOD_FUNC(_llama_context_initialize), -1);
653
858
  rb_define_method(rb_cLLaMAContext, "eval", RUBY_METHOD_FUNC(_llama_context_eval), -1);
859
+ rb_define_method(rb_cLLaMAContext, "eval_embd", RUBY_METHOD_FUNC(_llama_context_eval_embd), -1);
654
860
  rb_define_method(rb_cLLaMAContext, "eval_export", RUBY_METHOD_FUNC(_llama_context_eval_export), 1);
655
861
  rb_define_method(rb_cLLaMAContext, "tokenize", RUBY_METHOD_FUNC(_llama_context_tokenize), -1);
656
862
  rb_define_method(rb_cLLaMAContext, "logits", RUBY_METHOD_FUNC(_llama_context_logits), 0);
@@ -662,10 +868,6 @@ public:
662
868
  rb_define_method(rb_cLLaMAContext, "n_embd", RUBY_METHOD_FUNC(_llama_context_n_embd), 0);
663
869
  rb_define_method(rb_cLLaMAContext, "print_timings", RUBY_METHOD_FUNC(_llama_context_print_timings), 0);
664
870
  rb_define_method(rb_cLLaMAContext, "reset_timings", RUBY_METHOD_FUNC(_llama_context_reset_timings), 0);
665
- rb_define_method(rb_cLLaMAContext, "empty?", RUBY_METHOD_FUNC(_llama_context_empty), 0);
666
- rb_define_method(rb_cLLaMAContext, "free", RUBY_METHOD_FUNC(_llama_context_free), 0);
667
- rb_define_method(rb_cLLaMAContext, "load", RUBY_METHOD_FUNC(_llama_context_load), -1);
668
- rb_define_method(rb_cLLaMAContext, "apply_lora_from_file", RUBY_METHOD_FUNC(_llama_context_apply_lora_from_file), -1);
669
871
  rb_define_method(rb_cLLaMAContext, "kv_cache_token_count", RUBY_METHOD_FUNC(_llama_context_kv_cache_token_count), 0);
670
872
  rb_define_method(rb_cLLaMAContext, "set_rng_seed", RUBY_METHOD_FUNC(_llama_context_set_rng_seed), 1);
671
873
  rb_define_method(rb_cLLaMAContext, "load_session_file", RUBY_METHOD_FUNC(_llama_context_load_session_file), -1);
@@ -689,46 +891,37 @@ private:
689
891
 
690
892
  static VALUE _llama_context_initialize(int argc, VALUE* argv, VALUE self) {
691
893
  VALUE kw_args = Qnil;
692
- ID kw_table[2] = { rb_intern("model_path"), rb_intern("params") };
693
- VALUE kw_values[2] = { Qundef, Qundef };
894
+ ID kw_table[1] = { rb_intern("model") };
895
+ VALUE kw_values[1] = { Qundef };
694
896
  rb_scan_args(argc, argv, ":", &kw_args);
695
- rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
897
+ rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
696
898
 
697
- if (kw_values[0] == Qundef && kw_values[1] == Qundef) {
698
- rb_iv_set(self, "@params", Qnil);
699
- rb_iv_set(self, "@has_evaluated", Qfalse);
899
+ VALUE model = kw_values[0];
900
+ if (!rb_obj_is_kind_of(model, rb_cLLaMAModel)) {
901
+ rb_raise(rb_eArgError, "model must be a Model");
700
902
  return Qnil;
701
903
  }
702
904
 
703
- if (!RB_TYPE_P(kw_values[0], T_STRING)) {
704
- rb_raise(rb_eArgError, "model_path must be a string");
705
- return Qnil;
706
- }
707
- if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAContextParams)) {
708
- rb_raise(rb_eArgError, "params must be a ContextParams");
905
+ LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
906
+ if (model_ptr->model == NULL) {
907
+ rb_raise(rb_eRuntimeError, "Model is empty");
709
908
  return Qnil;
710
909
  }
711
910
 
712
- VALUE filename = kw_values[0];
713
- LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(kw_values[1]);
911
+ VALUE params = rb_iv_get(model, "@params");
912
+ LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
714
913
  LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
715
914
 
716
- try {
717
- ctx_ptr->ctx = llama_init_from_file(StringValueCStr(filename), prms_ptr->params);
718
- } catch (const std::runtime_error& e) {
719
- rb_raise(rb_eRuntimeError, "%s", e.what());
720
- return Qnil;
721
- }
915
+ ctx_ptr->ctx = llama_new_context_with_model(model_ptr->model, prms_ptr->params);
722
916
 
723
917
  if (ctx_ptr->ctx == NULL) {
724
918
  rb_raise(rb_eRuntimeError, "Failed to initialize LLaMA context");
725
919
  return Qnil;
726
920
  }
727
921
 
728
- rb_iv_set(self, "@params", kw_values[1]);
922
+ rb_iv_set(self, "@model", model);
729
923
  rb_iv_set(self, "@has_evaluated", Qfalse);
730
924
 
731
- RB_GC_GUARD(filename);
732
925
  return Qnil;
733
926
  };
734
927
 
@@ -787,6 +980,61 @@ private:
787
980
  return Qnil;
788
981
  };
789
982
 
983
+ static VALUE _llama_context_eval_embd(int argc, VALUE* argv, VALUE self) {
984
+ VALUE kw_args = Qnil;
985
+ ID kw_table[4] = { rb_intern("embd"), rb_intern("n_past"), rb_intern("n_tokens"), rb_intern("n_threads") };
986
+ VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
987
+ rb_scan_args(argc, argv, ":", &kw_args);
988
+ rb_get_kwargs(kw_args, kw_table, 2, 2, kw_values);
989
+
990
+ if (!RB_TYPE_P(kw_values[0], T_ARRAY)) {
991
+ rb_raise(rb_eArgError, "tokens must be an Array");
992
+ return Qnil;
993
+ }
994
+ if (!RB_INTEGER_TYPE_P(kw_values[1])) {
995
+ rb_raise(rb_eArgError, "n_past must be an integer");
996
+ return Qnil;
997
+ }
998
+ if (kw_values[2] != Qundef && !RB_INTEGER_TYPE_P(kw_values[2])) {
999
+ rb_raise(rb_eArgError, "n_tokens must be an integer");
1000
+ return Qnil;
1001
+ }
1002
+ if (kw_values[3] != Qundef && !RB_INTEGER_TYPE_P(kw_values[3])) {
1003
+ rb_raise(rb_eArgError, "n_threads must be an integer");
1004
+ return Qnil;
1005
+ }
1006
+
1007
+ const size_t tokens_len = RARRAY_LEN(kw_values[0]);
1008
+ std::vector<float> embd(tokens_len);
1009
+ for (size_t i = 0; i < tokens_len; i++) {
1010
+ VALUE el = rb_ary_entry(kw_values[0], i);
1011
+ if (!RB_FLOAT_TYPE_P(el)) {
1012
+ rb_raise(rb_eArgError, "embd must be an array of floats");
1013
+ return Qnil;
1014
+ }
1015
+ embd[i] = NUM2DBL(el);
1016
+ }
1017
+
1018
+ const int n_tokens = kw_values[2] == Qundef ? (int)tokens_len : NUM2INT(kw_values[2]);
1019
+ const int n_past = NUM2INT(kw_values[1]);
1020
+ const int n_threads = kw_values[3] == Qundef ? 1 : NUM2INT(kw_values[3]);
1021
+
1022
+ LLaMAContextWrapper* ptr = get_llama_context(self);
1023
+ if (ptr->ctx == NULL) {
1024
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1025
+ return Qnil;
1026
+ }
1027
+ if (llama_eval_embd(ptr->ctx, embd.data(), n_tokens, n_past, n_threads) != 0) {
1028
+ rb_raise(rb_eRuntimeError, "Failed to evaluate");
1029
+ return Qnil;
1030
+ }
1031
+
1032
+ rb_iv_set(self, "@n_tokens", INT2NUM(n_tokens));
1033
+ rb_iv_set(self, "@has_evaluated", Qtrue);
1034
+
1035
+ return Qnil;
1036
+ }
1037
+
790
1038
  static VALUE _llama_context_eval_export(VALUE self, VALUE fname_) {
791
1039
  LLaMAContextWrapper* ptr = get_llama_context(self);
792
1040
  if (ptr->ctx == NULL) {
@@ -873,7 +1121,9 @@ private:
873
1121
  return Qnil;
874
1122
  }
875
1123
 
876
- LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(rb_iv_get(self, "@params"));
1124
+ VALUE model = rb_iv_get(self, "@model");
1125
+ VALUE params = rb_iv_get(model, "@params");
1126
+ LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
877
1127
  const int n_tokens = prms_ptr->params.logits_all ? NUM2INT(rb_iv_get(self, "@n_tokens")) : 1;
878
1128
  const int n_vocab = llama_n_vocab(ptr->ctx);
879
1129
  const float* logits = llama_get_logits(ptr->ctx);
@@ -891,7 +1141,9 @@ private:
891
1141
  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
892
1142
  return Qnil;
893
1143
  }
894
- LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(rb_iv_get(self, "@params"));
1144
+ VALUE model = rb_iv_get(self, "@model");
1145
+ VALUE params = rb_iv_get(model, "@params");
1146
+ LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
895
1147
  if (!prms_ptr->params.embedding) {
896
1148
  rb_raise(rb_eRuntimeError, "embedding parameter is false");
897
1149
  return Qnil;
@@ -995,106 +1247,6 @@ private:
995
1247
  return Qnil;
996
1248
  };
997
1249
 
998
- static VALUE _llama_context_empty(VALUE self) {
999
- LLaMAContextWrapper* ptr = get_llama_context(self);
1000
- if (ptr->ctx != NULL) {
1001
- return Qfalse;
1002
- }
1003
- return Qtrue;
1004
- }
1005
-
1006
- static VALUE _llama_context_free(VALUE self) {
1007
- LLaMAContextWrapper* ptr = get_llama_context(self);
1008
- if (ptr->ctx != NULL) {
1009
- llama_free(ptr->ctx);
1010
- ptr->ctx = NULL;
1011
- rb_iv_set(self, "@params", Qnil);
1012
- rb_iv_set(self, "@has_evaluated", Qfalse);
1013
- }
1014
- return Qnil;
1015
- }
1016
-
1017
- static VALUE _llama_context_load(int argc, VALUE* argv, VALUE self) {
1018
- VALUE kw_args = Qnil;
1019
- ID kw_table[2] = { rb_intern("model_path"), rb_intern("params") };
1020
- VALUE kw_values[2] = { Qundef, Qundef };
1021
- rb_scan_args(argc, argv, ":", &kw_args);
1022
- rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
1023
-
1024
- if (!RB_TYPE_P(kw_values[0], T_STRING)) {
1025
- rb_raise(rb_eArgError, "model_path must be a string");
1026
- return Qnil;
1027
- }
1028
- if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAContextParams)) {
1029
- rb_raise(rb_eArgError, "params must be a LLaMAContextParams");
1030
- return Qnil;
1031
- }
1032
-
1033
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
1034
- if (ctx_ptr->ctx != NULL) {
1035
- rb_raise(rb_eRuntimeError, "LLaMA context is already loaded");
1036
- return Qnil;
1037
- }
1038
-
1039
- VALUE filename = kw_values[0];
1040
- LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(kw_values[1]);
1041
-
1042
- try {
1043
- ctx_ptr->ctx = llama_init_from_file(StringValueCStr(filename), prms_ptr->params);
1044
- } catch (const std::runtime_error& e) {
1045
- rb_raise(rb_eRuntimeError, "%s", e.what());
1046
- return Qnil;
1047
- }
1048
-
1049
- if (ctx_ptr->ctx == NULL) {
1050
- rb_raise(rb_eRuntimeError, "Failed to initialize LLaMA context");
1051
- return Qnil;
1052
- }
1053
-
1054
- rb_iv_set(self, "@params", kw_values[1]);
1055
- rb_iv_set(self, "@has_evaluated", Qfalse);
1056
-
1057
- RB_GC_GUARD(filename);
1058
- return Qnil;
1059
- };
1060
-
1061
- static VALUE _llama_context_apply_lora_from_file(int argc, VALUE* argv, VALUE self) {
1062
- VALUE kw_args = Qnil;
1063
- ID kw_table[3] = { rb_intern("lora_path"), rb_intern("base_model_path"), rb_intern("n_threads") };
1064
- VALUE kw_values[3] = { Qundef, Qundef, Qundef };
1065
- rb_scan_args(argc, argv, ":", &kw_args);
1066
- rb_get_kwargs(kw_args, kw_table, 1, 2, kw_values);
1067
-
1068
- if (!RB_TYPE_P(kw_values[0], T_STRING)) {
1069
- rb_raise(rb_eArgError, "lora_path must be a string");
1070
- return Qnil;
1071
- }
1072
- if (kw_values[1] != Qundef && !RB_TYPE_P(kw_values[1], T_STRING)) {
1073
- rb_raise(rb_eArgError, "base_model_path must be a string");
1074
- return Qnil;
1075
- }
1076
- if (kw_values[2] != Qundef && !RB_INTEGER_TYPE_P(kw_values[2])) {
1077
- rb_raise(rb_eArgError, "n_threads must be an integer");
1078
- return Qnil;
1079
- }
1080
-
1081
- const char* lora_path = StringValueCStr(kw_values[0]);
1082
- const char* base_model_path = kw_values[1] == Qundef ? NULL : StringValueCStr(kw_values[1]);
1083
- const int n_threads = kw_values[2] == Qundef ? 1 : NUM2INT(kw_values[2]);
1084
-
1085
- LLaMAContextWrapper* ptr = get_llama_context(self);
1086
- if (ptr->ctx != NULL) {
1087
- rb_raise(rb_eRuntimeError, "LLaMA context is already loaded");
1088
- return Qnil;
1089
- }
1090
-
1091
- if (llama_apply_lora_from_file(ptr->ctx, lora_path, base_model_path, n_threads) != 0) {
1092
- rb_raise(rb_eRuntimeError, "Failed to apply LoRA");
1093
- return Qnil;
1094
- }
1095
- return Qnil;
1096
- };
1097
-
1098
1250
  static VALUE _llama_context_kv_cache_token_count(VALUE self) {
1099
1251
  LLaMAContextWrapper* ptr = get_llama_context(self);
1100
1252
  if (ptr->ctx == NULL) {
@@ -1110,7 +1262,11 @@ private:
1110
1262
  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1111
1263
  return Qnil;
1112
1264
  }
1113
- const int seed = NUM2INT(seed_);
1265
+ if (NUM2INT(seed_) < 0) {
1266
+ rb_raise(rb_eArgError, "seed must be a non-negative integer");
1267
+ return Qnil;
1268
+ }
1269
+ const uint32_t seed = NUM2INT(seed_);
1114
1270
  llama_set_rng_seed(ptr->ctx, seed);
1115
1271
  return Qnil;
1116
1272
  };
@@ -1137,7 +1293,9 @@ private:
1137
1293
  return Qnil;
1138
1294
  }
1139
1295
 
1140
- LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(rb_iv_get(self, "@params"));
1296
+ VALUE model = rb_iv_get(self, "@model");
1297
+ VALUE params = rb_iv_get(model, "@params");
1298
+ LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
1141
1299
  const int n_ctx = prms_ptr->params.n_ctx;
1142
1300
 
1143
1301
  std::vector<llama_token> session_tokens(n_ctx);
@@ -1664,8 +1822,16 @@ const rb_data_type_t RbLLaMAContext::llama_context_type = {
1664
1822
 
1665
1823
  // module functions
1666
1824
 
1667
- static VALUE rb_llama_llama_init_backend(VALUE self) {
1668
- llama_init_backend();
1825
+ static VALUE rb_llama_llama_init_backend(int argc, VALUE* argv, VALUE self) {
1826
+ VALUE kw_args = Qnil;
1827
+ ID kw_table[1] = { rb_intern("numa") };
1828
+ VALUE kw_values[1] = { Qundef };
1829
+ rb_scan_args(argc, argv, ":", &kw_args);
1830
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
1831
+
1832
+ const bool numa = kw_values[0] == Qundef ? false : (RTEST ? true : false);
1833
+ llama_init_backend(numa);
1834
+
1669
1835
  return Qnil;
1670
1836
  }
1671
1837
 
@@ -1731,10 +1897,11 @@ extern "C" void Init_llama_cpp(void) {
1731
1897
 
1732
1898
  RbLLaMATokenData::define_class(rb_mLLaMACpp);
1733
1899
  RbLLaMATokenDataArray::define_class(rb_mLLaMACpp);
1900
+ RbLLaMAModel::define_class(rb_mLLaMACpp);
1734
1901
  RbLLaMAContext::define_class(rb_mLLaMACpp);
1735
1902
  RbLLaMAContextParams::define_class(rb_mLLaMACpp);
1736
1903
 
1737
- rb_define_module_function(rb_mLLaMACpp, "init_backend", rb_llama_llama_init_backend, 0);
1904
+ rb_define_module_function(rb_mLLaMACpp, "init_backend", rb_llama_llama_init_backend, -1);
1738
1905
  rb_define_module_function(rb_mLLaMACpp, "model_quantize", rb_llama_model_quantize, -1);
1739
1906
  rb_define_module_function(rb_mLLaMACpp, "token_bos", rb_llama_token_bos, 0);
1740
1907
  rb_define_module_function(rb_mLLaMACpp, "token_eos", rb_llama_token_eos, 0);
@@ -1802,6 +1969,11 @@ extern "C" void Init_llama_cpp(void) {
1802
1969
  ss_magic << std::showbase << std::hex << LLAMA_SESSION_MAGIC;
1803
1970
  rb_define_const(rb_mLLaMACpp, "LLAMA_SESSION_MAGIC", rb_str_new2(ss_magic.str().c_str()));
1804
1971
 
1972
+ ss_magic.str("");
1973
+ ss_magic.clear(std::stringstream::goodbit);
1974
+ ss_magic << std::showbase << std::hex << LLAMA_DEFAULT_SEED;
1975
+ rb_define_const(rb_mLLaMACpp, "LLAMA_DEFAULT_SEED", rb_str_new2(ss_magic.str().c_str()));
1976
+
1805
1977
  rb_define_const(rb_mLLaMACpp, "LLAMA_FILE_VERSION", rb_str_new2(std::to_string(LLAMA_FILE_VERSION).c_str()));
1806
1978
  rb_define_const(rb_mLLaMACpp, "LLAMA_SESSION_VERSION", rb_str_new2(std::to_string(LLAMA_SESSION_VERSION).c_str()));
1807
1979
  }