@fugood/llama.node 1.4.11 → 1.4.13

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. package/package.json +15 -15
  2. package/scripts/llama.cpp.patch +31 -31
  3. package/src/llama.cpp/common/arg.cpp +128 -59
  4. package/src/llama.cpp/common/arg.h +1 -0
  5. package/src/llama.cpp/common/chat-parser.cpp +11 -0
  6. package/src/llama.cpp/common/chat.cpp +36 -7
  7. package/src/llama.cpp/common/chat.h +1 -0
  8. package/src/llama.cpp/common/common.cpp +42 -23
  9. package/src/llama.cpp/common/common.h +11 -1
  10. package/src/llama.cpp/common/llguidance.cpp +10 -6
  11. package/src/llama.cpp/common/regex-partial.cpp +13 -13
  12. package/src/llama.cpp/common/sampling.cpp +58 -14
  13. package/src/llama.cpp/common/sampling.h +3 -1
  14. package/src/llama.cpp/ggml/CMakeLists.txt +13 -1
  15. package/src/llama.cpp/ggml/include/ggml-backend.h +1 -1
  16. package/src/llama.cpp/ggml/src/CMakeLists.txt +23 -9
  17. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +12 -2
  18. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +1 -1
  19. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +86 -25
  20. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +15 -8
  21. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +768 -0
  22. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +0 -4
  23. package/src/llama.cpp/include/llama.h +100 -12
  24. package/src/llama.cpp/src/CMakeLists.txt +4 -0
  25. package/src/llama.cpp/src/llama-adapter.cpp +12 -3
  26. package/src/llama.cpp/src/llama-adapter.h +7 -1
  27. package/src/llama.cpp/src/llama-arch.cpp +78 -0
  28. package/src/llama.cpp/src/llama-arch.h +8 -0
  29. package/src/llama.cpp/src/llama-chat.cpp +11 -0
  30. package/src/llama.cpp/src/llama-chat.h +1 -0
  31. package/src/llama.cpp/src/llama-context.cpp +637 -49
  32. package/src/llama.cpp/src/llama-context.h +43 -1
  33. package/src/llama.cpp/src/llama-grammar.cpp +40 -13
  34. package/src/llama.cpp/src/llama-grammar.h +2 -0
  35. package/src/llama.cpp/src/llama-graph.cpp +173 -5
  36. package/src/llama.cpp/src/llama-graph.h +71 -6
  37. package/src/llama.cpp/src/llama-hparams.cpp +4 -0
  38. package/src/llama.cpp/src/llama-hparams.h +12 -5
  39. package/src/llama.cpp/src/llama-kv-cache.h +1 -1
  40. package/src/llama.cpp/src/llama-mmap.cpp +11 -4
  41. package/src/llama.cpp/src/llama-model-loader.cpp +23 -0
  42. package/src/llama.cpp/src/llama-model-loader.h +2 -0
  43. package/src/llama.cpp/src/llama-model-saver.cpp +3 -0
  44. package/src/llama.cpp/src/llama-model.cpp +337 -26
  45. package/src/llama.cpp/src/llama-model.h +13 -2
  46. package/src/llama.cpp/src/llama-sampling.cpp +1259 -186
  47. package/src/llama.cpp/src/llama-sampling.h +19 -7
  48. package/src/llama.cpp/src/llama-vocab.cpp +101 -33
  49. package/src/llama.cpp/src/llama-vocab.h +2 -0
  50. package/src/llama.cpp/src/llama.cpp +87 -64
  51. package/src/llama.cpp/src/models/afmoe.cpp +9 -5
  52. package/src/llama.cpp/src/models/bert.cpp +4 -2
  53. package/src/llama.cpp/src/models/cogvlm.cpp +5 -3
  54. package/src/llama.cpp/src/models/cohere2-iswa.cpp +3 -0
  55. package/src/llama.cpp/src/models/deepseek2.cpp +1 -1
  56. package/src/llama.cpp/src/models/gemma-embedding.cpp +2 -6
  57. package/src/llama.cpp/src/models/gemma2-iswa.cpp +5 -2
  58. package/src/llama.cpp/src/models/gemma3.cpp +3 -4
  59. package/src/llama.cpp/src/models/gemma3n-iswa.cpp +4 -7
  60. package/src/llama.cpp/src/models/llama-iswa.cpp +6 -2
  61. package/src/llama.cpp/src/models/llama.cpp +19 -6
  62. package/src/llama.cpp/src/models/maincoder.cpp +117 -0
  63. package/src/llama.cpp/src/models/mimo2-iswa.cpp +123 -0
  64. package/src/llama.cpp/src/models/models.h +18 -0
  65. package/src/llama.cpp/src/models/modern-bert.cpp +116 -0
  66. package/src/llama.cpp/src/models/openai-moe-iswa.cpp +5 -2
  67. package/src/llama.cpp/src/models/plamo3.cpp +128 -0
  68. package/src/llama.cpp/src/models/smallthinker.cpp +11 -5
  69. package/src/llama.cpp/src/unicode.cpp +23 -14
