llama_cpp 0.2.2 → 0.3.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -2,6 +2,7 @@
2
2
  #include "llama_cpp.h"
3
3
 
4
4
  VALUE rb_mLLaMACpp;
5
+ VALUE rb_cLLaMAModel;
5
6
  VALUE rb_cLLaMAContext;
6
7
  VALUE rb_cLLaMAContextParams;
7
8
  VALUE rb_cLLaMAModelQuantizeParams;
@@ -403,6 +404,10 @@ private:
403
404
  // seed
404
405
  static VALUE _llama_context_params_set_seed(VALUE self, VALUE seed) {
405
406
  LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
407
+ if (NUM2INT(seed) < 0) {
408
+ rb_raise(rb_eArgError, "seed must be positive");
409
+ return Qnil;
410
+ }
406
411
  ptr->params.seed = NUM2INT(seed);
407
412
  return INT2NUM(ptr->params.seed);
408
413
  };
@@ -610,6 +615,206 @@ const rb_data_type_t RbLLaMAModelQuantizeParams::llama_model_quantize_params_typ
610
615
  RUBY_TYPED_FREE_IMMEDIATELY
611
616
  };
612
617
 
618
+ class LLaMAModelWrapper {
619
+ public:
620
+ struct llama_model* model;
621
+
622
+ LLaMAModelWrapper() : model(NULL){};
623
+
624
+ ~LLaMAModelWrapper() {
625
+ if (model != NULL) {
626
+ llama_free_model(model);
627
+ }
628
+ };
629
+ };
630
+
631
+ class RbLLaMAModel {
632
+ public:
633
+ static VALUE llama_model_alloc(VALUE self) {
634
+ LLaMAModelWrapper* ptr = (LLaMAModelWrapper*)ruby_xmalloc(sizeof(LLaMAModelWrapper));
635
+ new (ptr) LLaMAModelWrapper();
636
+ return TypedData_Wrap_Struct(self, &llama_model_type, ptr);
637
+ }
638
+
639
+ static void llama_model_free(void* ptr) {
640
+ ((LLaMAModelWrapper*)ptr)->~LLaMAModelWrapper();
641
+ ruby_xfree(ptr);
642
+ }
643
+
644
+ static size_t llama_model_size(const void* ptr) {
645
+ return sizeof(*((LLaMAModelWrapper*)ptr));
646
+ }
647
+
648
+ static LLaMAModelWrapper* get_llama_model(VALUE self) {
649
+ LLaMAModelWrapper* ptr;
650
+ TypedData_Get_Struct(self, LLaMAModelWrapper, &llama_model_type, ptr);
651
+ return ptr;
652
+ }
653
+
654
+ static void define_class(VALUE outer) {
655
+ rb_cLLaMAModel = rb_define_class_under(outer, "Model", rb_cObject);
656
+ rb_define_alloc_func(rb_cLLaMAModel, llama_model_alloc);
657
+ rb_define_method(rb_cLLaMAModel, "initialize", RUBY_METHOD_FUNC(_llama_model_initialize), -1);
658
+ rb_define_method(rb_cLLaMAModel, "empty?", RUBY_METHOD_FUNC(_llama_model_empty), 0);
659
+ rb_define_method(rb_cLLaMAModel, "free", RUBY_METHOD_FUNC(_llama_model_free), 0);
660
+ rb_define_method(rb_cLLaMAModel, "load", RUBY_METHOD_FUNC(_llama_model_load), -1);
661
+ rb_define_method(rb_cLLaMAModel, "apply_lora_from_file", RUBY_METHOD_FUNC(_llama_model_apply_lora_from_file), -1);
662
+ }
663
+
664
+ private:
665
+ static const rb_data_type_t llama_model_type;
666
+
667
+ static VALUE _llama_model_initialize(int argc, VALUE* argv, VALUE self) {
668
+ VALUE kw_args = Qnil;
669
+ ID kw_table[2] = { rb_intern("model_path"), rb_intern("params") };
670
+ VALUE kw_values[2] = { Qundef, Qundef };
671
+ rb_scan_args(argc, argv, ":", &kw_args);
672
+ rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
673
+
674
+ if (kw_values[0] == Qundef && kw_values[1] == Qundef) {
675
+ rb_iv_set(self, "@params", Qnil);
676
+ return Qnil;
677
+ }
678
+
679
+ if (!RB_TYPE_P(kw_values[0], T_STRING)) {
680
+ rb_raise(rb_eArgError, "model_path must be a string");
681
+ return Qnil;
682
+ }
683
+ if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAContextParams)) {
684
+ rb_raise(rb_eArgError, "params must be a ContextParams");
685
+ return Qnil;
686
+ }
687
+
688
+ VALUE filename = kw_values[0];
689
+ LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(kw_values[1]);
690
+ LLaMAModelWrapper* model_ptr = get_llama_model(self);
691
+
692
+ if (prms_ptr->params.seed == LLAMA_DEFAULT_SEED) {
693
+ prms_ptr->params.seed = time(NULL);
694
+ }
695
+
696
+ try {
697
+ model_ptr->model = llama_load_model_from_file(StringValueCStr(filename), prms_ptr->params);
698
+ } catch (const std::runtime_error& e) {
699
+ rb_raise(rb_eRuntimeError, "%s", e.what());
700
+ return Qnil;
701
+ }
702
+
703
+ if (model_ptr->model == NULL) {
704
+ rb_raise(rb_eRuntimeError, "Failed to initialize LLaMA model");
705
+ return Qnil;
706
+ }
707
+
708
+ rb_iv_set(self, "@params", kw_values[1]);
709
+
710
+ RB_GC_GUARD(filename);
711
+ return Qnil;
712
+ }
713
+
714
+ static VALUE _llama_model_empty(VALUE self) {
715
+ LLaMAModelWrapper* ptr = get_llama_model(self);
716
+ if (ptr->model != NULL) {
717
+ return Qfalse;
718
+ }
719
+ return Qtrue;
720
+ }
721
+
722
+ static VALUE _llama_model_free(VALUE self) {
723
+ LLaMAModelWrapper* ptr = get_llama_model(self);
724
+ if (ptr->model != NULL) {
725
+ llama_free_model(ptr->model);
726
+ ptr->model = NULL;
727
+ rb_iv_set(self, "@params", Qnil);
728
+ }
729
+ return Qnil;
730
+ }
731
+
732
+ static VALUE _llama_model_load(int argc, VALUE* argv, VALUE self) {
733
+ VALUE kw_args = Qnil;
734
+ ID kw_table[2] = { rb_intern("model_path"), rb_intern("params") };
735
+ VALUE kw_values[2] = { Qundef, Qundef };
736
+ rb_scan_args(argc, argv, ":", &kw_args);
737
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
738
+
739
+ if (!RB_TYPE_P(kw_values[0], T_STRING)) {
740
+ rb_raise(rb_eArgError, "model_path must be a string");
741
+ return Qnil;
742
+ }
743
+ if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAContextParams)) {
744
+ rb_raise(rb_eArgError, "params must be a LLaMAContextParams");
745
+ return Qnil;
746
+ }
747
+
748
+ LLaMAModelWrapper* model_ptr = get_llama_model(self);
749
+ if (model_ptr->model != NULL) {
750
+ rb_raise(rb_eRuntimeError, "LLaMA model is already loaded");
751
+ return Qnil;
752
+ }
753
+
754
+ VALUE filename = kw_values[0];
755
+ LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(kw_values[1]);
756
+
757
+ try {
758
+ model_ptr->model = llama_load_model_from_file(StringValueCStr(filename), prms_ptr->params);
759
+ } catch (const std::runtime_error& e) {
760
+ rb_raise(rb_eRuntimeError, "%s", e.what());
761
+ return Qnil;
762
+ }
763
+
764
+ if (model_ptr->model == NULL) {
765
+ rb_raise(rb_eRuntimeError, "Failed to initialize LLaMA model");
766
+ return Qnil;
767
+ }
768
+
769
+ rb_iv_set(self, "@params", kw_values[1]);
770
+
771
+ RB_GC_GUARD(filename);
772
+ return Qnil;
773
+ }
774
+
775
+ static VALUE _llama_model_apply_lora_from_file(int argc, VALUE* argv, VALUE self) {
776
+ VALUE kw_args = Qnil;
777
+ ID kw_table[3] = { rb_intern("lora_path"), rb_intern("base_model_path"), rb_intern("n_threads") };
778
+ VALUE kw_values[3] = { Qundef, Qundef, Qundef };
779
+ rb_scan_args(argc, argv, ":", &kw_args);
780
+ rb_get_kwargs(kw_args, kw_table, 1, 2, kw_values);
781
+
782
+ if (!RB_TYPE_P(kw_values[0], T_STRING)) {
783
+ rb_raise(rb_eArgError, "lora_path must be a string");
784
+ return Qnil;
785
+ }
786
+ if (kw_values[1] != Qundef && !RB_TYPE_P(kw_values[1], T_STRING)) {
787
+ rb_raise(rb_eArgError, "base_model_path must be a string");
788
+ return Qnil;
789
+ }
790
+ if (kw_values[2] != Qundef && !RB_INTEGER_TYPE_P(kw_values[2])) {
791
+ rb_raise(rb_eArgError, "n_threads must be an integer");
792
+ return Qnil;
793
+ }
794
+
795
+ const char* lora_path = StringValueCStr(kw_values[0]);
796
+ const char* base_model_path = kw_values[1] == Qundef ? NULL : StringValueCStr(kw_values[1]);
797
+ const int n_threads = kw_values[2] == Qundef ? 1 : NUM2INT(kw_values[2]);
798
+
799
+ LLaMAModelWrapper* ptr = get_llama_model(self);
800
+ if (llama_model_apply_lora_from_file(ptr->model, lora_path, base_model_path, n_threads) != 0) {
801
+ rb_raise(rb_eRuntimeError, "Failed to apply LoRA");
802
+ return Qnil;
803
+ }
804
+ return Qnil;
805
+ };
806
+ };
807
+
808
+ const rb_data_type_t RbLLaMAModel::llama_model_type = {
809
+ "RbLLaMAModel",
810
+ { NULL,
811
+ RbLLaMAModel::llama_model_free,
812
+ RbLLaMAModel::llama_model_size },
813
+ NULL,
814
+ NULL,
815
+ RUBY_TYPED_FREE_IMMEDIATELY
816
+ };
817
+
613
818
  class LLaMAContextWrapper {
614
819
  public:
615
820
  struct llama_context* ctx;
@@ -651,6 +856,7 @@ public:
651
856
  rb_define_alloc_func(rb_cLLaMAContext, llama_context_alloc);
652
857
  rb_define_method(rb_cLLaMAContext, "initialize", RUBY_METHOD_FUNC(_llama_context_initialize), -1);
653
858
  rb_define_method(rb_cLLaMAContext, "eval", RUBY_METHOD_FUNC(_llama_context_eval), -1);
859
+ rb_define_method(rb_cLLaMAContext, "eval_embd", RUBY_METHOD_FUNC(_llama_context_eval_embd), -1);
654
860
  rb_define_method(rb_cLLaMAContext, "eval_export", RUBY_METHOD_FUNC(_llama_context_eval_export), 1);
655
861
  rb_define_method(rb_cLLaMAContext, "tokenize", RUBY_METHOD_FUNC(_llama_context_tokenize), -1);
656
862
  rb_define_method(rb_cLLaMAContext, "logits", RUBY_METHOD_FUNC(_llama_context_logits), 0);
@@ -662,10 +868,6 @@ public:
662
868
  rb_define_method(rb_cLLaMAContext, "n_embd", RUBY_METHOD_FUNC(_llama_context_n_embd), 0);
663
869
  rb_define_method(rb_cLLaMAContext, "print_timings", RUBY_METHOD_FUNC(_llama_context_print_timings), 0);
664
870
  rb_define_method(rb_cLLaMAContext, "reset_timings", RUBY_METHOD_FUNC(_llama_context_reset_timings), 0);
665
- rb_define_method(rb_cLLaMAContext, "empty?", RUBY_METHOD_FUNC(_llama_context_empty), 0);
666
- rb_define_method(rb_cLLaMAContext, "free", RUBY_METHOD_FUNC(_llama_context_free), 0);
667
- rb_define_method(rb_cLLaMAContext, "load", RUBY_METHOD_FUNC(_llama_context_load), -1);
668
- rb_define_method(rb_cLLaMAContext, "apply_lora_from_file", RUBY_METHOD_FUNC(_llama_context_apply_lora_from_file), -1);
669
871
  rb_define_method(rb_cLLaMAContext, "kv_cache_token_count", RUBY_METHOD_FUNC(_llama_context_kv_cache_token_count), 0);
670
872
  rb_define_method(rb_cLLaMAContext, "set_rng_seed", RUBY_METHOD_FUNC(_llama_context_set_rng_seed), 1);
671
873
  rb_define_method(rb_cLLaMAContext, "load_session_file", RUBY_METHOD_FUNC(_llama_context_load_session_file), -1);
@@ -689,46 +891,37 @@ private:
689
891
 
690
892
  static VALUE _llama_context_initialize(int argc, VALUE* argv, VALUE self) {
691
893
  VALUE kw_args = Qnil;
692
- ID kw_table[2] = { rb_intern("model_path"), rb_intern("params") };
693
- VALUE kw_values[2] = { Qundef, Qundef };
894
+ ID kw_table[1] = { rb_intern("model") };
895
+ VALUE kw_values[1] = { Qundef };
694
896
  rb_scan_args(argc, argv, ":", &kw_args);
695
- rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
897
+ rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
696
898
 
697
- if (kw_values[0] == Qundef && kw_values[1] == Qundef) {
698
- rb_iv_set(self, "@params", Qnil);
699
- rb_iv_set(self, "@has_evaluated", Qfalse);
899
+ VALUE model = kw_values[0];
900
+ if (!rb_obj_is_kind_of(model, rb_cLLaMAModel)) {
901
+ rb_raise(rb_eArgError, "model must be a Model");
700
902
  return Qnil;
701
903
  }
702
904
 
703
- if (!RB_TYPE_P(kw_values[0], T_STRING)) {
704
- rb_raise(rb_eArgError, "model_path must be a string");
705
- return Qnil;
706
- }
707
- if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAContextParams)) {
708
- rb_raise(rb_eArgError, "params must be a ContextParams");
905
+ LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
906
+ if (model_ptr->model == NULL) {
907
+ rb_raise(rb_eRuntimeError, "Model is empty");
709
908
  return Qnil;
710
909
  }
711
910
 
712
- VALUE filename = kw_values[0];
713
- LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(kw_values[1]);
911
+ VALUE params = rb_iv_get(model, "@params");
912
+ LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
714
913
  LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
715
914
 
716
- try {
717
- ctx_ptr->ctx = llama_init_from_file(StringValueCStr(filename), prms_ptr->params);
718
- } catch (const std::runtime_error& e) {
719
- rb_raise(rb_eRuntimeError, "%s", e.what());
720
- return Qnil;
721
- }
915
+ ctx_ptr->ctx = llama_new_context_with_model(model_ptr->model, prms_ptr->params);
722
916
 
723
917
  if (ctx_ptr->ctx == NULL) {
724
918
  rb_raise(rb_eRuntimeError, "Failed to initialize LLaMA context");
725
919
  return Qnil;
726
920
  }
727
921
 
728
- rb_iv_set(self, "@params", kw_values[1]);
922
+ rb_iv_set(self, "@model", model);
729
923
  rb_iv_set(self, "@has_evaluated", Qfalse);
730
924
 
731
- RB_GC_GUARD(filename);
732
925
  return Qnil;
733
926
  };
734
927
 
@@ -787,6 +980,61 @@ private:
787
980
  return Qnil;
788
981
  };
