@fugood/llama.node 0.3.16 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (281) hide show
  1. package/CMakeLists.txt +6 -1
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  7. package/bin/linux-cuda/x64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  10. package/bin/win32/arm64/llama-node.node +0 -0
  11. package/bin/win32/arm64/node.lib +0 -0
  12. package/bin/win32/x64/llama-node.node +0 -0
  13. package/bin/win32/x64/node.lib +0 -0
  14. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/arm64/node.lib +0 -0
  16. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  17. package/bin/win32-vulkan/x64/node.lib +0 -0
  18. package/lib/binding.ts +44 -2
  19. package/lib/index.js +132 -1
  20. package/lib/index.ts +203 -3
  21. package/package.json +2 -1
  22. package/src/EmbeddingWorker.cpp +1 -1
  23. package/src/LlamaCompletionWorker.cpp +374 -19
  24. package/src/LlamaCompletionWorker.h +31 -10
  25. package/src/LlamaContext.cpp +216 -7
  26. package/src/LlamaContext.h +12 -0
  27. package/src/common.hpp +15 -0
  28. package/src/llama.cpp/.github/workflows/build-linux-cross.yml +233 -0
  29. package/src/llama.cpp/.github/workflows/build.yml +89 -767
  30. package/src/llama.cpp/.github/workflows/docker.yml +9 -6
  31. package/src/llama.cpp/.github/workflows/release.yml +716 -0
  32. package/src/llama.cpp/.github/workflows/server.yml +19 -23
  33. package/src/llama.cpp/CMakeLists.txt +11 -1
  34. package/src/llama.cpp/cmake/build-info.cmake +8 -2
  35. package/src/llama.cpp/cmake/x64-windows-llvm.cmake +0 -6
  36. package/src/llama.cpp/common/CMakeLists.txt +35 -4
  37. package/src/llama.cpp/common/arg.cpp +844 -121
  38. package/src/llama.cpp/common/arg.h +9 -0
  39. package/src/llama.cpp/common/chat.cpp +129 -107
  40. package/src/llama.cpp/common/chat.h +2 -0
  41. package/src/llama.cpp/common/common.cpp +64 -518
  42. package/src/llama.cpp/common/common.h +35 -45
  43. package/src/llama.cpp/common/json-schema-to-grammar.cpp +3 -0
  44. package/src/llama.cpp/common/llguidance.cpp +31 -47
  45. package/src/llama.cpp/common/minja/chat-template.hpp +23 -11
  46. package/src/llama.cpp/common/minja/minja.hpp +186 -127
  47. package/src/llama.cpp/common/regex-partial.cpp +204 -0
  48. package/src/llama.cpp/common/regex-partial.h +56 -0
  49. package/src/llama.cpp/common/sampling.cpp +60 -50
  50. package/src/llama.cpp/docs/build.md +122 -7
  51. package/src/llama.cpp/examples/CMakeLists.txt +2 -32
  52. package/src/llama.cpp/examples/batched/batched.cpp +1 -1
  53. package/src/llama.cpp/examples/embedding/embedding.cpp +9 -12
  54. package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
  55. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  56. package/src/llama.cpp/examples/parallel/parallel.cpp +89 -15
  57. package/src/llama.cpp/examples/passkey/passkey.cpp +1 -1
  58. package/src/llama.cpp/examples/speculative/speculative.cpp +1 -1
  59. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
  60. package/src/llama.cpp/examples/sycl/build.sh +2 -2
  61. package/src/llama.cpp/examples/sycl/win-build-sycl.bat +2 -2
  62. package/src/llama.cpp/examples/training/CMakeLists.txt +5 -0
  63. package/src/llama.cpp/examples/training/finetune.cpp +96 -0
  64. package/src/llama.cpp/ggml/CMakeLists.txt +35 -2
  65. package/src/llama.cpp/ggml/cmake/GitVars.cmake +22 -0
  66. package/src/llama.cpp/ggml/include/ggml-backend.h +4 -4
  67. package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -1
  68. package/src/llama.cpp/ggml/include/ggml-cpu.h +5 -0
  69. package/src/llama.cpp/ggml/include/ggml-opt.h +47 -28
  70. package/src/llama.cpp/ggml/include/ggml-rpc.h +6 -1
  71. package/src/llama.cpp/ggml/include/ggml.h +76 -106
  72. package/src/llama.cpp/ggml/src/CMakeLists.txt +11 -8
  73. package/src/llama.cpp/ggml/src/ggml-alloc.c +4 -1
  74. package/src/llama.cpp/ggml/src/ggml-backend.cpp +9 -5
  75. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +0 -2
  76. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +8 -4
  77. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +5 -5
  78. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +692 -1534
  79. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +613 -122
  80. package/src/llama.cpp/ggml/src/ggml-cann/common.h +135 -1
  81. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +507 -137
  82. package/src/llama.cpp/ggml/src/ggml-common.h +12 -6
  83. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +66 -33
  84. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
  85. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
  86. package/src/llama.cpp/ggml/src/ggml-cpu/common.h +72 -0
  87. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +1 -1
  88. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +896 -194
  89. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +2 -21
  90. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +1060 -410
  91. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1008 -13533
  92. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +31 -16
  93. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +90 -12
  94. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -13
  95. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +266 -72
  96. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1034 -88
  97. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +8796 -0
  98. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +110 -0
  99. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +892 -0
  100. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
  101. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +28 -0
  102. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +252 -0
  103. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +802 -0
  104. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +23 -4
  105. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +7 -0
  106. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +1 -0
  107. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +0 -4
  108. package/src/llama.cpp/ggml/src/ggml-impl.h +52 -18
  109. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +106 -14
  110. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +67 -119
  111. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1023 -262
  112. package/src/llama.cpp/ggml/src/ggml-opt.cpp +368 -190
  113. package/src/llama.cpp/ggml/src/ggml-quants.c +0 -6
  114. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +307 -40
  115. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +125 -45
  116. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +10 -8
  117. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +239 -0
  118. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.hpp +39 -0
  119. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +0 -35
  120. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +9 -307
  121. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +72 -25
  122. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +14 -7
  123. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
  124. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +7 -1
  125. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +79 -90
  126. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +944 -438
  127. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +22 -23
  128. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
  129. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +24 -20
  130. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +1 -4
  131. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +507 -411
  132. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +84 -74
  133. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +1 -3
  134. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +185 -89
  135. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +37 -49
  136. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +7 -22
  137. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +4 -14
  138. package/src/llama.cpp/ggml/src/ggml-sycl/quants.hpp +83 -0
  139. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +204 -118
  140. package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +1 -3
  141. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +128 -53
  142. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +83 -49
  143. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1278 -282
  144. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +32 -0
  145. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +133 -30
  146. package/src/llama.cpp/ggml/src/ggml.c +170 -265
  147. package/src/llama.cpp/ggml/src/gguf.cpp +34 -33
  148. package/src/llama.cpp/include/llama.h +82 -22
  149. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.inp +112 -0
  150. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.out +46 -0
  151. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +112 -0
  152. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.out +46 -0
  153. package/src/llama.cpp/requirements/requirements-all.txt +5 -3
  154. package/src/llama.cpp/requirements/requirements-gguf_editor_gui.txt +3 -0
  155. package/src/llama.cpp/scripts/xxd.cmake +1 -1
  156. package/src/llama.cpp/src/CMakeLists.txt +4 -2
  157. package/src/llama.cpp/src/llama-adapter.cpp +43 -1
  158. package/src/llama.cpp/src/llama-arch.cpp +163 -17
  159. package/src/llama.cpp/src/llama-arch.h +16 -0
  160. package/src/llama.cpp/src/llama-batch.cpp +5 -1
  161. package/src/llama.cpp/src/llama-batch.h +2 -1
  162. package/src/llama.cpp/src/llama-chat.cpp +91 -16
  163. package/src/llama.cpp/src/llama-chat.h +7 -2
  164. package/src/llama.cpp/src/llama-context.cpp +479 -575
  165. package/src/llama.cpp/src/llama-context.h +44 -33
  166. package/src/llama.cpp/src/llama-cparams.h +1 -0
  167. package/src/llama.cpp/src/llama-graph.cpp +209 -157
  168. package/src/llama.cpp/src/llama-graph.h +38 -14
  169. package/src/llama.cpp/src/llama-hparams.h +13 -0
  170. package/src/llama.cpp/src/llama-kv-cache.cpp +1604 -543
  171. package/src/llama.cpp/src/llama-kv-cache.h +283 -171
  172. package/src/llama.cpp/src/llama-memory.h +12 -2
  173. package/src/llama.cpp/src/llama-mmap.cpp +1 -1
  174. package/src/llama.cpp/src/llama-model-loader.cpp +34 -20
  175. package/src/llama.cpp/src/llama-model-loader.h +5 -3
  176. package/src/llama.cpp/src/llama-model-saver.cpp +281 -0
  177. package/src/llama.cpp/src/llama-model-saver.h +37 -0
  178. package/src/llama.cpp/src/llama-model.cpp +1803 -330
  179. package/src/llama.cpp/src/llama-model.h +21 -2
  180. package/src/llama.cpp/src/llama-quant.cpp +33 -10
  181. package/src/llama.cpp/src/llama-sampling.cpp +25 -7
  182. package/src/llama.cpp/src/llama-vocab.cpp +86 -10
  183. package/src/llama.cpp/src/llama-vocab.h +6 -0
  184. package/src/llama.cpp/src/llama.cpp +15 -1
  185. package/src/llama.cpp/tests/CMakeLists.txt +52 -31
  186. package/src/llama.cpp/tests/test-arg-parser.cpp +51 -4
  187. package/src/llama.cpp/tests/test-backend-ops.cpp +189 -90
  188. package/src/llama.cpp/tests/test-chat-template.cpp +26 -6
  189. package/src/llama.cpp/tests/test-chat.cpp +15 -3
  190. package/src/llama.cpp/{examples/gbnf-validator/gbnf-validator.cpp → tests/test-gbnf-validator.cpp} +2 -2
  191. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -2
  192. package/src/llama.cpp/tests/test-grammar-llguidance.cpp +63 -2
  193. package/src/llama.cpp/tests/test-grammar-parser.cpp +3 -1
  194. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -1
  195. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -1
  196. package/src/llama.cpp/tests/test-mtmd-c-api.c +63 -0
  197. package/src/llama.cpp/tests/test-opt.cpp +33 -21
  198. package/src/llama.cpp/{examples/quantize-stats/quantize-stats.cpp → tests/test-quantize-stats.cpp} +3 -1
  199. package/src/llama.cpp/tests/test-regex-partial.cpp +288 -0
  200. package/src/llama.cpp/tests/test-sampling.cpp +1 -1
  201. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +2 -1
  202. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +2 -1
  203. package/src/llama.cpp/tools/CMakeLists.txt +39 -0
  204. package/src/llama.cpp/{examples → tools}/batched-bench/batched-bench.cpp +3 -3
  205. package/src/llama.cpp/{examples → tools}/export-lora/export-lora.cpp +1 -1
  206. package/src/llama.cpp/{examples → tools}/gguf-split/gguf-split.cpp +15 -16
  207. package/src/llama.cpp/{examples → tools}/imatrix/imatrix.cpp +11 -9
  208. package/src/llama.cpp/{examples → tools}/llama-bench/llama-bench.cpp +623 -274
  209. package/src/llama.cpp/{examples → tools}/main/main.cpp +22 -14
  210. package/src/llama.cpp/tools/mtmd/CMakeLists.txt +47 -0
  211. package/src/llama.cpp/tools/mtmd/clip-impl.h +365 -0
  212. package/src/llama.cpp/tools/mtmd/clip.cpp +3646 -0
  213. package/src/llama.cpp/tools/mtmd/clip.h +99 -0
  214. package/src/llama.cpp/tools/mtmd/deprecation-warning.cpp +22 -0
  215. package/src/llama.cpp/tools/mtmd/mtmd-cli.cpp +370 -0
  216. package/src/llama.cpp/tools/mtmd/mtmd-helper.cpp +310 -0
  217. package/src/llama.cpp/tools/mtmd/mtmd.cpp +678 -0
  218. package/src/llama.cpp/tools/mtmd/mtmd.h +331 -0
  219. package/src/llama.cpp/{examples → tools}/perplexity/perplexity.cpp +21 -5
  220. package/src/llama.cpp/{examples → tools}/quantize/quantize.cpp +53 -3
  221. package/src/llama.cpp/tools/rpc/CMakeLists.txt +4 -0
  222. package/src/llama.cpp/tools/rpc/rpc-server.cpp +322 -0
  223. package/src/llama.cpp/tools/run/CMakeLists.txt +16 -0
  224. package/src/llama.cpp/{examples → tools}/run/run.cpp +30 -30
  225. package/src/llama.cpp/{examples → tools}/server/CMakeLists.txt +2 -1
  226. package/src/llama.cpp/{examples → tools}/server/httplib.h +313 -247
  227. package/src/llama.cpp/{examples → tools}/server/server.cpp +529 -215
  228. package/src/llama.cpp/{examples → tools}/server/utils.hpp +427 -6
  229. package/src/llama.cpp/{examples → tools}/tts/tts.cpp +6 -9
  230. package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +0 -6
  231. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +0 -5
  232. package/src/llama.cpp/examples/infill/CMakeLists.txt +0 -5
  233. package/src/llama.cpp/examples/infill/infill.cpp +0 -590
  234. package/src/llama.cpp/examples/llava/CMakeLists.txt +0 -66
  235. package/src/llama.cpp/examples/llava/android/build_64.sh +0 -8
  236. package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +0 -59
  237. package/src/llama.cpp/examples/llava/clip.cpp +0 -3206
  238. package/src/llama.cpp/examples/llava/clip.h +0 -118
  239. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +0 -341
  240. package/src/llama.cpp/examples/llava/llava-cli.cpp +0 -332
  241. package/src/llama.cpp/examples/llava/llava.cpp +0 -574
  242. package/src/llama.cpp/examples/llava/llava.h +0 -49
  243. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +0 -354
  244. package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +0 -584
  245. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +0 -6
  246. package/src/llama.cpp/examples/rpc/CMakeLists.txt +0 -2
  247. package/src/llama.cpp/examples/rpc/rpc-server.cpp +0 -171
  248. package/src/llama.cpp/examples/run/CMakeLists.txt +0 -5
  249. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  250. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  251. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  252. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  253. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  254. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  255. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  256. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  257. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  258. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  259. /package/src/llama.cpp/{examples → tools}/batched-bench/CMakeLists.txt +0 -0
  260. /package/src/llama.cpp/{examples → tools}/cvector-generator/CMakeLists.txt +0 -0
  261. /package/src/llama.cpp/{examples → tools}/cvector-generator/completions.txt +0 -0
  262. /package/src/llama.cpp/{examples → tools}/cvector-generator/cvector-generator.cpp +0 -0
  263. /package/src/llama.cpp/{examples → tools}/cvector-generator/mean.hpp +0 -0
  264. /package/src/llama.cpp/{examples → tools}/cvector-generator/negative.txt +0 -0
  265. /package/src/llama.cpp/{examples → tools}/cvector-generator/pca.hpp +0 -0
  266. /package/src/llama.cpp/{examples → tools}/cvector-generator/positive.txt +0 -0
  267. /package/src/llama.cpp/{examples → tools}/export-lora/CMakeLists.txt +0 -0
  268. /package/src/llama.cpp/{examples → tools}/gguf-split/CMakeLists.txt +0 -0
  269. /package/src/llama.cpp/{examples → tools}/imatrix/CMakeLists.txt +0 -0
  270. /package/src/llama.cpp/{examples → tools}/llama-bench/CMakeLists.txt +0 -0
  271. /package/src/llama.cpp/{examples → tools}/main/CMakeLists.txt +0 -0
  272. /package/src/llama.cpp/{examples/llava → tools/mtmd}/requirements.txt +0 -0
  273. /package/src/llama.cpp/{examples → tools}/perplexity/CMakeLists.txt +0 -0
  274. /package/src/llama.cpp/{examples → tools}/quantize/CMakeLists.txt +0 -0
  275. /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.cpp +0 -0
  276. /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.h +0 -0
  277. /package/src/llama.cpp/{examples → tools}/server/bench/requirements.txt +0 -0
  278. /package/src/llama.cpp/{examples → tools}/server/tests/requirements.txt +0 -0
  279. /package/src/llama.cpp/{examples → tools}/tokenize/CMakeLists.txt +0 -0
  280. /package/src/llama.cpp/{examples → tools}/tokenize/tokenize.cpp +0 -0
  281. /package/src/llama.cpp/{examples → tools}/tts/CMakeLists.txt +0 -0
