@fugood/llama.node 0.3.17 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (193) hide show
  1. package/CMakeLists.txt +3 -1
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  7. package/bin/linux-cuda/x64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  10. package/bin/win32/arm64/llama-node.node +0 -0
  11. package/bin/win32/arm64/node.lib +0 -0
  12. package/bin/win32/x64/llama-node.node +0 -0
  13. package/bin/win32/x64/node.lib +0 -0
  14. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/arm64/node.lib +0 -0
  16. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  17. package/bin/win32-vulkan/x64/node.lib +0 -0
  18. package/lib/binding.ts +39 -2
  19. package/lib/index.js +132 -1
  20. package/lib/index.ts +203 -3
  21. package/package.json +2 -1
  22. package/src/EmbeddingWorker.cpp +1 -1
  23. package/src/LlamaCompletionWorker.cpp +366 -19
  24. package/src/LlamaCompletionWorker.h +30 -10
  25. package/src/LlamaContext.cpp +213 -5
  26. package/src/LlamaContext.h +12 -0
  27. package/src/common.hpp +15 -0
  28. package/src/llama.cpp/.github/workflows/build-linux-cross.yml +133 -24
  29. package/src/llama.cpp/.github/workflows/build.yml +41 -762
  30. package/src/llama.cpp/.github/workflows/docker.yml +5 -2
  31. package/src/llama.cpp/.github/workflows/release.yml +716 -0
  32. package/src/llama.cpp/.github/workflows/server.yml +12 -12
  33. package/src/llama.cpp/CMakeLists.txt +5 -17
  34. package/src/llama.cpp/cmake/build-info.cmake +8 -2
  35. package/src/llama.cpp/cmake/x64-windows-llvm.cmake +0 -6
  36. package/src/llama.cpp/common/CMakeLists.txt +31 -3
  37. package/src/llama.cpp/common/arg.cpp +48 -29
  38. package/src/llama.cpp/common/chat.cpp +128 -106
  39. package/src/llama.cpp/common/chat.h +2 -0
  40. package/src/llama.cpp/common/common.cpp +37 -1
  41. package/src/llama.cpp/common/common.h +18 -9
  42. package/src/llama.cpp/common/llguidance.cpp +1 -0
  43. package/src/llama.cpp/common/minja/chat-template.hpp +9 -5
  44. package/src/llama.cpp/common/minja/minja.hpp +69 -36
  45. package/src/llama.cpp/common/regex-partial.cpp +204 -0
  46. package/src/llama.cpp/common/regex-partial.h +56 -0
  47. package/src/llama.cpp/common/sampling.cpp +57 -50
  48. package/src/llama.cpp/examples/CMakeLists.txt +2 -23
  49. package/src/llama.cpp/examples/embedding/embedding.cpp +2 -11
  50. package/src/llama.cpp/examples/parallel/parallel.cpp +86 -14
  51. package/src/llama.cpp/examples/training/CMakeLists.txt +5 -0
  52. package/src/llama.cpp/examples/training/finetune.cpp +96 -0
  53. package/src/llama.cpp/ggml/CMakeLists.txt +27 -0
  54. package/src/llama.cpp/ggml/include/ggml-backend.h +4 -4
  55. package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -1
  56. package/src/llama.cpp/ggml/include/ggml-opt.h +47 -28
  57. package/src/llama.cpp/ggml/include/ggml.h +10 -7
  58. package/src/llama.cpp/ggml/src/CMakeLists.txt +1 -1
  59. package/src/llama.cpp/ggml/src/ggml-alloc.c +4 -1
  60. package/src/llama.cpp/ggml/src/ggml-backend.cpp +9 -5
  61. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +20 -13
  62. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -2
  63. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +306 -6
  64. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +4 -13
  65. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +29 -16
  66. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +88 -5
  67. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -12
  68. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +264 -69
  69. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +501 -0
  70. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +0 -13
  71. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +0 -6
  72. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +23 -4
  73. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +36 -11
  74. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +0 -2
  75. package/src/llama.cpp/ggml/src/ggml-opt.cpp +368 -190
  76. package/src/llama.cpp/ggml/src/ggml-quants.c +0 -6
  77. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +41 -27
  78. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +29 -23
  79. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +9 -8
  80. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +121 -232
  81. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +7 -15
  82. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +72 -25
  83. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +14 -7
  84. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
  85. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +7 -1
  86. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +0 -23
  87. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
  88. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +338 -166
  89. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +185 -89
  90. package/src/llama.cpp/ggml/src/ggml-sycl/quants.hpp +83 -0
  91. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +128 -53
  92. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +81 -70
  93. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +657 -193
  94. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +20 -0
  95. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +123 -29
  96. package/src/llama.cpp/ggml/src/ggml.c +29 -20
  97. package/src/llama.cpp/ggml/src/gguf.cpp +33 -33
  98. package/src/llama.cpp/include/llama.h +52 -11
  99. package/src/llama.cpp/requirements/requirements-all.txt +3 -3
  100. package/src/llama.cpp/scripts/xxd.cmake +1 -1
  101. package/src/llama.cpp/src/CMakeLists.txt +1 -0
  102. package/src/llama.cpp/src/llama-adapter.cpp +6 -0
  103. package/src/llama.cpp/src/llama-arch.cpp +3 -0
  104. package/src/llama.cpp/src/llama-batch.cpp +5 -1
  105. package/src/llama.cpp/src/llama-batch.h +2 -1
  106. package/src/llama.cpp/src/llama-chat.cpp +17 -7
  107. package/src/llama.cpp/src/llama-chat.h +1 -0
  108. package/src/llama.cpp/src/llama-context.cpp +389 -501
  109. package/src/llama.cpp/src/llama-context.h +44 -32
  110. package/src/llama.cpp/src/llama-cparams.h +1 -0
  111. package/src/llama.cpp/src/llama-graph.cpp +20 -38
  112. package/src/llama.cpp/src/llama-graph.h +12 -8
  113. package/src/llama.cpp/src/llama-kv-cache.cpp +1503 -389
  114. package/src/llama.cpp/src/llama-kv-cache.h +271 -85
  115. package/src/llama.cpp/src/llama-memory.h +11 -1
  116. package/src/llama.cpp/src/llama-model-loader.cpp +24 -15
  117. package/src/llama.cpp/src/llama-model-saver.cpp +281 -0
  118. package/src/llama.cpp/src/llama-model-saver.h +37 -0
  119. package/src/llama.cpp/src/llama-model.cpp +316 -69
  120. package/src/llama.cpp/src/llama-model.h +8 -1
  121. package/src/llama.cpp/src/llama-quant.cpp +15 -13
  122. package/src/llama.cpp/src/llama-sampling.cpp +18 -6
  123. package/src/llama.cpp/src/llama-vocab.cpp +42 -4
  124. package/src/llama.cpp/src/llama-vocab.h +6 -0
  125. package/src/llama.cpp/src/llama.cpp +14 -0
  126. package/src/llama.cpp/tests/CMakeLists.txt +10 -2
  127. package/src/llama.cpp/tests/test-backend-ops.cpp +107 -47
  128. package/src/llama.cpp/tests/test-chat-template.cpp +10 -11
  129. package/src/llama.cpp/tests/test-chat.cpp +3 -1
  130. package/src/llama.cpp/tests/test-mtmd-c-api.c +63 -0
  131. package/src/llama.cpp/tests/test-opt.cpp +33 -21
  132. package/src/llama.cpp/tests/test-regex-partial.cpp +288 -0
  133. package/src/llama.cpp/tests/test-sampling.cpp +1 -1
  134. package/src/llama.cpp/tools/CMakeLists.txt +39 -0
  135. package/src/llama.cpp/{examples → tools}/batched-bench/batched-bench.cpp +2 -2
  136. package/src/llama.cpp/{examples → tools}/imatrix/imatrix.cpp +11 -9
  137. package/src/llama.cpp/{examples → tools}/llama-bench/llama-bench.cpp +495 -348
  138. package/src/llama.cpp/{examples → tools}/main/main.cpp +6 -9
  139. package/src/llama.cpp/{examples/llava → tools/mtmd}/CMakeLists.txt +1 -35
  140. package/src/llama.cpp/{examples/llava → tools/mtmd}/clip-impl.h +25 -5
  141. package/src/llama.cpp/{examples/llava → tools/mtmd}/clip.cpp +1440 -1349
  142. package/src/llama.cpp/tools/mtmd/clip.h +99 -0
  143. package/src/llama.cpp/{examples/llava → tools/mtmd}/mtmd-cli.cpp +70 -44
  144. package/src/llama.cpp/tools/mtmd/mtmd-helper.cpp +310 -0
  145. package/src/llama.cpp/{examples/llava → tools/mtmd}/mtmd.cpp +251 -281
  146. package/src/llama.cpp/tools/mtmd/mtmd.h +331 -0
  147. package/src/llama.cpp/{examples → tools}/perplexity/perplexity.cpp +4 -2
  148. package/src/llama.cpp/{examples → tools}/quantize/quantize.cpp +13 -76
  149. package/src/llama.cpp/{examples → tools}/rpc/rpc-server.cpp +70 -74
  150. package/src/llama.cpp/{examples → tools}/run/run.cpp +18 -4
  151. package/src/llama.cpp/{examples → tools}/server/CMakeLists.txt +2 -1
  152. package/src/llama.cpp/{examples → tools}/server/server.cpp +291 -76
  153. package/src/llama.cpp/{examples → tools}/server/utils.hpp +377 -5
  154. package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +0 -6
  155. package/src/llama.cpp/examples/infill/CMakeLists.txt +0 -5
  156. package/src/llama.cpp/examples/infill/infill.cpp +0 -590
  157. package/src/llama.cpp/examples/llava/android/build_64.sh +0 -8
  158. package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +0 -59
  159. package/src/llama.cpp/examples/llava/clip.h +0 -135
  160. package/src/llama.cpp/examples/llava/llava.cpp +0 -586
  161. package/src/llama.cpp/examples/llava/llava.h +0 -49
  162. package/src/llama.cpp/examples/llava/mtmd.h +0 -168
  163. package/src/llama.cpp/examples/llava/qwen2vl-test.cpp +0 -636
  164. /package/src/llama.cpp/{examples → tools}/batched-bench/CMakeLists.txt +0 -0
  165. /package/src/llama.cpp/{examples → tools}/cvector-generator/CMakeLists.txt +0 -0
  166. /package/src/llama.cpp/{examples → tools}/cvector-generator/completions.txt +0 -0
  167. /package/src/llama.cpp/{examples → tools}/cvector-generator/cvector-generator.cpp +0 -0
  168. /package/src/llama.cpp/{examples → tools}/cvector-generator/mean.hpp +0 -0
  169. /package/src/llama.cpp/{examples → tools}/cvector-generator/negative.txt +0 -0
  170. /package/src/llama.cpp/{examples → tools}/cvector-generator/pca.hpp +0 -0
  171. /package/src/llama.cpp/{examples → tools}/cvector-generator/positive.txt +0 -0
  172. /package/src/llama.cpp/{examples → tools}/export-lora/CMakeLists.txt +0 -0
  173. /package/src/llama.cpp/{examples → tools}/export-lora/export-lora.cpp +0 -0
  174. /package/src/llama.cpp/{examples → tools}/gguf-split/CMakeLists.txt +0 -0
  175. /package/src/llama.cpp/{examples → tools}/gguf-split/gguf-split.cpp +0 -0
  176. /package/src/llama.cpp/{examples → tools}/imatrix/CMakeLists.txt +0 -0
  177. /package/src/llama.cpp/{examples → tools}/llama-bench/CMakeLists.txt +0 -0
  178. /package/src/llama.cpp/{examples → tools}/main/CMakeLists.txt +0 -0
  179. /package/src/llama.cpp/{examples/llava → tools/mtmd}/deprecation-warning.cpp +0 -0
  180. /package/src/llama.cpp/{examples/llava → tools/mtmd}/requirements.txt +0 -0
  181. /package/src/llama.cpp/{examples → tools}/perplexity/CMakeLists.txt +0 -0
  182. /package/src/llama.cpp/{examples → tools}/quantize/CMakeLists.txt +0 -0
  183. /package/src/llama.cpp/{examples → tools}/rpc/CMakeLists.txt +0 -0
  184. /package/src/llama.cpp/{examples → tools}/run/CMakeLists.txt +0 -0
  185. /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.cpp +0 -0
  186. /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.h +0 -0
  187. /package/src/llama.cpp/{examples → tools}/server/bench/requirements.txt +0 -0
  188. /package/src/llama.cpp/{examples → tools}/server/httplib.h +0 -0
  189. /package/src/llama.cpp/{examples → tools}/server/tests/requirements.txt +0 -0
  190. /package/src/llama.cpp/{examples → tools}/tokenize/CMakeLists.txt +0 -0
  191. /package/src/llama.cpp/{examples → tools}/tokenize/tokenize.cpp +0 -0
  192. /package/src/llama.cpp/{examples → tools}/tts/CMakeLists.txt +0 -0
  193. /package/src/llama.cpp/{examples → tools}/tts/tts.cpp +0 -0
