@fugood/llama.node 0.3.16 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (281) hide show
  1. package/CMakeLists.txt +6 -1
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  7. package/bin/linux-cuda/x64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  10. package/bin/win32/arm64/llama-node.node +0 -0
  11. package/bin/win32/arm64/node.lib +0 -0
  12. package/bin/win32/x64/llama-node.node +0 -0
  13. package/bin/win32/x64/node.lib +0 -0
  14. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/arm64/node.lib +0 -0
  16. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  17. package/bin/win32-vulkan/x64/node.lib +0 -0
  18. package/lib/binding.ts +44 -2
  19. package/lib/index.js +132 -1
  20. package/lib/index.ts +203 -3
  21. package/package.json +2 -1
  22. package/src/EmbeddingWorker.cpp +1 -1
  23. package/src/LlamaCompletionWorker.cpp +374 -19
  24. package/src/LlamaCompletionWorker.h +31 -10
  25. package/src/LlamaContext.cpp +216 -7
  26. package/src/LlamaContext.h +12 -0
  27. package/src/common.hpp +15 -0
  28. package/src/llama.cpp/.github/workflows/build-linux-cross.yml +233 -0
  29. package/src/llama.cpp/.github/workflows/build.yml +89 -767
  30. package/src/llama.cpp/.github/workflows/docker.yml +9 -6
  31. package/src/llama.cpp/.github/workflows/release.yml +716 -0
  32. package/src/llama.cpp/.github/workflows/server.yml +19 -23
  33. package/src/llama.cpp/CMakeLists.txt +11 -1
  34. package/src/llama.cpp/cmake/build-info.cmake +8 -2
  35. package/src/llama.cpp/cmake/x64-windows-llvm.cmake +0 -6
  36. package/src/llama.cpp/common/CMakeLists.txt +35 -4
  37. package/src/llama.cpp/common/arg.cpp +844 -121
  38. package/src/llama.cpp/common/arg.h +9 -0
  39. package/src/llama.cpp/common/chat.cpp +129 -107
  40. package/src/llama.cpp/common/chat.h +2 -0
  41. package/src/llama.cpp/common/common.cpp +64 -518
  42. package/src/llama.cpp/common/common.h +35 -45
  43. package/src/llama.cpp/common/json-schema-to-grammar.cpp +3 -0
  44. package/src/llama.cpp/common/llguidance.cpp +31 -47
  45. package/src/llama.cpp/common/minja/chat-template.hpp +23 -11
  46. package/src/llama.cpp/common/minja/minja.hpp +186 -127
  47. package/src/llama.cpp/common/regex-partial.cpp +204 -0
  48. package/src/llama.cpp/common/regex-partial.h +56 -0
  49. package/src/llama.cpp/common/sampling.cpp +60 -50
  50. package/src/llama.cpp/docs/build.md +122 -7
  51. package/src/llama.cpp/examples/CMakeLists.txt +2 -32
  52. package/src/llama.cpp/examples/batched/batched.cpp +1 -1
  53. package/src/llama.cpp/examples/embedding/embedding.cpp +9 -12
  54. package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
  55. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  56. package/src/llama.cpp/examples/parallel/parallel.cpp +89 -15
  57. package/src/llama.cpp/examples/passkey/passkey.cpp +1 -1
  58. package/src/llama.cpp/examples/speculative/speculative.cpp +1 -1
  59. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
  60. package/src/llama.cpp/examples/sycl/build.sh +2 -2
  61. package/src/llama.cpp/examples/sycl/win-build-sycl.bat +2 -2
  62. package/src/llama.cpp/examples/training/CMakeLists.txt +5 -0
  63. package/src/llama.cpp/examples/training/finetune.cpp +96 -0
  64. package/src/llama.cpp/ggml/CMakeLists.txt +35 -2
  65. package/src/llama.cpp/ggml/cmake/GitVars.cmake +22 -0
  66. package/src/llama.cpp/ggml/include/ggml-backend.h +4 -4
  67. package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -1
  68. package/src/llama.cpp/ggml/include/ggml-cpu.h +5 -0
  69. package/src/llama.cpp/ggml/include/ggml-opt.h +47 -28
  70. package/src/llama.cpp/ggml/include/ggml-rpc.h +6 -1
  71. package/src/llama.cpp/ggml/include/ggml.h +76 -106
  72. package/src/llama.cpp/ggml/src/CMakeLists.txt +11 -8
  73. package/src/llama.cpp/ggml/src/ggml-alloc.c +4 -1
  74. package/src/llama.cpp/ggml/src/ggml-backend.cpp +9 -5
  75. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +0 -2
  76. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +8 -4
  77. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +5 -5
  78. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +692 -1534
  79. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +613 -122
  80. package/src/llama.cpp/ggml/src/ggml-cann/common.h +135 -1
  81. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +507 -137
  82. package/src/llama.cpp/ggml/src/ggml-common.h +12 -6
  83. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +66 -33
  84. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
  85. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
  86. package/src/llama.cpp/ggml/src/ggml-cpu/common.h +72 -0
  87. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +1 -1
  88. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +896 -194
  89. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +2 -21
  90. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +1060 -410
  91. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1008 -13533
  92. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +31 -16
  93. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +90 -12
  94. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -13
  95. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +266 -72
  96. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1034 -88
  97. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +8796 -0
  98. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +110 -0
  99. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +892 -0
  100. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
  101. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +28 -0
  102. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +252 -0
  103. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +802 -0
  104. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +23 -4
  105. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +7 -0
  106. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +1 -0
  107. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +0 -4
  108. package/src/llama.cpp/ggml/src/ggml-impl.h +52 -18
  109. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +106 -14
  110. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +67 -119
  111. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1023 -262
  112. package/src/llama.cpp/ggml/src/ggml-opt.cpp +368 -190
  113. package/src/llama.cpp/ggml/src/ggml-quants.c +0 -6
  114. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +307 -40
  115. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +125 -45
  116. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +10 -8
  117. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +239 -0
  118. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.hpp +39 -0
  119. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +0 -35
  120. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +9 -307
  121. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +72 -25
  122. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +14 -7
  123. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
  124. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +7 -1
  125. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +79 -90
  126. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +944 -438
  127. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +22 -23
  128. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
  129. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +24 -20
  130. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +1 -4
  131. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +507 -411
  132. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +84 -74
  133. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +1 -3
  134. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +185 -89
  135. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +37 -49
  136. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +7 -22
  137. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +4 -14
  138. package/src/llama.cpp/ggml/src/ggml-sycl/quants.hpp +83 -0
  139. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +204 -118
  140. package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +1 -3
  141. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +128 -53
  142. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +83 -49
  143. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1278 -282
  144. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +32 -0
  145. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +133 -30
  146. package/src/llama.cpp/ggml/src/ggml.c +170 -265
  147. package/src/llama.cpp/ggml/src/gguf.cpp +34 -33
  148. package/src/llama.cpp/include/llama.h +82 -22
  149. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.inp +112 -0
  150. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.out +46 -0
  151. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +112 -0
  152. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.out +46 -0
  153. package/src/llama.cpp/requirements/requirements-all.txt +5 -3
  154. package/src/llama.cpp/requirements/requirements-gguf_editor_gui.txt +3 -0
  155. package/src/llama.cpp/scripts/xxd.cmake +1 -1
  156. package/src/llama.cpp/src/CMakeLists.txt +4 -2
  157. package/src/llama.cpp/src/llama-adapter.cpp +43 -1
  158. package/src/llama.cpp/src/llama-arch.cpp +163 -17
  159. package/src/llama.cpp/src/llama-arch.h +16 -0
  160. package/src/llama.cpp/src/llama-batch.cpp +5 -1
  161. package/src/llama.cpp/src/llama-batch.h +2 -1
  162. package/src/llama.cpp/src/llama-chat.cpp +91 -16
  163. package/src/llama.cpp/src/llama-chat.h +7 -2
  164. package/src/llama.cpp/src/llama-context.cpp +479 -575
  165. package/src/llama.cpp/src/llama-context.h +44 -33
  166. package/src/llama.cpp/src/llama-cparams.h +1 -0
  167. package/src/llama.cpp/src/llama-graph.cpp +209 -157
  168. package/src/llama.cpp/src/llama-graph.h +38 -14
  169. package/src/llama.cpp/src/llama-hparams.h +13 -0
  170. package/src/llama.cpp/src/llama-kv-cache.cpp +1604 -543
  171. package/src/llama.cpp/src/llama-kv-cache.h +283 -171
  172. package/src/llama.cpp/src/llama-memory.h +12 -2
  173. package/src/llama.cpp/src/llama-mmap.cpp +1 -1
  174. package/src/llama.cpp/src/llama-model-loader.cpp +34 -20
  175. package/src/llama.cpp/src/llama-model-loader.h +5 -3
  176. package/src/llama.cpp/src/llama-model-saver.cpp +281 -0
  177. package/src/llama.cpp/src/llama-model-saver.h +37 -0
  178. package/src/llama.cpp/src/llama-model.cpp +1803 -330
  179. package/src/llama.cpp/src/llama-model.h +21 -2
  180. package/src/llama.cpp/src/llama-quant.cpp +33 -10
  181. package/src/llama.cpp/src/llama-sampling.cpp +25 -7
  182. package/src/llama.cpp/src/llama-vocab.cpp +86 -10
  183. package/src/llama.cpp/src/llama-vocab.h +6 -0
  184. package/src/llama.cpp/src/llama.cpp +15 -1
  185. package/src/llama.cpp/tests/CMakeLists.txt +52 -31
  186. package/src/llama.cpp/tests/test-arg-parser.cpp +51 -4
  187. package/src/llama.cpp/tests/test-backend-ops.cpp +189 -90
  188. package/src/llama.cpp/tests/test-chat-template.cpp +26 -6
  189. package/src/llama.cpp/tests/test-chat.cpp +15 -3
  190. package/src/llama.cpp/{examples/gbnf-validator/gbnf-validator.cpp → tests/test-gbnf-validator.cpp} +2 -2
  191. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -2
  192. package/src/llama.cpp/tests/test-grammar-llguidance.cpp +63 -2
  193. package/src/llama.cpp/tests/test-grammar-parser.cpp +3 -1
  194. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -1
  195. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -1
  196. package/src/llama.cpp/tests/test-mtmd-c-api.c +63 -0
  197. package/src/llama.cpp/tests/test-opt.cpp +33 -21
  198. package/src/llama.cpp/{examples/quantize-stats/quantize-stats.cpp → tests/test-quantize-stats.cpp} +3 -1
  199. package/src/llama.cpp/tests/test-regex-partial.cpp +288 -0
  200. package/src/llama.cpp/tests/test-sampling.cpp +1 -1
  201. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +2 -1
  202. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +2 -1
  203. package/src/llama.cpp/tools/CMakeLists.txt +39 -0
  204. package/src/llama.cpp/{examples → tools}/batched-bench/batched-bench.cpp +3 -3
  205. package/src/llama.cpp/{examples → tools}/export-lora/export-lora.cpp +1 -1
  206. package/src/llama.cpp/{examples → tools}/gguf-split/gguf-split.cpp +15 -16
  207. package/src/llama.cpp/{examples → tools}/imatrix/imatrix.cpp +11 -9
  208. package/src/llama.cpp/{examples → tools}/llama-bench/llama-bench.cpp +623 -274
  209. package/src/llama.cpp/{examples → tools}/main/main.cpp +22 -14
  210. package/src/llama.cpp/tools/mtmd/CMakeLists.txt +47 -0
  211. package/src/llama.cpp/tools/mtmd/clip-impl.h +365 -0
  212. package/src/llama.cpp/tools/mtmd/clip.cpp +3646 -0
  213. package/src/llama.cpp/tools/mtmd/clip.h +99 -0
  214. package/src/llama.cpp/tools/mtmd/deprecation-warning.cpp +22 -0
  215. package/src/llama.cpp/tools/mtmd/mtmd-cli.cpp +370 -0
  216. package/src/llama.cpp/tools/mtmd/mtmd-helper.cpp +310 -0
  217. package/src/llama.cpp/tools/mtmd/mtmd.cpp +678 -0
  218. package/src/llama.cpp/tools/mtmd/mtmd.h +331 -0
  219. package/src/llama.cpp/{examples → tools}/perplexity/perplexity.cpp +21 -5
  220. package/src/llama.cpp/{examples → tools}/quantize/quantize.cpp +53 -3
  221. package/src/llama.cpp/tools/rpc/CMakeLists.txt +4 -0
  222. package/src/llama.cpp/tools/rpc/rpc-server.cpp +322 -0
  223. package/src/llama.cpp/tools/run/CMakeLists.txt +16 -0
  224. package/src/llama.cpp/{examples → tools}/run/run.cpp +30 -30
  225. package/src/llama.cpp/{examples → tools}/server/CMakeLists.txt +2 -1
  226. package/src/llama.cpp/{examples → tools}/server/httplib.h +313 -247
  227. package/src/llama.cpp/{examples → tools}/server/server.cpp +529 -215
  228. package/src/llama.cpp/{examples → tools}/server/utils.hpp +427 -6
  229. package/src/llama.cpp/{examples → tools}/tts/tts.cpp +6 -9
  230. package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +0 -6
  231. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +0 -5
  232. package/src/llama.cpp/examples/infill/CMakeLists.txt +0 -5
  233. package/src/llama.cpp/examples/infill/infill.cpp +0 -590
  234. package/src/llama.cpp/examples/llava/CMakeLists.txt +0 -66
  235. package/src/llama.cpp/examples/llava/android/build_64.sh +0 -8
  236. package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +0 -59
  237. package/src/llama.cpp/examples/llava/clip.cpp +0 -3206
  238. package/src/llama.cpp/examples/llava/clip.h +0 -118
  239. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +0 -341
  240. package/src/llama.cpp/examples/llava/llava-cli.cpp +0 -332
  241. package/src/llama.cpp/examples/llava/llava.cpp +0 -574
  242. package/src/llama.cpp/examples/llava/llava.h +0 -49
  243. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +0 -354
  244. package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +0 -584
  245. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +0 -6
  246. package/src/llama.cpp/examples/rpc/CMakeLists.txt +0 -2
  247. package/src/llama.cpp/examples/rpc/rpc-server.cpp +0 -171
  248. package/src/llama.cpp/examples/run/CMakeLists.txt +0 -5
  249. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  250. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  251. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  252. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  253. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  254. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  255. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  256. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  257. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  258. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  259. /package/src/llama.cpp/{examples → tools}/batched-bench/CMakeLists.txt +0 -0
  260. /package/src/llama.cpp/{examples → tools}/cvector-generator/CMakeLists.txt +0 -0
  261. /package/src/llama.cpp/{examples → tools}/cvector-generator/completions.txt +0 -0
  262. /package/src/llama.cpp/{examples → tools}/cvector-generator/cvector-generator.cpp +0 -0
  263. /package/src/llama.cpp/{examples → tools}/cvector-generator/mean.hpp +0 -0
  264. /package/src/llama.cpp/{examples → tools}/cvector-generator/negative.txt +0 -0
  265. /package/src/llama.cpp/{examples → tools}/cvector-generator/pca.hpp +0 -0
  266. /package/src/llama.cpp/{examples → tools}/cvector-generator/positive.txt +0 -0
  267. /package/src/llama.cpp/{examples → tools}/export-lora/CMakeLists.txt +0 -0
  268. /package/src/llama.cpp/{examples → tools}/gguf-split/CMakeLists.txt +0 -0
  269. /package/src/llama.cpp/{examples → tools}/imatrix/CMakeLists.txt +0 -0
  270. /package/src/llama.cpp/{examples → tools}/llama-bench/CMakeLists.txt +0 -0
  271. /package/src/llama.cpp/{examples → tools}/main/CMakeLists.txt +0 -0
  272. /package/src/llama.cpp/{examples/llava → tools/mtmd}/requirements.txt +0 -0
  273. /package/src/llama.cpp/{examples → tools}/perplexity/CMakeLists.txt +0 -0
  274. /package/src/llama.cpp/{examples → tools}/quantize/CMakeLists.txt +0 -0
  275. /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.cpp +0 -0
  276. /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.h +0 -0
  277. /package/src/llama.cpp/{examples → tools}/server/bench/requirements.txt +0 -0
  278. /package/src/llama.cpp/{examples → tools}/server/tests/requirements.txt +0 -0
  279. /package/src/llama.cpp/{examples → tools}/tokenize/CMakeLists.txt +0 -0
  280. /package/src/llama.cpp/{examples → tools}/tokenize/tokenize.cpp +0 -0
  281. /package/src/llama.cpp/{examples → tools}/tts/CMakeLists.txt +0 -0