@@ -6,7 +6,6 @@
6
6
  #include "llama-model.h"
7
7
  #include "llama-kv-cache.h"
8
8
 
9
- #include <cassert>
10
9
  #include <cstring>
11
10
  #include <stdexcept>
12
11
  #include <cinttypes>
@@ -94,6 +93,7 @@ llama_context::llama_context(
94
93
  }
95
94
 
96
95
  cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
96
+ cparams.op_offload = params.op_offload;
97
97
 
98
98
  const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
99
99
 
@@ -113,12 +113,10 @@ llama_context::llama_context(
113
113
  }
114
114
 
115
115
  if (n_ctx_per_seq > hparams.n_ctx_train) {
116
- LLAMA_LOG_WARN("%s: n_ctx_pre_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
116
+ LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
117
117
  __func__, n_ctx_per_seq, hparams.n_ctx_train);
118
118
  }
119
119
 
120
- logits_all = params.logits_all;
121
-
122
120
  if (!hparams.vocab_only) {
123
121
  // GPU backends
124
122
  for (auto * dev : model.devices) {
@@ -176,44 +174,13 @@ llama_context::llama_context(
176
174
  }
177
175
 
178
176
  // init the memory module
179
- // TODO: for now, always create a unified KV cache
180
177
  if (!hparams.vocab_only) {
181
- kv_self.reset(static_cast<llama_kv_cache_unified *>(model.create_memory()));
182
-
183
- LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
184
-
185
- cparams.n_ctx = GGML_PAD(cparams.n_ctx, kv_self->get_padding(cparams));
186
-
187
- LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
188
-
189
- uint32_t kv_size = cparams.n_ctx;
190
- ggml_type type_k = params.type_k;
191
- ggml_type type_v = params.type_v;
192
-
193
- if (llama_model_is_recurrent(&model)) {
194
- // Mamba needs at least as many KV cells as there are sequences kept at any time
195
- kv_size = std::max((uint32_t) 1, params.n_seq_max);
196
- // it's probably best to keep as much precision as possible for the states
197
- type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
198
- type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
199
- }
200
-
201
- GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
202
- GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
203
-
204
- if (!kv_self->init(model, cparams, type_k, type_v, kv_size, cparams.offload_kqv)) {
205
- throw std::runtime_error("failed to initialize self-attention cache");
206
- }
207
-
208
- {
209
- const size_t memory_size_k = kv_self->size_k_bytes();
210
- const size_t memory_size_v = kv_self->size_v_bytes();
178
+ llama_memory_params params_mem = {
179
+ /*.type_k =*/ params.type_k,
180
+ /*.type_v =*/ params.type_v,
181
+ };
211
182
 
212
- LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
213
- (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
214
- ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
215
- ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
216
- }
183
+ memory.reset(model.create_memory(params_mem, cparams));
217
184
  }
218
185
 
219
186
  // init backends
@@ -255,7 +222,8 @@ llama_context::llama_context(
255
222
  model.n_devices() > 1 &&
256
223
  model.params.n_gpu_layers > (int) model.hparams.n_layer &&
257
224
  model.params.split_mode == LLAMA_SPLIT_MODE_LAYER &&
258
- cparams.offload_kqv;
225
+ cparams.offload_kqv &&
226
+ !model.has_tensor_overrides();
259
227
 
260
228
  // pipeline parallelism requires support for async compute and events in all devices
261
229
  if (pipeline_parallel) {
@@ -276,7 +244,7 @@ llama_context::llama_context(
276
244
  }
277
245
  }
278
246
 
279
- sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel));
247
+ sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel, cparams.op_offload));
280
248
 