@@ -6,11 +6,9 @@
6
6
  #include "llama-model.h"
7
7
  #include "llama-kv-cache.h"
8
8
 
9
- #include <cassert>
10
9
  #include <cstring>
11
10
  #include <stdexcept>
12
11
  #include <cinttypes>
13
- #include <cmath>
14
12
 
15
13
  //
16
14
  // llama_context
@@ -95,6 +93,7 @@ llama_context::llama_context(
95
93
  }
96
94
 
97
95
  cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
96
+ cparams.op_offload = params.op_offload;
98
97
 
99
98
  const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
100
99
 
@@ -118,8 +117,6 @@ llama_context::llama_context(
118
117
  __func__, n_ctx_per_seq, hparams.n_ctx_train);
119
118
  }
120
119
 
121
- logits_all = params.logits_all;
122
-
123
120
  if (!hparams.vocab_only) {
124
121
  // GPU backends
125
122
  for (auto * dev : model.devices) {
@@ -177,44 +174,13 @@ llama_context::llama_context(
177
174
  }
178
175
 
179
176
  // init the memory module
180
- // TODO: for now, always create a unified KV cache
181
177
  if (!hparams.vocab_only) {
182
- kv_self.reset(static_cast<llama_kv_cache_unified *>(model.create_memory()));
183
-
184
- LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
185
-
186
- cparams.n_ctx = GGML_PAD(cparams.n_ctx, kv_self->get_padding(cparams));
187
-
188
- LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
189
-
190
- uint32_t kv_size = cparams.n_ctx;
191
- ggml_type type_k = params.type_k;
192
- ggml_type type_v = params.type_v;
193
-
194
- if (llama_model_is_recurrent(&model)) {
195
- // Mamba needs at least as many KV cells as there are sequences kept at any time
196
- kv_size = std::max((uint32_t) 1, params.n_seq_max);
197
- // it's probably best to keep as much precision as possible for the states
198
- type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
199
- type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
200
- }
201
-
202
- GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
203
- GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
204
-
205
- if (!kv_self->init(model, cparams, type_k, type_v, kv_size, cparams.offload_kqv)) {
206
- throw std::runtime_error("failed to initialize self-attention cache");
207
- }
178
+ llama_memory_params params_mem = {
179
+ /*.type_k =*/ params.type_k,
180
+ /*.type_v =*/ params.type_v,
181
+ };
208
182
 
209
- {
210
- const size_t memory_size_k = kv_self->size_k_bytes();
211
- const size_t memory_size_v = kv_self->size_v_bytes();
212
-
213
- LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
214
- (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
215
- ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
216
- ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
217
- }
183
+ memory.reset(model.create_memory(params_mem, cparams));
218
184
  }
219
185
 
220
186
  // init backends
@@ -278,7 +244,7 @@ llama_context::llama_context(
278
244
  }
279
245
  }
280
246
 
281
- sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel));
247
+ sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel, cparams.op_offload));
282
248
 
