@fugood/llama.node 0.3.16 → 0.3.17

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (202) hide show
  1. package/CMakeLists.txt +3 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  7. package/bin/linux-cuda/x64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  10. package/bin/win32/arm64/llama-node.node +0 -0
  11. package/bin/win32/arm64/node.lib +0 -0
  12. package/bin/win32/x64/llama-node.node +0 -0
  13. package/bin/win32/x64/node.lib +0 -0
  14. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/arm64/node.lib +0 -0
  16. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  17. package/bin/win32-vulkan/x64/node.lib +0 -0
  18. package/lib/binding.ts +5 -0
  19. package/package.json +1 -1
  20. package/src/LlamaCompletionWorker.cpp +8 -0
  21. package/src/LlamaCompletionWorker.h +1 -0
  22. package/src/LlamaContext.cpp +3 -2
  23. package/src/llama.cpp/.github/workflows/build-linux-cross.yml +124 -0
  24. package/src/llama.cpp/.github/workflows/build.yml +70 -27
  25. package/src/llama.cpp/.github/workflows/docker.yml +6 -6
  26. package/src/llama.cpp/.github/workflows/server.yml +7 -11
  27. package/src/llama.cpp/CMakeLists.txt +23 -1
  28. package/src/llama.cpp/common/CMakeLists.txt +6 -3
  29. package/src/llama.cpp/common/arg.cpp +809 -105
  30. package/src/llama.cpp/common/arg.h +9 -0
  31. package/src/llama.cpp/common/chat.cpp +1 -1
  32. package/src/llama.cpp/common/common.cpp +31 -521
  33. package/src/llama.cpp/common/common.h +17 -36
  34. package/src/llama.cpp/common/json-schema-to-grammar.cpp +3 -0
  35. package/src/llama.cpp/common/llguidance.cpp +30 -47
  36. package/src/llama.cpp/common/minja/chat-template.hpp +15 -7
  37. package/src/llama.cpp/common/minja/minja.hpp +119 -93
  38. package/src/llama.cpp/common/sampling.cpp +3 -0
  39. package/src/llama.cpp/docs/build.md +122 -7
  40. package/src/llama.cpp/examples/CMakeLists.txt +0 -9
  41. package/src/llama.cpp/examples/batched/batched.cpp +1 -1
  42. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +1 -1
  43. package/src/llama.cpp/examples/embedding/embedding.cpp +7 -1
  44. package/src/llama.cpp/examples/export-lora/export-lora.cpp +1 -1
  45. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +15 -16
  46. package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
  47. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +210 -8
  48. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  49. package/src/llama.cpp/examples/llava/CMakeLists.txt +39 -24
  50. package/src/llama.cpp/examples/llava/clip-impl.h +345 -0
  51. package/src/llama.cpp/examples/llava/clip.cpp +2152 -1803
  52. package/src/llama.cpp/examples/llava/clip.h +39 -22
  53. package/src/llama.cpp/examples/llava/deprecation-warning.cpp +22 -0
  54. package/src/llama.cpp/examples/llava/llava.cpp +64 -52
  55. package/src/llama.cpp/examples/llava/mtmd-cli.cpp +344 -0
  56. package/src/llama.cpp/examples/llava/mtmd.cpp +708 -0
  57. package/src/llama.cpp/examples/llava/mtmd.h +168 -0
  58. package/src/llama.cpp/examples/llava/{qwen2vl-cli.cpp → qwen2vl-test.cpp} +83 -31
  59. package/src/llama.cpp/examples/main/main.cpp +16 -5
  60. package/src/llama.cpp/examples/parallel/parallel.cpp +3 -1
  61. package/src/llama.cpp/examples/passkey/passkey.cpp +1 -1
  62. package/src/llama.cpp/examples/perplexity/perplexity.cpp +17 -3
  63. package/src/llama.cpp/examples/quantize/quantize.cpp +115 -2
  64. package/src/llama.cpp/examples/rpc/CMakeLists.txt +4 -2
  65. package/src/llama.cpp/examples/rpc/rpc-server.cpp +163 -8
  66. package/src/llama.cpp/examples/run/CMakeLists.txt +12 -1
  67. package/src/llama.cpp/examples/run/run.cpp +14 -28
  68. package/src/llama.cpp/examples/server/httplib.h +313 -247
  69. package/src/llama.cpp/examples/server/server.cpp +238 -139
  70. package/src/llama.cpp/examples/server/utils.hpp +51 -2
  71. package/src/llama.cpp/examples/speculative/speculative.cpp +1 -1
  72. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
  73. package/src/llama.cpp/examples/sycl/build.sh +2 -2
  74. package/src/llama.cpp/examples/sycl/win-build-sycl.bat +2 -2
  75. package/src/llama.cpp/examples/tts/tts.cpp +6 -9
  76. package/src/llama.cpp/ggml/CMakeLists.txt +8 -2
  77. package/src/llama.cpp/ggml/cmake/GitVars.cmake +22 -0
  78. package/src/llama.cpp/ggml/include/ggml-cpu.h +5 -0
  79. package/src/llama.cpp/ggml/include/ggml-rpc.h +6 -1
  80. package/src/llama.cpp/ggml/include/ggml.h +66 -99
  81. package/src/llama.cpp/ggml/src/CMakeLists.txt +10 -7
  82. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +0 -2
  83. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +8 -4
  84. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +5 -5
  85. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +692 -1534
  86. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +613 -122
  87. package/src/llama.cpp/ggml/src/ggml-cann/common.h +135 -1
  88. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +507 -137
  89. package/src/llama.cpp/ggml/src/ggml-common.h +12 -6
  90. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +48 -22
  91. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
  92. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
  93. package/src/llama.cpp/ggml/src/ggml-cpu/common.h +72 -0
  94. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +1 -1
  95. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +896 -192
  96. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +2 -21
  97. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +754 -404
  98. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1003 -13519
  99. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -0
  100. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +2 -7
  101. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +0 -1
  102. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +3 -4
  103. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +533 -88
  104. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +8809 -0
  105. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +110 -0
  106. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +892 -0
  107. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
  108. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +28 -0
  109. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +258 -0
  110. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +802 -0
  111. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +7 -0
  112. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +1 -0
  113. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +0 -4
  114. package/src/llama.cpp/ggml/src/ggml-impl.h +52 -18
  115. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +70 -3
  116. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +67 -119
  117. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1023 -260
  118. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +293 -40
  119. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +96 -22
  120. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  121. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +350 -0
  122. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.hpp +39 -0
  123. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +0 -35
  124. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +2 -292
  125. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +79 -90
  126. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +967 -438
  127. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +22 -23
  128. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +24 -20
  129. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +1 -4
  130. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +204 -280
  131. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +84 -74
  132. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +1 -3
  133. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +37 -49
  134. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +7 -22
  135. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +4 -14
  136. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +204 -118
  137. package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +1 -3
  138. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +23 -0
  139. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +646 -114
  140. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +12 -0
  141. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +17 -8
  142. package/src/llama.cpp/ggml/src/ggml.c +141 -245
  143. package/src/llama.cpp/ggml/src/gguf.cpp +1 -0
  144. package/src/llama.cpp/include/llama.h +30 -11
  145. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.inp +112 -0
  146. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.out +46 -0
  147. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +112 -0
  148. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.out +46 -0
  149. package/src/llama.cpp/requirements/requirements-all.txt +2 -0
  150. package/src/llama.cpp/requirements/requirements-gguf_editor_gui.txt +3 -0
  151. package/src/llama.cpp/src/CMakeLists.txt +3 -2
  152. package/src/llama.cpp/src/llama-adapter.cpp +37 -1
  153. package/src/llama.cpp/src/llama-arch.cpp +160 -17
  154. package/src/llama.cpp/src/llama-arch.h +16 -0
  155. package/src/llama.cpp/src/llama-chat.cpp +82 -17
  156. package/src/llama.cpp/src/llama-chat.h +6 -2
  157. package/src/llama.cpp/src/llama-context.cpp +108 -92
  158. package/src/llama.cpp/src/llama-context.h +1 -2
  159. package/src/llama.cpp/src/llama-graph.cpp +189 -119
  160. package/src/llama.cpp/src/llama-graph.h +26 -6
  161. package/src/llama.cpp/src/llama-hparams.h +13 -0
  162. package/src/llama.cpp/src/llama-kv-cache.cpp +70 -123
  163. package/src/llama.cpp/src/llama-kv-cache.h +41 -115
  164. package/src/llama.cpp/src/llama-memory.h +1 -1
  165. package/src/llama.cpp/src/llama-mmap.cpp +1 -1
  166. package/src/llama.cpp/src/llama-model-loader.cpp +10 -5
  167. package/src/llama.cpp/src/llama-model-loader.h +5 -3
  168. package/src/llama.cpp/src/llama-model.cpp +1760 -534
  169. package/src/llama.cpp/src/llama-model.h +13 -1
  170. package/src/llama.cpp/src/llama-quant.cpp +29 -8
  171. package/src/llama.cpp/src/llama-sampling.cpp +7 -1
  172. package/src/llama.cpp/src/llama-vocab.cpp +44 -6
  173. package/src/llama.cpp/src/llama.cpp +1 -1
  174. package/src/llama.cpp/tests/CMakeLists.txt +43 -30
  175. package/src/llama.cpp/tests/test-arg-parser.cpp +51 -4
  176. package/src/llama.cpp/tests/test-backend-ops.cpp +82 -43
  177. package/src/llama.cpp/tests/test-chat-template.cpp +34 -13
  178. package/src/llama.cpp/tests/test-chat.cpp +12 -2
  179. package/src/llama.cpp/{examples/gbnf-validator/gbnf-validator.cpp → tests/test-gbnf-validator.cpp} +2 -2
  180. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -2
  181. package/src/llama.cpp/tests/test-grammar-llguidance.cpp +63 -2
  182. package/src/llama.cpp/tests/test-grammar-parser.cpp +3 -1
  183. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -1
  184. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -1
  185. package/src/llama.cpp/{examples/quantize-stats/quantize-stats.cpp → tests/test-quantize-stats.cpp} +3 -1
  186. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +2 -1
  187. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +2 -1
  188. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +0 -5
  189. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +0 -341
  190. package/src/llama.cpp/examples/llava/llava-cli.cpp +0 -332
  191. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +0 -354
  192. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +0 -6
  193. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  194. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  195. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  196. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  197. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  198. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  199. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  200. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  201. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  202. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