281
249
  if (pipeline_parallel) {
282
250
  LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get()));
@@ -284,7 +252,7 @@ llama_context::llama_context(
284
252
  }
285
253
 
286
254
  // reserve worst-case graph
287
- if (!hparams.vocab_only) {
255
+ if (!hparams.vocab_only && memory) {
288
256
  const uint32_t n_seqs = 1; // TODO: worst-case number of sequences
289
257
  const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
290
258
 
@@ -294,10 +262,7 @@ llama_context::llama_context(
294
262
  // TODO: something cleaner
295
263
  const auto n_outputs_save = n_outputs;
296
264
 
297
- // max number of outputs
298
- n_outputs = n_tokens;
299
-
300
- LLAMA_LOG_DEBUG("%s: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
265
+ LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
301
266
 
302
267
  int n_splits_pp = -1;
303
268
  int n_nodes_pp = -1;
@@ -306,15 +271,24 @@ llama_context::llama_context(
306
271
  int n_nodes_tg = -1;
307
272
 
308
273
  // simulate full KV cache
309
- kv_self->n = kv_self->size;
274
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
275
+
276
+ kv_self->set_full();
310
277
 
311
278
  cross.v_embd.clear();
312
279
 
313
280
  // reserve pp graph first so that buffers are only allocated once
314
281
  {
315
282
  llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
283
+
284
+ // max number of outputs
285
+ n_outputs = ubatch_pp.n_tokens;
286
+
287
+ LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
288
+
316
289
  auto * gf = graph_init();
317
290
  graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
291
+
318
292
  if (!ggml_backend_sched_reserve(sched.get(), gf)) {
319
293
  throw std::runtime_error("failed to allocate compute pp buffers");
320
294
  }
@@ -326,11 +300,18 @@ llama_context::llama_context(
326
300
  // reserve with tg graph to get the number of splits and nodes
327
301
  {
328
302
  llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
303
+
304
+ n_outputs = ubatch_tg.n_tokens;
305
+
306
+ LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_tg.n_tokens, ubatch_tg.n_seqs);
307
+
329
308
  auto * gf = graph_init();
330
309
  graph_build(ctx_compute.get(), gf, ubatch_tg, LLM_GRAPH_TYPE_DEFAULT);
310
+
331
311
  if (!ggml_backend_sched_reserve(sched.get(), gf)) {
332
312
  throw std::runtime_error("failed to allocate compute tg buffers");
333
313
  }
314
+
334
315
  n_splits_tg = ggml_backend_sched_get_n_splits(sched.get());
335
316
  n_nodes_tg = ggml_graph_n_nodes(gf);
336
317
  }
@@ -338,8 +319,14 @@ llama_context::llama_context(
338
319
  // reserve again with pp graph to avoid ggml-alloc reallocations during inference
339
320
  {
340
321
  llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
322
+
323
+ n_outputs = ubatch_pp.n_tokens;
324
+
325
+ LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
326
+
341
327
  auto * gf = graph_init();
342
328
  graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
329
+
343
330
  if (!ggml_backend_sched_reserve(sched.get(), gf)) {
344
331
  throw std::runtime_error("failed to allocate compute pp buffers");
345
332
  }
@@ -372,7 +359,9 @@ llama_context::llama_context(
372
359
  }
373
360
  }
374
361
 
375
- llama_context::~llama_context() = default;
362
+ llama_context::~llama_context() {
363
+ ggml_opt_free(opt_ctx);
364
+ }
376
365
 
377
366
  void llama_context::synchronize() {
378
367
  ggml_backend_sched_synchronize(sched.get());
@@ -408,6 +397,18 @@ const llama_model & llama_context::get_model() const {
408
397
  return model;
409
398
  }
410
399
 
400
+ const llama_cparams & llama_context::get_cparams() const {
401
+ return cparams;
402
+ }
403
+
404
+ ggml_backend_sched_t llama_context::get_sched() const {
405
+ return sched.get();
406
+ }
407
+
408
+ ggml_context * llama_context::get_ctx_compute() const {
409
+ return ctx_compute.get();
410
+ }
411
+
411
412
  uint32_t llama_context::n_ctx() const {
412
413
  return cparams.n_ctx;
413
414
  }
@@ -437,345 +438,21 @@ uint32_t llama_context::n_threads_batch() const {
437
438
  }
438
439
 
439
440
  llama_kv_cache * llama_context::get_kv_self() {
440
- return kv_self.get();
441
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
442
+ return kv_self;
441
443
  }
442
444
 
443
445
  const llama_kv_cache * llama_context::get_kv_self() const {
444
- return kv_self.get();
445
- }
446
-
447
- ggml_tensor * llama_context::build_rope_shift(
448
- ggml_context * ctx0,
449
- ggml_tensor * cur,
450
- ggml_tensor * shift,
451
- ggml_tensor * factors,
452
- float freq_base,
453
- float freq_scale,
454
- ggml_backend_buffer * bbuf) const {
455
- const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
456
-
457
- const auto & yarn_ext_factor = cparams.yarn_ext_factor;
458
- const auto & yarn_attn_factor = cparams.yarn_attn_factor;
459
- const auto & yarn_beta_fast = cparams.yarn_beta_fast;
460
- const auto & yarn_beta_slow = cparams.yarn_beta_slow;
461
-
462
- const auto & hparams = model.hparams;
463
-
464
- const auto & n_rot = hparams.n_rot;
465
- const auto & rope_type = hparams.rope_type;
466
-
467
- ggml_tensor * tmp;
468
-
469
- if (ggml_is_quantized(cur->type)) {
470
- // dequantize to f32 -> RoPE -> quantize back
471
- tmp = ggml_cast(ctx0, cur, GGML_TYPE_F32);
472
-
473
- if (bbuf) {
474
- for (const auto & backend : backends) {
475
- // Figure out which backend KV cache belongs to
476
- if (ggml_backend_supports_buft(backend.get(), ggml_backend_buffer_get_type(bbuf))) {
477
- ggml_backend_sched_set_tensor_backend(sched.get(), tmp, backend.get());
478
- break;
479
- }
480
- }
481
- }
482
-
483
- tmp = ggml_rope_ext_inplace(ctx0, tmp,
484
- shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
485
- yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
486
-
487
- tmp = ggml_cpy(ctx0, tmp, cur);
488
- } else {
489
- // we rotate only the first n_rot dimensions
490
- tmp = ggml_rope_ext_inplace(ctx0, cur,
491
- shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
492
- yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
493
- }
494
-
495
- return tmp;
496
- }
497
-
498
- class llm_graph_input_k_shift : public llm_graph_input_i {
499
- public:
500
- llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
501
- virtual ~llm_graph_input_k_shift() = default;
502
-
503
- void set_input(const llama_ubatch * ubatch) override;
504
-
505
- ggml_tensor * k_shift; // I32 [kv_size]
506
-
507
- const llama_kv_cache_unified * kv_self;
508
- };
509
-
510
- void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
511
- GGML_UNUSED(ubatch);
512
-
513
- if (k_shift) {
514
- assert(ggml_backend_buffer_is_host(k_shift->buffer));
515
-
516
- int32_t * data = (int32_t *) k_shift->data;
517
-
518
- for (uint32_t i = 0; i < kv_self->size; ++i) {
519
- data[i] = kv_self->cells[i].delta;
520
- }
521
- }
522
- }
523
-
524
- llm_graph_result_ptr llama_context::build_kv_self_shift(
525
- ggml_context * ctx0,
526
- ggml_cgraph * gf) const {
527
- auto res = std::make_unique<llm_graph_result>();
528
-
529
- const auto & hparams = model.hparams;
530
-
531
- const auto & n_layer = hparams.n_layer;
532
-
533
- const auto & n_embd_head_k = hparams.n_embd_head_k;
534
- //const auto & n_embd_head_v = hparams.n_embd_head_v;
535
-
536
- //GGML_ASSERT(kv_self->size == n_ctx);
537
-
538
- auto inp = std::make_unique<llm_graph_input_k_shift>(kv_self.get());
539
-
540
- inp->k_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, cparams.n_ctx);
541
- ggml_set_input(inp->k_shift);
542
-
543
- for (uint32_t il = 0; il < n_layer; ++il) {
544
- const int64_t n_head_kv = hparams.n_head_kv(il);
545
- const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
546
-
547
- const bool is_swa = hparams.is_swa(il);
548
-
549
- // note: the swa rope params could become part of the cparams in the future
550
- // if we decide to make them configurable, like the non-sliding ones
551
- const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
552
- const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
553
-
554
- ggml_tensor * rope_factors = kv_self->cbs.get_rope_factors(n_ctx_per_seq(), il);
555
-
556
- ggml_tensor * k =
557
- ggml_view_3d(ctx0, kv_self->k_l[il],
558
- n_embd_head_k, n_head_kv, kv_self->size,
559
- ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
560
- ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
561
- 0);
562
-
563
- ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, kv_self->k_l[il]->buffer);
564
-
565
- ggml_build_forward_expand(gf, cur);
566
- }
567
-
568
- res->add_input(std::move(inp));
569
-
570
- return res;
571
- }
572
-
573
- llm_graph_result_ptr llama_context::build_kv_self_defrag(
574
- ggml_context * ctx0,
575
- ggml_cgraph * gf) const {
576
- auto res = std::make_unique<llm_graph_result>();
577
-
578
- const auto & hparams = model.hparams;
579
-
580
- const auto & ids = kv_self->defrag_info.ids;
581
-
582
- #if 0
583
- // CPU defrag
584
- //
585
- // TODO: optimizations are possible:
586
- // - multiple threads
587
- // - avoid copying to the host memory when already there
588
- //
589
- // likely not worth the effort, as we have ggml_graph based defrag
590
- //
591
-
592
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
593
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
594
-
595
- const uint32_t kv_size = size;
596
-
597
- std::vector<uint8_t> buf_k;
598
- std::vector<uint8_t> buf_v;
599
-
600
- for (uint32_t il = 0; il < n_layer; ++il) {
601
- const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
602
- const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size);
603
-
604
- const size_t v_size_el = ggml_type_size(v_l[il]->type);
605
- const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size);
606
-
607
- buf_k.resize(k_size);
608
- buf_v.resize(v_size);
609
-
610
- ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size());
611
- ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size());
612
-
613
- // batch move [i, i+nm) to [id, id+nm)
614
- // note: cells can move only to a lower index
615
- for (uint32_t i = 0; i < n_kv; ++i) {
616
- const uint32_t id = ids[i];
617
-
618
- if (i == id || id == n_kv) {
619
- continue;
620
- }
621
-
622
- uint32_t nm = 1;
623
-
624
- while (i + nm < n_kv && ids[i + nm] == id + nm) {
625
- nm++;
626
- }
627
-
628
- // move keys
629
- {
630
- const int64_t os = i*k_size_row;
631
- const int64_t od = id*k_size_row;
632
-
633
- memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row);
634
- }
635
-
636
- // move values (note: they are transposed)
637
- {
638
- const int64_t os = i;
639
- const int64_t od = id;
640
-
641
- for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
642
- memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
643
- }
644
- }
645
-
646
- i += nm - 1;
647
- }
648
-
649
- ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size());
650
- ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
651
- }
652
- #else
653
- for (uint32_t i = 0; i < ids.size(); ++i) {
654
- const uint32_t id = ids[i];
655
-
656
- if (i == id || id == ids.size()) {
657
- continue;
658
- }
659
-
660
- uint32_t nm = 1;
661
-
662
- while (i + nm < ids.size() && ids[i + nm] == id + nm) {
663
- nm++;
664
- }
665
-
666
- for (uint32_t il = 0; il < hparams.n_layer; ++il) { // NOLINT
667
- const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
668
- const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
669
-
670
- ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self->k_l[il],
671
- n_embd_k_gqa, nm,
672
- ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
673
- ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*i));
674
-
675
- ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self->k_l[il],
676
- n_embd_k_gqa, nm,
677
- ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
678
- ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*id));
679
-
680
- ggml_tensor * view_v_src;
681
- ggml_tensor * view_v_dst;
682
-
683
- if (cparams.flash_attn) {
684
- // NOTE: the V cache is not transposed when using flash attention
685
- view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il],
686
- n_embd_v_gqa, nm,
687
- ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
688
- ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*i));
689
-
690
- view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il],
691
- n_embd_v_gqa, nm,
692
- ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
693
- ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*id));
694
- } else {
695
- view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il],
696
- nm, n_embd_v_gqa,
697
- ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
698
- ggml_row_size(kv_self->v_l[il]->type, i));
699
-
700
- view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il],
701
- nm, n_embd_v_gqa,
702
- ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
703
- ggml_row_size(kv_self->v_l[il]->type, id));
704
- }
705
-
706
- ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
707
- ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst));
708
- }
709
-
710
- i += nm - 1;
711
- }
712
-
713
- //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
714
- #endif
715
-
716
- return res;
446
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
447
+ return kv_self;
717
448
  }
