@fugood/llama.node 1.4.12 → 1.4.13
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/package.json +15 -15
- package/scripts/llama.cpp.patch +9 -9
- package/src/llama.cpp/common/arg.cpp +99 -45
- package/src/llama.cpp/common/chat.cpp +4 -4
- package/src/llama.cpp/common/common.cpp +19 -0
- package/src/llama.cpp/common/common.h +10 -0
- package/src/llama.cpp/common/llguidance.cpp +10 -6
- package/src/llama.cpp/common/regex-partial.cpp +13 -13
- package/src/llama.cpp/common/sampling.cpp +58 -14
- package/src/llama.cpp/common/sampling.h +3 -1
- package/src/llama.cpp/include/llama.h +87 -8
- package/src/llama.cpp/src/llama-arch.cpp +2 -0
- package/src/llama.cpp/src/llama-arch.h +1 -0
- package/src/llama.cpp/src/llama-context.cpp +615 -28
- package/src/llama.cpp/src/llama-context.h +43 -1
- package/src/llama.cpp/src/llama-grammar.cpp +40 -13
- package/src/llama.cpp/src/llama-grammar.h +2 -0
- package/src/llama.cpp/src/llama-graph.cpp +173 -5
- package/src/llama.cpp/src/llama-graph.h +71 -6
- package/src/llama.cpp/src/llama-hparams.cpp +4 -0
- package/src/llama.cpp/src/llama-hparams.h +8 -2
- package/src/llama.cpp/src/llama-model-saver.cpp +3 -0
- package/src/llama.cpp/src/llama-model.cpp +51 -11
- package/src/llama.cpp/src/llama-sampling.cpp +1232 -170
- package/src/llama.cpp/src/llama-sampling.h +16 -7
- package/src/llama.cpp/src/llama.cpp +38 -30
- package/src/llama.cpp/src/models/afmoe.cpp +9 -5
- package/src/llama.cpp/src/models/cohere2-iswa.cpp +3 -0
- package/src/llama.cpp/src/models/gemma2-iswa.cpp +5 -2
- package/src/llama.cpp/src/models/llama-iswa.cpp +6 -2
- package/src/llama.cpp/src/models/modern-bert.cpp +4 -3
- package/src/llama.cpp/src/models/openai-moe-iswa.cpp +5 -2
- package/src/llama.cpp/src/models/smallthinker.cpp +11 -5
|
@@ -60,6 +60,25 @@ llama_context::llama_context(
|
|
|
60
60
|
cparams.cb_eval = params.cb_eval;
|
|
61
61
|
cparams.cb_eval_user_data = params.cb_eval_user_data;
|
|
62
62
|
|
|
63
|
+
// Initialize backend samplers here so they are part of the sampling graph
|
|
64
|
+
// before the reserve passes run later in this function. This avoids a later
|
|
65
|
+
// re-reserve when graph nodes change.
|
|
66
|
+
if (params.samplers != nullptr && params.n_samplers > 0) {
|
|
67
|
+
for (size_t i = 0; i < params.n_samplers; ++i) {
|
|
68
|
+
const auto & config = params.samplers[i];
|
|
69
|
+
|
|
70
|
+
if (llama_sampler_chain_get(config.sampler, -1) == nullptr) {
|
|
71
|
+
throw std::runtime_error("the backend samplers must be of type llama_sampler_chain");
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
if (set_sampler(config.seq_id, config.sampler)) {
|
|
75
|
+
const int n_samplers = llama_sampler_chain_n(config.sampler);
|
|
76
|
+
|
|
77
|
+
LLAMA_LOG_INFO("%s: setting backend sampler for seq_id %d (n = %d)\n", __func__, config.seq_id, n_samplers);
|
|
78
|
+
}
|
|
79
|
+
}
|
|
80
|
+
}
|
|
81
|
+
|
|
63
82
|
auto rope_scaling_type = params.rope_scaling_type;
|
|
64
83
|
if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) {
|
|
65
84
|
rope_scaling_type = hparams.rope_scaling_type_train;
|
|
@@ -231,7 +250,10 @@ llama_context::llama_context(
|
|
|
231
250
|
// graph outputs buffer
|
|
232
251
|
{
|
|
233
252
|
// resized during inference when a batch uses more outputs
|
|
234
|
-
|
|
253
|
+
// Create a dummy batch for initialization.
|
|
254
|
+
llama_batch dummy_batch = {};
|
|
255
|
+
dummy_batch.n_tokens = 0;
|
|
256
|
+
if (output_reserve(params.n_seq_max, dummy_batch) < params.n_seq_max) {
|
|
235
257
|
throw std::runtime_error("failed to reserve initial output buffer");
|
|
236
258
|
}
|
|
237
259
|
|
|
@@ -456,6 +478,16 @@ llama_context::llama_context(
|
|
|
456
478
|
LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg);
|
|
457
479
|
}
|
|
458
480
|
}
|
|
481
|
+
|
|
482
|
+
// Initialize the full vocabulary token ids for backend samplers.
|
|
483
|
+
{
|
|
484
|
+
const int n_vocab = model.vocab.n_tokens();
|
|
485
|
+
|
|
486
|
+
sampling.token_ids_full_vocab.resize(n_vocab);
|
|
487
|
+
for (int i = 0; i < n_vocab; ++i) {
|
|
488
|
+
sampling.token_ids_full_vocab[i] = i;
|
|
489
|
+
}
|
|
490
|
+
}
|
|
459
491
|
}
|
|
460
492
|
|
|
461
493
|
llama_context::~llama_context() {
|
|
@@ -616,6 +648,35 @@ float * llama_context::get_logits() {
|
|
|
616
648
|
return logits;
|
|
617
649
|
}
|
|
618
650
|
|
|
651
|
+
int64_t llama_context::output_resolve_row(int32_t i) const {
|
|
652
|
+
int64_t j = -1;
|
|
653
|
+
|
|
654
|
+
// support negative indices (last output row)
|
|
655
|
+
if (i < 0) {
|
|
656
|
+
j = n_outputs + i;
|
|
657
|
+
if (j < 0) {
|
|
658
|
+
throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs));
|
|
659
|
+
}
|
|
660
|
+
} else if ((size_t) i >= output_ids.size()) {
|
|
661
|
+
throw std::runtime_error(format("out of range [0, %zu)", output_ids.size()));
|
|
662
|
+
} else {
|
|
663
|
+
// use output_ids to translate the batch token index into a row number
|
|
664
|
+
// that holds this token's data.
|
|
665
|
+
j = output_ids[i];
|
|
666
|
+
}
|
|
667
|
+
|
|
668
|
+
if (j < 0) {
|
|
669
|
+
// the batch token was not configured to output anything
|
|
670
|
+
throw std::runtime_error(format("batch.logits[%d] != true", i));
|
|
671
|
+
}
|
|
672
|
+
|
|
673
|
+
if (j >= n_outputs) {
|
|
674
|
+
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
|
|
675
|
+
}
|
|
676
|
+
|
|
677
|
+
return j;
|
|
678
|
+
}
|
|
679
|
+
|
|
619
680
|
float * llama_context::get_logits_ith(int32_t i) {
|
|
620
681
|
int64_t j = -1;
|
|
621
682
|
|
|
@@ -626,6 +687,7 @@ float * llama_context::get_logits_ith(int32_t i) {
|
|
|
626
687
|
throw std::runtime_error("no logits");
|
|
627
688
|
}
|
|
628
689
|
|
|
690
|
+
// TODO: use output_resolve_row()
|
|
629
691
|
if (i < 0) {
|
|
630
692
|
j = n_outputs + i;
|
|
631
693
|
if (j < 0) {
|
|
@@ -662,6 +724,10 @@ float * llama_context::get_embeddings() {
|
|
|
662
724
|
return embd;
|
|
663
725
|
}
|
|
664
726
|
|
|
727
|
+
llama_token * llama_context::get_sampled_tokens() const{
|
|
728
|
+
return sampling.sampled;
|
|
729
|
+
}
|
|
730
|
+
|
|
665
731
|
float * llama_context::get_embeddings_ith(int32_t i) {
|
|
666
732
|
int64_t j = -1;
|
|
667
733
|
|
|
@@ -672,6 +738,7 @@ float * llama_context::get_embeddings_ith(int32_t i) {
|
|
|
672
738
|
throw std::runtime_error("no embeddings");
|
|
673
739
|
}
|
|
674
740
|
|
|
741
|
+
// TODO: use output_resolve_row()
|
|
675
742
|
if (i < 0) {
|
|
676
743
|
j = n_outputs + i;
|
|
677
744
|
if (j < 0) {
|
|
@@ -691,7 +758,8 @@ float * llama_context::get_embeddings_ith(int32_t i) {
|
|
|
691
758
|
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
|
|
692
759
|
}
|
|
693
760
|
|
|
694
|
-
|
|
761
|
+
const uint32_t n_embd_out = model.hparams.get_n_embd_out();
|
|
762
|
+
return embd + j*n_embd_out;
|
|
695
763
|
} catch (const std::exception & err) {
|
|
696
764
|
LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
|
|
697
765
|
#ifndef NDEBUG
|
|
@@ -711,6 +779,136 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
|
|
|
711
779
|
return it->second.data();
|
|
712
780
|
}
|
|
713
781
|
|
|
782
|
+
llama_token llama_context::get_sampled_token_ith(int32_t idx) {
|
|
783
|
+
output_reorder();
|
|
784
|
+
|
|
785
|
+
if (sampling.sampled == nullptr) {
|
|
786
|
+
return LLAMA_TOKEN_NULL;
|
|
787
|
+
}
|
|
788
|
+
|
|
789
|
+
try {
|
|
790
|
+
const int64_t row = output_resolve_row(idx);
|
|
791
|
+
GGML_ASSERT(row < (int64_t) sampling.sampled_size);
|
|
792
|
+
return sampling.sampled[row];
|
|
793
|
+
} catch (const std::exception & err) {
|
|
794
|
+
LLAMA_LOG_ERROR("%s: invalid backend sampled token id %d, reason: %s\n", __func__, idx, err.what());
|
|
795
|
+
return LLAMA_TOKEN_NULL;
|
|
796
|
+
}
|
|
797
|
+
}
|
|
798
|
+
|
|
799
|
+
float * llama_context::get_sampled_probs_ith(int32_t idx) {
|
|
800
|
+
output_reorder();
|
|
801
|
+
|
|
802
|
+
if (sampling.probs == nullptr) {
|
|
803
|
+
return nullptr;
|
|
804
|
+
}
|
|
805
|
+
|
|
806
|
+
try {
|
|
807
|
+
const int64_t row = output_resolve_row(idx);
|
|
808
|
+
if ((size_t) row >= sampling.probs_count.size() || sampling.probs_count[row] == 0) {
|
|
809
|
+
return nullptr;
|
|
810
|
+
}
|
|
811
|
+
return sampling.probs + row*model.vocab.n_tokens();
|
|
812
|
+
} catch (const std::exception & err) {
|
|
813
|
+
LLAMA_LOG_ERROR("%s: invalid backend sampled probs id %d, reason: %s\n", __func__, idx, err.what());
|
|
814
|
+
return nullptr;
|
|
815
|
+
}
|
|
816
|
+
}
|
|
817
|
+
|
|
818
|
+
float * llama_context::get_sampled_logits_ith(int32_t idx) {
|
|
819
|
+
output_reorder();
|
|
820
|
+
|
|
821
|
+
if (sampling.logits == nullptr) {
|
|
822
|
+
return nullptr;
|
|
823
|
+
}
|
|
824
|
+
|
|
825
|
+
try {
|
|
826
|
+
const int64_t row = output_resolve_row(idx);
|
|
827
|
+
if ((size_t) row >= sampling.logits_count.size() || sampling.logits_count[row] == 0) {
|
|
828
|
+
return nullptr;
|
|
829
|
+
}
|
|
830
|
+
return sampling.logits + row*model.vocab.n_tokens();
|
|
831
|
+
} catch (const std::exception & err) {
|
|
832
|
+
LLAMA_LOG_ERROR("%s: invalid backend sampled logits id %d, reason: %s\n", __func__, idx, err.what());
|
|
833
|
+
return nullptr;
|
|
834
|
+
}
|
|
835
|
+
}
|
|
836
|
+
|
|
837
|
+
const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) {
|
|
838
|
+
output_reorder();
|
|
839
|
+
|
|
840
|
+
try {
|
|
841
|
+
const int64_t row = output_resolve_row(idx);
|
|
842
|
+
if (sampling.candidates != nullptr &&
|
|
843
|
+
(size_t) row < sampling.candidates_count.size() &&
|
|
844
|
+
sampling.candidates_count[row] > 0) {
|
|
845
|
+
return sampling.candidates + row*model.vocab.n_tokens();
|
|
846
|
+
}
|
|
847
|
+
} catch (const std::exception & err) {
|
|
848
|
+
// fallback to full vocab list
|
|
849
|
+
}
|
|
850
|
+
|
|
851
|
+
return sampling.token_ids_full_vocab.data();
|
|
852
|
+
}
|
|
853
|
+
|
|
854
|
+
size_t llama_context::get_sampled_candidates_count(int32_t idx) {
|
|
855
|
+
output_reorder();
|
|
856
|
+
|
|
857
|
+
if (sampling.candidates == nullptr) {
|
|
858
|
+
return 0;
|
|
859
|
+
}
|
|
860
|
+
|
|
861
|
+
try {
|
|
862
|
+
const int64_t row = output_resolve_row(idx);
|
|
863
|
+
if ((size_t) row >= sampling.candidates_count.size()) {
|
|
864
|
+
return 0;
|
|
865
|
+
}
|
|
866
|
+
return sampling.candidates_count[row];
|
|
867
|
+
} catch (const std::exception & err) {
|
|
868
|
+
LLAMA_LOG_ERROR("%s: invalid backend sampled candidates count id %d, reason: %s\n", __func__, idx, err.what());
|
|
869
|
+
return 0;
|
|
870
|
+
}
|
|
871
|
+
}
|
|
872
|
+
|
|
873
|
+
size_t llama_context::get_sampled_logits_count(int32_t idx) {
|
|
874
|
+
output_reorder();
|
|
875
|
+
|
|
876
|
+
if (sampling.logits == nullptr) {
|
|
877
|
+
return model.vocab.n_tokens();
|
|
878
|
+
}
|
|
879
|
+
|
|
880
|
+
try {
|
|
881
|
+
const int64_t row = output_resolve_row(idx);
|
|
882
|
+
if ((size_t) row >= sampling.logits_count.size()) {
|
|
883
|
+
return 0;
|
|
884
|
+
}
|
|
885
|
+
return sampling.logits_count[row];
|
|
886
|
+
} catch (const std::exception & err) {
|
|
887
|
+
LLAMA_LOG_ERROR("%s: invalid backend sampled logits count id %d, reason: %s\n", __func__, idx, err.what());
|
|
888
|
+
return 0;
|
|
889
|
+
}
|
|
890
|
+
}
|
|
891
|
+
|
|
892
|
+
size_t llama_context::get_sampled_probs_count(int32_t idx) {
|
|
893
|
+
output_reorder();
|
|
894
|
+
|
|
895
|
+
if (sampling.probs == nullptr) {
|
|
896
|
+
return 0;
|
|
897
|
+
}
|
|
898
|
+
|
|
899
|
+
try {
|
|
900
|
+
const int64_t row = output_resolve_row(idx);
|
|
901
|
+
if ((size_t) row >= sampling.probs_count.size()) {
|
|
902
|
+
return 0;
|
|
903
|
+
}
|
|
904
|
+
return sampling.probs_count[row];
|
|
905
|
+
} catch (const std::exception & err) {
|
|
906
|
+
LLAMA_LOG_ERROR("%s: invalid backend sampled probs count id %d, reason: %s\n", __func__, idx, err.what());
|
|
907
|
+
return 0;
|
|
908
|
+
}
|
|
909
|
+
}
|
|
910
|
+
|
|
911
|
+
|
|
714
912
|
void llama_context::attach_threadpool(
|
|
715
913
|
ggml_threadpool_t threadpool,
|
|
716
914
|
ggml_threadpool_t threadpool_batch) {
|
|
@@ -767,6 +965,42 @@ void llama_context::set_warmup(bool value) {
|
|
|
767
965
|
cparams.warmup = value;
|
|
768
966
|
}
|
|
769
967
|
|
|
968
|
+
bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
|
|
969
|
+
LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler);
|
|
970
|
+
|
|
971
|
+
const bool can_offload =
|
|
972
|
+
sampler &&
|
|
973
|
+
sampler->iface->backend_init &&
|
|
974
|
+
sampler->iface->backend_apply &&
|
|
975
|
+
llama_sampler_chain_n(sampler) > 0;
|
|
976
|
+
|
|
977
|
+
if (sampler && can_offload) {
|
|
978
|
+
ggml_backend_buffer_type_t buft = ggml_backend_dev_buffer_type(model.dev_output());
|
|
979
|
+
auto * host_buft = ggml_backend_dev_host_buffer_type(model.dev_output());
|
|
980
|
+
if (host_buft) {
|
|
981
|
+
buft = host_buft;
|
|
982
|
+
}
|
|
983
|
+
|
|
984
|
+
sampler->iface->backend_init(sampler, buft);
|
|
985
|
+
|
|
986
|
+
sampling.samplers[seq_id] = sampler;
|
|
987
|
+
|
|
988
|
+
return true;
|
|
989
|
+
}
|
|
990
|
+
|
|
991
|
+
if (sampler && !can_offload) {
|
|
992
|
+
LLAMA_LOG_WARN("%s: sampler '%s' for seq_id = %d, cannot be offloaded to the backend\n", __func__, llama_sampler_name(sampler), seq_id);
|
|
993
|
+
|
|
994
|
+
sampling.samplers.erase(seq_id);
|
|
995
|
+
|
|
996
|
+
return false;
|
|
997
|
+
}
|
|
998
|
+
|
|
999
|
+
sampling.samplers.erase(seq_id);
|
|
1000
|
+
|
|
1001
|
+
return true;
|
|
1002
|
+
}
|
|
1003
|
+
|
|
770
1004
|
void llama_context::set_adapter_lora(
|
|
771
1005
|
llama_adapter_lora * adapter,
|
|
772
1006
|
float scale) {
|
|
@@ -907,7 +1141,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|
|
907
1141
|
n_queued_tokens += n_tokens;
|
|
908
1142
|
|
|
909
1143
|
// reserve output buffer
|
|
910
|
-
if (output_reserve(n_tokens) < n_tokens) {
|
|
1144
|
+
if (output_reserve(n_tokens, batch_inp) < n_tokens) {
|
|
911
1145
|
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
|
|
912
1146
|
return -2;
|
|
913
1147
|
};
|
|
@@ -961,9 +1195,10 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|
|
961
1195
|
{
|
|
962
1196
|
// extract token embeddings
|
|
963
1197
|
GGML_ASSERT(embd != nullptr);
|
|
1198
|
+
const uint32_t n_embd_out = hparams.get_n_embd_out();
|
|
964
1199
|
|
|
965
|
-
GGML_ASSERT(n_tokens*
|
|
966
|
-
ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*
|
|
1200
|
+
GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd_size);
|
|
1201
|
+
ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd_out*sizeof(float));
|
|
967
1202
|
} break;
|
|
968
1203
|
case LLAMA_POOLING_TYPE_MEAN:
|
|
969
1204
|
case LLAMA_POOLING_TYPE_CLS:
|
|
@@ -1031,6 +1266,112 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|
|
1031
1266
|
return 0;
|
|
1032
1267
|
}
|
|
1033
1268
|
|
|
1269
|
+
static std::map<llama_seq_id, uint32_t> build_seq_to_output_row(const llama_ubatch & ubatch, uint32_t row_offset) {
|
|
1270
|
+
std::map<llama_seq_id, uint32_t> seq_to_row;
|
|
1271
|
+
// how many output tokens we have seen so far for this ubatch.
|
|
1272
|
+
uint32_t local = 0;
|
|
1273
|
+
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
|
1274
|
+
// skip tokens that are not output.
|
|
1275
|
+
if (!ubatch.output[i]) {
|
|
1276
|
+
continue;
|
|
1277
|
+
}
|
|
1278
|
+
|
|
1279
|
+
const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
|
1280
|
+
// row_offset is the number of output tokens before this ubatch.
|
|
1281
|
+
seq_to_row[seq_id] = row_offset + local;
|
|
1282
|
+
++local;
|
|
1283
|
+
}
|
|
1284
|
+
return seq_to_row;
|
|
1285
|
+
}
|
|
1286
|
+
|
|
1287
|
+
static void copy_tensor_async_ints(
|
|
1288
|
+
const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
|
|
1289
|
+
llama_token * sampled,
|
|
1290
|
+
size_t sampled_size,
|
|
1291
|
+
const std::map<llama_seq_id, uint32_t> & seq_to_row,
|
|
1292
|
+
ggml_backend_sched_t sched) {
|
|
1293
|
+
if (sampled == nullptr) {
|
|
1294
|
+
return;
|
|
1295
|
+
}
|
|
1296
|
+
|
|
1297
|
+
for (const auto & [seq_id, tensor] : tensor_map) {
|
|
1298
|
+
auto it = seq_to_row.find(seq_id);
|
|
1299
|
+
if (it == seq_to_row.end()) {
|
|
1300
|
+
continue;
|
|
1301
|
+
}
|
|
1302
|
+
|
|
1303
|
+
const uint32_t row = it->second;
|
|
1304
|
+
GGML_ASSERT(row < sampled_size);
|
|
1305
|
+
|
|
1306
|
+
GGML_ASSERT(ggml_is_contiguous(tensor) && "sampled tokens tensor must be contiguous for async copy");
|
|
1307
|
+
|
|
1308
|
+
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
|
|
1309
|
+
ggml_backend_tensor_get_async(backend, tensor, sampled + row, 0, sizeof(sampled[row]));
|
|
1310
|
+
}
|
|
1311
|
+
}
|
|
1312
|
+
|
|
1313
|
+
static void copy_tensor_async_floats(
|
|
1314
|
+
const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
|
|
1315
|
+
float * dst,
|
|
1316
|
+
size_t stride,
|
|
1317
|
+
std::vector<uint32_t> & counts,
|
|
1318
|
+
const std::map<llama_seq_id, uint32_t> & seq_to_row,
|
|
1319
|
+
ggml_backend_sched_t sched) {
|
|
1320
|
+
if (dst == nullptr) {
|
|
1321
|
+
return;
|
|
1322
|
+
}
|
|
1323
|
+
|
|
1324
|
+
for (const auto & [seq_id, tensor] : tensor_map) {
|
|
1325
|
+
auto it = seq_to_row.find(seq_id);
|
|
1326
|
+
if (it == seq_to_row.end()) {
|
|
1327
|
+
continue;
|
|
1328
|
+
}
|
|
1329
|
+
|
|
1330
|
+
const uint32_t row = it->second;
|
|
1331
|
+
GGML_ASSERT(row < counts.size());
|
|
1332
|
+
|
|
1333
|
+
GGML_ASSERT(ggml_is_contiguous(tensor) && "logits/probs tensor must be contiguous for async copy");
|
|
1334
|
+
|
|
1335
|
+
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
|
|
1336
|
+
float * row_ptr = dst + (size_t) row * stride;
|
|
1337
|
+
ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
|
|
1338
|
+
|
|
1339
|
+
// Update the actual number of logits/probabilities that were written for this row.
|
|
1340
|
+
counts[row] = ggml_nelements(tensor);
|
|
1341
|
+
}
|
|
1342
|
+
}
|
|
1343
|
+
|
|
1344
|
+
static void copy_tensor_async_candidates(
|
|
1345
|
+
const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
|
|
1346
|
+
llama_token * dst,
|
|
1347
|
+
size_t stride,
|
|
1348
|
+
std::vector<uint32_t> & counts,
|
|
1349
|
+
const std::map<llama_seq_id, uint32_t> & seq_to_row,
|
|
1350
|
+
ggml_backend_sched_t sched) {
|
|
1351
|
+
if (dst == nullptr) {
|
|
1352
|
+
return;
|
|
1353
|
+
}
|
|
1354
|
+
|
|
1355
|
+
for (const auto & [seq_id, tensor] : tensor_map) {
|
|
1356
|
+
auto it = seq_to_row.find(seq_id);
|
|
1357
|
+
if (it == seq_to_row.end()) {
|
|
1358
|
+
continue;
|
|
1359
|
+
}
|
|
1360
|
+
|
|
1361
|
+
const uint32_t row = it->second;
|
|
1362
|
+
GGML_ASSERT(row < counts.size());
|
|
1363
|
+
|
|
1364
|
+
GGML_ASSERT(ggml_is_contiguous(tensor) && "candidates tensor must be contiguous for async copy");
|
|
1365
|
+
|
|
1366
|
+
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
|
|
1367
|
+
llama_token * row_ptr = dst + (size_t) row * stride;
|
|
1368
|
+
ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
|
|
1369
|
+
|
|
1370
|
+
// Update the actual number of candidates that were written.
|
|
1371
|
+
counts[row] = ggml_nelements(tensor);
|
|
1372
|
+
}
|
|
1373
|
+
}
|
|
1374
|
+
|
|
1034
1375
|
int llama_context::decode(const llama_batch & batch_inp) {
|
|
1035
1376
|
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
|
|
1036
1377
|
|
|
@@ -1051,9 +1392,36 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|
|
1051
1392
|
const int64_t n_embd = hparams.n_embd_inp();
|
|
1052
1393
|
|
|
1053
1394
|
// when computing embeddings, all tokens are output
|
|
1054
|
-
const bool output_all
|
|
1395
|
+
const bool output_all = cparams.embeddings;
|
|
1396
|
+
const bool has_samplers = !sampling.samplers.empty();
|
|
1397
|
+
|
|
1398
|
+
const uint32_t n_seq_max = cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max;
|
|
1055
1399
|
|
|
1056
|
-
|
|
1400
|
+
// TODO: avoid this workaround in the future
|
|
1401
|
+
if (has_samplers && batch_inp.logits) {
|
|
1402
|
+
std::vector<int32_t> seq_output_count(n_seq_max, 0);
|
|
1403
|
+
|
|
1404
|
+
for (int32_t i = 0; i < batch_inp.n_tokens; ++i) {
|
|
1405
|
+
if (batch_inp.logits[i] == 0) {
|
|
1406
|
+
continue;
|
|
1407
|
+
}
|
|
1408
|
+
|
|
1409
|
+
const int ns = batch_inp.n_seq_id ? batch_inp.n_seq_id[i] : 1;
|
|
1410
|
+
|
|
1411
|
+
for (int32_t s = 0; s < ns; ++s) {
|
|
1412
|
+
const llama_seq_id seq_id = batch_inp.seq_id ? batch_inp.seq_id[i][s] : 0;
|
|
1413
|
+
|
|
1414
|
+
seq_output_count[seq_id]++;
|
|
1415
|
+
if (seq_output_count[seq_id] > 1) {
|
|
1416
|
+
LLAMA_LOG_ERROR("%s: backend sampling requires at most one output token per sequence (seq_id %d had %d)\n",
|
|
1417
|
+
__func__, seq_id, seq_output_count[seq_id]);
|
|
1418
|
+
return -1;
|
|
1419
|
+
}
|
|
1420
|
+
}
|
|
1421
|
+
}
|
|
1422
|
+
}
|
|
1423
|
+
|
|
1424
|
+
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, n_seq_max, output_all)) {
|
|
1057
1425
|
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
|
1058
1426
|
return -1;
|
|
1059
1427
|
}
|
|
@@ -1134,7 +1502,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|
|
1134
1502
|
}
|
|
1135
1503
|
|
|
1136
1504
|
// reserve output buffer
|
|
1137
|
-
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
|
1505
|
+
if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) {
|
|
1138
1506
|
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
|
|
1139
1507
|
return -2;
|
|
1140
1508
|
};
|
|
@@ -1207,7 +1575,10 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|
|
1207
1575
|
}
|
|
1208
1576
|
|
|
1209
1577
|
// extract logits
|
|
1210
|
-
|
|
1578
|
+
// For multi-sequence batches that mix backend samplers and CPU sampler
|
|
1579
|
+
// this is currently inefficient as we copy all logits even for the
|
|
1580
|
+
// backend sampled tokens.
|
|
1581
|
+
if (logits && t_logits && n_outputs > 0) {
|
|
1211
1582
|
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
|
|
1212
1583
|
GGML_ASSERT(backend_res != nullptr);
|
|
1213
1584
|
GGML_ASSERT(logits != nullptr);
|
|
@@ -1222,7 +1593,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|
|
1222
1593
|
}
|
|
1223
1594
|
|
|
1224
1595
|
// extract embeddings
|
|
1225
|
-
if (t_embd && n_outputs > 0) {
|
|
1596
|
+
if (embd && t_embd && n_outputs > 0) {
|
|
1226
1597
|
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
|
|
1227
1598
|
GGML_ASSERT(backend_embd != nullptr);
|
|
1228
1599
|
|
|
@@ -1231,12 +1602,13 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|
|
1231
1602
|
{
|
|
1232
1603
|
// extract token embeddings
|
|
1233
1604
|
GGML_ASSERT(embd != nullptr);
|
|
1234
|
-
|
|
1605
|
+
const uint32_t n_embd_out = hparams.get_n_embd_out();
|
|
1606
|
+
float * embd_out = embd + n_outputs_prev*n_embd_out;
|
|
1235
1607
|
|
|
1236
1608
|
if (n_outputs) {
|
|
1237
1609
|
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
|
|
1238
|
-
GGML_ASSERT((n_outputs_prev + n_outputs)*
|
|
1239
|
-
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*
|
|
1610
|
+
GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd_out <= (int64_t) embd_size);
|
|
1611
|
+
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd_out*sizeof(float));
|
|
1240
1612
|
}
|
|
1241
1613
|
} break;
|
|
1242
1614
|
case LLAMA_POOLING_TYPE_MEAN:
|
|
@@ -1276,6 +1648,22 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|
|
1276
1648
|
}
|
|
1277
1649
|
}
|
|
1278
1650
|
|
|
1651
|
+
// This flag indicates whether a backend sampler has actually sampled a specific
|
|
1652
|
+
// token, or if it has produced probabilites. If true, we can skip the normal copying of logits and embeddings.
|
|
1653
|
+
const bool has_sampled = !res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty();
|
|
1654
|
+
|
|
1655
|
+
if (has_samplers && has_sampled) {
|
|
1656
|
+
const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev);
|
|
1657
|
+
const auto stride = n_vocab;
|
|
1658
|
+
|
|
1659
|
+
// async copy the sampling data from the backend to the host
|
|
1660
|
+
copy_tensor_async_ints(res->t_sampled, sampling.sampled, sampling.sampled_size, seq_to_output_row, sched.get());
|
|
1661
|
+
|
|
1662
|
+
copy_tensor_async_floats (res->t_sampled_logits, sampling.logits, stride, sampling.logits_count, seq_to_output_row, sched.get());
|
|
1663
|
+
copy_tensor_async_floats (res->t_sampled_probs, sampling.probs, stride, sampling.probs_count, seq_to_output_row, sched.get());
|
|
1664
|
+
copy_tensor_async_candidates(res->t_candidates, sampling.candidates, stride, sampling.candidates_count, seq_to_output_row, sched.get());
|
|
1665
|
+
}
|
|
1666
|
+
|
|
1279
1667
|
n_outputs_prev += n_outputs;
|
|
1280
1668
|
} while (mctx->next());
|
|
1281
1669
|
|
|
@@ -1339,15 +1727,15 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|
|
1339
1727
|
// output
|
|
1340
1728
|
//
|
|
1341
1729
|
|
|
1342
|
-
uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
|
1730
|
+
uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & batch) {
|
|
1343
1731
|
const auto & hparams = model.hparams;
|
|
1344
1732
|
const auto & vocab = model.vocab;
|
|
1345
1733
|
|
|
1346
1734
|
const int64_t n_outputs_max = std::max<int64_t>(n_outputs, n_seq_max());
|
|
1347
1735
|
|
|
1348
|
-
const auto n_batch
|
|
1349
|
-
const auto n_vocab
|
|
1350
|
-
const auto
|
|
1736
|
+
const auto n_batch = cparams.n_batch;
|
|
1737
|
+
const auto n_vocab = vocab.n_tokens();
|
|
1738
|
+
const auto n_embd_out = hparams.get_n_embd_out();
|
|
1351
1739
|
|
|
1352
1740
|
bool has_logits = true;
|
|
1353
1741
|
bool has_embd = cparams.embeddings;
|
|
@@ -1358,8 +1746,53 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
|
|
1358
1746
|
has_embd = true;
|
|
1359
1747
|
}
|
|
1360
1748
|
|
|
1361
|
-
|
|
1362
|
-
|
|
1749
|
+
// Check which sampling modes are needed for the current batch.
|
|
1750
|
+
// TODO: avoid this branching by working with the worst-case
|
|
1751
|
+
bool has_sampling = false;
|
|
1752
|
+
bool cpu_logits = false;
|
|
1753
|
+
|
|
1754
|
+
if (batch.logits) {
|
|
1755
|
+
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
|
1756
|
+
if (!batch.logits[i]) {
|
|
1757
|
+
continue;
|
|
1758
|
+
}
|
|
1759
|
+
for (int32_t j = 0; j < batch.n_seq_id[i]; j++) {
|
|
1760
|
+
llama_seq_id seq_id = batch.seq_id[i][j];
|
|
1761
|
+
if (sampling.samplers.find(seq_id) != sampling.samplers.end()) {
|
|
1762
|
+
has_sampling = true;
|
|
1763
|
+
} else {
|
|
1764
|
+
cpu_logits = true;
|
|
1765
|
+
}
|
|
1766
|
+
}
|
|
1767
|
+
}
|
|
1768
|
+
} else {
|
|
1769
|
+
// When batch.logits is nullptr (when loading state with a dummy batch),
|
|
1770
|
+
// allocate CPU logits.
|
|
1771
|
+
cpu_logits = true;
|
|
1772
|
+
}
|
|
1773
|
+
|
|
1774
|
+
size_t backend_float_count = 0;
|
|
1775
|
+
size_t backend_token_count = 0;
|
|
1776
|
+
|
|
1777
|
+
// Allocate CPU logits buffer only if needed by sequences in this batch
|
|
1778
|
+
logits_size = (has_logits && cpu_logits) ? n_vocab*n_outputs_max : 0;
|
|
1779
|
+
embd_size = has_embd ? n_embd_out*n_outputs_max : 0;
|
|
1780
|
+
|
|
1781
|
+
// TODO: avoid this branching by working with the worst-case
|
|
1782
|
+
if (!has_sampling) {
|
|
1783
|
+
sampling.logits_size = 0;
|
|
1784
|
+
sampling.probs_size = 0;
|
|
1785
|
+
sampling.sampled_size = 0;
|
|
1786
|
+
sampling.candidates_size = 0;
|
|
1787
|
+
} else {
|
|
1788
|
+
sampling.logits_size = n_vocab*n_outputs_max;
|
|
1789
|
+
sampling.probs_size = n_vocab*n_outputs_max;
|
|
1790
|
+
sampling.sampled_size = n_outputs_max;
|
|
1791
|
+
sampling.candidates_size = n_vocab*n_outputs_max;
|
|
1792
|
+
|
|
1793
|
+
backend_float_count = sampling.logits_size + sampling.probs_size;
|
|
1794
|
+
backend_token_count = sampling.sampled_size + sampling.candidates_size;
|
|
1795
|
+
}
|
|
1363
1796
|
|
|
1364
1797
|
if (output_ids.empty()) {
|
|
1365
1798
|
// init, never resized afterwards
|
|
@@ -1367,7 +1800,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
|
|
1367
1800
|
}
|
|
1368
1801
|
|
|
1369
1802
|
const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0;
|
|
1370
|
-
const size_t new_size =
|
|
1803
|
+
const size_t new_size =
|
|
1804
|
+
(logits_size + embd_size + backend_float_count) * sizeof(float) +
|
|
1805
|
+
( backend_token_count) * sizeof(llama_token);
|
|
1371
1806
|
|
|
1372
1807
|
// alloc only when more than the current capacity is required
|
|
1373
1808
|
// TODO: also consider shrinking the buffer
|
|
@@ -1375,9 +1810,11 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
|
|
1375
1810
|
if (buf_output) {
|
|
1376
1811
|
#ifndef NDEBUG
|
|
1377
1812
|
// This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark)
|
|
1378
|
-
|
|
1813
|
+
LLAMA_LOG_DEBUG("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
|
|
1379
1814
|
#endif
|
|
1380
1815
|
synchronize();
|
|
1816
|
+
|
|
1817
|
+
// TODO: not needed?
|
|
1381
1818
|
buf_output = nullptr;
|
|
1382
1819
|
logits = nullptr;
|
|
1383
1820
|
embd = nullptr;
|
|
@@ -1399,8 +1836,49 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
|
|
1399
1836
|
|
|
1400
1837
|
float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get());
|
|
1401
1838
|
|
|
1402
|
-
logits =
|
|
1403
|
-
embd =
|
|
1839
|
+
logits = nullptr;
|
|
1840
|
+
embd = nullptr;
|
|
1841
|
+
|
|
1842
|
+
size_t offset = 0;
|
|
1843
|
+
uint8_t * base = (uint8_t *) output_base;
|
|
1844
|
+
|
|
1845
|
+
logits = (has_logits && cpu_logits) ? output_base : nullptr;
|
|
1846
|
+
offset += logits_size * sizeof(float);
|
|
1847
|
+
|
|
1848
|
+
embd = has_embd ? (float *) (base + offset) : nullptr;
|
|
1849
|
+
offset += embd_size * sizeof(float);
|
|
1850
|
+
|
|
1851
|
+
sampling.logits = nullptr;
|
|
1852
|
+
sampling.probs = nullptr;
|
|
1853
|
+
sampling.sampled = nullptr;
|
|
1854
|
+
sampling.candidates = nullptr;
|
|
1855
|
+
|
|
1856
|
+
if (has_sampling) {
|
|
1857
|
+
sampling.logits = (float *) (base + offset);
|
|
1858
|
+
offset += sampling.logits_size * sizeof(float);
|
|
1859
|
+
|
|
1860
|
+
sampling.probs = (float *) (base + offset);
|
|
1861
|
+
offset += sampling.probs_size * sizeof(float);
|
|
1862
|
+
|
|
1863
|
+
sampling.sampled = (llama_token *) (base + offset);
|
|
1864
|
+
offset += sampling.sampled_size * sizeof(llama_token);
|
|
1865
|
+
|
|
1866
|
+
sampling.candidates = (llama_token *) (base + offset);
|
|
1867
|
+
offset += sampling.candidates_size * sizeof(llama_token);
|
|
1868
|
+
|
|
1869
|
+
// The count vectors keep track of the actual number of logits/probs/candidates
|
|
1870
|
+
// copied from the backend for each output row.
|
|
1871
|
+
|
|
1872
|
+
sampling.logits_count.resize(n_outputs_max);
|
|
1873
|
+
sampling.probs_count.resize(n_outputs_max);
|
|
1874
|
+
sampling.candidates_count.resize(n_outputs_max);
|
|
1875
|
+
|
|
1876
|
+
std::fill(sampling.logits_count.begin(), sampling.logits_count.end(), 0);
|
|
1877
|
+
std::fill(sampling.probs_count.begin(), sampling.probs_count.end(), 0);
|
|
1878
|
+
std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0);
|
|
1879
|
+
|
|
1880
|
+
std::fill_n(sampling.sampled, sampling.sampled_size, LLAMA_TOKEN_NULL);
|
|
1881
|
+
}
|
|
1404
1882
|
|
|
1405
1883
|
// set all ids as invalid (negative)
|
|
1406
1884
|
std::fill(output_ids.begin(), output_ids.end(), -1);
|
|
@@ -1429,6 +1907,40 @@ void llama_context::output_reorder() {
|
|
|
1429
1907
|
std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]);
|
|
1430
1908
|
}
|
|
1431
1909
|
}
|
|
1910
|
+
|
|
1911
|
+
if (sampling.logits && sampling.logits_size > 0) {
|
|
1912
|
+
for (uint64_t k = 0; k < n_vocab; ++k) {
|
|
1913
|
+
std::swap(sampling.logits[i0*n_vocab + k], sampling.logits[i1*n_vocab + k]);
|
|
1914
|
+
}
|
|
1915
|
+
}
|
|
1916
|
+
|
|
1917
|
+
if (sampling.probs && sampling.probs_size > 0) {
|
|
1918
|
+
for (uint64_t k = 0; k < n_vocab; ++k) {
|
|
1919
|
+
std::swap(sampling.probs[i0*n_vocab + k], sampling.probs[i1*n_vocab + k]);
|
|
1920
|
+
}
|
|
1921
|
+
}
|
|
1922
|
+
|
|
1923
|
+
if (sampling.candidates && sampling.candidates_size > 0) {
|
|
1924
|
+
for (uint64_t k = 0; k < n_vocab; ++k) {
|
|
1925
|
+
std::swap(sampling.candidates[i0*n_vocab + k], sampling.candidates[i1*n_vocab + k]);
|
|
1926
|
+
}
|
|
1927
|
+
}
|
|
1928
|
+
|
|
1929
|
+
if (sampling.sampled && sampling.sampled_size > 0) {
|
|
1930
|
+
std::swap(sampling.sampled[i0], sampling.sampled[i1]);
|
|
1931
|
+
}
|
|
1932
|
+
|
|
1933
|
+
if (!sampling.logits_count.empty()) {
|
|
1934
|
+
std::swap(sampling.logits_count[i0], sampling.logits_count[i1]);
|
|
1935
|
+
}
|
|
1936
|
+
|
|
1937
|
+
if (!sampling.probs_count.empty()) {
|
|
1938
|
+
std::swap(sampling.probs_count[i0], sampling.probs_count[i1]);
|
|
1939
|
+
}
|
|
1940
|
+
|
|
1941
|
+
if (!sampling.candidates_count.empty()) {
|
|
1942
|
+
std::swap(sampling.candidates_count[i0], sampling.candidates_count[i1]);
|
|
1943
|
+
}
|
|
1432
1944
|
}
|
|
1433
1945
|
|
|
1434
1946
|
output_swaps.clear();
|
|
@@ -1458,7 +1970,7 @@ ggml_cgraph * llama_context::graph_reserve(
|
|
|
1458
1970
|
|
|
1459
1971
|
if (n_tokens % n_seqs != 0) {
|
|
1460
1972
|
n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs
|
|
1461
|
-
n_outputs = std::
|
|
1973
|
+
n_outputs = std::max(n_outputs, n_tokens);
|
|
1462
1974
|
|
|
1463
1975
|
LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
|
|
1464
1976
|
}
|
|
@@ -1477,6 +1989,15 @@ ggml_cgraph * llama_context::graph_reserve(
|
|
|
1477
1989
|
llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
|
|
1478
1990
|
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
|
|
1479
1991
|
|
|
1992
|
+
// set one output token per sequence in order to activate all backend samplers
|
|
1993
|
+
std::vector<llama_seq_id> seq_ids(n_seqs);
|
|
1994
|
+
for (uint32_t i = 0; i < n_seqs; ++i) {
|
|
1995
|
+
seq_ids[i] = i;
|
|
1996
|
+
ubatch.n_seq_id[i] = 1;
|
|
1997
|
+
ubatch.seq_id[i] = &seq_ids[i];
|
|
1998
|
+
ubatch.output[i] = true;
|
|
1999
|
+
}
|
|
2000
|
+
|
|
1480
2001
|
auto * res = gf_res_reserve.get();
|
|
1481
2002
|
|
|
1482
2003
|
const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT);
|
|
@@ -1507,7 +2028,7 @@ llm_graph_params llama_context::graph_params(
|
|
|
1507
2028
|
llm_graph_result * res,
|
|
1508
2029
|
const llama_ubatch & ubatch,
|
|
1509
2030
|
const llama_memory_context_i * mctx,
|
|
1510
|
-
|
|
2031
|
+
llm_graph_type gtype) const {
|
|
1511
2032
|
return {
|
|
1512
2033
|
/*.arch =*/ model.arch,
|
|
1513
2034
|
/*.hparams =*/ model.hparams,
|
|
@@ -1520,6 +2041,7 @@ llm_graph_params llama_context::graph_params(
|
|
|
1520
2041
|
/*.loras =*/ &loras,
|
|
1521
2042
|
/*.mctx =*/ mctx,
|
|
1522
2043
|
/*.cross =*/ &cross,
|
|
2044
|
+
/*.samplers =*/ sampling.samplers,
|
|
1523
2045
|
/*.n_outputs =*/ n_outputs,
|
|
1524
2046
|
/*.cb =*/ graph_get_cb(),
|
|
1525
2047
|
/*.res =*/ res,
|
|
@@ -1975,6 +2497,9 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
|
|
|
1975
2497
|
}
|
|
1976
2498
|
}
|
|
1977
2499
|
|
|
2500
|
+
// TODO: handle sampling buffers and samplers state ?
|
|
2501
|
+
// https://github.com/ggml-org/llama.cpp/pull/17004
|
|
2502
|
+
|
|
1978
2503
|
if (memory != nullptr) {
|
|
1979
2504
|
LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__);
|
|
1980
2505
|
memory->state_write(io);
|
|
@@ -2007,7 +2532,10 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
|
|
|
2007
2532
|
auto n_outputs = this->n_outputs;
|
|
2008
2533
|
io.read_to(&n_outputs, sizeof(n_outputs));
|
|
2009
2534
|
|
|
2010
|
-
|
|
2535
|
+
// Create a dummy batch for state loading.
|
|
2536
|
+
llama_batch dummy_batch = {};
|
|
2537
|
+
dummy_batch.n_tokens = 0;
|
|
2538
|
+
if (n_outputs > output_reserve(n_outputs, dummy_batch)) {
|
|
2011
2539
|
throw std::runtime_error("could not reserve outputs");
|
|
2012
2540
|
}
|
|
2013
2541
|
|
|
@@ -2061,6 +2589,9 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
|
|
|
2061
2589
|
}
|
|
2062
2590
|
}
|
|
2063
2591
|
|
|
2592
|
+
// TODO: handle sampling buffers and samplers state ?
|
|
2593
|
+
// https://github.com/ggml-org/llama.cpp/pull/17004
|
|
2594
|
+
|
|
2064
2595
|
if (memory) {
|
|
2065
2596
|
LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__);
|
|
2066
2597
|
|
|
@@ -2249,7 +2780,7 @@ void llama_context::opt_epoch_iter(
|
|
|
2249
2780
|
}
|
|
2250
2781
|
|
|
2251
2782
|
// reserve output buffer
|
|
2252
|
-
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
|
2783
|
+
if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) {
|
|
2253
2784
|
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
|
|
2254
2785
|
GGML_ABORT("TODO: handle this error");
|
|
2255
2786
|
};
|
|
@@ -2394,6 +2925,8 @@ llama_context_params llama_context_default_params() {
|
|
|
2394
2925
|
/*.op_offload =*/ true,
|
|
2395
2926
|
/*.swa_full =*/ true,
|
|
2396
2927
|
/*.kv_unified =*/ false,
|
|
2928
|
+
/*.sampler =*/ nullptr,
|
|
2929
|
+
/*.n_sampler =*/ 0,
|
|
2397
2930
|
};
|
|
2398
2931
|
|
|
2399
2932
|
return result;
|
|
@@ -2553,7 +3086,15 @@ float * llama_get_logits(llama_context * ctx) {
|
|
|
2553
3086
|
float * llama_get_logits_ith(llama_context * ctx, int32_t i) {
|
|
2554
3087
|
ctx->synchronize();
|
|
2555
3088
|
|
|
2556
|
-
|
|
3089
|
+
float * res = nullptr;
|
|
3090
|
+
|
|
3091
|
+
res = ctx->get_sampled_logits_ith(i);
|
|
3092
|
+
|
|
3093
|
+
if (!res) {
|
|
3094
|
+
res = ctx->get_logits_ith(i);
|
|
3095
|
+
}
|
|
3096
|
+
|
|
3097
|
+
return res;
|
|
2557
3098
|
}
|
|
2558
3099
|
|
|
2559
3100
|
float * llama_get_embeddings(llama_context * ctx) {
|
|
@@ -2574,6 +3115,52 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
|
|
|
2574
3115
|
return ctx->get_embeddings_seq(seq_id);
|
|
2575
3116
|
}
|
|
2576
3117
|
|
|
3118
|
+
bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) {
|
|
3119
|
+
return ctx->set_sampler(seq_id, smpl);
|
|
3120
|
+
}
|
|
3121
|
+
|
|
3122
|
+
llama_token llama_get_sampled_token_ith(llama_context * ctx, int32_t i) {
|
|
3123
|
+
ctx->synchronize();
|
|
3124
|
+
|
|
3125
|
+
return ctx->get_sampled_token_ith(i);
|
|
3126
|
+
}
|
|
3127
|
+
|
|
3128
|
+
float * llama_get_sampled_probs_ith(llama_context * ctx, int32_t i) {
|
|
3129
|
+
ctx->synchronize();
|
|
3130
|
+
|
|
3131
|
+
return ctx->get_sampled_probs_ith(i);
|
|
3132
|
+
}
|
|
3133
|
+
|
|
3134
|
+
float * llama_get_sampled_logits_ith(llama_context * ctx, int32_t i) {
|
|
3135
|
+
ctx->synchronize();
|
|
3136
|
+
|
|
3137
|
+
return ctx->get_sampled_logits_ith(i);
|
|
3138
|
+
}
|
|
3139
|
+
|
|
3140
|
+
llama_token * llama_get_sampled_candidates_ith(llama_context * ctx, int32_t i) {
|
|
3141
|
+
ctx->synchronize();
|
|
3142
|
+
|
|
3143
|
+
return const_cast<llama_token *>(ctx->get_sampled_candidates_ith(i));
|
|
3144
|
+
}
|
|
3145
|
+
|
|
3146
|
+
uint32_t llama_get_sampled_candidates_count_ith(llama_context * ctx, int32_t i) {
|
|
3147
|
+
ctx->synchronize();
|
|
3148
|
+
|
|
3149
|
+
return static_cast<uint32_t>(ctx->get_sampled_candidates_count(i));
|
|
3150
|
+
}
|
|
3151
|
+
|
|
3152
|
+
uint32_t llama_get_sampled_logits_count_ith(llama_context * ctx, int32_t i) {
|
|
3153
|
+
ctx->synchronize();
|
|
3154
|
+
|
|
3155
|
+
return static_cast<uint32_t>(ctx->get_sampled_logits_count(i));
|
|
3156
|
+
}
|
|
3157
|
+
|
|
3158
|
+
uint32_t llama_get_sampled_probs_count_ith(llama_context * ctx, int32_t i) {
|
|
3159
|
+
ctx->synchronize();
|
|
3160
|
+
|
|
3161
|
+
return static_cast<uint32_t>(ctx->get_sampled_probs_count(i));
|
|
3162
|
+
}
|
|
3163
|
+
|
|
2577
3164
|
// llama adapter API
|
|
2578
3165
|
|
|
2579
3166
|
int32_t llama_set_adapter_lora(
|