@@ -11,8 +11,6 @@
11
11
  #include <map>
12
12
  #include <stdexcept>
13
13
 
14
- static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false};
15
-
16
14
  llama_kv_cache_unified::llama_kv_cache_unified(const llama_hparams & hparams, callbacks cbs) : hparams(hparams), cbs(std::move(cbs)) {
17
15
  }
18
16
 
@@ -29,7 +27,7 @@ bool llama_kv_cache_unified::init(
29
27
 
30
28
  recurrent = llama_model_is_recurrent(&model);
31
29
  v_trans = !recurrent && !cparams.flash_attn;
32
- can_shift = !recurrent && model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA
30
+ can_shift = !recurrent;
33
31
 
34
32
  LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n",
35
33
  __func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift);
@@ -133,7 +131,7 @@ int32_t llama_kv_cache_unified::get_n_tokens() const {
133
131
  return result;
134
132
  }
135
133
 
136
- uint32_t llama_kv_cache_unified::get_used_cells() const {
134
+ int32_t llama_kv_cache_unified::get_used_cells() const {
137
135
  return used;
138
136
  }
139
137
 
@@ -206,6 +204,8 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
206
204
  return false;
207
205
  }
208
206
  }
207
+
208
+ return true;
209
209
  }
210
210
 
211
211
  for (uint32_t i = 0; i < size; ++i) {
@@ -428,7 +428,7 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
428
428
  }