@@ -60,6 +60,25 @@ llama_context::llama_context(
60
60
  cparams.cb_eval = params.cb_eval;
61
61
  cparams.cb_eval_user_data = params.cb_eval_user_data;
62
62
 
63
+ // Initialize backend samplers here so they are part of the sampling graph
64
+ // before the reserve passes run later in this function. This avoids a later
65
+ // re-reserve when graph nodes change.
66
+ if (params.samplers != nullptr && params.n_samplers > 0) {
67
+ for (size_t i = 0; i < params.n_samplers; ++i) {
68
+ const auto & config = params.samplers[i];
69
+
70
+ if (llama_sampler_chain_get(config.sampler, -1) == nullptr) {
71
+ throw std::runtime_error("the backend samplers must be of type llama_sampler_chain");
72
+ }
73
+
74
+ if (set_sampler(config.seq_id, config.sampler)) {
75
+ const int n_samplers = llama_sampler_chain_n(config.sampler);
76
+
77
+ LLAMA_LOG_INFO("%s: setting backend sampler for seq_id %d (n = %d)\n", __func__, config.seq_id, n_samplers);
78
+ }
79
+ }
80
+ }
81
+
63
82
  auto rope_scaling_type = params.rope_scaling_type;
64
83
  if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) {
65
84
  rope_scaling_type = hparams.rope_scaling_type_train;
@@ -231,7 +250,10 @@ llama_context::llama_context(
231
250
  // graph outputs buffer
232
251
  {
233
252
  // resized during inference when a batch uses more outputs
234
- if (output_reserve(params.n_seq_max) < params.n_seq_max) {
253
+ // Create a dummy batch for initialization.
254
+ llama_batch dummy_batch = {};
255
+ dummy_batch.n_tokens = 0;
256
+ if (output_reserve(params.n_seq_max, dummy_batch) < params.n_seq_max) {
235
257
  throw std::runtime_error("failed to reserve initial output buffer");
236
258
  }
237
259
 
@@ -294,8 +316,8 @@ llama_context::llama_context(
294
316
  // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
295
317
  bool pipeline_parallel =
296
318
  model.n_devices() > 1 &&
297
- model.params.n_gpu_layers > (int) model.hparams.n_layer &&
298
- model.params.split_mode == LLAMA_SPLIT_MODE_LAYER &&
319
+ model.n_gpu_layers() > model.hparams.n_layer &&
320
+ model.split_mode() == LLAMA_SPLIT_MODE_LAYER &&
299
321
  cparams.offload_kqv &&
300
322
  !model.has_tensor_overrides();
301
323
 
@@ -456,26 +478,35 @@ llama_context::llama_context(
456
478
  LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg);
457
479
  }
458
480
  }
481
+
482
+ // Initialize the full vocabulary token ids for backend samplers.
483
+ {
484
+ const int n_vocab = model.vocab.n_tokens();
485
+
486
+ sampling.token_ids_full_vocab.resize(n_vocab);
487
+ for (int i = 0; i < n_vocab; ++i) {
488
+ sampling.token_ids_full_vocab[i] = i;
489
+ }
490
+ }
459
491
  }
460
492
 
461
493
  llama_context::~llama_context() {
462
- // FIXME this currently results in a use-after-free bug if the model is freed before the context
463
- // if (!model.hparams.no_alloc) {
464
- // for (size_t i = 0; i < backend_ptrs.size(); ++i) {
465
- // ggml_backend_t backend = backend_ptrs[i];
466
- // ggml_backend_buffer_type_t buft = backend_buft[i];
467
-
468
- // const size_t size_exp = backend_buf_exp_size[i];
469
- // const size_t size_act = ggml_backend_sched_get_buffer_size(sched.get(), backend);
470
- // if (size_exp == size_act) {
471
- // LLAMA_LOG_DEBUG("%s: %10s compute buffer size is %8.4f MiB, matches expectation of %8.4f MiB\n",
472
- // __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0));
473
- // } else {
474
- // LLAMA_LOG_WARN("%s: %10s compute buffer size of %8.4f MiB, does not match expectation of %8.4f MiB\n",
475
- // __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0));
476
- // }
477
- // }
478
- // }
494
+ if (!model.hparams.no_alloc) {
495
+ for (size_t i = 0; i < backend_ptrs.size(); ++i) {
496
+ ggml_backend_t backend = backend_ptrs[i];
497
+ ggml_backend_buffer_type_t buft = backend_buft[i];
498
+
499
+ const size_t size_exp = backend_buf_exp_size[i];
500
+ const size_t size_act = ggml_backend_sched_get_buffer_size(sched.get(), backend);
501
+ if (size_exp == size_act) {
502
+ LLAMA_LOG_DEBUG("%s: %10s compute buffer size is %8.4f MiB, matches expectation of %8.4f MiB\n",
503
+ __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0));
504
+ } else {
505
+ LLAMA_LOG_WARN("%s: %10s compute buffer size of %8.4f MiB, does not match expectation of %8.4f MiB\n",
506
+ __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0));
507
+ }
508
+ }
509
+ }
479
510
  ggml_opt_free(opt_ctx);
480
511
  }
481
512
 
@@ -617,6 +648,35 @@ float * llama_context::get_logits() {
617
648
  return logits;
618
649
  }
619
650
 