789
982
 
983
+ static VALUE _llama_context_eval_embd(int argc, VALUE* argv, VALUE self) {
984
+ VALUE kw_args = Qnil;
985
+ ID kw_table[4] = { rb_intern("embd"), rb_intern("n_past"), rb_intern("n_tokens"), rb_intern("n_threads") };
986
+ VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
987
+ rb_scan_args(argc, argv, ":", &kw_args);
988
+ rb_get_kwargs(kw_args, kw_table, 2, 2, kw_values);
989
+
990
+ if (!RB_TYPE_P(kw_values[0], T_ARRAY)) {
991
+ rb_raise(rb_eArgError, "tokens must be an Array");
992
+ return Qnil;
993
+ }
994
+ if (!RB_INTEGER_TYPE_P(kw_values[1])) {
995
+ rb_raise(rb_eArgError, "n_past must be an integer");
996
+ return Qnil;
997
+ }
998
+ if (kw_values[2] != Qundef && !RB_INTEGER_TYPE_P(kw_values[2])) {
999
+ rb_raise(rb_eArgError, "n_tokens must be an integer");
1000
+ return Qnil;
1001
+ }
1002
+ if (kw_values[3] != Qundef && !RB_INTEGER_TYPE_P(kw_values[3])) {
1003
+ rb_raise(rb_eArgError, "n_threads must be an integer");
1004
+ return Qnil;
1005
+ }
1006
+
1007
+ const size_t tokens_len = RARRAY_LEN(kw_values[0]);
1008
+ std::vector<float> embd(tokens_len);
1009
+ for (size_t i = 0; i < tokens_len; i++) {
1010
+ VALUE el = rb_ary_entry(kw_values[0], i);
1011
+ if (!RB_FLOAT_TYPE_P(el)) {
1012
+ rb_raise(rb_eArgError, "embd must be an array of floats");
1013
+ return Qnil;
1014
+ }
1015
+ embd[i] = NUM2DBL(el);
1016
+ }
1017
+
1018
+ const int n_tokens = kw_values[2] == Qundef ? (int)tokens_len : NUM2INT(kw_values[2]);
1019
+ const int n_past = NUM2INT(kw_values[1]);
1020
+ const int n_threads = kw_values[3] == Qundef ? 1 : NUM2INT(kw_values[3]);
1021
+
1022
+ LLaMAContextWrapper* ptr = get_llama_context(self);
1023
+ if (ptr->ctx == NULL) {
1024
+ rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1025
+ return Qnil;
1026
+ }
1027
+ if (llama_eval_embd(ptr->ctx, embd.data(), n_tokens, n_past, n_threads) != 0) {
1028
+ rb_raise(rb_eRuntimeError, "Failed to evaluate");
1029
+ return Qnil;
1030
+ }
1031
+
1032
+ rb_iv_set(self, "@n_tokens", INT2NUM(n_tokens));
1033
+ rb_iv_set(self, "@has_evaluated", Qtrue);
1034
+
1035
+ return Qnil;
1036
+ }
1037
+
790
1038
  static VALUE _llama_context_eval_export(VALUE self, VALUE fname_) {
791
1039
  LLaMAContextWrapper* ptr = get_llama_context(self);
792
1040
  if (ptr->ctx == NULL) {
@@ -873,7 +1121,9 @@ private:
873
1121
  return Qnil;
874
1122
  }
875
1123
 
876
- LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(rb_iv_get(self, "@params"));
1124
+ VALUE model = rb_iv_get(self, "@model");
1125
+ VALUE params = rb_iv_get(model, "@params");
1126
+ LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
877
1127
  const int n_tokens = prms_ptr->params.logits_all ? NUM2INT(rb_iv_get(self, "@n_tokens")) : 1;
878
1128
  const int n_vocab = llama_n_vocab(ptr->ctx);
879
1129
  const float* logits = llama_get_logits(ptr->ctx);
@@ -891,7 +1141,9 @@ private:
891
1141
  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
892
1142
  return Qnil;
893
1143
  }