718
449
 
719
450
  void llama_context::kv_self_update() {
720
- auto & kv = kv_self;
721
-
722
451
  bool need_reserve = false;
723
452
 
724
- if (kv->has_shift) {
725
- if (!kv->get_can_shift()) {
726
- GGML_ABORT("The current context does not support K-shift");
727
- }
728
-
729
- LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__);
730
-
731
- // apply K-shift if needed
732
- if (model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
733
- ggml_backend_sched_reset(sched.get());
734
-
735
- auto * gf = graph_init();
736
-
737
- auto res = build_kv_self_shift(ctx_compute.get(), gf);
738
-
739
- ggml_backend_sched_alloc_graph(sched.get(), gf);
740
-
741
- res->set_inputs(nullptr);
742
-
743
- graph_compute(gf, false);
744
-
745
- need_reserve = true;
746
- }
747
-
748
- {
749
- kv->has_shift = false;
750
-
751
- for (uint32_t i = 0; i < kv->size; ++i) {
752
- kv->cells[i].delta = 0;
753
- }
754
- }
755
- }
756
-
757
- // defragment the KV cache if needed
758
- if (kv->do_defrag) {
759
- LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
760
-
761
- if (kv->defrag_prepare(graph_max_nodes())) {
762
- ggml_backend_sched_reset(sched.get());
763
-
764
- auto * gf = graph_init();
765
-
766
- auto res = build_kv_self_defrag(ctx_compute.get(), gf);
767
-
768
- ggml_backend_sched_alloc_graph(sched.get(), gf);
769
-
770
- res->set_inputs(nullptr);
453
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
771
454
 
772
- graph_compute(gf, false);
773
-
774
- need_reserve = true;
775
- }
776
-
777
- kv->do_defrag = false;
778
- }
455
+ need_reserve = kv_self->update(*this);
779
456
 
780
457
  // reserve a worst case graph if needed
781
458
  if (need_reserve) {
@@ -786,7 +463,7 @@ void llama_context::kv_self_update() {
786
463
  uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
787
464
 
788
465
  // simulate full KV cache
789
- kv_self->n = kv_self->size;
466
+ kv_self->set_full();
790
467
 
791
468
  llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
792
469
  llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
@@ -807,9 +484,6 @@ enum llama_pooling_type llama_context::pooling_type() const {
807
484
  }
808
485
 
809
486
  float * llama_context::get_logits() {
810
- // reorder logits for backward compatibility
811
- output_reorder();
812
-
813
487
  return logits;
814
488
  }
815
489
 
@@ -852,9 +526,6 @@ float * llama_context::get_logits_ith(int32_t i) {
852
526
  }
853
527
 
854
528
  float * llama_context::get_embeddings() {
855
- // reorder embeddings for backward compatibility
856
- output_reorder();
857
-
858
529
  return embd;
859
530
  }
860
531
 
@@ -1006,8 +677,8 @@ int llama_context::encode(llama_batch & inp_batch) {
1006
677
  }
1007
678
 
1008
679
  // temporary allocate memory for the input batch if needed
1009
- // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
1010
- llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
680
+ // note: during encode, we always pass the full sequence starting from pos = 0
681
+ llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : 0);
1011
682
 
1012
683
  const llama_batch & batch = batch_allocr.batch;
1013
684
  const int32_t n_tokens = batch.n_tokens;
@@ -1032,11 +703,13 @@ int llama_context::encode(llama_batch & inp_batch) {
1032
703
  t_compute_start_us = ggml_time_us();
1033
704
  }
1034
705
 
706
+ embd_seq.clear();
707
+
1035
708
  n_queued_tokens += n_tokens;
1036
709
 
1037
710
  const int64_t n_embd = hparams.n_embd;
1038
711
 
1039
- sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
712
+ llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
1040
713
 
1041
714
  const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
1042
715
 
@@ -1093,12 +766,12 @@ int llama_context::encode(llama_batch & inp_batch) {
1093
766
  ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
1094
767
  GGML_ASSERT(backend_embd != nullptr);
1095
768
 
1096
- GGML_ASSERT(embd != nullptr);
1097
-
1098
769
  switch (cparams.pooling_type) {
1099
770
  case LLAMA_POOLING_TYPE_NONE:
1100
771
  {
1101
772
  // extract token embeddings
773
+ GGML_ASSERT(embd != nullptr);
774
+
1102
775
  GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size);
1103
776
  ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float));
1104
777
  } break;
