@fugood/llama.node 0.3.15 → 0.3.17

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (203) hide show
  1. package/CMakeLists.txt +3 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  7. package/bin/linux-cuda/x64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  10. package/bin/win32/arm64/llama-node.node +0 -0
  11. package/bin/win32/arm64/node.lib +0 -0
  12. package/bin/win32/x64/llama-node.node +0 -0
  13. package/bin/win32/x64/node.lib +0 -0
  14. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/arm64/node.lib +0 -0
  16. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  17. package/bin/win32-vulkan/x64/node.lib +0 -0
  18. package/lib/binding.ts +5 -0
  19. package/package.json +1 -1
  20. package/src/LlamaCompletionWorker.cpp +8 -0
  21. package/src/LlamaCompletionWorker.h +1 -0
  22. package/src/LlamaContext.cpp +3 -2
  23. package/src/llama.cpp/.github/workflows/build-linux-cross.yml +124 -0
  24. package/src/llama.cpp/.github/workflows/build.yml +70 -27
  25. package/src/llama.cpp/.github/workflows/docker.yml +6 -6
  26. package/src/llama.cpp/.github/workflows/server.yml +7 -11
  27. package/src/llama.cpp/CMakeLists.txt +23 -1
  28. package/src/llama.cpp/common/CMakeLists.txt +6 -3
  29. package/src/llama.cpp/common/arg.cpp +809 -105
  30. package/src/llama.cpp/common/arg.h +9 -0
  31. package/src/llama.cpp/common/chat.cpp +1 -1
  32. package/src/llama.cpp/common/common.cpp +31 -521
  33. package/src/llama.cpp/common/common.h +17 -36
  34. package/src/llama.cpp/common/json-schema-to-grammar.cpp +3 -0
  35. package/src/llama.cpp/common/llguidance.cpp +30 -47
  36. package/src/llama.cpp/common/minja/chat-template.hpp +15 -7
  37. package/src/llama.cpp/common/minja/minja.hpp +119 -93
  38. package/src/llama.cpp/common/sampling.cpp +3 -0
  39. package/src/llama.cpp/docs/build.md +122 -7
  40. package/src/llama.cpp/examples/CMakeLists.txt +0 -9
  41. package/src/llama.cpp/examples/batched/batched.cpp +1 -1
  42. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +1 -1
  43. package/src/llama.cpp/examples/embedding/embedding.cpp +7 -1
  44. package/src/llama.cpp/examples/export-lora/export-lora.cpp +1 -1
  45. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +15 -16
  46. package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
  47. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +210 -8
  48. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  49. package/src/llama.cpp/examples/llava/CMakeLists.txt +39 -24
  50. package/src/llama.cpp/examples/llava/clip-impl.h +345 -0
  51. package/src/llama.cpp/examples/llava/clip.cpp +2152 -1803
  52. package/src/llama.cpp/examples/llava/clip.h +39 -22
  53. package/src/llama.cpp/examples/llava/deprecation-warning.cpp +22 -0
  54. package/src/llama.cpp/examples/llava/llava.cpp +64 -52
  55. package/src/llama.cpp/examples/llava/mtmd-cli.cpp +344 -0
  56. package/src/llama.cpp/examples/llava/mtmd.cpp +708 -0
  57. package/src/llama.cpp/examples/llava/mtmd.h +168 -0
  58. package/src/llama.cpp/examples/llava/{qwen2vl-cli.cpp → qwen2vl-test.cpp} +83 -31
  59. package/src/llama.cpp/examples/main/main.cpp +16 -5
  60. package/src/llama.cpp/examples/parallel/parallel.cpp +3 -1
  61. package/src/llama.cpp/examples/passkey/passkey.cpp +1 -1
  62. package/src/llama.cpp/examples/perplexity/perplexity.cpp +17 -3
  63. package/src/llama.cpp/examples/quantize/quantize.cpp +115 -2
  64. package/src/llama.cpp/examples/rpc/CMakeLists.txt +4 -2
  65. package/src/llama.cpp/examples/rpc/rpc-server.cpp +163 -8
  66. package/src/llama.cpp/examples/run/CMakeLists.txt +12 -1
  67. package/src/llama.cpp/examples/run/run.cpp +14 -28
  68. package/src/llama.cpp/examples/server/httplib.h +313 -247
  69. package/src/llama.cpp/examples/server/server.cpp +243 -139
  70. package/src/llama.cpp/examples/server/utils.hpp +51 -2
  71. package/src/llama.cpp/examples/speculative/speculative.cpp +1 -1
  72. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
  73. package/src/llama.cpp/examples/sycl/build.sh +2 -2
  74. package/src/llama.cpp/examples/sycl/win-build-sycl.bat +2 -2
  75. package/src/llama.cpp/examples/tts/tts.cpp +14 -9
  76. package/src/llama.cpp/ggml/CMakeLists.txt +8 -2
  77. package/src/llama.cpp/ggml/cmake/GitVars.cmake +22 -0
  78. package/src/llama.cpp/ggml/include/ggml-cpu.h +5 -0
  79. package/src/llama.cpp/ggml/include/ggml-rpc.h +6 -1
  80. package/src/llama.cpp/ggml/include/ggml.h +66 -99
  81. package/src/llama.cpp/ggml/src/CMakeLists.txt +15 -8
  82. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +0 -2
  83. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +8 -4
  84. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +5 -5
  85. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +692 -1534
  86. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +613 -122
  87. package/src/llama.cpp/ggml/src/ggml-cann/common.h +135 -1
  88. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +507 -137
  89. package/src/llama.cpp/ggml/src/ggml-common.h +12 -6
  90. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +48 -22
  91. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
  92. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
  93. package/src/llama.cpp/ggml/src/ggml-cpu/common.h +72 -0
  94. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +1 -1
  95. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2413 -228
  96. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +2 -21
  97. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +754 -404
  98. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1004 -13516
  99. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -0
  100. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +2 -7
  101. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +0 -1
  102. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +3 -4
  103. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +533 -88
  104. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +8809 -0
  105. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +110 -0
  106. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +892 -0
  107. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
  108. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +28 -0
  109. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +258 -0
  110. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +802 -0
  111. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +7 -0
  112. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +1 -0
  113. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +0 -4
  114. package/src/llama.cpp/ggml/src/ggml-impl.h +52 -18
  115. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +70 -3
  116. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +67 -119
  117. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1023 -260
  118. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +293 -40
  119. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +127 -33
  120. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  121. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +350 -0
  122. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.hpp +39 -0
  123. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +0 -35
  124. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +29 -293
  125. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +79 -90
  126. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +967 -438
  127. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +22 -23
  128. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +12 -43
  129. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +24 -20
  130. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +1 -4
  131. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +210 -286
  132. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +84 -74
  133. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +1 -3
  134. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +37 -49
  135. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +7 -22
  136. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +4 -14
  137. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +204 -118
  138. package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +1 -3
  139. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +23 -0
  140. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +692 -126
  141. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +12 -0
  142. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +21 -10
  143. package/src/llama.cpp/ggml/src/ggml.c +141 -245
  144. package/src/llama.cpp/ggml/src/gguf.cpp +1 -0
  145. package/src/llama.cpp/include/llama.h +30 -11
  146. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.inp +112 -0
  147. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.out +46 -0
  148. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +112 -0
  149. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.out +46 -0
  150. package/src/llama.cpp/requirements/requirements-all.txt +2 -0
  151. package/src/llama.cpp/requirements/requirements-gguf_editor_gui.txt +3 -0
  152. package/src/llama.cpp/src/CMakeLists.txt +3 -2
  153. package/src/llama.cpp/src/llama-adapter.cpp +37 -1
  154. package/src/llama.cpp/src/llama-arch.cpp +161 -17
  155. package/src/llama.cpp/src/llama-arch.h +16 -0
  156. package/src/llama.cpp/src/llama-chat.cpp +82 -17
  157. package/src/llama.cpp/src/llama-chat.h +6 -2
  158. package/src/llama.cpp/src/llama-context.cpp +108 -92
  159. package/src/llama.cpp/src/llama-context.h +1 -2
  160. package/src/llama.cpp/src/llama-graph.cpp +189 -119
  161. package/src/llama.cpp/src/llama-graph.h +26 -6
  162. package/src/llama.cpp/src/llama-hparams.h +13 -0
  163. package/src/llama.cpp/src/llama-kv-cache.cpp +70 -123
  164. package/src/llama.cpp/src/llama-kv-cache.h +41 -115
  165. package/src/llama.cpp/src/llama-memory.h +1 -1
  166. package/src/llama.cpp/src/llama-mmap.cpp +1 -1
  167. package/src/llama.cpp/src/llama-model-loader.cpp +10 -5
  168. package/src/llama.cpp/src/llama-model-loader.h +5 -3
  169. package/src/llama.cpp/src/llama-model.cpp +1544 -291
  170. package/src/llama.cpp/src/llama-model.h +13 -1
  171. package/src/llama.cpp/src/llama-quant.cpp +29 -8
  172. package/src/llama.cpp/src/llama-sampling.cpp +7 -1
  173. package/src/llama.cpp/src/llama-vocab.cpp +44 -6
  174. package/src/llama.cpp/src/llama.cpp +1 -1
  175. package/src/llama.cpp/tests/CMakeLists.txt +43 -30
  176. package/src/llama.cpp/tests/test-arg-parser.cpp +51 -4
  177. package/src/llama.cpp/tests/test-backend-ops.cpp +139 -57
  178. package/src/llama.cpp/tests/test-chat-template.cpp +34 -13
  179. package/src/llama.cpp/tests/test-chat.cpp +12 -2
  180. package/src/llama.cpp/{examples/gbnf-validator/gbnf-validator.cpp → tests/test-gbnf-validator.cpp} +2 -2
  181. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -2
  182. package/src/llama.cpp/tests/test-grammar-llguidance.cpp +63 -2
  183. package/src/llama.cpp/tests/test-grammar-parser.cpp +3 -1
  184. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -1
  185. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -1
  186. package/src/llama.cpp/{examples/quantize-stats/quantize-stats.cpp → tests/test-quantize-stats.cpp} +3 -1
  187. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +2 -1
  188. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +2 -1
  189. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +0 -5
  190. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +0 -341
  191. package/src/llama.cpp/examples/llava/llava-cli.cpp +0 -332
  192. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +0 -354
  193. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +0 -6
  194. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  195. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  196. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  197. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  198. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  199. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  200. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  201. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  202. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  203. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
