cui-llama.rn 1.6.0 → 1.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. package/README.md +35 -7
  2. package/android/src/main/CMakeLists.txt +16 -11
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +4 -1
  4. package/android/src/main/jni.cpp +20 -4
  5. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  10. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  11. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  12. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  13. package/cpp/LICENSE +21 -0
  14. package/cpp/chat.cpp +1 -1
  15. package/cpp/common.cpp +17 -2
  16. package/cpp/common.h +7 -3
  17. package/cpp/ggml-alloc.c +4 -1
  18. package/cpp/ggml-cpp.h +1 -1
  19. package/cpp/ggml-cpu/amx/amx.cpp +221 -0
  20. package/cpp/ggml-cpu/amx/amx.h +8 -0
  21. package/cpp/ggml-cpu/amx/common.h +91 -0
  22. package/cpp/ggml-cpu/amx/mmq.cpp +2511 -0
  23. package/cpp/ggml-cpu/amx/mmq.h +10 -0
  24. package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/binary-ops.h +1 -1
  25. package/cpp/ggml-cpu/common.h +72 -0
  26. package/cpp/{ggml-cpu-aarch64.cpp → ggml-cpu/ggml-cpu-aarch64.cpp} +809 -101
  27. package/cpp/{ggml-cpu.c → ggml-cpu/ggml-cpu.c} +109 -42
  28. package/cpp/{ggml-cpu.cpp → ggml-cpu/ggml-cpu.cpp} +3 -0
  29. package/cpp/{ops.cpp → ggml-cpu/ops.cpp} +246 -160
  30. package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/ops.h +2 -20
  31. package/cpp/{sgemm.cpp → ggml-cpu/sgemm.cpp} +501 -0
  32. package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/simd-mappings.h +7 -3
  33. package/{ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers → cpp/ggml-cpu}/unary-ops.h +1 -1
  34. package/cpp/ggml-cpu.h +5 -0
  35. package/cpp/ggml-impl.h +16 -9
  36. package/cpp/ggml-llama-sim.metallib +0 -0
  37. package/cpp/ggml-llama.metallib +0 -0
  38. package/cpp/ggml-metal.m +492 -47
  39. package/cpp/ggml.c +134 -244
  40. package/cpp/ggml.h +61 -94
  41. package/cpp/json-schema-to-grammar.cpp +3 -0
  42. package/cpp/llama-arch.cpp +46 -17
  43. package/cpp/llama-arch.h +9 -0
  44. package/cpp/llama-batch.cpp +5 -1
  45. package/cpp/llama-batch.h +2 -1
  46. package/cpp/llama-chat.cpp +31 -10
  47. package/cpp/llama-chat.h +3 -2
  48. package/cpp/llama-context.cpp +104 -489
  49. package/cpp/llama-context.h +14 -30
  50. package/cpp/llama-graph.cpp +69 -62
  51. package/cpp/llama-graph.h +21 -18
  52. package/cpp/llama-hparams.h +5 -0
  53. package/cpp/llama-kv-cache.cpp +1497 -391
  54. package/cpp/llama-kv-cache.h +272 -80
  55. package/cpp/llama-memory.h +11 -1
  56. package/cpp/llama-model.cpp +502 -176
  57. package/cpp/llama-model.h +13 -3
  58. package/cpp/llama-sampling.cpp +2 -1
  59. package/cpp/llama-vocab.cpp +8 -1
  60. package/cpp/llama.h +14 -11
  61. package/cpp/rn-llama.cpp +20 -172
  62. package/cpp/rn-llama.h +1 -5
  63. package/ios/CMakeLists.txt +13 -10
  64. package/ios/RNLlama.h +6 -0
  65. package/ios/RNLlama.mm +5 -0
  66. package/ios/RNLlamaContext.mm +26 -28
  67. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +7 -3
  68. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpp.h +1 -1
  69. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu.h +5 -0
  70. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-impl.h +16 -9
  71. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +61 -94
  72. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-arch.h +9 -0
  73. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-batch.h +2 -1
  74. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +3 -2
  75. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +14 -30
  76. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +21 -18
  77. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +5 -0
  78. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +272 -80
  79. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +11 -1
  80. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +13 -3
  81. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +14 -11
  82. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +1 -5
  83. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  84. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
  85. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +7 -3
  86. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +1 -1
  87. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +5 -0
  88. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +16 -9
  89. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +61 -94
  90. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +9 -0
  91. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +2 -1
  92. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +3 -2
  93. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +14 -30
  94. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +21 -18
  95. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +5 -0
  96. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +272 -80
  97. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +11 -1
  98. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +13 -3
  99. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +14 -11
  100. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +1 -5
  101. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  102. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  103. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +7 -3
  104. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpp.h +1 -1
  105. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu.h +5 -0
  106. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-impl.h +16 -9
  107. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +61 -94
  108. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-arch.h +9 -0
  109. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-batch.h +2 -1
  110. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +3 -2
  111. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +14 -30
  112. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +21 -18
  113. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +5 -0
  114. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +272 -80
  115. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +11 -1
  116. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +13 -3
  117. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +14 -11
  118. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +1 -5
  119. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  120. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
  121. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +7 -3
  122. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +1 -1
  123. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +5 -0
  124. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +16 -9
  125. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +61 -94
  126. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +9 -0
  127. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +2 -1
  128. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +3 -2
  129. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +14 -30
  130. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +21 -18
  131. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +5 -0
  132. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +272 -80
  133. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +11 -1
  134. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +13 -3
  135. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +14 -11
  136. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +1 -5
  137. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  138. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  139. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  140. package/lib/module/NativeRNLlama.js.map +1 -1
  141. package/lib/typescript/NativeRNLlama.d.ts +4 -0
  142. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  143. package/package.json +1 -1
  144. package/src/NativeRNLlama.ts +5 -0
  145. package/cpp/binary-ops.h +0 -16
  146. package/cpp/ops.h +0 -128
  147. package/cpp/simd-mappings.h +0 -888
  148. package/cpp/unary-ops.h +0 -28
  149. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/binary-ops.h +0 -16
  150. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  151. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  152. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  153. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  154. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ops.h +0 -128
  155. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/sgemm.h +0 -14
  156. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/simd-mappings.h +0 -888
  157. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/vec.h +0 -802
  158. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  159. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  160. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  161. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  162. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +0 -14
  163. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +0 -28
  164. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +0 -802
  165. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/binary-ops.h +0 -16
  166. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  167. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  168. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  169. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  170. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ops.h +0 -128
  171. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/sgemm.h +0 -14
  172. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/simd-mappings.h +0 -888
  173. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/unary-ops.h +0 -28
  174. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/vec.h +0 -802
  175. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/binary-ops.h +0 -16
  176. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  177. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  178. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  179. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  180. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ops.h +0 -128
  181. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +0 -14
  182. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/simd-mappings.h +0 -888
  183. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +0 -28
  184. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +0 -802
  185. /package/cpp/{binary-ops.cpp → ggml-cpu/binary-ops.cpp} +0 -0
  186. /package/cpp/{ggml-cpu-aarch64.h → ggml-cpu/ggml-cpu-aarch64.h} +0 -0
  187. /package/cpp/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +0 -0
  188. /package/cpp/{ggml-cpu-quants.c → ggml-cpu/ggml-cpu-quants.c} +0 -0
  189. /package/cpp/{ggml-cpu-quants.h → ggml-cpu/ggml-cpu-quants.h} +0 -0
  190. /package/cpp/{ggml-cpu-traits.cpp → ggml-cpu/ggml-cpu-traits.cpp} +0 -0
  191. /package/cpp/{ggml-cpu-traits.h → ggml-cpu/ggml-cpu-traits.h} +0 -0
  192. /package/cpp/{sgemm.h → ggml-cpu/sgemm.h} +0 -0
  193. /package/cpp/{unary-ops.cpp → ggml-cpu/unary-ops.cpp} +0 -0
  194. /package/cpp/{vec.cpp → ggml-cpu/vec.cpp} +0 -0
  195. /package/cpp/{vec.h → ggml-cpu/vec.h} +0 -0