283
249
  if (pipeline_parallel) {
284
250
  LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get()));
@@ -286,7 +252,7 @@ llama_context::llama_context(
286
252
  }
287
253
 
288
254
  // reserve worst-case graph
289
- if (!hparams.vocab_only) {
255
+ if (!hparams.vocab_only && memory) {
290
256
  const uint32_t n_seqs = 1; // TODO: worst-case number of sequences
291
257
  const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
292
258
 
@@ -305,7 +271,9 @@ llama_context::llama_context(
305
271
  int n_nodes_tg = -1;
306
272
 
307
273
  // simulate full KV cache
308
- kv_self->n = kv_self->size;
274
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
275
+
276
+ kv_self->set_full();
309
277
 
310
278
  cross.v_embd.clear();
311
279
 
@@ -391,7 +359,9 @@ llama_context::llama_context(
391
359
  }
392
360
  }
393
361
 
394
- llama_context::~llama_context() = default;
362
+ llama_context::~llama_context() {
363
+ ggml_opt_free(opt_ctx);
364
+ }
395
365
 
396
366
  void llama_context::synchronize() {
397
367
  ggml_backend_sched_synchronize(sched.get());
@@ -427,6 +397,18 @@ const llama_model & llama_context::get_model() const {
427
397
  return model;
428
398
  }
429
399
 
400
+ const llama_cparams & llama_context::get_cparams() const {
401
+ return cparams;
402
+ }
403
+
404
+ ggml_backend_sched_t llama_context::get_sched() const {
405
+ return sched.get();
406
+ }
407
+
408
+ ggml_context * llama_context::get_ctx_compute() const {
409
+ return ctx_compute.get();
410
+ }
411
+
430
412
  uint32_t llama_context::n_ctx() const {
431
413
  return cparams.n_ctx;
432
414
  }
@@ -456,337 +438,21 @@ uint32_t llama_context::n_threads_batch() const {
456
438
  }
457
439
 
458
440
  llama_kv_cache * llama_context::get_kv_self() {
459
- return kv_self.get();
441
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
442
+ return kv_self;
460
443
  }
461
444
 
462
445
  const llama_kv_cache * llama_context::get_kv_self() const {
463
- return kv_self.get();
464
- }
465
-
466
- ggml_tensor * llama_context::build_rope_shift(
467
- ggml_context * ctx0,
468
- ggml_tensor * cur,
469
- ggml_tensor * shift,
470
- ggml_tensor * factors,
471
- float freq_base,
472
- float freq_scale) const {
473
- const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
474
-
475
- const auto & yarn_ext_factor = cparams.yarn_ext_factor;
476
- const auto & yarn_beta_fast = cparams.yarn_beta_fast;
477
- const auto & yarn_beta_slow = cparams.yarn_beta_slow;
478
-
479
- const auto & hparams = model.hparams;
480
-
481
- const auto & n_rot = hparams.n_rot;
482
- const auto & rope_type = hparams.rope_type;
483
-
484
- // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
485
- // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
486
- const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor;
487
-
488
- ggml_tensor * tmp;
489
-
490
- if (ggml_is_quantized(cur->type)) {
491
- // dequantize to f32 -> RoPE -> quantize back
492
- tmp = ggml_cast(ctx0, cur, GGML_TYPE_F32);
493
-
494
- tmp = ggml_rope_ext(ctx0, tmp,
495
- shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
496
- yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
497
-
498
- tmp = ggml_cpy(ctx0, tmp, cur);
499
- } else {
500
- // we rotate only the first n_rot dimensions
501
- tmp = ggml_rope_ext_inplace(ctx0, cur,
502
- shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
503
- yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
504
- }
505
-
506
- return tmp;
507
- }
508
-
509
- class llm_graph_input_k_shift : public llm_graph_input_i {
510
- public:
511
- llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
512
- virtual ~llm_graph_input_k_shift() = default;
513
-
514
- void set_input(const llama_ubatch * ubatch) override;
515
-
516
- ggml_tensor * k_shift; // I32 [kv_size]
517
-
518
- const llama_kv_cache_unified * kv_self;
519
- };
520
-
521
- void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
522
- GGML_UNUSED(ubatch);
523
-
524
- if (k_shift) {
525
- assert(ggml_backend_buffer_is_host(k_shift->buffer));
526
-
527
- int32_t * data = (int32_t *) k_shift->data;
528
-
529
- for (uint32_t i = 0; i < kv_self->size; ++i) {
530
- data[i] = kv_self->cells[i].delta;
531
- }
532
- }
533
- }
534
-
535
- llm_graph_result_ptr llama_context::build_kv_self_shift(
536
- ggml_context * ctx0,
537
- ggml_cgraph * gf) const {
538
- auto res = std::make_unique<llm_graph_result>();
539
-
540
- const auto & hparams = model.hparams;
541
-
542
- const auto & n_layer = hparams.n_layer;
543
-
544
- const auto & n_embd_head_k = hparams.n_embd_head_k;
545
- //const auto & n_embd_head_v = hparams.n_embd_head_v;
546
-
547
- //GGML_ASSERT(kv_self->size == n_ctx);
548
-
549
- auto inp = std::make_unique<llm_graph_input_k_shift>(kv_self.get());
550
-
551
- inp->k_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, cparams.n_ctx);
552
- ggml_set_input(inp->k_shift);
553
-
554
- for (uint32_t il = 0; il < n_layer; ++il) {
555
- const int64_t n_head_kv = hparams.n_head_kv(il);
556
- const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
557
-
558
- const bool is_swa = hparams.is_swa(il);
559
-
560
- // note: the swa rope params could become part of the cparams in the future
561
- // if we decide to make them configurable, like the non-sliding ones
562
- const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
563
- const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
564
-
565
- ggml_tensor * rope_factors = kv_self->cbs.get_rope_factors(n_ctx_per_seq(), il);
566
-
567
- ggml_tensor * k =
568
- ggml_view_3d(ctx0, kv_self->k_l[il],
569
- n_embd_head_k, n_head_kv, kv_self->size,
570
- ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
571
- ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
572
- 0);
573
-
574
- ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
575
-
576
- ggml_build_forward_expand(gf, cur);
577
- }
578
-
579
- res->add_input(std::move(inp));
580
-
581
- return res;
582
- }
583
-
584
- llm_graph_result_ptr llama_context::build_kv_self_defrag(
585
- ggml_context * ctx0,
586
- ggml_cgraph * gf) const {
587
- auto res = std::make_unique<llm_graph_result>();
588
-
589
- const auto & hparams = model.hparams;
590
-
591
- const auto & ids = kv_self->defrag_info.ids;
592
-
593
- #if 0
594
- // CPU defrag
595
- //
596
- // TODO: optimizations are possible:
597
- // - multiple threads
598
- // - avoid copying to the host memory when already there
599
- //
600
- // likely not worth the effort, as we have ggml_graph based defrag
601
- //
602
-
603
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
604
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
605
-
606
- const uint32_t kv_size = size;
607
-
608
- std::vector<uint8_t> buf_k;
609
- std::vector<uint8_t> buf_v;
610
-
611
- for (uint32_t il = 0; il < n_layer; ++il) {
612
- const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
613
- const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size);
614
-
615
- const size_t v_size_el = ggml_type_size(v_l[il]->type);
616
- const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size);
617
-
618
- buf_k.resize(k_size);
619
- buf_v.resize(v_size);
620
-
621
- ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size());
622
- ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size());
623
-
624
- // batch move [i, i+nm) to [id, id+nm)
625
- // note: cells can move only to a lower index
626
- for (uint32_t i = 0; i < n_kv; ++i) {
627
- const uint32_t id = ids[i];
628
-
629
- if (i == id || id == n_kv) {
630
- continue;
631
- }
632
-
633
- uint32_t nm = 1;
634
-
635
- while (i + nm < n_kv && ids[i + nm] == id + nm) {
636
- nm++;
637
- }
638
-
639
- // move keys
640
- {
641
- const int64_t os = i*k_size_row;
642
- const int64_t od = id*k_size_row;
643
-
644
- memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row);
645
- }
646
-
647
- // move values (note: they are transposed)
648
- {
649
- const int64_t os = i;
650
- const int64_t od = id;
651
-
652
- for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
653
- memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
654
- }
655
- }
656
-
657
- i += nm - 1;
658
- }
659
-
660
- ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size());
661
- ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
662
- }
663
- #else
664
- for (uint32_t i = 0; i < ids.size(); ++i) {
665
- const uint32_t id = ids[i];
666
-
667
- if (i == id || id == ids.size()) {
668
- continue;
669
- }
670
-
671
- uint32_t nm = 1;
672
-
673
- while (i + nm < ids.size() && ids[i + nm] == id + nm) {
674
- nm++;
675
- }
676
-
677
- for (uint32_t il = 0; il < hparams.n_layer; ++il) { // NOLINT
678
- const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
679
- const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
680
-
681
- ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self->k_l[il],
682
- n_embd_k_gqa, nm,
683
- ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
684
- ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*i));
685
-
686
- ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self->k_l[il],
687
- n_embd_k_gqa, nm,
688
- ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
689
- ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*id));
690
-
691
- ggml_tensor * view_v_src;
692
- ggml_tensor * view_v_dst;
693
-
694
- if (cparams.flash_attn) {
695
- // NOTE: the V cache is not transposed when using flash attention
696
- view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il],
697
- n_embd_v_gqa, nm,
698
- ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
699
- ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*i));
700
-
701
- view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il],
702
- n_embd_v_gqa, nm,
703
- ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
704
- ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*id));
705
- } else {
706
- view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il],
707
- nm, n_embd_v_gqa,
708
- ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
709
- ggml_row_size(kv_self->v_l[il]->type, i));
710
-
711
- view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il],
712
- nm, n_embd_v_gqa,
713
- ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
714
- ggml_row_size(kv_self->v_l[il]->type, id));
715
- }
716
-
717
- ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
718
- ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst));
719
- }
720
-
721
- i += nm - 1;
722
- }
723
-
724
- //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
725
- #endif
726
-
727
- return res;
446
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
447
+ return kv_self;
728
448
  }
