cui-llama.rn 1.6.1 → 1.7.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. package/android/src/main/CMakeLists.txt +6 -0
  2. package/android/src/main/java/com/rnllama/LlamaContext.java +38 -5
  3. package/android/src/main/java/com/rnllama/RNLlama.java +139 -4
  4. package/android/src/main/jni.cpp +153 -14
  5. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  10. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  11. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  12. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  13. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +24 -4
  14. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +22 -2
  15. package/cpp/chat.cpp +128 -106
  16. package/cpp/chat.h +2 -0
  17. package/cpp/common.cpp +41 -76
  18. package/cpp/common.h +23 -19
  19. package/cpp/ggml-backend.cpp +9 -5
  20. package/cpp/ggml-backend.h +4 -4
  21. package/cpp/ggml-cpu/ggml-cpu-aarch64.cpp +0 -2
  22. package/cpp/ggml-cpu/ggml-cpu-quants.c +306 -6
  23. package/cpp/ggml-cpu/ggml-cpu.c +5 -13
  24. package/cpp/ggml-cpu/ggml-cpu.cpp +29 -16
  25. package/cpp/ggml-cpu/ops.cpp +107 -13
  26. package/cpp/ggml-cpu/vec.cpp +0 -6
  27. package/cpp/ggml-cpu/vec.h +16 -0
  28. package/cpp/ggml-llama-sim.metallib +0 -0
  29. package/cpp/ggml-llama.metallib +0 -0
  30. package/cpp/ggml-metal-impl.h +36 -11
  31. package/cpp/ggml-metal.m +321 -132
  32. package/cpp/ggml-opt.cpp +373 -190
  33. package/cpp/ggml-opt.h +49 -28
  34. package/cpp/ggml-quants.c +0 -6
  35. package/cpp/ggml.c +93 -38
  36. package/cpp/ggml.h +21 -7
  37. package/cpp/gguf.cpp +33 -33
  38. package/cpp/llama-adapter.cpp +6 -0
  39. package/cpp/llama-arch.cpp +3 -0
  40. package/cpp/llama-batch.cpp +3 -1
  41. package/cpp/llama-chat.cpp +8 -6
  42. package/cpp/llama-chat.h +1 -0
  43. package/cpp/llama-context.cpp +349 -135
  44. package/cpp/llama-context.h +30 -3
  45. package/cpp/llama-cparams.h +1 -0
  46. package/cpp/llama-graph.cpp +150 -234
  47. package/cpp/llama-graph.h +52 -7
  48. package/cpp/llama-hparams.cpp +17 -1
  49. package/cpp/llama-hparams.h +34 -5
  50. package/cpp/llama-kv-cache.cpp +662 -321
  51. package/cpp/llama-kv-cache.h +203 -93
  52. package/cpp/llama-memory.h +3 -2
  53. package/cpp/llama-model-loader.cpp +24 -15
  54. package/cpp/llama-model-saver.cpp +281 -0
  55. package/cpp/llama-model-saver.h +37 -0
  56. package/cpp/llama-model.cpp +536 -132
  57. package/cpp/llama-model.h +7 -1
  58. package/cpp/llama-sampling.cpp +18 -6
  59. package/cpp/llama-vocab.cpp +46 -8
  60. package/cpp/llama-vocab.h +6 -0
  61. package/cpp/llama.cpp +14 -0
  62. package/cpp/llama.h +72 -131
  63. package/cpp/minja/chat-template.hpp +9 -5
  64. package/cpp/minja/minja.hpp +69 -36
  65. package/cpp/rn-llama.cpp +611 -47
  66. package/cpp/rn-llama.h +33 -3
  67. package/cpp/sampling.cpp +57 -50
  68. package/cpp/tools/mtmd/clip-impl.h +462 -0
  69. package/cpp/tools/mtmd/clip.cpp +4024 -0
  70. package/cpp/tools/mtmd/clip.h +101 -0
  71. package/cpp/tools/mtmd/miniaudio.h +93468 -0
  72. package/cpp/tools/mtmd/mtmd-audio.cpp +855 -0
  73. package/cpp/tools/mtmd/mtmd-audio.h +62 -0
  74. package/cpp/tools/mtmd/mtmd-helper.cpp +297 -0
  75. package/cpp/tools/mtmd/mtmd.cpp +942 -0
  76. package/cpp/tools/mtmd/mtmd.h +362 -0
  77. package/cpp/tools/mtmd/stb_image.h +7988 -0
  78. package/ios/CMakeLists.txt +7 -0
  79. package/ios/RNLlama.mm +77 -3
  80. package/ios/RNLlamaContext.h +5 -1
  81. package/ios/RNLlamaContext.mm +105 -10
  82. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat.h +2 -0
  83. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +23 -19
  84. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
  85. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  86. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
  87. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +21 -7
  88. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
  89. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +30 -3
  90. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
  91. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +52 -7
  92. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +34 -5
  93. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +203 -93
  94. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +3 -2
  95. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
  96. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +7 -1
  97. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
  98. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +72 -131
  99. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  100. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
  101. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +33 -3
  102. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Info.plist +0 -0
  103. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  104. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
  105. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
  106. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +23 -19
  107. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
  108. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  109. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
  110. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +21 -7
  111. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
  112. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +30 -3
  113. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
  114. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +52 -7
  115. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +34 -5
  116. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +203 -93
  117. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +3 -2
  118. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
  119. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +7 -1
  120. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
  121. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +72 -131
  122. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  123. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
  124. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +33 -3
  125. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
  126. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
  127. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  128. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  129. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat.h +2 -0
  130. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +23 -19
  131. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
  132. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  133. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
  134. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +21 -7
  135. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
  136. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +30 -3
  137. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
  138. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +52 -7
  139. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +34 -5
  140. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +203 -93
  141. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +3 -2
  142. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
  143. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +7 -1
  144. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
  145. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +72 -131
  146. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  147. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
  148. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +33 -3
  149. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Info.plist +0 -0
  150. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  151. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
  152. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
  153. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +23 -19
  154. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
  155. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  156. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
  157. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +21 -7
  158. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
  159. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +30 -3
  160. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
  161. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +52 -7
  162. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +34 -5
  163. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +203 -93
  164. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +3 -2
  165. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
  166. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +7 -1
  167. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
  168. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +72 -131
  169. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  170. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
  171. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +33 -3
  172. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
  173. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
  174. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  175. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  176. package/jest/mock.js +33 -7
  177. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  178. package/lib/commonjs/index.js +153 -21
  179. package/lib/commonjs/index.js.map +1 -1
  180. package/lib/module/NativeRNLlama.js.map +1 -1
  181. package/lib/module/index.js +152 -20
  182. package/lib/module/index.js.map +1 -1
  183. package/lib/typescript/NativeRNLlama.d.ts +50 -4
  184. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  185. package/lib/typescript/index.d.ts +72 -6
  186. package/lib/typescript/index.d.ts.map +1 -1
  187. package/package.json +1 -1
  188. package/src/NativeRNLlama.ts +67 -4
  189. package/src/index.ts +212 -38
  190. package/lib/commonjs/chat.js +0 -37
  191. package/lib/commonjs/chat.js.map +0 -1
  192. package/lib/module/chat.js +0 -33
  193. package/lib/module/chat.js.map +0 -1
  194. package/lib/typescript/chat.d.ts +0 -10
  195. package/lib/typescript/chat.d.ts.map +0 -1
  196. package/src/chat.ts +0 -44