651
+ int64_t llama_context::output_resolve_row(int32_t i) const {
652
+ int64_t j = -1;
653
+
654
+ // support negative indices (last output row)
655
+ if (i < 0) {
656
+ j = n_outputs + i;
657
+ if (j < 0) {
658
+ throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs));
659
+ }
660
+ } else if ((size_t) i >= output_ids.size()) {
661
+ throw std::runtime_error(format("out of range [0, %zu)", output_ids.size()));
662
+ } else {
663
+ // use output_ids to translate the batch token index into a row number
664
+ // that holds this token's data.
665
+ j = output_ids[i];
666
+ }
667
+
668
+ if (j < 0) {
669
+ // the batch token was not configured to output anything
670
+ throw std::runtime_error(format("batch.logits[%d] != true", i));
671
+ }
672
+
673
+ if (j >= n_outputs) {
674
+ throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
675
+ }
676
+
677
+ return j;
678
+ }
679
+
620
680
  float * llama_context::get_logits_ith(int32_t i) {
621
681
  int64_t j = -1;
622
682
 
@@ -627,6 +687,7 @@ float * llama_context::get_logits_ith(int32_t i) {
627
687
  throw std::runtime_error("no logits");
628
688
  }
629
689
 
690
+ // TODO: use output_resolve_row()
630
691
  if (i < 0) {
631
692
  j = n_outputs + i;
632
693
  if (j < 0) {
@@ -663,6 +724,10 @@ float * llama_context::get_embeddings() {
663
724
  return embd;
664
725
  }
665
726
 
727
+ llama_token * llama_context::get_sampled_tokens() const{
728
+ return sampling.sampled;
729
+ }
730
+
666
731
  float * llama_context::get_embeddings_ith(int32_t i) {
667
732
  int64_t j = -1;
668
733
 
@@ -673,6 +738,7 @@ float * llama_context::get_embeddings_ith(int32_t i) {
673
738
  throw std::runtime_error("no embeddings");
674
739
  }
675
740
 
741
+ // TODO: use output_resolve_row()
676
742
  if (i < 0) {
677
743
  j = n_outputs + i;
678
744
  if (j < 0) {
@@ -692,7 +758,8 @@ float * llama_context::get_embeddings_ith(int32_t i) {
692
758
  throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
693
759
  }
694
760
 
695
- return embd + j*model.hparams.n_embd;
761
+ const uint32_t n_embd_out = model.hparams.get_n_embd_out();
762
+ return embd + j*n_embd_out;
696
763
  } catch (const std::exception & err) {
697
764
  LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
698
765
  #ifndef NDEBUG
@@ -712,6 +779,136 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
712
779
  return it->second.data();
713
780
  }
714
781
 
782
+ llama_token llama_context::get_sampled_token_ith(int32_t idx) {
783
+ output_reorder();
784
+
785
+ if (sampling.sampled == nullptr) {
786
+ return LLAMA_TOKEN_NULL;
787
+ }
788
+
789
+ try {
790
+ const int64_t row = output_resolve_row(idx);
791
+ GGML_ASSERT(row < (int64_t) sampling.sampled_size);
792
+ return sampling.sampled[row];
793
+ } catch (const std::exception & err) {
794
+ LLAMA_LOG_ERROR("%s: invalid backend sampled token id %d, reason: %s\n", __func__, idx, err.what());
795
+ return LLAMA_TOKEN_NULL;
796
+ }
797
+ }
798
+
799
+ float * llama_context::get_sampled_probs_ith(int32_t idx) {
800
+ output_reorder();
801
+
802
+ if (sampling.probs == nullptr) {
803
+ return nullptr;
804
+ }
805
+
806
+ try {
807
+ const int64_t row = output_resolve_row(idx);
808
+ if ((size_t) row >= sampling.probs_count.size() || sampling.probs_count[row] == 0) {
809
+ return nullptr;
810
+ }
811
+ return sampling.probs + row*model.vocab.n_tokens();
812
+ } catch (const std::exception & err) {
813
+ LLAMA_LOG_ERROR("%s: invalid backend sampled probs id %d, reason: %s\n", __func__, idx, err.what());
814
+ return nullptr;
815
+ }
816
+ }
817
+
818
+ float * llama_context::get_sampled_logits_ith(int32_t idx) {
819
+ output_reorder();
820
+
821
+ if (sampling.logits == nullptr) {
822
+ return nullptr;
823
+ }
824
+
825
+ try {
826
+ const int64_t row = output_resolve_row(idx);
827
+ if ((size_t) row >= sampling.logits_count.size() || sampling.logits_count[row] == 0) {
828
+ return nullptr;
829
+ }
830
+ return sampling.logits + row*model.vocab.n_tokens();
831
+ } catch (const std::exception & err) {
832
+ LLAMA_LOG_ERROR("%s: invalid backend sampled logits id %d, reason: %s\n", __func__, idx, err.what());
833
+ return nullptr;
834
+ }
835
+ }
836
+
837
+ const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) {
838
+ output_reorder();
839
+
840
+ try {
841
+ const int64_t row = output_resolve_row(idx);
842
+ if (sampling.candidates != nullptr &&
843
+ (size_t) row < sampling.candidates_count.size() &&
844
+ sampling.candidates_count[row] > 0) {
845
+ return sampling.candidates + row*model.vocab.n_tokens();
846
+ }
847
+ } catch (const std::exception & err) {
848
+ // fallback to full vocab list
849
+ }
850
+
851
+ return sampling.token_ids_full_vocab.data();
852
+ }
853
+
854
+ size_t llama_context::get_sampled_candidates_count(int32_t idx) {
855
+ output_reorder();
856
+
857
+ if (sampling.candidates == nullptr) {
858
+ return 0;
859
+ }
860
+
861
+ try {
862
+ const int64_t row = output_resolve_row(idx);
863
+ if ((size_t) row >= sampling.candidates_count.size()) {
864
+ return 0;
865
+ }
866
+ return sampling.candidates_count[row];
867
+ } catch (const std::exception & err) {
868
+ LLAMA_LOG_ERROR("%s: invalid backend sampled candidates count id %d, reason: %s\n", __func__, idx, err.what());
869
+ return 0;
870
+ }
871
+ }
872
+
873
+ size_t llama_context::get_sampled_logits_count(int32_t idx) {
874
+ output_reorder();
875
+
876
+ if (sampling.logits == nullptr) {
877
+ return model.vocab.n_tokens();
878
+ }
879
+
880
+ try {
881
+ const int64_t row = output_resolve_row(idx);
882
+ if ((size_t) row >= sampling.logits_count.size()) {
883
+ return 0;
884
+ }
885
+ return sampling.logits_count[row];
886
+ } catch (const std::exception & err) {
887
+ LLAMA_LOG_ERROR("%s: invalid backend sampled logits count id %d, reason: %s\n", __func__, idx, err.what());
888
+ return 0;
889
+ }
890
+ }
891
+
892
+ size_t llama_context::get_sampled_probs_count(int32_t idx) {
893
+ output_reorder();
894
+
895
+ if (sampling.probs == nullptr) {
896
+ return 0;
897
+ }
898
+
899
+ try {
900
+ const int64_t row = output_resolve_row(idx);
901
+ if ((size_t) row >= sampling.probs_count.size()) {
902
+ return 0;
903
+ }
904
+ return sampling.probs_count[row];
905
+ } catch (const std::exception & err) {
906
+ LLAMA_LOG_ERROR("%s: invalid backend sampled probs count id %d, reason: %s\n", __func__, idx, err.what());
907
+ return 0;
908
+ }
909
+ }
910
+
911
+
715
912
  void llama_context::attach_threadpool(
716
913
  ggml_threadpool_t threadpool,
717
914
  ggml_threadpool_t threadpool_batch) {
@@ -768,6 +965,42 @@ void llama_context::set_warmup(bool value) {
768
965
  cparams.warmup = value;
769
966
  }
770
967
 
968
+ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
969
+ LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler);
970
+
971
+ const bool can_offload =
972
+ sampler &&
973
+ sampler->iface->backend_init &&
974
+ sampler->iface->backend_apply &&
975
+ llama_sampler_chain_n(sampler) > 0;
976
+
977
+ if (sampler && can_offload) {
978
+ ggml_backend_buffer_type_t buft = ggml_backend_dev_buffer_type(model.dev_output());
979
+ auto * host_buft = ggml_backend_dev_host_buffer_type(model.dev_output());
980
+ if (host_buft) {
981
+ buft = host_buft;
982
+ }
983
+
984
+ sampler->iface->backend_init(sampler, buft);
985
+
986
+ sampling.samplers[seq_id] = sampler;
987
+
988
+ return true;
989
+ }
990
+
991
+ if (sampler && !can_offload) {
992
+ LLAMA_LOG_WARN("%s: sampler '%s' for seq_id = %d, cannot be offloaded to the backend\n", __func__, llama_sampler_name(sampler), seq_id);
993
+
994
+ sampling.samplers.erase(seq_id);
995
+
996
+ return false;
997
+ }
998
+
999
+ sampling.samplers.erase(seq_id);
1000
+
1001
+ return true;
1002
+ }
1003
+
771
1004
  void llama_context::set_adapter_lora(
772
1005
  llama_adapter_lora * adapter,
773
1006
  float scale) {
@@ -908,7 +1141,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
908
1141
  n_queued_tokens += n_tokens;
909
1142
 
910
1143
  // reserve output buffer
911
- if (output_reserve(n_tokens) < n_tokens) {
1144
+ if (output_reserve(n_tokens, batch_inp) < n_tokens) {
912
1145
  LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
913
1146
  return -2;
914
1147
  };
@@ -962,9 +1195,10 @@ int llama_context::encode(const llama_batch & batch_inp) {
962
1195
  {
963
1196
  // extract token embeddings
964
1197
  GGML_ASSERT(embd != nullptr);
1198
+ const uint32_t n_embd_out = hparams.get_n_embd_out();
965
1199
 
966
- GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size);
967
- ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float));
1200
+ GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd_size);
1201
+ ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd_out*sizeof(float));
968
1202
  } break;
969
1203
  case LLAMA_POOLING_TYPE_MEAN:
970
1204
  case LLAMA_POOLING_TYPE_CLS:
@@ -1032,6 +1266,112 @@ int llama_context::encode(const llama_batch & batch_inp) {
1032
1266
  return 0;
1033
1267
  }
1034
1268
 
1269
+ static std::map<llama_seq_id, uint32_t> build_seq_to_output_row(const llama_ubatch & ubatch, uint32_t row_offset) {
1270
+ std::map<llama_seq_id, uint32_t> seq_to_row;
1271
+ // how many output tokens we have seen so far for this ubatch.
1272
+ uint32_t local = 0;
1273
+ for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
1274
+ // skip tokens that are not output.
1275
+ if (!ubatch.output[i]) {
1276
+ continue;
1277
+ }
1278
+
1279
+ const llama_seq_id seq_id = ubatch.seq_id[i][0];
1280
+ // row_offset is the number of output tokens before this ubatch.
1281
+ seq_to_row[seq_id] = row_offset + local;
1282
+ ++local;
1283
+ }
1284
+ return seq_to_row;
1285
+ }
1286
+
1287
+ static void copy_tensor_async_ints(
1288
+ const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
1289
+ llama_token * sampled,
1290
+ size_t sampled_size,
1291
+ const std::map<llama_seq_id, uint32_t> & seq_to_row,
1292
+ ggml_backend_sched_t sched) {
1293
+ if (sampled == nullptr) {
1294
+ return;
1295
+ }
1296
+
1297
+ for (const auto & [seq_id, tensor] : tensor_map) {
1298
+ auto it = seq_to_row.find(seq_id);
1299
+ if (it == seq_to_row.end()) {
1300
+ continue;
1301
+ }
1302
+
1303
+ const uint32_t row = it->second;
1304
+ GGML_ASSERT(row < sampled_size);
1305
+
1306
+ GGML_ASSERT(ggml_is_contiguous(tensor) && "sampled tokens tensor must be contiguous for async copy");
1307
+
1308
+ ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
1309
+ ggml_backend_tensor_get_async(backend, tensor, sampled + row, 0, sizeof(sampled[row]));
1310
+ }
1311
+ }
1312
+
1313
+ static void copy_tensor_async_floats(
1314
+ const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
1315
+ float * dst,
1316
+ size_t stride,
1317
+ std::vector<uint32_t> & counts,
1318
+ const std::map<llama_seq_id, uint32_t> & seq_to_row,
1319
+ ggml_backend_sched_t sched) {
1320
+ if (dst == nullptr) {
1321
+ return;
1322
+ }
1323
+
1324
+ for (const auto & [seq_id, tensor] : tensor_map) {
1325
+ auto it = seq_to_row.find(seq_id);
1326
+ if (it == seq_to_row.end()) {
1327
+ continue;
1328
+ }
1329
+
1330
+ const uint32_t row = it->second;
1331
+ GGML_ASSERT(row < counts.size());
1332
+
1333
+ GGML_ASSERT(ggml_is_contiguous(tensor) && "logits/probs tensor must be contiguous for async copy");
1334
+
1335
+ ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
1336
+ float * row_ptr = dst + (size_t) row * stride;
1337
+ ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
1338
+
1339
+ // Update the actual number of logits/probabilities that were written for this row.
1340
+ counts[row] = ggml_nelements(tensor);
1341
+ }
1342
+ }
1343
+
1344
+ static void copy_tensor_async_candidates(
1345
+ const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
1346
+ llama_token * dst,
1347
+ size_t stride,
1348
+ std::vector<uint32_t> & counts,
1349
+ const std::map<llama_seq_id, uint32_t> & seq_to_row,
1350
+ ggml_backend_sched_t sched) {
1351
+ if (dst == nullptr) {
1352
+ return;
1353
+ }
1354
+
1355
+ for (const auto & [seq_id, tensor] : tensor_map) {
1356
+ auto it = seq_to_row.find(seq_id);
1357
+ if (it == seq_to_row.end()) {
1358
+ continue;
1359
+ }
1360
+
1361
+ const uint32_t row = it->second;
1362
+ GGML_ASSERT(row < counts.size());
1363
+
1364
+ GGML_ASSERT(ggml_is_contiguous(tensor) && "candidates tensor must be contiguous for async copy");
1365
+
1366
+ ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
1367
+ llama_token * row_ptr = dst + (size_t) row * stride;
1368
+ ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
1369
+
1370
+ // Update the actual number of candidates that were written.
1371
+ counts[row] = ggml_nelements(tensor);
1372
+ }
1373
+ }
1374
+
1035
1375
  int llama_context::decode(const llama_batch & batch_inp) {
1036
1376
  GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
1037
1377
 
@@ -1052,9 +1392,36 @@ int llama_context::decode(const llama_batch & batch_inp) {
1052
1392
  const int64_t n_embd = hparams.n_embd_inp();
1053
1393
 
1054
1394
  // when computing embeddings, all tokens are output
1055
- const bool output_all = cparams.embeddings;
1395
+ const bool output_all = cparams.embeddings;
1396
+ const bool has_samplers = !sampling.samplers.empty();
1397
+
1398
+ const uint32_t n_seq_max = cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max;
1399
+
1400
+ // TODO: avoid this workaround in the future
1401
+ if (has_samplers && batch_inp.logits) {
1402
+ std::vector<int32_t> seq_output_count(n_seq_max, 0);
1403
+
1404
+ for (int32_t i = 0; i < batch_inp.n_tokens; ++i) {
1405
+ if (batch_inp.logits[i] == 0) {
1406
+ continue;
1407
+ }
1408
+
1409
+ const int ns = batch_inp.n_seq_id ? batch_inp.n_seq_id[i] : 1;
1410
+
1411
+ for (int32_t s = 0; s < ns; ++s) {
1412
+ const llama_seq_id seq_id = batch_inp.seq_id ? batch_inp.seq_id[i][s] : 0;
1056
1413
 
1057
- if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) {
1414
+ seq_output_count[seq_id]++;
1415
+ if (seq_output_count[seq_id] > 1) {
1416
+ LLAMA_LOG_ERROR("%s: backend sampling requires at most one output token per sequence (seq_id %d had %d)\n",
1417
+ __func__, seq_id, seq_output_count[seq_id]);
1418
+ return -1;
1419
+ }
1420
+ }
1421
+ }
1422
+ }
1423
+
1424
+ if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, n_seq_max, output_all)) {
1058
1425
  LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
1059
1426
  return -1;
1060
1427
  }