729
449
 
730
450
  void llama_context::kv_self_update() {
731
- auto & kv = kv_self;
732
-
733
451
  bool need_reserve = false;
734
452
 
735
- if (kv->has_shift) {
736
- if (!kv->get_can_shift()) {
737
- GGML_ABORT("The current context does not support K-shift");
738
- }
739
-
740
- LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__);
741
-
742
- // apply K-shift if needed
743
- if (model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
744
- ggml_backend_sched_reset(sched.get());
745
-
746
- auto * gf = graph_init();
747
-
748
- auto res = build_kv_self_shift(ctx_compute.get(), gf);
749
-
750
- ggml_backend_sched_alloc_graph(sched.get(), gf);
751
-
752
- res->set_inputs(nullptr);
753
-
754
- graph_compute(gf, false);
453
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
755
454
 
756
- need_reserve = true;
757
- }
758
-
759
- {
760
- kv->has_shift = false;
761
-
762
- for (uint32_t i = 0; i < kv->size; ++i) {
763
- kv->cells[i].delta = 0;
764
- }
765
- }
766
- }
767
-
768
- // defragment the KV cache if needed
769
- if (kv->do_defrag) {
770
- LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
771
-
772
- if (kv->defrag_prepare(graph_max_nodes())) {
773
- ggml_backend_sched_reset(sched.get());
774
-
775
- auto * gf = graph_init();
776
-
777
- auto res = build_kv_self_defrag(ctx_compute.get(), gf);
778
-
779
- ggml_backend_sched_alloc_graph(sched.get(), gf);
780
-
781
- res->set_inputs(nullptr);
782
-
783
- graph_compute(gf, false);
784
-
785
- need_reserve = true;
786
- }
787
-
788
- kv->do_defrag = false;
789
- }
455
+ need_reserve = kv_self->update(*this);
790
456
 
791
457
  // reserve a worst case graph if needed
792
458
  if (need_reserve) {
@@ -797,7 +463,7 @@ void llama_context::kv_self_update() {
797
463
  uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
798
464
 
799
465
  // simulate full KV cache
800
- kv_self->n = kv_self->size;
466
+ kv_self->set_full();
801
467
 
802
468
  llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
803
469
  llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
@@ -818,9 +484,6 @@ enum llama_pooling_type llama_context::pooling_type() const {
818
484
  }
819
485
 
820
486
  float * llama_context::get_logits() {
821
- // reorder logits for backward compatibility
822
- output_reorder();
823
-
824
487
  return logits;
825
488
  }
826
489
 
@@ -863,9 +526,6 @@ float * llama_context::get_logits_ith(int32_t i) {
863
526
  }
864
527
 
865
528
  float * llama_context::get_embeddings() {
866
- // reorder embeddings for backward compatibility
867
- output_reorder();
868
-
869
529
  return embd;
870
530
  }
871
531
 
@@ -1017,8 +677,8 @@ int llama_context::encode(llama_batch & inp_batch) {
1017
677
  }
1018
678
 
1019
679
  // temporary allocate memory for the input batch if needed
1020
- // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
1021
- llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
680
+ // note: during encode, we always pass the full sequence starting from pos = 0
681
+ llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : 0);
1022
682
 
1023
683
  const llama_batch & batch = batch_allocr.batch;
1024
684
  const int32_t n_tokens = batch.n_tokens;
@@ -1043,11 +703,13 @@ int llama_context::encode(llama_batch & inp_batch) {
1043
703
  t_compute_start_us = ggml_time_us();
1044
704
  }
1045
705
 
706
+ embd_seq.clear();
707
+
1046
708
  n_queued_tokens += n_tokens;
1047
709
 
1048
710
  const int64_t n_embd = hparams.n_embd;
1049
711
 
1050
- sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
712
+ llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
1051
713
 
1052
714
  const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
1053
715
 
@@ -1104,12 +766,12 @@ int llama_context::encode(llama_batch & inp_batch) {
1104
766
  ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
1105
767
  GGML_ASSERT(backend_embd != nullptr);
1106
768
 
1107
- GGML_ASSERT(embd != nullptr);
1108
-
1109
769
  switch (cparams.pooling_type) {
1110
770
  case LLAMA_POOLING_TYPE_NONE:
1111
771
  {
1112
772
  // extract token embeddings
773
+ GGML_ASSERT(embd != nullptr);
774
+
1113
775
  GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size);
1114
776
  ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float));
1115
777
  } break;