@@ -94,6 +94,8 @@ llama_context::llama_context(
94
94
 
95
95
  cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
96
96
 
97
+ cparams.op_offload = params.op_offload;
98
+
97
99
  const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
98
100
 
99
101
  LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max);
@@ -116,8 +118,6 @@ llama_context::llama_context(
116
118
  __func__, n_ctx_per_seq, hparams.n_ctx_train);
117
119
  }
118
120
 
119
- logits_all = params.logits_all;
120
-
121
121
  if (!hparams.vocab_only) {
122
122
  // GPU backends
123
123
  for (auto * dev : model.devices) {
@@ -177,8 +177,9 @@ llama_context::llama_context(
177
177
  // init the memory module
178
178
  if (!hparams.vocab_only) {
179
179
  llama_memory_params params_mem = {
180
- /*.type_k =*/ params.type_k,
181
- /*.type_v =*/ params.type_v,
180
+ /*.type_k =*/ params.type_k,
181
+ /*.type_v =*/ params.type_v,
182
+ /*.swa_full =*/ params.swa_full,
182
183
  };
183
184
 
184
185
  memory.reset(model.create_memory(params_mem, cparams));
@@ -245,7 +246,7 @@ llama_context::llama_context(
245
246
  }
246
247
  }
247
248
 
248
- sched.reset(lm_ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel));
249
+ sched.reset(lm_ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel, cparams.op_offload));
249
250
 
250
251
  if (pipeline_parallel) {
251
252
  LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, lm_ggml_backend_sched_get_n_copies(sched.get()));
@@ -253,7 +254,7 @@ llama_context::llama_context(
253
254
  }
254
255
 
255
256
  // reserve worst-case graph
256
- if (!hparams.vocab_only) {
257
+ if (!hparams.vocab_only && memory) {
257
258
  const uint32_t n_seqs = 1; // TODO: worst-case number of sequences
258
259
  const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
259
260
 
@@ -360,7 +361,9 @@ llama_context::llama_context(
360
361
  }
361
362
  }
362
363
 
363
- llama_context::~llama_context() = default;
364
+ llama_context::~llama_context() {
365
+ lm_ggml_opt_free(opt_ctx);
366
+ }
364
367
 
365
368
  void llama_context::synchronize() {
366
369
  lm_ggml_backend_sched_synchronize(sched.get());
@@ -702,6 +705,8 @@ int llama_context::encode(llama_batch & inp_batch) {
702
705
  t_compute_start_us = lm_ggml_time_us();
703
706
  }
704
707
 
708
+ embd_seq.clear();
709
+
705
710
  n_queued_tokens += n_tokens;
706
711
 
707
712
  const int64_t n_embd = hparams.n_embd;
@@ -763,12 +768,12 @@ int llama_context::encode(llama_batch & inp_batch) {
763
768
  lm_ggml_backend_t backend_embd = lm_ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
764
769
  LM_GGML_ASSERT(backend_embd != nullptr);
765
770
 
766
- LM_GGML_ASSERT(embd != nullptr);
767
-
768
771
  switch (cparams.pooling_type) {
769
772
  case LLAMA_POOLING_TYPE_NONE:
770
773
  {
771
774
  // extract token embeddings
775
+ LM_GGML_ASSERT(embd != nullptr);
776
+
772
777
  LM_GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size);
773
778
  lm_ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float));
774
779
  } break;
@@ -793,11 +798,18 @@ int llama_context::encode(llama_batch & inp_batch) {
793
798
  } break;
794
799
  case LLAMA_POOLING_TYPE_RANK:
795
800
  {
796
- // TODO: this likely should be the same logic as in llama_decoder_internal, but better to
797
- // wait for an encoder model that requires this pooling type in order to test it
798
- // https://github.com/ggerganov/llama.cpp/pull/9510
799
- LM_GGML_ABORT("RANK pooling not implemented yet");
800
- }
801
+ // extract the rerank score - a single float per sequence
802
+ auto & embd_seq_out = embd_seq;
803
+
804
+ for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
805
+ const llama_seq_id seq_id = ubatch.seq_id[s][0];
806
+ if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
807
+ continue;
808
+ }
809
+ embd_seq_out[seq_id].resize(1);
810
+ lm_ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
811
+ }
812
+ } break;
801
813
  case LLAMA_POOLING_TYPE_UNSPECIFIED:
802
814
  {
803
815
  LM_GGML_ABORT("unknown pooling type");
@@ -835,16 +847,27 @@ int llama_context::encode(llama_batch & inp_batch) {
835
847
  }
836
848
 
837
849
  int llama_context::decode(llama_batch & inp_batch) {
850
+ if (!memory) {
851
+ LLAMA_LOG_WARN("%s: cannot decode batches with this context (use llama_encode() instead)\n", __func__);
852
+ return encode(inp_batch);
853
+ }
854
+
838
855
  if (inp_batch.n_tokens == 0) {
839
856
  LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
840
857
  return -1;
841
858
  }
842
859
 
860
+ if (!inp_batch.pos) {
861
+ if (inp_batch.seq_id) {
862
+ LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
863
+ return -1;
864
+ }
865
+ }
866
+
843
867
  llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
844
868
 
845
869
  // temporary allocate memory for the input batch if needed
846
- // TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences
847
- llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max() + 1);
870
+ llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->seq_pos_max(0) + 1);
848
871
 
849
872
  const llama_batch & batch = batch_allocr.batch;
850
873
 
@@ -890,7 +913,7 @@ int llama_context::decode(llama_batch & inp_batch) {
890
913
  for (uint32_t i = 0; i < n_tokens_all; ++i) {
891
914
  n_outputs_all += batch.logits[i] != 0;
892
915
  }
893
- } else if (logits_all || embd_pooled) {
916
+ } else if (embd_pooled) {
894
917
  n_outputs_all = n_tokens_all;
895
918
  } else {
896
919
  // keep last output only
@@ -932,8 +955,6 @@ int llama_context::decode(llama_batch & inp_batch) {
932
955
 
933
956
  // find KV slot
934
957
  if (!kv_self->find_slot(ubatch)) {
935
- LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
936
-
937
958
  return 1;
938
959
  }
939
960
 
@@ -1689,10 +1710,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
1689
1710
  }
1690
1711
  }
1691
1712
 
1692
- LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
1693
1713
  llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1694
1714
 
1695
- kv_self->state_write(io);
1715
+ if (kv_self != nullptr) {
1716
+ LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
1717
+ kv_self->state_write(io);
1718
+ }
1696
1719
 
1697
1720
  return io.n_bytes();
1698
1721
  }
@@ -1775,10 +1798,13 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
1775
1798
  }
1776
1799
  }
1777
1800
 
1778
- LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
1779
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1801
+ if (memory) {
1802
+ LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
1780
1803
 
1781
- kv_self->state_read(io);
1804
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1805
+
1806
+ kv_self->state_read(io);
1807
+ }
1782
1808
 