@@ -6,7 +6,6 @@
6
6
  #include "llama-model.h"
7
7
  #include "llama-kv-cache.h"
8
8
 
9
- #include <cassert>
10
9
  #include <cstring>
11
10
  #include <stdexcept>
12
11
  #include <cinttypes>
@@ -113,7 +112,7 @@ llama_context::llama_context(
113
112
  }
114
113
 
115
114
  if (n_ctx_per_seq > hparams.n_ctx_train) {
116
- LLAMA_LOG_WARN("%s: n_ctx_pre_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
115
+ LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
117
116
  __func__, n_ctx_per_seq, hparams.n_ctx_train);
118
117
  }
119
118
 
@@ -176,44 +175,13 @@ llama_context::llama_context(
176
175
  }
177
176
 
178
177
  // init the memory module
179
- // TODO: for now, always create a unified KV cache
180
178
  if (!hparams.vocab_only) {
181
- kv_self.reset(static_cast<llama_kv_cache_unified *>(model.create_memory()));
179
+ llama_memory_params params_mem = {
180
+ /*.type_k =*/ params.type_k,
181
+ /*.type_v =*/ params.type_v,
182
+ };
182
183
 
183
- LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
184
-
185
- cparams.n_ctx = LM_GGML_PAD(cparams.n_ctx, kv_self->get_padding(cparams));
186
-
187
- LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
188
-
189
- uint32_t kv_size = cparams.n_ctx;
190
- lm_ggml_type type_k = params.type_k;
191
- lm_ggml_type type_v = params.type_v;
192
-
193
- if (llama_model_is_recurrent(&model)) {
194
- // Mamba needs at least as many KV cells as there are sequences kept at any time
195
- kv_size = std::max((uint32_t) 1, params.n_seq_max);
196
- // it's probably best to keep as much precision as possible for the states
197
- type_k = LM_GGML_TYPE_F32; // required by lm_ggml_ssm_conv for Mamba's conv_states
198
- type_v = LM_GGML_TYPE_F32; // required by lm_ggml_ssm_scan for Mamba's ssm_states
199
- }
200
-
201
- LM_GGML_ASSERT(hparams.n_embd_head_k % lm_ggml_blck_size(type_k) == 0);
202
- LM_GGML_ASSERT(hparams.n_embd_head_v % lm_ggml_blck_size(type_v) == 0);
203
-
204
- if (!kv_self->init(model, cparams, type_k, type_v, kv_size, cparams.offload_kqv)) {
205
- throw std::runtime_error("failed to initialize self-attention cache");
206
- }
207
-
208
- {
209
- const size_t memory_size_k = kv_self->size_k_bytes();
210
- const size_t memory_size_v = kv_self->size_v_bytes();
211
-
212
- LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
213
- (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
214
- lm_ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
215
- lm_ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
216
- }
184
+ memory.reset(model.create_memory(params_mem, cparams));
217
185
  }
218
186
 
219
187
  // init backends
@@ -304,7 +272,9 @@ llama_context::llama_context(
304
272
  int n_nodes_tg = -1;
305
273
 
306
274
  // simulate full KV cache
307
- kv_self->n = kv_self->size;
275
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
276
+
277
+ kv_self->set_full();
308
278
 
309
279
  cross.v_embd.clear();
310
280
 
@@ -426,6 +396,18 @@ const llama_model & llama_context::get_model() const {
426
396
  return model;
427
397
  }
428
398
 
399
+ const llama_cparams & llama_context::get_cparams() const {
400
+ return cparams;
401
+ }
402
+
403
+ lm_ggml_backend_sched_t llama_context::get_sched() const {
404
+ return sched.get();
405
+ }
406
+
407
+ lm_ggml_context * llama_context::get_ctx_compute() const {
408
+ return ctx_compute.get();
409
+ }
410
+
429
411
  uint32_t llama_context::n_ctx() const {
430
412
  return cparams.n_ctx;
431
413
  }
@@ -455,345 +437,21 @@ uint32_t llama_context::n_threads_batch() const {
455
437
  }
456
438
 
457
439
  llama_kv_cache * llama_context::get_kv_self() {
458
- return kv_self.get();
440
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
441
+ return kv_self;
459
442
  }
460
443
 
461
444
  const llama_kv_cache * llama_context::get_kv_self() const {
462
- return kv_self.get();
463
- }
464
-
465
- lm_ggml_tensor * llama_context::build_rope_shift(
466
- lm_ggml_context * ctx0,
467
- lm_ggml_tensor * cur,
468
- lm_ggml_tensor * shift,
469
- lm_ggml_tensor * factors,
470
- float freq_base,
471
- float freq_scale,
472
- lm_ggml_backend_buffer * bbuf) const {
473
- const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
474
-
475
- const auto & yarn_ext_factor = cparams.yarn_ext_factor;
476
- const auto & yarn_attn_factor = cparams.yarn_attn_factor;
477
- const auto & yarn_beta_fast = cparams.yarn_beta_fast;
478
- const auto & yarn_beta_slow = cparams.yarn_beta_slow;
479
-
480
- const auto & hparams = model.hparams;
481
-
482
- const auto & n_rot = hparams.n_rot;
483
- const auto & rope_type = hparams.rope_type;
484
-
485
- lm_ggml_tensor * tmp;
486
-
487
- if (lm_ggml_is_quantized(cur->type)) {
488
- // dequantize to f32 -> RoPE -> quantize back
489
- tmp = lm_ggml_cast(ctx0, cur, LM_GGML_TYPE_F32);
490
-
491
- if (bbuf) {
492
- for (const auto & backend : backends) {
493
- // Figure out which backend KV cache belongs to
494
- if (lm_ggml_backend_supports_buft(backend.get(), lm_ggml_backend_buffer_get_type(bbuf))) {
495
- lm_ggml_backend_sched_set_tensor_backend(sched.get(), tmp, backend.get());
496
- break;
497
- }
498
- }
499
- }
500
-
501
- tmp = lm_ggml_rope_ext_inplace(ctx0, tmp,
502
- shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
503
- yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
504
-
505
- tmp = lm_ggml_cpy(ctx0, tmp, cur);
506
- } else {
507
- // we rotate only the first n_rot dimensions
508
- tmp = lm_ggml_rope_ext_inplace(ctx0, cur,
509
- shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
510
- yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
511
- }
512
-
513
- return tmp;
514
- }
515
-
516
- class llm_graph_input_k_shift : public llm_graph_input_i {
517
- public:
518
- llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
519
- virtual ~llm_graph_input_k_shift() = default;
520
-
521
- void set_input(const llama_ubatch * ubatch) override;
522
-
523
- lm_ggml_tensor * k_shift; // I32 [kv_size]
524
-
525
- const llama_kv_cache_unified * kv_self;
526
- };
527
-
528
- void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
529
- LM_GGML_UNUSED(ubatch);
530
-
531
- if (k_shift) {
532
- assert(lm_ggml_backend_buffer_is_host(k_shift->buffer));
533
-
534
- int32_t * data = (int32_t *) k_shift->data;
535
-
536
- for (uint32_t i = 0; i < kv_self->size; ++i) {
537
- data[i] = kv_self->cells[i].delta;
538
- }
539
- }
540
- }
541
-
542
- llm_graph_result_ptr llama_context::build_kv_self_shift(
543
- lm_ggml_context * ctx0,
544
- lm_ggml_cgraph * gf) const {
545
- auto res = std::make_unique<llm_graph_result>();
546
-
547
- const auto & hparams = model.hparams;
548
-
549
- const auto & n_layer = hparams.n_layer;
550
-
551
- const auto & n_embd_head_k = hparams.n_embd_head_k;
552
- //const auto & n_embd_head_v = hparams.n_embd_head_v;
553
-
554
- //LM_GGML_ASSERT(kv_self->size == n_ctx);
555
-
556
- auto inp = std::make_unique<llm_graph_input_k_shift>(kv_self.get());
557
-
558
- inp->k_shift = lm_ggml_new_tensor_1d(ctx0, LM_GGML_TYPE_I32, cparams.n_ctx);
559
- lm_ggml_set_input(inp->k_shift);
560
-
561
- for (uint32_t il = 0; il < n_layer; ++il) {
562
- const int64_t n_head_kv = hparams.n_head_kv(il);
563
- const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
564
-
565
- const bool is_swa = hparams.is_swa(il);
566
-
567
- // note: the swa rope params could become part of the cparams in the future
568
- // if we decide to make them configurable, like the non-sliding ones
569
- const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
570
- const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
571
-
572
- lm_ggml_tensor * rope_factors = kv_self->cbs.get_rope_factors(n_ctx_per_seq(), il);
573
-
574
- lm_ggml_tensor * k =
575
- lm_ggml_view_3d(ctx0, kv_self->k_l[il],
576
- n_embd_head_k, n_head_kv, kv_self->size,
577
- lm_ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
578
- lm_ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
579
- 0);
580
-
581
- lm_ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, kv_self->k_l[il]->buffer);
582
-
583
- lm_ggml_build_forward_expand(gf, cur);
584
- }
585
-
586
- res->add_input(std::move(inp));
587
-
588
- return res;
589
- }
590
-
591
- llm_graph_result_ptr llama_context::build_kv_self_defrag(
592
- lm_ggml_context * ctx0,
593
- lm_ggml_cgraph * gf) const {
594
- auto res = std::make_unique<llm_graph_result>();
595
-
596
- const auto & hparams = model.hparams;
597
-
598
- const auto & ids = kv_self->defrag_info.ids;
599
-
600
- #if 0
601
- // CPU defrag
602
- //
603
- // TODO: optimizations are possible:
604
- // - multiple threads
605
- // - avoid copying to the host memory when already there
606
- //
607
- // likely not worth the effort, as we have lm_ggml_graph based defrag
608
- //
609
-
610
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
611
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
612
-
613
- const uint32_t kv_size = size;
614
-
615
- std::vector<uint8_t> buf_k;
616
- std::vector<uint8_t> buf_v;
617
-
618
- for (uint32_t il = 0; il < n_layer; ++il) {
619
- const size_t k_size_row = lm_ggml_row_size(k_l[il]->type, n_embd_k_gqa);
620
- const size_t k_size = lm_ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size);
621
-
622
- const size_t v_size_el = lm_ggml_type_size(v_l[il]->type);
623
- const size_t v_size = lm_ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size);
624
-
625
- buf_k.resize(k_size);
626
- buf_v.resize(v_size);
627
-
628
- lm_ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size());
629
- lm_ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size());
630
-
631
- // batch move [i, i+nm) to [id, id+nm)
632
- // note: cells can move only to a lower index
633
- for (uint32_t i = 0; i < n_kv; ++i) {
634
- const uint32_t id = ids[i];
635
-
636
- if (i == id || id == n_kv) {
637
- continue;
638
- }
639
-
640
- uint32_t nm = 1;
641
-
642
- while (i + nm < n_kv && ids[i + nm] == id + nm) {
643
- nm++;
644
- }
645
-
646
- // move keys
647
- {
648
- const int64_t os = i*k_size_row;
649
- const int64_t od = id*k_size_row;
650
-
651
- memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row);
652
- }
653
-
654
- // move values (note: they are transposed)
655
- {
656
- const int64_t os = i;
657
- const int64_t od = id;
658
-
659
- for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
660
- memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
661
- }
662
- }
663
-
664
- i += nm - 1;
665
- }
666
-
667
- lm_ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size());
668
- lm_ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
669
- }
670
- #else
671
- for (uint32_t i = 0; i < ids.size(); ++i) {
672
- const uint32_t id = ids[i];
673
-
674
- if (i == id || id == ids.size()) {
675
- continue;
676
- }
677
-
678
- uint32_t nm = 1;
679
-
680
- while (i + nm < ids.size() && ids[i + nm] == id + nm) {
681
- nm++;
682
- }
683
-
684
- for (uint32_t il = 0; il < hparams.n_layer; ++il) { // NOLINT
685
- const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
686
- const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
687
-
688
- lm_ggml_tensor * view_k_src = lm_ggml_view_2d(ctx0, kv_self->k_l[il],
689
- n_embd_k_gqa, nm,
690
- lm_ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
691
- lm_ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*i));
692
-
693
- lm_ggml_tensor * view_k_dst = lm_ggml_view_2d(ctx0, kv_self->k_l[il],
694
- n_embd_k_gqa, nm,
695
- lm_ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
696
- lm_ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*id));
697
-
698
- lm_ggml_tensor * view_v_src;
699
- lm_ggml_tensor * view_v_dst;
700
-
701
- if (cparams.flash_attn) {
702
- // NOTE: the V cache is not transposed when using flash attention
703
- view_v_src = lm_ggml_view_2d(ctx0, kv_self->v_l[il],
704
- n_embd_v_gqa, nm,
705
- lm_ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
706
- lm_ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*i));
707
-
708
- view_v_dst = lm_ggml_view_2d(ctx0, kv_self->v_l[il],
709
- n_embd_v_gqa, nm,
710
- lm_ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
711
- lm_ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*id));
712
- } else {
713
- view_v_src = lm_ggml_view_2d(ctx0, kv_self->v_l[il],
714
- nm, n_embd_v_gqa,
715
- lm_ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
716
- lm_ggml_row_size(kv_self->v_l[il]->type, i));
717
-
718
- view_v_dst = lm_ggml_view_2d(ctx0, kv_self->v_l[il],
719
- nm, n_embd_v_gqa,
720
- lm_ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
721
- lm_ggml_row_size(kv_self->v_l[il]->type, id));
722
- }
723
-
724
- lm_ggml_build_forward_expand(gf, lm_ggml_cpy(ctx0, view_k_src, view_k_dst));
725
- lm_ggml_build_forward_expand(gf, lm_ggml_cpy(ctx0, view_v_src, view_v_dst));
726
- }
727
-
728
- i += nm - 1;
729
- }
730
-
731
- //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
732
- #endif
733
-
734
- return res;
445
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
446
+ return kv_self;
735
447
  }