@@ -55,7 +55,37 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
55
55
  if (ubatch->pos && pos) {
56
56
  const int64_t n_tokens = ubatch->n_tokens;
57
57
 
58
- ggml_backend_tensor_set(pos, ubatch->pos, 0, n_tokens*n_pos_per_token*ggml_element_size(pos));
58
+ if (ubatch->token && n_pos_per_embd == 4) {
59
+ // in case we're using M-RoPE with text tokens, convert the 1D positions to 4D
60
+ // the 3 first dims are the same, and 4th dim is all 0
61
+ std::vector<llama_pos> pos_data(n_tokens*n_pos_per_embd);
62
+ // copy the first dimension
63
+ for (int i = 0; i < n_tokens; ++i) {
64
+ pos_data[ i] = ubatch->pos[i];
65
+ pos_data[ n_tokens + i] = ubatch->pos[i];
66
+ pos_data[2 * n_tokens + i] = ubatch->pos[i];
67
+ pos_data[3 * n_tokens + i] = 0; // 4th dim is 0
68
+ }
69
+ ggml_backend_tensor_set(pos, pos_data.data(), 0, pos_data.size()*ggml_element_size(pos));
70
+ } else {
71
+ ggml_backend_tensor_set(pos, ubatch->pos, 0, n_tokens*n_pos_per_embd*ggml_element_size(pos));
72
+ }
73
+ }
74
+ }
75
+
76
+ void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
77
+ if (ubatch->pos && attn_scale) {
78
+ const int64_t n_tokens = ubatch->n_tokens;
79
+
80
+ std::vector<float> attn_scale_data(n_tokens, 0.0f);
81
+ for (int i = 0; i < n_tokens; ++i) {
82
+ const float pos = ubatch->pos[i];
83
+ attn_scale_data[i] = std::log(
84
+ std::floor((pos + 1.0f) / n_attn_temp_floor_scale) + 1.0
85
+ ) * f_attn_temp_scale + 1.0;
86
+ }
87
+
88
+ ggml_backend_tensor_set(attn_scale, attn_scale_data.data(), 0, n_tokens*ggml_element_size(attn_scale));
59
89
  }