1783
1809
  return io.n_bytes();
1784
1810
  }
@@ -1786,9 +1812,11 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
1786
1812
  size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
1787
1813
  LM_GGML_UNUSED(seq_id);
1788
1814
 
1789
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1815
+ if (memory) {
1816
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1790
1817
 
1791
- kv_self->state_write(io, seq_id);
1818
+ kv_self->state_write(io, seq_id);
1819
+ }
1792
1820
 
1793
1821
  return io.n_bytes();
1794
1822
  }
@@ -1796,9 +1824,11 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s
1796
1824
  size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
1797
1825
  LM_GGML_UNUSED(seq_id);
1798
1826
 
1799
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1827
+ if (memory) {
1828
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1800
1829
 
1801
- kv_self->state_read(io, seq_id);
1830
+ kv_self->state_read(io, seq_id);
1831
+ }
1802
1832
 
1803
1833
  return io.n_bytes();
1804
1834
  }
@@ -1826,6 +1856,215 @@ void llama_context::perf_reset() {
1826
1856
  t_p_eval_us = n_p_eval = 0;
1827
1857
  }
1828
1858
 
1859
+ //
1860
+ // training
1861
+ //
1862
+
1863
+ static void llama_set_param(struct lm_ggml_tensor * tensor, llama_opt_param_filter param_filter, void * userdata) {
1864
+ if (!tensor || tensor->type != LM_GGML_TYPE_F32) {
1865
+ return;
1866
+ }
1867
+ if (!param_filter(tensor, userdata)) {
1868
+ return;
1869
+ }
1870
+ if (strcmp(tensor->name, "token_embd.weight") == 0) {
1871
+ return; // FIXME
1872
+ }
1873
+ if (strcmp(tensor->name, "rope_freqs.weight") == 0) {
1874
+ return; // FIXME
1875
+ }
1876
+ lm_ggml_set_param(tensor);
1877
+ }
1878
+
1879
+ void llama_context::opt_init(struct llama_model * model, struct llama_opt_params lopt_params) {
1880
+ LM_GGML_ASSERT(!opt_ctx);
1881
+ model->hparams.n_ctx_train = lopt_params.n_ctx_train > 0 ? lopt_params.n_ctx_train : n_ctx();
1882
+ const uint32_t n_batch = std::min(this->n_batch(), model->hparams.n_ctx_train);
1883
+ const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
1884
+ LM_GGML_ASSERT(model->hparams.n_ctx_train % n_batch == 0);
1885
+ LM_GGML_ASSERT(n_batch % n_ubatch == 0);
1886
+
1887
+ lm_ggml_opt_params opt_params = lm_ggml_opt_default_params(sched.get(), LM_GGML_OPT_LOSS_TYPE_CROSS_ENTROPY);
1888
+ opt_params.opt_period = n_batch / n_ubatch;
1889
+ opt_params.get_opt_pars = lopt_params.get_opt_pars;
1890
+ opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud;
1891
+
1892
+ opt_ctx = lm_ggml_opt_init(opt_params);
1893
+
1894
+ llama_opt_param_filter param_filter = lopt_params.param_filter;
1895
+ void * param_filter_ud = lopt_params.param_filter_ud;
1896
+
1897
+ //llama_set_param(model->tok_embd, param_filter, param_filter_ud); // FIXME
1898
+ llama_set_param(model->type_embd, param_filter, param_filter_ud);
1899
+ llama_set_param(model->pos_embd, param_filter, param_filter_ud);
1900
+ llama_set_param(model->tok_norm, param_filter, param_filter_ud);
1901
+ llama_set_param(model->tok_norm_b, param_filter, param_filter_ud);
1902
+ llama_set_param(model->output_norm, param_filter, param_filter_ud);
1903
+ llama_set_param(model->output_norm_b, param_filter, param_filter_ud);
1904
+ llama_set_param(model->output, param_filter, param_filter_ud);
1905
+ llama_set_param(model->output_b, param_filter, param_filter_ud);
1906
+ llama_set_param(model->output_norm_enc, param_filter, param_filter_ud);
1907
+ llama_set_param(model->cls, param_filter, param_filter_ud);
1908
+ llama_set_param(model->cls_b, param_filter, param_filter_ud);
1909
+ llama_set_param(model->cls_out, param_filter, param_filter_ud);
1910
+ llama_set_param(model->cls_out_b, param_filter, param_filter_ud);
1911
+
1912
+ for (struct llama_layer & layer : model->layers) {
1913
+ for (size_t i = 0; i < sizeof(layer)/sizeof(struct lm_ggml_tensor *); ++i) {
1914
+ llama_set_param(reinterpret_cast<struct lm_ggml_tensor **>(&layer)[i], param_filter, param_filter_ud);
1915
+ }
1916
+ }
1917
+ }
1918
+
1919
+ void llama_context::opt_epoch_iter(
1920
+ lm_ggml_opt_dataset_t dataset,
1921
+ lm_ggml_opt_result_t result,
1922
+ const std::vector<llama_token> & tokens,
1923
+ const std::vector<llama_token> & labels_sparse,
1924
+ llama_batch & batch,
1925
+ lm_ggml_opt_epoch_callback callback,
1926
+ bool train,
1927
+ int64_t idata_in_loop,
1928
+ int64_t ndata_in_loop,
1929
+ int64_t t_loop_start) {
1930
+ LM_GGML_ASSERT(opt_ctx);
1931
+ const uint32_t n_ctx = llama_model_n_ctx_train(&model);
1932
+ const uint32_t n_batch = std::min(this->n_batch(), n_ctx);
1933
+ const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
1934
+
1935
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1936
+
1937
+ kv_self->clear();
1938
+ llama_kv_cache_guard kv_guard(kv_self);
1939
+
1940
+ for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
1941
+ batch.n_tokens = n_batch;
1942
+ for (uint32_t pos_batch = 0; pos_batch < n_batch; ++pos_batch) {
1943
+ batch.token [pos_batch] = tokens[pos_ctx + pos_batch];
1944
+ batch.pos [pos_batch] = pos_ctx + pos_batch;
1945
+ batch.n_seq_id[pos_batch] = 1;
1946
+ batch.seq_id [pos_batch][0] = 0;
1947
+ batch.logits [pos_batch] = true;
1948
+ }
1949
+
1950
+ const auto n_tokens_all = batch.n_tokens;
1951
+
1952
+ n_queued_tokens += n_tokens_all;
1953
+
1954
+ // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
1955
+ const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
1956
+
1957
+ embd_seq.clear();
1958
+
1959
+ int64_t n_outputs_all = n_tokens_all;
1960
+
1961
+ llama_sbatch sbatch = kv_self->sbatch_init(batch, /*logits_all =*/ true);
1962
+
1963
+ // reserve output buffer
1964
+ if (output_reserve(n_outputs_all) < n_outputs_all) {
1965
+ LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
1966
+ LM_GGML_ABORT("TODO: handle this error");
1967
+ };
1968
+
1969
+ for (uint32_t pos_batch = 0; pos_batch < n_batch; pos_batch += n_ubatch) {
1970
+ llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
1971
+
1972
+ n_outputs = ubatch.n_tokens;
1973
+
1974
+ // TODO: not sure if this is needed
1975
+ if (!kv_self->find_slot(ubatch)) {
1976
+ LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
1977
+
1978
+ LM_GGML_ABORT("TODO: handle this error");
1979
+ }
1980
+
1981
+ auto * gf = graph_init();
1982
+ auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
1983
+
1984
+ struct lm_ggml_context * ctx_compute_opt;
1985
+ {
1986
+ const size_t size_gf = lm_ggml_graph_size(gf);
1987
+ const size_t size_meta = 4*size_gf*lm_ggml_tensor_overhead() + 2*lm_ggml_graph_overhead_custom(size_gf, /*grads = */ true);
1988
+ struct lm_ggml_init_params params = {
1989
+ /*.mem_size =*/ size_meta,
1990
+ /*.mem_buffer =*/ nullptr,
1991
+ /*.no_alloc =*/ true,
1992
+ };
1993
+ ctx_compute_opt = lm_ggml_init(params);
1994
+ }
1995
+ lm_ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits());
1996
+ lm_ggml_opt_alloc(opt_ctx, train);
1997
+ res->set_inputs(&ubatch);
1998
+ {
1999
+ struct lm_ggml_tensor * labels = lm_ggml_opt_labels(opt_ctx);
2000
+ LM_GGML_ASSERT(labels->ne[1] == n_ubatch);
2001
+ lm_ggml_set_zero(labels);
2002
+ const float onef = 1.0f;
2003
+ for (uint32_t pos_ubatch = 0; pos_ubatch < n_ubatch; ++pos_ubatch) {
2004
+ const uint32_t ilabel = pos_ctx + pos_batch + pos_ubatch;
2005
+ LM_GGML_ASSERT(labels_sparse[ilabel] < labels->ne[0]);
2006
+ lm_ggml_backend_tensor_set(labels, &onef, (pos_ubatch*labels->ne[0] + labels_sparse[ilabel])*sizeof(float), sizeof(float));
2007
+ }
2008
+ }
2009
+ lm_ggml_opt_eval(opt_ctx, result);
2010
+ if (callback) {
2011
+ callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start);
2012
+ }
2013
+ lm_ggml_free(ctx_compute_opt);
2014
+ }
2015
+ }
2016
+
2017
+ kv_guard.commit();
2018
+ }
2019
+
2020
+ void llama_context::opt_epoch(
2021
+ lm_ggml_opt_dataset_t dataset,
2022
+ lm_ggml_opt_result_t result_train,
2023
+ lm_ggml_opt_result_t result_eval,
2024
+ int64_t idata_split,
2025
+ lm_ggml_opt_epoch_callback callback_train,
2026
+ lm_ggml_opt_epoch_callback callback_eval) {
2027
+ const uint32_t n_ctx = this->n_ctx();
2028
+ const uint32_t n_batch = std::min(cparams.n_batch, n_ctx);
2029
+ const uint32_t n_ubatch = std::min(cparams.n_ubatch, n_batch);
2030
+ const int64_t ndata = lm_ggml_opt_dataset_ndata(dataset);
2031
+
2032
+ LM_GGML_ASSERT(idata_split >= 0);
2033
+ LM_GGML_ASSERT(idata_split <= ndata);
2034
+
2035
+ const uint32_t ubatch_per_ctx = n_ctx / n_ubatch;
2036
+
2037
+ struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
2038
+ std::vector<llama_token> tokens(n_ctx);
2039
+ std::vector<llama_token> labels_sparse(n_ctx);
2040
+
2041
+ int64_t idata = 0;
2042
+
2043
+ int64_t t_loop_start = lm_ggml_time_us();
2044
+ int64_t ndata_in_loop = idata_split*ubatch_per_ctx;
2045
+ for (; idata < idata_split; ++idata) {
2046
+ constexpr bool train = true;
2047
+ const int64_t idata_in_loop = idata*ubatch_per_ctx;
2048
+
2049
+ lm_ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata);
2050
+ opt_epoch_iter(dataset, result_train, tokens, labels_sparse, batch,
2051
+ callback_train, train, idata_in_loop, ndata_in_loop, t_loop_start);
2052
+ }
2053
+
2054
+ t_loop_start = lm_ggml_time_us();
2055
+ ndata_in_loop = (ndata - idata_split)*ubatch_per_ctx;
2056
+ for (; idata < ndata; ++idata) {
2057
+ constexpr bool train = false;
2058
+ const int64_t idata_in_loop = (idata - idata_split)*ubatch_per_ctx;
2059
+
2060
+ lm_ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata);
2061
+ opt_epoch_iter(dataset, result_eval, tokens, labels_sparse, batch,
2062
+ callback_eval, train, idata_in_loop, ndata_in_loop, t_loop_start);
2063
+ }
2064
+
2065
+ llama_batch_free(batch);
2066
+ }
2067
+
1829
2068
  //
