@novastera-oss/llamarn 0.2.4 → 0.2.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. package/RNLlamaCpp.podspec +3 -2
  2. package/android/CMakeLists.txt +6 -3
  3. package/android/src/main/cpp/include/llama.h +12 -8
  4. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  8. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  10. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  11. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  12. package/cpp/LlamaCppModel.cpp +46 -65
  13. package/cpp/LlamaCppModel.h +5 -0
  14. package/cpp/build-info.cpp +2 -2
  15. package/cpp/llama.cpp/README.md +1 -0
  16. package/cpp/llama.cpp/common/CMakeLists.txt +5 -8
  17. package/cpp/llama.cpp/common/arg.cpp +8 -6
  18. package/cpp/llama.cpp/common/chat-parser.cpp +4 -3
  19. package/cpp/llama.cpp/common/chat-parser.h +2 -1
  20. package/cpp/llama.cpp/common/chat.cpp +4 -4
  21. package/cpp/llama.cpp/common/common.cpp +2 -0
  22. package/cpp/llama.cpp/common/json-partial.cpp +5 -4
  23. package/cpp/llama.cpp/common/json-partial.h +2 -1
  24. package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +2 -1
  25. package/cpp/llama.cpp/common/json-schema-to-grammar.h +4 -4
  26. package/cpp/llama.cpp/convert_hf_to_gguf.py +31 -28
  27. package/cpp/llama.cpp/ggml/include/ggml.h +1 -3
  28. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +2 -0
  29. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +10 -5
  30. package/cpp/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +3 -3
  31. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +23 -0
  32. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +1 -0
  33. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +1 -1
  34. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +19 -8
  35. package/cpp/llama.cpp/ggml/src/ggml-impl.h +2 -0
  36. package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +2 -2
  37. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +0 -8
  38. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +118 -11
  39. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1 -1
  40. package/cpp/llama.cpp/ggml/src/ggml.c +9 -2
  41. package/cpp/llama.cpp/ggml/src/ggml.cpp +26 -0
  42. package/cpp/llama.cpp/ggml/src/gguf.cpp +19 -2
  43. package/cpp/llama.cpp/include/llama.h +12 -8
  44. package/cpp/llama.cpp/src/CMakeLists.txt +3 -0
  45. package/cpp/llama.cpp/src/llama-batch.cpp +19 -12
  46. package/cpp/llama.cpp/src/llama-batch.h +15 -10
  47. package/cpp/llama.cpp/src/llama-context.cpp +226 -151
  48. package/cpp/llama.cpp/src/llama-context.h +25 -8
  49. package/cpp/llama.cpp/src/llama-graph.cpp +50 -47
  50. package/cpp/llama.cpp/src/llama-graph.h +25 -24
  51. package/cpp/llama.cpp/src/llama-kv-cache-recurrent.cpp +1132 -0
  52. package/cpp/llama.cpp/src/llama-kv-cache-recurrent.h +191 -0
  53. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +249 -0
  54. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +136 -0
  55. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +1717 -0
  56. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +278 -0
  57. package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -2746
  58. package/cpp/llama.cpp/src/llama-kv-cache.h +14 -472
  59. package/cpp/llama.cpp/src/llama-kv-cells.h +37 -6
  60. package/cpp/llama.cpp/src/llama-memory.h +44 -0
  61. package/cpp/llama.cpp/src/llama-model.cpp +23 -16
  62. package/cpp/llama.cpp/src/llama-vocab.cpp +7 -2
  63. package/cpp/llama.cpp/vendor/cpp-httplib/httplib.h +10518 -0
  64. package/cpp/llama.cpp/vendor/miniaudio/miniaudio.h +93468 -0
  65. package/cpp/llama.cpp/{common → vendor}/minja/chat-template.hpp +1 -1
  66. package/cpp/llama.cpp/{common → vendor}/minja/minja.hpp +1 -1
  67. package/cpp/llama.cpp/{common → vendor/nlohmann}/json.hpp +3027 -2267
  68. package/cpp/llama.cpp/vendor/nlohmann/json_fwd.hpp +187 -0
  69. package/cpp/llama.cpp/vendor/stb/stb_image.h +7988 -0
  70. package/cpp/rn-completion.cpp +101 -52
  71. package/cpp/rn-utils.hpp +8 -1
  72. package/ios/include/common/minja/chat-template.hpp +1 -1
  73. package/ios/include/common/minja/minja.hpp +1 -1
  74. package/ios/include/json-schema-to-grammar.h +4 -4
  75. package/ios/include/llama.h +12 -8
  76. package/ios/include/{common → nlohmann}/json.hpp +3027 -2267
  77. package/ios/libs/llama.xcframework/Info.plist +22 -22
  78. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  79. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4689 -4617
  80. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +1 -3
  81. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +12 -8
  82. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  83. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  84. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4710 -4638
  85. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3622 -3557
  86. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
  87. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +12 -8
  88. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  89. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  90. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4710 -4638
  91. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3624 -3559
  92. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +1 -3
  93. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +12 -8
  94. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +1 -3
  95. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +12 -8
  96. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  97. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +1 -3
  98. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +12 -8
  99. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  100. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  101. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  102. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4689 -4616
  103. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +1 -3
  104. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +12 -8
  105. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  106. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  107. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4710 -4637
  108. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3622 -3556
  109. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
  110. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +12 -8
  111. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  112. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  113. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4725 -4653
  114. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +1 -3
  115. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +12 -8
  116. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  117. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  118. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4746 -4674
  119. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3652 -3587
  120. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
  121. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +12 -8
  122. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  123. package/package.json +1 -1
