cui-llama.rn 1.7.3 → 1.7.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (276) hide show
  1. package/README.md +217 -17
  2. package/android/src/main/CMakeLists.txt +34 -15
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +94 -8
  4. package/android/src/main/java/com/rnllama/RNLlama.java +247 -0
  5. package/android/src/main/jni.cpp +213 -14
  6. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  10. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  11. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  12. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  13. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  14. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +35 -0
  15. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +34 -0
  16. package/cpp/README.md +1 -1
  17. package/cpp/chat-parser.cpp +385 -0
  18. package/cpp/chat-parser.h +120 -0
  19. package/cpp/chat.cpp +726 -596
  20. package/cpp/chat.h +71 -6
  21. package/cpp/common.cpp +56 -38
  22. package/cpp/common.h +9 -3
  23. package/cpp/ggml-backend-reg.cpp +5 -0
  24. package/cpp/ggml-backend.cpp +10 -2
  25. package/cpp/ggml-common.h +4 -0
  26. package/cpp/ggml-cpu/amx/amx.cpp +1 -1
  27. package/cpp/ggml-cpu/amx/mmq.cpp +11 -10
  28. package/cpp/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  29. package/cpp/ggml-cpu/arch/arm/quants.c +4114 -0
  30. package/cpp/ggml-cpu/arch/arm/repack.cpp +2163 -0
  31. package/cpp/ggml-cpu/arch/x86/cpu-feats.cpp +327 -0
  32. package/cpp/ggml-cpu/arch/x86/quants.c +4311 -0
  33. package/cpp/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
  34. package/cpp/ggml-cpu/arch-fallback.h +184 -0
  35. package/cpp/ggml-cpu/common.h +4 -3
  36. package/cpp/ggml-cpu/ggml-cpu-impl.h +21 -16
  37. package/cpp/ggml-cpu/ggml-cpu.c +123 -104
  38. package/cpp/ggml-cpu/ggml-cpu.cpp +11 -8
  39. package/cpp/ggml-cpu/ops.cpp +330 -148
  40. package/cpp/ggml-cpu/ops.h +1 -0
  41. package/cpp/ggml-cpu/quants.c +1158 -0
  42. package/cpp/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  43. package/cpp/ggml-cpu/repack.cpp +1571 -0
  44. package/cpp/ggml-cpu/repack.h +98 -0
  45. package/cpp/ggml-cpu/simd-mappings.h +330 -38
  46. package/cpp/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  47. package/cpp/ggml-cpu/vec.cpp +87 -18
  48. package/cpp/ggml-cpu/vec.h +249 -94
  49. package/cpp/ggml-cpu.h +1 -0
  50. package/cpp/ggml-impl.h +63 -183
  51. package/cpp/ggml-llama-sim.metallib +0 -0
  52. package/cpp/ggml-llama.metallib +0 -0
  53. package/cpp/ggml-metal.m +152 -45
  54. package/cpp/ggml-quants.c +0 -2
  55. package/cpp/ggml.c +61 -21
  56. package/cpp/ggml.h +22 -3
  57. package/cpp/gguf.cpp +24 -3
  58. package/cpp/json-partial.cpp +256 -0
  59. package/cpp/json-partial.h +38 -0
  60. package/cpp/json-schema-to-grammar.cpp +5 -47
  61. package/cpp/json-schema-to-grammar.h +4 -4
  62. package/cpp/llama-arch.cpp +153 -3
  63. package/cpp/llama-arch.h +27 -1
  64. package/cpp/llama-batch.cpp +741 -272
  65. package/cpp/llama-batch.h +112 -54
  66. package/cpp/llama-chat.cpp +30 -8
  67. package/cpp/llama-chat.h +1 -0
  68. package/cpp/llama-context.cpp +524 -339
  69. package/cpp/llama-context.h +38 -17
  70. package/cpp/llama-cparams.cpp +4 -0
  71. package/cpp/llama-cparams.h +2 -0
  72. package/cpp/llama-grammar.cpp +12 -2
  73. package/cpp/llama-graph.cpp +431 -356
  74. package/cpp/llama-graph.h +126 -58
  75. package/cpp/llama-hparams.cpp +10 -2
  76. package/cpp/llama-hparams.h +19 -2
  77. package/cpp/llama-kv-cache-unified-iswa.cpp +279 -0
  78. package/cpp/llama-kv-cache-unified-iswa.h +128 -0
  79. package/cpp/llama-kv-cache-unified.cpp +1841 -0
  80. package/cpp/llama-kv-cache-unified.h +303 -0
  81. package/cpp/llama-kv-cells.h +439 -0
  82. package/cpp/llama-memory-hybrid.cpp +246 -0
  83. package/cpp/llama-memory-hybrid.h +138 -0
  84. package/cpp/llama-memory-recurrent.cpp +1112 -0
  85. package/cpp/llama-memory-recurrent.h +183 -0
  86. package/cpp/llama-memory.cpp +41 -0
  87. package/cpp/llama-memory.h +86 -5
  88. package/cpp/llama-mmap.cpp +1 -1
  89. package/cpp/llama-model-loader.cpp +42 -17
  90. package/cpp/llama-model-saver.cpp +1 -0
  91. package/cpp/llama-model.cpp +1639 -513
  92. package/cpp/llama-model.h +26 -0
  93. package/cpp/llama-sampling.cpp +2 -2
  94. package/cpp/llama-vocab.cpp +65 -28
  95. package/cpp/llama-vocab.h +1 -0
  96. package/cpp/llama.cpp +11 -7
  97. package/cpp/llama.h +150 -42
  98. package/cpp/minja/chat-template.hpp +1 -1
  99. package/cpp/minja/minja.hpp +1 -1
  100. package/cpp/{json.hpp → nlohmann/json.hpp} +3027 -2267
  101. package/cpp/nlohmann/json_fwd.hpp +187 -0
  102. package/cpp/regex-partial.cpp +204 -0
  103. package/cpp/regex-partial.h +56 -0
  104. package/cpp/rn-llama.cpp +646 -35
  105. package/cpp/rn-llama.h +32 -1
  106. package/cpp/rn-tts.h +39 -0
  107. package/cpp/sampling.cpp +7 -8
  108. package/cpp/tools/mtmd/clip-impl.h +5 -0
  109. package/cpp/tools/mtmd/clip.cpp +572 -436
  110. package/cpp/tools/mtmd/clip.h +14 -4
  111. package/cpp/tools/mtmd/mtmd-audio.cpp +0 -86
  112. package/cpp/tools/mtmd/mtmd-audio.h +2 -17
  113. package/cpp/tools/mtmd/mtmd-helper.cpp +175 -12
  114. package/cpp/tools/mtmd/mtmd-helper.h +91 -0
  115. package/cpp/tools/mtmd/mtmd.cpp +368 -248
  116. package/cpp/tools/mtmd/mtmd.h +6 -70
  117. package/cpp/unicode.cpp +5 -0
  118. package/ios/CMakeLists.txt +26 -6
  119. package/ios/RNLlama.h +1 -1
  120. package/ios/RNLlama.mm +153 -3
  121. package/ios/RNLlamaContext.h +9 -1
  122. package/ios/RNLlamaContext.mm +112 -9
  123. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat-parser.h +120 -0
  124. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat.h +71 -6
  125. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +9 -3
  126. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-common.h +4 -0
  127. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu.h +1 -0
  128. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-impl.h +63 -183
  129. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +22 -3
  130. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/json-partial.h +38 -0
  131. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/json-schema-to-grammar.h +4 -4
  132. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-arch.h +27 -1
  133. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-batch.h +112 -54
  134. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
  135. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +38 -17
  136. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cparams.h +2 -0
  137. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +126 -58
  138. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +19 -2
  139. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache-unified-iswa.h +128 -0
  140. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache-unified.h +303 -0
  141. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cells.h +439 -0
  142. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory-hybrid.h +138 -0
  143. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory-recurrent.h +183 -0
  144. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +86 -5
  145. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +26 -0
  146. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-vocab.h +1 -0
  147. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +150 -42
  148. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/chat-template.hpp +1 -1
  149. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/minja.hpp +1 -1
  150. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/{json.hpp → nlohmann/json.hpp} +3027 -2267
  151. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/nlohmann/json_fwd.hpp +187 -0
  152. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/regex-partial.h +56 -0
  153. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +32 -1
  154. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-tts.h +39 -0
  155. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  156. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
  157. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat-parser.h +120 -0
  158. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +71 -6
  159. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +9 -3
  160. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-common.h +4 -0
  161. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +1 -0
  162. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +63 -183
  163. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +22 -3
  164. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/json-partial.h +38 -0
  165. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/json-schema-to-grammar.h +4 -4
  166. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +27 -1
  167. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +112 -54
  168. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
  169. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +38 -17
  170. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +2 -0
  171. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +126 -58
  172. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +19 -2
  173. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache-unified-iswa.h +128 -0
  174. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache-unified.h +303 -0
  175. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cells.h +439 -0
  176. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory-hybrid.h +138 -0
  177. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory-recurrent.h +183 -0
  178. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +86 -5
  179. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +26 -0
  180. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +1 -0
  181. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +150 -42
  182. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +1 -1
  183. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +1 -1
  184. package/ios/rnllama.xcframework/{tvos-arm64/rnllama.framework/Headers → ios-arm64_x86_64-simulator/rnllama.framework/Headers/nlohmann}/json.hpp +3027 -2267
  185. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/nlohmann/json_fwd.hpp +187 -0
  186. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/regex-partial.h +56 -0
  187. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +32 -1
  188. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-tts.h +39 -0
  189. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  190. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  191. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat-parser.h +120 -0
  192. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat.h +71 -6
  193. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +9 -3
  194. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-common.h +4 -0
  195. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu.h +1 -0
  196. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-impl.h +63 -183
  197. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +22 -3
  198. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/json-partial.h +38 -0
  199. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/json-schema-to-grammar.h +4 -4
  200. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-arch.h +27 -1
  201. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-batch.h +112 -54
  202. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
  203. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +38 -17
  204. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cparams.h +2 -0
  205. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +126 -58
  206. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +19 -2
  207. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache-unified-iswa.h +128 -0
  208. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache-unified.h +303 -0
  209. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cells.h +439 -0
  210. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory-hybrid.h +138 -0
  211. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory-recurrent.h +183 -0
  212. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +86 -5
  213. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +26 -0
  214. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-vocab.h +1 -0
  215. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +150 -42
  216. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/chat-template.hpp +1 -1
  217. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/minja.hpp +1 -1
  218. package/ios/rnllama.xcframework/{ios-arm64_x86_64-simulator/rnllama.framework/Headers → tvos-arm64/rnllama.framework/Headers/nlohmann}/json.hpp +3027 -2267
  219. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/nlohmann/json_fwd.hpp +187 -0
  220. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/regex-partial.h +56 -0
  221. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +32 -1
  222. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-tts.h +39 -0
  223. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  224. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
  225. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat-parser.h +120 -0
  226. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +71 -6
  227. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +9 -3
  228. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-common.h +4 -0
  229. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +1 -0
  230. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +63 -183
  231. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +22 -3
  232. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json-partial.h +38 -0
  233. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json-schema-to-grammar.h +4 -4
  234. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +27 -1
  235. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +112 -54
  236. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
  237. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +38 -17
  238. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +2 -0
  239. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +126 -58
  240. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +19 -2
  241. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache-unified-iswa.h +128 -0
  242. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache-unified.h +303 -0
  243. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cells.h +439 -0
  244. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory-hybrid.h +138 -0
  245. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory-recurrent.h +183 -0
  246. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +86 -5
  247. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +26 -0
  248. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +1 -0
  249. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +150 -42
  250. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +1 -1
  251. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +1 -1
  252. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/nlohmann/json.hpp +25526 -0
  253. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/nlohmann/json_fwd.hpp +187 -0
  254. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/regex-partial.h +56 -0
  255. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +32 -1
  256. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-tts.h +39 -0
  257. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  258. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  259. package/jest/mock.js +24 -0
  260. package/package.json +1 -1
  261. package/src/NativeRNLlama.ts +46 -2
  262. package/src/index.ts +105 -1
  263. package/cpp/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  264. package/cpp/ggml-cpu/ggml-cpu-quants.c +0 -13326
  265. package/cpp/ggml-cpu/sgemm.cpp +0 -3544
  266. package/cpp/ggml-cpu/sgemm.h +0 -14
  267. package/cpp/llama-kv-cache.cpp +0 -2827
  268. package/cpp/llama-kv-cache.h +0 -515
  269. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +0 -515
  270. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +0 -515
  271. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +0 -515
  272. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json.hpp +0 -24766
  273. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +0 -515
  274. /package/cpp/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
  275. /package/cpp/tools/mtmd/{miniaudio.h → miniaudio/miniaudio.h} +0 -0
  276. /package/cpp/tools/mtmd/{stb_image.h → stb/stb_image.h} +0 -0