60
90
  }
61
91
 
@@ -254,24 +284,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
254
284
 
255
285
  // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
256
286
  for (uint32_t i = 0; i < n_kv; ++i) {
257
- const uint32_t cell_id = i + kv_self->head;
258
-
259
- //////////////////////////////////////////////
260
- // TODO: this should not mutate the KV cache !
261
- llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
262
-
263
- // prevent out-of-bound sources
264
- if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self->size) {
265
- kv_cell.src = cell_id;
266
- }
267
-
268
- data[i] = kv_cell.src;
269
-
270
- // TODO: do not mutate the KV cache
271
- // ensure copy only happens once
272
- if (kv_cell.src != (int32_t) cell_id) {
273
- kv_cell.src = cell_id;
274
- }
287
+ data[i] = kv_self->s_copy(i);
275
288
  }
276
289
  }
277
290
  }
@@ -287,18 +300,7 @@ void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
287
300
 
288
301
  // clear unused states
289
302
  for (int i = 0; i < n_kv; ++i) {
290
- const uint32_t cell_id = i + kv_self->head;
291
-
292
- //////////////////////////////////////////////
293
- // TODO: this should not mutate the KV cache !
294
- llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
295
-
296
- data[i] = (float) (kv_cell.src >= 0);
297
-
298
- // only clear once
299
- if (kv_cell.src < 0) {
300
- kv_cell.src = cell_id;
301
- }
303
+ data[i] = kv_self->s_mask(i);
302
304
  }