@@ -15,24 +15,31 @@ llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
15
15
  break;
16
16
  }
17
17
  }
18
- ubatch_token.resize(!has_embd ? n_ubatch : 0);
19
- ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0);
20
- ubatch_pos.resize(n_ubatch);
21
- ubatch_n_seq_id.resize(n_ubatch);
22
- ubatch_seq_id.resize(n_ubatch);
23
- ubatch_output.resize(n_ubatch);
18
+
19
+ udatas.push_back({});
20
+
21
+ auto & udata = udatas.back();
22
+
23
+ udata.token.resize(!has_embd ? n_ubatch : 0);
24
+ udata.embd.resize(has_embd ? n_embd * n_ubatch : 0);
25
+ udata.pos.resize(n_ubatch);
26
+ udata.n_seq_id.resize(n_ubatch);
27
+ udata.seq_id.resize(n_ubatch);
28
+ udata.output.resize(n_ubatch);
29
+
24
30
  llama_ubatch ubatch = {
25
31
  /*equal_seqs =*/ true,
26
32
  /*n_tokens =*/ 0,
27
33
  /*n_seq_tokens =*/ 0,
28
34
  /*n_seqs =*/ 0,
29
- /*token =*/ !has_embd ? ubatch_token.data() : nullptr,
30
- /*embd =*/ has_embd ? ubatch_embd.data() : nullptr,
31
- /*pos =*/ ubatch_pos.data(),
32
- /*n_seq_id =*/ ubatch_n_seq_id.data(),
33
- /*seq_id =*/ ubatch_seq_id.data(),
34
- /*output =*/ ubatch_output.data(),
35
+ /*token =*/ !has_embd ? udata.token.data() : nullptr,
36
+ /*embd =*/ has_embd ? udata.embd.data() : nullptr,
37
+ /*pos =*/ udata.pos.data(),
38
+ /*n_seq_id =*/ udata.n_seq_id.data(),
39
+ /*seq_id =*/ udata.seq_id.data(),
40
+ /*output =*/ udata.output.data(),
35
41
  };
42
+
36
43
  return ubatch;
37
44
  }
38
45
 