429
429
  }
430
430
 
431
- llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) {
431
+ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
432
432
  llama_pos result = 0;
433
433
 
434
434
  for (uint32_t i = 0; i < size; ++i) {
@@ -446,16 +446,71 @@ void llama_kv_cache_unified::defrag() {
446
446
  }
447
447
  }
448
448
 
449
+ void llama_kv_cache_unified::restore() {
450
+ if (pending.ranges.empty()) {
451
+ return;
452
+ }
453
+
454
+ // TODO: tmp - move to llama_kv_cache_recurrent
455
+ if (recurrent) {
456
+ seq_rm(-1, -1, -1);
457
+ return;
458
+ }
459
+
460
+ uint32_t new_head = size;
461
+
462
+ for (auto & range : pending.ranges) {
463
+ for (uint32_t i = range.c0; i < range.c1; ++i) {
464
+ cells[i].seq_id.clear();
465
+
466
+ // keep count of the number of used cells
467
+ if (cells[i].pos >= 0) {
468
+ used--;
469
+ }
470
+
471
+ cells[i].pos = -1;
472
+ cells[i].src = -1;
473
+ }
474
+
475
+ new_head = std::min(new_head, range.c0);
476
+ }
477
+
478
+ if (new_head != size && new_head < head) {
479
+ head = new_head;
480
+ }
481
+ }
482
+
483
+ void llama_kv_cache_unified::commit() {
484
+ // TODO: tmp - move to llama_kv_cache_recurrent
485
+ if (recurrent) {
486
+ return;
487
+ }
488
+
489
+ if (pending.ranges.empty()) {
490
+ LLAMA_LOG_WARN("%s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n",
491
+ __func__, "https://github.com/ggml-org/llama.cpp/pull/12695");
492
+ return;
493
+ }
494
+
495
+ pending.ranges.clear();
496
+ }
497
+
449
498
  bool llama_kv_cache_unified::get_can_shift() const {
450
499
  return can_shift;
451
500
  }
452
501
 
453
- llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
502
+ bool llama_kv_cache_unified::find_slot(
454
503
  const llama_ubatch & ubatch) {
455
504
  const uint32_t n_tokens = ubatch.n_tokens;
456
505
  const uint32_t n_seqs = ubatch.n_seqs;
457
506
  const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
458
507
 
508
+ // if we have enough unused cells before the current head ->
509
+ // better to start searching from the beginning of the cache, hoping to fill it
510
+ if (head > used + 2*ubatch.n_tokens) {
511
+ head = 0;
512
+ }
513
+
459
514
  if (recurrent) {
460
515
  // For recurrent state architectures (like Mamba or RWKV),
461
516
  // each cache cell can store the state for a whole sequence.
@@ -477,7 +532,7 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
477
532
  // too big seq_id
478
533
  // TODO: would it be possible to resize the cache instead?
479
534
  LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size);
480
- return llama_kv_cache_slot_info_failed;
535
+ return false;
481
536
  }
482
537
  if (j > 0) {
483
538
  llama_kv_cell & seq = cells[seq_id];
@@ -616,14 +671,14 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
616
671
  [](const llama_kv_cell& cell){ return !cell.is_empty(); });
617
672
 
618
673
  // sanity check
619
- return llama_kv_cache_slot_info(n >= n_seqs);
674
+ return n >= n_seqs;
620
675
  }