@@ -1135,7 +1502,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
1135
1502
  }
1136
1503
 
1137
1504
  // reserve output buffer
1138
- if (output_reserve(n_outputs_all) < n_outputs_all) {
1505
+ if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) {
1139
1506
  LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
1140
1507
  return -2;
1141
1508
  };
@@ -1208,7 +1575,10 @@ int llama_context::decode(const llama_batch & batch_inp) {
1208
1575
  }
1209
1576
 
1210
1577
  // extract logits
1211
- if (t_logits && n_outputs > 0) {
1578
+ // For multi-sequence batches that mix backend samplers and CPU sampler
1579
+ // this is currently inefficient as we copy all logits even for the
1580
+ // backend sampled tokens.
1581
+ if (logits && t_logits && n_outputs > 0) {
1212
1582
  ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
1213
1583
  GGML_ASSERT(backend_res != nullptr);
1214
1584
  GGML_ASSERT(logits != nullptr);
@@ -1223,7 +1593,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
1223
1593
  }
1224
1594
 
1225
1595
  // extract embeddings
1226
- if (t_embd && n_outputs > 0) {
1596
+ if (embd && t_embd && n_outputs > 0) {
1227
1597
  ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
1228
1598
  GGML_ASSERT(backend_embd != nullptr);
1229
1599
 
@@ -1232,12 +1602,13 @@ int llama_context::decode(const llama_batch & batch_inp) {
1232
1602
  {
1233
1603
  // extract token embeddings
1234
1604
  GGML_ASSERT(embd != nullptr);
1235
- float * embd_out = embd + n_outputs_prev*n_embd;
1605
+ const uint32_t n_embd_out = hparams.get_n_embd_out();
1606
+ float * embd_out = embd + n_outputs_prev*n_embd_out;
1236
1607
 
1237
1608
  if (n_outputs) {
1238
1609
  GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
1239
- GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size);
1240
- ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float));
1610
+ GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd_out <= (int64_t) embd_size);
1611
+ ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd_out*sizeof(float));
1241
1612
  }