@@ -10,6 +10,7 @@
10
10
  #include <cstring>
11
11
  #include <stdexcept>
12
12
  #include <cinttypes>
13
+ #include <cmath>
13
14
 
14
15
  //
15
16
  // llama_context
@@ -113,7 +114,7 @@ llama_context::llama_context(
113
114
  }
114
115
 
115
116
  if (n_ctx_per_seq > hparams.n_ctx_train) {
116
- LLAMA_LOG_WARN("%s: n_ctx_pre_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
117
+ LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
117
118
  __func__, n_ctx_per_seq, hparams.n_ctx_train);
118
119
  }
119
120
 
@@ -255,7 +256,8 @@ llama_context::llama_context(
255
256
  model.n_devices() > 1 &&
256
257
  model.params.n_gpu_layers > (int) model.hparams.n_layer &&
257
258
  model.params.split_mode == LLAMA_SPLIT_MODE_LAYER &&
258
- cparams.offload_kqv;
259
+ cparams.offload_kqv &&
260
+ !model.has_tensor_overrides();
259
261
 
260
262
  // pipeline parallelism requires support for async compute and events in all devices
261
263
  if (pipeline_parallel) {
@@ -294,10 +296,7 @@ llama_context::llama_context(
294
296
  // TODO: something cleaner
295
297
  const auto n_outputs_save = n_outputs;
296
298
 
297
- // max number of outputs
298
- n_outputs = n_tokens;
299
-
300
- LLAMA_LOG_DEBUG("%s: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
299
+ LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
301
300
 
302
301
  int n_splits_pp = -1;
303
302
  int n_nodes_pp = -1;
@@ -313,8 +312,15 @@ llama_context::llama_context(
313
312
  // reserve pp graph first so that buffers are only allocated once
314
313
  {
315
314
  llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
315
+
316
+ // max number of outputs
317
+ n_outputs = ubatch_pp.n_tokens;
318
+
319
+ LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
320
+
316
321
  auto * gf = graph_init();
317
322
  graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
323
+
318
324
  if (!ggml_backend_sched_reserve(sched.get(), gf)) {
319
325
  throw std::runtime_error("failed to allocate compute pp buffers");
320
326
  }
@@ -326,11 +332,18 @@ llama_context::llama_context(
326
332
  // reserve with tg graph to get the number of splits and nodes
327
333
  {
328
334
  llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
335
+
336
+ n_outputs = ubatch_tg.n_tokens;
337
+
338
+ LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_tg.n_tokens, ubatch_tg.n_seqs);
339
+
329
340
  auto * gf = graph_init();
330
341
  graph_build(ctx_compute.get(), gf, ubatch_tg, LLM_GRAPH_TYPE_DEFAULT);
342
+
331
343
  if (!ggml_backend_sched_reserve(sched.get(), gf)) {
332
344
  throw std::runtime_error("failed to allocate compute tg buffers");
333
345
  }
346
+
334
347
  n_splits_tg = ggml_backend_sched_get_n_splits(sched.get());
335
348
  n_nodes_tg = ggml_graph_n_nodes(gf);
336
349
  }
@@ -338,8 +351,14 @@ llama_context::llama_context(
338
351
  // reserve again with pp graph to avoid ggml-alloc reallocations during inference
339
352
  {
340
353
  llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
354
+
355
+ n_outputs = ubatch_pp.n_tokens;
356
+
357
+ LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
358
+
341
359
  auto * gf = graph_init();
342
360
  graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
361
+
343
362
  if (!ggml_backend_sched_reserve(sched.get(), gf)) {
344
363
  throw std::runtime_error("failed to allocate compute pp buffers");
345
364
  }
@@ -450,12 +469,10 @@ ggml_tensor * llama_context::build_rope_shift(
450
469
  ggml_tensor * shift,
451
470
  ggml_tensor * factors,
452
471
  float freq_base,
453
- float freq_scale,
454
- ggml_backend_buffer * bbuf) const {
472
+ float freq_scale) const {
455
473
  const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
456
474
 
457
475
  const auto & yarn_ext_factor = cparams.yarn_ext_factor;
458
- const auto & yarn_attn_factor = cparams.yarn_attn_factor;
459
476
  const auto & yarn_beta_fast = cparams.yarn_beta_fast;
460
477
  const auto & yarn_beta_slow = cparams.yarn_beta_slow;
461
478
 
@@ -464,23 +481,17 @@ ggml_tensor * llama_context::build_rope_shift(
464
481
  const auto & n_rot = hparams.n_rot;
465
482
  const auto & rope_type = hparams.rope_type;
466
483
 
484
+ // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
485
+ // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
486
+ const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor;
487
+
467
488
  ggml_tensor * tmp;
468
489
 
469
490
  if (ggml_is_quantized(cur->type)) {
470
491
  // dequantize to f32 -> RoPE -> quantize back
471
492
  tmp = ggml_cast(ctx0, cur, GGML_TYPE_F32);
472
493
 
473
- if (bbuf) {
474
- for (const auto & backend : backends) {
475
- // Figure out which backend KV cache belongs to
476
- if (ggml_backend_supports_buft(backend.get(), ggml_backend_buffer_get_type(bbuf))) {
477
- ggml_backend_sched_set_tensor_backend(sched.get(), tmp, backend.get());
478
- break;
479
- }
480
- }
481
- }
482
-
483
- tmp = ggml_rope_ext_inplace(ctx0, tmp,
494
+ tmp = ggml_rope_ext(ctx0, tmp,
484
495
  shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
485
496
  yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
486
497
 
@@ -560,7 +571,7 @@ llm_graph_result_ptr llama_context::build_kv_self_shift(
560
571
  ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
561
572
  0);
562
573
 
563
- ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, kv_self->k_l[il]->buffer);
574
+ ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
564
575
 
565
576
  ggml_build_forward_expand(gf, cur);
566
577
  }
@@ -1184,33 +1195,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1184
1195
  const int64_t n_tokens_all = batch.n_tokens;
1185
1196
  const int64_t n_embd = hparams.n_embd;
1186
1197
 
1187
- // TODO: remove this stuff
1188
- class batch_guard {
1189
- public:
1190
- batch_guard(llama_kv_cache_unified & kv_self) : kv_slot_restorer(kv_self) {
1191
- }
1192
-
1193
- ~batch_guard() {
1194
- if (!is_done) {
1195
- kv_slot_restorer.restore();
1196
- }
1197
- }
1198
-
1199
- void done() {
1200
- is_done = true;
1201
- }
1202
-
1203
- void save(const llama_kv_cache_slot_info & slot_info) {
1204
- kv_slot_restorer.save(slot_info);
1205
- }
1206
-
1207
- private:
1208
- bool is_done = false;
1209
-
1210
- llama_kv_slot_restorer kv_slot_restorer;
1211
- };
1212
-
1213
- batch_guard bg(*kv_self);
1198
+ llama_kv_cache_guard kv_guard(kv_self.get());
1214
1199
 
1215
1200
  GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
1216
1201
 
@@ -1263,6 +1248,9 @@ int llama_context::decode(llama_batch & inp_batch) {
1263
1248
  return -2;
1264
1249
  };
1265
1250
 
1251
+ // handle any pending defrags/shifts
1252
+ kv_self_update();
1253
+
1266
1254
  int64_t n_outputs_prev = 0;
1267
1255
 
1268
1256
  while (sbatch.n_tokens > 0) {
@@ -1300,24 +1288,14 @@ int llama_context::decode(llama_batch & inp_batch) {
1300
1288
  n_outputs = n_outputs_new;
1301
1289
  }
1302
1290
 
1303
- // non-causal masks do not use the KV cache
1304
- if (hparams.causal_attn) {
1305
- kv_self_update();
1306
-
1307
- // if we have enough unused cells before the current head ->
1308
- // better to start searching from the beginning of the cache, hoping to fill it
1309
- if (kv_self->head > kv_self->used + 2*ubatch.n_tokens) {
1310
- kv_self->head = 0;
1311
- }
1291
+ // find KV slot
1292
+ {
1293
+ if (!kv_self->find_slot(ubatch)) {
1294
+ LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
1312
1295
 
1313
- const auto slot_info = kv_self->find_slot(ubatch);
1314
- if (!slot_info) {
1315
- LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__);
1316
- return -3;
1296
+ return 1;
1317
1297
  }
1318
1298
 
1319
- bg.save(slot_info);
1320
-
1321
1299
  if (!kv_self->recurrent) {
1322
1300
  // a heuristic, to avoid attending the full cache if it is not yet utilized
1323
1301
  // after enough generations, the benefit from this heuristic disappears
@@ -1354,16 +1332,6 @@ int llama_context::decode(llama_batch & inp_batch) {
1354
1332
  }
1355
1333
  }
1356
1334
 
1357
- // update the kv ring buffer
1358
- {
1359
- kv_self->head += ubatch.n_tokens;
1360
-
1361
- // Ensure kv cache head points to a valid index.
1362
- if (kv_self->head >= kv_self->size) {
1363
- kv_self->head = 0;
1364
- }
1365
- }
1366
-
1367
1335
  // plot the computation graph in dot format (for debugging purposes)
1368
1336
  //if (n_past%100 == 0) {
1369
1337
  // ggml_graph_dump_dot(gf, NULL, "llama.dot");
@@ -1450,7 +1418,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1450
1418
  }
1451
1419
 
1452
1420
  // finalize the batch processing
1453
- bg.done();
1421
+ kv_guard.commit();
1454
1422
 
1455
1423
  // set output mappings
1456
1424
  {
@@ -1568,8 +1536,6 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
1568
1536
  // set all ids as invalid (negative)
1569
1537
  std::fill(output_ids.begin(), output_ids.end(), -1);
1570
1538
 
1571
- ggml_backend_buffer_clear(buf_output.get(), 0);
1572
-
1573
1539
  this->n_outputs = 0;
1574
1540
  this->n_outputs_max = n_outputs_max;
1575
1541
 
@@ -2299,11 +2265,6 @@ llama_context * llama_init_from_model(
2299
2265
  params.flash_attn = false;
2300
2266
  }
2301
2267
 
2302
- if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
2303
- LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
2304
- params.flash_attn = false;
2305
- }
2306
-
2307
2268
  if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
2308
2269
  LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
2309
2270
  return nullptr;
@@ -2504,7 +2465,12 @@ int32_t llama_get_kv_cache_token_count(const llama_context * ctx) {
2504
2465
  }
2505
2466
 
2506
2467
  int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
2507
- return llama_kv_cache_n_tokens(ctx->get_kv_self());
2468
+ const auto * kv = ctx->get_kv_self();
2469
+ if (!kv) {
2470
+ return 0;
2471
+ }
2472
+
2473
+ return kv->get_n_tokens();
2508
2474
  }
2509
2475
 
2510
2476
  // deprecated
@@ -2513,7 +2479,12 @@ int32_t llama_get_kv_cache_used_cells(const llama_context * ctx) {
2513
2479
  }
2514
2480
 
2515
2481
  int32_t llama_kv_self_used_cells(const llama_context * ctx) {
2516
- return llama_kv_cache_used_cells(ctx->get_kv_self());
2482
+ const auto * kv = ctx->get_kv_self();
2483
+ if (!kv) {
2484
+ return 0;
2485
+ }
2486
+
2487
+ return kv->get_used_cells();
2517
2488
  }
2518
2489
 
2519
2490
  // deprecated
@@ -2522,7 +2493,12 @@ void llama_kv_cache_clear(llama_context * ctx) {
2522
2493
  }
2523
2494
 
2524
2495
  void llama_kv_self_clear(llama_context * ctx) {
2525
- llama_kv_cache_clear(ctx->get_kv_self());
2496
+ auto * kv = ctx->get_kv_self();
2497
+ if (!kv) {
2498
+ return;
2499
+ }
2500
+
2501
+ kv->clear();
2526
2502
  }
2527
2503
 
2528
2504
  // deprecated
@@ -2539,7 +2515,12 @@ bool llama_kv_self_seq_rm(
2539
2515
  llama_seq_id seq_id,
2540
2516
  llama_pos p0,
2541
2517
  llama_pos p1) {
2542
- return llama_kv_cache_seq_rm(ctx->get_kv_self(), seq_id, p0, p1);
2518
+ auto * kv = ctx->get_kv_self();
2519
+ if (!kv) {
2520
+ return true;
2521
+ }
2522
+
2523
+ return kv->seq_rm(seq_id, p0, p1);
2543
2524
  }
2544
2525
 
2545
2526
  // deprecated
@@ -2558,7 +2539,12 @@ void llama_kv_self_seq_cp(
2558
2539
  llama_seq_id seq_id_dst,
2559
2540
  llama_pos p0,
2560
2541
  llama_pos p1) {
2561
- return llama_kv_cache_seq_cp(ctx->get_kv_self(), seq_id_src, seq_id_dst, p0, p1);
2542
+ auto * kv = ctx->get_kv_self();
2543
+ if (!kv) {
2544
+ return;
2545
+ }
2546
+
2547
+ return kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
2562
2548
  }
2563
2549
 
2564
2550
  // deprecated
@@ -2569,7 +2555,12 @@ void llama_kv_cache_seq_keep(
2569
2555
  }
2570
2556
 
2571
2557
  void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
2572
- return llama_kv_cache_seq_keep(ctx->get_kv_self(), seq_id);
2558
+ auto * kv = ctx->get_kv_self();
2559
+ if (!kv) {
2560
+ return;
2561
+ }
2562
+
2563
+ return kv->seq_keep(seq_id);
2573
2564
  }
2574
2565
 
2575
2566
  // deprecated
@@ -2588,7 +2579,12 @@ void llama_kv_self_seq_add(
2588
2579
  llama_pos p0,
2589
2580
  llama_pos p1,
2590
2581
  llama_pos delta) {
2591
- return llama_kv_cache_seq_add(ctx->get_kv_self(), seq_id, p0, p1, delta);
2582
+ auto * kv = ctx->get_kv_self();
2583
+ if (!kv) {
2584
+ return;
2585
+ }
2586
+
2587
+ return kv->seq_add(seq_id, p0, p1, delta);
2592
2588
  }
2593
2589
 
2594
2590
  // deprecated
@@ -2607,7 +2603,12 @@ void llama_kv_self_seq_div(
2607
2603
  llama_pos p0,
2608
2604
  llama_pos p1,
2609
2605
  int d) {
2610
- return llama_kv_cache_seq_div(ctx->get_kv_self(), seq_id, p0, p1, d);
2606
+ auto * kv = ctx->get_kv_self();
2607
+ if (!kv) {
2608
+ return;
2609
+ }
2610
+
2611
+ return kv->seq_div(seq_id, p0, p1, d);
2611
2612
  }
2612
2613
 
2613
2614
  // deprecated
@@ -2616,7 +2617,12 @@ llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
2616
2617
  }
2617
2618
 
2618
2619
  llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
2619
- return llama_kv_cache_seq_pos_max(ctx->get_kv_self(), seq_id);
2620
+ const auto * kv = ctx->get_kv_self();
2621
+ if (!kv) {
2622
+ return 0;
2623
+ }
2624
+
2625
+ return kv->seq_pos_max(seq_id);
2620
2626
  }
2621
2627
 
2622
2628
  // deprecated
@@ -2625,7 +2631,12 @@ void llama_kv_cache_defrag(llama_context * ctx) {
2625
2631
  }
2626
2632
 
2627
2633
  void llama_kv_self_defrag(llama_context * ctx) {
2628
- llama_kv_cache_defrag(ctx->get_kv_self());
2634
+ auto * kv = ctx->get_kv_self();
2635
+ if (!kv) {
2636
+ return;
2637
+ }
2638
+
2639
+ return kv->defrag();
2629
2640
  }
2630
2641
 
2631
2642
  // deprecated
@@ -2634,7 +2645,12 @@ bool llama_kv_cache_can_shift(const llama_context * ctx) {
2634
2645
  }
2635
2646
 
2636
2647
  bool llama_kv_self_can_shift(const llama_context * ctx) {
2637
- return llama_kv_cache_can_shift(ctx->get_kv_self());
2648
+ const auto * kv = ctx->get_kv_self();
2649
+ if (!kv) {
2650
+ return false;
2651
+ }
2652
+
2653
+ return kv->get_can_shift();
2638
2654
  }
2639
2655
 
2640
2656
  // deprecated
@@ -170,8 +170,7 @@ private:
170
170
  ggml_tensor * shift,
171
171
  ggml_tensor * factors,
172
172
  float freq_base,
173
- float freq_scale,
174
- ggml_backend_buffer * bbuf) const;
173
+ float freq_scale) const;
175
174
 
176
175
  llm_graph_result_ptr build_kv_self_shift(
177
176
  ggml_context * ctx0,