@@ -1123,11 +796,18 @@ int llama_context::encode(llama_batch & inp_batch) {
1123
796
  } break;
1124
797
  case LLAMA_POOLING_TYPE_RANK:
1125
798
  {
1126
- // TODO: this likely should be the same logic as in llama_decoder_internal, but better to
1127
- // wait for an encoder model that requires this pooling type in order to test it
1128
- // https://github.com/ggerganov/llama.cpp/pull/9510
1129
- GGML_ABORT("RANK pooling not implemented yet");
1130
- }
799
+ // extract the rerank score - a single float per sequence
800
+ auto & embd_seq_out = embd_seq;
801
+
802
+ for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
803
+ const llama_seq_id seq_id = ubatch.seq_id[s][0];
804
+ if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
805
+ continue;
806
+ }
807
+ embd_seq_out[seq_id].resize(1);
808
+ ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
809
+ }
810
+ } break;
1131
811
  case LLAMA_POOLING_TYPE_UNSPECIFIED:
1132
812
  {
1133
813
  GGML_ABORT("unknown pooling type");
@@ -1165,14 +845,21 @@ int llama_context::encode(llama_batch & inp_batch) {
1165
845
  }
1166
846
 
1167
847
  int llama_context::decode(llama_batch & inp_batch) {
848
+ if (!memory) {
849
+ LLAMA_LOG_WARN("%s: cannot decode batches with this context (use llama_encode() instead)\n", __func__);
850
+ return encode(inp_batch);
851
+ }
852
+
1168
853
  if (inp_batch.n_tokens == 0) {
1169
854
  LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
1170
855
  return -1;
1171
856
  }
1172
857
 
858
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
859
+
1173
860
  // temporary allocate memory for the input batch if needed
1174
- // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
1175
- llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
861
+ // TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences
862
+ llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max() + 1);
1176
863
 
1177
864
  const llama_batch & batch = batch_allocr.batch;
1178
865
 
@@ -1184,33 +871,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1184
871
  const int64_t n_tokens_all = batch.n_tokens;
1185
872
  const int64_t n_embd = hparams.n_embd;
1186
873
 
1187
- // TODO: remove this stuff
1188
- class batch_guard {
1189
- public:
1190
- batch_guard(llama_kv_cache_unified & kv_self) : kv_slot_restorer(kv_self) {
1191
- }
1192
-
1193
- ~batch_guard() {
1194
- if (!is_done) {
1195
- kv_slot_restorer.restore();
1196
- }
1197
- }
1198
-
1199
- void done() {
1200
- is_done = true;
1201
- }
1202
-
1203
- void save(const llama_kv_cache_slot_info & slot_info) {
1204
- kv_slot_restorer.save(slot_info);
1205
- }
1206
-
1207
- private:
1208
- bool is_done = false;
1209
-
1210
- llama_kv_slot_restorer kv_slot_restorer;
1211
- };
1212
-
1213
- batch_guard bg(*kv_self);
874
+ llama_kv_cache_guard kv_guard(kv_self);
1214
875
 
1215
876
  GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
1216
877
 
@@ -1244,18 +905,14 @@ int llama_context::decode(llama_batch & inp_batch) {
1244
905
  for (uint32_t i = 0; i < n_tokens_all; ++i) {
1245
906
  n_outputs_all += batch.logits[i] != 0;
1246
907
  }
1247
- } else if (logits_all || embd_pooled) {
908
+ } else if (embd_pooled) {
1248
909
  n_outputs_all = n_tokens_all;
1249
910
  } else {
1250
911
  // keep last output only
1251
912
  n_outputs_all = 1;
1252
913
  }
1253
914
 
1254
- const bool logits_all = n_outputs_all == n_tokens_all;
1255
-
1256
- sbatch.from_batch(batch, n_embd,
1257
- /* simple_split */ !kv_self->recurrent,
1258
- /* logits_all */ logits_all);
915
+ llama_sbatch sbatch = kv_self->sbatch_init(batch, /* logits_all */ n_outputs_all == n_tokens_all);
1259
916
 
1260
917
  // reserve output buffer
1261
918
  if (output_reserve(n_outputs_all) < n_outputs_all) {
@@ -1263,25 +920,13 @@ int llama_context::decode(llama_batch & inp_batch) {
1263
920
  return -2;
1264
921
  };
1265
922
 
923
+ // handle any pending defrags/shifts
924
+ kv_self_update();
925
+
1266
926
  int64_t n_outputs_prev = 0;
1267
927
 
1268
928
  while (sbatch.n_tokens > 0) {
1269
- llama_ubatch ubatch = llama_ubatch();
1270
-
1271
- const auto & n_ubatch = cparams.n_ubatch;
1272
-
1273
- if (kv_self->recurrent) {
1274
- if (embd_pooled) {
1275
- // Pooled embeddings cannot be split across ubatches (yet)
1276
- ubatch = sbatch.split_seq(cparams.n_ubatch);
1277
- } else {
1278
- // recurrent model architectures are easier to implement
1279
- // with equal-length sequences
1280
- ubatch = sbatch.split_equal(cparams.n_ubatch);
1281
- }
1282
- } else {
1283
- ubatch = sbatch.split_simple(n_ubatch);
1284
- }
929
+ llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
1285
930
 
1286
931
  // count the outputs in this u_batch
1287
932
  {
@@ -1300,35 +945,13 @@ int llama_context::decode(llama_batch & inp_batch) {
1300
945
  n_outputs = n_outputs_new;
1301
946
  }
1302
947
 
1303
- // non-causal masks do not use the KV cache
1304
- if (hparams.causal_attn) {
1305
- kv_self_update();
948
+ // find KV slot
949
+ if (!kv_self->find_slot(ubatch)) {
950
+ LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
1306
951
 
1307
- // if we have enough unused cells before the current head ->
1308
- // better to start searching from the beginning of the cache, hoping to fill it
1309
- if (kv_self->head > kv_self->used + 2*ubatch.n_tokens) {
1310
- kv_self->head = 0;
1311
- }
1312
-
1313
- const auto slot_info = kv_self->find_slot(ubatch);
1314
- if (!slot_info) {
1315
- LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__);
1316
- return -3;
1317
- }
1318
-
1319
- bg.save(slot_info);
1320
-
1321
- if (!kv_self->recurrent) {
1322
- // a heuristic, to avoid attending the full cache if it is not yet utilized
1323
- // after enough generations, the benefit from this heuristic disappears
1324
- // if we start defragmenting the cache, the benefit from this will be more important
1325
- const uint32_t pad = kv_self->get_padding(cparams);
1326
- kv_self->n = std::min(kv_self->size, std::max(pad, GGML_PAD(kv_self->cell_max(), pad)));
1327
- }
952
+ return 1;
1328
953
  }
1329
954
 
1330
- //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self->n, kv_self->used, kv_self->head);
1331
-
1332
955
  ggml_backend_sched_reset(sched.get());
1333
956
  ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
1334
957
 
@@ -1354,16 +977,6 @@ int llama_context::decode(llama_batch & inp_batch) {
1354
977
  }
1355
978
  }
1356
979
 
1357
- // update the kv ring buffer
1358
- {
1359
- kv_self->head += ubatch.n_tokens;
1360
-
1361
- // Ensure kv cache head points to a valid index.
1362
- if (kv_self->head >= kv_self->size) {
1363
- kv_self->head = 0;
1364
- }
1365
- }
1366
-
1367
980
  // plot the computation graph in dot format (for debugging purposes)
1368
981
  //if (n_past%100 == 0) {
1369
982
  // ggml_graph_dump_dot(gf, NULL, "llama.dot");
@@ -1450,45 +1063,70 @@ int llama_context::decode(llama_batch & inp_batch) {
1450
1063
  }
1451
1064
 
1452
1065
  // finalize the batch processing
1453
- bg.done();
1066
+ kv_guard.commit();
1067
+
1068
+ // set to total number of outputs in the batch, for use in llama_get_logits_ith
1069
+ n_outputs = n_outputs_all;
1454
1070
 
1455
1071
  // set output mappings
1456
1072
  {
1457
1073
  bool sorted_output = true;
1458
1074
 
1459
- GGML_ASSERT(sbatch.out_ids.size() == (size_t) n_outputs_all);
1075
+ auto & out_ids = sbatch.out_ids;
1076
+
1077
+ GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
1460
1078
 
1461
1079
  for (int64_t i = 0; i < n_outputs_all; ++i) {
1462
- int64_t out_id = sbatch.out_ids[i];
1080
+ int64_t out_id = out_ids[i];
1463
1081
  output_ids[out_id] = i;
1464
1082
  if (out_id != i) {
1465
1083
  sorted_output = false;
1466
1084
  }
1467
1085
  }
1468
1086
 
1469
- if (sorted_output) {
1470
- sbatch.out_ids.clear();
1087
+ // make the outputs have the same order they had in the user-provided batch
1088
+ // note: this is mostly relevant for recurrent models atm
1089
+ if (!sorted_output) {
1090
+ const uint32_t n_vocab = model.vocab.n_tokens();
1091
+ const uint32_t n_embd = model.hparams.n_embd;
1092
+
1093
+ GGML_ASSERT((size_t) n_outputs == out_ids.size());
1094
+
1095
+ // TODO: is there something more efficient which also minimizes swaps?
1096
+ // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
1097
+ for (int32_t i = 0; i < n_outputs - 1; ++i) {
1098
+ int32_t j_min = i;
1099
+ for (int32_t j = i + 1; j < n_outputs; ++j) {
1100
+ if (out_ids[j] < out_ids[j_min]) {
1101
+ j_min = j;
1102
+ }
1103
+ }
1104
+ if (j_min == i) { continue; }
1105
+ std::swap(out_ids[i], out_ids[j_min]);
1106
+ if (logits_size > 0) {
1107
+ for (uint32_t k = 0; k < n_vocab; k++) {
1108
+ std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
1109
+ }
1110
+ }
1111
+ if (embd_size > 0) {
1112
+ for (uint32_t k = 0; k < n_embd; k++) {
1113
+ std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
1114
+ }
1115
+ }
1116
+ }
1117
+ std::fill(output_ids.begin(), output_ids.end(), -1);
1118
+ for (int32_t i = 0; i < n_outputs; ++i) {
1119
+ output_ids[out_ids[i]] = i;
1120
+ }
1471
1121
  }
1472
1122
  }
1473
1123
 
1474
- // set to total number of outputs in the batch, for use in llama_get_logits_ith
1475
- n_outputs = n_outputs_all;
1476
-
1477
1124
  // wait for the computation to finish (automatically done when obtaining the model output)
1478
1125
  //synchronize();
1479
1126
 
1480
1127
  // decide if we need to defrag the kv cache
1481
- if (cparams.causal_attn && cparams.defrag_thold > 0.0f) {
1482
- // - do not defrag small contexts (i.e. < 2048 tokens)
1483
- // - count the padding towards the number of used tokens
1484
- const float fragmentation = kv_self->n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self->used + kv_self->get_padding(cparams))/float(kv_self->n)) : 0.0f;
1485
-
1486
- // queue defragmentation for next llama_kv_cache_update
1487
- if (fragmentation > cparams.defrag_thold) {
1488
- LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
1489
-
1490
- kv_self->defrag();
1491
- }
1128
+ if (cparams.defrag_thold > 0.0f) {
1129
+ kv_self->defrag_sched(cparams.defrag_thold);
1492
1130
  }
1493
1131
 
1494
1132
  // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
@@ -1568,52 +1206,12 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
1568
1206
  // set all ids as invalid (negative)
1569
1207
  std::fill(output_ids.begin(), output_ids.end(), -1);
1570
1208
 
1571
- ggml_backend_buffer_clear(buf_output.get(), 0);
1572
-
1573
1209
  this->n_outputs = 0;
1574
1210
  this->n_outputs_max = n_outputs_max;
1575
1211
 
1576
1212
  return n_outputs_max;
1577
1213
  }
1578
1214
 
1579
- void llama_context::output_reorder() {
1580
- auto & out_ids = sbatch.out_ids;
1581
- if (!out_ids.empty()) {
1582
- const uint32_t n_vocab = model.vocab.n_tokens();
1583
- const uint32_t n_embd = model.hparams.n_embd;
1584
-
1585
- GGML_ASSERT((size_t) n_outputs == out_ids.size());
1586
-
1587
- // TODO: is there something more efficient which also minimizes swaps?
1588
- // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
1589
- for (int32_t i = 0; i < n_outputs - 1; ++i) {
1590
- int32_t j_min = i;
1591
- for (int32_t j = i + 1; j < n_outputs; ++j) {
1592
- if (out_ids[j] < out_ids[j_min]) {
1593
- j_min = j;
1594
- }
1595
- }
1596
- if (j_min == i) { continue; }
1597
- std::swap(out_ids[i], out_ids[j_min]);
1598
- if (logits_size > 0) {
1599
- for (uint32_t k = 0; k < n_vocab; k++) {
1600
- std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
1601
- }
1602
- }
1603
- if (embd_size > 0) {
1604
- for (uint32_t k = 0; k < n_embd; k++) {
1605
- std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
1606
- }
1607
- }
1608
- }
1609
- std::fill(output_ids.begin(), output_ids.end(), -1);
1610
- for (int32_t i = 0; i < n_outputs; ++i) {
1611
- output_ids[out_ids[i]] = i;
1612
- }
1613
- out_ids.clear();
1614
- }
1615
- }
1616
-
1617
1215
  //
1618
1216
  // graph
1619
1217
  //
@@ -1650,7 +1248,7 @@ llm_graph_result_ptr llama_context::graph_build(
1650
1248
  /*.backend_cpu =*/ backend_cpu,
1651
1249
  /*.cvec =*/ &cvec,
1652
1250
  /*.loras =*/ &loras,
1653
- /*.memory =*/ kv_self.get(),
1251
+ /*.memory =*/ memory.get(),
1654
1252
  /*.cross =*/ &cross,
1655
1253
  /*.n_outputs =*/ n_outputs,
1656
1254
  /*.cb =*/ graph_get_cb(),
@@ -2054,8 +1652,6 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
2054
1652
  {
2055
1653
  LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__);
2056
1654
 
2057
- output_reorder();
2058
-
2059
1655
  const auto n_outputs = this->n_outputs;
2060
1656
  const auto & output_ids = this->output_ids;
2061
1657
 
@@ -2108,8 +1704,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
2108
1704
  }
2109
1705
  }
2110
1706
 
2111
- LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
2112
- kv_self->state_write(io);
1707
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1708
+
1709
+ if (kv_self != nullptr) {
1710
+ LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
1711
+ kv_self->state_write(io);
1712
+ }
2113
1713
 
2114
1714
  return io.n_bytes();
2115
1715
  }