@@ -3,7 +3,11 @@
3
3
  #include "llama-impl.h"
4
4
  #include "llama-batch.h"
5
5
  #include "llama-cparams.h"
6
- #include "llama-kv-cache.h"
6
+
7
+ #include "llama-kv-cache-unified.h"
8
+ #include "llama-kv-cache-unified-iswa.h"
9
+ #include "llama-memory-hybrid.h"
10
+ #include "llama-memory-recurrent.h"
7
11
 
8
12
  #include <cassert>
9
13
  #include <cmath>
@@ -83,41 +87,33 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
83
87
 
84
88
  void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
85
89
  if (pos_bucket) {
86
- kv_self->set_input_pos_bucket(pos_bucket, ubatch);
90
+ mctx->set_input_pos_bucket(pos_bucket, ubatch);
87
91
  }
88
92
  }
89
93
 
90
94
  void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
91
- if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
92
- //LM_GGML_ASSERT(out_ids && "every model that can must skip unused outputs");
95
+ LM_GGML_ASSERT(out_ids);
93
96
 
94
- if (!out_ids) {
95
- LLAMA_LOG_WARN("%s: 'out_ids' is not created\n", __func__);
96
- } else {
97
- const int64_t n_tokens = ubatch->n_tokens;
97
+ const int64_t n_tokens = ubatch->n_tokens;
98
98
 
99
- LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(out_ids->buffer));
100
- int32_t * data = (int32_t *) out_ids->data;
99
+ LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(out_ids->buffer));
100
+ int32_t * data = (int32_t *) out_ids->data;
101
101
 