@@ -11,15 +11,15 @@ struct llama_ubatch {
11
11
  bool equal_seqs;
12
12
  // TODO: whole_seqs for embeddings?
13
13
 
14
- uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
14
+ uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
15
15
  uint32_t n_seq_tokens; // tokens per sequence
16
16
  uint32_t n_seqs;
17
17
 
18
18
  llama_token * token; // [n_tokens]
19
19
  float * embd; // [n_embd, n_tokens]
20
20
  llama_pos * pos; // [n_tokens]
21
- int32_t * n_seq_id; // [n_seqs]
22
- llama_seq_id ** seq_id; // [n_seqs]
21
+ int32_t * n_seq_id; // [n_seqs] // TODO: remove, should belong to only 1 sequence
22
+ llama_seq_id ** seq_id; // [n_seqs] // TODO: become llama_seq_id * seq_id;
23
23
  int8_t * output; // [n_tokens]
24
24
  };
25
25
 
@@ -49,13 +49,18 @@ struct llama_sbatch {
49
49
 
50
50
  const llama_batch * batch = nullptr;
51
51
 
52
- // buffers for the ubatch
53
- std::vector<llama_token> ubatch_token;
54
- std::vector<float> ubatch_embd;
55
- std::vector<llama_pos> ubatch_pos;
56
- std::vector<int32_t> ubatch_n_seq_id;
57
- std::vector<llama_seq_id *> ubatch_seq_id;
58
- std::vector<int8_t> ubatch_output;
52
+ // buffers for the ubatches
53
+ // TODO: very hacky, this needs a complete rework
54
+ struct ubatch_data {
55
+ std::vector<llama_token> token;
56
+ std::vector<float> embd;
57
+ std::vector<llama_pos> pos;
58
+ std::vector<int32_t> n_seq_id;
59
+ std::vector<llama_seq_id *> seq_id;
60
+ std::vector<int8_t> output;
61
+ };
62
+
63
+ std::vector<ubatch_data> udatas;
59
64
 
60
65
  llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false);
61
66
 
@@ -6,9 +6,10 @@
6
6
  #include "llama-model.h"
7
7
  #include "llama-kv-cache.h"
8
8
 
9
+ #include <cinttypes>
9
10
  #include <cstring>
11
+ #include <limits>
10
12
  #include <stdexcept>
11
- #include <cinttypes>
12
13
 
13
14
  //
14
15
  // llama_context
@@ -122,6 +123,11 @@ llama_context::llama_context(
122
123
  __func__, n_ctx_per_seq, hparams.n_ctx_train);
123
124
  }
124
125
 