@@ -2192,8 +1792,13 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
2192
1792
  }
2193
1793
  }
2194
1794
 
2195
- LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
2196
- kv_self->state_read(io);
1795
+ if (memory) {
1796
+ LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
1797
+
1798
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1799
+
1800
+ kv_self->state_read(io);
1801
+ }
2197
1802
 
2198
1803
  return io.n_bytes();
2199
1804
  }
@@ -2201,7 +1806,11 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
2201
1806
  size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
2202
1807
  GGML_UNUSED(seq_id);
2203
1808
 
2204
- kv_self->state_write(io, seq_id);
1809
+ if (memory) {
1810
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1811
+
1812
+ kv_self->state_write(io, seq_id);
1813
+ }
2205
1814
 
2206
1815
  return io.n_bytes();
2207
1816
  }
@@ -2209,7 +1818,11 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s
2209
1818
  size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
2210
1819
  GGML_UNUSED(seq_id);
2211
1820
 
2212
- kv_self->state_read(io, seq_id);
1821
+ if (memory) {
1822
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1823
+
1824
+ kv_self->state_read(io, seq_id);
1825
+ }
2213
1826
 
2214
1827
  return io.n_bytes();
2215
1828
  }
@@ -2237,6 +1850,215 @@ void llama_context::perf_reset() {
2237
1850
  t_p_eval_us = n_p_eval = 0;
2238
1851
  }
2239
1852
 