102
- if (n_outputs == n_tokens) {
103
- for (int i = 0; i < n_tokens; ++i) {
104
- data[i] = i;
105
- }
106
- } else if (ubatch->output) {
107
- int32_t n_outputs = 0;
108
- for (int i = 0; i < n_tokens; ++i) {
109
- if (ubatch->output[i]) {
110
- data[n_outputs++] = i;
111
- }
112
- }
113
- // the graph needs to have been passed the correct number of outputs
114
- LM_GGML_ASSERT(n_outputs == n_outputs);
115
- } else if (n_outputs == 1) {
116
- // only keep last output
117
- data[0] = n_tokens - 1;
118
- } else {
119
- LM_GGML_ASSERT(n_outputs == 0);
120
- }
102
+ if (n_outputs == n_tokens) {
103
+ for (int i = 0; i < n_tokens; ++i) {
104
+ data[i] = i;
105
+ }
106
+
107
+ return;
108
+ }
109
+
110
+ LM_GGML_ASSERT(ubatch->output);
111
+
112
+ int n_outputs = 0;
113
+
114
+ for (int i = 0; i < n_tokens; ++i) {
115
+ if (ubatch->output[i]) {
116
+ data[n_outputs++] = i;
121
117
  }
122
118
  }
123
119
  }
@@ -126,139 +122,114 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
126
122
  if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
127
123
  const int64_t n_tokens = ubatch->n_tokens;
128
124
  const int64_t n_seq_tokens = ubatch->n_seq_tokens;
129
- const int64_t n_seqs = ubatch->n_seqs;
125
+ const int64_t n_seqs_unq = ubatch->n_seqs_unq;
130
126
 
131
127
  LM_GGML_ASSERT(mean);
132
128
  LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(mean->buffer));
133
129
 
134
130
  float * data = (float *) mean->data;
135
- memset(mean->data, 0, n_tokens * n_tokens * lm_ggml_element_size(mean));
136
-
137
- std::vector<uint64_t> sum(n_tokens, 0);
131
+ memset(mean->data, 0, n_tokens*n_seqs_unq*lm_ggml_element_size(mean));
138
132
 
139
- for (int s = 0; s < n_seqs; ++s) {
140
- const llama_seq_id seq_id = ubatch->seq_id[s][0];
133
+ std::vector<uint64_t> sums(n_seqs_unq, 0);
134
+ for (int i = 0; i < n_tokens; i += n_seq_tokens) {
135
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
136
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
137
+ const int32_t seq_idx = ubatch->seq_idx[seq_id];
141
138
 
142
- // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
143
- LM_GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
144
-
145
- sum[seq_id] += ubatch->n_seq_tokens;
139
+ sums[seq_idx] += ubatch->n_seq_tokens;
140
+ }
146
141
  }
147
142
 
148
- std::vector<float> div(n_tokens, 0.0f);
149
- for (int i = 0; i < n_tokens; ++i) {
150
- const uint64_t s = sum[i];
151
- if (s > 0) {
152
- div[i] = 1.0f/float(s);
143
+ std::vector<float> div(n_seqs_unq, 0.0f);
144
+ for (int s = 0; s < n_seqs_unq; ++s) {
145
+ const uint64_t sum = sums[s];
146
+ if (sum > 0) {
147
+ div[s] = 1.0f/float(sum);
153
148
  }
154
149
  }
155
150
 
156
- for (int s = 0; s < n_seqs; ++s) {
157
- const llama_seq_id seq_id = ubatch->seq_id[s][0];
151
+ for (int i = 0; i < n_tokens; i += n_seq_tokens) {
152
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
153
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
154
+ const int32_t seq_idx = ubatch->seq_idx[seq_id];
158
155
 
159
- for (int i = 0; i < n_seq_tokens; ++i) {
160
- data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id];
156
+ for (int j = 0; j < n_seq_tokens; ++j) {
157
+ data[seq_idx*n_tokens + i + j] = div[seq_idx];
158
+ }
161
159
  }
162
160
  }
163
161
  }
164
162
  }
165
163
 
166
164
  void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
167
- if (cparams.embeddings && (
168
- cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
169
- cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) {
170
- const int64_t n_tokens = ubatch->n_tokens;
171
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
172
- const int64_t n_seqs = ubatch->n_seqs;
165
+ const int64_t n_tokens = ubatch->n_tokens;
166
+ const int64_t n_seq_tokens = ubatch->n_seq_tokens;
167
+ const int64_t n_seqs_unq = ubatch->n_seqs_unq;
173
168
 
169
+ if (cparams.embeddings && (
170
+ cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
171
+ cparams.pooling_type == LLAMA_POOLING_TYPE_RANK
172
+ )) {
174
173
  LM_GGML_ASSERT(cls);
175
174
  LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(cls->buffer));
176
175
 
177
176
  uint32_t * data = (uint32_t *) cls->data;
178
- memset(cls->data, 0, n_tokens * lm_ggml_element_size(cls));
179
-
180
- for (int s = 0; s < n_seqs; ++s) {
181
- const llama_seq_id seq_id = ubatch->seq_id[s][0];
177
+ memset(cls->data, 0, n_seqs_unq*lm_ggml_element_size(cls));
182
178
 
183
- // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
184
- LM_GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");
179
+ for (int i = 0; i < n_tokens; i += n_seq_tokens) {
180
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
181
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
182
+ const int32_t seq_idx = ubatch->seq_idx[seq_id];
185
183
 
186
- for (int i = 0; i < n_seq_tokens; ++i) {
187
- const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
188
-
189
- if (pos == 0) {
190
- data[seq_id] = s*n_seq_tokens + i;
191
- }
184
+ data[seq_idx] = i;
192
185
  }
193
186
  }
194
187
  }
195
188
 
196
189
  if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
197
- const int64_t n_tokens = ubatch->n_tokens;
198
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
199
- const int64_t n_seqs = ubatch->n_seqs;
200
-
201
190
  LM_GGML_ASSERT(cls);
202
191
  LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(cls->buffer));
203
192
 
204
193
  uint32_t * data = (uint32_t *) cls->data;
205
- memset(cls->data, 0, n_tokens * lm_ggml_element_size(cls));
206
-
207
- std::vector<int> last_pos(n_tokens, -1);
208
- std::vector<int> last_row(n_tokens, -1);
194
+ memset(cls->data, 0, n_seqs_unq*lm_ggml_element_size(cls));
209
195
 
210
- for (int s = 0; s < n_seqs; ++s) {
211
- const llama_seq_id seq_id = ubatch->seq_id[s][0];
196
+ std::vector<int> last_pos(n_seqs_unq, -1);
197
+ std::vector<int> last_row(n_seqs_unq, -1);
212
198
 
213
- // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
214
- LM_GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
199
+ for (int i = 0; i < n_tokens; ++i) {
200
+ const llama_pos pos = ubatch->pos[i];
215
201
 
216
- for (int i = 0; i < n_seq_tokens; ++i) {
217
- const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
202
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
203
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
204
+ const int32_t seq_idx = ubatch->seq_idx[seq_id];
218
205
 
219
- if (pos >= last_pos[seq_id]) {
220
- last_pos[seq_id] = pos;
221
- last_row[seq_id] = s*n_seq_tokens + i;
206
+ if (pos >= last_pos[seq_idx]) {
207
+ last_pos[seq_idx] = pos;
208
+ last_row[seq_idx] = i;
222
209
  }
223
210
  }
224
211
  }
225
212
 
226
- for (int i = 0; i < n_tokens; ++i) {
227
- if (last_row[i] >= 0) {
228
- data[i] = last_row[i];
213
+ for (int s = 0; s < n_seqs_unq; ++s) {
214
+ if (last_row[s] >= 0) {
215
+ data[s] = last_row[s];
229
216
  }
230
217
  }
231
218
  }
232
219
  }
233
220
 
234
- void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
221
+ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
235
222
  LM_GGML_UNUSED(ubatch);
236
223
 
237
- const int64_t n_kv = kv_self->n;
224
+ const int64_t n_rs = mctx->get_n_rs();
238
225
 
239
226
  if (s_copy) {
240
227
  LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(s_copy->buffer));
241
228
  int32_t * data = (int32_t *) s_copy->data;
242
229
 
243
230
  // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
244
- for (uint32_t i = 0; i < n_kv; ++i) {
245
- data[i] = kv_self->s_copy(i);
246
- }
247
- }
248
- }
249
-
250
- void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
251
- LM_GGML_UNUSED(ubatch);
252
-
253
- const int64_t n_kv = kv_self->n;
254
-
255
- if (s_mask) {
256
- LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(s_mask->buffer));
257
- float * data = (float *) s_mask->data;
258
-
259
- // clear unused states
260
- for (int i = 0; i < n_kv; ++i) {
261
- data[i] = kv_self->s_mask(i);
231
+ for (uint32_t i = 0; i < n_rs; ++i) {
232
+ data[i] = mctx->s_copy(i);
262
233
  }
263
234
  }
264
235
  }