@@ -1134,11 +796,18 @@ int llama_context::encode(llama_batch & inp_batch) {
1134
796
  } break;
1135
797
  case LLAMA_POOLING_TYPE_RANK:
1136
798
  {
1137
- // TODO: this likely should be the same logic as in llama_decoder_internal, but better to
1138
- // wait for an encoder model that requires this pooling type in order to test it
1139
- // https://github.com/ggerganov/llama.cpp/pull/9510
1140
- GGML_ABORT("RANK pooling not implemented yet");
1141
- }
799
+ // extract the rerank score - a single float per sequence
800
+ auto & embd_seq_out = embd_seq;
801
+
802
+ for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
803
+ const llama_seq_id seq_id = ubatch.seq_id[s][0];
804
+ if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
805
+ continue;
806
+ }
807
+ embd_seq_out[seq_id].resize(1);
808
+ ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
809
+ }
810
+ } break;
1142
811
  case LLAMA_POOLING_TYPE_UNSPECIFIED:
1143
812
  {
1144
813
  GGML_ABORT("unknown pooling type");
@@ -1176,14 +845,21 @@ int llama_context::encode(llama_batch & inp_batch) {
1176
845
  }
1177
846
 
1178
847
  int llama_context::decode(llama_batch & inp_batch) {
848
+ if (!memory) {
849
+ LLAMA_LOG_WARN("%s: cannot decode batches with this context (use llama_encode() instead)\n", __func__);
850
+ return encode(inp_batch);
851
+ }
852
+
1179
853
  if (inp_batch.n_tokens == 0) {
1180
854
  LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
1181
855
  return -1;
1182
856
  }
1183
857
 
858
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
859
+
1184
860
  // temporary allocate memory for the input batch if needed
1185
- // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
1186
- llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
861
+ // TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences
862
+ llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max() + 1);
1187
863
 
1188
864
  const llama_batch & batch = batch_allocr.batch;
1189
865
 
@@ -1195,7 +871,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1195
871
  const int64_t n_tokens_all = batch.n_tokens;
1196
872
  const int64_t n_embd = hparams.n_embd;
1197
873
 
1198
- llama_kv_cache_guard kv_guard(kv_self.get());
874
+ llama_kv_cache_guard kv_guard(kv_self);
1199
875
 
1200
876
  GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
1201
877
 
@@ -1229,18 +905,14 @@ int llama_context::decode(llama_batch & inp_batch) {
1229
905
  for (uint32_t i = 0; i < n_tokens_all; ++i) {
1230
906
  n_outputs_all += batch.logits[i] != 0;
1231
907
  }
1232
- } else if (logits_all || embd_pooled) {
908
+ } else if (embd_pooled) {
1233
909
  n_outputs_all = n_tokens_all;
1234
910
  } else {
1235
911
  // keep last output only
1236
912
  n_outputs_all = 1;
1237
913
  }
1238
914
 
1239
- const bool logits_all = n_outputs_all == n_tokens_all;
1240
-
1241
- sbatch.from_batch(batch, n_embd,
1242
- /* simple_split */ !kv_self->recurrent,
1243
- /* logits_all */ logits_all);
915
+ llama_sbatch sbatch = kv_self->sbatch_init(batch, /* logits_all */ n_outputs_all == n_tokens_all);
1244
916
 
1245
917
  // reserve output buffer
1246
918
  if (output_reserve(n_outputs_all) < n_outputs_all) {
@@ -1254,22 +926,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1254
926
  int64_t n_outputs_prev = 0;
1255
927
 
1256
928
  while (sbatch.n_tokens > 0) {
1257
- llama_ubatch ubatch = llama_ubatch();
1258
-
1259
- const auto & n_ubatch = cparams.n_ubatch;
1260
-
1261
- if (kv_self->recurrent) {
1262
- if (embd_pooled) {
1263
- // Pooled embeddings cannot be split across ubatches (yet)
1264
- ubatch = sbatch.split_seq(cparams.n_ubatch);
1265
- } else {
1266
- // recurrent model architectures are easier to implement
1267
- // with equal-length sequences
1268
- ubatch = sbatch.split_equal(cparams.n_ubatch);
1269
- }
1270
- } else {
1271
- ubatch = sbatch.split_simple(n_ubatch);
1272
- }
929
+ llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
1273
930
 
1274
931
  // count the outputs in this u_batch
1275
932
  {
@@ -1289,24 +946,12 @@ int llama_context::decode(llama_batch & inp_batch) {
1289
946
  }
1290
947
 
1291
948
  // find KV slot
1292
- {
1293
- if (!kv_self->find_slot(ubatch)) {
1294
- LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
1295
-
1296
- return 1;
1297
- }
949
+ if (!kv_self->find_slot(ubatch)) {
950
+ LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
1298
951
 
1299
- if (!kv_self->recurrent) {
1300
- // a heuristic, to avoid attending the full cache if it is not yet utilized
1301
- // after enough generations, the benefit from this heuristic disappears
1302
- // if we start defragmenting the cache, the benefit from this will be more important
1303
- const uint32_t pad = kv_self->get_padding(cparams);
1304
- kv_self->n = std::min(kv_self->size, std::max(pad, GGML_PAD(kv_self->cell_max(), pad)));
1305
- }
952
+ return 1;
1306
953
  }
1307
954
 
1308
- //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self->n, kv_self->used, kv_self->head);
1309
-
1310
955
  ggml_backend_sched_reset(sched.get());
1311
956
  ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
1312
957
 
@@ -1420,43 +1065,68 @@ int llama_context::decode(llama_batch & inp_batch) {
1420
1065
  // finalize the batch processing
1421
1066
  kv_guard.commit();
1422
1067
 
1068
+ // set to total number of outputs in the batch, for use in llama_get_logits_ith
1069
+ n_outputs = n_outputs_all;
1070
+
1423
1071
  // set output mappings
1424
1072
  {
1425
1073
  bool sorted_output = true;
1426
1074
 
1427
- GGML_ASSERT(sbatch.out_ids.size() == (size_t) n_outputs_all);
1075
+ auto & out_ids = sbatch.out_ids;
1076
+
1077
+ GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
1428
1078
 
1429
1079
  for (int64_t i = 0; i < n_outputs_all; ++i) {
1430
- int64_t out_id = sbatch.out_ids[i];
1080
+ int64_t out_id = out_ids[i];
1431
1081
  output_ids[out_id] = i;
1432
1082
  if (out_id != i) {
1433
1083
  sorted_output = false;
1434
1084
  }
1435
1085
  }
1436
1086
 
1437
- if (sorted_output) {
1438
- sbatch.out_ids.clear();
1087
+ // make the outputs have the same order they had in the user-provided batch
1088
+ // note: this is mostly relevant for recurrent models atm
1089
+ if (!sorted_output) {
1090
+ const uint32_t n_vocab = model.vocab.n_tokens();
1091
+ const uint32_t n_embd = model.hparams.n_embd;
1092
+
1093
+ GGML_ASSERT((size_t) n_outputs == out_ids.size());
1094
+
1095
+ // TODO: is there something more efficient which also minimizes swaps?
1096
+ // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
1097
+ for (int32_t i = 0; i < n_outputs - 1; ++i) {
1098
+ int32_t j_min = i;
1099
+ for (int32_t j = i + 1; j < n_outputs; ++j) {
1100
+ if (out_ids[j] < out_ids[j_min]) {
1101
+ j_min = j;
1102
+ }
1103
+ }
1104
+ if (j_min == i) { continue; }
1105
+ std::swap(out_ids[i], out_ids[j_min]);
1106
+ if (logits_size > 0) {
1107
+ for (uint32_t k = 0; k < n_vocab; k++) {
1108
+ std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
1109
+ }
1110
+ }
1111
+ if (embd_size > 0) {
1112
+ for (uint32_t k = 0; k < n_embd; k++) {
1113
+ std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
1114
+ }
1115
+ }
1116
+ }
1117
+ std::fill(output_ids.begin(), output_ids.end(), -1);
1118
+ for (int32_t i = 0; i < n_outputs; ++i) {
1119
+ output_ids[out_ids[i]] = i;
1120
+ }
1439
1121
  }
1440
1122
  }
1441
1123
 
1442
- // set to total number of outputs in the batch, for use in llama_get_logits_ith
1443
- n_outputs = n_outputs_all;
1444
-
1445
1124
  // wait for the computation to finish (automatically done when obtaining the model output)
1446
1125
  //synchronize();
1447
1126
 
1448
1127
  // decide if we need to defrag the kv cache
1449
- if (cparams.causal_attn && cparams.defrag_thold > 0.0f) {
1450
- // - do not defrag small contexts (i.e. < 2048 tokens)
1451
- // - count the padding towards the number of used tokens
1452
- const float fragmentation = kv_self->n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self->used + kv_self->get_padding(cparams))/float(kv_self->n)) : 0.0f;
1453
-
1454
- // queue defragmentation for next llama_kv_cache_update
1455
- if (fragmentation > cparams.defrag_thold) {
1456
- LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
1457
-
1458
- kv_self->defrag();
1459
- }
1128
+ if (cparams.defrag_thold > 0.0f) {
1129
+ kv_self->defrag_sched(cparams.defrag_thold);
1460
1130
  }
1461
1131
 
1462
1132
  // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
@@ -1542,44 +1212,6 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
1542
1212
  return n_outputs_max;
1543
1213
  }
1544
1214
 
1545
- void llama_context::output_reorder() {
1546
- auto & out_ids = sbatch.out_ids;
1547
- if (!out_ids.empty()) {
1548
- const uint32_t n_vocab = model.vocab.n_tokens();
1549
- const uint32_t n_embd = model.hparams.n_embd;
1550
-
1551
- GGML_ASSERT((size_t) n_outputs == out_ids.size());
1552
-
1553
- // TODO: is there something more efficient which also minimizes swaps?
1554
- // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
1555
- for (int32_t i = 0; i < n_outputs - 1; ++i) {
1556
- int32_t j_min = i;
1557
- for (int32_t j = i + 1; j < n_outputs; ++j) {
1558
- if (out_ids[j] < out_ids[j_min]) {
1559
- j_min = j;
1560
- }
1561
- }
1562
- if (j_min == i) { continue; }
1563
- std::swap(out_ids[i], out_ids[j_min]);
1564
- if (logits_size > 0) {
1565
- for (uint32_t k = 0; k < n_vocab; k++) {
1566
- std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
1567
- }
1568
- }
1569
- if (embd_size > 0) {
1570
- for (uint32_t k = 0; k < n_embd; k++) {
1571
- std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
1572
- }
1573
- }
1574
- }
1575
- std::fill(output_ids.begin(), output_ids.end(), -1);
1576
- for (int32_t i = 0; i < n_outputs; ++i) {
1577
- output_ids[out_ids[i]] = i;
1578
- }
1579
- out_ids.clear();
1580
- }
1581
- }
1582
-
1583
1215
  //
1584
1216
  // graph
1585
1217
  //
@@ -1616,7 +1248,7 @@ llm_graph_result_ptr llama_context::graph_build(
1616
1248
  /*.backend_cpu =*/ backend_cpu,
1617
1249
  /*.cvec =*/ &cvec,
1618
1250
  /*.loras =*/ &loras,
1619
- /*.memory =*/ kv_self.get(),
1251
+ /*.memory =*/ memory.get(),
1620
1252
  /*.cross =*/ &cross,
1621
1253
  /*.n_outputs =*/ n_outputs,
1622
1254
  /*.cb =*/ graph_get_cb(),
@@ -2020,8 +1652,6 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
2020
1652
  {
2021
1653
  LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__);
2022
1654
 
2023
- output_reorder();
2024
-
2025
1655
  const auto n_outputs = this->n_outputs;
2026
1656
  const auto & output_ids = this->output_ids;
2027
1657
 
@@ -2074,8 +1704,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
2074
1704
  }
2075
1705
  }
2076
1706
 
2077
- LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
2078
- kv_self->state_write(io);
1707
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1708
+
1709
+ if (kv_self != nullptr) {
1710
+ LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
1711
+ kv_self->state_write(io);
1712
+ }
2079
1713
 
2080
1714
  return io.n_bytes();
2081
1715
  }
@@ -2158,8 +1792,13 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
2158
1792
  }
2159
1793
  }
2160
1794
 
2161
- LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
2162
- kv_self->state_read(io);
1795
+ if (memory) {
1796
+ LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
1797
+
1798
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1799
+
1800
+ kv_self->state_read(io);
1801
+ }
2163
1802
 
2164
1803
  return io.n_bytes();
2165
1804
  }