894
- LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(rb_iv_get(self, "@params"));
1144
+ VALUE model = rb_iv_get(self, "@model");
1145
+ VALUE params = rb_iv_get(model, "@params");
1146
+ LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
895
1147
  if (!prms_ptr->params.embedding) {
896
1148
  rb_raise(rb_eRuntimeError, "embedding parameter is false");
897
1149
  return Qnil;
@@ -995,106 +1247,6 @@ private:
995
1247
  return Qnil;
996
1248
  };
997
1249
 
998
- static VALUE _llama_context_empty(VALUE self) {
999
- LLaMAContextWrapper* ptr = get_llama_context(self);
1000
- if (ptr->ctx != NULL) {
1001
- return Qfalse;
1002
- }
1003
- return Qtrue;
1004
- }
1005
-
1006
- static VALUE _llama_context_free(VALUE self) {
1007
- LLaMAContextWrapper* ptr = get_llama_context(self);
1008
- if (ptr->ctx != NULL) {
1009
- llama_free(ptr->ctx);
1010
- ptr->ctx = NULL;
1011
- rb_iv_set(self, "@params", Qnil);
1012
- rb_iv_set(self, "@has_evaluated", Qfalse);
1013
- }
1014
- return Qnil;
1015
- }
1016
-
1017
- static VALUE _llama_context_load(int argc, VALUE* argv, VALUE self) {
1018
- VALUE kw_args = Qnil;
1019
- ID kw_table[2] = { rb_intern("model_path"), rb_intern("params") };
1020
- VALUE kw_values[2] = { Qundef, Qundef };
1021
- rb_scan_args(argc, argv, ":", &kw_args);
1022
- rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
1023
-
1024
- if (!RB_TYPE_P(kw_values[0], T_STRING)) {
1025
- rb_raise(rb_eArgError, "model_path must be a string");
1026
- return Qnil;
1027
- }
1028
- if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAContextParams)) {
1029
- rb_raise(rb_eArgError, "params must be a LLaMAContextParams");
1030
- return Qnil;
1031
- }
1032
-
1033
- LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
1034
- if (ctx_ptr->ctx != NULL) {
1035
- rb_raise(rb_eRuntimeError, "LLaMA context is already loaded");
1036
- return Qnil;
1037
- }
1038
-
1039
- VALUE filename = kw_values[0];
1040
- LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(kw_values[1]);
1041
-
1042
- try {
1043
- ctx_ptr->ctx = llama_init_from_file(StringValueCStr(filename), prms_ptr->params);
1044
- } catch (const std::runtime_error& e) {
1045
- rb_raise(rb_eRuntimeError, "%s", e.what());
1046
- return Qnil;
1047
- }
1048
-
1049
- if (ctx_ptr->ctx == NULL) {
1050
- rb_raise(rb_eRuntimeError, "Failed to initialize LLaMA context");
1051
- return Qnil;
1052
- }
1053
-
1054
- rb_iv_set(self, "@params", kw_values[1]);
1055
- rb_iv_set(self, "@has_evaluated", Qfalse);
1056
-
1057
- RB_GC_GUARD(filename);
1058
- return Qnil;
1059
- };
1060
-
1061
- static VALUE _llama_context_apply_lora_from_file(int argc, VALUE* argv, VALUE self) {
1062
- VALUE kw_args = Qnil;
1063
- ID kw_table[3] = { rb_intern("lora_path"), rb_intern("base_model_path"), rb_intern("n_threads") };
1064
- VALUE kw_values[3] = { Qundef, Qundef, Qundef };
1065
- rb_scan_args(argc, argv, ":", &kw_args);
1066
- rb_get_kwargs(kw_args, kw_table, 1, 2, kw_values);
1067
-
1068
- if (!RB_TYPE_P(kw_values[0], T_STRING)) {
1069
- rb_raise(rb_eArgError, "lora_path must be a string");
1070
- return Qnil;
1071
- }
1072
- if (kw_values[1] != Qundef && !RB_TYPE_P(kw_values[1], T_STRING)) {
1073
- rb_raise(rb_eArgError, "base_model_path must be a string");
1074
- return Qnil;
1075
- }
1076
- if (kw_values[2] != Qundef && !RB_INTEGER_TYPE_P(kw_values[2])) {
1077
- rb_raise(rb_eArgError, "n_threads must be an integer");
1078
- return Qnil;
1079
- }
1080
-
1081
- const char* lora_path = StringValueCStr(kw_values[0]);
1082
- const char* base_model_path = kw_values[1] == Qundef ? NULL : StringValueCStr(kw_values[1]);
1083
- const int n_threads = kw_values[2] == Qundef ? 1 : NUM2INT(kw_values[2]);
1084
-
1085
- LLaMAContextWrapper* ptr = get_llama_context(self);
1086
- if (ptr->ctx != NULL) {
1087
- rb_raise(rb_eRuntimeError, "LLaMA context is already loaded");
1088
- return Qnil;
1089
- }
1090
-
1091
- if (llama_apply_lora_from_file(ptr->ctx, lora_path, base_model_path, n_threads) != 0) {
1092
- rb_raise(rb_eRuntimeError, "Failed to apply LoRA");
1093
- return Qnil;
1094
- }
1095
- return Qnil;
1096
- };
1097
-
1098
1250
  static VALUE _llama_context_kv_cache_token_count(VALUE self) {
1099
1251
  LLaMAContextWrapper* ptr = get_llama_context(self);
1100
1252
  if (ptr->ctx == NULL) {
@@ -1110,7 +1262,11 @@ private:
1110
1262
  rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
1111
1263
  return Qnil;
1112
1264
  }
1113
- const int seed = NUM2INT(seed_);
1265
+ if (NUM2INT(seed_) < 0) {
1266
+ rb_raise(rb_eArgError, "seed must be a non-negative integer");
1267
+ return Qnil;
1268
+ }
1269
+ const uint32_t seed = NUM2INT(seed_);
1114
1270
  llama_set_rng_seed(ptr->ctx, seed);
1115
1271
  return Qnil;
1116
1272
  };