1830
2069
  // interface implementation
1831
2070
  //
@@ -1853,13 +2092,14 @@ llama_context_params llama_context_default_params() {
1853
2092
  /*.cb_eval_user_data =*/ nullptr,
1854
2093
  /*.type_k =*/ LM_GGML_TYPE_F16,
1855
2094
  /*.type_v =*/ LM_GGML_TYPE_F16,
1856
- /*.logits_all =*/ false,
2095
+ /*.abort_callback =*/ nullptr,
2096
+ /*.abort_callback_data =*/ nullptr,
1857
2097
  /*.embeddings =*/ false,
1858
2098
  /*.offload_kqv =*/ true,
1859
2099
  /*.flash_attn =*/ false,
1860
2100
  /*.no_perf =*/ true,
1861
- /*.abort_callback =*/ nullptr,
1862
- /*.abort_callback_data =*/ nullptr,
2101
+ /*.op_offload =*/ true,
2102
+ /*.swa_full =*/ true,
1863
2103
  };
1864
2104
 
1865
2105
  return result;
@@ -2054,65 +2294,51 @@ int32_t llama_apply_adapter_cvec(
2054
2294
  return res ? 0 : -1;
2055
2295
  }
2056
2296
 
2057
- //
2058
- // kv cache view
2059
- //
2060
-
2061
- llama_kv_cache_view llama_kv_cache_view_init(const llama_context * ctx, int32_t n_seq_max) {
2062
- const auto * kv = ctx->get_kv_self();
2063
- if (kv == nullptr) {
2064
- LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__);
2065
- return {};
2066
- }
2067
-
2068
- return llama_kv_cache_view_init(*kv, n_seq_max);
2069
- }
2070
-
2071
- void llama_kv_cache_view_update(const llama_context * ctx, llama_kv_cache_view * view) {
2072
- const auto * kv = ctx->get_kv_self();
2073
- if (kv == nullptr) {
2074
- LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__);
2075
- return;
2076
- }
2077
-
2078
- llama_kv_cache_view_update(view, kv);
2079
- }
2080
-
2081
2297
  //