126
+ if (!params.swa_full && cparams.n_seq_max > 1) {
127
+ LLAMA_LOG_WARN("%s: requested n_seq_max (%u) > 1, but swa_full is not enabled -- performance may be degraded: %s\n",
128
+ __func__, cparams.n_seq_max, "https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573");
129
+ }
130
+
125
131
  if (!hparams.vocab_only) {
126
132
  // GPU backends
127
133
  for (auto * dev : model.devices) {
@@ -259,15 +265,9 @@ llama_context::llama_context(
259
265
 
260
266
  // reserve worst-case graph
261
267
  if (!hparams.vocab_only && memory) {
262
- const uint32_t n_seqs = 1; // TODO: worst-case number of sequences
268
+ const uint32_t n_seqs = cparams.n_seq_max;
263
269
  const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
264
270
 
265
- llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
266
-
267
- // restore later
268
- // TODO: something cleaner
269
- const auto n_outputs_save = n_outputs;
270
-
271
271
  LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
272
272
 
273
273
  int n_splits_pp = -1;
@@ -279,23 +279,17 @@ llama_context::llama_context(
279
279
  // simulate full KV cache
280
280
  llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
281
281
 
282
- kv_self->set_full();
282
+ const auto kv_state = kv_self->init_full();
283
+ if (!kv_state) {
284
+ throw std::runtime_error("failed to initialize KV cache");
285
+ }
283
286
 
284
287
  cross.v_embd.clear();
285
288
 
286
289
  // reserve pp graph first so that buffers are only allocated once
287
290
  {
288
- llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
289
-
290
- // max number of outputs
291
- n_outputs = ubatch_pp.n_tokens;
292
-
293
- LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
294
-
295
- auto * gf = graph_init();
296
- graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
297
-
298
- if (!ggml_backend_sched_reserve(sched.get(), gf)) {
291
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
292
+ if (!gf) {
299
293
  throw std::runtime_error("failed to allocate compute pp buffers");
300
294
  }
301
295
 
@@ -305,16 +299,8 @@ llama_context::llama_context(
305
299
 
306
300
  // reserve with tg graph to get the number of splits and nodes
307
301
  {
308
- llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
309
-
310
- n_outputs = ubatch_tg.n_tokens;
311
-
312
- LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_tg.n_tokens, ubatch_tg.n_seqs);
313
-
314
- auto * gf = graph_init();
315
- graph_build(ctx_compute.get(), gf, ubatch_tg, LLM_GRAPH_TYPE_DEFAULT);
316
-
317
- if (!ggml_backend_sched_reserve(sched.get(), gf)) {
302
+ auto * gf = graph_reserve(1, 1, 1, kv_state.get());
303
+ if (!gf) {
318
304
  throw std::runtime_error("failed to allocate compute tg buffers");
319
305
  }
320
306
 
@@ -324,22 +310,12 @@ llama_context::llama_context(
324
310
 
325
311
  // reserve again with pp graph to avoid ggml-alloc reallocations during inference
326
312
  {
327
- llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
328
-
329
- n_outputs = ubatch_pp.n_tokens;
330
-
331
- LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
332
-
333
- auto * gf = graph_init();
334
- graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
335
-
336
- if (!ggml_backend_sched_reserve(sched.get(), gf)) {
313
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
314
+ if (!gf) {
337
315
  throw std::runtime_error("failed to allocate compute pp buffers");
338
316
  }
339
317
  }
340
318
 
341
- n_outputs = n_outputs_save;
342
-
343
319
  for (size_t i = 0; i < backend_ptrs.size(); ++i) {
344
320
  ggml_backend_t backend = backend_ptrs[i];
345
321
  ggml_backend_buffer_type_t buft = backend_buft[i];
@@ -453,36 +429,33 @@ const llama_kv_cache * llama_context::get_kv_self() const {
453
429
  return kv_self;
454
430
  }
455
431
 
456
- void llama_context::kv_self_update() {
457
- bool need_reserve = false;
432
+ bool llama_context::kv_self_update() {
433
+ if (!memory) {
434
+ return false;
435
+ }
458
436
 
459
437
  llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
460
438
 
461
- need_reserve = kv_self->update(*this);
462
-
463
- // reserve a worst case graph if needed
464
- if (need_reserve) {
465
- LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__);
466
-
467
- // build worst-case graph
468
- uint32_t n_seqs = 1; // TODO: worst-case number of sequences
469
- uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
470
-
471
- // simulate full KV cache
472
- kv_self->set_full();
439
+ if (!kv_self->update(*this)) {
440
+ // no updates have been performed
441
+ return false;
442
+ }
473
443
 
474
- llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
475
- llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
444
+ // if the KV cache did any computation, we have to reserve a new worst-case graph
445
+ const auto kv_state = kv_self->init_full();
446
+ if (!kv_state) {
447
+ throw std::runtime_error("failed to initialize KV cache");
448
+ }
476
449
 
477
- auto * gf = graph_init();
478
- graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
450
+ const uint32_t n_seqs = cparams.n_seq_max;
451
+ const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
479
452
 
480
- // initialize scheduler with the worst-case graph
481
- ggml_backend_sched_reset(sched.get());
482
- if (!ggml_backend_sched_reserve(sched.get(), gf)) {
483
- LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
484
- }
453
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
454
+ if (!gf) {
455
+ LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__);
485
456
  }
457
+
458
+ return true;
486
459
  }
487
460
 
488
461
  enum llama_pooling_type llama_context::pooling_type() const {
@@ -676,6 +649,49 @@ bool llama_context::apply_adapter_cvec(
676
649
  return cvec.apply(model, data, len, n_embd, il_start, il_end);
677
650
  }
678
651
 
652
+ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_state_i * mstate, ggml_status & ret) {
653
+ if (mstate && !mstate->apply()) {
654
+ LLAMA_LOG_ERROR("%s: failed to apply memory state\n", __func__);
655
+ ret = GGML_STATUS_FAILED;
656
+ return nullptr;
657
+ }
658
+
659
+ auto * gf = graph_init();
660
+ if (!gf) {
661
+ LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
662
+ ret = GGML_STATUS_FAILED;
663
+ return nullptr;
664
+ }
665
+
666
+ auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mstate);
667
+ if (!res) {
668
+ LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
669
+ ret = GGML_STATUS_FAILED;
670
+ return nullptr;
671
+ }
672
+
673
+ // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
674
+
675
+ if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
676
+ LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
677
+ ret = GGML_STATUS_ALLOC_FAILED;
678
+ return nullptr;
679
+ }
680
+
681
+ res->set_inputs(&ubatch);
682
+
683
+ const auto status = graph_compute(gf, ubatch.n_tokens > 1);
684
+ if (status != GGML_STATUS_SUCCESS) {
685
+ LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status);
686
+ ret = status;
687
+ return nullptr;
688
+ }
689
+
690
+ ret = GGML_STATUS_SUCCESS;
691
+
692
+ return res;
693
+ }
694
+
679
695
  int llama_context::encode(llama_batch & inp_batch) {
680
696
  if (inp_batch.n_tokens == 0) {
681
697
  LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
@@ -737,8 +753,6 @@ int llama_context::encode(llama_batch & inp_batch) {
737
753
 
738
754
  n_outputs = n_tokens;
739
755
 
740
- //batch_manager->prepare(ubatch);
741
-
742
756
  ggml_backend_sched_reset(sched.get());
743
757
  ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
744
758
 
@@ -749,26 +763,18 @@ int llama_context::encode(llama_batch & inp_batch) {
749
763
  // ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223
750
764
  cparams.causal_attn = false;
751
765
 
752
- auto * gf = graph_init();
753
- auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_ENCODER);
754
-
755
- ggml_backend_sched_alloc_graph(sched.get(), gf);
756
-
757
- res->set_inputs(&ubatch);
766
+ ggml_status status;
767
+ const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
758
768
 
759
769
  cparams.causal_attn = causal_attn_org;
760
770
 
761
- const auto compute_status = graph_compute(gf, n_tokens > 1);
762
- switch (compute_status) {
763
- case GGML_STATUS_SUCCESS:
764
- break;
765
- case GGML_STATUS_ABORTED:
766
- return 2;
767
- case GGML_STATUS_ALLOC_FAILED:
768
- return -2;
769
- case GGML_STATUS_FAILED:
770
- default:
771
- return -3;
771
+ if (!res) {
772
+ switch (status) {
773
+ case GGML_STATUS_ABORTED: return 2;
774
+ case GGML_STATUS_ALLOC_FAILED: return -2;
775
+ case GGML_STATUS_FAILED: return -3;
776
+ case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen");
777
+ }
772
778
  }
773
779
 
774
780
  auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
@@ -889,8 +895,6 @@ int llama_context::decode(llama_batch & inp_batch) {
889
895
  const int64_t n_tokens_all = batch.n_tokens;
890
896
  const int64_t n_embd = hparams.n_embd;
891
897
 
892
- llama_kv_cache_guard kv_guard(kv_self);
893
-
894
898
  GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
895
899
 
896
900
  // TODO: move the validation to the llama_batch_allocr
@@ -936,7 +940,48 @@ int llama_context::decode(llama_batch & inp_batch) {
936
940
  n_outputs_all = 1;
937
941
  }
938
942
 
939
- llama_sbatch sbatch = kv_self->sbatch_init(batch, /* logits_all */ n_outputs_all == n_tokens_all);
943
+ // handle any pending defrags/shifts
944
+ kv_self_update();
945
+
946
+ llama_memory_state_ptr kv_state;
947
+
948
+ bool did_defrag = false;
949
+
950
+ while (true) {
951
+ kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
952
+ if (!kv_state) {
953
+ return -2;
954
+ }
955
+
956
+ switch (kv_state->get_status()) {
957
+ case LLAMA_MEMORY_STATUS_SUCCESS:
958
+ {
959
+ } break;
960
+ case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
961
+ {
962
+ if (!did_defrag) {
963
+ did_defrag = true;
964
+
965
+ kv_self->defrag_sched(-1.0f);
966
+ if (kv_self_update()) {
967
+ LLAMA_LOG_DEBUG("%s: failed to init batch of size %d, retrying after defrag\n", __func__, batch.n_tokens);
968
+
969
+ continue;
970
+ }
971
+ }
972
+
973
+ LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
974
+
975
+ return 1;
976
+ }
977
+ case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
978
+ {
979
+ return -2;
980
+ }
981
+ }
982
+
983
+ break;
984
+ }
940
985
 
941
986
  // reserve output buffer
942
987
  if (output_reserve(n_outputs_all) < n_outputs_all) {
@@ -944,13 +989,10 @@ int llama_context::decode(llama_batch & inp_batch) {
944
989
  return -2;
945
990
  };
946
991
 
947
- // handle any pending defrags/shifts
948
- kv_self_update();
949
-
950
992
  int64_t n_outputs_prev = 0;
951
993
 
952
- while (sbatch.n_tokens > 0) {
953
- llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
994
+ do {
995
+ const auto & ubatch = kv_state->get_ubatch();
954
996
 
955
997
  // count the outputs in this u_batch
956
998
  {
@@ -969,33 +1011,37 @@ int llama_context::decode(llama_batch & inp_batch) {
969
1011
  n_outputs = n_outputs_new;
970
1012
  }
971
1013
 
972
- // find KV slot
973
- if (!kv_self->find_slot(ubatch)) {
974
- return 1;
975
- }
976
-
977
1014
  ggml_backend_sched_reset(sched.get());
978
1015
  ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
979
1016
 
980
- auto * gf = graph_init();
981
- auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DECODER);
1017
+ ggml_status status;
1018
+ const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, kv_state.get(), status);
1019
+
1020
+ if (!res) {
1021
+ // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
1022
+ llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES] = { std::numeric_limits<llama_pos>::max() };
982
1023
 
983
- // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
1024
+ for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
1025
+ const auto & seq_id = ubatch.seq_id[i][0];
984
1026
 
985
- ggml_backend_sched_alloc_graph(sched.get(), gf);
1027
+ pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]);
1028
+ }
986
1029
 
987
- res->set_inputs(&ubatch);
1030
+ for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
1031
+ if (pos_min[s] == std::numeric_limits<llama_pos>::max()) {
1032
+ continue;
1033
+ }
988
1034
 
989
- const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1);
990
- if (compute_status != GGML_STATUS_SUCCESS) {
991
- switch (compute_status) {
992
- case GGML_STATUS_ABORTED:
993
- return 2;
994
- case GGML_STATUS_ALLOC_FAILED:
995
- return -2;
996
- case GGML_STATUS_FAILED:
997
- default:
998
- return -3;
1035
+ LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
1036
+
1037
+ llama_kv_self_seq_rm(this, s, pos_min[s], -1);
1038
+ }
1039
+
1040
+ switch (status) {
1041
+ case GGML_STATUS_ABORTED: return 2;
1042
+ case GGML_STATUS_ALLOC_FAILED: return -2;
1043
+ case GGML_STATUS_FAILED: return -3;
1044
+ case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen");
999
1045
  }
1000
1046
  }
1001
1047
 
@@ -1082,10 +1128,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1082
1128
  }
1083
1129
 
1084
1130
  n_outputs_prev += n_outputs;
1085
- }
1086
-
1087
- // finalize the batch processing
1088
- kv_guard.commit();
1131
+ } while (kv_state->next());
1089
1132
 
1090
1133
  // set to total number of outputs in the batch, for use in llama_get_logits_ith
1091
1134
  n_outputs = n_outputs_all;
@@ -1094,7 +1137,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1094
1137
  {
1095
1138
  bool sorted_output = true;
1096
1139
 
1097
- auto & out_ids = sbatch.out_ids;
1140
+ auto & out_ids = kv_state->out_ids();
1098
1141
 
1099
1142
  GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
1100
1143
 
@@ -1254,11 +1297,52 @@ ggml_cgraph * llama_context::graph_init() {
1254
1297
  return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
1255
1298
  }
1256
1299
 
1300
+ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate) {
1301
+ LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
1302
+
1303
+ if (n_tokens % n_seqs != 0) {
1304
+ n_tokens = (n_tokens / n_seqs) * n_seqs;
1305
+ n_outputs = std::min(n_outputs, n_tokens);
1306
+
1307
+ LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
1308
+ }
1309
+
1310
+ // store the n_outputs as it is, and restore it afterwards
1311
+ // TODO: not sure if needed, might simplify in the future by removing this
1312
+ const auto save_n_outputs = this->n_outputs;
1313
+
1314
+ this->n_outputs = n_outputs;
1315
+
1316
+ llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
1317
+ llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
1318
+
1319
+ auto * gf = graph_init();
1320
+ auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate);
1321
+
1322
+ this->n_outputs = save_n_outputs;
1323
+
1324
+ if (!res) {
1325
+ LLAMA_LOG_ERROR("%s: failed to build worst-case graph\n", __func__);
1326
+ return nullptr;
1327
+ }
1328
+
1329
+ ggml_backend_sched_reset(sched.get());
1330
+
1331
+ // initialize scheduler with the specified graph
1332
+ if (!ggml_backend_sched_reserve(sched.get(), gf)) {
1333
+ LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
1334
+ return nullptr;
1335
+ }
1336
+
1337
+ return gf;
1338
+ }
1339
+
1257
1340
  llm_graph_result_ptr llama_context::graph_build(
1258
- ggml_context * ctx,
1259
- ggml_cgraph * gf,
1260
- const llama_ubatch & ubatch,
1261
- llm_graph_type gtype) {
1341
+ ggml_context * ctx,
1342
+ ggml_cgraph * gf,
1343
+ const llama_ubatch & ubatch,
1344
+ llm_graph_type gtype,
1345
+ const llama_memory_state_i * mstate) {
1262
1346
  return model.build_graph(
1263
1347
  {
1264
1348
  /*.ctx =*/ ctx,
@@ -1270,7 +1354,7 @@ llm_graph_result_ptr llama_context::graph_build(
1270
1354
  /*.backend_cpu =*/ backend_cpu,
1271
1355
  /*.cvec =*/ &cvec,
1272
1356
  /*.loras =*/ &loras,
1273
- /*.memory =*/ memory.get(),
1357
+ /*.mstate =*/ mstate,
1274
1358
  /*.cross =*/ &cross,
1275
1359
  /*.n_outputs =*/ n_outputs,
1276
1360
  /*.cb =*/ graph_get_cb(),
@@ -1951,7 +2035,6 @@ void llama_context::opt_epoch_iter(
1951
2035
  llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1952
2036
 
1953
2037
  kv_self->clear();
1954
- llama_kv_cache_guard kv_guard(kv_self);
1955
2038
 
1956
2039
  for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
1957
2040
  batch.n_tokens = n_batch;
@@ -1974,7 +2057,11 @@ void llama_context::opt_epoch_iter(
1974
2057
 
1975
2058
  int64_t n_outputs_all = n_tokens_all;
1976
2059
 
1977
- llama_sbatch sbatch = kv_self->sbatch_init(batch, /*logits_all =*/ true);
2060
+ auto kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true);
2061
+ if (!kv_state || kv_state->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
2062
+ LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
2063
+ break;
2064
+ }
1978
2065
 
1979
2066
  // reserve output buffer
1980
2067
  if (output_reserve(n_outputs_all) < n_outputs_all) {
@@ -1982,20 +2069,19 @@ void llama_context::opt_epoch_iter(
1982
2069
  GGML_ABORT("TODO: handle this error");
1983
2070
  };
1984
2071
 
1985
- for (uint32_t pos_batch = 0; pos_batch < n_batch; pos_batch += n_ubatch) {
1986
- llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
2072
+ uint32_t pos_batch = 0;
2073
+ do {
2074
+ const auto & ubatch = kv_state->get_ubatch();
1987
2075
 
1988
2076
  n_outputs = ubatch.n_tokens;
1989
2077
 
1990
- // TODO: not sure if this is needed
1991
- if (!kv_self->find_slot(ubatch)) {
1992
- LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
1993
-
1994
- GGML_ABORT("TODO: handle this error");
2078
+ if (!kv_state->apply()) {
2079
+ LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__);
2080
+ break;
1995
2081
  }
1996
2082
 
1997
2083
  auto * gf = graph_init();
1998
- auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
2084
+ auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, kv_state.get());
1999
2085
 
2000
2086
  struct ggml_context * ctx_compute_opt;
2001
2087
  {
@@ -2010,6 +2096,7 @@ void llama_context::opt_epoch_iter(
2010
2096
  }
2011
2097
  ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits());
2012
2098
  ggml_opt_alloc(opt_ctx, train);
2099
+
2013
2100
  res->set_inputs(&ubatch);
2014
2101
  {
2015
2102
  struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
@@ -2027,10 +2114,10 @@ void llama_context::opt_epoch_iter(
2027
2114
  callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start);
2028
2115
  }
2029
2116
  ggml_free(ctx_compute_opt);
2030
- }
2031
- }
2032
2117
 
2033
- kv_guard.commit();
2118
+ pos_batch += ubatch.n_tokens;
2119
+ } while (kv_state->next());
2120
+ }
2034
2121
  }
2035
2122
 
2036
2123
  void llama_context::opt_epoch(
@@ -2194,6 +2281,7 @@ llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
2194
2281
  return ctx->get_kv_self();
2195
2282
  }
2196
2283
 
2284
+ // deprecated
2197
2285
  void llama_kv_self_update(llama_context * ctx) {
2198
2286
  ctx->kv_self_update();
2199
2287
  }
@@ -2448,6 +2536,7 @@ llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
2448
2536
  return kv->seq_pos_max(seq_id);
2449
2537
  }
2450
2538
 
2539
+ // deprecated
2451
2540
  void llama_kv_self_defrag(llama_context * ctx) {
2452
2541
  auto * kv = ctx->get_kv_self();
2453
2542
  if (!kv) {
@@ -2589,22 +2678,8 @@ int32_t llama_encode(
2589
2678
  int32_t llama_decode(
2590
2679
  llama_context * ctx,
2591
2680
  llama_batch batch) {
2592
- int ret = ctx->decode(batch);
2593
-
2594
- // defrag and try again
2595
- // TODO: distinguish return code when we are sure that even after defrag there is no space available
2596
- if (ret == 1) {
2597
- llama_kv_self_defrag(ctx);
2598
- ret = ctx->decode(batch);
2599
-
2600
- if (ret == 1) {
2601
- LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
2602
-
2603
- return ret;
2604
- }
2605
- }
2606
-
2607
- if (ret != 0) {
2681
+ const int ret = ctx->decode(batch);
2682
+ if (ret != 0 && ret != 1) {
2608
2683
  LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
2609
2684
  }
2610
2685