621
676
 
622
677
  // otherwise, one cell per token.
623
678
 
624
679
  if (n_tokens > size) {
625
680
  LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %d\n", __func__, n_tokens, size);
626
- return llama_kv_cache_slot_info_failed;
681
+ return false;
627
682
  }
628
683
 
629
684
  uint32_t n_tested = 0;
@@ -651,7 +706,7 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
651
706
 
652
707
  if (n_tested >= size) {
653
708
  //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
654
- return llama_kv_cache_slot_info_failed;
709
+ return false;
655
710
  }
656
711
  }
657
712
 
@@ -668,7 +723,9 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
668
723
 
669
724
  used += n_tokens;
670
725
 
671
- return llama_kv_cache_slot_info(head, head + n_tokens);
726
+ pending.ranges.push_back({head, head + n_tokens});
727
+
728
+ return true;
672
729
  }
673
730
 
674
731
  uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) const {
@@ -1033,6 +1090,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1033
1090
  LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
1034
1091
  return false;
1035
1092
  }
1093
+ commit();
1036
1094
 
1037
1095
  // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
1038
1096
  // Assume that this is one contiguous block of cells
@@ -1220,117 +1278,6 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1220
1278
  return true;
1221
1279
  }
1222
1280
 
1223
- //
1224
- // interface implementation
1225
- //
1226
-
1227
- int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv) {
1228
- if (!kv) {
1229
- return 0;
1230
- }
1231
-
1232
- return kv->get_n_tokens();
1233
- }
1234
-
1235
- int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv) {
1236
- if (!kv) {
1237
- return 0;
1238
- }
1239
-
1240
- return kv->get_used_cells();
1241
- }
1242
-
1243
- void llama_kv_cache_clear(llama_kv_cache * kv) {
1244
- if (!kv) {
1245
- return;
1246
- }
1247
-
1248
- kv->clear();
1249
- }
1250
-
1251
- bool llama_kv_cache_seq_rm(
1252
- llama_kv_cache * kv,
1253
- llama_seq_id seq_id,
1254
- llama_pos p0,
1255
- llama_pos p1) {
1256
- if (!kv) {
1257
- return true;
1258
- }
1259
-
1260
- return kv->seq_rm(seq_id, p0, p1);
1261
- }
1262
-
1263
- void llama_kv_cache_seq_cp(
1264
- llama_kv_cache * kv,
1265
- llama_seq_id seq_id_src,
1266
- llama_seq_id seq_id_dst,
1267
- llama_pos p0,
1268
- llama_pos p1) {
1269
- if (!kv) {
1270
- return;
1271
- }
1272
-
1273
- kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
1274
- }
1275
-
1276
- void llama_kv_cache_seq_keep(llama_kv_cache * kv, llama_seq_id seq_id) {
1277
- if (!kv) {
1278
- return;
1279
- }
1280
-
1281
- kv->seq_keep(seq_id);
1282
- }
1283
-
1284
- void llama_kv_cache_seq_add(
1285
- llama_kv_cache * kv,
1286
- llama_seq_id seq_id,
1287
- llama_pos p0,
1288
- llama_pos p1,
1289
- llama_pos delta) {
1290
- if (!kv) {
1291
- return;
1292
- }
1293
-
1294
- kv->seq_add(seq_id, p0, p1, delta);
1295
- }
1296
-
1297
- void llama_kv_cache_seq_div(
1298
- llama_kv_cache * kv,
1299
- llama_seq_id seq_id,
1300
- llama_pos p0,
1301
- llama_pos p1,
1302
- int d) {
1303
- if (!kv) {
1304
- return;
1305
- }
1306
-
1307
- kv->seq_div(seq_id, p0, p1, d);
1308
- }
1309
-
1310
- llama_pos llama_kv_cache_seq_pos_max(llama_kv_cache * kv, llama_seq_id seq_id) {
1311
- if (!kv) {
1312
- return 0;
1313
- }
1314
-
1315
- return kv->seq_pos_max(seq_id);
1316
- }
1317
-
1318
- void llama_kv_cache_defrag(llama_kv_cache * kv) {
1319
- if (!kv) {
1320
- return;
1321
- }
1322
-
1323
- kv->defrag();
1324
- }
1325
-
1326
- bool llama_kv_cache_can_shift(const llama_kv_cache * kv) {
1327
- if (!kv) {
1328
- return false;
1329
- }
1330
-
1331
- return kv->get_can_shift();
1332
- }
1333
-
1334
1281
  //