2082
2298
  // kv cache
2083
2299
  //
2084
2300
 
2085
2301
  // deprecated
2086
- int32_t llama_get_kv_cache_token_count(const llama_context * ctx) {
2087
- return llama_kv_self_n_tokens(ctx);
2088
- }
2089
-
2090
2302
  int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
2091
2303
  const auto * kv = ctx->get_kv_self();
2092
2304
  if (!kv) {
2093
2305
  return 0;
2094
2306
  }
2095
2307
 
2096
- return kv->get_n_tokens();
2097
- }
2308
+ int32_t res = 0;
2098
2309
 
2099
- // deprecated
2100
- int32_t llama_get_kv_cache_used_cells(const llama_context * ctx) {
2101
- return llama_kv_self_used_cells(ctx);
2310
+ for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
2311
+ const llama_pos p0 = kv->seq_pos_min(s);
2312
+ const llama_pos p1 = kv->seq_pos_max(s);
2313
+
2314
+ if (p0 >= 0) {
2315
+ res += (p1 - p0) + 1;
2316
+ }
2317
+ }
2318
+
2319
+ return res;
2102
2320
  }
2103
2321
 
2322
+ // deprecated
2323
+ // note: this is the same as above - will be removed anyway, so it's ok
2104
2324
  int32_t llama_kv_self_used_cells(const llama_context * ctx) {
2105
2325
  const auto * kv = ctx->get_kv_self();
2106
2326
  if (!kv) {
2107
2327
  return 0;
2108
2328
  }
2109
2329
 
2110
- return kv->get_used_cells();
2111
- }
2330
+ int32_t res = 0;
2112
2331
 
