llama_cpp 0.2.2 → 0.3.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +34 -0
- data/README.md +39 -6
- data/examples/chat.rb +2 -1
- data/examples/embedding.rb +3 -2
- data/ext/llama_cpp/extconf.rb +13 -0
- data/ext/llama_cpp/llama_cpp.cpp +305 -133
- data/ext/llama_cpp/src/ggml-cuda.cu +367 -69
- data/ext/llama_cpp/src/ggml-cuda.h +1 -0
- data/ext/llama_cpp/src/ggml-metal.m +36 -30
- data/ext/llama_cpp/src/ggml-metal.metal +328 -84
- data/ext/llama_cpp/src/ggml-opencl.cpp +352 -175
- data/ext/llama_cpp/src/ggml.c +800 -303
- data/ext/llama_cpp/src/ggml.h +68 -5
- data/ext/llama_cpp/src/k_quants.c +1712 -56
- data/ext/llama_cpp/src/k_quants.h +41 -6
- data/ext/llama_cpp/src/llama-util.h +19 -5
- data/ext/llama_cpp/src/llama.cpp +262 -291
- data/ext/llama_cpp/src/llama.h +49 -11
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +0 -2
- data/sig/llama_cpp.rbs +14 -17
- metadata +2 -3
- data/lib/llama_cpp/client.rb +0 -172
data/ext/llama_cpp/llama_cpp.cpp
CHANGED
@@ -2,6 +2,7 @@
|
|
2
2
|
#include "llama_cpp.h"
|
3
3
|
|
4
4
|
VALUE rb_mLLaMACpp;
|
5
|
+
VALUE rb_cLLaMAModel;
|
5
6
|
VALUE rb_cLLaMAContext;
|
6
7
|
VALUE rb_cLLaMAContextParams;
|
7
8
|
VALUE rb_cLLaMAModelQuantizeParams;
|
@@ -403,6 +404,10 @@ private:
|
|
403
404
|
// seed
|
404
405
|
static VALUE _llama_context_params_set_seed(VALUE self, VALUE seed) {
|
405
406
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
407
|
+
if (NUM2INT(seed) < 0) {
|
408
|
+
rb_raise(rb_eArgError, "seed must be positive");
|
409
|
+
return Qnil;
|
410
|
+
}
|
406
411
|
ptr->params.seed = NUM2INT(seed);
|
407
412
|
return INT2NUM(ptr->params.seed);
|
408
413
|
};
|
@@ -610,6 +615,206 @@ const rb_data_type_t RbLLaMAModelQuantizeParams::llama_model_quantize_params_typ
|
|
610
615
|
RUBY_TYPED_FREE_IMMEDIATELY
|
611
616
|
};
|
612
617
|
|
618
|
+
class LLaMAModelWrapper {
|
619
|
+
public:
|
620
|
+
struct llama_model* model;
|
621
|
+
|
622
|
+
LLaMAModelWrapper() : model(NULL){};
|
623
|
+
|
624
|
+
~LLaMAModelWrapper() {
|
625
|
+
if (model != NULL) {
|
626
|
+
llama_free_model(model);
|
627
|
+
}
|
628
|
+
};
|
629
|
+
};
|
630
|
+
|
631
|
+
class RbLLaMAModel {
|
632
|
+
public:
|
633
|
+
static VALUE llama_model_alloc(VALUE self) {
|
634
|
+
LLaMAModelWrapper* ptr = (LLaMAModelWrapper*)ruby_xmalloc(sizeof(LLaMAModelWrapper));
|
635
|
+
new (ptr) LLaMAModelWrapper();
|
636
|
+
return TypedData_Wrap_Struct(self, &llama_model_type, ptr);
|
637
|
+
}
|
638
|
+
|
639
|
+
static void llama_model_free(void* ptr) {
|
640
|
+
((LLaMAModelWrapper*)ptr)->~LLaMAModelWrapper();
|
641
|
+
ruby_xfree(ptr);
|
642
|
+
}
|
643
|
+
|
644
|
+
static size_t llama_model_size(const void* ptr) {
|
645
|
+
return sizeof(*((LLaMAModelWrapper*)ptr));
|
646
|
+
}
|
647
|
+
|
648
|
+
static LLaMAModelWrapper* get_llama_model(VALUE self) {
|
649
|
+
LLaMAModelWrapper* ptr;
|
650
|
+
TypedData_Get_Struct(self, LLaMAModelWrapper, &llama_model_type, ptr);
|
651
|
+
return ptr;
|
652
|
+
}
|
653
|
+
|
654
|
+
static void define_class(VALUE outer) {
|
655
|
+
rb_cLLaMAModel = rb_define_class_under(outer, "Model", rb_cObject);
|
656
|
+
rb_define_alloc_func(rb_cLLaMAModel, llama_model_alloc);
|
657
|
+
rb_define_method(rb_cLLaMAModel, "initialize", RUBY_METHOD_FUNC(_llama_model_initialize), -1);
|
658
|
+
rb_define_method(rb_cLLaMAModel, "empty?", RUBY_METHOD_FUNC(_llama_model_empty), 0);
|
659
|
+
rb_define_method(rb_cLLaMAModel, "free", RUBY_METHOD_FUNC(_llama_model_free), 0);
|
660
|
+
rb_define_method(rb_cLLaMAModel, "load", RUBY_METHOD_FUNC(_llama_model_load), -1);
|
661
|
+
rb_define_method(rb_cLLaMAModel, "apply_lora_from_file", RUBY_METHOD_FUNC(_llama_model_apply_lora_from_file), -1);
|
662
|
+
}
|
663
|
+
|
664
|
+
private:
|
665
|
+
static const rb_data_type_t llama_model_type;
|
666
|
+
|
667
|
+
static VALUE _llama_model_initialize(int argc, VALUE* argv, VALUE self) {
|
668
|
+
VALUE kw_args = Qnil;
|
669
|
+
ID kw_table[2] = { rb_intern("model_path"), rb_intern("params") };
|
670
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
671
|
+
rb_scan_args(argc, argv, ":", &kw_args);
|
672
|
+
rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
|
673
|
+
|
674
|
+
if (kw_values[0] == Qundef && kw_values[1] == Qundef) {
|
675
|
+
rb_iv_set(self, "@params", Qnil);
|
676
|
+
return Qnil;
|
677
|
+
}
|
678
|
+
|
679
|
+
if (!RB_TYPE_P(kw_values[0], T_STRING)) {
|
680
|
+
rb_raise(rb_eArgError, "model_path must be a string");
|
681
|
+
return Qnil;
|
682
|
+
}
|
683
|
+
if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAContextParams)) {
|
684
|
+
rb_raise(rb_eArgError, "params must be a ContextParams");
|
685
|
+
return Qnil;
|
686
|
+
}
|
687
|
+
|
688
|
+
VALUE filename = kw_values[0];
|
689
|
+
LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(kw_values[1]);
|
690
|
+
LLaMAModelWrapper* model_ptr = get_llama_model(self);
|
691
|
+
|
692
|
+
if (prms_ptr->params.seed == LLAMA_DEFAULT_SEED) {
|
693
|
+
prms_ptr->params.seed = time(NULL);
|
694
|
+
}
|
695
|
+
|
696
|
+
try {
|
697
|
+
model_ptr->model = llama_load_model_from_file(StringValueCStr(filename), prms_ptr->params);
|
698
|
+
} catch (const std::runtime_error& e) {
|
699
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
700
|
+
return Qnil;
|
701
|
+
}
|
702
|
+
|
703
|
+
if (model_ptr->model == NULL) {
|
704
|
+
rb_raise(rb_eRuntimeError, "Failed to initialize LLaMA model");
|
705
|
+
return Qnil;
|
706
|
+
}
|
707
|
+
|
708
|
+
rb_iv_set(self, "@params", kw_values[1]);
|
709
|
+
|
710
|
+
RB_GC_GUARD(filename);
|
711
|
+
return Qnil;
|
712
|
+
}
|
713
|
+
|
714
|
+
static VALUE _llama_model_empty(VALUE self) {
|
715
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
716
|
+
if (ptr->model != NULL) {
|
717
|
+
return Qfalse;
|
718
|
+
}
|
719
|
+
return Qtrue;
|
720
|
+
}
|
721
|
+
|
722
|
+
static VALUE _llama_model_free(VALUE self) {
|
723
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
724
|
+
if (ptr->model != NULL) {
|
725
|
+
llama_free_model(ptr->model);
|
726
|
+
ptr->model = NULL;
|
727
|
+
rb_iv_set(self, "@params", Qnil);
|
728
|
+
}
|
729
|
+
return Qnil;
|
730
|
+
}
|
731
|
+
|
732
|
+
static VALUE _llama_model_load(int argc, VALUE* argv, VALUE self) {
|
733
|
+
VALUE kw_args = Qnil;
|
734
|
+
ID kw_table[2] = { rb_intern("model_path"), rb_intern("params") };
|
735
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
736
|
+
rb_scan_args(argc, argv, ":", &kw_args);
|
737
|
+
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
738
|
+
|
739
|
+
if (!RB_TYPE_P(kw_values[0], T_STRING)) {
|
740
|
+
rb_raise(rb_eArgError, "model_path must be a string");
|
741
|
+
return Qnil;
|
742
|
+
}
|
743
|
+
if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAContextParams)) {
|
744
|
+
rb_raise(rb_eArgError, "params must be a LLaMAContextParams");
|
745
|
+
return Qnil;
|
746
|
+
}
|
747
|
+
|
748
|
+
LLaMAModelWrapper* model_ptr = get_llama_model(self);
|
749
|
+
if (model_ptr->model != NULL) {
|
750
|
+
rb_raise(rb_eRuntimeError, "LLaMA model is already loaded");
|
751
|
+
return Qnil;
|
752
|
+
}
|
753
|
+
|
754
|
+
VALUE filename = kw_values[0];
|
755
|
+
LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(kw_values[1]);
|
756
|
+
|
757
|
+
try {
|
758
|
+
model_ptr->model = llama_load_model_from_file(StringValueCStr(filename), prms_ptr->params);
|
759
|
+
} catch (const std::runtime_error& e) {
|
760
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
761
|
+
return Qnil;
|
762
|
+
}
|
763
|
+
|
764
|
+
if (model_ptr->model == NULL) {
|
765
|
+
rb_raise(rb_eRuntimeError, "Failed to initialize LLaMA model");
|
766
|
+
return Qnil;
|
767
|
+
}
|
768
|
+
|
769
|
+
rb_iv_set(self, "@params", kw_values[1]);
|
770
|
+
|
771
|
+
RB_GC_GUARD(filename);
|
772
|
+
return Qnil;
|
773
|
+
}
|
774
|
+
|
775
|
+
static VALUE _llama_model_apply_lora_from_file(int argc, VALUE* argv, VALUE self) {
|
776
|
+
VALUE kw_args = Qnil;
|
777
|
+
ID kw_table[3] = { rb_intern("lora_path"), rb_intern("base_model_path"), rb_intern("n_threads") };
|
778
|
+
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
779
|
+
rb_scan_args(argc, argv, ":", &kw_args);
|
780
|
+
rb_get_kwargs(kw_args, kw_table, 1, 2, kw_values);
|
781
|
+
|
782
|
+
if (!RB_TYPE_P(kw_values[0], T_STRING)) {
|
783
|
+
rb_raise(rb_eArgError, "lora_path must be a string");
|
784
|
+
return Qnil;
|
785
|
+
}
|
786
|
+
if (kw_values[1] != Qundef && !RB_TYPE_P(kw_values[1], T_STRING)) {
|
787
|
+
rb_raise(rb_eArgError, "base_model_path must be a string");
|
788
|
+
return Qnil;
|
789
|
+
}
|
790
|
+
if (kw_values[2] != Qundef && !RB_INTEGER_TYPE_P(kw_values[2])) {
|
791
|
+
rb_raise(rb_eArgError, "n_threads must be an integer");
|
792
|
+
return Qnil;
|
793
|
+
}
|
794
|
+
|
795
|
+
const char* lora_path = StringValueCStr(kw_values[0]);
|
796
|
+
const char* base_model_path = kw_values[1] == Qundef ? NULL : StringValueCStr(kw_values[1]);
|
797
|
+
const int n_threads = kw_values[2] == Qundef ? 1 : NUM2INT(kw_values[2]);
|
798
|
+
|
799
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
800
|
+
if (llama_model_apply_lora_from_file(ptr->model, lora_path, base_model_path, n_threads) != 0) {
|
801
|
+
rb_raise(rb_eRuntimeError, "Failed to apply LoRA");
|
802
|
+
return Qnil;
|
803
|
+
}
|
804
|
+
return Qnil;
|
805
|
+
};
|
806
|
+
};
|
807
|
+
|
808
|
+
const rb_data_type_t RbLLaMAModel::llama_model_type = {
|
809
|
+
"RbLLaMAModel",
|
810
|
+
{ NULL,
|
811
|
+
RbLLaMAModel::llama_model_free,
|
812
|
+
RbLLaMAModel::llama_model_size },
|
813
|
+
NULL,
|
814
|
+
NULL,
|
815
|
+
RUBY_TYPED_FREE_IMMEDIATELY
|
816
|
+
};
|
817
|
+
|
613
818
|
class LLaMAContextWrapper {
|
614
819
|
public:
|
615
820
|
struct llama_context* ctx;
|
@@ -651,6 +856,7 @@ public:
|
|
651
856
|
rb_define_alloc_func(rb_cLLaMAContext, llama_context_alloc);
|
652
857
|
rb_define_method(rb_cLLaMAContext, "initialize", RUBY_METHOD_FUNC(_llama_context_initialize), -1);
|
653
858
|
rb_define_method(rb_cLLaMAContext, "eval", RUBY_METHOD_FUNC(_llama_context_eval), -1);
|
859
|
+
rb_define_method(rb_cLLaMAContext, "eval_embd", RUBY_METHOD_FUNC(_llama_context_eval_embd), -1);
|
654
860
|
rb_define_method(rb_cLLaMAContext, "eval_export", RUBY_METHOD_FUNC(_llama_context_eval_export), 1);
|
655
861
|
rb_define_method(rb_cLLaMAContext, "tokenize", RUBY_METHOD_FUNC(_llama_context_tokenize), -1);
|
656
862
|
rb_define_method(rb_cLLaMAContext, "logits", RUBY_METHOD_FUNC(_llama_context_logits), 0);
|
@@ -662,10 +868,6 @@ public:
|
|
662
868
|
rb_define_method(rb_cLLaMAContext, "n_embd", RUBY_METHOD_FUNC(_llama_context_n_embd), 0);
|
663
869
|
rb_define_method(rb_cLLaMAContext, "print_timings", RUBY_METHOD_FUNC(_llama_context_print_timings), 0);
|
664
870
|
rb_define_method(rb_cLLaMAContext, "reset_timings", RUBY_METHOD_FUNC(_llama_context_reset_timings), 0);
|
665
|
-
rb_define_method(rb_cLLaMAContext, "empty?", RUBY_METHOD_FUNC(_llama_context_empty), 0);
|
666
|
-
rb_define_method(rb_cLLaMAContext, "free", RUBY_METHOD_FUNC(_llama_context_free), 0);
|
667
|
-
rb_define_method(rb_cLLaMAContext, "load", RUBY_METHOD_FUNC(_llama_context_load), -1);
|
668
|
-
rb_define_method(rb_cLLaMAContext, "apply_lora_from_file", RUBY_METHOD_FUNC(_llama_context_apply_lora_from_file), -1);
|
669
871
|
rb_define_method(rb_cLLaMAContext, "kv_cache_token_count", RUBY_METHOD_FUNC(_llama_context_kv_cache_token_count), 0);
|
670
872
|
rb_define_method(rb_cLLaMAContext, "set_rng_seed", RUBY_METHOD_FUNC(_llama_context_set_rng_seed), 1);
|
671
873
|
rb_define_method(rb_cLLaMAContext, "load_session_file", RUBY_METHOD_FUNC(_llama_context_load_session_file), -1);
|
@@ -689,46 +891,37 @@ private:
|
|
689
891
|
|
690
892
|
static VALUE _llama_context_initialize(int argc, VALUE* argv, VALUE self) {
|
691
893
|
VALUE kw_args = Qnil;
|
692
|
-
ID kw_table[
|
693
|
-
VALUE kw_values[
|
894
|
+
ID kw_table[1] = { rb_intern("model") };
|
895
|
+
VALUE kw_values[1] = { Qundef };
|
694
896
|
rb_scan_args(argc, argv, ":", &kw_args);
|
695
|
-
rb_get_kwargs(kw_args, kw_table,
|
897
|
+
rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
|
696
898
|
|
697
|
-
|
698
|
-
|
699
|
-
|
899
|
+
VALUE model = kw_values[0];
|
900
|
+
if (!rb_obj_is_kind_of(model, rb_cLLaMAModel)) {
|
901
|
+
rb_raise(rb_eArgError, "model must be a Model");
|
700
902
|
return Qnil;
|
701
903
|
}
|
702
904
|
|
703
|
-
|
704
|
-
|
705
|
-
|
706
|
-
}
|
707
|
-
if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAContextParams)) {
|
708
|
-
rb_raise(rb_eArgError, "params must be a ContextParams");
|
905
|
+
LLaMAModelWrapper* model_ptr = RbLLaMAModel::get_llama_model(model);
|
906
|
+
if (model_ptr->model == NULL) {
|
907
|
+
rb_raise(rb_eRuntimeError, "Model is empty");
|
709
908
|
return Qnil;
|
710
909
|
}
|
711
910
|
|
712
|
-
VALUE
|
713
|
-
LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(
|
911
|
+
VALUE params = rb_iv_get(model, "@params");
|
912
|
+
LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
|
714
913
|
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
715
914
|
|
716
|
-
|
717
|
-
ctx_ptr->ctx = llama_init_from_file(StringValueCStr(filename), prms_ptr->params);
|
718
|
-
} catch (const std::runtime_error& e) {
|
719
|
-
rb_raise(rb_eRuntimeError, "%s", e.what());
|
720
|
-
return Qnil;
|
721
|
-
}
|
915
|
+
ctx_ptr->ctx = llama_new_context_with_model(model_ptr->model, prms_ptr->params);
|
722
916
|
|
723
917
|
if (ctx_ptr->ctx == NULL) {
|
724
918
|
rb_raise(rb_eRuntimeError, "Failed to initialize LLaMA context");
|
725
919
|
return Qnil;
|
726
920
|
}
|
727
921
|
|
728
|
-
rb_iv_set(self, "@
|
922
|
+
rb_iv_set(self, "@model", model);
|
729
923
|
rb_iv_set(self, "@has_evaluated", Qfalse);
|
730
924
|
|
731
|
-
RB_GC_GUARD(filename);
|
732
925
|
return Qnil;
|
733
926
|
};
|
734
927
|
|
@@ -787,6 +980,61 @@ private:
|
|
787
980
|
return Qnil;
|
788
981
|
};
|
789
982
|
|
983
|
+
static VALUE _llama_context_eval_embd(int argc, VALUE* argv, VALUE self) {
|
984
|
+
VALUE kw_args = Qnil;
|
985
|
+
ID kw_table[4] = { rb_intern("embd"), rb_intern("n_past"), rb_intern("n_tokens"), rb_intern("n_threads") };
|
986
|
+
VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
|
987
|
+
rb_scan_args(argc, argv, ":", &kw_args);
|
988
|
+
rb_get_kwargs(kw_args, kw_table, 2, 2, kw_values);
|
989
|
+
|
990
|
+
if (!RB_TYPE_P(kw_values[0], T_ARRAY)) {
|
991
|
+
rb_raise(rb_eArgError, "tokens must be an Array");
|
992
|
+
return Qnil;
|
993
|
+
}
|
994
|
+
if (!RB_INTEGER_TYPE_P(kw_values[1])) {
|
995
|
+
rb_raise(rb_eArgError, "n_past must be an integer");
|
996
|
+
return Qnil;
|
997
|
+
}
|
998
|
+
if (kw_values[2] != Qundef && !RB_INTEGER_TYPE_P(kw_values[2])) {
|
999
|
+
rb_raise(rb_eArgError, "n_tokens must be an integer");
|
1000
|
+
return Qnil;
|
1001
|
+
}
|
1002
|
+
if (kw_values[3] != Qundef && !RB_INTEGER_TYPE_P(kw_values[3])) {
|
1003
|
+
rb_raise(rb_eArgError, "n_threads must be an integer");
|
1004
|
+
return Qnil;
|
1005
|
+
}
|
1006
|
+
|
1007
|
+
const size_t tokens_len = RARRAY_LEN(kw_values[0]);
|
1008
|
+
std::vector<float> embd(tokens_len);
|
1009
|
+
for (size_t i = 0; i < tokens_len; i++) {
|
1010
|
+
VALUE el = rb_ary_entry(kw_values[0], i);
|
1011
|
+
if (!RB_FLOAT_TYPE_P(el)) {
|
1012
|
+
rb_raise(rb_eArgError, "embd must be an array of floats");
|
1013
|
+
return Qnil;
|
1014
|
+
}
|
1015
|
+
embd[i] = NUM2DBL(el);
|
1016
|
+
}
|
1017
|
+
|
1018
|
+
const int n_tokens = kw_values[2] == Qundef ? (int)tokens_len : NUM2INT(kw_values[2]);
|
1019
|
+
const int n_past = NUM2INT(kw_values[1]);
|
1020
|
+
const int n_threads = kw_values[3] == Qundef ? 1 : NUM2INT(kw_values[3]);
|
1021
|
+
|
1022
|
+
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1023
|
+
if (ptr->ctx == NULL) {
|
1024
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1025
|
+
return Qnil;
|
1026
|
+
}
|
1027
|
+
if (llama_eval_embd(ptr->ctx, embd.data(), n_tokens, n_past, n_threads) != 0) {
|
1028
|
+
rb_raise(rb_eRuntimeError, "Failed to evaluate");
|
1029
|
+
return Qnil;
|
1030
|
+
}
|
1031
|
+
|
1032
|
+
rb_iv_set(self, "@n_tokens", INT2NUM(n_tokens));
|
1033
|
+
rb_iv_set(self, "@has_evaluated", Qtrue);
|
1034
|
+
|
1035
|
+
return Qnil;
|
1036
|
+
}
|
1037
|
+
|
790
1038
|
static VALUE _llama_context_eval_export(VALUE self, VALUE fname_) {
|
791
1039
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
792
1040
|
if (ptr->ctx == NULL) {
|
@@ -873,7 +1121,9 @@ private:
|
|
873
1121
|
return Qnil;
|
874
1122
|
}
|
875
1123
|
|
876
|
-
|
1124
|
+
VALUE model = rb_iv_get(self, "@model");
|
1125
|
+
VALUE params = rb_iv_get(model, "@params");
|
1126
|
+
LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
|
877
1127
|
const int n_tokens = prms_ptr->params.logits_all ? NUM2INT(rb_iv_get(self, "@n_tokens")) : 1;
|
878
1128
|
const int n_vocab = llama_n_vocab(ptr->ctx);
|
879
1129
|
const float* logits = llama_get_logits(ptr->ctx);
|
@@ -891,7 +1141,9 @@ private:
|
|
891
1141
|
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
892
1142
|
return Qnil;
|
893
1143
|
}
|
894
|
-
|
1144
|
+
VALUE model = rb_iv_get(self, "@model");
|
1145
|
+
VALUE params = rb_iv_get(model, "@params");
|
1146
|
+
LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
|
895
1147
|
if (!prms_ptr->params.embedding) {
|
896
1148
|
rb_raise(rb_eRuntimeError, "embedding parameter is false");
|
897
1149
|
return Qnil;
|
@@ -995,106 +1247,6 @@ private:
|
|
995
1247
|
return Qnil;
|
996
1248
|
};
|
997
1249
|
|
998
|
-
static VALUE _llama_context_empty(VALUE self) {
|
999
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1000
|
-
if (ptr->ctx != NULL) {
|
1001
|
-
return Qfalse;
|
1002
|
-
}
|
1003
|
-
return Qtrue;
|
1004
|
-
}
|
1005
|
-
|
1006
|
-
static VALUE _llama_context_free(VALUE self) {
|
1007
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1008
|
-
if (ptr->ctx != NULL) {
|
1009
|
-
llama_free(ptr->ctx);
|
1010
|
-
ptr->ctx = NULL;
|
1011
|
-
rb_iv_set(self, "@params", Qnil);
|
1012
|
-
rb_iv_set(self, "@has_evaluated", Qfalse);
|
1013
|
-
}
|
1014
|
-
return Qnil;
|
1015
|
-
}
|
1016
|
-
|
1017
|
-
static VALUE _llama_context_load(int argc, VALUE* argv, VALUE self) {
|
1018
|
-
VALUE kw_args = Qnil;
|
1019
|
-
ID kw_table[2] = { rb_intern("model_path"), rb_intern("params") };
|
1020
|
-
VALUE kw_values[2] = { Qundef, Qundef };
|
1021
|
-
rb_scan_args(argc, argv, ":", &kw_args);
|
1022
|
-
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
1023
|
-
|
1024
|
-
if (!RB_TYPE_P(kw_values[0], T_STRING)) {
|
1025
|
-
rb_raise(rb_eArgError, "model_path must be a string");
|
1026
|
-
return Qnil;
|
1027
|
-
}
|
1028
|
-
if (!rb_obj_is_kind_of(kw_values[1], rb_cLLaMAContextParams)) {
|
1029
|
-
rb_raise(rb_eArgError, "params must be a LLaMAContextParams");
|
1030
|
-
return Qnil;
|
1031
|
-
}
|
1032
|
-
|
1033
|
-
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
1034
|
-
if (ctx_ptr->ctx != NULL) {
|
1035
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is already loaded");
|
1036
|
-
return Qnil;
|
1037
|
-
}
|
1038
|
-
|
1039
|
-
VALUE filename = kw_values[0];
|
1040
|
-
LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(kw_values[1]);
|
1041
|
-
|
1042
|
-
try {
|
1043
|
-
ctx_ptr->ctx = llama_init_from_file(StringValueCStr(filename), prms_ptr->params);
|
1044
|
-
} catch (const std::runtime_error& e) {
|
1045
|
-
rb_raise(rb_eRuntimeError, "%s", e.what());
|
1046
|
-
return Qnil;
|
1047
|
-
}
|
1048
|
-
|
1049
|
-
if (ctx_ptr->ctx == NULL) {
|
1050
|
-
rb_raise(rb_eRuntimeError, "Failed to initialize LLaMA context");
|
1051
|
-
return Qnil;
|
1052
|
-
}
|
1053
|
-
|
1054
|
-
rb_iv_set(self, "@params", kw_values[1]);
|
1055
|
-
rb_iv_set(self, "@has_evaluated", Qfalse);
|
1056
|
-
|
1057
|
-
RB_GC_GUARD(filename);
|
1058
|
-
return Qnil;
|
1059
|
-
};
|
1060
|
-
|
1061
|
-
static VALUE _llama_context_apply_lora_from_file(int argc, VALUE* argv, VALUE self) {
|
1062
|
-
VALUE kw_args = Qnil;
|
1063
|
-
ID kw_table[3] = { rb_intern("lora_path"), rb_intern("base_model_path"), rb_intern("n_threads") };
|
1064
|
-
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
1065
|
-
rb_scan_args(argc, argv, ":", &kw_args);
|
1066
|
-
rb_get_kwargs(kw_args, kw_table, 1, 2, kw_values);
|
1067
|
-
|
1068
|
-
if (!RB_TYPE_P(kw_values[0], T_STRING)) {
|
1069
|
-
rb_raise(rb_eArgError, "lora_path must be a string");
|
1070
|
-
return Qnil;
|
1071
|
-
}
|
1072
|
-
if (kw_values[1] != Qundef && !RB_TYPE_P(kw_values[1], T_STRING)) {
|
1073
|
-
rb_raise(rb_eArgError, "base_model_path must be a string");
|
1074
|
-
return Qnil;
|
1075
|
-
}
|
1076
|
-
if (kw_values[2] != Qundef && !RB_INTEGER_TYPE_P(kw_values[2])) {
|
1077
|
-
rb_raise(rb_eArgError, "n_threads must be an integer");
|
1078
|
-
return Qnil;
|
1079
|
-
}
|
1080
|
-
|
1081
|
-
const char* lora_path = StringValueCStr(kw_values[0]);
|
1082
|
-
const char* base_model_path = kw_values[1] == Qundef ? NULL : StringValueCStr(kw_values[1]);
|
1083
|
-
const int n_threads = kw_values[2] == Qundef ? 1 : NUM2INT(kw_values[2]);
|
1084
|
-
|
1085
|
-
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1086
|
-
if (ptr->ctx != NULL) {
|
1087
|
-
rb_raise(rb_eRuntimeError, "LLaMA context is already loaded");
|
1088
|
-
return Qnil;
|
1089
|
-
}
|
1090
|
-
|
1091
|
-
if (llama_apply_lora_from_file(ptr->ctx, lora_path, base_model_path, n_threads) != 0) {
|
1092
|
-
rb_raise(rb_eRuntimeError, "Failed to apply LoRA");
|
1093
|
-
return Qnil;
|
1094
|
-
}
|
1095
|
-
return Qnil;
|
1096
|
-
};
|
1097
|
-
|
1098
1250
|
static VALUE _llama_context_kv_cache_token_count(VALUE self) {
|
1099
1251
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
1100
1252
|
if (ptr->ctx == NULL) {
|
@@ -1110,7 +1262,11 @@ private:
|
|
1110
1262
|
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
1111
1263
|
return Qnil;
|
1112
1264
|
}
|
1113
|
-
|
1265
|
+
if (NUM2INT(seed_) < 0) {
|
1266
|
+
rb_raise(rb_eArgError, "seed must be a non-negative integer");
|
1267
|
+
return Qnil;
|
1268
|
+
}
|
1269
|
+
const uint32_t seed = NUM2INT(seed_);
|
1114
1270
|
llama_set_rng_seed(ptr->ctx, seed);
|
1115
1271
|
return Qnil;
|
1116
1272
|
};
|
@@ -1137,7 +1293,9 @@ private:
|
|
1137
1293
|
return Qnil;
|
1138
1294
|
}
|
1139
1295
|
|
1140
|
-
|
1296
|
+
VALUE model = rb_iv_get(self, "@model");
|
1297
|
+
VALUE params = rb_iv_get(model, "@params");
|
1298
|
+
LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(params);
|
1141
1299
|
const int n_ctx = prms_ptr->params.n_ctx;
|
1142
1300
|
|
1143
1301
|
std::vector<llama_token> session_tokens(n_ctx);
|
@@ -1664,8 +1822,16 @@ const rb_data_type_t RbLLaMAContext::llama_context_type = {
|
|
1664
1822
|
|
1665
1823
|
// module functions
|
1666
1824
|
|
1667
|
-
static VALUE rb_llama_llama_init_backend(VALUE self) {
|
1668
|
-
|
1825
|
+
static VALUE rb_llama_llama_init_backend(int argc, VALUE* argv, VALUE self) {
|
1826
|
+
VALUE kw_args = Qnil;
|
1827
|
+
ID kw_table[1] = { rb_intern("numa") };
|
1828
|
+
VALUE kw_values[1] = { Qundef };
|
1829
|
+
rb_scan_args(argc, argv, ":", &kw_args);
|
1830
|
+
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
1831
|
+
|
1832
|
+
const bool numa = kw_values[0] == Qundef ? false : (RTEST ? true : false);
|
1833
|
+
llama_init_backend(numa);
|
1834
|
+
|
1669
1835
|
return Qnil;
|
1670
1836
|
}
|
1671
1837
|
|
@@ -1731,10 +1897,11 @@ extern "C" void Init_llama_cpp(void) {
|
|
1731
1897
|
|
1732
1898
|
RbLLaMATokenData::define_class(rb_mLLaMACpp);
|
1733
1899
|
RbLLaMATokenDataArray::define_class(rb_mLLaMACpp);
|
1900
|
+
RbLLaMAModel::define_class(rb_mLLaMACpp);
|
1734
1901
|
RbLLaMAContext::define_class(rb_mLLaMACpp);
|
1735
1902
|
RbLLaMAContextParams::define_class(rb_mLLaMACpp);
|
1736
1903
|
|
1737
|
-
rb_define_module_function(rb_mLLaMACpp, "init_backend", rb_llama_llama_init_backend,
|
1904
|
+
rb_define_module_function(rb_mLLaMACpp, "init_backend", rb_llama_llama_init_backend, -1);
|
1738
1905
|
rb_define_module_function(rb_mLLaMACpp, "model_quantize", rb_llama_model_quantize, -1);
|
1739
1906
|
rb_define_module_function(rb_mLLaMACpp, "token_bos", rb_llama_token_bos, 0);
|
1740
1907
|
rb_define_module_function(rb_mLLaMACpp, "token_eos", rb_llama_token_eos, 0);
|
@@ -1802,6 +1969,11 @@ extern "C" void Init_llama_cpp(void) {
|
|
1802
1969
|
ss_magic << std::showbase << std::hex << LLAMA_SESSION_MAGIC;
|
1803
1970
|
rb_define_const(rb_mLLaMACpp, "LLAMA_SESSION_MAGIC", rb_str_new2(ss_magic.str().c_str()));
|
1804
1971
|
|
1972
|
+
ss_magic.str("");
|
1973
|
+
ss_magic.clear(std::stringstream::goodbit);
|
1974
|
+
ss_magic << std::showbase << std::hex << LLAMA_DEFAULT_SEED;
|
1975
|
+
rb_define_const(rb_mLLaMACpp, "LLAMA_DEFAULT_SEED", rb_str_new2(ss_magic.str().c_str()));
|
1976
|
+
|
1805
1977
|
rb_define_const(rb_mLLaMACpp, "LLAMA_FILE_VERSION", rb_str_new2(std::to_string(LLAMA_FILE_VERSION).c_str()));
|
1806
1978
|
rb_define_const(rb_mLLaMACpp, "LLAMA_SESSION_VERSION", rb_str_new2(std::to_string(LLAMA_SESSION_VERSION).c_str()));
|
1807
1979
|
}
|