@@ -1137,7 +1293,9 @@ private:
1137
1293
  return Qnil;
1138
1294
  }
1139
1295
 
1140
- LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(rb_iv_get(self, "@params"));
1296
+ VALUE model = rb_iv_get(self, "@model");
1297
+ VALUE params = rb_iv_get(model, "@params");
1298
+ LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
1141
1299
  const int n_ctx = prms_ptr->params.n_ctx;
1142
1300
 
1143
1301
  std::vector<llama_token> session_tokens(n_ctx);
@@ -1664,8 +1822,16 @@ const rb_data_type_t RbLLaMAContext::llama_context_type = {
1664
1822
 
1665
1823
  // module functions
1666
1824
 
1667
- static VALUE rb_llama_llama_init_backend(VALUE self) {
1668
- llama_init_backend();
1825
+ static VALUE rb_llama_llama_init_backend(int argc, VALUE* argv, VALUE self) {
1826
+ VALUE kw_args = Qnil;
1827
+ ID kw_table[1] = { rb_intern("numa") };
1828
+ VALUE kw_values[1] = { Qundef };
1829
+ rb_scan_args(argc, argv, ":", &kw_args);
1830
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
1831
+
1832
+ const bool numa = kw_values[0] == Qundef ? false : (RTEST ? true : false);
1833
+ llama_init_backend(numa);
1834
+
1669
1835
  return Qnil;
1670
1836
  }
1671
1837
 
@@ -1731,10 +1897,11 @@ extern "C" void Init_llama_cpp(void) {
1731
1897
 
1732
1898
  RbLLaMATokenData::define_class(rb_mLLaMACpp);
1733
1899
  RbLLaMATokenDataArray::define_class(rb_mLLaMACpp);
1900
+ RbLLaMAModel::define_class(rb_mLLaMACpp);
1734
1901
  RbLLaMAContext::define_class(rb_mLLaMACpp);
1735
1902
  RbLLaMAContextParams::define_class(rb_mLLaMACpp);
1736
1903
 
1737
- rb_define_module_function(rb_mLLaMACpp, "init_backend", rb_llama_llama_init_backend, 0);
1904
+ rb_define_module_function(rb_mLLaMACpp, "init_backend", rb_llama_llama_init_backend, -1);
1738
1905
  rb_define_module_function(rb_mLLaMACpp, "model_quantize", rb_llama_model_quantize, -1);
1739
1906
  rb_define_module_function(rb_mLLaMACpp, "token_bos", rb_llama_token_bos, 0);
1740
1907
  rb_define_module_function(rb_mLLaMACpp, "token_eos", rb_llama_token_eos, 0);
@@ -1802,6 +1969,11 @@ extern "C" void Init_llama_cpp(void) {
1802
1969
  ss_magic << std::showbase << std::hex << LLAMA_SESSION_MAGIC;
1803
1970
  rb_define_const(rb_mLLaMACpp, "LLAMA_SESSION_MAGIC", rb_str_new2(ss_magic.str().c_str()));
1804
1971
 
1972
+ ss_magic.str("");
1973
+ ss_magic.clear(std::stringstream::goodbit);
1974
+ ss_magic << std::showbase << std::hex << LLAMA_DEFAULT_SEED;
1975
+ rb_define_const(rb_mLLaMACpp, "LLAMA_DEFAULT_SEED", rb_str_new2(ss_magic.str().c_str()));
1976
+
1805
1977
  rb_define_const(rb_mLLaMACpp, "LLAMA_FILE_VERSION", rb_str_new2(std::to_string(LLAMA_FILE_VERSION).c_str()));
1806
1978
  rb_define_const(rb_mLLaMACpp, "LLAMA_SESSION_VERSION", rb_str_new2(std::to_string(LLAMA_SESSION_VERSION).c_str()));
1807
1979
  }