2113
- // deprecated
2114
- void llama_kv_cache_clear(llama_context * ctx) {
2115
- llama_kv_self_clear(ctx);
2332
+ for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
2333
+ const llama_pos p0 = kv->seq_pos_min(s);
2334
+ const llama_pos p1 = kv->seq_pos_max(s);
2335
+
2336
+ if (p0 >= 0) {
2337
+ res += (p1 - p0) + 1;
2338
+ }
2339
+ }
2340
+
2341
+ return res;
2116
2342
  }
2117
2343
 
2118
2344
  void llama_kv_self_clear(llama_context * ctx) {
@@ -2124,15 +2350,6 @@ void llama_kv_self_clear(llama_context * ctx) {
2124
2350
  kv->clear();
2125
2351
  }
2126
2352
 
2127
- // deprecated
2128
- bool llama_kv_cache_seq_rm(
2129
- llama_context * ctx,
2130
- llama_seq_id seq_id,
2131
- llama_pos p0,
2132
- llama_pos p1) {
2133
- return llama_kv_self_seq_rm(ctx, seq_id, p0, p1);
2134
- }
2135
-
2136
2353
  bool llama_kv_self_seq_rm(
2137
2354
  llama_context * ctx,
2138
2355
  llama_seq_id seq_id,
@@ -2146,16 +2363,6 @@ bool llama_kv_self_seq_rm(
2146
2363
  return kv->seq_rm(seq_id, p0, p1);
2147
2364
  }
2148
2365
 
2149
- // deprecated
2150
- void llama_kv_cache_seq_cp(
2151
- llama_context * ctx,
2152
- llama_seq_id seq_id_src,
2153
- llama_seq_id seq_id_dst,
2154
- llama_pos p0,
2155
- llama_pos p1) {
2156
- llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
2157
- }
2158
-
2159
2366
  void llama_kv_self_seq_cp(
2160
2367
  llama_context * ctx,
2161
2368
  llama_seq_id seq_id_src,
@@ -2170,13 +2377,6 @@ void llama_kv_self_seq_cp(
2170
2377
  kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
2171
2378
  }
2172
2379
 
2173
- // deprecated
2174
- void llama_kv_cache_seq_keep(
2175
- llama_context * ctx,
2176
- llama_seq_id seq_id) {
2177
- llama_kv_self_seq_keep(ctx, seq_id);
2178
- }
2179
-
2180
2380
  void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
2181
2381
  auto * kv = ctx->get_kv_self();
2182
2382
  if (!kv) {
@@ -2186,16 +2386,6 @@ void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
2186
2386
  kv->seq_keep(seq_id);
2187
2387
  }
2188
2388
 
2189
- // deprecated
2190
- void llama_kv_cache_seq_add(
2191
- llama_context * ctx,
2192
- llama_seq_id seq_id,
2193
- llama_pos p0,
2194
- llama_pos p1,
2195
- llama_pos delta) {
2196
- llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
2197
- }
2198
-
2199
2389
  void llama_kv_self_seq_add(
2200
2390
  llama_context * ctx,
2201
2391
  llama_seq_id seq_id,
@@ -2210,16 +2400,6 @@ void llama_kv_self_seq_add(
2210
2400
  kv->seq_add(seq_id, p0, p1, delta);
2211
2401
  }
2212
2402
 
2213
- // deprecated
2214
- void llama_kv_cache_seq_div(
2215
- llama_context * ctx,
2216
- llama_seq_id seq_id,
2217
- llama_pos p0,
2218
- llama_pos p1,
2219
- int d) {
2220
- llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
2221
- }
2222
-
2223
2403
  void llama_kv_self_seq_div(
2224
2404
  llama_context * ctx,
2225
2405
  llama_seq_id seq_id,
@@ -2234,25 +2414,24 @@ void llama_kv_self_seq_div(
2234
2414
  kv->seq_div(seq_id, p0, p1, d);
2235
2415
  }
2236
2416
 
2237
- // deprecated
2238
- llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
2239
- return llama_kv_self_seq_pos_max(ctx, seq_id);
2417
+ llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
2418
+ const auto * kv = ctx->get_kv_self();
2419
+ if (!kv) {
2420
+ return -1;
2421
+ }
2422
+
2423
+ return kv->seq_pos_min(seq_id);
2240
2424
  }
2241
2425
 
2242
2426
  llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
2243
2427
  const auto * kv = ctx->get_kv_self();
2244
2428
  if (!kv) {
2245
- return 0;
2429
+ return -1;
2246
2430
  }
2247
2431
 
2248
2432
  return kv->seq_pos_max(seq_id);
2249
2433
  }
2250
2434
 
2251
- // deprecated
2252
- void llama_kv_cache_defrag(llama_context * ctx) {
2253
- llama_kv_self_defrag(ctx);
2254
- }
2255
-
2256
2435
  void llama_kv_self_defrag(llama_context * ctx) {
2257
2436
  auto * kv = ctx->get_kv_self();
2258
2437
  if (!kv) {
@@ -2263,11 +2442,6 @@ void llama_kv_self_defrag(llama_context * ctx) {
2263
2442
  kv->defrag_sched(-1.0f);
2264
2443
  }
2265
2444
 
2266
- // deprecated
2267
- bool llama_kv_cache_can_shift(const llama_context * ctx) {
2268
- return llama_kv_self_can_shift(ctx);
2269
- }
2270
-
2271
2445
  bool llama_kv_self_can_shift(const llama_context * ctx) {
2272
2446
  const auto * kv = ctx->get_kv_self();
2273
2447
  if (!kv) {
@@ -2277,11 +2451,6 @@ bool llama_kv_self_can_shift(const llama_context * ctx) {
2277
2451
  return kv->get_can_shift();
2278
2452
  }
2279
2453
 
2280
- // deprecated
2281
- void llama_kv_cache_update(llama_context * ctx) {
2282
- llama_kv_self_update(ctx);
2283
- }
2284
-
2285
2454
  // llama state API
2286
2455
 
2287
2456
  // deprecated
@@ -2404,7 +2573,21 @@ int32_t llama_encode(
2404
2573
  int32_t llama_decode(
2405
2574
  llama_context * ctx,
2406
2575
  llama_batch batch) {
2407
- const int ret = ctx->decode(batch);
2576
+ int ret = ctx->decode(batch);
2577
+
2578
+ // defrag and try again
2579
+ // TODO: distinguish return code when we are sure that even after defrag there is no space available
2580
+ if (ret == 1) {
2581
+ llama_kv_self_defrag(ctx);
2582
+ ret = ctx->decode(batch);
2583
+
2584
+ if (ret == 1) {
2585
+ LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
2586
+
2587
+ return ret;
2588
+ }
2589
+ }
2590
+
2408
2591
  if (ret != 0) {
2409
2592
  LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
2410
2593
  }
@@ -2444,3 +2627,34 @@ void llama_perf_context_print(const llama_context * ctx) {
2444
2627
  void llama_perf_context_reset(llama_context * ctx) {
2445
2628
  ctx->perf_reset();
2446
2629
  }
2630
+
2631
+ //
2632
+ // training
2633
+ //
2634
+
2635
+ bool llama_opt_param_filter_all(const struct lm_ggml_tensor * tensor, void * userdata) {
2636
+ LM_GGML_UNUSED(tensor);
2637
+ LM_GGML_UNUSED(userdata);
2638
+ return true;
2639
+ }
2640
+
2641
+ void llama_opt_init(struct llama_context * ctx, struct llama_model * model, struct llama_opt_params lopt_params) {
2642
+ ctx->opt_init(model, lopt_params);
2643
+ }
2644
+
2645
+ void llama_opt_epoch(
2646
+ struct llama_context * ctx,
2647
+ lm_ggml_opt_dataset_t dataset,
2648
+ lm_ggml_opt_result_t result_train,
2649
+ lm_ggml_opt_result_t result_eval,
2650
+ int64_t idata_split,
2651
+ lm_ggml_opt_epoch_callback callback_train,
2652
+ lm_ggml_opt_epoch_callback callback_eval) {
2653
+ ctx->opt_epoch(
2654
+ dataset,
2655
+ result_train,
2656
+ result_eval,
2657
+ idata_split,
2658
+ callback_train,
2659
+ callback_eval);
2660
+ }