303
305
  }
304
306
  }
@@ -402,120 +404,94 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
402
404
 
403
405
  void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
404
406
  if (self_kq_mask || self_kq_mask_swa) {
405
- // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
406
- if (cparams.causal_attn) {
407
- const int64_t n_kv = kv_self->n;
408
- const int64_t n_tokens = ubatch->n_tokens;
409
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
410
- const int64_t n_seqs = ubatch->n_seqs;
411
-
412
- float * data = nullptr;
413
- float * data_swa = nullptr;
414
-
415
- if (self_kq_mask) {
416
- GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
417
- data = (float *) self_kq_mask->data;
418
- }
407
+ const int64_t n_kv = kv_self->n;
408
+ const int64_t n_tokens = ubatch->n_tokens;
409
+ const int64_t n_seq_tokens = ubatch->n_seq_tokens;
410
+ const int64_t n_seqs = ubatch->n_seqs;
419
411
 
420
- if (self_kq_mask_swa) {
421
- GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
422
- data_swa = (float *) self_kq_mask_swa->data;
423
- }
412
+ float * data = nullptr;
413
+ float * data_swa = nullptr;
424
414
 
425
- // For causal attention, use only the previous KV cells
426
- // of the correct sequence for each token of the ubatch.
427
- // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
428
- for (int h = 0; h < 1; ++h) {
429
- for (int s = 0; s < n_seqs; ++s) {
430
- const llama_seq_id seq_id = ubatch->seq_id[s][0];
415
+ if (self_kq_mask) {
416
+ GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
417
+ data = (float *) self_kq_mask->data;
418
+ }
431
419
 
432
- for (int j = 0; j < n_seq_tokens; ++j) {
433
- const llama_pos pos = ubatch->pos[s*n_seq_tokens + j];
420
+ if (self_kq_mask_swa) {
421
+ GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
422
+ data_swa = (float *) self_kq_mask_swa->data;
423
+ }
434
424
 
435
- for (int i = 0; i < n_kv; ++i) {
436
- float f;
437
- if (!kv_self->cells[i].has_seq_id(seq_id) || kv_self->cells[i].pos > pos) {
438
- f = -INFINITY;
425
+ // Use only the previous KV cells of the correct sequence for each token of the ubatch.
426
+ // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
427
+ // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
428
+ // Causal mask:
429
+ // xxx-------
430
+ // xxxx------
431
+ // xxxxx-----
432
+ // Non-causal mask:
433
+ // xxxxx-----
434
+ // xxxxx-----
435
+ // xxxxx-----
436
+ // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
437
+ for (int h = 0; h < 1; ++h) {
438
+ for (int s = 0; s < n_seqs; ++s) {
439
+ const llama_seq_id seq_id = ubatch->seq_id[s][0];
440
+
441
+ for (int j = 0; j < n_seq_tokens; ++j) {
442
+ const llama_pos pos = ubatch->pos[s*n_seq_tokens + j];
443
+ for (int i = 0; i < n_kv; ++i) {
444
+ float f;
445
+ // mask the token if:
446
+ if (!kv_self->cells[i].has_seq_id(seq_id) // not the correct sequence
447
+ || (cparams.causal_attn && kv_self->cells[i].pos > pos) // for causal, mask future tokens
448
+ ) {
449
+ f = -INFINITY;
450
+ } else {
451
+ if (hparams.use_alibi) {
452
+ f = -std::abs(kv_self->cells[i].pos - pos);
439
453
  } else {
440
- if (hparams.use_alibi) {
441
- f = -std::abs(kv_self->cells[i].pos - pos);
442
- } else {
443
- f = 0.0f;
444
- }
454
+ f = 0.0f;
445
455
  }
456
+ }
446
457
 
447
- if (data) {
448
- data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
449
- }
458
+ if (data) {
459
+ data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
460
+ }
450
461
 
451
- // may need to cut off old tokens for sliding window
452
- if (data_swa) {
462
+ // may need to cut off old tokens for sliding window
463
+ // TODO @ngxson : we are currently re-using the swa logic to store the chunked mask, we should rename SWA to something more generic like "aux mask"
464
+ if (data_swa) {
465
+ if (hparams.n_attn_chunk) {
466
+ llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
467
+ if (kv_self->cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
468
+ f = -INFINITY;
469
+ }
470
+ } else {
453
471
  if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
454
472
  f = -INFINITY;
455
473
  }
456
- data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
457
474
  }
475
+ data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
458
476
  }
459
477
  }
460
478
  }
479
+ }
461
480
 
462
- if (data) {
463
- for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
464
- for (int j = 0; j < n_kv; ++j) {
465
- data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
466
- }
467
- }
468
- }
469
-
470
- if (data_swa) {
471
- for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
472
- for (int j = 0; j < n_kv; ++j) {
473
- data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
474
- }
481
+ // mask padded tokens
482
+ if (data) {
483
+ for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
484
+ for (int j = 0; j < n_kv; ++j) {
485
+ data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
475
486
  }
476
487
  }
477
488
  }
478
- } else {
479
- const int64_t n_tokens = ubatch->n_tokens;
480
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
481
- const int64_t n_seqs = ubatch->n_seqs;
482
- // when using kv cache, the mask needs to match the kv cache size
483
- const int64_t n_stride = n_tokens;
484
-
485
- GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
486
-
487
- float * data = (float *) self_kq_mask->data;
488
-
489
- for (int h = 0; h < 1; ++h) {
490
- for (int s1 = 0; s1 < n_seqs; ++s1) {
491
- const llama_seq_id seq_id = ubatch->seq_id[s1][0];
492
-
493
- for (int j = 0; j < n_seq_tokens; ++j) {
494
- const int32_t tj = s1*n_seq_tokens + j;
495
-
496
- for (int s0 = 0; s0 < n_seqs; ++s0) {
497
- for (int i = 0; i < n_seq_tokens; ++i) {
498
- const int32_t ti = s0*n_seq_tokens + i;
499
- float f = -INFINITY;
500
-
501
- for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
502
- if (ubatch->seq_id[s0][s] == seq_id) {
503
- if (hparams.use_alibi) {
504
- f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
505
- } else {
506
- f = 0.0f;
507
- }
508
- break;
509
- }
510
- }
511
-
512
- data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
513
- }
514
- }
515
489
 
516
- for (int i = n_tokens; i < n_stride; ++i) {
517
- data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
518
- }
490
+ // mask padded tokens
491
+ if (data_swa) {
492
+ for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
493
+ for (int j = 0; j < n_kv; ++j) {
494
+ data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
519
495
  }
520
496
  }
521
497
  }
@@ -602,7 +578,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
602
578
  res (std::make_unique<llm_graph_result>()) {
603
579
  }
604
580
 
605
- int64_t llm_graph_context::n_pos_per_token() const {
581
+ int64_t llm_graph_context::n_pos_per_embd() const {
606
582
  return arch == LLM_ARCH_QWEN2VL ? 4 : 1;
607
583
  }
608
584
 
@@ -806,13 +782,17 @@ ggml_tensor * llm_graph_context::build_ffn(
806
782
  } break;
807
783
  }
808
784
 
809
- if (type_gate == LLM_FFN_PAR) {
785
+ if (gate && type_gate == LLM_FFN_PAR) {
810
786
  cur = ggml_mul(ctx0, cur, tmp);
811
787
  cb(cur, "ffn_gate_par", il);
812
788
  }
813
789
 
814
790
  if (down) {
815
791
  cur = build_lora_mm(down, cur);
792
+ if (arch == LLM_ARCH_GLM4) {
793
+ // GLM4 seems to have numerical issues with half-precision accumulators
794
+ ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
795
+ }
816
796
  }
817
797
 
818
798
  if (down_b) {
@@ -846,8 +826,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
846
826
  float w_scale,
847
827
  llama_expert_gating_func_type gating_op,
848
828
  int il) const {
849
- int64_t n_embd = cur->ne[0];
850
- int64_t n_tokens = cur->ne[1];
829
+ const int64_t n_embd = cur->ne[0];
830
+ const int64_t n_tokens = cur->ne[1];
831
+ const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
851
832
 
852
833
  ggml_tensor * logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens]
853
834
  cb(logits, "ffn_moe_logits", il);
@@ -875,6 +856,12 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
875
856
  cb(selection_probs, "ffn_moe_probs_biased", il);
876
857
  }
877
858
 
859
+ // llama4 doesn't have exp_probs_b, and sigmoid is only used after top_k
860
+ // see: https://github.com/meta-llama/llama-models/blob/699a02993512fb36936b1b0741e13c06790bcf98/models/llama4/moe.py#L183-L198
861
+ if (arch == LLM_ARCH_LLAMA4) {
862
+ selection_probs = logits;
863
+ }
864
+
878
865
  // select experts
879
866
  ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
880
867
  cb(selected_experts->src[0], "ffn_moe_argsort", il);
@@ -901,34 +888,53 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
901
888
  }
902
889
 
903
890
  cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
891
+
892
+ if (weight_before_ffn) {
893
+ // TODO: this is a workaround as we don't yet have a repeat op that takes custom dim (ggml_repeat_4d)
894
+ ggml_tensor * repeated = ggml_new_tensor_3d(ctx0, cur->type, n_embd, n_expert_used, n_tokens);
895
+ repeated = ggml_repeat(ctx0, cur, repeated); // [n_embd, n_expert_used, n_tokens]
896
+ cur = ggml_mul(ctx0, repeated, weights);
897
+ cb(cur, "ffn_moe_weighted", il);
898
+ }
899
+
904
900
  ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
905
901
  cb(up, "ffn_moe_up", il);
906
902
 
907
- ggml_tensor * gate = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
908
- cb(gate, "ffn_moe_gate", il);
903
+ ggml_tensor * experts = nullptr;
904
+ if (gate_exps) {
905
+ cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
906
+ cb(cur, "ffn_moe_gate", il);
907
+ } else {
908
+ cur = up;
909
+ }
909
910
 
910
911
  switch (type_op) {
911
912
  case LLM_FFN_SILU:
912
913
  {
913
- gate = ggml_silu(ctx0, gate);
914
- cb(gate, "ffn_moe_silu", il);
914
+ cur = ggml_silu(ctx0, cur);
915
+ cb(cur, "ffn_moe_silu", il);
915
916
  } break;
916
917
  case LLM_FFN_GELU:
917
918
  {
918
- gate = ggml_gelu(ctx0, gate);
919
- cb(gate, "ffn_moe_gelu", il);
919
+ cur = ggml_gelu(ctx0, cur);
920
+ cb(cur, "ffn_moe_gelu", il);
920
921
  } break;
921
922
  default:
922
923
  GGML_ABORT("fatal error");
923
924
  }
924
925
 
925
- ggml_tensor * par = ggml_mul(ctx0, up, gate); // [n_ff, n_expert_used, n_tokens]
926
- cb(par, "ffn_moe_gate_par", il);
926
+ if (gate_exps) {
927
+ cur = ggml_mul(ctx0, cur, up); // [n_ff, n_expert_used, n_tokens]
928
+ cb(cur, "ffn_moe_gate_par", il);
929
+ }
927
930
 
928
- ggml_tensor * experts = build_lora_mm_id(down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
931
+ experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
929
932
  cb(experts, "ffn_moe_down", il);
930
933
 
931
- experts = ggml_mul(ctx0, experts, weights);
934
+ if (!weight_before_ffn) {
935
+ experts = ggml_mul(ctx0, experts, weights);
936
+ cb(cur, "ffn_moe_weighted", il);
937
+ }
932
938
 
933
939
  // aggregate experts
934
940
  ggml_tensor * moe_out = nullptr;
@@ -948,6 +954,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
948
954
  moe_out = ggml_cont(ctx0, moe_out);
949
955
  }
950
956
 
957
+ cb(moe_out, "ffn_moe_out", il);
958
+
951
959
  return moe_out;
952
960
  }
953
961
 
@@ -963,6 +971,7 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
963
971
  inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
964
972
  //cb(inp->tokens, "inp_tokens", -1);
965
973
  ggml_set_input(inp->tokens);
974
+ res->t_tokens = inp->tokens;
966
975
 
967
976
  cur = ggml_get_rows(ctx0, tok_embd, inp->tokens);
968
977
 
@@ -1003,11 +1012,25 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
1003
1012
  }
1004
1013
 
1005
1014
  ggml_tensor * llm_graph_context::build_inp_pos() const {
1006
- auto inp = std::make_unique<llm_graph_input_pos>(n_pos_per_token());
1015
+ auto inp = std::make_unique<llm_graph_input_pos>(n_pos_per_embd());
1007
1016
 
1008
1017
  auto & cur = inp->pos;
1009
1018
 
1010
- cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_token());
1019
+ cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_embd());
1020
+ ggml_set_input(cur);
1021
+
1022
+ res->add_input(std::move(inp));
1023
+
1024
+ return cur;
1025
+ }
1026
+
1027
+ ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
1028
+ auto inp = std::make_unique<llm_graph_input_attn_temp>(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale);
1029
+
1030
+ auto & cur = inp->attn_scale;
1031
+
1032
+ // this need to be 1x1xN for broadcasting
1033
+ cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens);
1011
1034
  ggml_set_input(cur);
1012
1035
 
1013
1036
  res->add_input(std::move(inp));
@@ -1055,7 +1078,7 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
1055
1078
  }
1056
1079
 
1057
1080
  ggml_tensor * llm_graph_context::build_inp_s_copy() const {
1058
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1081
+ const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1059
1082
 
1060
1083
  auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self);
1061
1084
 
@@ -1072,7 +1095,7 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
1072
1095
  }