@@ -274,87 +245,36 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
274
245
  }
275
246
 
276
247
  void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
277
- if (kq_mask) {
278
- if (cparams.causal_attn) {
279
- const int64_t n_kv = ubatch->n_tokens;
280
- const int64_t n_tokens = ubatch->n_tokens;
281
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
282
- const int64_t n_seqs = ubatch->n_seqs;
283
-
284
- LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(kq_mask->buffer));
285
- float * data = (float *) kq_mask->data;
286
-
287
- for (int h = 0; h < 1; ++h) {
288
- for (int s1 = 0; s1 < n_seqs; ++s1) {
289
- const llama_seq_id seq_id = ubatch->seq_id[s1][0];
290
-
291
- for (int j = 0; j < n_seq_tokens; ++j) {
292
- const int32_t tj = s1*n_seq_tokens + j;
293
-
294
- for (int s0 = 0; s0 < n_seqs; ++s0) {
295
- for (int i = 0; i < n_seq_tokens; ++i) {
296
- const int32_t ti = s0*n_seq_tokens + i;
297
- float f = -INFINITY;
298
-
299
- for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
300
- if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) {
301
- if (hparams.use_alibi) {
302
- f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
303
- } else {
304
- f = 0.0f;
305
- }
306
- break;
307
- }
308
- }
309
-
310
- data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f;
311
- }
312
- }
313
- }
314
- }
315
- }
316
- } else {
317
- const int64_t n_tokens = ubatch->n_tokens;
318
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
319
- const int64_t n_seqs = ubatch->n_seqs;
320
- const int64_t n_stride = ubatch->n_tokens;
321
-
322
- LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(kq_mask->buffer));
323
-
324
- float * data = (float *) kq_mask->data;
325
-
326
- for (int h = 0; h < 1; ++h) {
327
- for (int s1 = 0; s1 < n_seqs; ++s1) {
328
- const llama_seq_id seq_id = ubatch->seq_id[s1][0];
329
-
330
- for (int j = 0; j < n_seq_tokens; ++j) {
331
- const int32_t tj = s1*n_seq_tokens + j;
332
-
333
- for (int s0 = 0; s0 < n_seqs; ++s0) {
334
- for (int i = 0; i < n_seq_tokens; ++i) {
335
- const int32_t ti = s0*n_seq_tokens + i;
336
- float f = -INFINITY;
337
-
338
- for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
339
- if (ubatch->seq_id[s0][s] == seq_id) {
340
- if (hparams.use_alibi) {
341
- f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
342
- } else {
343
- f = 0.0f;
344
- }
345
- break;
346
- }
347
- }
348
-
349
- data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
350
- }
351
- }
248
+ const int64_t n_kv = ubatch->n_tokens;
249
+ const int64_t n_tokens = ubatch->n_tokens;
250
+
251
+ LM_GGML_ASSERT(kq_mask);
252
+ LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(kq_mask->buffer));
253
+
254
+ float * data = (float *) kq_mask->data;
255
+
256
+ for (int h = 0; h < 1; ++h) {
257
+ for (int i1 = 0; i1 < n_tokens; ++i1) {
258
+ const llama_seq_id s1 = ubatch->seq_id[i1][0];
352
259
 
353
- for (int i = n_tokens; i < n_stride; ++i) {
354
- data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
260
+ for (int i0 = 0; i0 < n_tokens; ++i0) {
261
+ float f = -INFINITY;
262
+
263
+ for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
264
+ const llama_seq_id s0 = ubatch->seq_id[i0][0];
265
+
266
+ // TODO: reimplement this like in llama_kv_cache_unified
267
+ if (s0 == s1 && (!cparams.causal_attn || ubatch->pos[i0] <= ubatch->pos[i1])) {
268
+ if (hparams.use_alibi) {
269
+ f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
270
+ } else {
271
+ f = 0.0f;
355
272
  }
273
+ break;
356
274
  }
357
275
  }
276
+
277
+ data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
358
278
  }
359
279
  }
360
280
  }
@@ -362,53 +282,80 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
362
282
 
363
283
  void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
364
284
  if (self_kq_mask) {
365
- kv_self->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
285
+ mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
366
286
  }
367
287
  }
368
288
 
369
289
  void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
370
290
  if (self_kq_mask) {
371
- kv_self->get_kv_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
291
+ mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
372
292
  }
373
293
 
374
294
  if (self_kq_mask_swa) {
375
- kv_self->get_kv_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
295
+ mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
376
296
  }
377
297
  }
378
298
 
379
299
  void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
380
- if (cross_kq_mask) {
381
- const int64_t n_enc = cross_kq_mask->ne[0];
382
- const int64_t n_tokens = ubatch->n_tokens;
300
+ LM_GGML_ASSERT(cross_kq_mask);
383
301
 
384
- LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(cross_kq_mask->buffer));
385
- LM_GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
302
+ const int64_t n_enc = cross_kq_mask->ne[0];
303
+ const int64_t n_tokens = ubatch->n_tokens;
386
304
 
387
- float * data = (float *) cross_kq_mask->data;
305
+ LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(cross_kq_mask->buffer));
306
+ LM_GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
388
307
 
389
- for (int h = 0; h < 1; ++h) {
390
- for (int j = 0; j < n_tokens; ++j) {
391
- for (int i = 0; i < n_enc; ++i) {
392
- float f = -INFINITY;
393
- for (int s = 0; s < ubatch->n_seq_id[j]; ++s) {
394
- const llama_seq_id seq_id = ubatch->seq_id[j][s];
395
- if (cross->seq_ids_enc[i].find(seq_id) != cross->seq_ids_enc[i].end()) {
396
- f = 0.0f;
397
- }
308
+ float * data = (float *) cross_kq_mask->data;
309
+
310
+ for (int h = 0; h < 1; ++h) {
311
+ for (int i = 0; i < n_tokens; ++i) {
312
+ for (int j = 0; j < n_enc; ++j) {
313
+ float f = -INFINITY;
314
+
315
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
316
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
317
+
318
+ if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
319
+ f = 0.0f;
398
320
  }
399
- data[h*(n_enc*n_tokens) + j*n_enc + i] = f;
400
321
  }
322
+
323
+ data[h*(n_enc*n_tokens) + i*n_enc + j] = f;
401
324
  }
325
+ }
402
326
 
403
- for (int i = n_tokens; i < LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD); ++i) {
404
- for (int j = 0; j < n_enc; ++j) {
405
- data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
406
- }
327
+ for (int i = n_tokens; i < LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD); ++i) {
328
+ for (int j = 0; j < n_enc; ++j) {
329
+ data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
407
330
  }
408
331
  }
409
332
  }
410
333
  }