1853
+ //
1854
+ // training
1855
+ //
1856
+
1857
+ static void llama_set_param(struct ggml_tensor * tensor, llama_opt_param_filter param_filter, void * userdata) {
1858
+ if (!tensor || tensor->type != GGML_TYPE_F32) {
1859
+ return;
1860
+ }
1861
+ if (!param_filter(tensor, userdata)) {
1862
+ return;
1863
+ }
1864
+ if (strcmp(tensor->name, "token_embd.weight") == 0) {
1865
+ return; // FIXME
1866
+ }
1867
+ if (strcmp(tensor->name, "rope_freqs.weight") == 0) {
1868
+ return; // FIXME
1869
+ }
1870
+ ggml_set_param(tensor);
1871
+ }
1872
+
1873
+ void llama_context::opt_init(struct llama_model * model, struct llama_opt_params lopt_params) {
1874
+ GGML_ASSERT(!opt_ctx);
1875
+ model->hparams.n_ctx_train = lopt_params.n_ctx_train > 0 ? lopt_params.n_ctx_train : n_ctx();
1876
+ const uint32_t n_batch = std::min(this->n_batch(), model->hparams.n_ctx_train);
1877
+ const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
1878
+ GGML_ASSERT(model->hparams.n_ctx_train % n_batch == 0);
1879
+ GGML_ASSERT(n_batch % n_ubatch == 0);
1880
+
1881
+ ggml_opt_params opt_params = ggml_opt_default_params(sched.get(), GGML_OPT_LOSS_TYPE_CROSS_ENTROPY);
1882
+ opt_params.opt_period = n_batch / n_ubatch;
1883
+ opt_params.get_opt_pars = lopt_params.get_opt_pars;
1884
+ opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud;
1885
+
1886
+ opt_ctx = ggml_opt_init(opt_params);
1887
+
1888
+ llama_opt_param_filter param_filter = lopt_params.param_filter;
1889
+ void * param_filter_ud = lopt_params.param_filter_ud;
1890
+
1891
+ //llama_set_param(model->tok_embd, param_filter, param_filter_ud); // FIXME
1892
+ llama_set_param(model->type_embd, param_filter, param_filter_ud);
1893
+ llama_set_param(model->pos_embd, param_filter, param_filter_ud);
1894
+ llama_set_param(model->tok_norm, param_filter, param_filter_ud);
1895
+ llama_set_param(model->tok_norm_b, param_filter, param_filter_ud);
1896
+ llama_set_param(model->output_norm, param_filter, param_filter_ud);
1897
+ llama_set_param(model->output_norm_b, param_filter, param_filter_ud);
1898
+ llama_set_param(model->output, param_filter, param_filter_ud);
1899
+ llama_set_param(model->output_b, param_filter, param_filter_ud);
1900
+ llama_set_param(model->output_norm_enc, param_filter, param_filter_ud);
1901
+ llama_set_param(model->cls, param_filter, param_filter_ud);
1902
+ llama_set_param(model->cls_b, param_filter, param_filter_ud);
1903
+ llama_set_param(model->cls_out, param_filter, param_filter_ud);
1904
+ llama_set_param(model->cls_out_b, param_filter, param_filter_ud);
1905
+
1906
+ for (struct llama_layer & layer : model->layers) {
1907
+ for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) {
1908
+ llama_set_param(reinterpret_cast<struct ggml_tensor **>(&layer)[i], param_filter, param_filter_ud);
1909
+ }
1910
+ }
1911
+ }
1912
+
1913
+ void llama_context::opt_epoch_iter(
1914
+ ggml_opt_dataset_t dataset,
1915
+ ggml_opt_result_t result,
1916
+ const std::vector<llama_token> & tokens,
1917
+ const std::vector<llama_token> & labels_sparse,
1918
+ llama_batch & batch,
1919
+ ggml_opt_epoch_callback callback,
1920
+ bool train,
1921
+ int64_t idata_in_loop,
1922
+ int64_t ndata_in_loop,
1923
+ int64_t t_loop_start) {
1924
+ GGML_ASSERT(opt_ctx);
1925
+ const uint32_t n_ctx = llama_model_n_ctx_train(&model);
1926
+ const uint32_t n_batch = std::min(this->n_batch(), n_ctx);
1927
+ const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
1928
+
1929
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1930
+
1931
+ kv_self->clear();
1932
+ llama_kv_cache_guard kv_guard(kv_self);
1933
+
1934
+ for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
1935
+ batch.n_tokens = n_batch;
1936
+ for (uint32_t pos_batch = 0; pos_batch < n_batch; ++pos_batch) {
1937
+ batch.token [pos_batch] = tokens[pos_ctx + pos_batch];
1938
+ batch.pos [pos_batch] = pos_ctx + pos_batch;
1939
+ batch.n_seq_id[pos_batch] = 1;
1940
+ batch.seq_id [pos_batch][0] = 0;
1941
+ batch.logits [pos_batch] = true;
1942
+ }
1943
+
1944
+ const auto n_tokens_all = batch.n_tokens;
1945
+
1946
+ n_queued_tokens += n_tokens_all;
1947
+
1948
+ // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
1949
+ const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
1950
+
1951
+ embd_seq.clear();
1952
+
1953
+ int64_t n_outputs_all = n_tokens_all;
1954
+
1955
+ llama_sbatch sbatch = kv_self->sbatch_init(batch, /*logits_all =*/ true);
1956
+
1957
+ // reserve output buffer
1958
+ if (output_reserve(n_outputs_all) < n_outputs_all) {
1959
+ LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
1960
+ GGML_ABORT("TODO: handle this error");
1961
+ };
1962
+
1963
+ for (uint32_t pos_batch = 0; pos_batch < n_batch; pos_batch += n_ubatch) {
1964
+ llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
1965
+
1966
+ n_outputs = ubatch.n_tokens;
1967
+
1968
+ // TODO: not sure if this is needed
1969
+ if (!kv_self->find_slot(ubatch)) {
1970
+ LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
1971
+
1972
+ GGML_ABORT("TODO: handle this error");
1973
+ }
1974
+
1975
+ auto * gf = graph_init();
1976
+ auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
1977
+
1978
+ struct ggml_context * ctx_compute_opt;
1979
+ {
1980
+ const size_t size_gf = ggml_graph_size(gf);
1981
+ const size_t size_meta = 4*size_gf*ggml_tensor_overhead() + 2*ggml_graph_overhead_custom(size_gf, /*grads = */ true);
1982
+ struct ggml_init_params params = {
1983
+ /*.mem_size =*/ size_meta,
1984
+ /*.mem_buffer =*/ nullptr,
1985
+ /*.no_alloc =*/ true,
1986
+ };
1987
+ ctx_compute_opt = ggml_init(params);
1988
+ }
1989
+ ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits());
1990
+ ggml_opt_alloc(opt_ctx, train);
1991
+ res->set_inputs(&ubatch);
1992
+ {
1993
+ struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
1994
+ GGML_ASSERT(labels->ne[1] == n_ubatch);
1995
+ ggml_set_zero(labels);
1996
+ const float onef = 1.0f;
1997
+ for (uint32_t pos_ubatch = 0; pos_ubatch < n_ubatch; ++pos_ubatch) {
1998
+ const uint32_t ilabel = pos_ctx + pos_batch + pos_ubatch;
1999
+ GGML_ASSERT(labels_sparse[ilabel] < labels->ne[0]);
2000
+ ggml_backend_tensor_set(labels, &onef, (pos_ubatch*labels->ne[0] + labels_sparse[ilabel])*sizeof(float), sizeof(float));
2001
+ }
2002
+ }
2003
+ ggml_opt_eval(opt_ctx, result);
2004
+ if (callback) {
2005
+ callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start);
2006
+ }
2007
+ ggml_free(ctx_compute_opt);
2008
+ }
2009
+ }
2010
+
2011
+ kv_guard.commit();
2012
+ }
2013
+
2014
+ void llama_context::opt_epoch(
2015
+ ggml_opt_dataset_t dataset,
2016
+ ggml_opt_result_t result_train,
2017
+ ggml_opt_result_t result_eval,
2018
+ int64_t idata_split,
2019
+ ggml_opt_epoch_callback callback_train,
2020
+ ggml_opt_epoch_callback callback_eval) {
2021
+ const uint32_t n_ctx = this->n_ctx();
2022
+ const uint32_t n_batch = std::min(cparams.n_batch, n_ctx);
2023
+ const uint32_t n_ubatch = std::min(cparams.n_ubatch, n_batch);
2024
+ const int64_t ndata = ggml_opt_dataset_ndata(dataset);
2025
+
2026
+ GGML_ASSERT(idata_split >= 0);
2027
+ GGML_ASSERT(idata_split <= ndata);
2028
+
2029
+ const uint32_t ubatch_per_ctx = n_ctx / n_ubatch;
2030
+
2031
+ struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
2032
+ std::vector<llama_token> tokens(n_ctx);
2033
+ std::vector<llama_token> labels_sparse(n_ctx);
2034
+
2035
+ int64_t idata = 0;
2036
+
2037
+ int64_t t_loop_start = ggml_time_us();
2038
+ int64_t ndata_in_loop = idata_split*ubatch_per_ctx;
2039
+ for (; idata < idata_split; ++idata) {
2040
+ constexpr bool train = true;
2041
+ const int64_t idata_in_loop = idata*ubatch_per_ctx;
2042
+
2043
+ ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata);
2044
+ opt_epoch_iter(dataset, result_train, tokens, labels_sparse, batch,
2045
+ callback_train, train, idata_in_loop, ndata_in_loop, t_loop_start);
2046
+ }
2047
+
2048
+ t_loop_start = ggml_time_us();
2049
+ ndata_in_loop = (ndata - idata_split)*ubatch_per_ctx;
2050
+ for (; idata < ndata; ++idata) {
2051
+ constexpr bool train = false;
2052
+ const int64_t idata_in_loop = (idata - idata_split)*ubatch_per_ctx;
2053
+
2054
+ ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata);
2055
+ opt_epoch_iter(dataset, result_eval, tokens, labels_sparse, batch,
2056
+ callback_eval, train, idata_in_loop, ndata_in_loop, t_loop_start);
2057
+ }
2058
+
2059
+ llama_batch_free(batch);
2060
+ }
2061
+
2240
2062
  //
2241
2063
  // interface implementation
2242
2064
  //
@@ -2264,13 +2086,13 @@ llama_context_params llama_context_default_params() {
2264
2086
  /*.cb_eval_user_data =*/ nullptr,
2265
2087
  /*.type_k =*/ GGML_TYPE_F16,
2266
2088
  /*.type_v =*/ GGML_TYPE_F16,
2267
- /*.logits_all =*/ false,
2089
+ /*.abort_callback =*/ nullptr,
2090
+ /*.abort_callback_data =*/ nullptr,
2268
2091
  /*.embeddings =*/ false,
2269
2092
  /*.offload_kqv =*/ true,
2270
2093
  /*.flash_attn =*/ false,
2271
2094
  /*.no_perf =*/ true,
2272
- /*.abort_callback =*/ nullptr,
2273
- /*.abort_callback_data =*/ nullptr,
2095
+ /*.op_offload =*/ true,
2274
2096
  };