1073
1096
 
1074
1097
  ggml_tensor * llm_graph_context::build_inp_s_mask() const {
1075
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1098
+ const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1076
1099
 
1077
1100
  auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self);
1078
1101
 
@@ -1164,6 +1187,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1164
1187
  ggml_tensor * v,
1165
1188
  ggml_tensor * kq_b,
1166
1189
  ggml_tensor * kq_mask,
1190
+ ggml_tensor * v_mla,
1167
1191
  bool v_trans,
1168
1192
  float kq_scale) const {
1169
1193
  //const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
@@ -1175,8 +1199,6 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1175
1199
  //const auto & n_embd_head_k = hparams.n_embd_head_k;
1176
1200
  //const auto & n_embd_head_v = hparams.n_embd_head_v;
1177
1201
 
1178
- const auto n_embd_head_v = v_trans ? v->ne[1] : v->ne[0];
1179
-
1180
1202
  const auto n_tokens = q->ne[1];
1181
1203
  const auto n_head = q->ne[2];
1182
1204
  const auto n_kv = k->ne[1];
@@ -1191,12 +1213,37 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1191
1213
  v = ggml_transpose(ctx0, v);
1192
1214
  }
1193
1215
 
1216
+ // this can happen when KV cache is not used (e.g. an embedding model with non-causal attn)
1217
+ if (k->type == GGML_TYPE_F32) {
1218
+ k = ggml_cast(ctx0, k, GGML_TYPE_F16);
1219
+ }
1220
+
1221
+ if (v->type == GGML_TYPE_F32) {
1222
+ v = ggml_cast(ctx0, v, GGML_TYPE_F16);
1223
+ }
1224
+
1194
1225
  cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