411
334
 
335
+ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
336
+ if (self_kq_mask) {
337
+ mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
338
+ }
339
+
340
+ const int64_t n_rs = mctx->get_recr()->get_n_rs();
341
+
342
+ if (s_copy) {
343
+ LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(s_copy->buffer));
344
+ int32_t * data = (int32_t *) s_copy->data;
345
+
346
+ // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
347
+ for (uint32_t i = 0; i < n_rs; ++i) {
348
+ data[i] = mctx->get_recr()->s_copy(i);
349
+ }
350
+ }
351
+ }
352
+
353
+ void llm_graph_input_one::set_input(const llama_ubatch *) {
354
+ LM_GGML_ASSERT(one && lm_ggml_nelements(one) == 1);
355
+ float f_one = 1.0f;
356
+ lm_ggml_backend_tensor_set(one, &f_one, 0, sizeof(float));
357
+ }
358
+
412
359
  //
413
360
  // llm_graph_context
414
361
  //
@@ -448,16 +395,12 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
448
395
  backend_cpu (params.backend_cpu),
449
396
  cvec (params.cvec),
450
397
  loras (params.loras),
451
- memory (params.memory),
398
+ mctx (params.mctx),
452
399
  cross (params.cross),
453
400
  cb_func (params.cb),
454
401
  res (std::make_unique<llm_graph_result>()) {
455
402
  }
456
403
 