1242
1613
  } break;
1243
1614
  case LLAMA_POOLING_TYPE_MEAN:
@@ -1277,6 +1648,22 @@ int llama_context::decode(const llama_batch & batch_inp) {
1277
1648
  }
1278
1649
  }
1279
1650
 
1651
+ // This flag indicates whether a backend sampler has actually sampled a specific
1652
+ // token, or if it has produced probabilites. If true, we can skip the normal copying of logits and embeddings.
1653
+ const bool has_sampled = !res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty();
1654
+
1655
+ if (has_samplers && has_sampled) {
1656
+ const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev);
1657
+ const auto stride = n_vocab;
1658
+
1659
+ // async copy the sampling data from the backend to the host
1660
+ copy_tensor_async_ints(res->t_sampled, sampling.sampled, sampling.sampled_size, seq_to_output_row, sched.get());
1661
+
1662
+ copy_tensor_async_floats (res->t_sampled_logits, sampling.logits, stride, sampling.logits_count, seq_to_output_row, sched.get());
1663
+ copy_tensor_async_floats (res->t_sampled_probs, sampling.probs, stride, sampling.probs_count, seq_to_output_row, sched.get());
1664
+ copy_tensor_async_candidates(res->t_candidates, sampling.candidates, stride, sampling.candidates_count, seq_to_output_row, sched.get());
1665
+ }
1666
+
1280
1667
  n_outputs_prev += n_outputs;