@@ -2167,7 +1806,11 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
2167
1806
  size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
2168
1807
  GGML_UNUSED(seq_id);
2169
1808
 
2170
- kv_self->state_write(io, seq_id);
1809
+ if (memory) {
1810
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1811
+
1812
+ kv_self->state_write(io, seq_id);
1813
+ }
2171
1814
 
2172
1815
  return io.n_bytes();
2173
1816
  }
@@ -2175,7 +1818,11 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s
2175
1818
  size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
2176
1819
  GGML_UNUSED(seq_id);
2177
1820
 
2178
- kv_self->state_read(io, seq_id);
1821
+ if (memory) {
1822
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1823
+
1824
+ kv_self->state_read(io, seq_id);
1825
+ }
2179
1826
 
2180
1827
  return io.n_bytes();
2181
1828
  }
@@ -2203,6 +1850,215 @@ void llama_context::perf_reset() {
2203
1850
  t_p_eval_us = n_p_eval = 0;
2204
1851
  }
2205
1852
 
1853
+ //
1854
+ // training
1855
+ //
1856
+
1857
+ static void llama_set_param(struct ggml_tensor * tensor, llama_opt_param_filter param_filter, void * userdata) {
1858
+ if (!tensor || tensor->type != GGML_TYPE_F32) {
1859
+ return;
1860
+ }
1861
+ if (!param_filter(tensor, userdata)) {
1862
+ return;
1863
+ }
1864
+ if (strcmp(tensor->name, "token_embd.weight") == 0) {
1865
+ return; // FIXME
1866
+ }
1867
+ if (strcmp(tensor->name, "rope_freqs.weight") == 0) {
1868
+ return; // FIXME
1869
+ }
1870
+ ggml_set_param(tensor);
1871
+ }
1872
+
1873
+ void llama_context::opt_init(struct llama_model * model, struct llama_opt_params lopt_params) {
1874
+ GGML_ASSERT(!opt_ctx);
1875
+ model->hparams.n_ctx_train = lopt_params.n_ctx_train > 0 ? lopt_params.n_ctx_train : n_ctx();
1876
+ const uint32_t n_batch = std::min(this->n_batch(), model->hparams.n_ctx_train);
1877
+ const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
1878
+ GGML_ASSERT(model->hparams.n_ctx_train % n_batch == 0);
1879
+ GGML_ASSERT(n_batch % n_ubatch == 0);
1880
+
1881
+ ggml_opt_params opt_params = ggml_opt_default_params(sched.get(), GGML_OPT_LOSS_TYPE_CROSS_ENTROPY);
1882
+ opt_params.opt_period = n_batch / n_ubatch;
1883
+ opt_params.get_opt_pars = lopt_params.get_opt_pars;
1884
+ opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud;
1885
+
1886
+ opt_ctx = ggml_opt_init(opt_params);
1887
+
1888
+ llama_opt_param_filter param_filter = lopt_params.param_filter;
1889
+ void * param_filter_ud = lopt_params.param_filter_ud;
1890
+
1891
+ //llama_set_param(model->tok_embd, param_filter, param_filter_ud); // FIXME
1892
+ llama_set_param(model->type_embd, param_filter, param_filter_ud);
1893
+ llama_set_param(model->pos_embd, param_filter, param_filter_ud);
1894
+ llama_set_param(model->tok_norm, param_filter, param_filter_ud);
1895
+ llama_set_param(model->tok_norm_b, param_filter, param_filter_ud);
1896
+ llama_set_param(model->output_norm, param_filter, param_filter_ud);
1897
+ llama_set_param(model->output_norm_b, param_filter, param_filter_ud);
1898
+ llama_set_param(model->output, param_filter, param_filter_ud);
1899
+ llama_set_param(model->output_b, param_filter, param_filter_ud);
1900
+ llama_set_param(model->output_norm_enc, param_filter, param_filter_ud);
1901
+ llama_set_param(model->cls, param_filter, param_filter_ud);
1902
+ llama_set_param(model->cls_b, param_filter, param_filter_ud);
1903
+ llama_set_param(model->cls_out, param_filter, param_filter_ud);
1904
+ llama_set_param(model->cls_out_b, param_filter, param_filter_ud);
1905
+
1906
+ for (struct llama_layer & layer : model->layers) {
1907
+ for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) {
1908
+ llama_set_param(reinterpret_cast<struct ggml_tensor **>(&layer)[i], param_filter, param_filter_ud);
1909
+ }
1910
+ }
1911
+ }
1912
+
1913
+ void llama_context::opt_epoch_iter(
1914
+ ggml_opt_dataset_t dataset,
1915
+ ggml_opt_result_t result,
1916
+ const std::vector<llama_token> & tokens,
1917
+ const std::vector<llama_token> & labels_sparse,
1918
+ llama_batch & batch,
1919
+ ggml_opt_epoch_callback callback,
1920
+ bool train,
1921
+ int64_t idata_in_loop,
1922
+ int64_t ndata_in_loop,
1923
+ int64_t t_loop_start) {
1924
+ GGML_ASSERT(opt_ctx);
1925
+ const uint32_t n_ctx = llama_model_n_ctx_train(&model);
1926
+ const uint32_t n_batch = std::min(this->n_batch(), n_ctx);
1927
+ const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
1928
+
1929
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1930
+
1931
+ kv_self->clear();
1932
+ llama_kv_cache_guard kv_guard(kv_self);
1933
+
1934
+ for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
1935
+ batch.n_tokens = n_batch;
1936
+ for (uint32_t pos_batch = 0; pos_batch < n_batch; ++pos_batch) {
1937
+ batch.token [pos_batch] = tokens[pos_ctx + pos_batch];
1938
+ batch.pos [pos_batch] = pos_ctx + pos_batch;
1939
+ batch.n_seq_id[pos_batch] = 1;
1940
+ batch.seq_id [pos_batch][0] = 0;
1941
+ batch.logits [pos_batch] = true;
1942
+ }
1943
+
1944
+ const auto n_tokens_all = batch.n_tokens;
1945
+
1946
+ n_queued_tokens += n_tokens_all;
1947
+
1948
+ // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
1949
+ const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
1950
+
1951
+ embd_seq.clear();
1952
+
1953
+ int64_t n_outputs_all = n_tokens_all;
1954
+
1955
+ llama_sbatch sbatch = kv_self->sbatch_init(batch, /*logits_all =*/ true);
1956
+
1957
+ // reserve output buffer
1958
+ if (output_reserve(n_outputs_all) < n_outputs_all) {
1959
+ LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
1960
+ GGML_ABORT("TODO: handle this error");
1961
+ };
1962
+
1963
+ for (uint32_t pos_batch = 0; pos_batch < n_batch; pos_batch += n_ubatch) {
1964
+ llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
1965
+
1966
+ n_outputs = ubatch.n_tokens;
1967
+
1968
+ // TODO: not sure if this is needed
1969
+ if (!kv_self->find_slot(ubatch)) {
1970
+ LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
1971
+
1972
+ GGML_ABORT("TODO: handle this error");
1973
+ }
1974
+
1975
+ auto * gf = graph_init();
1976
+ auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
1977
+
1978
+ struct ggml_context * ctx_compute_opt;
1979
+ {
1980
+ const size_t size_gf = ggml_graph_size(gf);
1981
+ const size_t size_meta = 4*size_gf*ggml_tensor_overhead() + 2*ggml_graph_overhead_custom(size_gf, /*grads = */ true);
1982
+ struct ggml_init_params params = {
1983
+ /*.mem_size =*/ size_meta,
1984
+ /*.mem_buffer =*/ nullptr,
1985
+ /*.no_alloc =*/ true,
1986
+ };
1987
+ ctx_compute_opt = ggml_init(params);
1988
+ }
1989
+ ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits());
1990
+ ggml_opt_alloc(opt_ctx, train);
1991
+ res->set_inputs(&ubatch);
1992
+ {
1993
+ struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
1994
+ GGML_ASSERT(labels->ne[1] == n_ubatch);
1995
+ ggml_set_zero(labels);
1996
+ const float onef = 1.0f;
1997
+ for (uint32_t pos_ubatch = 0; pos_ubatch < n_ubatch; ++pos_ubatch) {
1998
+ const uint32_t ilabel = pos_ctx + pos_batch + pos_ubatch;
1999
+ GGML_ASSERT(labels_sparse[ilabel] < labels->ne[0]);
2000
+ ggml_backend_tensor_set(labels, &onef, (pos_ubatch*labels->ne[0] + labels_sparse[ilabel])*sizeof(float), sizeof(float));
2001
+ }
2002
+ }
2003
+ ggml_opt_eval(opt_ctx, result);
2004
+ if (callback) {
2005
+ callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start);
2006
+ }
2007
+ ggml_free(ctx_compute_opt);
2008
+ }
2009
+ }
2010
+
2011
+ kv_guard.commit();
2012
+ }
2013
+
2014
+ void llama_context::opt_epoch(
2015
+ ggml_opt_dataset_t dataset,
2016
+ ggml_opt_result_t result_train,
2017
+ ggml_opt_result_t result_eval,
2018
+ int64_t idata_split,
2019
+ ggml_opt_epoch_callback callback_train,
2020
+ ggml_opt_epoch_callback callback_eval) {
2021
+ const uint32_t n_ctx = this->n_ctx();
2022
+ const uint32_t n_batch = std::min(cparams.n_batch, n_ctx);
2023
+ const uint32_t n_ubatch = std::min(cparams.n_ubatch, n_batch);
2024
+ const int64_t ndata = ggml_opt_dataset_ndata(dataset);
2025
+
2026
+ GGML_ASSERT(idata_split >= 0);
2027
+ GGML_ASSERT(idata_split <= ndata);
2028
+
2029
+ const uint32_t ubatch_per_ctx = n_ctx / n_ubatch;
2030
+
2031
+ struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
2032
+ std::vector<llama_token> tokens(n_ctx);
2033
+ std::vector<llama_token> labels_sparse(n_ctx);
2034
+
2035
+ int64_t idata = 0;
2036
+
2037
+ int64_t t_loop_start = ggml_time_us();
2038
+ int64_t ndata_in_loop = idata_split*ubatch_per_ctx;
2039
+ for (; idata < idata_split; ++idata) {
2040
+ constexpr bool train = true;
2041
+ const int64_t idata_in_loop = idata*ubatch_per_ctx;
2042
+
2043
+ ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata);
2044
+ opt_epoch_iter(dataset, result_train, tokens, labels_sparse, batch,
2045
+ callback_train, train, idata_in_loop, ndata_in_loop, t_loop_start);
2046
+ }
2047
+
2048
+ t_loop_start = ggml_time_us();
2049
+ ndata_in_loop = (ndata - idata_split)*ubatch_per_ctx;
2050
+ for (; idata < ndata; ++idata) {
2051
+ constexpr bool train = false;
2052
+ const int64_t idata_in_loop = (idata - idata_split)*ubatch_per_ctx;
2053
+
2054
+ ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata);
2055
+ opt_epoch_iter(dataset, result_eval, tokens, labels_sparse, batch,
2056
+ callback_eval, train, idata_in_loop, ndata_in_loop, t_loop_start);
2057
+ }
2058
+
2059
+ llama_batch_free(batch);
2060
+ }
2061
+
2206
2062
  //
2207
2063
  // interface implementation
2208
2064
  //
@@ -2230,13 +2086,13 @@ llama_context_params llama_context_default_params() {
2230
2086
  /*.cb_eval_user_data =*/ nullptr,
2231
2087
  /*.type_k =*/ GGML_TYPE_F16,
2232
2088
  /*.type_v =*/ GGML_TYPE_F16,
2233
- /*.logits_all =*/ false,
2089
+ /*.abort_callback =*/ nullptr,
2090
+ /*.abort_callback_data =*/ nullptr,
2234
2091
  /*.embeddings =*/ false,
2235
2092
  /*.offload_kqv =*/ true,
2236
2093
  /*.flash_attn =*/ false,
2237
2094
  /*.no_perf =*/ true,
2238
- /*.abort_callback =*/ nullptr,
2239
- /*.abort_callback_data =*/ nullptr,
2095
+ /*.op_offload =*/ true,
2240
2096
  };
2241
2097
 
2242
2098
  return result;
@@ -2530,7 +2386,7 @@ void llama_kv_cache_seq_cp(
2530
2386
  llama_seq_id seq_id_dst,
2531
2387
  llama_pos p0,
2532
2388
  llama_pos p1) {
2533
- return llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
2389
+ llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
2534
2390
  }
2535
2391
 