457
- int64_t llm_graph_context::n_pos_per_embd() const {
458
- return arch == LLM_ARCH_QWEN2VL ? 4 : 1;
459
- }
460
-
461
404
  void llm_graph_context::cb(lm_ggml_tensor * cur, const char * name, int il) const {
462
405
  if (cb_func) {
463
406
  cb_func(ubatch, cur, name, il);
@@ -647,6 +590,7 @@ lm_ggml_tensor * llm_graph_context::build_ffn(
647
590
  {
648
591
  // Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
649
592
  int64_t split_point = cur->ne[0] / 2;
593
+ // TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
650
594
  lm_ggml_tensor * x0 = lm_ggml_cont(ctx0, lm_ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
651
595
  lm_ggml_tensor * x1 = lm_ggml_cont(ctx0, lm_ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * lm_ggml_element_size(cur)));
652
596
 
@@ -656,6 +600,20 @@ lm_ggml_tensor * llm_graph_context::build_ffn(
656
600
  cur = lm_ggml_mul(ctx0, x0, x1);
657
601
  cb(cur, "ffn_mul", il);
658
602
  } break;
603
+ case LLM_FFN_GEGLU:
604
+ {
605
+ // Split into two equal parts
606
+ int64_t split_point = cur->ne[0] / 2;
607
+ // TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
608
+ lm_ggml_tensor * x0 = lm_ggml_cont(ctx0, lm_ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
609
+ lm_ggml_tensor * x1 = lm_ggml_cont(ctx0, lm_ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * lm_ggml_element_size(cur)));
610
+
611
+ x0 = lm_ggml_gelu(ctx0, x0);
612
+ cb(x0, "ffn_gelu", il);
613
+
614
+ cur = lm_ggml_mul(ctx0, x0, x1);
615
+ cb(cur, "ffn_geglu", il);
616
+ } break;
659
617
  }
660
618
 
661
619
  if (gate && type_gate == LLM_FFN_PAR) {
@@ -766,9 +724,8 @@ lm_ggml_tensor * llm_graph_context::build_moe_ffn(
766
724
  cur = lm_ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
767
725
 
768
726
  if (weight_before_ffn) {
769
- // TODO: this is a workaround as we don't yet have a repeat op that takes custom dim (lm_ggml_repeat_4d)
770
- lm_ggml_tensor * repeated = lm_ggml_new_tensor_3d(ctx0, cur->type, n_embd, n_expert_used, n_tokens);
771
- repeated = lm_ggml_repeat(ctx0, cur, repeated); // [n_embd, n_expert_used, n_tokens]
727
+ // repeat cur to [n_embd, n_expert_used, n_tokens]
728
+ lm_ggml_tensor * repeated = lm_ggml_repeat_4d(ctx0, cur, n_embd, n_expert_used, n_tokens, 1);
772
729
  cur = lm_ggml_mul(ctx0, repeated, weights);
773
730
  cb(cur, "ffn_moe_weighted", il);
774
731
  }
@@ -888,11 +845,11 @@ lm_ggml_tensor * llm_graph_context::build_inp_embd(lm_ggml_tensor * tok_embd) co
888
845
  }
889
846
 
890
847
  lm_ggml_tensor * llm_graph_context::build_inp_pos() const {
891
- auto inp = std::make_unique<llm_graph_input_pos>(n_pos_per_embd());
848
+ auto inp = std::make_unique<llm_graph_input_pos>(hparams.n_pos_per_embd());
892
849
 
893
850
  auto & cur = inp->pos;
894
851
 
895
- cur = lm_ggml_new_tensor_1d(ctx0, LM_GGML_TYPE_I32, n_tokens*n_pos_per_embd());
852
+ cur = lm_ggml_new_tensor_1d(ctx0, LM_GGML_TYPE_I32, (int64_t)n_tokens*hparams.n_pos_per_embd());
896
853
  lm_ggml_set_input(cur);
897
854
 
898
855
  res->add_input(std::move(inp));
@@ -915,6 +872,14 @@ lm_ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
915
872
  }
916
873
 
917
874
  lm_ggml_tensor * llm_graph_context::build_inp_out_ids() const {
875
+ // note: when all tokens are output, we could skip this optimization to spare the lm_ggml_get_rows() calls,
876
+ // but this would make the graph topology depend on the number of output tokens, which can interere with
877
+ // features that require constant topology such as pipline parallelism
878
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14275#issuecomment-2987424471
879
+ //if (n_outputs < n_tokens) {
880
+ // return nullptr;
881
+ //}
882
+
918
883
  auto inp = std::make_unique<llm_graph_input_out_ids>(hparams, cparams, n_outputs);
919
884
 
920
885
  auto & cur = inp->out_ids;
@@ -932,7 +897,7 @@ lm_ggml_tensor * llm_graph_context::build_inp_mean() const {
932
897
 
933
898
  auto & cur = inp->mean;
934
899
 
935
- cur = lm_ggml_new_tensor_2d(ctx0, LM_GGML_TYPE_F32, n_tokens, n_tokens);
900
+ cur = lm_ggml_new_tensor_2d(ctx0, LM_GGML_TYPE_F32, n_tokens, ubatch.n_seqs_unq);
936
901
  lm_ggml_set_input(cur);
937
902
 
938
903
  res->add_input(std::move(inp));
@@ -945,41 +910,7 @@ lm_ggml_tensor * llm_graph_context::build_inp_cls() const {
945
910
 
946
911
  auto & cur = inp->cls;
947
912
 
948
- cur = lm_ggml_new_tensor_1d(ctx0, LM_GGML_TYPE_I32, n_tokens);
949
- lm_ggml_set_input(cur);
950
-
951
- res->add_input(std::move(inp));
952
-
953
- return cur;
954
- }
955
-
956
- lm_ggml_tensor * llm_graph_context::build_inp_s_copy() const {
957
- const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
958
-
959
- auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self);
960
-
961
- const auto n_kv = kv_self->n;
962
-
963
- auto & cur = inp->s_copy;
964
-
965
- cur = lm_ggml_new_tensor_1d(ctx0, LM_GGML_TYPE_I32, n_kv);
966
- lm_ggml_set_input(cur);
967
-
968
- res->add_input(std::move(inp));
969
-
970
- return cur;
971
- }
972
-
973
- lm_ggml_tensor * llm_graph_context::build_inp_s_mask() const {
974
- const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
975
-
976
- auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self);
977
-
978
- const auto n_kv = kv_self->n;
979
-
980
- auto & cur = inp->s_mask;
981
-
982
- cur = lm_ggml_new_tensor_2d(ctx0, LM_GGML_TYPE_F32, 1, n_kv);
913
+ cur = lm_ggml_new_tensor_1d(ctx0, LM_GGML_TYPE_I32, ubatch.n_seqs_unq);
983
914
  lm_ggml_set_input(cur);
984
915
 
985
916
  res->add_input(std::move(inp));
@@ -1025,11 +956,11 @@ lm_ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
1025
956
  }
1026
957
 
1027
958
  lm_ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
1028
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
959
+ const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
1029
960
 
1030
- auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_self);
961
+ auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);
1031
962
 
1032
- const auto n_kv = kv_self->get_n();
963
+ const auto n_kv = mctx_cur->get_n_kv();
1033
964
 
1034
965
  auto & cur = inp->pos_bucket;
1035
966
 
@@ -1056,6 +987,33 @@ lm_ggml_tensor * llm_graph_context::build_pos_bias(lm_ggml_tensor * pos_bucket,
1056
987
  return pos_bias;
1057
988
  }
1058
989
 
990
+ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
991
+ const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
992
+
993
+ auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mctx_cur);
994
+
995
+ {
996
+ LM_GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
997
+
998
+ const auto n_kv = inp->mctx->get_attn()->get_n_kv();
999
+
1000
+ inp->self_kq_mask = lm_ggml_new_tensor_2d(ctx0, LM_GGML_TYPE_F32, n_kv, LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD));
1001
+ //cb(inp->self_kq_mask, "KQ_mask", -1);
1002
+ lm_ggml_set_input(inp->self_kq_mask);
1003
+
1004
+ inp->self_kq_mask_cnv = cparams.flash_attn ? lm_ggml_cast(ctx0, inp->self_kq_mask, LM_GGML_TYPE_F16) : inp->self_kq_mask;
1005
+ }
1006
+
1007
+ {
1008
+ const auto n_rs = mctx_cur->get_recr()->get_n_rs();
1009
+
1010
+ inp->s_copy = lm_ggml_new_tensor_1d(ctx0, LM_GGML_TYPE_I32, n_rs);
1011
+ lm_ggml_set_input(inp->s_copy);
1012
+ }
1013
+
1014
+ return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
1015
+ }
1016
+
1059
1017
  lm_ggml_tensor * llm_graph_context::build_attn_mha(
1060
1018
  lm_ggml_cgraph * gf,
1061
1019
  lm_ggml_tensor * q,
@@ -1231,14 +1189,14 @@ lm_ggml_tensor * llm_graph_context::build_attn(
1231
1189
  }
1232
1190
 
1233
1191
  llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
1234
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1192
+ const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
1235
1193
 
1236
- auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
1194
+ auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
1237
1195
 
1238
1196
  {
1239
1197
  LM_GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
1240
1198
 
1241
- const auto n_kv = kv_self->get_n();
1199
+ const auto n_kv = mctx_cur->get_n_kv();
1242
1200
 
1243
1201
  inp->self_kq_mask = lm_ggml_new_tensor_2d(ctx0, LM_GGML_TYPE_F32, n_kv, LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD));
1244
1202
  //cb(inp->self_kq_mask, "KQ_mask", -1);
@@ -1268,25 +1226,29 @@ lm_ggml_tensor * llm_graph_context::build_attn(
1268
1226
  lm_ggml_build_forward_expand(gf, k_cur);
1269
1227
  lm_ggml_build_forward_expand(gf, v_cur);
1270
1228
 
1271
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1229
+ const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
1272
1230
 
1273
1231
  // store to KV cache
1274
1232
  {
1275
- lm_ggml_build_forward_expand(gf, kv_self->cpy_k(ctx0, k_cur, il));
1276
- lm_ggml_build_forward_expand(gf, kv_self->cpy_v(ctx0, v_cur, il));
1233
+ lm_ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1234
+ lm_ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
1277
1235
  }
1278
1236
 
1279
1237
  const auto & kq_mask = inp->get_kq_mask();
1280
1238
 
1281
1239
  lm_ggml_tensor * q = q_cur;
1282
- lm_ggml_tensor * k = kv_self->get_k(ctx0, il);
1283
- lm_ggml_tensor * v = kv_self->get_v(ctx0, il);
1240
+ lm_ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1241
+ lm_ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1284
1242
 
1285
1243
  lm_ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1286
1244
  cb(cur, "kqv_out", il);
1287
1245
 
1288
1246
  if (wo) {
1289
1247
  cur = build_lora_mm(wo, cur);
1248
+ if (arch == LLM_ARCH_GLM4) {
1249
+ // GLM4 seems to have numerical issues with half-precision accumulators
1250
+ lm_ggml_mul_mat_set_prec(cur, LM_GGML_PREC_F32);
1251
+ }
1290
1252
  }
1291
1253
 
1292
1254
  if (wo_b) {
@@ -1296,36 +1258,6 @@ lm_ggml_tensor * llm_graph_context::build_attn(
1296
1258
  return cur;
1297
1259
  }
1298
1260
 
1299
- llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
1300
- const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
1301
-
1302
- auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_self);
1303
-
1304
- {
1305
- const auto n_kv = kv_self->get_kv_base()->get_n();
1306
-
1307
- inp->self_kq_mask = lm_ggml_new_tensor_2d(ctx0, LM_GGML_TYPE_F32, n_kv, LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD));
1308
- //cb(inp->self_kq_mask, "KQ_mask", -1);
1309
- lm_ggml_set_input(inp->self_kq_mask);
1310
-
1311
- inp->self_kq_mask_cnv = cparams.flash_attn ? lm_ggml_cast(ctx0, inp->self_kq_mask, LM_GGML_TYPE_F16) : inp->self_kq_mask;
1312
- }
1313
-
1314
- {
1315
- LM_GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
1316
-
1317
- const auto n_kv = kv_self->get_kv_swa()->get_n();
1318
-
1319
- inp->self_kq_mask_swa = lm_ggml_new_tensor_2d(ctx0, LM_GGML_TYPE_F32, n_kv, LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD));
1320
- //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
1321
- lm_ggml_set_input(inp->self_kq_mask_swa);
1322
-
1323
- inp->self_kq_mask_swa_cnv = cparams.flash_attn ? lm_ggml_cast(ctx0, inp->self_kq_mask_swa, LM_GGML_TYPE_F16) : inp->self_kq_mask_swa;
1324
- }
1325
-
1326
- return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
1327
- }
1328
-
1329
1261
  lm_ggml_tensor * llm_graph_context::build_attn(
1330
1262
  llm_graph_input_attn_kv_unified_iswa * inp,
1331
1263
  lm_ggml_cgraph * gf,
@@ -1341,36 +1273,41 @@ lm_ggml_tensor * llm_graph_context::build_attn(
1341
1273
  // these nodes are added to the graph together so that they are not reordered
1342
1274
  // by doing so, the number of splits in the graph is reduced
1343
1275
  lm_ggml_build_forward_expand(gf, q_cur);
1344
- lm_ggml_build_forward_expand(gf, k_cur);
1345
- lm_ggml_build_forward_expand(gf, v_cur);
1276
+
1277
+ if (k_cur) {
1278
+ lm_ggml_build_forward_expand(gf, k_cur);
1279
+ }
1280
+
1281
+ if (v_cur) {
1282
+ lm_ggml_build_forward_expand(gf, v_cur);
1283
+ }
1284
+
1285
+ const auto * mctx_iswa = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
1346
1286
 
1347
1287
  const bool is_swa = hparams.is_swa(il);
1348
1288
 
1349
- const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
1289
+ const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();
1350
1290
 
1351
- const auto * kv = is_swa ? kv_self->get_kv_swa() : kv_self->get_kv_base();
1291
+ // optionally store to KV cache
1292
+ if (k_cur) {
1293
+ lm_ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1294
+ }
1352
1295
 
1353
- // store to KV cache
1354
- {
1355
- lm_ggml_build_forward_expand(gf, kv->cpy_k(ctx0, k_cur, il));
1356
- lm_ggml_build_forward_expand(gf, kv->cpy_v(ctx0, v_cur, il));
1296
+ if (v_cur) {
1297
+ lm_ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
1357
1298
  }
1358
1299
 
1359
1300
  const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
1360
1301
 
1361
1302
  lm_ggml_tensor * q = q_cur;
1362
- lm_ggml_tensor * k = kv->get_k(ctx0, il);
1363
- lm_ggml_tensor * v = kv->get_v(ctx0, il);
1303
+ lm_ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1304
+ lm_ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1364
1305
 
1365
1306
  lm_ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1366
1307
  cb(cur, "kqv_out", il);
1367
1308
 
1368
1309
  if (wo) {
1369
1310
  cur = build_lora_mm(wo, cur);
1370
- if (arch == LLM_ARCH_GLM4) {
1371
- // GLM4 seems to have numerical issues with half-precision accumulators
1372
- lm_ggml_mul_mat_set_prec(cur, LM_GGML_PREC_F32);
1373
- }
1374
1311
  }
1375
1312
 
1376
1313
  if (wo_b) {
@@ -1439,56 +1376,182 @@ lm_ggml_tensor * llm_graph_context::build_attn(
1439
1376
  return cur;
1440
1377
  }
1441
1378
 
1442
- lm_ggml_tensor * llm_graph_context::build_copy_mask_state(
1443
- lm_ggml_cgraph * gf,
1444
- lm_ggml_tensor * s,
1445
- lm_ggml_tensor * state_copy,
1446
- lm_ggml_tensor * state_mask,
1447
- int32_t n_state,
1448
- int32_t n_seqs) const {
1449
- const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1379
+ lm_ggml_tensor * llm_graph_context::build_attn(
1380
+ llm_graph_input_mem_hybrid * inp,
1381
+ lm_ggml_cgraph * gf,
1382
+ lm_ggml_tensor * wo,
1383
+ lm_ggml_tensor * wo_b,
1384
+ lm_ggml_tensor * q_cur,
1385
+ lm_ggml_tensor * k_cur,
1386
+ lm_ggml_tensor * v_cur,
1387
+ lm_ggml_tensor * kq_b,
1388
+ lm_ggml_tensor * v_mla,
1389
+ float kq_scale,
1390
+ int il) const {
1391
+ // these nodes are added to the graph together so that they are not reordered
1392
+ // by doing so, the number of splits in the graph is reduced
1393
+ lm_ggml_build_forward_expand(gf, q_cur);
1394
+ lm_ggml_build_forward_expand(gf, k_cur);
1395
+ lm_ggml_build_forward_expand(gf, v_cur);
1396
+
1397
+ const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_attn();
1398
+
1399
+ // store to KV cache
1400
+ {
1401
+ lm_ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1402
+ lm_ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
1403
+ }
1404
+
1405
+ const auto & kq_mask = inp->get_kq_mask();
1406
+
1407
+ lm_ggml_tensor * q = q_cur;
1408
+ lm_ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1409
+ lm_ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1410
+
1411
+ lm_ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1412
+ cb(cur, "kqv_out", il);
1413
+
1414
+ if (wo) {
1415
+ cur = build_lora_mm(wo, cur);
1416
+ if (arch == LLM_ARCH_GLM4) {
1417
+ // GLM4 seems to have numerical issues with half-precision accumulators
1418
+ lm_ggml_mul_mat_set_prec(cur, LM_GGML_PREC_F32);
1419
+ }
1420
+ }
1421
+
1422
+ if (wo_b) {
1423
+ cur = lm_ggml_add(ctx0, cur, wo_b);
1424
+ }
1425
+
1426
+ return cur;
1427
+ }
1428
+
1429
+ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
1430
+ const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
1431
+
1432
+ auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
1433
+
1434
+ {
1435
+ const auto n_kv = mctx_cur->get_base()->get_n_kv();
1436
+
1437
+ inp->self_kq_mask = lm_ggml_new_tensor_2d(ctx0, LM_GGML_TYPE_F32, n_kv, LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD));
1438
+ //cb(inp->self_kq_mask, "KQ_mask", -1);
1439
+ lm_ggml_set_input(inp->self_kq_mask);
1440
+
1441
+ inp->self_kq_mask_cnv = cparams.flash_attn ? lm_ggml_cast(ctx0, inp->self_kq_mask, LM_GGML_TYPE_F16) : inp->self_kq_mask;
1442
+ }
1443
+
1444
+ {
1445
+ LM_GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
1450
1446
 
1451
- const auto n_kv = kv_self->n;
1452
- const auto kv_head = kv_self->head;
1447
+ const auto n_kv = mctx_cur->get_swa()->get_n_kv();
1453
1448
 
1454
- lm_ggml_tensor * states = lm_ggml_reshape_2d(ctx0, s, n_state, kv_self->size);
1449
+ inp->self_kq_mask_swa = lm_ggml_new_tensor_2d(ctx0, LM_GGML_TYPE_F32, n_kv, LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD));
1450
+ //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
1451
+ lm_ggml_set_input(inp->self_kq_mask_swa);
1455
1452
 
1456
- // copy states
1457
- // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1458
- // this shrinks the tensors's ne[1] to n_kv
1459
- states = lm_ggml_get_rows(ctx0, states, state_copy);
1453
+ inp->self_kq_mask_swa_cnv = cparams.flash_attn ? lm_ggml_cast(ctx0, inp->self_kq_mask_swa, LM_GGML_TYPE_F16) : inp->self_kq_mask_swa;
1454
+ }
1460
1455
 
1461
- // clear states of sequences which are starting at the beginning of this batch
1462
- // FIXME: zero-out NANs?
1463
- states = lm_ggml_mul(ctx0, states, state_mask);
1456
+ return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
1457
+ }
1464
1458
 
1465
- // copy states which won't be changed further (between n_seqs and n_kv)
1459
+ lm_ggml_tensor * llm_graph_context::build_rs(
1460
+ lm_ggml_cgraph * gf,
1461
+ lm_ggml_tensor * s,
1462
+ lm_ggml_tensor * state_copy,
1463
+ int32_t state_size,
1464
+ int32_t n_seqs,
1465
+ uint32_t n_kv,
1466
+ uint32_t kv_head,
1467
+ uint32_t kv_size,
1468
+ int32_t rs_zero,
1469
+ bool avoid_copies) const {
1470
+
1471
+ lm_ggml_tensor * states = lm_ggml_reshape_2d(ctx0, s, state_size, kv_size);
1472
+
1473
+ // Clear a single state which will then be copied to the other cleared states.
1474
+ // Note that this is a no-op when the view is zero-sized.
1475
+ lm_ggml_tensor * state_zero = lm_ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
1476
+ lm_ggml_build_forward_expand(gf, lm_ggml_scale_inplace(ctx0, state_zero, 0));
1477
+
1478
+ lm_ggml_tensor * output_states;
1479
+
1480
+ if (!avoid_copies) {
1481
+ // copy states
1482
+ // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1483
+ // {state_size, kv_size} -> {state_size, n_seqs}
1484
+ output_states = lm_ggml_get_rows(ctx0, states, lm_ggml_view_1d(ctx0, state_copy, n_seqs, 0));
1485
+ lm_ggml_build_forward_expand(gf, output_states);
1486
+ } else {
1487
+ // FIXME: make the gathering operation happen before the copy below
1488
+ // (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
1489
+ output_states = states;
1490
+ }
1491
+
1492
+ // copy extra states which won't be changed further (between n_seqs and n_kv)
1493
+ lm_ggml_tensor * states_extra = lm_ggml_get_rows(ctx0, states, lm_ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
1466
1494
  lm_ggml_build_forward_expand(gf,
1467
1495
  lm_ggml_cpy(ctx0,
1468
- lm_ggml_view_1d(ctx0, states, n_state*(n_kv - n_seqs), (n_seqs )*n_state*lm_ggml_element_size(states)),
1469
- lm_ggml_view_1d(ctx0, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*lm_ggml_element_size(s))));
1496
+ states_extra,
1497
+ lm_ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*lm_ggml_element_size(s))));
1498
+
1499
+ return output_states;
1500
+ }
1501
+
1502
+ llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
1503
+ const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1504
+
1505
+ auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
1506
+
1507
+ const auto n_rs = mctx_cur->get_n_rs();
1508
+
1509
+ inp->s_copy = lm_ggml_new_tensor_1d(ctx0, LM_GGML_TYPE_I32, n_rs);
1510
+ lm_ggml_set_input(inp->s_copy);
1470
1511
 