736
448
 
737
449
  void llama_context::kv_self_update() {
738
- auto & kv = kv_self;
739
-
740
450
  bool need_reserve = false;
741
451
 
742
- if (kv->has_shift) {
743
- if (!kv->get_can_shift()) {
744
- LM_GGML_ABORT("The current context does not support K-shift");
745
- }
746
-
747
- LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__);
748
-
749
- // apply K-shift if needed
750
- if (model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
751
- lm_ggml_backend_sched_reset(sched.get());
752
-
753
- auto * gf = graph_init();
754
-
755
- auto res = build_kv_self_shift(ctx_compute.get(), gf);
756
-
757
- lm_ggml_backend_sched_alloc_graph(sched.get(), gf);
758
-
759
- res->set_inputs(nullptr);
760
-
761
- graph_compute(gf, false);
762
-
763
- need_reserve = true;
764
- }
765
-
766
- {
767
- kv->has_shift = false;
768
-
769
- for (uint32_t i = 0; i < kv->size; ++i) {
770
- kv->cells[i].delta = 0;
771
- }
772
- }
773
- }
774
-
775
- // defragment the KV cache if needed
776
- if (kv->do_defrag) {
777
- LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
452
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
778
453
 
779
- if (kv->defrag_prepare(graph_max_nodes())) {
780
- lm_ggml_backend_sched_reset(sched.get());
781
-
782
- auto * gf = graph_init();
783
-
784
- auto res = build_kv_self_defrag(ctx_compute.get(), gf);
785
-
786
- lm_ggml_backend_sched_alloc_graph(sched.get(), gf);
787
-
788
- res->set_inputs(nullptr);
789
-
790
- graph_compute(gf, false);
791
-
792
- need_reserve = true;
793
- }
794
-
795
- kv->do_defrag = false;
796
- }
454
+ need_reserve = kv_self->update(*this);
797
455
 