1195
1226
  hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
1196
1227
 
1197
1228
  ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
1198
1229
 
1199
- cur = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens);
1230
+ if (v_mla) {
1231
+ #if 0
1232
+ // v_mla can be applied as a matrix-vector multiplication with broadcasting across dimension 3 == n_tokens.
1233
+ // However, the code is optimized for dimensions 0 and 1 being large, so this is ineffient.
1234
+ cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
1235
+ cur = ggml_mul_mat(ctx0, v_mla, cur);
1236
+ #else
1237
+ // It's preferable to do the calculation as a matrix-matrix multiplication with n_tokens in dimension 1.
1238
+ // The permutations are noops and only change how the tensor data is interpreted.
1239
+ cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
1240
+ cur = ggml_mul_mat(ctx0, v_mla, cur);
1241
+ cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
1242
+ cur = ggml_cont(ctx0, cur); // Needed because ggml_reshape_2d expects contiguous inputs.
1243
+ #endif
1244
+ }
1245
+
1246
+ cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
1200
1247
  } else {
1201
1248
  ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
1202
1249
 
@@ -1234,9 +1281,14 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1234
1281
 
1235
1282
  ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
1236
1283
 
1237
- ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
1284
+ // for MLA with the absorption optimization, we need to "decompress" from MQA back to MHA
1285
+ if (v_mla) {
1286
+ kqv = ggml_mul_mat(ctx0, v_mla, kqv);
1287
+ }
1288
+
1289
+ cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
1238
1290
 
1239
- cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
1291
+ cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
1240
1292
 
1241
1293
  if (!cparams.offload_kqv) {
1242
1294
  // all nodes between the KV store and the attention output are run on the CPU
@@ -1271,6 +1323,7 @@ ggml_tensor * llm_graph_context::build_attn(
1271
1323
  ggml_tensor * k_cur,
1272
1324
  ggml_tensor * v_cur,
1273
1325
  ggml_tensor * kq_b,
1326
+ ggml_tensor * v_mla,
1274
1327
  float kq_scale,
1275
1328
  int il) const {
1276
1329
  GGML_UNUSED(n_tokens);
@@ -1292,7 +1345,7 @@ ggml_tensor * llm_graph_context::build_attn(
1292
1345
  ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
1293
1346
  //cb(k, "v", il);
1294
1347
 
1295
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale);
1348
+ ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
1296
1349
 
1297
1350
  cb(cur, "kqv_out", il);
1298
1351
 
@@ -1346,6 +1399,7 @@ ggml_tensor * llm_graph_context::build_attn(
1346
1399
  ggml_tensor * k_cur,
1347
1400
  ggml_tensor * v_cur,
1348
1401
  ggml_tensor * kq_b,
1402
+ ggml_tensor * v_mla,
1349
1403
  float kq_scale,
1350
1404
  int il) const {
1351
1405
  // these nodes are added to the graph together so that they are not reordered
@@ -1366,8 +1420,6 @@ ggml_tensor * llm_graph_context::build_attn(
1366
1420
 
1367
1421
  // store to KV cache
1368
1422
  {
1369
- GGML_ASSERT(!kv_self->recurrent);
1370
-
1371
1423
  const auto kv_head = kv_self->head;
1372
1424
 
1373
1425
  GGML_ASSERT(kv_self->size == n_ctx);
@@ -1431,7 +1483,7 @@ ggml_tensor * llm_graph_context::build_attn(
1431
1483
  ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v,
1432
1484
  0);
1433
1485
 
1434
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_trans, kq_scale);
1486
+ ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, v_trans, kq_scale);
1435
1487
  cb(cur, "kqv_out", il);
1436
1488
 
1437
1489
  if (wo) {
@@ -1471,6 +1523,7 @@ ggml_tensor * llm_graph_context::build_attn(
1471
1523
  ggml_tensor * k_cur,
1472
1524
  ggml_tensor * v_cur,
1473
1525
  ggml_tensor * kq_b,
1526
+ ggml_tensor * v_mla,
1474
1527
  float kq_scale,
1475
1528
  int il) const {
1476
1529
  // these nodes are added to the graph together so that they are not reordered
@@ -1490,7 +1543,7 @@ ggml_tensor * llm_graph_context::build_attn(
1490
1543
  ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
1491
1544
  //cb(k, "v", il);
1492
1545
 
1493
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale);
1546
+ ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
1494
1547
 
1495
1548
  cb(cur, "kqv_out", il);
1496
1549
 
@@ -1516,7 +1569,7 @@ ggml_tensor * llm_graph_context::build_copy_mask_state(
1516
1569
  ggml_tensor * state_mask,
1517
1570
  int32_t n_state,
1518
1571
  int32_t n_seqs) const {
1519
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1572
+ const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1520
1573
 
1521
1574
  const auto n_kv = kv_self->n;
1522
1575
  const auto kv_head = kv_self->head;
@@ -1548,7 +1601,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
1548
1601
  ggml_tensor * state_mask,
1549
1602
  const llama_ubatch & ubatch,
1550
1603
  int il) const {
1551
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1604
+ const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1552
1605
 
1553
1606
  const auto token_shift_count = hparams.token_shift_count;
1554
1607
 
@@ -1569,7 +1622,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
1569
1622
  ggml_tensor * token_shift,
1570
1623
  const llama_ubatch & ubatch,
1571
1624
  int il) const {
1572
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1625
+ const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1573
1626
 
1574
1627
  const auto token_shift_count = hparams.token_shift_count;
1575
1628
  const auto n_embd = hparams.n_embd;
@@ -1659,4 +1712,3 @@ void llm_graph_context::build_pooling(
1659
1712
 
1660
1713
  ggml_build_forward_expand(gf, cur);
1661
1714
  }
1662
-