1471
- // the part of the states that will be used and modified
1472
- return lm_ggml_view_2d(ctx0, states, n_state, n_seqs, states->nb[1], 0);
1512
+ return (llm_graph_input_rs *) res->add_input(std::move(inp));
1513
+ }
1514
+
1515
+ lm_ggml_tensor * llm_graph_context::build_rs(
1516
+ llm_graph_input_rs * inp,
1517
+ lm_ggml_cgraph * gf,
1518
+ lm_ggml_tensor * s,
1519
+ int32_t state_size,
1520
+ int32_t n_seqs,
1521
+ bool avoid_copies) const {
1522
+ const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1523
+
1524
+ return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
1525
+ }
1526
+
1527
+ lm_ggml_tensor * llm_graph_context::build_rs(
1528
+ llm_graph_input_mem_hybrid * inp,
1529
+ lm_ggml_cgraph * gf,
1530
+ lm_ggml_tensor * s,
1531
+ int32_t state_size,
1532
+ int32_t n_seqs,
1533
+ bool avoid_copies) const {
1534
+ const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
1535
+
1536
+ return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
1473
1537
  }
1474
1538
 
1475
1539
  lm_ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
1476
- lm_ggml_cgraph * gf,
1477
- lm_ggml_tensor * state_copy,
1478
- lm_ggml_tensor * state_mask,
1479
- const llama_ubatch & ubatch,
1540
+ llm_graph_input_rs * inp,
1541
+ lm_ggml_cgraph * gf,
1542
+ const llama_ubatch & ubatch,
1480
1543
  int il) const {
1481
- const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1544
+ const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1482
1545
 
1483
1546
  const auto token_shift_count = hparams.token_shift_count;
1484
1547
 
1485
1548
  const int64_t n_seqs = ubatch.n_seqs;
1486
1549
 
1487
- lm_ggml_tensor * token_shift_all = kv_self->k_l[il];
1550
+ lm_ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
1488
1551
 
1489
- lm_ggml_tensor * token_shift = build_copy_mask_state(
1490
- gf, token_shift_all, state_copy, state_mask,
1491
- hparams.n_embd_k_s(), n_seqs);
1552
+ lm_ggml_tensor * token_shift = build_rs(
1553
+ inp, gf, token_shift_all,
1554
+ hparams.n_embd_r(), n_seqs);
1492
1555
 
1493
1556
  token_shift = lm_ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
1494
1557
 
@@ -1499,19 +1562,19 @@ lm_ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
1499
1562
  lm_ggml_tensor * token_shift,
1500
1563
  const llama_ubatch & ubatch,
1501
1564
  int il) const {
1502
- const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1565
+ const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1503
1566
 
1504
1567
  const auto token_shift_count = hparams.token_shift_count;
1505
1568
  const auto n_embd = hparams.n_embd;
1506
1569
 
1507
1570
  const int64_t n_seqs = ubatch.n_seqs;
1508
1571
 
1509
- const auto kv_head = kv_self->head;
1572
+ const auto kv_head = mctx_cur->get_head();
1510
1573
 
1511
1574
  return lm_ggml_cpy(
1512
1575
  ctx0,
1513
1576
  lm_ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
1514
- lm_ggml_view_1d(ctx0, kv_self->k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * lm_ggml_element_size(kv_self->k_l[il]))
1577
+ lm_ggml_view_1d(ctx0, mctx_cur->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*lm_ggml_element_size(mctx_cur->get_r_l(il)))
1515
1578
  );
1516
1579
  }
