cui-llama.rn 1.2.3 → 1.2.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +0 -2
- package/android/src/main/CMakeLists.txt +1 -0
- package/android/src/main/java/com/rnllama/LlamaContext.java +0 -3
- package/android/src/main/jni.cpp +9 -11
- package/cpp/common.cpp +85 -75
- package/cpp/common.h +127 -91
- package/cpp/ggml-aarch64.c +269 -0
- package/cpp/ggml-alloc.c +17 -19
- package/cpp/ggml-backend-impl.h +4 -15
- package/cpp/ggml-backend.cpp +1697 -1626
- package/cpp/ggml-backend.h +13 -25
- package/cpp/ggml-cpp.h +38 -0
- package/cpp/ggml-cpu.c +13720 -0
- package/cpp/ggml-cpu.h +150 -0
- package/cpp/ggml-impl.h +95 -0
- package/cpp/ggml-metal.m +185 -71
- package/cpp/ggml-quants.c +38 -51
- package/cpp/ggml.c +4468 -19500
- package/cpp/ggml.h +26 -146
- package/cpp/json-schema-to-grammar.cpp +1 -1
- package/cpp/llama-sampling.cpp +742 -249
- package/cpp/llama-sampling.h +21 -2
- package/cpp/llama-vocab.cpp +49 -9
- package/cpp/llama-vocab.h +35 -11
- package/cpp/llama.cpp +2468 -2307
- package/cpp/llama.h +65 -32
- package/cpp/log.cpp +50 -50
- package/cpp/log.h +18 -18
- package/cpp/rn-llama.hpp +23 -22
- package/cpp/sampling.cpp +117 -118
- package/cpp/sampling.h +20 -20
- package/cpp/sgemm.cpp +57 -0
- package/lib/commonjs/NativeRNLlama.js.map +1 -1
- package/lib/module/NativeRNLlama.js.map +1 -1
- package/lib/typescript/NativeRNLlama.d.ts +0 -1
- package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/NativeRNLlama.ts +0 -1
package/cpp/llama-sampling.cpp
CHANGED
@@ -63,6 +63,30 @@ static void llama_log_softmax(float * array, size_t size) {
|
|
63
63
|
}
|
64
64
|
*/
|
65
65
|
|
66
|
+
static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp) {
|
67
|
+
if (temp <= 0.0f) {
|
68
|
+
// find the token with the highest logit and set the rest to -inf
|
69
|
+
size_t max_i = 0;
|
70
|
+
float max_l = cur_p->data[0].logit;
|
71
|
+
|
72
|
+
for (size_t i = 1; i < cur_p->size; ++i) {
|
73
|
+
if (cur_p->data[i ].logit > max_l) {
|
74
|
+
cur_p->data[max_i].logit = -INFINITY;
|
75
|
+
max_i = i;
|
76
|
+
max_l = cur_p->data[i].logit;
|
77
|
+
} else {
|
78
|
+
cur_p->data[i].logit = -INFINITY;
|
79
|
+
}
|
80
|
+
}
|
81
|
+
|
82
|
+
return;
|
83
|
+
}
|
84
|
+
|
85
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
86
|
+
cur_p->data[i].logit /= temp;
|
87
|
+
}
|
88
|
+
}
|
89
|
+
|
66
90
|
static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
|
67
91
|
LM_GGML_ASSERT(cur_p->size > 0);
|
68
92
|
|
@@ -89,7 +113,7 @@ static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
|
|
89
113
|
}
|
90
114
|
|
91
115
|
static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) {
|
92
|
-
// TODO: move bucket sort to separate function so that top_p/
|
116
|
+
// TODO: move bucket sort to separate function so that top_p/typical/softmax first is equally fast
|
93
117
|
// if (k >= (int32_t)cur_p->size) {
|
94
118
|
// return;
|
95
119
|
// }
|
@@ -428,6 +452,9 @@ static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*
|
|
428
452
|
|
429
453
|
static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
430
454
|
auto * ctx = (llama_sampler_dist *) smpl->ctx;
|
455
|
+
|
456
|
+
llama_sampler_softmax_impl(cur_p);
|
457
|
+
|
431
458
|
cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
|
432
459
|
}
|
433
460
|
|
@@ -707,245 +734,6 @@ struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
|
|
707
734
|
};
|
708
735
|
}
|
709
736
|
|
710
|
-
// xtc
|
711
|
-
|
712
|
-
struct llama_sampler_xtc {
|
713
|
-
const uint32_t seed;
|
714
|
-
std::mt19937 rng;
|
715
|
-
const float xtc_p;
|
716
|
-
const float xtc_t;
|
717
|
-
const size_t min_keep;
|
718
|
-
};
|
719
|
-
|
720
|
-
static const char * llama_sampler_xtc_name(const struct llama_sampler * /* smpl */) {
|
721
|
-
return "xtc";
|
722
|
-
}
|
723
|
-
|
724
|
-
static void llama_sampler_xtc_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
725
|
-
auto * ctx = (llama_sampler_xtc *) smpl->ctx;
|
726
|
-
|
727
|
-
size_t min_keep = ctx -> min_keep;
|
728
|
-
std::mt19937 rng = ctx -> rng;
|
729
|
-
|
730
|
-
float xtc_threshold = ctx -> xtc_t;
|
731
|
-
float xtc_probability = ctx -> xtc_p;
|
732
|
-
|
733
|
-
|
734
|
-
if(xtc_threshold <= 0.0f || !cur_p-> size) {
|
735
|
-
return;
|
736
|
-
}
|
737
|
-
|
738
|
-
bool xtc_applied = false;
|
739
|
-
const int64_t t_start_sample_us = lm_ggml_time_us();
|
740
|
-
llama_sampler_softmax_impl(cur_p);
|
741
|
-
|
742
|
-
// unsorted iteration
|
743
|
-
if (!cur_p->sorted) {
|
744
|
-
std::vector<llama_token_data> top_tokens, low_tokens;
|
745
|
-
|
746
|
-
// split candidates into two arrays for low and high tokens
|
747
|
-
for (size_t i = 0; i < cur_p->size; ++i) {
|
748
|
-
if (cur_p->data[i].logit >= xtc_threshold) {
|
749
|
-
top_tokens.push_back(cur_p->data[i]);
|
750
|
-
} else {
|
751
|
-
low_tokens.push_back(cur_p-> data[i]);
|
752
|
-
}
|
753
|
-
}
|
754
|
-
// if there is only one or no top_tokens, do not truncate
|
755
|
-
|
756
|
-
if (top_tokens.size() <= 1) {
|
757
|
-
return;
|
758
|
-
}
|
759
|
-
|
760
|
-
// sort top_tokens
|
761
|
-
std::sort(top_tokens.begin(), top_tokens.end(), [](const llama_token_data & a, const llama_token_data & b) {
|
762
|
-
return a.logit < b.logit;
|
763
|
-
});
|
764
|
-
|
765
|
-
// insert top_tokens with probability. Always insert lowest top_token
|
766
|
-
low_tokens.push_back(top_tokens[0]);
|
767
|
-
std::uniform_real_distribution<float> random_float(0.0 , 1.0);
|
768
|
-
for (size_t i = 1; i < top_tokens.size(); ++i) {
|
769
|
-
if(random_float(rng) <= xtc_probability) {
|
770
|
-
low_tokens.push_back(top_tokens[i]);
|
771
|
-
}
|
772
|
-
}
|
773
|
-
if(low_tokens.size() >= min_keep) {
|
774
|
-
memcpy(cur_p->data, low_tokens.data(), low_tokens.size()*sizeof(llama_token_data));
|
775
|
-
cur_p->size = low_tokens.size();
|
776
|
-
xtc_applied = true;
|
777
|
-
}
|
778
|
-
}
|
779
|
-
// sorted iteration
|
780
|
-
|
781
|
-
if (!xtc_applied) {
|
782
|
-
// Sort the logits in descending order
|
783
|
-
if (!cur_p->sorted) {
|
784
|
-
std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
|
785
|
-
return a.logit > b.logit;
|
786
|
-
});
|
787
|
-
cur_p->sorted = true;
|
788
|
-
}
|
789
|
-
|
790
|
-
// find last token over threshold
|
791
|
-
|
792
|
-
size_t last_index = 0;
|
793
|
-
|
794
|
-
for (; last_index < cur_p -> size; ++last_index) {
|
795
|
-
if(cur_p -> data[last_index].p < xtc_threshold) {
|
796
|
-
break;
|
797
|
-
}
|
798
|
-
}
|
799
|
-
|
800
|
-
// check if only 1 token above threshold
|
801
|
-
if(last_index <= 1) {
|
802
|
-
return;
|
803
|
-
}
|
804
|
-
last_index--;
|
805
|
-
// items beyond safe index will be ignored
|
806
|
-
size_t safe_index = cur_p -> size;
|
807
|
-
|
808
|
-
// remove tokens until last threshold item
|
809
|
-
std::uniform_real_distribution<float> random_float(0.0 , 1.0);
|
810
|
-
for (size_t i = 0; i < last_index; i++) {
|
811
|
-
if(random_float(rng) < xtc_probability) {
|
812
|
-
std::swap(cur_p-> data[i], cur_p->data[safe_index - 1]);
|
813
|
-
safe_index--;
|
814
|
-
if (cur_p-> sorted) {
|
815
|
-
cur_p -> sorted = false;
|
816
|
-
}
|
817
|
-
}
|
818
|
-
}
|
819
|
-
cur_p -> size = safe_index;
|
820
|
-
}
|
821
|
-
}
|
822
|
-
|
823
|
-
static struct llama_sampler * llama_sampler_xtc_clone(const struct llama_sampler * smpl) {
|
824
|
-
const auto * ctx = (const llama_sampler_xtc *) smpl->ctx;
|
825
|
-
return llama_sampler_init_xtc(ctx->xtc_p, ctx->xtc_t, ctx->min_keep, ctx->seed);
|
826
|
-
}
|
827
|
-
|
828
|
-
static void llama_sampler_xtc_free(struct llama_sampler * smpl) {
|
829
|
-
delete (const llama_sampler_xtc *) smpl->ctx;
|
830
|
-
}
|
831
|
-
|
832
|
-
static struct llama_sampler_i llama_sampler_xtc_i = {
|
833
|
-
/* .name = */ llama_sampler_xtc_name,
|
834
|
-
/* .accept = */ nullptr,
|
835
|
-
/* .apply = */ llama_sampler_xtc_apply,
|
836
|
-
/* .reset = */ nullptr,
|
837
|
-
/* .clone = */ llama_sampler_xtc_clone,
|
838
|
-
/* .free = */ llama_sampler_xtc_free,
|
839
|
-
};
|
840
|
-
|
841
|
-
struct llama_sampler * llama_sampler_init_xtc(float xtc_p, float xtc_t, size_t min_keep, uint32_t seed) {
|
842
|
-
return new llama_sampler {
|
843
|
-
/* .iface = */ &llama_sampler_xtc_i,
|
844
|
-
/* .ctx = */ new llama_sampler_xtc {
|
845
|
-
/* .seed = */ seed,
|
846
|
-
/* .rng = */ std::mt19937(seed),
|
847
|
-
/* .xtc_p = */ xtc_p,
|
848
|
-
/* .xtc_t = */ xtc_t,
|
849
|
-
/* .min_keep = */ min_keep
|
850
|
-
},
|
851
|
-
};
|
852
|
-
}
|
853
|
-
|
854
|
-
// tail-free
|
855
|
-
|
856
|
-
struct llama_sampler_tail_free {
|
857
|
-
const float z;
|
858
|
-
const size_t min_keep;
|
859
|
-
};
|
860
|
-
|
861
|
-
static const char * llama_sampler_tail_free_name(const struct llama_sampler * /*smpl*/) {
|
862
|
-
return "tail-free";
|
863
|
-
}
|
864
|
-
|
865
|
-
static void llama_sampler_tail_free_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
866
|
-
const auto * ctx = (llama_sampler_tail_free *) smpl->ctx;
|
867
|
-
|
868
|
-
if (ctx->z >= 1.0f || cur_p->size <= 2) {
|
869
|
-
return;
|
870
|
-
}
|
871
|
-
|
872
|
-
llama_sampler_softmax_impl(cur_p);
|
873
|
-
|
874
|
-
// Compute the first and second derivatives
|
875
|
-
std::vector<float> first_derivatives(cur_p->size - 1);
|
876
|
-
std::vector<float> second_derivatives(cur_p->size - 2);
|
877
|
-
|
878
|
-
for (size_t i = 0; i < first_derivatives.size(); ++i) {
|
879
|
-
first_derivatives[i] = cur_p->data[i].p - cur_p->data[i + 1].p;
|
880
|
-
}
|
881
|
-
for (size_t i = 0; i < second_derivatives.size(); ++i) {
|
882
|
-
second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1];
|
883
|
-
}
|
884
|
-
|
885
|
-
// Calculate absolute value of second derivatives
|
886
|
-
for (size_t i = 0; i < second_derivatives.size(); ++i) {
|
887
|
-
second_derivatives[i] = std::abs(second_derivatives[i]);
|
888
|
-
}
|
889
|
-
|
890
|
-
// Normalize the second derivatives
|
891
|
-
{
|
892
|
-
const float second_derivatives_sum = std::accumulate(second_derivatives.begin(), second_derivatives.end(), 0.0f);
|
893
|
-
|
894
|
-
if (second_derivatives_sum > 1e-6f) {
|
895
|
-
for (float & value : second_derivatives) {
|
896
|
-
value /= second_derivatives_sum;
|
897
|
-
}
|
898
|
-
} else {
|
899
|
-
for (float & value : second_derivatives) {
|
900
|
-
value = 1.0f / second_derivatives.size();
|
901
|
-
}
|
902
|
-
}
|
903
|
-
}
|
904
|
-
|
905
|
-
float cum_sum = 0.0f;
|
906
|
-
size_t last_idx = cur_p->size;
|
907
|
-
for (size_t i = 0; i < second_derivatives.size(); ++i) {
|
908
|
-
cum_sum += second_derivatives[i];
|
909
|
-
|
910
|
-
// Check if the running sum is greater than z or if we have kept at least min_keep tokens
|
911
|
-
if (cum_sum > ctx->z && i >= ctx->min_keep) {
|
912
|
-
last_idx = i;
|
913
|
-
break;
|
914
|
-
}
|
915
|
-
}
|
916
|
-
|
917
|
-
// Resize the output vector to keep only the tokens above the tail location
|
918
|
-
cur_p->size = last_idx;
|
919
|
-
}
|
920
|
-
|
921
|
-
static struct llama_sampler * llama_sampler_tail_free_clone(const struct llama_sampler * smpl) {
|
922
|
-
const auto * ctx = (const llama_sampler_tail_free *) smpl->ctx;
|
923
|
-
return llama_sampler_init_tail_free(ctx->z, ctx->min_keep);
|
924
|
-
}
|
925
|
-
|
926
|
-
static void llama_sampler_tail_free_free(struct llama_sampler * smpl) {
|
927
|
-
delete (llama_sampler_tail_free *) smpl->ctx;
|
928
|
-
}
|
929
|
-
|
930
|
-
static struct llama_sampler_i llama_sampler_tail_free_i = {
|
931
|
-
/* .name = */ llama_sampler_tail_free_name,
|
932
|
-
/* .accept = */ nullptr,
|
933
|
-
/* .apply = */ llama_sampler_tail_free_apply,
|
934
|
-
/* .reset = */ nullptr,
|
935
|
-
/* .clone = */ llama_sampler_tail_free_clone,
|
936
|
-
/* .free = */ llama_sampler_tail_free_free,
|
937
|
-
};
|
938
|
-
|
939
|
-
struct llama_sampler * llama_sampler_init_tail_free(float z, size_t min_keep) {
|
940
|
-
return new llama_sampler {
|
941
|
-
/* .iface = */ &llama_sampler_tail_free_i,
|
942
|
-
/* .ctx = */ new llama_sampler_tail_free {
|
943
|
-
/* .z = */ z,
|
944
|
-
/*. min_keep = */ min_keep,
|
945
|
-
},
|
946
|
-
};
|
947
|
-
}
|
948
|
-
|
949
737
|
// typical
|
950
738
|
|
951
739
|
struct llama_sampler_typical {
|
@@ -1057,9 +845,8 @@ static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*
|
|
1057
845
|
|
1058
846
|
static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
1059
847
|
const auto * ctx = (llama_sampler_temp *) smpl->ctx;
|
1060
|
-
|
1061
|
-
|
1062
|
-
}
|
848
|
+
|
849
|
+
llama_sampler_temp_impl(cur_p, ctx->temp);
|
1063
850
|
}
|
1064
851
|
|
1065
852
|
static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) {
|
@@ -1106,6 +893,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
|
|
1106
893
|
if (ctx->delta > 0) {
|
1107
894
|
const float min_temp = std::max(0.0f, ctx->temp - ctx->delta);
|
1108
895
|
const float max_temp = ctx->temp + ctx->delta;
|
896
|
+
|
1109
897
|
float exponent_val = ctx->exponent;
|
1110
898
|
|
1111
899
|
// no need to do anything if there is only one (or zero) candidates
|
@@ -1143,9 +931,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
|
|
1143
931
|
#endif
|
1144
932
|
|
1145
933
|
// Apply the dynamically calculated temperature scaling
|
1146
|
-
|
1147
|
-
cur_p->data[i].logit /= dyn_temp;
|
1148
|
-
}
|
934
|
+
llama_sampler_temp_impl(cur_p, dyn_temp);
|
1149
935
|
|
1150
936
|
// Re-compute softmax probabilities after scaling logits with dynamic temperature
|
1151
937
|
const double max_l_double = cur_p->data[0].logit;
|
@@ -1169,9 +955,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
|
|
1169
955
|
}
|
1170
956
|
#endif
|
1171
957
|
} else {
|
1172
|
-
|
1173
|
-
cur_p->data[i].logit /= ctx->temp;
|
1174
|
-
}
|
958
|
+
llama_sampler_temp_impl(cur_p, ctx->temp);
|
1175
959
|
}
|
1176
960
|
}
|
1177
961
|
|
@@ -1204,6 +988,101 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa
|
|
1204
988
|
};
|
1205
989
|
}
|
1206
990
|
|
991
|
+
// xtc
|
992
|
+
|
993
|
+
struct llama_sampler_xtc {
|
994
|
+
const float probability;
|
995
|
+
const float threshold;
|
996
|
+
const size_t min_keep;
|
997
|
+
|
998
|
+
const uint32_t seed;
|
999
|
+
uint32_t seed_cur;
|
1000
|
+
|
1001
|
+
std::mt19937 rng;
|
1002
|
+
};
|
1003
|
+
|
1004
|
+
static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) {
|
1005
|
+
return "xtc";
|
1006
|
+
}
|
1007
|
+
|
1008
|
+
static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
1009
|
+
auto * ctx = (llama_sampler_xtc *) smpl->ctx;
|
1010
|
+
|
1011
|
+
if (ctx->probability <= 0.0f
|
1012
|
+
|| ctx->threshold > 0.5f
|
1013
|
+
|| cur_p->size < 2) {
|
1014
|
+
return;
|
1015
|
+
}
|
1016
|
+
|
1017
|
+
std::uniform_real_distribution<float> distribution(0.0f, 1.0f);
|
1018
|
+
float chance = distribution(ctx->rng);
|
1019
|
+
if (chance > ctx->probability) return;
|
1020
|
+
|
1021
|
+
// in case it's not sorted/recalculated yet
|
1022
|
+
llama_sampler_softmax_impl(cur_p);
|
1023
|
+
|
1024
|
+
int pos_last = 0;
|
1025
|
+
|
1026
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
1027
|
+
if (cur_p->data[i].p >= ctx->threshold) {
|
1028
|
+
pos_last = i;
|
1029
|
+
} else break;
|
1030
|
+
}
|
1031
|
+
|
1032
|
+
if (cur_p->size - pos_last >= ctx->min_keep && pos_last > 0) {
|
1033
|
+
cur_p->data += pos_last;
|
1034
|
+
cur_p->size -= pos_last;
|
1035
|
+
}
|
1036
|
+
}
|
1037
|
+
|
1038
|
+
static struct llama_sampler * llama_sampler_xtc_clone(const struct llama_sampler * smpl) {
|
1039
|
+
const auto * ctx = (const llama_sampler_xtc *) smpl->ctx;
|
1040
|
+
auto * result = llama_sampler_init_xtc(ctx->probability, ctx->threshold, ctx->min_keep, ctx->seed);
|
1041
|
+
|
1042
|
+
// copy the state
|
1043
|
+
{
|
1044
|
+
auto * result_ctx = (llama_sampler_xtc *) result->ctx;
|
1045
|
+
|
1046
|
+
result_ctx->rng = ctx->rng;
|
1047
|
+
}
|
1048
|
+
|
1049
|
+
return result;
|
1050
|
+
}
|
1051
|
+
|
1052
|
+
static void llama_sampler_xtc_free(struct llama_sampler * smpl) {
|
1053
|
+
delete (llama_sampler_xtc *) smpl->ctx;
|
1054
|
+
}
|
1055
|
+
|
1056
|
+
static void llama_sampler_xtc_reset(struct llama_sampler * smpl) {
|
1057
|
+
auto * ctx = (llama_sampler_xtc *) smpl->ctx;
|
1058
|
+
ctx->seed_cur = get_rng_seed(ctx->seed);
|
1059
|
+
ctx->rng.seed(ctx->seed_cur);
|
1060
|
+
}
|
1061
|
+
|
1062
|
+
static struct llama_sampler_i llama_sampler_xtc_i = {
|
1063
|
+
/* .name = */ llama_sampler_xtc_name,
|
1064
|
+
/* .accept = */ nullptr,
|
1065
|
+
/* .apply = */ llama_sample_xtc_apply,
|
1066
|
+
/* .reset = */ llama_sampler_xtc_reset,
|
1067
|
+
/* .clone = */ llama_sampler_xtc_clone,
|
1068
|
+
/* .free = */ llama_sampler_xtc_free,
|
1069
|
+
};
|
1070
|
+
|
1071
|
+
struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
|
1072
|
+
auto seed_cur = get_rng_seed(seed);
|
1073
|
+
return new llama_sampler {
|
1074
|
+
/* .iface = */ &llama_sampler_xtc_i,
|
1075
|
+
/* .ctx = */ new llama_sampler_xtc {
|
1076
|
+
/* .probability = */ p,
|
1077
|
+
/* .threshold = */ t,
|
1078
|
+
/* .min_keep = */ min_keep,
|
1079
|
+
/* .seed = */ seed,
|
1080
|
+
/* .seed_cur = */ seed_cur,
|
1081
|
+
/* .rng = */ std::mt19937(seed_cur),
|
1082
|
+
},
|
1083
|
+
};
|
1084
|
+
}
|
1085
|
+
|
1207
1086
|
// mirostat
|
1208
1087
|
|
1209
1088
|
struct llama_sampler_mirostat {
|
@@ -1710,6 +1589,397 @@ struct llama_sampler * llama_sampler_init_penalties(
|
|
1710
1589
|
};
|
1711
1590
|
}
|
1712
1591
|
|
1592
|
+
// DRY
|
1593
|
+
|
1594
|
+
struct llama_sampler_dry {
|
1595
|
+
int32_t total_context_size;
|
1596
|
+
|
1597
|
+
const float dry_multiplier;
|
1598
|
+
const float dry_base;
|
1599
|
+
const int32_t dry_allowed_length;
|
1600
|
+
const int32_t dry_penalty_last_n;
|
1601
|
+
|
1602
|
+
std::unordered_multimap<llama_token, std::vector<llama_token>> dry_processed_breakers;
|
1603
|
+
std::vector<int> dry_repeat_count;
|
1604
|
+
std::unordered_map<llama_token, int> dry_max_token_repeat;
|
1605
|
+
ring_buffer<llama_token> last_tokens;
|
1606
|
+
};
|
1607
|
+
|
1608
|
+
// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
|
1609
|
+
static void get_overlapping_token_sequences(const llama_vocab & vocab, const std::string& str, std::unordered_multimap<llama_token, std::vector<llama_token>>& token_sequences, int max_tail_len = -1) {
|
1610
|
+
for (llama_token token_id = 0; token_id < (llama_token)vocab.n_vocab; token_id++) {
|
1611
|
+
std::string word = llama_detokenize(vocab, {token_id}, true);
|
1612
|
+
if (word.find(str) != std::string::npos) {
|
1613
|
+
token_sequences.emplace(token_id, std::vector<llama_token>());
|
1614
|
+
} else {
|
1615
|
+
size_t word_len = word.size(), str_len = str.size();
|
1616
|
+
size_t pos = -1;
|
1617
|
+
while ((pos = word.find(str[0], pos + 1)) != std::string::npos) {
|
1618
|
+
bool match = true;
|
1619
|
+
size_t i;
|
1620
|
+
for (i = 1; i < str_len && i + pos < word_len; ++i) {
|
1621
|
+
if (word[pos + i] != str[i]) {
|
1622
|
+
match = false;
|
1623
|
+
break;
|
1624
|
+
}
|
1625
|
+
}
|
1626
|
+
if (match) {
|
1627
|
+
std::vector<llama_token> tokenization = llama_tokenize_internal(vocab, str.substr(i), false, false);
|
1628
|
+
if (max_tail_len >= 0 && tokenization.size() > (size_t)max_tail_len) {
|
1629
|
+
tokenization.resize(max_tail_len);
|
1630
|
+
}
|
1631
|
+
|
1632
|
+
// Ensure we don't already have a duplicate matching tokenization
|
1633
|
+
auto its = token_sequences.equal_range(token_id);
|
1634
|
+
bool found = false;
|
1635
|
+
for (auto it = its.first; it != its.second; ++it) {
|
1636
|
+
if (tokenization == it->second) {
|
1637
|
+
found = true;
|
1638
|
+
break;
|
1639
|
+
}
|
1640
|
+
}
|
1641
|
+
if (!found) {
|
1642
|
+
token_sequences.emplace(token_id, tokenization);
|
1643
|
+
}
|
1644
|
+
}
|
1645
|
+
}
|
1646
|
+
}
|
1647
|
+
}
|
1648
|
+
}
|
1649
|
+
|
1650
|
+
static const char * llama_sampler_dry_name(const struct llama_sampler * /*smpl*/) {
|
1651
|
+
return "dry";
|
1652
|
+
}
|
1653
|
+
|
1654
|
+
static void llama_sampler_dry_accept(struct llama_sampler * smpl, llama_token token) {
|
1655
|
+
auto * ctx = (llama_sampler_dry *) smpl->ctx;
|
1656
|
+
if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) {
|
1657
|
+
return;
|
1658
|
+
}
|
1659
|
+
|
1660
|
+
ctx->last_tokens.push_back(token);
|
1661
|
+
}
|
1662
|
+
|
1663
|
+
// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
|
1664
|
+
static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
1665
|
+
auto * ctx = (llama_sampler_dry *) smpl->ctx;
|
1666
|
+
|
1667
|
+
if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) {
|
1668
|
+
return;
|
1669
|
+
}
|
1670
|
+
|
1671
|
+
int32_t effective_dry_penalty_last_n = (ctx->dry_penalty_last_n == -1) ? ctx->total_context_size : std::max(ctx->dry_penalty_last_n, 0);
|
1672
|
+
int last_n_repeat = std::min(std::min((int)ctx->last_tokens.size(), effective_dry_penalty_last_n), ctx->total_context_size);
|
1673
|
+
|
1674
|
+
if (last_n_repeat <= ctx->dry_allowed_length) {
|
1675
|
+
return;
|
1676
|
+
}
|
1677
|
+
|
1678
|
+
ctx->dry_repeat_count.assign(last_n_repeat, 0);
|
1679
|
+
ctx->dry_max_token_repeat.clear();
|
1680
|
+
|
1681
|
+
// Step 1: Look for restart sequences to limit the maximum repetition length.
|
1682
|
+
// Work backwards through the context looking for any token that begins a restart sequence.
|
1683
|
+
//
|
1684
|
+
// The collection `restart_sequences` is a mapping from a "head" token to all "tail"
|
1685
|
+
// sequences that together comprise a restart sequence. This allows us to quickly check
|
1686
|
+
// whether each token is the head of a complete sequence. Most restart sequences are actually
|
1687
|
+
// a single token, and for these the "tail" is an empty vector.
|
1688
|
+
//
|
1689
|
+
// If the token is a "head", test all restart sequences that begin with this token
|
1690
|
+
// (there will often only be one sequence for each token, but if sequences like 'aaaq1' and
|
1691
|
+
// 'aaa1' are used as restart strings, both could start with 'aaa' when tokenized). The
|
1692
|
+
// longest matching sequence (if any) is used to limit the maximum repetition length.
|
1693
|
+
//
|
1694
|
+
// Note that in the case case of a short sequence contained in a longer one, this might fail to
|
1695
|
+
// find the smallest value for `rep_limit`. For example, if 'amniotic' and 'ni' are both used as
|
1696
|
+
// restart sequences, 'ni' will be found first, and since it's shorter it will fail to suppress
|
1697
|
+
// 'otic'. This is a minor issue since fully contained restart sequences are likely to be rare.
|
1698
|
+
//
|
1699
|
+
// This is theoretically worst-case O(N^2) for arbitrary restart sequences, which is why we
|
1700
|
+
// have already clamped the maximum tail sequence length when generating `restart_sequences`.
|
1701
|
+
// With clamping, this scan is O(N) in the context length.
|
1702
|
+
|
1703
|
+
int rep_limit = last_n_repeat;
|
1704
|
+
for (int i = 0; i < last_n_repeat; ++i) {
|
1705
|
+
llama_token token = ctx->last_tokens.rat(i);
|
1706
|
+
auto its = ctx->dry_processed_breakers.equal_range(token);
|
1707
|
+
if (its.first == ctx->dry_processed_breakers.end()) {
|
1708
|
+
continue;
|
1709
|
+
}
|
1710
|
+
int longest_match = -1;
|
1711
|
+
for (auto it = its.first; it != its.second; ++it) {
|
1712
|
+
// Note that (*it) does not contain the head character, so seq_len will be
|
1713
|
+
// the restart sequence length minus 1.
|
1714
|
+
// In the common case of a single-token restart sequence, (*it) will be empty
|
1715
|
+
// and we will trivially match.
|
1716
|
+
int seq_len = (int)it->second.size();
|
1717
|
+
if (seq_len > longest_match && seq_len <= (int)i) {
|
1718
|
+
bool match = true;
|
1719
|
+
for (int offset = 0; offset < seq_len; ++offset) {
|
1720
|
+
// The -1 when indexing `last_tokens` is because we already matched the head.
|
1721
|
+
if (it->second[offset] != ctx->last_tokens.rat(i - offset - 1)) {
|
1722
|
+
match = false;
|
1723
|
+
break;
|
1724
|
+
}
|
1725
|
+
}
|
1726
|
+
if (match) {
|
1727
|
+
longest_match = seq_len;
|
1728
|
+
}
|
1729
|
+
}
|
1730
|
+
}
|
1731
|
+
if (longest_match >= 0) {
|
1732
|
+
// We found a restart sequence starting `i` tokens from the end and continuing for
|
1733
|
+
// `longest_match` tokens.
|
1734
|
+
rep_limit = i - longest_match;
|
1735
|
+
break;
|
1736
|
+
}
|
1737
|
+
}
|
1738
|
+
if (rep_limit < ctx->dry_allowed_length) {
|
1739
|
+
return;
|
1740
|
+
}
|
1741
|
+
|
1742
|
+
// Step 2: Iterate in reverse over the last N tokens of the context, using the "Z-algorithm" (in
|
1743
|
+
// the reverse direction) to efficiently compute the positions and lengths of suffixes appearing
|
1744
|
+
// elsewhere in the context. We limit the suffix length to `rep_limit` to respect restart sequences.
|
1745
|
+
//
|
1746
|
+
// This algorithm is not currently documented on Wikipedia, but there is a clear description here:
|
1747
|
+
// https://ivanyu.me/blog/2014/10/15/z-algorithm/
|
1748
|
+
//
|
1749
|
+
// The code below is adapted from the public domain implementation by the same author here:
|
1750
|
+
// https://github.com/ivanyu/string-algorithms/blob/master/z_algorithm.py
|
1751
|
+
//
|
1752
|
+
// Example:
|
1753
|
+
// Last N tokens: a b c c b c y a b c
|
1754
|
+
// Repeat counts: 0 0 3 1 0 2 0 0 0 0
|
1755
|
+
// ^
|
1756
|
+
// This `3` means that the last three tokens of the context (a b c) also appear here.
|
1757
|
+
//
|
1758
|
+
// This step is worst case O(N) since the Z-algorithm is linear, despite the appearance of nested
|
1759
|
+
// for/while loops. This can be seen by observing that the `lt` and `rt` bounds are set after each
|
1760
|
+
// repeated suffix is detected (i.e. after each while loop when n > 0). These bound variables
|
1761
|
+
// ensure that the inner while loops only examine each token in the context once as the outer
|
1762
|
+
// for loop iterates over the context.
|
1763
|
+
|
1764
|
+
{
|
1765
|
+
const int last = last_n_repeat - 1;
|
1766
|
+
int rt = 0, lt = 0;
|
1767
|
+
|
1768
|
+
for (int k = 1; k < last_n_repeat; ++k) {
|
1769
|
+
if (k > rt) {
|
1770
|
+
// If k is outside the current Z-box, do naive computation.
|
1771
|
+
int n = 0;
|
1772
|
+
while (n + k < last_n_repeat && ctx->last_tokens.rat(n) == ctx->last_tokens.rat(n+k)) {
|
1773
|
+
++n;
|
1774
|
+
}
|
1775
|
+
ctx->dry_repeat_count[last - k] = std::min(n, rep_limit);
|
1776
|
+
if (n > 0) {
|
1777
|
+
lt = k;
|
1778
|
+
rt = k+n-1;
|
1779
|
+
}
|
1780
|
+
} else {
|
1781
|
+
// If k is inside the current Z-box, consider two cases.
|
1782
|
+
|
1783
|
+
int p = k - lt; // Pair index.
|
1784
|
+
int right_part_len = rt - k + 1;
|
1785
|
+
|
1786
|
+
if (ctx->dry_repeat_count[last - p] < right_part_len) {
|
1787
|
+
int n = std::min(ctx->dry_repeat_count[last - p], rep_limit);
|
1788
|
+
ctx->dry_repeat_count[last - k] = n;
|
1789
|
+
} else {
|
1790
|
+
int i = rt + 1;
|
1791
|
+
while (i < last_n_repeat && ctx->last_tokens.rat(i) == ctx->last_tokens.rat(i - k)) {
|
1792
|
+
i += 1;
|
1793
|
+
}
|
1794
|
+
|
1795
|
+
int n = std::min(i - k, rep_limit);
|
1796
|
+
ctx->dry_repeat_count[last - k] = n;
|
1797
|
+
lt = k;
|
1798
|
+
rt = i - 1;
|
1799
|
+
}
|
1800
|
+
}
|
1801
|
+
}
|
1802
|
+
}
|
1803
|
+
|
1804
|
+
// Step 3: Iterate over dry_repeat_count and last_tokens, examining the maximum repeat length
|
1805
|
+
// that would be generated by emitting each new token that would extend a sequence.
|
1806
|
+
//
|
1807
|
+
// Following the same example as above:
|
1808
|
+
// Last N tokens: a b c c b c y a b c
|
1809
|
+
// Repeat counts: 0 0 3 1 0 2 0 0 0 0
|
1810
|
+
//
|
1811
|
+
// For each non-zero, look ahead one token. This token, if emitted, would extend the repetition.
|
1812
|
+
// c: 3 -> 4 (from `a b c` to `a b c c`)
|
1813
|
+
// b: 1 -> 2 (from `c` to `c b`)
|
1814
|
+
// y: 2 -> 3 (from `b c` to `b c y`)
|
1815
|
+
|
1816
|
+
for (int i = 0; i < last_n_repeat - 1; ++i) {
|
1817
|
+
int repeat_len = ctx->dry_repeat_count[i];
|
1818
|
+
if (repeat_len >= ctx->dry_allowed_length) {
|
1819
|
+
// This token ends a repeat, so the next token would continue one.
|
1820
|
+
// By convention, the value of `repeat_len` only includes the tokens currently
|
1821
|
+
// in the context, not the new token that would be added.
|
1822
|
+
llama_token token = ctx->last_tokens.rat(last_n_repeat - 2 - i);
|
1823
|
+
// Track the maximum sequence ending in this token.
|
1824
|
+
const auto& it = ctx->dry_max_token_repeat.find(token);
|
1825
|
+
if (it == ctx->dry_max_token_repeat.end() || it->second < repeat_len) {
|
1826
|
+
ctx->dry_max_token_repeat[token] = repeat_len;
|
1827
|
+
}
|
1828
|
+
}
|
1829
|
+
}
|
1830
|
+
|
1831
|
+
// Step 4: Apply logit penalties based on the maximum repeat length for relevant tokens.
|
1832
|
+
|
1833
|
+
// Prevent floating point overflow in `pow(penalty_base, exponent)` by clamping to `max_exponent`.
|
1834
|
+
// Compute it from `penalty_base` and the approximate log of `std::numeric_limits<float>::max()`
|
1835
|
+
const float FLOAT_MAX_LOG = 88.7228391f;
|
1836
|
+
int max_exponent = 0;
|
1837
|
+
if (ctx->dry_base > 1.000001f) {
|
1838
|
+
max_exponent = FLOAT_MAX_LOG / std::log(ctx->dry_base);
|
1839
|
+
}
|
1840
|
+
|
1841
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
1842
|
+
const auto& af_kvp = ctx->dry_max_token_repeat.find(cur_p->data[i].id);
|
1843
|
+
if (af_kvp != ctx->dry_max_token_repeat.end()) {
|
1844
|
+
// Check all sequence breakers starting with this token
|
1845
|
+
auto range = ctx->dry_processed_breakers.equal_range(cur_p->data[i].id);
|
1846
|
+
bool is_single_token_breaker = false;
|
1847
|
+
|
1848
|
+
for (auto it = range.first; it != range.second; ++it) {
|
1849
|
+
if (it->second.empty()) {
|
1850
|
+
is_single_token_breaker = true;
|
1851
|
+
break;
|
1852
|
+
}
|
1853
|
+
}
|
1854
|
+
|
1855
|
+
// Apply penalty only if it's not a single-token sequence breaker
|
1856
|
+
if (!is_single_token_breaker) {
|
1857
|
+
int repeat_exp = af_kvp->second - ctx->dry_allowed_length;
|
1858
|
+
if (max_exponent > 0 && repeat_exp > max_exponent) {
|
1859
|
+
repeat_exp = max_exponent;
|
1860
|
+
}
|
1861
|
+
float penalty = ctx->dry_multiplier * std::pow(ctx->dry_base, repeat_exp);
|
1862
|
+
cur_p->data[i].logit -= penalty;
|
1863
|
+
}
|
1864
|
+
}
|
1865
|
+
}
|
1866
|
+
|
1867
|
+
cur_p->sorted = false;
|
1868
|
+
}
|
1869
|
+
|
1870
|
+
static void llama_sampler_dry_reset(struct llama_sampler * smpl) {
|
1871
|
+
auto * ctx = (llama_sampler_dry *) smpl->ctx;
|
1872
|
+
ctx->last_tokens.clear();
|
1873
|
+
ctx->dry_repeat_count.clear();
|
1874
|
+
ctx->dry_max_token_repeat.clear();
|
1875
|
+
}
|
1876
|
+
|
1877
|
+
static struct llama_sampler * llama_sampler_dry_clone(const struct llama_sampler * smpl) {
|
1878
|
+
const auto * ctx = (llama_sampler_dry *) smpl->ctx;
|
1879
|
+
|
1880
|
+
// nullptr is passed as vocab because it is only needed for raw sequence breaker processing, which we have already done and will be copying
|
1881
|
+
auto * result = llama_sampler_init_dry(nullptr, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0);
|
1882
|
+
// Copy the state, including the processed breakers
|
1883
|
+
{
|
1884
|
+
auto * result_ctx = (llama_sampler_dry *) result->ctx;
|
1885
|
+
result_ctx->dry_processed_breakers = ctx->dry_processed_breakers;
|
1886
|
+
result_ctx->dry_repeat_count = ctx->dry_repeat_count;
|
1887
|
+
result_ctx->dry_max_token_repeat = ctx->dry_max_token_repeat;
|
1888
|
+
result_ctx->last_tokens = ctx->last_tokens;
|
1889
|
+
}
|
1890
|
+
|
1891
|
+
return result;
|
1892
|
+
}
|
1893
|
+
|
1894
|
+
static void llama_sampler_dry_free(struct llama_sampler * smpl) {
|
1895
|
+
delete (llama_sampler_dry *) smpl->ctx;
|
1896
|
+
}
|
1897
|
+
|
1898
|
+
static struct llama_sampler_i llama_sampler_dry_i = {
|
1899
|
+
/* .name = */ llama_sampler_dry_name,
|
1900
|
+
/* .accept = */ llama_sampler_dry_accept,
|
1901
|
+
/* .apply = */ llama_sampler_dry_apply,
|
1902
|
+
/* .reset = */ llama_sampler_dry_reset,
|
1903
|
+
/* .clone = */ llama_sampler_dry_clone,
|
1904
|
+
/* .free = */ llama_sampler_dry_free,
|
1905
|
+
};
|
1906
|
+
|
1907
|
+
struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
|
1908
|
+
int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? context_size : std::max(dry_penalty_last_n, 0);
|
1909
|
+
std::unordered_multimap<llama_token, std::vector<llama_token>> processed_breakers;
|
1910
|
+
const int MAX_CHAR_LEN = 40;
|
1911
|
+
const int MAX_SEQ_LEN = 20;
|
1912
|
+
|
1913
|
+
const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0);
|
1914
|
+
|
1915
|
+
if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) {
|
1916
|
+
// Process sequence breakers
|
1917
|
+
for (size_t i = 0; i < num_breakers; ++i) {
|
1918
|
+
if (seq_breakers[i] == nullptr || std::strlen(seq_breakers[i]) == 0) {
|
1919
|
+
LLAMA_LOG_WARN("skipping null or empty DRY sequence breaker at index %zu\n", i);
|
1920
|
+
continue;
|
1921
|
+
}
|
1922
|
+
|
1923
|
+
std::string sequence_break(seq_breakers[i]);
|
1924
|
+
if (sequence_break.empty()) {
|
1925
|
+
LLAMA_LOG_WARN("skipping empty DRY sequence breaker\n");
|
1926
|
+
continue;
|
1927
|
+
}
|
1928
|
+
|
1929
|
+
if (sequence_break.size() > MAX_CHAR_LEN) {
|
1930
|
+
LLAMA_LOG_WARN("truncating DRY sequence breaker to %d characters\n", MAX_CHAR_LEN);
|
1931
|
+
sequence_break.resize(MAX_CHAR_LEN);
|
1932
|
+
}
|
1933
|
+
|
1934
|
+
get_overlapping_token_sequences(vocab, sequence_break, processed_breakers, MAX_SEQ_LEN);
|
1935
|
+
}
|
1936
|
+
}
|
1937
|
+
|
1938
|
+
return new llama_sampler {
|
1939
|
+
/* .iface = */ &llama_sampler_dry_i,
|
1940
|
+
/* .ctx = */ new llama_sampler_dry {
|
1941
|
+
/* .total_context_size = */ context_size,
|
1942
|
+
/* .dry_multiplier = */ dry_multiplier,
|
1943
|
+
/* .dry_base = */ dry_base,
|
1944
|
+
/* .dry_allowed_length = */ dry_allowed_length,
|
1945
|
+
/* .dry_penalty_last_n = */ dry_penalty_last_n,
|
1946
|
+
/* .dry_processed_breakers = */ std::move(processed_breakers),
|
1947
|
+
/* .dry_repeat_count = */ dry_enabled ? std::vector<int>(effective_dry_penalty_last_n, 0) : std::vector<int>{},
|
1948
|
+
/* .dry_max_token_repeat = */ {},
|
1949
|
+
/* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0),
|
1950
|
+
},
|
1951
|
+
};
|
1952
|
+
}
|
1953
|
+
|
1954
|
+
// wrapper for test-sampling.cpp
|
1955
|
+
struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector<std::vector<llama_token>>& seq_breakers) {
|
1956
|
+
llama_vocab dummy_vocab;
|
1957
|
+
auto * result = llama_sampler_init_dry_impl(dummy_vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, 0);
|
1958
|
+
auto * ctx = (llama_sampler_dry *) result->ctx;
|
1959
|
+
|
1960
|
+
// Process the token-based sequence breakers
|
1961
|
+
ctx->dry_processed_breakers.clear();
|
1962
|
+
if (seq_breakers.empty()) {
|
1963
|
+
LLAMA_LOG_WARN("empty DRY sequence breakers list in llama_sampler_init_dry_testing\n");
|
1964
|
+
} else {
|
1965
|
+
for (const auto& breaker : seq_breakers) {
|
1966
|
+
if (breaker.empty()) {
|
1967
|
+
LLAMA_LOG_WARN("skipping DRY empty sequence breaker\n");
|
1968
|
+
continue;
|
1969
|
+
}
|
1970
|
+
llama_token head_token = breaker[0];
|
1971
|
+
std::vector<llama_token> tail_tokens(breaker.begin() + 1, breaker.end());
|
1972
|
+
ctx->dry_processed_breakers.emplace(head_token, std::move(tail_tokens));
|
1973
|
+
}
|
1974
|
+
|
1975
|
+
if (ctx->dry_processed_breakers.empty()) {
|
1976
|
+
LLAMA_LOG_WARN("no valid DRY sequence breakers processed in llama_sampler_init_dry_testing\n");
|
1977
|
+
}
|
1978
|
+
}
|
1979
|
+
|
1980
|
+
return result;
|
1981
|
+
}
|
1982
|
+
|
1713
1983
|
// logit-bias
|
1714
1984
|
|
1715
1985
|
struct llama_sampler_logit_bias {
|
@@ -1789,6 +2059,229 @@ struct llama_sampler * llama_sampler_init_logit_bias(
|
|
1789
2059
|
};
|
1790
2060
|
}
|
1791
2061
|
|
2062
|
+
// infill
|
2063
|
+
|
2064
|
+
//#define LM_GGML_DEBUG_SAMPLER_INFILL
|
2065
|
+
|
2066
|
+
struct llama_sampler_infill {
|
2067
|
+
const struct llama_vocab * vocab;
|
2068
|
+
|
2069
|
+
std::vector<char> buf0;
|
2070
|
+
std::vector<char> buf1;
|
2071
|
+
};
|
2072
|
+
|
2073
|
+
static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) {
|
2074
|
+
return "infill";
|
2075
|
+
}
|
2076
|
+
|
2077
|
+
static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
2078
|
+
auto * ctx = (llama_sampler_infill *) smpl->ctx;
|
2079
|
+
|
2080
|
+
llama_sampler_softmax_impl(cur_p);
|
2081
|
+
|
2082
|
+
#if defined(LM_GGML_DEBUG_SAMPLER_INFILL)
|
2083
|
+
#define LOG_DBG_CUR LLAMA_LOG_DEBUG
|
2084
|
+
#else
|
2085
|
+
#define LOG_DBG_CUR(...)
|
2086
|
+
#endif
|
2087
|
+
|
2088
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
2089
|
+
LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
|
2090
|
+
}
|
2091
|
+
|
2092
|
+
float p_txt_sum = 0.0f;
|
2093
|
+
float p_eog_sum = 0.0f;
|
2094
|
+
|
2095
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
2096
|
+
if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
|
2097
|
+
p_eog_sum += cur_p->data[i].p;
|
2098
|
+
} else {
|
2099
|
+
p_txt_sum += cur_p->data[i].p;
|
2100
|
+
}
|
2101
|
+
}
|
2102
|
+
|
2103
|
+
const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum; LM_GGML_UNUSED(rat);
|
2104
|
+
|
2105
|
+
LOG_DBG_CUR("%s: p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", __func__, p_txt_sum, p_eog_sum, rat, cur_p->size);
|
2106
|
+
|
2107
|
+
if (3*p_eog_sum*cur_p->size > p_txt_sum) {
|
2108
|
+
LOG_DBG_CUR("%s: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n", __func__, p_txt_sum/p_eog_sum);
|
2109
|
+
|
2110
|
+
// keep just the EOG tokens
|
2111
|
+
const auto size_org = cur_p->size;
|
2112
|
+
|
2113
|
+
cur_p->size = 0;
|
2114
|
+
|
2115
|
+
float p_sum = 0.0f;
|
2116
|
+
|
2117
|
+
for (size_t i = 0; i < size_org; ++i) {
|
2118
|
+
if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
|
2119
|
+
p_sum += cur_p->data[i].p;
|
2120
|
+
|
2121
|
+
cur_p->data[cur_p->size++] = cur_p->data[i];
|
2122
|
+
}
|
2123
|
+
}
|
2124
|
+
|
2125
|
+
// normalize probs
|
2126
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
2127
|
+
cur_p->data[i].p /= p_sum;
|
2128
|
+
}
|
2129
|
+
|
2130
|
+
return;
|
2131
|
+
}
|
2132
|
+
|
2133
|
+
size_t n_combined = 0; LM_GGML_UNUSED(n_combined);
|
2134
|
+
|
2135
|
+
// combine tokens with common prefix
|
2136
|
+
for (size_t i0 = 0; i0 < cur_p->size; ++i0) {
|
2137
|
+
for (size_t i1 = 0; i1 < cur_p->size; ++i1) {
|
2138
|
+
if (cur_p->data[i0].logit == -INFINITY) {
|
2139
|
+
break;
|
2140
|
+
}
|
2141
|
+
|
2142
|
+
if (i0 == i1 || cur_p->data[i1].logit == -INFINITY) {
|
2143
|
+
continue;
|
2144
|
+
}
|
2145
|
+
|
2146
|
+
int len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
|
2147
|
+
if (len0 < 0) {
|
2148
|
+
ctx->buf0.resize(len0);
|
2149
|
+
len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
|
2150
|
+
assert(len0 > 0);
|
2151
|
+
}
|
2152
|
+
|
2153
|
+
int len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
|
2154
|
+
if (len1 < 0) {
|
2155
|
+
ctx->buf1.resize(len1);
|
2156
|
+
len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
|
2157
|
+
assert(len1 > 0);
|
2158
|
+
}
|
2159
|
+
|
2160
|
+
// token i0 is a prefix of token i1
|
2161
|
+
if (len0 > 0 && len0 <= len1 && memcmp(ctx->buf0.data(), ctx->buf1.data(), len0) == 0) {
|
2162
|
+
int dst = i0;
|
2163
|
+
int src = i1;
|
2164
|
+
|
2165
|
+
// merge into the token with higher probability
|
2166
|
+
if (cur_p->data[i1].p > cur_p->data[i0].p) {
|
2167
|
+
std::swap(dst, src);
|
2168
|
+
}
|
2169
|
+
|
2170
|
+
cur_p->data[dst].p += cur_p->data[src].p;
|
2171
|
+
cur_p->data[src].logit = -INFINITY;
|
2172
|
+
cur_p->data[src].p = 0.0f;
|
2173
|
+
|
2174
|
+
n_combined++;
|
2175
|
+
}
|
2176
|
+
}
|
2177
|
+
}
|
2178
|
+
|
2179
|
+
size_t n_non_eog = 0;
|
2180
|
+
|
2181
|
+
size_t size_org = cur_p->size;
|
2182
|
+
|
2183
|
+
float p_sum = 0.0f;
|
2184
|
+
float thold = 0.2f;
|
2185
|
+
|
2186
|
+
cur_p->size = 0;
|
2187
|
+
|
2188
|
+
LOG_DBG_CUR("%s: n_combined = %zu, applying thold = %.3f\n", __func__, n_combined, thold);
|
2189
|
+
|
2190
|
+
for (size_t i = 0; i < size_org; ++i) {
|
2191
|
+
const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
|
2192
|
+
|
2193
|
+
if (cur_p->data[i].p < thold && !is_eog) {
|
2194
|
+
continue;
|
2195
|
+
}
|
2196
|
+
|
2197
|
+
if (!is_eog) {
|
2198
|
+
++n_non_eog;
|
2199
|
+
}
|
2200
|
+
|
2201
|
+
p_sum += cur_p->data[i].p;
|
2202
|
+
|
2203
|
+
// keep this token
|
2204
|
+
cur_p->data[cur_p->size++] = cur_p->data[i];
|
2205
|
+
}
|
2206
|
+
|
2207
|
+
LOG_DBG_CUR("%s: n_non_eog = %zu\n", __func__, n_non_eog);
|
2208
|
+
|
2209
|
+
// if no non-EOG tokens are left -> reduce cur_p to single EOT token
|
2210
|
+
if (n_non_eog == 0) {
|
2211
|
+
cur_p->size = 1;
|
2212
|
+
cur_p->data[0].id = llama_token_eot_impl(*ctx->vocab);
|
2213
|
+
cur_p->data[0].logit = 1.0f;
|
2214
|
+
|
2215
|
+
return;
|
2216
|
+
}
|
2217
|
+
|
2218
|
+
// normalize probs
|
2219
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
2220
|
+
cur_p->data[i].p /= p_sum;
|
2221
|
+
|
2222
|
+
LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
|
2223
|
+
}
|
2224
|
+
|
2225
|
+
size_org = cur_p->size;
|
2226
|
+
p_sum = 0.0f;
|
2227
|
+
thold = 1.0/(n_non_eog + 1);
|
2228
|
+
|
2229
|
+
cur_p->size = 0;
|
2230
|
+
|
2231
|
+
LOG_DBG_CUR("%s: applying thold = %.3f\n", __func__, thold);
|
2232
|
+
|
2233
|
+
for (size_t i = 0; i < size_org; ++i) {
|
2234
|
+
const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
|
2235
|
+
|
2236
|
+
if (cur_p->data[i].p < thold && !is_eog) {
|
2237
|
+
continue;
|
2238
|
+
}
|
2239
|
+
|
2240
|
+
p_sum += cur_p->data[i].p;
|
2241
|
+
|
2242
|
+
cur_p->data[cur_p->size++] = cur_p->data[i];
|
2243
|
+
}
|
2244
|
+
|
2245
|
+
// normalize probs
|
2246
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
2247
|
+
cur_p->data[i].p /= p_sum;
|
2248
|
+
|
2249
|
+
LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
|
2250
|
+
}
|
2251
|
+
|
2252
|
+
#undef LOG_DBG_CUR
|
2253
|
+
}
|
2254
|
+
|
2255
|
+
static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) {
|
2256
|
+
const auto * ctx = (const llama_sampler_infill *) smpl->ctx;
|
2257
|
+
return llama_sampler_init_infill_impl(*ctx->vocab);
|
2258
|
+
}
|
2259
|
+
|
2260
|
+
static void llama_sampler_infill_free(struct llama_sampler * smpl) {
|
2261
|
+
delete (llama_sampler_infill *) smpl->ctx;
|
2262
|
+
}
|
2263
|
+
|
2264
|
+
static struct llama_sampler_i llama_sampler_infill_i = {
|
2265
|
+
/* .name = */ llama_sampler_infill_name,
|
2266
|
+
/* .accept = */ nullptr,
|
2267
|
+
/* .apply = */ llama_sampler_infill_apply,
|
2268
|
+
/* .reset = */ nullptr,
|
2269
|
+
/* .clone = */ llama_sampler_infill_clone,
|
2270
|
+
/* .free = */ llama_sampler_infill_free,
|
2271
|
+
};
|
2272
|
+
|
2273
|
+
struct llama_sampler * llama_sampler_init_infill_impl(
|
2274
|
+
const struct llama_vocab & vocab) {
|
2275
|
+
return new llama_sampler {
|
2276
|
+
/* .iface = */ &llama_sampler_infill_i,
|
2277
|
+
/* .ctx = */ new llama_sampler_infill {
|
2278
|
+
/* .vocab = */ &vocab,
|
2279
|
+
/* .buf0 = */ std::vector<char>(512),
|
2280
|
+
/* .buf1 = */ std::vector<char>(512),
|
2281
|
+
},
|
2282
|
+
};
|
2283
|
+
}
|
2284
|
+
|
1792
2285
|
// utils
|
1793
2286
|
|
1794
2287
|
uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
|