1281
1668
  } while (mctx->next());
1282
1669
 
@@ -1340,15 +1727,15 @@ int llama_context::decode(const llama_batch & batch_inp) {
1340
1727
  // output
1341
1728
  //
1342
1729
 
1343
- uint32_t llama_context::output_reserve(int32_t n_outputs) {
1730
+ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & batch) {
1344
1731
  const auto & hparams = model.hparams;
1345
1732
  const auto & vocab = model.vocab;
1346
1733
 
1347
1734
  const int64_t n_outputs_max = std::max<int64_t>(n_outputs, n_seq_max());
1348
1735
 
1349
- const auto n_batch = cparams.n_batch;
1350
- const auto n_vocab = vocab.n_tokens();
1351
- const auto n_embd = hparams.n_embd;
1736
+ const auto n_batch = cparams.n_batch;
1737
+ const auto n_vocab = vocab.n_tokens();
1738
+ const auto n_embd_out = hparams.get_n_embd_out();
1352
1739
 
1353
1740
  bool has_logits = true;
1354
1741
  bool has_embd = cparams.embeddings;
@@ -1359,8 +1746,53 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
1359
1746
  has_embd = true;
1360
1747
  }
1361
1748
 
1362
- logits_size = has_logits ? n_vocab*n_outputs_max : 0;
1363
- embd_size = has_embd ? n_embd*n_outputs_max : 0;
1749
+ // Check which sampling modes are needed for the current batch.
1750
+ // TODO: avoid this branching by working with the worst-case
1751
+ bool has_sampling = false;
1752
+ bool cpu_logits = false;
1753
+
1754
+ if (batch.logits) {
1755
+ for (int32_t i = 0; i < batch.n_tokens; i++) {
1756
+ if (!batch.logits[i]) {
1757
+ continue;
1758
+ }
1759
+ for (int32_t j = 0; j < batch.n_seq_id[i]; j++) {
1760
+ llama_seq_id seq_id = batch.seq_id[i][j];
1761
+ if (sampling.samplers.find(seq_id) != sampling.samplers.end()) {
1762
+ has_sampling = true;
1763
+ } else {
1764
+ cpu_logits = true;
1765
+ }
1766
+ }
1767
+ }
1768
+ } else {
1769
+ // When batch.logits is nullptr (when loading state with a dummy batch),
1770
+ // allocate CPU logits.
1771
+ cpu_logits = true;
1772
+ }
1773
+
1774
+ size_t backend_float_count = 0;
1775
+ size_t backend_token_count = 0;
1776
+
1777
+ // Allocate CPU logits buffer only if needed by sequences in this batch
1778
+ logits_size = (has_logits && cpu_logits) ? n_vocab*n_outputs_max : 0;
1779
+ embd_size = has_embd ? n_embd_out*n_outputs_max : 0;
1780
+
1781
+ // TODO: avoid this branching by working with the worst-case
1782
+ if (!has_sampling) {
1783
+ sampling.logits_size = 0;
1784
+ sampling.probs_size = 0;
1785
+ sampling.sampled_size = 0;
1786
+ sampling.candidates_size = 0;
1787
+ } else {
1788
+ sampling.logits_size = n_vocab*n_outputs_max;
1789
+ sampling.probs_size = n_vocab*n_outputs_max;
1790
+ sampling.sampled_size = n_outputs_max;
1791
+ sampling.candidates_size = n_vocab*n_outputs_max;
1792
+
1793
+ backend_float_count = sampling.logits_size + sampling.probs_size;
1794
+ backend_token_count = sampling.sampled_size + sampling.candidates_size;
1795
+ }
1364
1796
 
1365
1797
  if (output_ids.empty()) {
1366
1798
  // init, never resized afterwards
@@ -1368,7 +1800,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
1368
1800
  }
1369
1801
 
1370
1802
  const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0;
1371
- const size_t new_size = (logits_size + embd_size) * sizeof(float);
1803
+ const size_t new_size =
1804
+ (logits_size + embd_size + backend_float_count) * sizeof(float) +
1805
+ ( backend_token_count) * sizeof(llama_token);
1372
1806
 
1373
1807
  // alloc only when more than the current capacity is required
1374
1808
  // TODO: also consider shrinking the buffer
@@ -1376,9 +1810,11 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
1376
1810
  if (buf_output) {
1377
1811
  #ifndef NDEBUG
1378
1812
  // This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark)
1379
- LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
1813
+ LLAMA_LOG_DEBUG("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
1380
1814
  #endif
1381
1815
  synchronize();
1816
+
1817
+ // TODO: not needed?
1382
1818
  buf_output = nullptr;
1383
1819
  logits = nullptr;
1384
1820
  embd = nullptr;
@@ -1400,8 +1836,49 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
1400
1836
 
1401
1837
  float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get());
1402
1838
 
1403
- logits = has_logits ? output_base : nullptr;
1404
- embd = has_embd ? output_base + logits_size : nullptr;
1839
+ logits = nullptr;
1840
+ embd = nullptr;
1841
+
1842
+ size_t offset = 0;
1843
+ uint8_t * base = (uint8_t *) output_base;
1844
+
1845
+ logits = (has_logits && cpu_logits) ? output_base : nullptr;
1846
+ offset += logits_size * sizeof(float);
1847
+
1848
+ embd = has_embd ? (float *) (base + offset) : nullptr;
1849
+ offset += embd_size * sizeof(float);
1850
+
1851
+ sampling.logits = nullptr;
1852
+ sampling.probs = nullptr;
1853
+ sampling.sampled = nullptr;
1854
+ sampling.candidates = nullptr;
1855
+
1856
+ if (has_sampling) {
1857
+ sampling.logits = (float *) (base + offset);
1858
+ offset += sampling.logits_size * sizeof(float);
1859
+
1860
+ sampling.probs = (float *) (base + offset);
1861
+ offset += sampling.probs_size * sizeof(float);
1862
+
1863
+ sampling.sampled = (llama_token *) (base + offset);
1864
+ offset += sampling.sampled_size * sizeof(llama_token);
1865
+
1866
+ sampling.candidates = (llama_token *) (base + offset);
1867
+ offset += sampling.candidates_size * sizeof(llama_token);
1868
+
1869
+ // The count vectors keep track of the actual number of logits/probs/candidates
1870
+ // copied from the backend for each output row.
1871
+
1872
+ sampling.logits_count.resize(n_outputs_max);
1873
+ sampling.probs_count.resize(n_outputs_max);
1874
+ sampling.candidates_count.resize(n_outputs_max);
1875
+
1876
+ std::fill(sampling.logits_count.begin(), sampling.logits_count.end(), 0);
1877
+ std::fill(sampling.probs_count.begin(), sampling.probs_count.end(), 0);
1878
+ std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0);
1879
+
1880
+ std::fill_n(sampling.sampled, sampling.sampled_size, LLAMA_TOKEN_NULL);
1881
+ }
1405
1882
 
1406
1883
  // set all ids as invalid (negative)
1407
1884
  std::fill(output_ids.begin(), output_ids.end(), -1);
@@ -1430,6 +1907,40 @@ void llama_context::output_reorder() {
1430
1907
  std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]);
1431
1908
  }
1432
1909
  }
1910
+
1911
+ if (sampling.logits && sampling.logits_size > 0) {
1912
+ for (uint64_t k = 0; k < n_vocab; ++k) {
1913
+ std::swap(sampling.logits[i0*n_vocab + k], sampling.logits[i1*n_vocab + k]);
1914
+ }
1915
+ }
1916
+
1917
+ if (sampling.probs && sampling.probs_size > 0) {
1918
+ for (uint64_t k = 0; k < n_vocab; ++k) {
1919
+ std::swap(sampling.probs[i0*n_vocab + k], sampling.probs[i1*n_vocab + k]);
1920
+ }
1921
+ }
1922
+
1923
+ if (sampling.candidates && sampling.candidates_size > 0) {
1924
+ for (uint64_t k = 0; k < n_vocab; ++k) {
1925
+ std::swap(sampling.candidates[i0*n_vocab + k], sampling.candidates[i1*n_vocab + k]);
1926
+ }
1927
+ }
1928
+
1929
+ if (sampling.sampled && sampling.sampled_size > 0) {
1930
+ std::swap(sampling.sampled[i0], sampling.sampled[i1]);
1931
+ }
1932
+
1933
+ if (!sampling.logits_count.empty()) {
1934
+ std::swap(sampling.logits_count[i0], sampling.logits_count[i1]);
1935
+ }
1936
+
1937
+ if (!sampling.probs_count.empty()) {
1938
+ std::swap(sampling.probs_count[i0], sampling.probs_count[i1]);
1939
+ }
1940
+
1941
+ if (!sampling.candidates_count.empty()) {
1942
+ std::swap(sampling.candidates_count[i0], sampling.candidates_count[i1]);
1943
+ }
1433
1944
  }
1434
1945
 
1435
1946
  output_swaps.clear();
@@ -1443,7 +1954,9 @@ uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const {
1443
1954
  if (model.arch == LLM_ARCH_QWEN3NEXT) {
1444
1955
  return std::max<uint32_t>(n_tokens * 40, 32u * model.n_tensors());
1445
1956
  }
1446
- return std::max<uint32_t>(1024u, 8u*model.n_tensors());
1957
+ uint32_t res = std::max<uint32_t>(1024u, 8u*model.n_tensors());
1958
+ res += model.n_lora_nodes;
1959
+ return res;
1447
1960
  }
1448
1961
 
1449
1962
  llm_graph_result * llama_context::get_gf_res_reserve() const {
@@ -1457,7 +1970,7 @@ ggml_cgraph * llama_context::graph_reserve(
1457
1970
 
1458
1971
  if (n_tokens % n_seqs != 0) {
1459
1972
  n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs
1460
- n_outputs = std::min(n_outputs, n_tokens);
1973
+ n_outputs = std::max(n_outputs, n_tokens);
1461
1974
 
1462
1975
  LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
1463
1976
  }
@@ -1476,6 +1989,15 @@ ggml_cgraph * llama_context::graph_reserve(
1476
1989
  llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
1477
1990
  llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
1478
1991
 
1992
+ // set one output token per sequence in order to activate all backend samplers
1993
+ std::vector<llama_seq_id> seq_ids(n_seqs);
1994
+ for (uint32_t i = 0; i < n_seqs; ++i) {
1995
+ seq_ids[i] = i;
1996
+ ubatch.n_seq_id[i] = 1;
1997
+ ubatch.seq_id[i] = &seq_ids[i];
1998
+ ubatch.output[i] = true;
1999
+ }
2000
+
1479
2001
  auto * res = gf_res_reserve.get();
1480
2002
 
1481
2003
  const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT);
@@ -1506,7 +2028,7 @@ llm_graph_params llama_context::graph_params(
1506
2028
  llm_graph_result * res,
1507
2029
  const llama_ubatch & ubatch,
1508
2030
  const llama_memory_context_i * mctx,
1509
- llm_graph_type gtype) const {
2031
+ llm_graph_type gtype) const {
1510
2032
  return {
1511
2033
  /*.arch =*/ model.arch,
1512
2034
  /*.hparams =*/ model.hparams,
@@ -1519,6 +2041,7 @@ llm_graph_params llama_context::graph_params(
1519
2041
  /*.loras =*/ &loras,
1520
2042
  /*.mctx =*/ mctx,
1521
2043
  /*.cross =*/ &cross,
2044
+ /*.samplers =*/ sampling.samplers,
1522
2045
  /*.n_outputs =*/ n_outputs,
1523
2046
  /*.cb =*/ graph_get_cb(),
1524
2047
  /*.res =*/ res,
@@ -1571,7 +2094,7 @@ llm_graph_cb llama_context::graph_get_cb() const {
1571
2094
 
1572
2095
  // norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends
1573
2096
  // FIXME: fix in ggml_backend_sched
1574
- const bool full_offload = model.params.n_gpu_layers > (int) model.hparams.n_layer;
2097
+ const bool full_offload = model.n_gpu_layers() > model.hparams.n_layer;
1575
2098
  if (ubatch.n_tokens < 32 || full_offload) {
1576
2099
  if (il != -1 && strcmp(name, "norm") == 0) {
1577
2100
  const auto & dev_layer = model.dev_layer(il);
@@ -1974,6 +2497,9 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
1974
2497
  }
1975
2498
  }
1976
2499
 
2500
+ // TODO: handle sampling buffers and samplers state ?
2501
+ // https://github.com/ggml-org/llama.cpp/pull/17004
2502
+
1977
2503
  if (memory != nullptr) {
1978
2504
  LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__);
1979
2505
  memory->state_write(io);
@@ -2006,7 +2532,10 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
2006
2532
  auto n_outputs = this->n_outputs;
2007
2533
  io.read_to(&n_outputs, sizeof(n_outputs));
2008
2534
 
2009
- if (n_outputs > output_reserve(n_outputs)) {
2535
+ // Create a dummy batch for state loading.
2536
+ llama_batch dummy_batch = {};
2537
+ dummy_batch.n_tokens = 0;
2538
+ if (n_outputs > output_reserve(n_outputs, dummy_batch)) {
2010
2539
  throw std::runtime_error("could not reserve outputs");
2011
2540
  }
2012
2541
 
@@ -2060,6 +2589,9 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
2060
2589
  }
2061
2590
  }
2062
2591
 
2592
+ // TODO: handle sampling buffers and samplers state ?
2593
+ // https://github.com/ggml-org/llama.cpp/pull/17004
2594
+
2063
2595
  if (memory) {
2064
2596
  LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__);
2065
2597
 
@@ -2248,7 +2780,7 @@ void llama_context::opt_epoch_iter(
2248
2780
  }
2249
2781
 
2250
2782
  // reserve output buffer
2251
- if (output_reserve(n_outputs_all) < n_outputs_all) {
2783
+ if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) {
2252
2784
  LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
2253
2785
  GGML_ABORT("TODO: handle this error");
2254
2786
  };
@@ -2393,6 +2925,8 @@ llama_context_params llama_context_default_params() {
2393
2925
  /*.op_offload =*/ true,
2394
2926
  /*.swa_full =*/ true,
2395
2927
  /*.kv_unified =*/ false,
2928
+ /*.sampler =*/ nullptr,
2929
+ /*.n_sampler =*/ 0,
2396
2930
  };
2397
2931
 
2398
2932
  return result;
@@ -2552,7 +3086,15 @@ float * llama_get_logits(llama_context * ctx) {
2552
3086
  float * llama_get_logits_ith(llama_context * ctx, int32_t i) {
2553
3087
  ctx->synchronize();
2554
3088
 
2555
- return ctx->get_logits_ith(i);
3089
+ float * res = nullptr;
3090
+
3091
+ res = ctx->get_sampled_logits_ith(i);
3092
+
3093
+ if (!res) {
3094
+ res = ctx->get_logits_ith(i);
3095
+ }
3096
+
3097
+ return res;
2556
3098
  }
2557
3099
 
2558
3100
  float * llama_get_embeddings(llama_context * ctx) {
@@ -2573,6 +3115,52 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
2573
3115
  return ctx->get_embeddings_seq(seq_id);
2574
3116
  }
2575
3117
 
3118
+ bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) {
3119
+ return ctx->set_sampler(seq_id, smpl);
3120
+ }
3121
+
3122
+ llama_token llama_get_sampled_token_ith(llama_context * ctx, int32_t i) {
3123
+ ctx->synchronize();
3124
+
3125
+ return ctx->get_sampled_token_ith(i);
3126
+ }
3127
+
3128
+ float * llama_get_sampled_probs_ith(llama_context * ctx, int32_t i) {
3129
+ ctx->synchronize();
3130
+
3131
+ return ctx->get_sampled_probs_ith(i);
3132
+ }
3133
+
3134
+ float * llama_get_sampled_logits_ith(llama_context * ctx, int32_t i) {
3135
+ ctx->synchronize();
3136
+
3137
+ return ctx->get_sampled_logits_ith(i);
3138
+ }
3139
+
3140
+ llama_token * llama_get_sampled_candidates_ith(llama_context * ctx, int32_t i) {
3141
+ ctx->synchronize();
3142
+
3143
+ return const_cast<llama_token *>(ctx->get_sampled_candidates_ith(i));
3144
+ }
3145
+
3146
+ uint32_t llama_get_sampled_candidates_count_ith(llama_context * ctx, int32_t i) {
3147
+ ctx->synchronize();
3148
+
3149
+ return static_cast<uint32_t>(ctx->get_sampled_candidates_count(i));
3150
+ }
3151
+
3152
+ uint32_t llama_get_sampled_logits_count_ith(llama_context * ctx, int32_t i) {
3153
+ ctx->synchronize();
3154
+
3155
+ return static_cast<uint32_t>(ctx->get_sampled_logits_count(i));
3156
+ }
3157
+
3158
+ uint32_t llama_get_sampled_probs_count_ith(llama_context * ctx, int32_t i) {
3159
+ ctx->synchronize();
3160
+
3161
+ return static_cast<uint32_t>(ctx->get_sampled_probs_count(i));
3162
+ }
3163
+
2576
3164
  // llama adapter API
2577
3165
 
2578
3166
  int32_t llama_set_adapter_lora(