1335
1282
  // kv cache view
1336
1283
  //
@@ -1340,7 +1287,7 @@ llama_kv_cache_view llama_kv_cache_view_init(const llama_kv_cache & kv, int32_t
1340
1287
  /*.n_cells = */ 0,
1341
1288
  /*.n_seq_max = */ n_seq_max,
1342
1289
  /*.token_count = */ 0,
1343
- /*.used_cells = */ llama_kv_cache_used_cells(&kv),
1290
+ /*.used_cells = */ kv.get_used_cells(),
1344
1291
  /*.max_contiguous = */ 0,
1345
1292
  /*.max_contiguous_idx = */ -1,
1346
1293
  /*.cells = */ nullptr,
@@ -17,17 +17,35 @@ struct llama_ubatch;
17
17
  struct llama_kv_cache : public llama_memory_i {
18
18
  using llama_memory_i::llama_memory_i;
19
19
 
20
- virtual int32_t get_n_tokens() const = 0;
21
- virtual uint32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
20
+ virtual void restore() = 0; // call if batch processing fails - restores the cache state
21
+ virtual void commit() = 0; // call after successful batch processing - clears any pending state
22
+
23
+ virtual int32_t get_n_tokens() const = 0;
24
+ virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
22
25
 
23
26
  virtual bool get_can_shift() const = 0;
24
27
 
25
28
  bool get_can_edit() const override { return get_can_shift(); }
26
29
  };
27
30
 
31
+ struct llama_kv_cache_guard {
32
+ llama_kv_cache_guard(llama_kv_cache * kv) : kv(kv) {}
33
+
34
+ ~llama_kv_cache_guard() {
35
+ kv->restore();
36
+ }
37
+
38
+ void commit() {
39
+ kv->commit();
40
+ }
41
+
42
+ private:
43
+ llama_kv_cache * kv;
44
+ };
45
+
28
46
  struct llama_kv_cell {
29
47
  llama_pos pos = -1;
30
- llama_pos delta = 0;
48
+ llama_pos delta = 0;
31
49
  int32_t src = -1; // used by recurrent state models to copy states
32
50
  int32_t tail = -1;
33
51
 
@@ -46,17 +64,6 @@ struct llama_kv_cell {
46
64
  }
47
65
  };
48
66
 
49
- // a structure holds information about the slot found in llama_kv_cache_find_slot
50
- struct llama_kv_cache_slot_info {
51
- std::pair<uint32_t, uint32_t> boundaries; // slot boundaries [begin, end)
52
- bool found = false; // the slot was found
53
-
54
- explicit llama_kv_cache_slot_info(bool found_) : found{found_} {}
55
- llama_kv_cache_slot_info(uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true} {}
56
-
57
- operator bool() const { return found; }
58
- };
59
-
60
67
  // ring-buffer of cached KV data
61
68
  // TODO: pimpl
62
69
  // TODO: add notion of max sequences
@@ -82,8 +89,8 @@ public:
82
89
  uint32_t kv_size,
83
90
  bool offload);