2536
2392
  void llama_kv_self_seq_cp(
@@ -2544,14 +2400,14 @@ void llama_kv_self_seq_cp(
2544
2400
  return;
2545
2401
  }
2546
2402
 
2547
- return kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
2403
+ kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
2548
2404
  }
2549
2405
 
2550
2406
  // deprecated
2551
2407
  void llama_kv_cache_seq_keep(
2552
2408
  llama_context * ctx,
2553
2409
  llama_seq_id seq_id) {
2554
- return llama_kv_self_seq_keep(ctx, seq_id);
2410
+ llama_kv_self_seq_keep(ctx, seq_id);
2555
2411
  }
2556
2412
 
2557
2413
  void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
@@ -2560,7 +2416,7 @@ void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
2560
2416
  return;
2561
2417
  }
2562
2418
 
2563
- return kv->seq_keep(seq_id);
2419
+ kv->seq_keep(seq_id);
2564
2420
  }
2565
2421
 
2566
2422
  // deprecated
@@ -2570,7 +2426,7 @@ void llama_kv_cache_seq_add(
2570
2426
  llama_pos p0,
2571
2427
  llama_pos p1,
2572
2428
  llama_pos delta) {
2573
- return llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
2429
+ llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
2574
2430
  }
2575
2431
 
2576
2432
  void llama_kv_self_seq_add(
@@ -2584,7 +2440,7 @@ void llama_kv_self_seq_add(
2584
2440
  return;
2585
2441
  }
2586
2442
 
2587
- return kv->seq_add(seq_id, p0, p1, delta);
2443
+ kv->seq_add(seq_id, p0, p1, delta);
2588
2444
  }
2589
2445
 
2590
2446
  // deprecated
@@ -2594,7 +2450,7 @@ void llama_kv_cache_seq_div(
2594
2450
  llama_pos p0,
2595
2451
  llama_pos p1,
2596
2452
  int d) {
2597
- return llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
2453
+ llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
2598
2454
  }
2599
2455
 
2600
2456
  void llama_kv_self_seq_div(
@@ -2608,7 +2464,7 @@ void llama_kv_self_seq_div(
2608
2464
  return;
2609
2465
  }
2610
2466
 
2611
- return kv->seq_div(seq_id, p0, p1, d);
2467
+ kv->seq_div(seq_id, p0, p1, d);
2612
2468
  }
2613
2469
 
2614
2470
  // deprecated
@@ -2627,7 +2483,7 @@ llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
2627
2483
 
2628
2484
  // deprecated
2629
2485
  void llama_kv_cache_defrag(llama_context * ctx) {
2630
- return llama_kv_self_defrag(ctx);
2486
+ llama_kv_self_defrag(ctx);
2631
2487
  }
2632
2488
 
2633
2489
  void llama_kv_self_defrag(llama_context * ctx) {
@@ -2636,7 +2492,8 @@ void llama_kv_self_defrag(llama_context * ctx) {
2636
2492
  return;
2637
2493
  }
2638
2494
 
2639
- return kv->defrag();
2495
+ // force defrag
2496
+ kv->defrag_sched(-1.0f);
2640
2497
  }
2641
2498
 
2642
2499
  // deprecated
@@ -2820,3 +2677,34 @@ void llama_perf_context_print(const llama_context * ctx) {
2820
2677
  void llama_perf_context_reset(llama_context * ctx) {
2821
2678
  ctx->perf_reset();
2822
2679
  }
2680
+
2681
+ //
2682
+ // training
2683
+ //
2684
+
2685
+ bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata) {
2686
+ GGML_UNUSED(tensor);
2687
+ GGML_UNUSED(userdata);
2688
+ return true;
2689
+ }
2690
+
2691
+ void llama_opt_init(struct llama_context * ctx, struct llama_model * model, struct llama_opt_params lopt_params) {
2692
+ ctx->opt_init(model, lopt_params);
2693
+ }
2694
+
2695
+ void llama_opt_epoch(
2696
+ struct llama_context * ctx,
2697
+ ggml_opt_dataset_t dataset,
2698
+ ggml_opt_result_t result_train,
2699
+ ggml_opt_result_t result_eval,
2700
+ int64_t idata_split,
2701
+ ggml_opt_epoch_callback callback_train,
2702
+ ggml_opt_epoch_callback callback_eval) {
2703
+ ctx->opt_epoch(
2704
+ dataset,
2705
+ result_train,
2706
+ result_eval,
2707
+ idata_split,
2708
+ callback_train,
2709
+ callback_eval);
2710
+ }