1517
1580
 
@@ -1562,20 +1625,32 @@ void llm_graph_context::build_pooling(
1562
1625
  lm_ggml_tensor * inp_cls = build_inp_cls();
1563
1626
  inp = lm_ggml_get_rows(ctx0, inp, inp_cls);
1564
1627
 
1565
- // classification head
1566
- // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
1567
- LM_GGML_ASSERT(cls != nullptr);
1568
- LM_GGML_ASSERT(cls_b != nullptr);
1569
-
1570
- cur = lm_ggml_add (ctx0, lm_ggml_mul_mat(ctx0, cls, inp), cls_b);
1571
- cur = lm_ggml_tanh(ctx0, cur);
1572
-
1573
- // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
1574
- // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
1575
- if (cls_out) {
1576
- LM_GGML_ASSERT(cls_out_b != nullptr);
1577
-
1578
- cur = lm_ggml_add (ctx0, lm_ggml_mul_mat(ctx0, cls_out, cur), cls_out_b);
1628
+ if (cls) {
1629
+ // classification head
1630
+ // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
1631
+ cur = lm_ggml_mul_mat(ctx0, cls, inp);
1632
+ if (cls_b) {
1633
+ cur = lm_ggml_add(ctx0, cur, cls_b);
1634
+ }
1635
+ cur = lm_ggml_tanh(ctx0, cur);
1636
+
1637
+ // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
1638
+ // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
1639
+ if (cls_out) {
1640
+ cur = lm_ggml_mul_mat(ctx0, cls_out, cur);
1641
+ if (cls_out_b) {
1642
+ cur = lm_ggml_add(ctx0, cur, cls_out_b);
1643
+ }
1644
+ }
1645
+ } else if (cls_out) {
1646
+ // Single layer classification head (direct projection)
1647
+ // https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
1648
+ cur = lm_ggml_mul_mat(ctx0, cls_out, inp);
1649
+ if (cls_out_b) {
1650
+ cur = lm_ggml_add(ctx0, cur, cls_out_b);
1651
+ }
1652
+ } else {
1653
+ LM_GGML_ABORT("RANK pooling requires either cls+cls_b or cls_out+cls_out_b");
1579
1654
  }
1580
1655
  } break;
1581
1656
  default: