llama_cpp 0.2.2 → 0.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +34 -0
- data/README.md +39 -6
- data/examples/chat.rb +2 -1
- data/examples/embedding.rb +3 -2
- data/ext/llama_cpp/extconf.rb +13 -0
- data/ext/llama_cpp/llama_cpp.cpp +305 -133
- data/ext/llama_cpp/src/ggml-cuda.cu +367 -69
- data/ext/llama_cpp/src/ggml-cuda.h +1 -0
- data/ext/llama_cpp/src/ggml-metal.m +36 -30
- data/ext/llama_cpp/src/ggml-metal.metal +328 -84
- data/ext/llama_cpp/src/ggml-opencl.cpp +352 -175
- data/ext/llama_cpp/src/ggml.c +800 -303
- data/ext/llama_cpp/src/ggml.h +68 -5
- data/ext/llama_cpp/src/k_quants.c +1712 -56
- data/ext/llama_cpp/src/k_quants.h +41 -6
- data/ext/llama_cpp/src/llama-util.h +19 -5
- data/ext/llama_cpp/src/llama.cpp +262 -291
- data/ext/llama_cpp/src/llama.h +49 -11
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +0 -2
- data/sig/llama_cpp.rbs +14 -17
- metadata +2 -3
- data/lib/llama_cpp/client.rb +0 -172
data/ext/llama_cpp/llama_cpp.cpp
CHANGED
@@ -2,6 +2,7 @@
|
|
2
2
|
#include "llama_cpp.h"
|
3
3
|
|
4
4
|
VALUE rb_mLLaMACpp;
|
5
|
+
VALUE rb_cLLaMAModel;
|
5
6
|
VALUE rb_cLLaMAContext;
|
6
7
|
VALUE rb_cLLaMAContextParams;
|
7
8
|
VALUE rb_cLLaMAModelQuantizeParams;
|
@@ -403,6 +404,10 @@ private:
|
|
403
404
|
// seed
|
404
405
|
static VALUE _llama_context_params_set_seed(VALUE self, VALUE seed) {
|
405
406
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
407
|
+
if (NUM2INT(seed) < 0) {
|
408
|
+
rb_raise(rb_eArgError, "seed must be positive");
|
409
|
+
return Qnil;
|
410
|
+
}
|
406
411
|
ptr->params.seed = NUM2INT(seed);
|
407
412
|
return INT2NUM(ptr->params.seed);
|
408
413
|
};
|
@@ -610,6 +615,206 @@ const rb_data_type_t RbLLaMAModelQuantizeParams::llama_model_quantize_params_typ
|
|
610
615
|
RUBY_TYPED_FREE_IMMEDIATELY
|
611
616
|
};
|
612
617
|
|
618
|
+
class LLaMAModelWrapper {
|
619
|
+
public:
|
620
|
+
struct llama_model* model;
|
621
|
+
|
622
|
+
LLaMAModelWrapper() : model(NULL){};
|
623
|
+
|
624
|
+
~LLaMAModelWrapper() {
|
625
|
+
if (model != NULL) {
|
626
|
+
llama_free_model(model);
|
627
|
+
}
|
628
|
+
};
|
629
|
+
};
|
630
|
+
|
631
|
+
class RbLLaMAModel {
|
632
|
+
public:
|
633
|
+
static VALUE llama_model_alloc(VALUE self) {
|
634
|
+
LLaMAModelWrapper* ptr = (LLaMAModelWrapper*)ruby_xmalloc(sizeof(LLaMAModelWrapper));
|
635
|
+
new (ptr) LLaMAModelWrapper();
|
636
|
+
return TypedData_Wrap_Struct(self, &llama_model_type, ptr);
|
637
|
+
}
|
638
|
+
|
639
|
+
static void llama_model_free(void* ptr) {
|
640
|
+
((LLaMAModelWrapper*)ptr)->~LLaMAModelWrapper();
|
641
|
+
ruby_xfree(ptr);
|
642
|
+
}
|
643
|
+
|
644
|
+
static size_t llama_model_size(const void* ptr) {
|
645
|
+
return sizeof(*((LLaMAModelWrapper*)ptr));
|
646
|
+
}
|
647
|
+
|
648
|
+
static LLaMAModelWrapper* get_llama_model(VALUE self) {
|
649
|
+
LLaMAModelWrapper* ptr;
|
650
|
+
TypedData_Get_Struct(self, LLaMAModelWrapper, &llama_model_type, ptr);
|
651
|
+
return ptr;
|
652
|
+
}
|
653
|
+
|
654
|
+
static void define_class(VALUE outer) {
|
655
|
+
rb_cLLaMAModel = rb_define_class_under(outer, "Model", rb_cObject);
|
656
|
+
rb_define_alloc_func(rb_cLLaMAModel, llama_model_alloc);
|
657
|
+
rb_define_method(rb_cLLaMAModel, "initialize", RUBY_METHOD_FUNC(_llama_model_initialize), -1);
|
658
|
+
rb_define_method(rb_cLLaMAModel, "empty?", RUBY_METHOD_FUNC(_llama_model_empty), 0);
|
659
|
+
rb_define_method(rb_cLLaMAModel, "free", RUBY_METHOD_FUNC(_llama_model_free), 0);
|
660
|
+
rb_define_method(rb_cLLaMAModel, "load", RUBY_METHOD_FUNC(_llama_model_load), -1);
|
661
|
+
rb_define_method(rb_cLLaMAModel, "apply_lora_from_file", RUBY_METHOD_FUNC(_llama_model_apply_lora_from_file), -1);
|
662
|
+
}
|
663
|
+
|
664
|
+
private:
|
665
|
+
static const rb_data_type_t llama_model_type;
|
666
|
+
|
667
|
+
static VALUE _llama_model_initialize(int argc, VALUE* argv, VALUE self) {
|
668
|
+
VALUE kw_args = Qnil;
|
669
|
+
ID kw_table[2] = { rb_intern("model_path"), rb_intern("params") };
|
670
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
671
|
+
rb_scan_args(argc, argv, ":", &kw_args);
|
672
|
+
rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
|
673
|
+
|
674
|
+
if (kw_values[0] == Qundef && kw_values[1] == Qundef) {
|
675
|
+
rb_iv_set(self, "@params", Qnil);
|
676
|
+
return Qnil;
|
677
|
+
}
|
678
|
+
|
679
|
+
if (!RB_TYPE_P(kw_values[0], T_STRING)) {
|
680
|
+
rb_raise(rb_eArgError, "model_path must be a string");
|
681
|
+
return Qnil;
|
682
|
+
}
|
683
|
+
if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAContextParams)) {
|
684
|
+
rb_raise(rb_eArgError, "params must be a ContextParams");
|
685
|
+
return Qnil;
|
686
|
+
}
|
687
|
+
|
688
|
+
VALUE filename = kw_values[0];
|
689
|
+
LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(kw_values[1]);
|
690
|
+
LLaMAModelWrapper* model_ptr = get_llama_model(self);
|
691
|
+
|
692
|
+
if (prms_ptr->params.seed == LLAMA_DEFAULT_SEED) {
|
693
|
+
prms_ptr->params.seed = time(NULL);
|
694
|
+
}
|
695
|
+
|
696
|
+
try {
|
697
|
+
model_ptr->model = llama_load_model_from_file(StringValueCStr(filename), prms_ptr->params);
|
698
|
+
} catch (const std::runtime_error& e) {
|
699
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
700
|
+
return Qnil;
|
701
|
+
}
|
702
|
+
|
703
|
+
if (model_ptr->model == NULL) {
|
704
|
+
rb_raise(rb_eRuntimeError, "Failed to initialize LLaMA model");
|
705
|
+
return Qnil;
|
706
|
+
}
|
707
|
+
|
708
|
+
rb_iv_set(self, "@params", kw_values[1]);
|
709
|
+
|
710
|
+
RB_GC_GUARD(filename);
|
711
|
+
return Qnil;
|
712
|
+
}
|
713
|
+
|
714
|
+
static VALUE _llama_model_empty(VALUE self) {
|
715
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
716
|
+
if (ptr->model != NULL) {
|
717
|
+
return Qfalse;
|
718
|
+
}
|
719
|
+
return Qtrue;
|
720
|
+
}
|
721
|
+
|
722
|
+
static VALUE _llama_model_free(VALUE self) {
|
723
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
724
|
+
if (ptr->model != NULL) {
|
725
|
+
llama_free_model(ptr->model);
|
726
|
+
ptr->model = NULL;
|
727
|
+
rb_iv_set(self, "@params", Qnil);
|
728
|
+
}
|
729
|
+
return Qnil;
|
730
|
+
}
|
731
|
+
|
732
|
+
static VALUE _llama_model_load(int argc, VALUE* argv, VALUE self) {
|
733
|
+
VALUE kw_args = Qnil;
|
734
|
+
ID kw_table[2] = { rb_intern("model_path"), rb_intern("params") };
|
735
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
736
|
+
rb_scan_args(argc, argv, ":", &kw_args);
|
737
|
+
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
738
|
+
|
739
|
+
if (!RB_TYPE_P(kw_values[0], T_STRING)) {
|
740
|
+
rb_raise(rb_eArgError, "model_path must be a string");
|
741
|
+
return Qnil;
|
742
|
+
}
|
743
|
+
if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAContextParams)) {
|
744
|
+
rb_raise(rb_eArgError, "params must be a LLaMAContextParams");
|
745
|
+
return Qnil;
|
746
|
+
}
|
747
|
+
|
748
|
+
LLaMAModelWrapper* model_ptr = get_llama_model(self);
|
749
|
+
if (model_ptr->model != NULL) {
|
750
|
+
rb_raise(rb_eRuntimeError, "LLaMA model is already loaded");
|
751
|
+
return Qnil;
|
752
|
+
}
|
753
|
+
|
754
|
+
VALUE filename = kw_values[0];
|
755
|
+
LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(kw_values[1]);
|
756
|
+
|
757
|
+
try {
|
758
|
+
model_ptr->model = llama_load_model_from_file(StringValueCStr(filename), prms_ptr->params);
|
759
|
+
} catch (const std::runtime_error& e) {
|
760
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
761
|
+
return Qnil;
|
762
|
+
}
|
763
|
+
|
764
|
+
if (model_ptr->model == NULL) {
|
765
|
+
rb_raise(rb_eRuntimeError, "Failed to initialize LLaMA model");
|
766
|
+
return Qnil;
|
767
|
+
}
|
768
|
+
|
769
|
+
rb_iv_set(self, "@params", kw_values[1]);
|
770
|
+
|
771
|
+
RB_GC_GUARD(filename);
|
772
|
+
return Qnil;
|
773
|
+
}
|
774
|
+
|
775
|
+
static VALUE _llama_model_apply_lora_from_file(int argc, VALUE* argv, VALUE self) {
|
776
|
+
VALUE kw_args = Qnil;
|
777
|
+
ID kw_table[3] = { rb_intern("lora_path"), rb_intern("base_model_path"), rb_intern("n_threads") };
|
778
|
+
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
779
|
+
rb_scan_args(argc, argv, ":", &kw_args);
|
780
|
+
rb_get_kwargs(kw_args, kw_table, 1, 2, kw_values);
|
781
|
+
|
782
|
+
if (!RB_TYPE_P(kw_values[0], T_STRING)) {
|
783
|
+
rb_raise(rb_eArgError, "lora_path must be a string");
|
784
|
+
return Qnil;
|
785
|
+
}
|
786
|
+
if (kw_values[1] != Qundef && !RB_TYPE_P(kw_values[1], T_STRING)) {
|
787
|
+
rb_raise(rb_eArgError, "base_model_path must be a string");
|
788
|
+
return Qnil;
|
789
|
+
}
|
790
|
+
if (kw_values[2] != Qundef && !RB_INTEGER_TYPE_P(kw_values[2])) {
|
791
|
+
rb_raise(rb_eArgError, "n_threads must be an integer");
|
792
|
+
return Qnil;
|
793
|
+
}
|
794
|
+
|
795
|
+
const char* lora_path = StringValueCStr(kw_values[0]);
|
796
|
+
const char* base_model_path = kw_values[1] == Qundef ? NULL : StringValueCStr(kw_values[1]);
|
797
|
+
const int n_threads = kw_values[2] == Qundef ? 1 : NUM2INT(kw_values[2]);
|
798
|
+
|
799
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
800
|
+
if (llama_model_apply_lora_from_file(ptr->model, lora_path, base_model_path, n_threads) != 0) {
|
801
|
+
rb_raise(rb_eRuntimeError, "Failed to apply LoRA");
|
802
|
+
return Qnil;
|
803
|
+
}
|
804
|
+
return Qnil;
|
805
|
+
};
|
806
|
+
};
|
807
|
+
|
808
|
+
const rb_data_type_t RbLLaMAModel::llama_model_type = {
|
809
|
+
"RbLLaMAModel",
|
810
|
+
{ NULL,
|
811
|
+
RbLLaMAModel::llama_model_free,
|
812
|
+
RbLLaMAModel::llama_model_size },
|
813
|
+
NULL,
|
814
|
+
NULL,
|
815
|
+
RUBY_TYPED_FREE_IMMEDIATELY
|
816
|
+
};
|
817
|
+
|
613
818
|
class LLaMAContextWrapper {
|
614
819
|
public:
|
615
820
|
struct llama_context* ctx;
|
@@ -651,6 +856,7 @@ public:
|
|
651
856
|
rb_define_alloc_func(rb_cLLaMAContext, llama_context_alloc);
|
652
857
|
rb_define_method(rb_cLLaMAContext, "initialize", RUBY_METHOD_FUNC(_llama_context_initialize), -1);
|
653
858
|
rb_define_method(rb_cLLaMAContext, "eval", RUBY_METHOD_FUNC(_llama_context_eval), -1);
|
859
|
+
rb_define_method(rb_cLLaMAContext, "eval_embd", RUBY_METHOD_FUNC(_llama_context_eval_embd), -1);
|
654
860
|
rb_define_method(rb_cLLaMAContext, "eval_export", RUBY_METHOD_FUNC(_llama_context_eval_export), 1);
|
655
861
|
rb_define_method(rb_cLLaMAContext, "tokenize", RUBY_METHOD_FUNC(_llama_context_tokenize), -1);
|
656
862
|
rb_define_method(rb_cLLaMAContext, "logits", RUBY_METHOD_FUNC(_llama_context_logits), 0);
|
@@ -662,10 +868,6 @@ public:
|
|
662
868
|
rb_define_method(rb_cLLaMAContext, "n_embd", RUBY_METHOD_FUNC(_llama_context_n_embd), 0);
|
663
869
|
rb_define_method(rb_cLLaMAContext, "print_timings", RUBY_METHOD_FUNC(_llama_context_print_timings), 0);
|
664
870
|
rb_define_method(rb_cLLaMAContext, "reset_timings", RUBY_METHOD_FUNC(_llama_context_reset_timings), 0);
|
665
|
-
rb_define_method(rb_cLLaMAContext, "empty?", RUBY_METHOD_FUNC(_llama_context_empty), 0);
|
666
|
-
rb_define_method(rb_cLLaMAContext, "free", RUBY_METHOD_FUNC(_llama_context_free), 0);
|
667
|
-
rb_define_method(rb_cLLaMAContext, "load", RUBY_METHOD_FUNC(_llama_context_load), -1);
|
668
|
-
rb_define_method(rb_cLLaMAContext, "apply_lora_from_file", RUBY_METHOD_FUNC(_llama_context_apply_lora_from_file), -1);
|
669
871
|
rb_define_method(rb_cLLaMAContext, "kv_cache_token_count", RUBY_METHOD_FUNC(_llama_context_kv_cache_token_count), 0);
|
670
872
|
rb_define_method(rb_cLLaMAContext, "set_rng_seed", RUBY_METHOD_FUNC(_llama_context_set_rng_seed), 1);
|
671
873
|
rb_define_method(rb_cLLaMAContext, "load_session_file", RUBY_METHOD_FUNC(_llama_context_load_session_file), -1);
|
@@ -689,46 +891,37 @@ private:
|
|
689
891
|
|
690
892
|
static VALUE _llama_context_initialize(int argc, VALUE* argv, VALUE self) {
|
691
893
|
VALUE kw_args = Qnil;
|
692
|
-
ID kw_table[
|
693
|
-
VALUE kw_values[
|
894
|
+
ID kw_table[1] = { rb_intern("model") };
|
895
|
+
VALUE kw_values[1] = { Qundef };
|
694
896
|
rb_scan_args(argc, argv, ":", &kw_args);
|
695
|
-
rb_get_kwargs(kw_args, kw_table,
|
897
|
+
rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
|
696
898
|
|
697
|
-
|
698
|
-
|
699
|
-
|
899
|
+
VALUE model = kw_values[0];
|
900
|
+
if (!rb_obj_is_kind_of(model, rb_cLLaMAModel)) {
|
901
|
+
rb_raise(rb_eArgError, "model must be a Model");
|
700
902
|
return Qnil;
|
701
903
|
}
|
702
904
|
|
703
|
-
|
704
|
-
|
705
|
-
|
706
|
-
}
|
707
|
-
if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAContextParams)) {
|
708
|
-
rb_raise(rb_eArgError, "params must be a ContextParams");
|
905
|
+
LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
|
906
|
+
if (model_ptr->model == NULL) {
|
907
|
+
rb_raise(rb_eRuntimeError, "Model is empty");
|
709
908
|
return Qnil;
|
710
909
|
}
|
711
910
|
|
712
|
-
VALUE
|
713
|
-
LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(
|
911
|
+
VALUE params = rb_iv_get(model, "@params");
|
912
|
+
LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
|
714
913
|
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
715
914
|
|
716
|
-
|
717
|
-
ctx_ptr->ctx = llama_init_from_file(StringValueCStr(filename), prms_ptr->params);
|
718
|
-
} catch (const std::runtime_error& e) {
|
719
|
-
rb_raise(rb_eRuntimeError, "%s", e.what());
|
720
|
-
return Qnil;
|
721
|
-
}
|
915
|
+
ctx_ptr->ctx = llama_new_context_with_model(model_ptr->model, prms_ptr->params);
|
722
916
|
|
723
917
|
if (ctx_ptr->ctx == NULL) {
|
724
918
|
rb_raise(rb_eRuntimeError, "Failed to initialize LLaMA context");
|
725
919
|
return Qnil;
|
726
920
|
}
|
727
921
|
|
728
|
-
rb_iv_set(self, "@
|
922
|
+
rb_iv_set(self, "@model", model);
|
729
923
|
rb_iv_set(self, "@has_evaluated", Qfalse);
|
730
924
|
|
731
|
-
RB_GC_GUARD(filename);
|
732
925
|
return Qnil;
|
733
926
|
};
|
734
927
|
|
@@ -787,6 +980,61 @@ private:
|
|
787
980
|
return Qnil;
|
788
981
|
};
|
789
982
|
|
983
|
+
static VALUE _llama_context_eval_embd(int argc, VALUE* argv, VALUE self) {
|
984
|
+
VALUE kw_args = Qnil;
|
985
|
+
ID kw_table[4] = { rb_intern("embd"), rb_intern("n_past"), rb_intern("n_tokens"), rb_intern("n_threads") };
|
986
|
+
VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
|
987
|
+
rb_scan_args(argc, argv, ":", &kw_args);
|
988
|
+
rb_get_kwargs(kw_args, kw_table, 2, 2, kw_values);
|
989
|
+
|
990
|
+
if (!RB_TYPE_P(kw_values[0], T_ARRAY)) {
|
991
|
+
rb_raise(rb_eArgError, "tokens must be an Array");
|
992
|
+
return Qnil;
|
993
|
+
}
|
994
|
+
if (!RB_INTEGER_TYPE_P(kw_values[1])) {
|
995
|
+
rb_raise(rb_eArgError, "n_past must be an integer");
|
996
|
+
return Qnil;
|
997
|
+
}
|
998
|
+
if (kw_values[2] != Qundef && !RB_INTEGER_TYPE_P(kw_values[2])) {
|
999
|
+
rb_raise(rb_eArgError, "n_tokens must be an integer");
|
1000
|
+
return Qnil;
|
1001
|
+
}
|
1002
|
+
if (kw_values[3] != Qundef && !RB_INTEGER_TYPE_P(kw_values[3])) {
|
1003
|
+
rb_raise(rb_eArgError, "n_threads must be an integer");
|
1004
|
+
return Qnil;
|
1005
|
+
}
|
1006
|
+
|
1007
|
+
const size_t tokens_len = RARRAY_LEN(kw_values[0]);
|
1008
|
+
std::vector<float> embd(tokens_len);
|
1009
|
+
for (size_t i = 0; i < tokens_len; i++) {
|
1010
|
+
VALUE el = rb_ary_entry(kw_values[0], i);
|
1011
|
+
if (!RB_FLOAT_TYPE_P(el)) {
|
1012
|
+
rb_raise(rb_eArgError, "embd must be an array of floats");
|
1013
|
+
return Qnil;
|
1014
|
+
}
|
1015
|
+
embd[i] = NUM2DBL(el);
|
1016
|
+
}
|
1017
|
+
|
1018
|
+
const int n_tokens = kw_values[2] == Qundef ? (int)tokens_len : NUM2INT(kw_values[2]);
|
1019
|
+
const int n_past = NUM2INT(kw_values[1]);
|
1020
|
+
const int n_threads = kw_values[3] == Qundef ? 1 : NUM2INT(kw_values[3]);
|
1021
|
+
|
1022
|
+
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1023
|
+
if (ptr->ctx == NULL) {
|
1024
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1025
|
+
return Qnil;
|
1026
|
+
}
|
1027
|
+
if (llama_eval_embd(ptr->ctx, embd.data(), n_tokens, n_past, n_threads) != 0) {
|
1028
|
+
rb_raise(rb_eRuntimeError, "Failed to evaluate");
|
1029
|
+
return Qnil;
|
1030
|
+
}
|
1031
|
+
|
1032
|
+
rb_iv_set(self, "@n_tokens", INT2NUM(n_tokens));
|
1033
|
+
rb_iv_set(self, "@has_evaluated", Qtrue);
|
1034
|
+
|
1035
|
+
return Qnil;
|
1036
|
+
}
|
1037
|
+
|
790
1038
|
static VALUE _llama_context_eval_export(VALUE self, VALUE fname_) {
|
791
1039
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
792
1040
|
if (ptr->ctx == NULL) {
|
@@ -873,7 +1121,9 @@ private:
|
|
873
1121
|
return Qnil;
|
874
1122
|
}
|
875
1123
|
|
876
|
-
|
1124
|
+
VALUE model = rb_iv_get(self, "@model");
|
1125
|
+
VALUE params = rb_iv_get(model, "@params");
|
1126
|
+
LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
|
877
1127
|
const int n_tokens = prms_ptr->params.logits_all ? NUM2INT(rb_iv_get(self, "@n_tokens")) : 1;
|
878
1128
|
const int n_vocab = llama_n_vocab(ptr->ctx);
|
879
1129
|
const float* logits = llama_get_logits(ptr->ctx);
|
@@ -891,7 +1141,9 @@ private:
|
|
891
1141
|
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
892
1142
|
return Qnil;
|
893
1143
|
}
|
894
|
-
|
1144
|
+
VALUE model = rb_iv_get(self, "@model");
|
1145
|
+
VALUE params = rb_iv_get(model, "@params");
|
1146
|
+
LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
|
895
1147
|
if (!prms_ptr->params.embedding) {
|
896
1148
|
rb_raise(rb_eRuntimeError, "embedding parameter is false");
|
897
1149
|
return Qnil;
|
@@ -995,106 +1247,6 @@ private:
|
|
995
1247
|
return Qnil;
|
996
1248
|
};
|
997
1249
|
|
998
|
-
static VALUE _llama_context_empty(VALUE self) {
|
999
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1000
|
-
if (ptr->ctx != NULL) {
|
1001
|
-
return Qfalse;
|
1002
|
-
}
|
1003
|
-
return Qtrue;
|
1004
|
-
}
|
1005
|
-
|
1006
|
-
static VALUE _llama_context_free(VALUE self) {
|
1007
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1008
|
-
if (ptr->ctx != NULL) {
|
1009
|
-
llama_free(ptr->ctx);
|
1010
|
-
ptr->ctx = NULL;
|
1011
|
-
rb_iv_set(self, "@params", Qnil);
|
1012
|
-
rb_iv_set(self, "@has_evaluated", Qfalse);
|
1013
|
-
}
|
1014
|
-
return Qnil;
|
1015
|
-
}
|
1016
|
-
|
1017
|
-
static VALUE _llama_context_load(int argc, VALUE* argv, VALUE self) {
|
1018
|
-
VALUE kw_args = Qnil;
|
1019
|
-
ID kw_table[2] = { rb_intern("model_path"), rb_intern("params") };
|
1020
|
-
VALUE kw_values[2] = { Qundef, Qundef };
|
1021
|
-
rb_scan_args(argc, argv, ":", &kw_args);
|
1022
|
-
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
1023
|
-
|
1024
|
-
if (!RB_TYPE_P(kw_values[0], T_STRING)) {
|
1025
|
-
rb_raise(rb_eArgError, "model_path must be a string");
|
1026
|
-
return Qnil;
|
1027
|
-
}
|
1028
|
-
if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAContextParams)) {
|
1029
|
-
rb_raise(rb_eArgError, "params must be a LLaMAContextParams");
|
1030
|
-
return Qnil;
|
1031
|
-
}
|
1032
|
-
|
1033
|
-
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
1034
|
-
if (ctx_ptr->ctx != NULL) {
|
1035
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is already loaded");
|
1036
|
-
return Qnil;
|
1037
|
-
}
|
1038
|
-
|
1039
|
-
VALUE filename = kw_values[0];
|
1040
|
-
LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(kw_values[1]);
|
1041
|
-
|
1042
|
-
try {
|
1043
|
-
ctx_ptr->ctx = llama_init_from_file(StringValueCStr(filename), prms_ptr->params);
|
1044
|
-
} catch (const std::runtime_error& e) {
|
1045
|
-
rb_raise(rb_eRuntimeError, "%s", e.what());
|
1046
|
-
return Qnil;
|
1047
|
-
}
|
1048
|
-
|
1049
|
-
if (ctx_ptr->ctx == NULL) {
|
1050
|
-
rb_raise(rb_eRuntimeError, "Failed to initialize LLaMA context");
|
1051
|
-
return Qnil;
|
1052
|
-
}
|
1053
|
-
|
1054
|
-
rb_iv_set(self, "@params", kw_values[1]);
|
1055
|
-
rb_iv_set(self, "@has_evaluated", Qfalse);
|
1056
|
-
|
1057
|
-
RB_GC_GUARD(filename);
|
1058
|
-
return Qnil;
|
1059
|
-
};
|
1060
|
-
|
1061
|
-
static VALUE _llama_context_apply_lora_from_file(int argc, VALUE* argv, VALUE self) {
|
1062
|
-
VALUE kw_args = Qnil;
|
1063
|
-
ID kw_table[3] = { rb_intern("lora_path"), rb_intern("base_model_path"), rb_intern("n_threads") };
|
1064
|
-
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
1065
|
-
rb_scan_args(argc, argv, ":", &kw_args);
|
1066
|
-
rb_get_kwargs(kw_args, kw_table, 1, 2, kw_values);
|
1067
|
-
|
1068
|
-
if (!RB_TYPE_P(kw_values[0], T_STRING)) {
|
1069
|
-
rb_raise(rb_eArgError, "lora_path must be a string");
|
1070
|
-
return Qnil;
|
1071
|
-
}
|
1072
|
-
if (kw_values[1] != Qundef && !RB_TYPE_P(kw_values[1], T_STRING)) {
|
1073
|
-
rb_raise(rb_eArgError, "base_model_path must be a string");
|
1074
|
-
return Qnil;
|
1075
|
-
}
|
1076
|
-
if (kw_values[2] != Qundef && !RB_INTEGER_TYPE_P(kw_values[2])) {
|
1077
|
-
rb_raise(rb_eArgError, "n_threads must be an integer");
|
1078
|
-
return Qnil;
|
1079
|
-
}
|
1080
|
-
|
1081
|
-
const char* lora_path = StringValueCStr(kw_values[0]);
|
1082
|
-
const char* base_model_path = kw_values[1] == Qundef ? NULL : StringValueCStr(kw_values[1]);
|
1083
|
-
const int n_threads = kw_values[2] == Qundef ? 1 : NUM2INT(kw_values[2]);
|
1084
|
-
|
1085
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1086
|
-
if (ptr->ctx != NULL) {
|
1087
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is already loaded");
|
1088
|
-
return Qnil;
|
1089
|
-
}
|
1090
|
-
|
1091
|
-
if (llama_apply_lora_from_file(ptr->ctx, lora_path, base_model_path, n_threads) != 0) {
|
1092
|
-
rb_raise(rb_eRuntimeError, "Failed to apply LoRA");
|
1093
|
-
return Qnil;
|
1094
|
-
}
|
1095
|
-
return Qnil;
|
1096
|
-
};
|
1097
|
-
|
1098
1250
|
static VALUE _llama_context_kv_cache_token_count(VALUE self) {
|
1099
1251
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1100
1252
|
if (ptr->ctx == NULL) {
|
@@ -1110,7 +1262,11 @@ private:
|
|
1110
1262
|
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1111
1263
|
return Qnil;
|
1112
1264
|
}
|
1113
|
-
|
1265
|
+
if (NUM2INT(seed_) < 0) {
|
1266
|
+
rb_raise(rb_eArgError, "seed must be a non-negative integer");
|
1267
|
+
return Qnil;
|
1268
|
+
}
|
1269
|
+
const uint32_t seed = NUM2INT(seed_);
|
1114
1270
|
llama_set_rng_seed(ptr->ctx, seed);
|
1115
1271
|
return Qnil;
|
1116
1272
|
};
|
@@ -1137,7 +1293,9 @@ private:
|
|
1137
1293
|
return Qnil;
|
1138
1294
|
}
|
1139
1295
|
|
1140
|
-
|
1296
|
+
VALUE model = rb_iv_get(self, "@model");
|
1297
|
+
VALUE params = rb_iv_get(model, "@params");
|
1298
|
+
LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
|
1141
1299
|
const int n_ctx = prms_ptr->params.n_ctx;
|
1142
1300
|
|
1143
1301
|
std::vector<llama_token> session_tokens(n_ctx);
|
@@ -1664,8 +1822,16 @@ const rb_data_type_t RbLLaMAContext::llama_context_type = {
|
|
1664
1822
|
|
1665
1823
|
// module functions
|
1666
1824
|
|
1667
|
-
static VALUE rb_llama_llama_init_backend(VALUE self) {
|
1668
|
-
|
1825
|
+
static VALUE rb_llama_llama_init_backend(int argc, VALUE* argv, VALUE self) {
|
1826
|
+
VALUE kw_args = Qnil;
|
1827
|
+
ID kw_table[1] = { rb_intern("numa") };
|
1828
|
+
VALUE kw_values[1] = { Qundef };
|
1829
|
+
rb_scan_args(argc, argv, ":", &kw_args);
|
1830
|
+
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
1831
|
+
|
1832
|
+
const bool numa = kw_values[0] == Qundef ? false : (RTEST ? true : false);
|
1833
|
+
llama_init_backend(numa);
|
1834
|
+
|
1669
1835
|
return Qnil;
|
1670
1836
|
}
|
1671
1837
|
|
@@ -1731,10 +1897,11 @@ extern "C" void Init_llama_cpp(void) {
|
|
1731
1897
|
|
1732
1898
|
RbLLaMATokenData::define_class(rb_mLLaMACpp);
|
1733
1899
|
RbLLaMATokenDataArray::define_class(rb_mLLaMACpp);
|
1900
|
+
RbLLaMAModel::define_class(rb_mLLaMACpp);
|
1734
1901
|
RbLLaMAContext::define_class(rb_mLLaMACpp);
|
1735
1902
|
RbLLaMAContextParams::define_class(rb_mLLaMACpp);
|
1736
1903
|
|
1737
|
-
rb_define_module_function(rb_mLLaMACpp, "init_backend", rb_llama_llama_init_backend,
|
1904
|
+
rb_define_module_function(rb_mLLaMACpp, "init_backend", rb_llama_llama_init_backend, -1);
|
1738
1905
|
rb_define_module_function(rb_mLLaMACpp, "model_quantize", rb_llama_model_quantize, -1);
|
1739
1906
|
rb_define_module_function(rb_mLLaMACpp, "token_bos", rb_llama_token_bos, 0);
|
1740
1907
|
rb_define_module_function(rb_mLLaMACpp, "token_eos", rb_llama_token_eos, 0);
|
@@ -1802,6 +1969,11 @@ extern "C" void Init_llama_cpp(void) {
|
|
1802
1969
|
ss_magic << std::showbase << std::hex << LLAMA_SESSION_MAGIC;
|
1803
1970
|
rb_define_const(rb_mLLaMACpp, "LLAMA_SESSION_MAGIC", rb_str_new2(ss_magic.str().c_str()));
|
1804
1971
|
|
1972
|
+
ss_magic.str("");
|
1973
|
+
ss_magic.clear(std::stringstream::goodbit);
|
1974
|
+
ss_magic << std::showbase << std::hex << LLAMA_DEFAULT_SEED;
|
1975
|
+
rb_define_const(rb_mLLaMACpp, "LLAMA_DEFAULT_SEED", rb_str_new2(ss_magic.str().c_str()));
|
1976
|
+
|
1805
1977
|
rb_define_const(rb_mLLaMACpp, "LLAMA_FILE_VERSION", rb_str_new2(std::to_string(LLAMA_FILE_VERSION).c_str()));
|
1806
1978
|
rb_define_const(rb_mLLaMACpp, "LLAMA_SESSION_VERSION", rb_str_new2(std::to_string(LLAMA_SESSION_VERSION).c_str()));
|
1807
1979
|
}
|