84
91
 
85
- int32_t get_n_tokens() const override;
86
- uint32_t get_used_cells() const override;
92
+ int32_t get_n_tokens() const override;
93
+ int32_t get_used_cells() const override;
87
94
 
88
95
  size_t total_size() const;
89
96
 
@@ -93,22 +100,24 @@ public:
93
100
  void clear() override;
94
101
  void defrag() override;
95
102
 
103
+ virtual void restore() override;
104
+ virtual void commit() override;
105
+
96
106
  bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
97
107
  void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
98
108
  void seq_keep(llama_seq_id seq_id) override;
99
109
  void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
100
110
  void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
101
111
 
102
- llama_pos seq_pos_max(llama_seq_id seq_id) override;
112
+ llama_pos seq_pos_max(llama_seq_id seq_id) const override;
103
113
 
104
114
  bool get_can_shift() const override;
105
115
 
106
116
  // find an empty slot of size "n_tokens" in the cache
107
117
  // updates the cache head
108
- // returns a structure holding information about the slot found
109
118
  // Note: On success, it's important that cache.head points
110
119
  // to the first cell of the slot.
111
- llama_kv_cache_slot_info find_slot(const llama_ubatch & batch);
120
+ bool find_slot(const llama_ubatch & batch);
112
121
 
113
122
  // TODO: maybe not needed
114
123
  uint32_t get_padding(const llama_cparams & cparams) const;
@@ -128,7 +137,19 @@ public:
128
137
  // return true if cells have been moved
129
138
  bool defrag_prepare(int32_t n_max_nodes);
130
139
 
131
- // state save/load
140
+ // commit/restore cache
141
+
142
+ struct slot_range {
143
+ uint32_t c0 = 0; // note: these are cell indices, not sequence positions
144
+ uint32_t c1 = 0;
145
+ };
146
+
147
+ // pending cell updates that are not yet committed
148
+ struct {
149
+ std::vector<slot_range> ranges;
150
+ } pending;
151
+
152
+ // state write/load
132
153
 