2275
2097
 
2276
2098
  return result;
@@ -2299,11 +2121,6 @@ llama_context * llama_init_from_model(
2299
2121
  params.flash_attn = false;
2300
2122
  }
2301
2123
 
2302
- if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
2303
- LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
2304
- params.flash_attn = false;
2305
- }
2306
-
2307
2124
  if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
2308
2125
  LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
2309
2126
  return nullptr;
@@ -2504,7 +2321,12 @@ int32_t llama_get_kv_cache_token_count(const llama_context * ctx) {
2504
2321
  }
2505
2322
 
2506
2323
  int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
2507
- return llama_kv_cache_n_tokens(ctx->get_kv_self());
2324
+ const auto * kv = ctx->get_kv_self();
2325
+ if (!kv) {
2326
+ return 0;
2327
+ }
2328
+
2329
+ return kv->get_n_tokens();
2508
2330
  }
2509
2331
 
2510
2332
  // deprecated
@@ -2513,7 +2335,12 @@ int32_t llama_get_kv_cache_used_cells(const llama_context * ctx) {
2513
2335
  }
2514
2336
 
2515
2337
  int32_t llama_kv_self_used_cells(const llama_context * ctx) {
2516
- return llama_kv_cache_used_cells(ctx->get_kv_self());
2338
+ const auto * kv = ctx->get_kv_self();
2339
+ if (!kv) {
2340
+ return 0;
2341
+ }
2342
+
2343
+ return kv->get_used_cells();
2517
2344
  }
2518
2345
 
2519
2346
  // deprecated
@@ -2522,7 +2349,12 @@ void llama_kv_cache_clear(llama_context * ctx) {
2522
2349
  }
2523
2350
 
2524
2351
  void llama_kv_self_clear(llama_context * ctx) {
2525
- llama_kv_cache_clear(ctx->get_kv_self());
2352
+ auto * kv = ctx->get_kv_self();
2353
+ if (!kv) {
2354
+ return;
2355
+ }
2356
+
2357
+ kv->clear();
2526
2358
  }
2527
2359
 
2528
2360
  // deprecated
@@ -2539,7 +2371,12 @@ bool llama_kv_self_seq_rm(
2539
2371
  llama_seq_id seq_id,
2540
2372
  llama_pos p0,
2541
2373
  llama_pos p1) {
2542
- return llama_kv_cache_seq_rm(ctx->get_kv_self(), seq_id, p0, p1);
2374
+ auto * kv = ctx->get_kv_self();
2375
+ if (!kv) {
2376
+ return true;
2377
+ }
2378
+
2379
+ return kv->seq_rm(seq_id, p0, p1);
2543
2380
  }
2544
2381
 
2545
2382
  // deprecated
@@ -2549,7 +2386,7 @@ void llama_kv_cache_seq_cp(
2549
2386
  llama_seq_id seq_id_dst,
2550
2387
  llama_pos p0,
2551
2388
  llama_pos p1) {
2552
- return llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
2389
+ llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
2553
2390
  }
2554
2391
 
2555
2392
  void llama_kv_self_seq_cp(
@@ -2558,18 +2395,28 @@ void llama_kv_self_seq_cp(
2558
2395
  llama_seq_id seq_id_dst,
2559
2396
  llama_pos p0,
2560
2397
  llama_pos p1) {
2561
- return llama_kv_cache_seq_cp(ctx->get_kv_self(), seq_id_src, seq_id_dst, p0, p1);
2398
+ auto * kv = ctx->get_kv_self();
2399
+ if (!kv) {
2400
+ return;
2401
+ }
2402
+
2403
+ kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
2562
2404
  }
2563
2405
 
2564
2406
  // deprecated
2565
2407
  void llama_kv_cache_seq_keep(
2566
2408
  llama_context * ctx,
2567
2409
  llama_seq_id seq_id) {
2568
- return llama_kv_self_seq_keep(ctx, seq_id);
2410
+ llama_kv_self_seq_keep(ctx, seq_id);
2569
2411
  }
2570
2412
 
2571
2413
  void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
2572
- return llama_kv_cache_seq_keep(ctx->get_kv_self(), seq_id);
2414
+ auto * kv = ctx->get_kv_self();
2415
+ if (!kv) {
2416
+ return;
2417
+ }
2418
+
2419
+ kv->seq_keep(seq_id);
2573
2420
  }
2574
2421
 
2575
2422
  // deprecated
@@ -2579,7 +2426,7 @@ void llama_kv_cache_seq_add(
2579
2426
  llama_pos p0,
2580
2427
  llama_pos p1,
2581
2428
  llama_pos delta) {
2582
- return llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
2429
+ llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
2583
2430
  }
2584
2431
 
2585
2432
  void llama_kv_self_seq_add(
@@ -2588,7 +2435,12 @@ void llama_kv_self_seq_add(
2588
2435
  llama_pos p0,
2589
2436
  llama_pos p1,
2590
2437
  llama_pos delta) {
2591
- return llama_kv_cache_seq_add(ctx->get_kv_self(), seq_id, p0, p1, delta);
2438
+ auto * kv = ctx->get_kv_self();
2439
+ if (!kv) {
2440
+ return;
2441
+ }
2442
+
2443
+ kv->seq_add(seq_id, p0, p1, delta);
2592
2444
  }
2593
2445
 
2594
2446
  // deprecated
@@ -2598,7 +2450,7 @@ void llama_kv_cache_seq_div(
2598
2450
  llama_pos p0,
2599
2451
  llama_pos p1,
2600
2452
  int d) {
2601
- return llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
2453
+ llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
2602
2454
  }
2603
2455
 
2604
2456
  void llama_kv_self_seq_div(
@@ -2607,7 +2459,12 @@ void llama_kv_self_seq_div(
2607
2459
  llama_pos p0,
2608
2460
  llama_pos p1,
2609
2461
  int d) {
2610
- return llama_kv_cache_seq_div(ctx->get_kv_self(), seq_id, p0, p1, d);
2462
+ auto * kv = ctx->get_kv_self();
2463
+ if (!kv) {
2464
+ return;
2465
+ }
2466
+
2467
+ kv->seq_div(seq_id, p0, p1, d);
2611
2468
  }
2612
2469
 
2613
2470
  // deprecated
@@ -2616,16 +2473,27 @@ llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
2616
2473
  }
2617
2474
 
2618
2475
  llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
2619
- return llama_kv_cache_seq_pos_max(ctx->get_kv_self(), seq_id);
2476
+ const auto * kv = ctx->get_kv_self();
2477
+ if (!kv) {
2478
+ return 0;
2479
+ }
2480
+
2481
+ return kv->seq_pos_max(seq_id);
2620
2482
  }
2621
2483
 
2622
2484
  // deprecated
2623
2485
  void llama_kv_cache_defrag(llama_context * ctx) {
2624
- return llama_kv_self_defrag(ctx);
2486
+ llama_kv_self_defrag(ctx);
2625
2487
  }
2626
2488
 
2627
2489
  void llama_kv_self_defrag(llama_context * ctx) {
2628
- llama_kv_cache_defrag(ctx->get_kv_self());
2490
+ auto * kv = ctx->get_kv_self();
2491
+ if (!kv) {
2492
+ return;
2493
+ }
2494
+
2495
+ // force defrag
2496
+ kv->defrag_sched(-1.0f);
2629
2497
  }
2630
2498
 
2631
2499
  // deprecated
@@ -2634,7 +2502,12 @@ bool llama_kv_cache_can_shift(const llama_context * ctx) {
2634
2502
  }
2635
2503
 
2636
2504
  bool llama_kv_self_can_shift(const llama_context * ctx) {
2637
- return llama_kv_cache_can_shift(ctx->get_kv_self());
2505
+ const auto * kv = ctx->get_kv_self();
2506
+ if (!kv) {
2507
+ return false;
2508
+ }
2509
+
2510
+ return kv->get_can_shift();
2638
2511
  }
2639
2512
 
2640
2513
  // deprecated
@@ -2804,3 +2677,34 @@ void llama_perf_context_print(const llama_context * ctx) {
2804
2677
  void llama_perf_context_reset(llama_context * ctx) {
2805
2678
  ctx->perf_reset();
2806
2679
  }
2680
+
2681
+ //
2682
+ // training
2683
+ //
2684
+
2685
+ bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata) {
2686
+ GGML_UNUSED(tensor);
2687
+ GGML_UNUSED(userdata);
2688
+ return true;
2689
+ }
2690
+
2691
+ void llama_opt_init(struct llama_context * ctx, struct llama_model * model, struct llama_opt_params lopt_params) {
2692
+ ctx->opt_init(model, lopt_params);
2693
+ }
2694
+
2695
+ void llama_opt_epoch(
2696
+ struct llama_context * ctx,
2697
+ ggml_opt_dataset_t dataset,
2698
+ ggml_opt_result_t result_train,
2699
+ ggml_opt_result_t result_eval,
2700
+ int64_t idata_split,
2701
+ ggml_opt_epoch_callback callback_train,
2702
+ ggml_opt_epoch_callback callback_eval) {
2703
+ ctx->opt_epoch(
2704
+ dataset,
2705
+ result_train,
2706
+ result_eval,
2707
+ idata_split,
2708
+ callback_train,
2709
+ callback_eval);
2710
+ }