798
456
  // reserve a worst case graph if needed
799
457
  if (need_reserve) {
@@ -804,7 +462,7 @@ void llama_context::kv_self_update() {
804
462
  uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
805
463
 
806
464
  // simulate full KV cache
807
- kv_self->n = kv_self->size;
465
+ kv_self->set_full();
808
466
 
809
467
  llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
810
468
  llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
@@ -825,9 +483,6 @@ enum llama_pooling_type llama_context::pooling_type() const {
825
483
  }
826
484
 
827
485
  float * llama_context::get_logits() {
828
- // reorder logits for backward compatibility
829
- output_reorder();
830
-
831
486
  return logits;
832
487
  }
833
488
 
@@ -870,9 +525,6 @@ float * llama_context::get_logits_ith(int32_t i) {
870
525
  }
871
526
 
872
527
  float * llama_context::get_embeddings() {
873
- // reorder embeddings for backward compatibility
874
- output_reorder();
875
-
876
528
  return embd;
877
529
  }
878
530
 
@@ -1024,8 +676,8 @@ int llama_context::encode(llama_batch & inp_batch) {
1024
676
  }
1025
677
 
1026
678
  // temporary allocate memory for the input batch if needed
1027
- // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
1028
- llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
679
+ // note: during encode, we always pass the full sequence starting from pos = 0
680
+ llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : 0);
1029
681
 
1030
682
  const llama_batch & batch = batch_allocr.batch;
1031
683
  const int32_t n_tokens = batch.n_tokens;
@@ -1054,7 +706,7 @@ int llama_context::encode(llama_batch & inp_batch) {
1054
706
 
1055
707
  const int64_t n_embd = hparams.n_embd;
1056
708
 
1057
- sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
709
+ llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
1058
710
 
1059
711
  const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
1060
712
 
@@ -1188,9 +840,11 @@ int llama_context::decode(llama_batch & inp_batch) {
1188
840
  return -1;
1189
841
  }
1190
842
 
843
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
844
+
1191
845
  // temporary allocate memory for the input batch if needed
1192
- // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
1193
- llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
846
+ // TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences
847
+ llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max() + 1);
1194
848
 
1195
849
  const llama_batch & batch = batch_allocr.batch;
1196
850
 
@@ -1202,7 +856,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1202
856
  const int64_t n_tokens_all = batch.n_tokens;
1203
857
  const int64_t n_embd = hparams.n_embd;
1204
858
 
1205
- llama_kv_cache_guard kv_guard(kv_self.get());
859
+ llama_kv_cache_guard kv_guard(kv_self);
1206
860
 
1207
861
  LM_GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
1208
862
 
@@ -1243,11 +897,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1243
897
  n_outputs_all = 1;
1244
898
  }
1245
899
 
1246
- const bool logits_all = n_outputs_all == n_tokens_all;
1247
-
1248
- sbatch.from_batch(batch, n_embd,
1249
- /* simple_split */ !kv_self->recurrent,
1250
- /* logits_all */ logits_all);
900
+ llama_sbatch sbatch = kv_self->sbatch_init(batch, /* logits_all */ n_outputs_all == n_tokens_all);
1251
901
 
1252
902
  // reserve output buffer
1253
903
  if (output_reserve(n_outputs_all) < n_outputs_all) {
@@ -1261,22 +911,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1261
911
  int64_t n_outputs_prev = 0;
1262
912
 
1263
913
  while (sbatch.n_tokens > 0) {
1264
- llama_ubatch ubatch = llama_ubatch();
1265
-
1266
- const auto & n_ubatch = cparams.n_ubatch;
1267
-
1268
- if (kv_self->recurrent) {
1269
- if (embd_pooled) {
1270
- // Pooled embeddings cannot be split across ubatches (yet)
1271
- ubatch = sbatch.split_seq(cparams.n_ubatch);
1272
- } else {
1273
- // recurrent model architectures are easier to implement
1274
- // with equal-length sequences
1275
- ubatch = sbatch.split_equal(cparams.n_ubatch);
1276
- }
1277
- } else {
1278
- ubatch = sbatch.split_simple(n_ubatch);
1279
- }
914
+ llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
1280
915
 
1281
916
  // count the outputs in this u_batch
1282
917
  {
@@ -1296,24 +931,12 @@ int llama_context::decode(llama_batch & inp_batch) {
1296
931
  }
1297
932
 
1298
933
  // find KV slot
1299
- {
1300
- if (!kv_self->find_slot(ubatch)) {
1301
- LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
1302
-
1303
- return 1;
1304
- }
934
+ if (!kv_self->find_slot(ubatch)) {
935
+ LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
1305
936
 
1306
- if (!kv_self->recurrent) {
1307
- // a heuristic, to avoid attending the full cache if it is not yet utilized
1308
- // after enough generations, the benefit from this heuristic disappears
1309
- // if we start defragmenting the cache, the benefit from this will be more important
1310
- const uint32_t pad = kv_self->get_padding(cparams);
1311
- kv_self->n = std::min(kv_self->size, std::max(pad, LM_GGML_PAD(kv_self->cell_max(), pad)));
1312
- }
937
+ return 1;
1313
938
  }
1314
939
 
1315
- //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self->n, kv_self->used, kv_self->head);
1316
-
1317
940
  lm_ggml_backend_sched_reset(sched.get());
1318
941
  lm_ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
1319
942
 
@@ -1427,43 +1050,68 @@ int llama_context::decode(llama_batch & inp_batch) {
1427
1050
  // finalize the batch processing
1428
1051
  kv_guard.commit();
1429
1052
 
1053
+ // set to total number of outputs in the batch, for use in llama_get_logits_ith
1054
+ n_outputs = n_outputs_all;
1055
+
1430
1056
  // set output mappings
1431
1057
  {
1432
1058
  bool sorted_output = true;
1433
1059
 
1434
- LM_GGML_ASSERT(sbatch.out_ids.size() == (size_t) n_outputs_all);
1060
+ auto & out_ids = sbatch.out_ids;
1061
+
1062
+ LM_GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
1435
1063
 
1436
1064
  for (int64_t i = 0; i < n_outputs_all; ++i) {
1437
- int64_t out_id = sbatch.out_ids[i];
1065
+ int64_t out_id = out_ids[i];
1438
1066
  output_ids[out_id] = i;
1439
1067
  if (out_id != i) {
1440
1068
  sorted_output = false;
1441
1069
  }
1442
1070
  }
1443
1071
 
1444
- if (sorted_output) {
1445
- sbatch.out_ids.clear();
1072
+ // make the outputs have the same order they had in the user-provided batch
1073
+ // note: this is mostly relevant for recurrent models atm
1074
+ if (!sorted_output) {
1075
+ const uint32_t n_vocab = model.vocab.n_tokens();
1076
+ const uint32_t n_embd = model.hparams.n_embd;
1077
+
1078
+ LM_GGML_ASSERT((size_t) n_outputs == out_ids.size());
1079
+
1080
+ // TODO: is there something more efficient which also minimizes swaps?
1081
+ // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
1082
+ for (int32_t i = 0; i < n_outputs - 1; ++i) {
1083
+ int32_t j_min = i;
1084
+ for (int32_t j = i + 1; j < n_outputs; ++j) {
1085
+ if (out_ids[j] < out_ids[j_min]) {
1086
+ j_min = j;
1087
+ }
1088
+ }
1089
+ if (j_min == i) { continue; }
1090
+ std::swap(out_ids[i], out_ids[j_min]);
1091
+ if (logits_size > 0) {
1092
+ for (uint32_t k = 0; k < n_vocab; k++) {
1093
+ std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
1094
+ }
1095
+ }
1096
+ if (embd_size > 0) {
1097
+ for (uint32_t k = 0; k < n_embd; k++) {
1098
+ std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
1099
+ }
1100
+ }
1101
+ }
1102
+ std::fill(output_ids.begin(), output_ids.end(), -1);
1103
+ for (int32_t i = 0; i < n_outputs; ++i) {
1104
+ output_ids[out_ids[i]] = i;
1105
+ }
1446
1106
  }
1447
1107
  }
1448
1108
 
1449
- // set to total number of outputs in the batch, for use in llama_get_logits_ith
1450
- n_outputs = n_outputs_all;
1451
-
1452
1109
  // wait for the computation to finish (automatically done when obtaining the model output)
1453
1110
  //synchronize();
1454
1111
 
1455
1112
  // decide if we need to defrag the kv cache
1456
- if (cparams.causal_attn && cparams.defrag_thold > 0.0f) {
1457
- // - do not defrag small contexts (i.e. < 2048 tokens)
1458
- // - count the padding towards the number of used tokens
1459
- const float fragmentation = kv_self->n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self->used + kv_self->get_padding(cparams))/float(kv_self->n)) : 0.0f;
1460
-
1461
- // queue defragmentation for next llama_kv_cache_update
1462
- if (fragmentation > cparams.defrag_thold) {
1463
- LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
1464
-
1465
- kv_self->defrag();
1466
- }
1113
+ if (cparams.defrag_thold > 0.0f) {
1114
+ kv_self->defrag_sched(cparams.defrag_thold);
1467
1115
  }
1468
1116
 
1469
1117
  // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
@@ -1543,52 +1191,12 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
1543
1191
  // set all ids as invalid (negative)
1544
1192
  std::fill(output_ids.begin(), output_ids.end(), -1);
1545
1193
 
1546
- lm_ggml_backend_buffer_clear(buf_output.get(), 0);
1547
-
1548
1194
  this->n_outputs = 0;
1549
1195
  this->n_outputs_max = n_outputs_max;
1550
1196
 
1551
1197
  return n_outputs_max;
1552
1198
  }
1553
1199
 
1554
- void llama_context::output_reorder() {
1555
- auto & out_ids = sbatch.out_ids;
1556
- if (!out_ids.empty()) {
1557
- const uint32_t n_vocab = model.vocab.n_tokens();
1558
- const uint32_t n_embd = model.hparams.n_embd;
1559
-
1560
- LM_GGML_ASSERT((size_t) n_outputs == out_ids.size());
1561
-
1562
- // TODO: is there something more efficient which also minimizes swaps?
1563
- // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
1564
- for (int32_t i = 0; i < n_outputs - 1; ++i) {
1565
- int32_t j_min = i;
1566
- for (int32_t j = i + 1; j < n_outputs; ++j) {
1567
- if (out_ids[j] < out_ids[j_min]) {
1568
- j_min = j;
1569
- }
1570
- }
1571
- if (j_min == i) { continue; }
1572
- std::swap(out_ids[i], out_ids[j_min]);
1573
- if (logits_size > 0) {
1574
- for (uint32_t k = 0; k < n_vocab; k++) {
1575
- std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
1576
- }
1577
- }
1578
- if (embd_size > 0) {
1579
- for (uint32_t k = 0; k < n_embd; k++) {
1580
- std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
1581
- }
1582
- }
1583
- }
1584
- std::fill(output_ids.begin(), output_ids.end(), -1);
1585
- for (int32_t i = 0; i < n_outputs; ++i) {
1586
- output_ids[out_ids[i]] = i;
1587
- }
1588
- out_ids.clear();
1589
- }
1590
- }
1591
-
1592
1200
  //
1593
1201
  // graph
1594
1202
  //
@@ -1625,7 +1233,7 @@ llm_graph_result_ptr llama_context::graph_build(
1625
1233
  /*.backend_cpu =*/ backend_cpu,
1626
1234
  /*.cvec =*/ &cvec,
1627
1235
  /*.loras =*/ &loras,
1628
- /*.memory =*/ kv_self.get(),
1236
+ /*.memory =*/ memory.get(),
1629
1237
  /*.cross =*/ &cross,
1630
1238
  /*.n_outputs =*/ n_outputs,
1631
1239
  /*.cb =*/ graph_get_cb(),
@@ -2029,8 +1637,6 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
2029
1637
  {
2030
1638
  LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__);
2031
1639
 
2032
- output_reorder();
2033
-
2034
1640
  const auto n_outputs = this->n_outputs;
2035
1641
  const auto & output_ids = this->output_ids;
2036
1642
 
@@ -2084,6 +1690,8 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
2084
1690
  }
2085
1691
 
2086
1692
  LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
1693
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1694
+
2087
1695
  kv_self->state_write(io);
2088
1696
 
2089
1697
  return io.n_bytes();
@@ -2168,6 +1776,8 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
2168
1776
  }
2169
1777
 
2170
1778
  LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
1779
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1780
+
2171
1781
  kv_self->state_read(io);
2172
1782
 
2173
1783
  return io.n_bytes();
@@ -2176,6 +1786,8 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
2176
1786
  size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
2177
1787
  LM_GGML_UNUSED(seq_id);
2178
1788
 
1789
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1790
+
2179
1791
  kv_self->state_write(io, seq_id);
2180
1792
 
2181
1793
  return io.n_bytes();
@@ -2184,6 +1796,8 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s
2184
1796
  size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
2185
1797
  LM_GGML_UNUSED(seq_id);
2186
1798
 
1799
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1800
+
2187
1801
  kv_self->state_read(io, seq_id);
2188
1802
 
2189
1803
  return io.n_bytes();
@@ -2539,7 +2153,7 @@ void llama_kv_cache_seq_cp(
2539
2153
  llama_seq_id seq_id_dst,
2540
2154
  llama_pos p0,
2541
2155
  llama_pos p1) {
2542
- return llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
2156
+ llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
2543
2157
  }
2544
2158
 
2545
2159
  void llama_kv_self_seq_cp(
@@ -2553,14 +2167,14 @@ void llama_kv_self_seq_cp(
2553
2167
  return;
2554
2168
  }
2555
2169
 
2556
- return kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
2170
+ kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
2557
2171
  }
2558
2172
 
2559
2173
  // deprecated
2560
2174
  void llama_kv_cache_seq_keep(
2561
2175
  llama_context * ctx,
2562
2176
  llama_seq_id seq_id) {
2563
- return llama_kv_self_seq_keep(ctx, seq_id);
2177
+ llama_kv_self_seq_keep(ctx, seq_id);
2564
2178
  }
2565
2179
 
2566
2180
  void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
@@ -2569,7 +2183,7 @@ void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
2569
2183
  return;
2570
2184
  }
2571
2185
 
2572
- return kv->seq_keep(seq_id);
2186
+ kv->seq_keep(seq_id);
2573
2187
  }
2574
2188
 
2575
2189
  // deprecated
@@ -2579,7 +2193,7 @@ void llama_kv_cache_seq_add(
2579
2193
  llama_pos p0,
2580
2194
  llama_pos p1,
2581
2195
  llama_pos delta) {
2582
- return llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
2196
+ llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
2583
2197
  }
2584
2198
 
2585
2199
  void llama_kv_self_seq_add(
@@ -2593,7 +2207,7 @@ void llama_kv_self_seq_add(
2593
2207
  return;
2594
2208
  }
2595
2209
 
2596
- return kv->seq_add(seq_id, p0, p1, delta);
2210
+ kv->seq_add(seq_id, p0, p1, delta);
2597
2211
  }
2598
2212
 
2599
2213
  // deprecated
@@ -2603,7 +2217,7 @@ void llama_kv_cache_seq_div(
2603
2217
  llama_pos p0,
2604
2218
  llama_pos p1,
2605
2219
  int d) {
2606
- return llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
2220
+ llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
2607
2221
  }
2608
2222
 
2609
2223
  void llama_kv_self_seq_div(
@@ -2617,7 +2231,7 @@ void llama_kv_self_seq_div(
2617
2231
  return;
2618
2232
  }
2619
2233
 
2620
- return kv->seq_div(seq_id, p0, p1, d);
2234
+ kv->seq_div(seq_id, p0, p1, d);
2621
2235
  }
2622
2236
 
2623
2237
  // deprecated
@@ -2636,7 +2250,7 @@ llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
2636
2250
 
2637
2251
  // deprecated
2638
2252
  void llama_kv_cache_defrag(llama_context * ctx) {
2639
- return llama_kv_self_defrag(ctx);
2253
+ llama_kv_self_defrag(ctx);
2640
2254
  }
2641
2255
 
2642
2256
  void llama_kv_self_defrag(llama_context * ctx) {
@@ -2645,7 +2259,8 @@ void llama_kv_self_defrag(llama_context * ctx) {
2645
2259
  return;
2646
2260
  }
2647
2261
 
2648
- return kv->defrag();
2262
+ // force defrag
2263
+ kv->defrag_sched(-1.0f);
2649
2264
  }
2650
2265
 
2651
2266
  // deprecated