133
154
  void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const;
134
155
  void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1);
@@ -183,101 +204,6 @@ private:
183
204
  // using llama_kv_cache_unified::llama_kv_cache_unified;
184
205
  //};
185
206
 
186
- //
187
- // kv cache restore
188
- //
189
-
190
- // saves the kv_cache state for future recovery.
191
- // used to rollback llama_kv_cache_find_slot changes.
192
- struct llama_kv_slot_restorer {
193
- struct llama_kv_cache_state {
194
- uint32_t head = 0;
195
- uint32_t n = 0;
196
- } old_state;
197
-
198
- // for non-recurrent models only
199
- // list of slots to restore
200
- std::vector<std::pair<uint32_t, uint32_t>> slot_boundaries;
201
-
202
- bool do_restore = false;
203
-
204
- llama_kv_cache_unified & cache;
205
-
206
- explicit llama_kv_slot_restorer(llama_kv_cache_unified & cache) : cache(cache) {
207
- old_state.head = cache.head;
208
- old_state.n = cache.n;
209
- }
210
-
211
- // saves a slot information for future restoration
212
- void save(const llama_kv_cache_slot_info & slot) {
213
- if (slot) {
214
- do_restore = true;
215
- if (slot.boundaries.first != slot.boundaries.second) {
216
- slot_boundaries.push_back(slot.boundaries);
217
- }
218
- }
219
- }
220
-
221
- // must be explicitly called to restore the kv_cache state
222
- // and rollback changes from all llama_kv_cache_find_slot calls
223
- void restore() {
224
- if (do_restore) {
225
- cache.head = old_state.head;
226
- cache.n = old_state.n;
227
-
228
- if (cache.recurrent) { // recurrent models like Mamba or RWKV can't have a state partially erased
229
- cache.seq_rm(-1, -1, -1);
230
- } else {
231
- for (auto & slot : slot_boundaries) {
232
- cache.seq_rm(-1, slot.first, slot.second);
233
- }
234
- }
235
- }
236
- }
237
- };
238
-
239
- // TODO: maybe become part of the public llama_kv_cache in the future
240
- int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv);
241
-
242
- int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv);
243
-
244
- void llama_kv_cache_clear(llama_kv_cache * kv);
245
-
246
- bool llama_kv_cache_seq_rm(
247
- llama_kv_cache * kv,
248
- llama_seq_id seq_id,
249
- llama_pos p0,
250
- llama_pos p1);
251
-
252
- void llama_kv_cache_seq_cp(
253
- llama_kv_cache * kv,
254
- llama_seq_id seq_id_src,
255
- llama_seq_id seq_id_dst,
256
- llama_pos p0,
257
- llama_pos p1);
258
-
259
- void llama_kv_cache_seq_keep(llama_kv_cache * kv, llama_seq_id seq_id);
260
-
261
- void llama_kv_cache_seq_add(
262
- llama_kv_cache * kv,
263
- llama_seq_id seq_id,
264
- llama_pos p0,
265
- llama_pos p1,
266
- llama_pos delta);
267
-
268
- void llama_kv_cache_seq_div(
269
- llama_kv_cache * kv,
270
- llama_seq_id seq_id,
271
- llama_pos p0,
272
- llama_pos p1,
273
- int d);
274
-
275
- llama_pos llama_kv_cache_seq_pos_max(llama_kv_cache * kv, llama_seq_id seq_id);
276
-
277
- void llama_kv_cache_defrag(llama_kv_cache * kv);
278
-
279
- bool llama_kv_cache_can_shift(const llama_kv_cache * kv);
280
-
281
207
  //
282
208
  // kv cache view
283
209
  //
@@ -15,7 +15,7 @@ public:
15
15
  virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) = 0;
16
16
  virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
17
17
 
18
- virtual llama_pos seq_pos_max(llama_seq_id seq_id) = 0;
18
+ virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
19
19
 
20
20
  virtual bool get_can_edit() const = 0;
21
21
  };
@@ -476,7 +476,7 @@ struct llama_mlock::impl {
476
476
 
477
477
  char* errmsg = std::strerror(errno);
478
478
  bool suggest = (errno == ENOMEM);
479
- #if defined(TARGET_OS_VISION) || defined(TARGET_OS_TV)
479
+ #if defined(TARGET_OS_VISION) || defined(TARGET_OS_TV) || defined(_AIX)
480
480
  // visionOS/tvOS dont't support RLIMIT_MEMLOCK
481
481
  // Skip resource limit checks on visionOS/tvOS
482
482
  suggest = false;
@@ -445,7 +445,8 @@ llama_model_loader::llama_model_loader(
445
445
  std::vector<std::string> & splits,
446
446
  bool use_mmap,
447
447
  bool check_tensors,
448
- const struct llama_model_kv_override * param_overrides_p) {
448
+ const llama_model_kv_override * param_overrides_p,
449
+ const llama_model_tensor_buft_override * param_tensor_buft_overrides_p) {
449
450
  int trace = 0;
450
451
  if (getenv("LLAMA_TRACE")) {
451
452
  trace = atoi(getenv("LLAMA_TRACE"));
@@ -457,6 +458,8 @@ llama_model_loader::llama_model_loader(
457
458
  }
458
459
  }
459
460
 
461
+ tensor_buft_overrides = param_tensor_buft_overrides_p;
462
+
460
463
  // Load the main GGUF
461
464
  struct ggml_context * ctx = NULL;
462
465
  struct gguf_init_params params = {
@@ -600,7 +603,9 @@ llama_model_loader::llama_model_loader(
600
603
 
601
604
  if (trace > 0) {
602
605
  const uint16_t sid = w.idx;
603
- LLAMA_LOG_INFO("%s: - tensor split %2d: %32s %-8s [ %s ]\n", __func__, sid, ggml_get_name(tensor), ggml_type_name(type), llama_format_tensor_shape(tensor).c_str());
606
+ LLAMA_LOG_INFO("%s: - tensor split %2d: %32s %-8s [ %s ] %8.2f MiB\n", __func__,
607
+ sid, ggml_get_name(tensor), ggml_type_name(type), llama_format_tensor_shape(tensor).c_str(),
608
+ ggml_nbytes(tensor)/1024.0f/1024.0f);
604
609
  }
605
610
  }
606
611
 
@@ -640,9 +645,9 @@ llama_model_loader::llama_model_loader(
640
645
  ftype = (llama_ftype) (ftype | LLAMA_FTYPE_GUESSED);
641
646
 
642
647
  {
643
- const int kid = gguf_find_key(meta.get(), "general.file_type"); // TODO: use LLM_KV
644
- if (kid >= 0) {
645
- ftype = (llama_ftype) gguf_get_val_u32(meta.get(), kid);
648
+ uint32_t ftype_val = 0;
649
+ if (get_key(LLM_KV_GENERAL_FILE_TYPE, ftype_val, false)) {
650
+ ftype = (llama_ftype) ftype_val;
646
651
  }
647
652
  }
648
653
 
@@ -77,8 +77,9 @@ struct llama_model_loader {
77
77
 
78
78
  llama_mmaps mappings;
79
79
 
80
- std::map<std::string, struct llama_tensor_weight, weight_name_comparer> weights_map;
81
- std::unordered_map<std::string, struct llama_model_kv_override> kv_overrides;
80
+ std::map<std::string, llama_tensor_weight, weight_name_comparer> weights_map;
81
+ std::unordered_map<std::string, llama_model_kv_override> kv_overrides;
82
+ const llama_model_tensor_buft_override * tensor_buft_overrides;
82
83
 
83
84
  gguf_context_ptr meta;
84
85
  std::vector<ggml_context_ptr> contexts;
@@ -95,7 +96,8 @@ struct llama_model_loader {
95
96
  std::vector<std::string> & splits, // optional, only need if the split does not follow naming scheme
96
97
  bool use_mmap,
97
98
  bool check_tensors,
98
- const struct llama_model_kv_override * param_overrides_p);
99
+ const llama_model_kv_override * param_overrides_p,
100
+ const llama_model_tensor_buft_override * param_tensor_buft_overrides_p);
99
101
 
100
102
  template<typename T>
101
103
  typename std::enable_if<std::is_integral<T>::value, bool>::type