cui-llama.rn 1.5.0 → 1.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (324) hide show
  1. package/LICENSE +20 -20
  2. package/README.md +345 -319
  3. package/android/build.gradle +116 -116
  4. package/android/gradle.properties +5 -5
  5. package/android/src/main/AndroidManifest.xml +4 -4
  6. package/android/src/main/CMakeLists.txt +129 -124
  7. package/android/src/main/java/com/rnllama/LlamaContext.java +648 -645
  8. package/android/src/main/java/com/rnllama/RNLlama.java +695 -695
  9. package/android/src/main/java/com/rnllama/RNLlamaPackage.java +48 -48
  10. package/android/src/main/jni-utils.h +100 -100
  11. package/android/src/main/jni.cpp +1279 -1263
  12. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  13. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  14. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  15. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  16. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  17. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  18. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  19. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  20. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +135 -135
  21. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +136 -136
  22. package/cpp/LICENSE +21 -0
  23. package/cpp/README.md +4 -4
  24. package/cpp/chat.cpp +1 -1
  25. package/cpp/common.cpp +17 -2
  26. package/cpp/common.h +7 -3
  27. package/cpp/ggml-alloc.c +4 -1
  28. package/cpp/ggml-cpp.h +1 -1
  29. package/cpp/ggml-cpu/amx/amx.cpp +221 -0
  30. package/cpp/ggml-cpu/amx/amx.h +8 -0
  31. package/cpp/ggml-cpu/amx/common.h +91 -0
  32. package/cpp/ggml-cpu/amx/mmq.cpp +2511 -0
  33. package/cpp/ggml-cpu/amx/mmq.h +10 -0
  34. package/cpp/{binary-ops.h → ggml-cpu/binary-ops.h} +1 -1
  35. package/cpp/ggml-cpu/common.h +72 -0
  36. package/cpp/{ggml-cpu-aarch64.cpp → ggml-cpu/ggml-cpu-aarch64.cpp} +809 -101
  37. package/cpp/{ggml-cpu.c → ggml-cpu/ggml-cpu.c} +109 -42
  38. package/cpp/{ggml-cpu.cpp → ggml-cpu/ggml-cpu.cpp} +3 -0
  39. package/cpp/{ops.cpp → ggml-cpu/ops.cpp} +246 -160
  40. package/cpp/{ops.h → ggml-cpu/ops.h} +2 -20
  41. package/cpp/{sgemm.cpp → ggml-cpu/sgemm.cpp} +501 -0
  42. package/cpp/{simd-mappings.h → ggml-cpu/simd-mappings.h} +7 -3
  43. package/cpp/{unary-ops.h → ggml-cpu/unary-ops.h} +1 -1
  44. package/cpp/ggml-cpu.h +5 -0
  45. package/cpp/ggml-impl.h +16 -9
  46. package/cpp/ggml-llama-sim.metallib +0 -0
  47. package/cpp/ggml-llama.metallib +0 -0
  48. package/cpp/ggml-metal-impl.h +597 -597
  49. package/cpp/ggml-metal.m +496 -47
  50. package/cpp/ggml.c +134 -244
  51. package/cpp/ggml.h +62 -95
  52. package/cpp/json-schema-to-grammar.cpp +3 -0
  53. package/cpp/llama-arch.cpp +46 -17
  54. package/cpp/llama-arch.h +9 -0
  55. package/cpp/llama-batch.cpp +5 -1
  56. package/cpp/llama-batch.h +2 -1
  57. package/cpp/llama-chat.cpp +31 -10
  58. package/cpp/llama-chat.h +3 -2
  59. package/cpp/llama-context.cpp +104 -489
  60. package/cpp/llama-context.h +14 -30
  61. package/cpp/llama-graph.cpp +69 -62
  62. package/cpp/llama-graph.h +21 -18
  63. package/cpp/llama-hparams.h +5 -0
  64. package/cpp/llama-kv-cache.cpp +1497 -391
  65. package/cpp/llama-kv-cache.h +272 -80
  66. package/cpp/llama-memory.h +11 -1
  67. package/cpp/llama-model.cpp +502 -176
  68. package/cpp/llama-model.h +13 -3
  69. package/cpp/llama-sampling.cpp +2 -1
  70. package/cpp/llama-vocab.cpp +8 -1
  71. package/cpp/llama.h +14 -11
  72. package/cpp/rn-llama.cpp +721 -873
  73. package/cpp/rn-llama.h +134 -138
  74. package/cpp/sampling.h +107 -107
  75. package/cpp/unicode-data.cpp +7034 -7034
  76. package/cpp/unicode-data.h +20 -20
  77. package/cpp/unicode.cpp +849 -849
  78. package/cpp/unicode.h +66 -66
  79. package/ios/CMakeLists.txt +119 -108
  80. package/ios/RNLlama.h +13 -7
  81. package/ios/RNLlama.mm +423 -405
  82. package/ios/RNLlamaContext.h +57 -57
  83. package/ios/RNLlamaContext.mm +833 -835
  84. package/ios/rnllama.xcframework/Info.plist +74 -74
  85. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat.h +143 -0
  86. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +681 -0
  87. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/cpu-common.h +72 -0
  88. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-alloc.h +76 -0
  89. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-backend-impl.h +255 -0
  90. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-backend.h +354 -0
  91. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-common.h +1857 -0
  92. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpp.h +39 -0
  93. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu.h +143 -0
  94. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-impl.h +601 -0
  95. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-metal-impl.h +597 -0
  96. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-metal.h +66 -0
  97. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-opt.h +216 -0
  98. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-quants.h +100 -0
  99. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-threading.h +14 -0
  100. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +2189 -0
  101. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/gguf.h +202 -0
  102. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/json-schema-to-grammar.h +21 -0
  103. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/json.hpp +24766 -0
  104. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-adapter.h +76 -0
  105. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-arch.h +437 -0
  106. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-batch.h +89 -0
  107. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +57 -0
  108. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +249 -0
  109. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cparams.h +38 -0
  110. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cpp.h +30 -0
  111. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-grammar.h +173 -0
  112. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +595 -0
  113. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +161 -0
  114. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-impl.h +61 -0
  115. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-io.h +35 -0
  116. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +405 -0
  117. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +31 -0
  118. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-mmap.h +68 -0
  119. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model-loader.h +169 -0
  120. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +419 -0
  121. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-sampling.h +32 -0
  122. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-vocab.h +125 -0
  123. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +1437 -0
  124. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/log.h +132 -0
  125. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/chat-template.hpp +537 -0
  126. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/minja.hpp +2941 -0
  127. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +134 -0
  128. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/sampling.h +107 -0
  129. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/speculative.h +28 -0
  130. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/unicode-data.h +20 -0
  131. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/unicode.h +66 -0
  132. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Info.plist +0 -0
  133. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  134. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
  135. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +143 -0
  136. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +681 -0
  137. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/cpu-common.h +72 -0
  138. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-alloc.h +76 -0
  139. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend-impl.h +255 -0
  140. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +354 -0
  141. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-common.h +1857 -0
  142. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +39 -0
  143. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +143 -0
  144. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +601 -0
  145. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +597 -0
  146. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal.h +66 -0
  147. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +216 -0
  148. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-quants.h +100 -0
  149. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-threading.h +14 -0
  150. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +2189 -0
  151. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/gguf.h +202 -0
  152. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/json-schema-to-grammar.h +21 -0
  153. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/json.hpp +24766 -0
  154. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-adapter.h +76 -0
  155. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +437 -0
  156. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +89 -0
  157. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +57 -0
  158. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +249 -0
  159. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +38 -0
  160. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cpp.h +30 -0
  161. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-grammar.h +173 -0
  162. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +595 -0
  163. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +161 -0
  164. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-impl.h +61 -0
  165. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-io.h +35 -0
  166. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +405 -0
  167. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +31 -0
  168. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-mmap.h +68 -0
  169. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-loader.h +169 -0
  170. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +419 -0
  171. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-sampling.h +32 -0
  172. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +125 -0
  173. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +1437 -0
  174. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/log.h +132 -0
  175. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +537 -0
  176. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +2941 -0
  177. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +134 -0
  178. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/sampling.h +107 -0
  179. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/speculative.h +28 -0
  180. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/unicode-data.h +20 -0
  181. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/unicode.h +66 -0
  182. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
  183. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +101 -0
  184. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  185. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  186. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat.h +143 -0
  187. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +681 -0
  188. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/cpu-common.h +72 -0
  189. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-alloc.h +76 -0
  190. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-backend-impl.h +255 -0
  191. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-backend.h +354 -0
  192. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-common.h +1857 -0
  193. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpp.h +39 -0
  194. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu.h +143 -0
  195. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-impl.h +601 -0
  196. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-metal-impl.h +597 -0
  197. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-metal.h +66 -0
  198. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-opt.h +216 -0
  199. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-quants.h +100 -0
  200. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-threading.h +14 -0
  201. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +2189 -0
  202. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/gguf.h +202 -0
  203. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/json-schema-to-grammar.h +21 -0
  204. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/json.hpp +24766 -0
  205. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-adapter.h +76 -0
  206. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-arch.h +437 -0
  207. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-batch.h +89 -0
  208. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +57 -0
  209. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +249 -0
  210. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cparams.h +38 -0
  211. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cpp.h +30 -0
  212. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-grammar.h +173 -0
  213. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +595 -0
  214. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +161 -0
  215. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-impl.h +61 -0
  216. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-io.h +35 -0
  217. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +405 -0
  218. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +31 -0
  219. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-mmap.h +68 -0
  220. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model-loader.h +169 -0
  221. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +419 -0
  222. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-sampling.h +32 -0
  223. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-vocab.h +125 -0
  224. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +1437 -0
  225. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/log.h +132 -0
  226. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/chat-template.hpp +537 -0
  227. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/minja.hpp +2941 -0
  228. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +134 -0
  229. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/sampling.h +107 -0
  230. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/speculative.h +28 -0
  231. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/unicode-data.h +20 -0
  232. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/unicode.h +66 -0
  233. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Info.plist +0 -0
  234. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  235. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
  236. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +143 -0
  237. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +681 -0
  238. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/cpu-common.h +72 -0
  239. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-alloc.h +76 -0
  240. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend-impl.h +255 -0
  241. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +354 -0
  242. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-common.h +1857 -0
  243. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +39 -0
  244. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +143 -0
  245. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +601 -0
  246. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +597 -0
  247. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal.h +66 -0
  248. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +216 -0
  249. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-quants.h +100 -0
  250. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-threading.h +14 -0
  251. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +2189 -0
  252. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/gguf.h +202 -0
  253. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json-schema-to-grammar.h +21 -0
  254. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json.hpp +24766 -0
  255. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-adapter.h +76 -0
  256. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +437 -0
  257. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +89 -0
  258. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +57 -0
  259. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +249 -0
  260. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +38 -0
  261. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cpp.h +30 -0
  262. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-grammar.h +173 -0
  263. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +595 -0
  264. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +161 -0
  265. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-impl.h +61 -0
  266. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-io.h +35 -0
  267. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +405 -0
  268. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +31 -0
  269. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-mmap.h +68 -0
  270. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-loader.h +169 -0
  271. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +419 -0
  272. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-sampling.h +32 -0
  273. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +125 -0
  274. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +1437 -0
  275. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/log.h +132 -0
  276. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +537 -0
  277. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +2941 -0
  278. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +134 -0
  279. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/sampling.h +107 -0
  280. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/speculative.h +28 -0
  281. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/unicode-data.h +20 -0
  282. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/unicode.h +66 -0
  283. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
  284. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +101 -0
  285. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  286. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  287. package/jest/mock.js +203 -203
  288. package/lib/commonjs/NativeRNLlama.js +1 -2
  289. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  290. package/lib/commonjs/chat.js.map +1 -1
  291. package/lib/commonjs/grammar.js +12 -31
  292. package/lib/commonjs/grammar.js.map +1 -1
  293. package/lib/commonjs/index.js +47 -47
  294. package/lib/commonjs/index.js.map +1 -1
  295. package/lib/commonjs/package.json +1 -0
  296. package/lib/module/NativeRNLlama.js +2 -0
  297. package/lib/module/NativeRNLlama.js.map +1 -1
  298. package/lib/module/chat.js +2 -0
  299. package/lib/module/chat.js.map +1 -1
  300. package/lib/module/grammar.js +14 -31
  301. package/lib/module/grammar.js.map +1 -1
  302. package/lib/module/index.js +47 -45
  303. package/lib/module/index.js.map +1 -1
  304. package/lib/module/package.json +1 -0
  305. package/lib/typescript/NativeRNLlama.d.ts +10 -4
  306. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  307. package/lib/typescript/index.d.ts.map +1 -1
  308. package/llama-rn.podspec +48 -48
  309. package/package.json +233 -233
  310. package/src/NativeRNLlama.ts +431 -426
  311. package/src/chat.ts +44 -44
  312. package/src/grammar.ts +854 -854
  313. package/src/index.ts +495 -487
  314. /package/cpp/{binary-ops.cpp → ggml-cpu/binary-ops.cpp} +0 -0
  315. /package/cpp/{ggml-cpu-aarch64.h → ggml-cpu/ggml-cpu-aarch64.h} +0 -0
  316. /package/cpp/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +0 -0
  317. /package/cpp/{ggml-cpu-quants.c → ggml-cpu/ggml-cpu-quants.c} +0 -0
  318. /package/cpp/{ggml-cpu-quants.h → ggml-cpu/ggml-cpu-quants.h} +0 -0
  319. /package/cpp/{ggml-cpu-traits.cpp → ggml-cpu/ggml-cpu-traits.cpp} +0 -0
  320. /package/cpp/{ggml-cpu-traits.h → ggml-cpu/ggml-cpu-traits.h} +0 -0
  321. /package/cpp/{sgemm.h → ggml-cpu/sgemm.h} +0 -0
  322. /package/cpp/{unary-ops.cpp → ggml-cpu/unary-ops.cpp} +0 -0
  323. /package/cpp/{vec.cpp → ggml-cpu/vec.cpp} +0 -0
  324. /package/cpp/{vec.h → ggml-cpu/vec.h} +0 -0
@@ -40,14 +40,17 @@ const char * llm_type_name(llm_type type) {
40
40
  case LLM_TYPE_335M: return "335M";
41
41
  case LLM_TYPE_410M: return "410M";
42
42
  case LLM_TYPE_450M: return "450M";
43
+ case LLM_TYPE_475M: return "475M";
43
44
  case LLM_TYPE_770M: return "770M";
44
45
  case LLM_TYPE_780M: return "780M";
45
46
  case LLM_TYPE_0_5B: return "0.5B";
47
+ case LLM_TYPE_0_6B: return "0.6B";
46
48
  case LLM_TYPE_1B: return "1B";
47
49
  case LLM_TYPE_1_3B: return "1.3B";
48
50
  case LLM_TYPE_1_4B: return "1.4B";
49
51
  case LLM_TYPE_1_5B: return "1.5B";
50
52
  case LLM_TYPE_1_6B: return "1.6B";
53
+ case LLM_TYPE_1_7B: return "1.7B";
51
54
  case LLM_TYPE_1_8B: return "1.8B";
52
55
  case LLM_TYPE_2B: return "2B";
53
56
  case LLM_TYPE_2_8B: return "2.8B";
@@ -66,6 +69,7 @@ const char * llm_type_name(llm_type type) {
66
69
  case LLM_TYPE_15B: return "15B";
67
70
  case LLM_TYPE_16B: return "16B";
68
71
  case LLM_TYPE_20B: return "20B";
72
+ case LLM_TYPE_27B: return "27B";
69
73
  case LLM_TYPE_30B: return "30B";
70
74
  case LLM_TYPE_32B: return "32B";
71
75
  case LLM_TYPE_34B: return "34B";
@@ -74,6 +78,7 @@ const char * llm_type_name(llm_type type) {
74
78
  case LLM_TYPE_65B: return "65B";
75
79
  case LLM_TYPE_70B: return "70B";
76
80
  case LLM_TYPE_236B: return "236B";
81
+ case LLM_TYPE_290B: return "290B";
77
82
  case LLM_TYPE_314B: return "314B";
78
83
  case LLM_TYPE_671B: return "671B";
79
84
  case LLM_TYPE_SMALL: return "0.1B";
@@ -88,10 +93,10 @@ const char * llm_type_name(llm_type type) {
88
93
  case LLM_TYPE_16x3_8B: return "16x3.8B";
89
94
  case LLM_TYPE_10B_128x3_66B: return "10B+128x3.66B";
90
95
  case LLM_TYPE_57B_A14B: return "57B.A14B";
91
- case LLM_TYPE_27B: return "27B";
92
- case LLM_TYPE_290B: return "290B";
93
96
  case LLM_TYPE_17B_16E: return "17Bx16E (Scout)";
94
97
  case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)";
98
+ case LLM_TYPE_30B_A3B: return "30B.A3B";
99
+ case LLM_TYPE_235B_A22B: return "235B.A22B";
95
100
  default: return "?B";
96
101
  }
97
102
  }
@@ -695,13 +700,19 @@ void llama_model::load_hparams(llama_model_loader & ml) {
695
700
  }
696
701
  } break;
697
702
  case LLM_ARCH_NOMIC_BERT:
703
+ case LLM_ARCH_NOMIC_BERT_MOE:
698
704
  {
699
705
  ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
700
706
  ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
701
707
  ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type);
708
+ ml.get_key(LLM_KV_MOE_EVERY_N_LAYERS, hparams.moe_every_n_layers, 0);
702
709
 
703
710
  if (hparams.n_layer == 12 && hparams.n_embd == 768) {
704
- type = LLM_TYPE_137M;
711
+ if (arch == LLM_ARCH_NOMIC_BERT) {
712
+ type = LLM_TYPE_137M;
713
+ } else if (arch == LLM_ARCH_NOMIC_BERT_MOE && hparams.moe_every_n_layers == 2) {
714
+ type = LLM_TYPE_475M;
715
+ }
705
716
  }
706
717
  } break;
707
718
  case LLM_ARCH_BLOOM:
@@ -762,6 +773,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
762
773
  // fall through
763
774
  case LLM_ARCH_QWEN2:
764
775
  {
776
+ ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
765
777
  ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
766
778
  switch (hparams.n_layer) {
767
779
  case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_5B : LLM_TYPE_1B; break;
@@ -791,6 +803,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
791
803
  {
792
804
  ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
793
805
  switch (hparams.n_layer) {
806
+ case 28: type = hparams.n_embd == 1024 ? LLM_TYPE_0_6B : LLM_TYPE_1_7B; break;
807
+ case 36: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_8B; break;
808
+ case 40: type = LLM_TYPE_14B; break;
809
+ case 64: type = LLM_TYPE_32B; break;
794
810
  default: type = LLM_TYPE_UNKNOWN;
795
811
  }
796
812
  } break;
@@ -800,6 +816,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
800
816
 
801
817
  ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
802
818
  switch (hparams.n_layer) {
819
+ case 48: type = LLM_TYPE_30B_A3B; break;
820
+ case 94: type = LLM_TYPE_235B_A22B; break;
803
821
  default: type = LLM_TYPE_UNKNOWN;
804
822
  }
805
823
  } break;
@@ -1156,6 +1174,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
1156
1174
  ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q);
1157
1175
  }
1158
1176
  ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv);
1177
+ ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla, false);
1178
+ ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla, false);
1159
1179
  ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
1160
1180
  ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
1161
1181
  ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale);
@@ -1205,6 +1225,15 @@ void llama_model::load_hparams(llama_model_loader & ml) {
1205
1225
  default: type = LLM_TYPE_UNKNOWN;
1206
1226
  }
1207
1227
  } break;
1228
+ case LLM_ARCH_GLM4:
1229
+ {
1230
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1231
+ switch (hparams.n_layer) {
1232
+ case 40: type = LLM_TYPE_9B; break;
1233
+ case 61: type = LLM_TYPE_32B; break;
1234
+ default: type = LLM_TYPE_UNKNOWN;
1235
+ }
1236
+ } break;
1208
1237
  case LLM_ARCH_BITNET:
1209
1238
  {
1210
1239
  ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -2046,6 +2075,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
2046
2075
  } break;
2047
2076
  case LLM_ARCH_BERT:
2048
2077
  case LLM_ARCH_NOMIC_BERT:
2078
+ case LLM_ARCH_NOMIC_BERT_MOE:
2049
2079
  {
2050
2080
  tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
2051
2081
  type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, 0);
@@ -2079,20 +2109,31 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
2079
2109
  layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
2080
2110
  }
2081
2111
 
2112
+ if (arch == LLM_ARCH_NOMIC_BERT_MOE) {
2113
+ layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0);
2114
+ }
2115
+
2082
2116
  layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
2083
2117
 
2084
2118
  layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0);
2085
2119
  layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0);
2086
2120
 
2087
- layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
2088
- layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
2089
-
2090
- if (arch == LLM_ARCH_BERT) {
2121
+ if (hparams.moe_every_n_layers > 0 && i % hparams.moe_every_n_layers == 1) {
2091
2122
  layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
2092
- layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0);
2093
- layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0);
2123
+ layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff, n_expert}, 0);
2124
+ layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0);
2125
+ layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
2094
2126
  } else {
2095
- layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
2127
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
2128
+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
2129
+
2130
+ if (arch == LLM_ARCH_BERT || arch == LLM_ARCH_NOMIC_BERT_MOE) {
2131
+ layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
2132
+ layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0);
2133
+ layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0);
2134
+ } else {
2135
+ layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
2136
+ }
2096
2137
  }
2097
2138
 
2098
2139
  layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0);
@@ -3196,8 +3237,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
3196
3237
  {
3197
3238
  const bool is_lite = (hparams.n_layer == 27);
3198
3239
 
3240
+ const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0);
3241
+
3242
+ // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA
3243
+ const int64_t n_embd_head_k_mla = is_mla ? hparams.n_embd_head_k_mla : hparams.n_embd_head_k;
3244
+ const int64_t n_embd_head_v_mla = is_mla ? hparams.n_embd_head_v_mla : hparams.n_embd_head_v;
3245
+
3199
3246
  const int64_t n_embd_head_qk_rope = hparams.n_rot;
3200
- const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
3247
+ const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope;
3201
3248
 
3202
3249
  const int64_t q_lora_rank = hparams.n_lora_q;
3203
3250
  const int64_t kv_lora_rank = hparams.n_lora_kv;
@@ -3223,14 +3270,22 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
3223
3270
 
3224
3271
  if (!is_lite) {
3225
3272
  layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0);
3226
- layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k}, 0);
3273
+ layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, 0);
3227
3274
  } else {
3228
- layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0);
3275
+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_embd_head_k_mla}, 0);
3229
3276
  }
3230
3277
 
3231
- layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0);
3232
- layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0);
3233
- layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0);
3278
+ layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope}, 0);
3279
+
3280
+ // note: only old legacy GGUF files will have the unsplit wkv_b tensor in
3281
+ if (is_mla) {
3282
+ layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, kv_lora_rank, n_head}, 0);
3283
+ layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, 0);
3284
+ } else {
3285
+ layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v_mla)}, 0);
3286
+ }
3287
+
3288
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, 0);
3234
3289
 
3235
3290
  layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
3236
3291
 
@@ -3476,6 +3531,45 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
3476
3531
  layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
3477
3532
  }
3478
3533
  } break;
3534
+ case LLM_ARCH_GLM4:
3535
+ {
3536
+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
3537
+
3538
+ // output
3539
+ output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
3540
+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
3541
+ // if output is NULL, init from the input tok embed
3542
+ if (output == NULL) {
3543
+ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
3544
+ }
3545
+
3546
+ for (int i = 0; i < n_layer; ++i) {
3547
+ auto & layer = layers[i];
3548
+
3549
+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
3550
+ layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED);
3551
+ layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED);
3552
+
3553
+ if (layer.wqkv == nullptr) {
3554
+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
3555
+ layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
3556
+ layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
3557
+ layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
3558
+ layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
3559
+ layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
3560
+ }
3561
+
3562
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
3563
+
3564
+ layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
3565
+
3566
+ layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
3567
+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
3568
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, 0);
3569
+
3570
+ layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
3571
+ }
3572
+ } break;
3479
3573
  case LLM_ARCH_NEMOTRON:
3480
3574
  {
3481
3575
  tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -4242,6 +4336,8 @@ void llama_model::print_info() const {
4242
4336
  LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead);
4243
4337
  LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q);
4244
4338
  LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv);
4339
+ LLAMA_LOG_INFO("%s: n_embd_head_k_mla = %d\n", __func__, hparams.n_embd_head_k_mla);
4340
+ LLAMA_LOG_INFO("%s: n_embd_head_v_mla = %d\n", __func__, hparams.n_embd_head_v_mla);
4245
4341
  LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
4246
4342
  LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared);
4247
4343
  LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale);
@@ -4350,6 +4446,19 @@ const lm_ggml_tensor * llama_model::get_tensor(const char * name) const {
4350
4446
  return it->second;
4351
4447
  }
4352
4448
 
4449
+ lm_ggml_tensor * llama_model::get_rope_factors(uint32_t n_ctx_per_seq, int il) const {
4450
+ // choose long/short freq factors based on the context size
4451
+ if (layers[il].rope_freqs != nullptr) {
4452
+ return layers[il].rope_freqs;
4453
+ }
4454
+
4455
+ if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) {
4456
+ return layers[il].rope_long;
4457
+ }
4458
+
4459
+ return layers[il].rope_short;
4460
+ }
4461
+
4353
4462
  struct llm_build_llama : public llm_graph_context {
4354
4463
  llm_build_llama(const llama_model & model, const llm_graph_params & params, lm_ggml_cgraph * gf) : llm_graph_context(params) {
4355
4464
  const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -4390,7 +4499,7 @@ struct llm_build_llama : public llm_graph_context {
4390
4499
  // self-attention
4391
4500
  {
4392
4501
  // rope freq factors for llama3; may return nullptr for llama2 and other models
4393
- lm_ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
4502
+ lm_ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
4394
4503
 
4395
4504
  // compute Q and K and RoPE them
4396
4505
  lm_ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -4440,15 +4549,15 @@ struct llm_build_llama : public llm_graph_context {
4440
4549
 
4441
4550
  if (arch == LLM_ARCH_LLAMA4 && use_rope && hparams.use_kq_norm) {
4442
4551
  // Llama4TextL2Norm
4443
- Qcur = lm_ggml_rms_norm(ctx0, Qcur, 1e-6);
4444
- Kcur = lm_ggml_rms_norm(ctx0, Kcur, 1e-6);
4552
+ Qcur = lm_ggml_rms_norm(ctx0, Qcur, hparams.f_norm_rms_eps);
4553
+ Kcur = lm_ggml_rms_norm(ctx0, Kcur, hparams.f_norm_rms_eps);
4445
4554
  cb(Qcur, "Qcur_normed", il);
4446
4555
  cb(Kcur, "Kcur_normed", il);
4447
4556
  }
4448
4557
 
4449
4558
  cur = build_attn(inp_attn, gf,
4450
4559
  model.layers[il].wo, model.layers[il].bo,
4451
- Qcur, Kcur, Vcur, nullptr, kq_scale, il);
4560
+ Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
4452
4561
  cb(cur, "attn_out", il);
4453
4562
  }
4454
4563
 
@@ -4615,7 +4724,7 @@ struct llm_build_deci : public llm_graph_context {
4615
4724
  } else if (n_head > 0) {
4616
4725
  // self-attention
4617
4726
  // rope freq factors for llama3; may return nullptr for llama2 and other models
4618
- lm_ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
4727
+ lm_ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
4619
4728
 
4620
4729
  // compute Q and K and RoPE them
4621
4730
  lm_ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -4661,7 +4770,7 @@ struct llm_build_deci : public llm_graph_context {
4661
4770
 
4662
4771
  cur = build_attn(inp_attn, gf,
4663
4772
  model.layers[il].wo, model.layers[il].bo,
4664
- Qcur, Kcur, Vcur, nullptr, kq_scale, il);
4773
+ Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
4665
4774
  }
4666
4775
 
4667
4776
  if (il == n_layer - 1) {
@@ -4803,7 +4912,7 @@ struct llm_build_baichuan : public llm_graph_context {
4803
4912
 
4804
4913
  cur = build_attn(inp_attn, gf,
4805
4914
  model.layers[il].wo, NULL,
4806
- Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
4915
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
4807
4916
  }
4808
4917
 
4809
4918
  if (il == n_layer - 1) {
@@ -4918,7 +5027,7 @@ struct llm_build_xverse : public llm_graph_context {
4918
5027
 
4919
5028
  cur = build_attn(inp_attn, gf,
4920
5029
  model.layers[il].wo, NULL,
4921
- Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
5030
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
4922
5031
  }
4923
5032
 
4924
5033
  if (il == n_layer - 1) {
@@ -5043,7 +5152,7 @@ struct llm_build_falcon : public llm_graph_context {
5043
5152
 
5044
5153
  cur = build_attn(inp_attn, gf,
5045
5154
  model.layers[il].wo, NULL,
5046
- Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
5155
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
5047
5156
  }
5048
5157
 
5049
5158
  if (il == n_layer - 1) {
@@ -5173,7 +5282,7 @@ struct llm_build_grok : public llm_graph_context {
5173
5282
 
5174
5283
  cur = build_attn(inp_attn, gf,
5175
5284
  model.layers[il].wo, model.layers[il].bo,
5176
- Qcur, Kcur, Vcur, nullptr, 1.0f, il);
5285
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
5177
5286
  }
5178
5287
 
5179
5288
  if (il == n_layer - 1) {
@@ -5324,7 +5433,7 @@ struct llm_build_dbrx : public llm_graph_context {
5324
5433
 
5325
5434
  cur = build_attn(inp_attn, gf,
5326
5435
  model.layers[il].wo, NULL,
5327
- Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
5436
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
5328
5437
  }
5329
5438
 
5330
5439
  if (il == n_layer - 1) {
@@ -5438,7 +5547,7 @@ struct llm_build_starcoder : public llm_graph_context {
5438
5547
 
5439
5548
  cur = build_attn(inp_attn, gf,
5440
5549
  model.layers[il].wo, model.layers[il].bo,
5441
- Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
5550
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
5442
5551
  }
5443
5552
 
5444
5553
  if (il == n_layer - 1) {
@@ -5537,7 +5646,7 @@ struct llm_build_refact : public llm_graph_context {
5537
5646
 
5538
5647
  cur = build_attn(inp_attn, gf,
5539
5648
  model.layers[il].wo, NULL,
5540
- Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
5649
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
5541
5650
  }
5542
5651
 
5543
5652
  if (il == n_layer - 1) {
@@ -5664,6 +5773,11 @@ struct llm_build_bert : public llm_graph_context {
5664
5773
  cur = build_lora_mm(model.layers[il].wqkv, cur);
5665
5774
  cb(cur, "wqkv", il);
5666
5775
 
5776
+ if (model.arch == LLM_ARCH_NOMIC_BERT_MOE) {
5777
+ cur = lm_ggml_add(ctx0, cur, model.layers[il].bqkv);
5778
+ cb(cur, "bqkv", il);
5779
+ }
5780
+
5667
5781
  Qcur = lm_ggml_cont(ctx0, lm_ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
5668
5782
  Kcur = lm_ggml_cont(ctx0, lm_ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
5669
5783
  Vcur = lm_ggml_cont(ctx0, lm_ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
@@ -5691,7 +5805,7 @@ struct llm_build_bert : public llm_graph_context {
5691
5805
 
5692
5806
  cur = build_attn(inp_attn, gf,
5693
5807
  model.layers[il].wo, model.layers[il].bo,
5694
- Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
5808
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
5695
5809
  cb(cur, "kqv_out", il);
5696
5810
 
5697
5811
  if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) {
@@ -5716,13 +5830,29 @@ struct llm_build_bert : public llm_graph_context {
5716
5830
  cb(ffn_inp, "ffn_inp", il);
5717
5831
 
5718
5832
  // feed-forward network
5719
- if (model.arch == LLM_ARCH_BERT) {
5833
+ if (hparams.moe_every_n_layers > 0 && il % hparams.moe_every_n_layers == 1) {
5834
+ // MoE branch
5835
+ cur = build_moe_ffn(cur,
5836
+ model.layers[il].ffn_gate_inp,
5837
+ model.layers[il].ffn_up_exps,
5838
+ nullptr,
5839
+ model.layers[il].ffn_down_exps,
5840
+ nullptr,
5841
+ hparams.n_expert,
5842
+ hparams.n_expert_used,
5843
+ LLM_FFN_GELU,
5844
+ false, false,
5845
+ 0.0f,
5846
+ LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il);
5847
+ cb(cur, "ffn_moe_out", il);
5848
+ } else if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE) {
5720
5849
  cur = build_ffn(cur,
5721
5850
  model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
5722
5851
  NULL, NULL, NULL,
5723
5852
  model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
5724
5853
  NULL,
5725
5854
  LLM_FFN_GELU, LLM_FFN_SEQ, il);
5855
+ cb(cur, "ffn_out", il);
5726
5856
  } else if (model.arch == LLM_ARCH_JINA_BERT_V2) {
5727
5857
  cur = build_ffn(cur,
5728
5858
  model.layers[il].ffn_up, NULL, NULL,
@@ -5730,6 +5860,7 @@ struct llm_build_bert : public llm_graph_context {
5730
5860
  model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
5731
5861
  NULL,
5732
5862
  LLM_FFN_GELU, LLM_FFN_PAR, il);
5863
+ cb(cur, "ffn_out", il);
5733
5864
  } else {
5734
5865
  cur = build_ffn(cur,
5735
5866
  model.layers[il].ffn_up, NULL, NULL,
@@ -5737,8 +5868,8 @@ struct llm_build_bert : public llm_graph_context {
5737
5868
  model.layers[il].ffn_down, NULL, NULL,
5738
5869
  NULL,
5739
5870
  LLM_FFN_SILU, LLM_FFN_PAR, il);
5871
+ cb(cur, "ffn_out", il);
5740
5872
  }
5741
- cb(cur, "ffn_out", il);
5742
5873
 
5743
5874
  // attentions bypass the intermediate layer
5744
5875
  cur = lm_ggml_add(ctx0, cur, ffn_inp);
@@ -5808,7 +5939,7 @@ struct llm_build_bloom : public llm_graph_context {
5808
5939
 
5809
5940
  cur = build_attn(inp_attn, gf,
5810
5941
  model.layers[il].wo, model.layers[il].bo,
5811
- Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
5942
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
5812
5943
  }
5813
5944
 
5814
5945
  if (il == n_layer - 1) {
@@ -5949,7 +6080,7 @@ struct llm_build_mpt : public llm_graph_context {
5949
6080
 
5950
6081
  cur = build_attn(inp_attn, gf,
5951
6082
  model.layers[il].wo, model.layers[il].bo,
5952
- Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6083
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
5953
6084
  }
5954
6085
 
5955
6086
  if (il == n_layer - 1) {
@@ -6095,7 +6226,7 @@ struct llm_build_stablelm : public llm_graph_context {
6095
6226
 
6096
6227
  cur = build_attn(inp_attn, gf,
6097
6228
  model.layers[il].wo, NULL,
6098
- Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6229
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6099
6230
  }
6100
6231
 
6101
6232
  if (il == n_layer - 1) {
@@ -6218,7 +6349,7 @@ struct llm_build_qwen : public llm_graph_context {
6218
6349
 
6219
6350
  cur = build_attn(inp_attn, gf,
6220
6351
  model.layers[il].wo, NULL,
6221
- Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6352
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6222
6353
  }
6223
6354
 
6224
6355
  if (il == n_layer - 1) {
@@ -6338,7 +6469,7 @@ struct llm_build_qwen2 : public llm_graph_context {
6338
6469
 
6339
6470
  cur = build_attn(inp_attn, gf,
6340
6471
  model.layers[il].wo, model.layers[il].bo,
6341
- Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6472
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6342
6473
  }
6343
6474
 
6344
6475
  if (il == n_layer - 1) {
@@ -6459,7 +6590,7 @@ struct llm_build_qwen2vl : public llm_graph_context {
6459
6590
 
6460
6591
  cur = build_attn(inp_attn, gf,
6461
6592
  model.layers[il].wo, model.layers[il].bo,
6462
- Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6593
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6463
6594
  }
6464
6595
 
6465
6596
  if (il == n_layer - 1) {
@@ -6586,7 +6717,7 @@ struct llm_build_qwen2moe : public llm_graph_context {
6586
6717
 
6587
6718
  cur = build_attn(inp_attn, gf,
6588
6719
  model.layers[il].wo, model.layers[il].bo,
6589
- Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6720
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6590
6721
  }
6591
6722
 
6592
6723
  if (il == n_layer - 1) {
@@ -6739,7 +6870,7 @@ struct llm_build_qwen3 : public llm_graph_context {
6739
6870
 
6740
6871
  cur = build_attn(inp_attn, gf,
6741
6872
  model.layers[il].wo, model.layers[il].bo,
6742
- Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6873
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6743
6874
  }
6744
6875
 
6745
6876
  if (il == n_layer - 1) {
@@ -6860,7 +6991,7 @@ struct llm_build_qwen3moe : public llm_graph_context {
6860
6991
 
6861
6992
  cur = build_attn(inp_attn, gf,
6862
6993
  model.layers[il].wo, model.layers[il].bo,
6863
- Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6994
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6864
6995
  }
6865
6996
 
6866
6997
  if (il == n_layer - 1) {
@@ -7000,7 +7131,7 @@ struct llm_build_phi2 : public llm_graph_context {
7000
7131
 
7001
7132
  cur = build_attn(inp_attn, gf,
7002
7133
  model.layers[il].wo, model.layers[il].bo,
7003
- Qcur, Kcur, Vcur, nullptr, 1.0f, il);
7134
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
7004
7135
  }
7005
7136
 
7006
7137
  if (il == n_layer - 1) {
@@ -7075,7 +7206,7 @@ struct llm_build_phi3 : public llm_graph_context {
7075
7206
  // self-attention
7076
7207
  {
7077
7208
  // rope freq factors for 128k context
7078
- lm_ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
7209
+ lm_ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
7079
7210
 
7080
7211
  lm_ggml_tensor* attn_norm_output = build_norm(inpL,
7081
7212
  model.layers[il].attn_norm,
@@ -7129,7 +7260,7 @@ struct llm_build_phi3 : public llm_graph_context {
7129
7260
 
7130
7261
  cur = build_attn(inp_attn, gf,
7131
7262
  model.layers[il].wo, model.layers[il].bo,
7132
- Qcur, Kcur, Vcur, nullptr, 1.0f, il);
7263
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
7133
7264
  }
7134
7265
 
7135
7266
  if (il == n_layer - 1) {
@@ -7264,7 +7395,7 @@ struct llm_build_plamo : public llm_graph_context {
7264
7395
 
7265
7396
  cur = build_attn(inp_attn, gf,
7266
7397
  model.layers[il].wo, NULL,
7267
- Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
7398
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
7268
7399
  }
7269
7400
  lm_ggml_tensor * sa_out = cur;
7270
7401
 
@@ -7371,7 +7502,7 @@ struct llm_build_gpt2 : public llm_graph_context {
7371
7502
 
7372
7503
  cur = build_attn(inp_attn, gf,
7373
7504
  model.layers[il].wo, model.layers[il].bo,
7374
- Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
7505
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
7375
7506
  }
7376
7507
 
7377
7508
  if (il == n_layer - 1) {
@@ -7487,7 +7618,7 @@ struct llm_build_codeshell : public llm_graph_context {
7487
7618
 
7488
7619
  cur = build_attn(inp_attn, gf,
7489
7620
  model.layers[il].wo, model.layers[il].bo,
7490
- Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
7621
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
7491
7622
  }
7492
7623
 
7493
7624
  if (il == n_layer - 1) {
@@ -7616,7 +7747,7 @@ struct llm_build_orion : public llm_graph_context {
7616
7747
 
7617
7748
  cur = build_attn(inp_attn, gf,
7618
7749
  model.layers[il].wo, NULL,
7619
- Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
7750
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
7620
7751
  }
7621
7752
 
7622
7753
  if (il == n_layer - 1) {
@@ -7743,7 +7874,7 @@ struct llm_build_internlm2 : public llm_graph_context {
7743
7874
 
7744
7875
  cur = build_attn(inp_attn, gf,
7745
7876
  model.layers[il].wo, model.layers[il].bo,
7746
- Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
7877
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
7747
7878
  }
7748
7879
 
7749
7880
  if (il == n_layer - 1) {
@@ -7827,7 +7958,7 @@ struct llm_build_minicpm3 : public llm_graph_context {
7827
7958
  for (int il = 0; il < n_layer; ++il) {
7828
7959
  lm_ggml_tensor * inpSA = inpL;
7829
7960
 
7830
- lm_ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
7961
+ lm_ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
7831
7962
 
7832
7963
  // norm
7833
7964
  cur = build_norm(inpL,
@@ -7940,7 +8071,7 @@ struct llm_build_minicpm3 : public llm_graph_context {
7940
8071
 
7941
8072
  cur = build_attn(inp_attn, gf,
7942
8073
  model.layers[il].wo, NULL,
7943
- q_states, k_states, v_states, nullptr, kq_scale, il);
8074
+ q_states, k_states, v_states, nullptr, nullptr, kq_scale, il);
7944
8075
  }
7945
8076
 
7946
8077
  if (il == n_layer - 1) {
@@ -8070,7 +8201,7 @@ struct llm_build_gemma : public llm_graph_context {
8070
8201
 
8071
8202
  cur = build_attn(inp_attn, gf,
8072
8203
  model.layers[il].wo, NULL,
8073
- Qcur, Kcur, Vcur, nullptr, 1.0f, il);
8204
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
8074
8205
  }
8075
8206
 
8076
8207
  if (il == n_layer - 1) {
@@ -8192,7 +8323,7 @@ struct llm_build_gemma2 : public llm_graph_context {
8192
8323
 
8193
8324
  cur = build_attn(inp_attn, gf,
8194
8325
  model.layers[il].wo, NULL,
8195
- Qcur, Kcur, Vcur, nullptr, 1.0f, il);
8326
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
8196
8327
  }
8197
8328
 
8198
8329
  cur = build_norm(cur,
@@ -8333,7 +8464,7 @@ struct llm_build_gemma3 : public llm_graph_context {
8333
8464
 
8334
8465
  cur = build_attn(inp_attn, gf,
8335
8466
  model.layers[il].wo, NULL,
8336
- Qcur, Kcur, Vcur, nullptr, hparams.f_attention_scale, il);
8467
+ Qcur, Kcur, Vcur, nullptr, nullptr, hparams.f_attention_scale, il);
8337
8468
  }
8338
8469
 
8339
8470
  cur = build_norm(cur,
@@ -8473,7 +8604,7 @@ struct llm_build_starcoder2 : public llm_graph_context {
8473
8604
 
8474
8605
  cur = build_attn(inp_attn, gf,
8475
8606
  model.layers[il].wo, model.layers[il].bo,
8476
- Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
8607
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
8477
8608
  }
8478
8609
 
8479
8610
  if (il == n_layer - 1) {
@@ -8594,7 +8725,7 @@ struct llm_build_mamba : public llm_graph_context {
8594
8725
  lm_ggml_tensor * state_mask,
8595
8726
  const llama_ubatch & ubatch,
8596
8727
  int il) const {
8597
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
8728
+ const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
8598
8729
 
8599
8730
  const auto kv_head = kv_self->head;
8600
8731
 
@@ -8808,7 +8939,7 @@ struct llm_build_command_r : public llm_graph_context {
8808
8939
 
8809
8940
  cur = build_attn(inp_attn, gf,
8810
8941
  model.layers[il].wo, model.layers[il].bo,
8811
- Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
8942
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
8812
8943
  }
8813
8944
 
8814
8945
  if (il == n_layer - 1) {
@@ -8895,7 +9026,7 @@ struct llm_build_cohere2 : public llm_graph_context {
8895
9026
  // self-attention
8896
9027
  {
8897
9028
  // rope freq factors for 128k context
8898
- lm_ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
9029
+ lm_ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
8899
9030
 
8900
9031
  // compute Q and K and RoPE them
8901
9032
  lm_ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -8943,7 +9074,7 @@ struct llm_build_cohere2 : public llm_graph_context {
8943
9074
 
8944
9075
  cur = build_attn(inp_attn, gf,
8945
9076
  model.layers[il].wo, model.layers[il].bo,
8946
- Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
9077
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
8947
9078
  }
8948
9079
 
8949
9080
  if (il == n_layer - 1) {
@@ -9074,7 +9205,7 @@ struct llm_build_olmo : public llm_graph_context {
9074
9205
 
9075
9206
  cur = build_attn(inp_attn, gf,
9076
9207
  model.layers[il].wo, nullptr,
9077
- Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
9208
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
9078
9209
  }
9079
9210
 
9080
9211
  if (il == n_layer - 1) {
@@ -9194,7 +9325,7 @@ struct llm_build_olmo2 : public llm_graph_context {
9194
9325
 
9195
9326
  cur = build_attn(inp_attn, gf,
9196
9327
  model.layers[il].wo, NULL,
9197
- Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
9328
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
9198
9329
  }
9199
9330
 
9200
9331
  cur = build_norm(cur,
@@ -9327,7 +9458,7 @@ struct llm_build_olmoe : public llm_graph_context {
9327
9458
 
9328
9459
  cur = build_attn(inp_attn, gf,
9329
9460
  model.layers[il].wo, NULL,
9330
- Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
9461
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
9331
9462
  }
9332
9463
 
9333
9464
  if (il == n_layer - 1) {
@@ -9460,7 +9591,7 @@ struct llm_build_openelm : public llm_graph_context {
9460
9591
 
9461
9592
  cur = build_attn(inp_attn, gf,
9462
9593
  model.layers[il].wo, NULL,
9463
- Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
9594
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
9464
9595
  }
9465
9596
 
9466
9597
  if (il == n_layer - 1) {
@@ -9574,7 +9705,7 @@ struct llm_build_gptneox : public llm_graph_context {
9574
9705
 
9575
9706
  cur = build_attn(inp_attn, gf,
9576
9707
  model.layers[il].wo, model.layers[il].bo,
9577
- Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
9708
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
9578
9709
  }
9579
9710
 
9580
9711
  if (il == n_layer - 1) {
@@ -9724,7 +9855,7 @@ struct llm_build_arctic : public llm_graph_context {
9724
9855
 
9725
9856
  cur = build_attn(inp_attn, gf,
9726
9857
  model.layers[il].wo, NULL,
9727
- Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
9858
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
9728
9859
  }
9729
9860
 
9730
9861
  if (il == n_layer - 1) {
@@ -9833,7 +9964,7 @@ struct llm_build_deepseek : public llm_graph_context {
9833
9964
  // self-attention
9834
9965
  {
9835
9966
  // rope freq factors for llama3; may return nullptr for llama2 and other models
9836
- lm_ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
9967
+ lm_ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
9837
9968
 
9838
9969
  // compute Q and K and RoPE them
9839
9970
  lm_ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -9879,7 +10010,7 @@ struct llm_build_deepseek : public llm_graph_context {
9879
10010
 
9880
10011
  cur = build_attn(inp_attn, gf,
9881
10012
  model.layers[il].wo, model.layers[il].bo,
9882
- Qcur, Kcur, Vcur, nullptr, kq_scale, il);
10013
+ Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
9883
10014
  }
9884
10015
 
9885
10016
  if (il == n_layer - 1) {
@@ -9969,15 +10100,22 @@ struct llm_build_deepseek2 : public llm_graph_context {
9969
10100
  llm_build_deepseek2(const llama_model & model, const llm_graph_params & params, lm_ggml_cgraph * gf) : llm_graph_context(params) {
9970
10101
  bool is_lite = (hparams.n_layer == 27);
9971
10102
 
10103
+ const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0);
10104
+
10105
+ // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA
10106
+ const int64_t n_embd_head_k = is_mla ? hparams.n_embd_head_k_mla : hparams.n_embd_head_k;
10107
+ const int64_t n_embd_head_v = is_mla ? hparams.n_embd_head_v_mla : hparams.n_embd_head_v;
10108
+
10109
+ const int64_t n_embd_head_qk_rope = hparams.n_rot;
10110
+ const int64_t n_embd_head_qk_nope = n_embd_head_k - n_embd_head_qk_rope;
10111
+
10112
+ const uint32_t kv_lora_rank = hparams.n_lora_kv;
10113
+
9972
10114
  // We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly.
9973
10115
  // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
9974
10116
  const float mscale = attn_factor * (1.0f + hparams.rope_yarn_log_mul * logf(1.0f / freq_scale));
9975
- const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(hparams.n_embd_head_k));
9976
- const float attn_factor_scaled = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale));
9977
-
9978
- const uint32_t n_embd_head_qk_rope = hparams.n_rot;
9979
- const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
9980
- const uint32_t kv_lora_rank = hparams.n_lora_kv;
10117
+ const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(n_embd_head_k));
10118
+ const float attn_factor = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale));
9981
10119
 
9982
10120
  lm_ggml_tensor * cur;
9983
10121
  lm_ggml_tensor * inpL;
@@ -10003,16 +10141,14 @@ struct llm_build_deepseek2 : public llm_graph_context {
10003
10141
  {
10004
10142
  lm_ggml_tensor * q = NULL;
10005
10143
  if (!is_lite) {
10006
- // {n_embd, q_lora_rank} * {n_embd, n_tokens} -> {q_lora_rank, n_tokens}
10007
10144
  q = lm_ggml_mul_mat(ctx0, model.layers[il].wq_a, cur);
10008
10145
  cb(q, "q", il);
10009
10146
 
10010
10147
  q = build_norm(q,
10011
- model.layers[il].attn_q_a_norm, NULL,
10148
+ model.layers[il].attn_q_a_norm, nullptr,
10012
10149
  LLM_NORM_RMS, il);
10013
10150
  cb(q, "q", il);
10014
10151
 
10015
- // {q_lora_rank, n_head * hparams.n_embd_head_k} * {q_lora_rank, n_tokens} -> {n_head * hparams.n_embd_head_k, n_tokens}
10016
10152
  q = lm_ggml_mul_mat(ctx0, model.layers[il].wq_b, q);
10017
10153
  cb(q, "q", il);
10018
10154
  } else {
@@ -10020,96 +10156,125 @@ struct llm_build_deepseek2 : public llm_graph_context {
10020
10156
  cb(q, "q", il);
10021
10157
  }
10022
10158
 
10023
- // split into {n_head * n_embd_head_qk_nope, n_tokens}
10024
- lm_ggml_tensor * q_nope = lm_ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens,
10025
- lm_ggml_row_size(q->type, hparams.n_embd_head_k),
10026
- lm_ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
10159
+ // split into {n_embd_head_qk_nope, n_head, n_tokens}
10160
+ lm_ggml_tensor * q_nope = lm_ggml_view_3d(ctx0, q,
10161
+ n_embd_head_qk_nope, n_head, n_tokens,
10162
+ lm_ggml_row_size(q->type, n_embd_head_k),
10163
+ lm_ggml_row_size(q->type, n_embd_head_k) * n_head,
10027
10164
  0);
10028
10165
  cb(q_nope, "q_nope", il);
10029
10166
 
10030
- // and {n_head * n_embd_head_qk_rope, n_tokens}
10031
- lm_ggml_tensor * q_pe = lm_ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens,
10032
- lm_ggml_row_size(q->type, hparams.n_embd_head_k),
10033
- lm_ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
10167
+ // and {n_embd_head_qk_rope, n_head, n_tokens}
10168
+ lm_ggml_tensor * q_pe = lm_ggml_view_3d(ctx0, q,
10169
+ n_embd_head_qk_rope, n_head, n_tokens,
10170
+ lm_ggml_row_size(q->type, n_embd_head_k),
10171
+ lm_ggml_row_size(q->type, n_embd_head_k) * n_head,
10034
10172
  lm_ggml_row_size(q->type, n_embd_head_qk_nope));
10035
10173
  cb(q_pe, "q_pe", il);
10036
10174
 
10037
- // {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens}
10038
- lm_ggml_tensor * kv_pe_compresseed = lm_ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
10039
- cb(kv_pe_compresseed, "kv_pe_compresseed", il);
10175
+ lm_ggml_tensor * kv_cmpr_pe = lm_ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
10176
+ cb(kv_cmpr_pe, "kv_cmpr_pe", il);
10040
10177
 
10041
10178
  // split into {kv_lora_rank, n_tokens}
10042
- lm_ggml_tensor * kv_compressed = lm_ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens,
10043
- kv_pe_compresseed->nb[1],
10179
+ lm_ggml_tensor * kv_cmpr = lm_ggml_view_2d(ctx0, kv_cmpr_pe,
10180
+ kv_lora_rank, n_tokens,
10181
+ lm_ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope),
10044
10182
  0);
10045
- cb(kv_compressed, "kv_compressed", il);
10183
+ cb(kv_cmpr, "kv_cmpr", il);
10184
+
10185
+ // and {n_embd_head_qk_rope, 1, n_tokens}
10186
+ lm_ggml_tensor * k_pe = lm_ggml_view_3d(ctx0, kv_cmpr_pe,
10187
+ n_embd_head_qk_rope, 1, n_tokens,
10188
+ lm_ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope),
10189
+ lm_ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope),
10190
+ lm_ggml_row_size(kv_cmpr_pe->type, kv_lora_rank));
10191
+ cb(k_pe, "k_pe", il);
10046
10192
 
10047
- // and {n_embd_head_qk_rope, n_tokens}
10048
- lm_ggml_tensor * k_pe = lm_ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens,
10049
- kv_pe_compresseed->nb[1],
10050
- kv_pe_compresseed->nb[1],
10051
- lm_ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
10193
+ q_pe = lm_ggml_rope_ext(ctx0, q_pe, inp_pos, nullptr,
10194
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
10195
+ ext_factor, attn_factor, beta_fast, beta_slow
10196
+ );
10197
+ cb(q_pe, "q_pe", il);
10198
+
10199
+ k_pe = lm_ggml_rope_ext(ctx0, k_pe, inp_pos, nullptr,
10200
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
10201
+ ext_factor, attn_factor, beta_fast, beta_slow
10202
+ );
10052
10203
  cb(k_pe, "k_pe", il);
10053
10204
 
10054
- // TODO: the CUDA backend used to not support non-cont. (RMS) norm, investigate removing lm_ggml_cont
10055
- kv_compressed = lm_ggml_cont(ctx0, kv_compressed);
10056
- kv_compressed = build_norm(kv_compressed,
10057
- model.layers[il].attn_kv_a_norm, NULL,
10205
+ kv_cmpr = build_norm(kv_cmpr,
10206
+ model.layers[il].attn_kv_a_norm, nullptr,
10058
10207
  LLM_NORM_RMS, il);
10059
- cb(kv_compressed, "kv_compressed", il);
10208
+ cb(kv_cmpr, "kv_cmpr", il);
10060
10209
 
10061
- // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
10062
- lm_ggml_tensor * kv = lm_ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed);
10063
- cb(kv, "kv", il);
10210
+ if (is_mla) {
10211
+ // {n_embd_head_qk_nope, n_tokens, n_head}
10212
+ q_nope = lm_ggml_permute(ctx0, q_nope, 0, 2, 1, 3);
10213
+ cb(q_nope, "q_nope_perm", il);
10064
10214
 
10065
- // split into {n_head * n_embd_head_qk_nope, n_tokens}
10066
- lm_ggml_tensor * k_nope = lm_ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
10067
- lm_ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
10068
- lm_ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
10069
- 0);
10070
- cb(k_nope, "k_nope", il);
10215
+ // {n_embd_head_qk_nope, kv_lora_rank, n_head} x {n_embd_head_qk_nope, n_tokens, n_head}
10216
+ lm_ggml_tensor * q_nope_absorbed = lm_ggml_mul_mat(ctx0, model.layers[il].wk_b, q_nope);
10217
+ cb(q_nope_absorbed, "q_nope_absorbed", il);
10071
10218
 
10072
- // and {n_head * n_embd_head_v, n_tokens}
10073
- lm_ggml_tensor * v_states = lm_ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens,
10074
- lm_ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)),
10075
- lm_ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head),
10076
- lm_ggml_row_size(kv->type, (n_embd_head_qk_nope)));
10077
- cb(v_states, "v_states", il);
10219
+ // {kv_lora_rank, n_head, n_tokens}
10220
+ q_nope_absorbed = lm_ggml_permute(ctx0, q_nope_absorbed, 0, 2, 1, 3);
10221
+ cb(q_nope_absorbed, "q_nope_absorbed_perm", il);
10078
10222
 
10079
- v_states = lm_ggml_cont(ctx0, v_states);
10080
- cb(v_states, "v_states", il);
10223
+ // {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens}
10224
+ // note: rope must go first for in-place context shifting in build_rope_shift()
10225
+ lm_ggml_tensor * Qcur = lm_ggml_concat(ctx0, q_pe, q_nope_absorbed, 0);
10226
+ cb(Qcur, "Qcur", il);
10081
10227
 
10082
- v_states = lm_ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens,
10083
- lm_ggml_row_size(kv->type, hparams.n_embd_head_v * n_head),
10084
- 0);
10085
- cb(v_states, "v_states", il);
10228
+ kv_cmpr = lm_ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens);
10229
+ cb(kv_cmpr, "kv_cmpr_reshape", il);
10086
10230
 
10087
- q_pe = lm_ggml_cont(ctx0, q_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
10088
- q_pe = lm_ggml_rope_ext(
10089
- ctx0, q_pe, inp_pos, nullptr,
10090
- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
10091
- ext_factor, attn_factor_scaled, beta_fast, beta_slow
10092
- );
10093
- cb(q_pe, "q_pe", il);
10231
+ // {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens}
10232
+ lm_ggml_tensor * Kcur = lm_ggml_concat(ctx0, k_pe, kv_cmpr, 0);
10233
+ cb(Kcur, "Kcur", il);
10094
10234
 
10095
- // shared RoPE key
10096
- k_pe = lm_ggml_cont(ctx0, k_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
10097
- k_pe = lm_ggml_rope_ext(
10098
- ctx0, k_pe, inp_pos, nullptr,
10099
- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
10100
- ext_factor, attn_factor_scaled, beta_fast, beta_slow
10101
- );
10102
- cb(k_pe, "k_pe", il);
10235
+ // {kv_lora_rank, 1, n_tokens}
10236
+ lm_ggml_tensor * Vcur = kv_cmpr;
10237
+ cb(Vcur, "Vcur", il);
10103
10238
 
10104
- lm_ggml_tensor * q_states = lm_ggml_concat(ctx0, q_nope, q_pe, 0);
10105
- cb(q_states, "q_states", il);
10239
+ // note: MLA with the absorption optimzation converts into MQA (ie: GQA with 1 group)
10240
+ cur = build_attn(inp_attn, gf,
10241
+ model.layers[il].wo, NULL,
10242
+ Qcur, Kcur, Vcur, nullptr, model.layers[il].wv_b, kq_scale, il);
10243
+ } else {
10244
+ lm_ggml_tensor * kv = lm_ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cmpr);
10245
+ cb(kv, "kv", il);
10246
+
10247
+ // split into {n_embd_head_qk_nope, n_head, n_tokens}
10248
+ lm_ggml_tensor * k_nope = lm_ggml_view_3d(ctx0, kv,
10249
+ n_embd_head_qk_nope, n_head, n_tokens,
10250
+ lm_ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v),
10251
+ lm_ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v) * n_head,
10252
+ 0);
10253
+ cb(k_nope, "k_nope_view", il);
10106
10254
 
10107
- lm_ggml_tensor * k_states = lm_ggml_concat(ctx0, k_nope, lm_ggml_repeat(ctx0, k_pe, q_pe), 0);
10108
- cb(k_states, "k_states", il);
10255
+ // and {n_embd_head_v, n_head, n_tokens}
10256
+ lm_ggml_tensor * Vcur = lm_ggml_view_3d(ctx0, kv,
10257
+ n_embd_head_v, n_head, n_tokens,
10258
+ lm_ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v),
10259
+ lm_ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v) * n_head,
10260
+ lm_ggml_row_size(kv->type, n_embd_head_qk_nope));
10261
+ cb(Vcur, "Vcur_view", il);
10109
10262
 
10110
- cur = build_attn(inp_attn, gf,
10111
- model.layers[il].wo, NULL,
10112
- q_states, k_states, v_states, nullptr, kq_scale, il);
10263
+ Vcur = lm_ggml_cont(ctx0, Vcur);
10264
+ cb(Vcur, "Vcur_cont", il);
10265
+
10266
+ // note: rope must go first for in-place context shifting in build_rope_shift()
10267
+ lm_ggml_tensor * Qcur = lm_ggml_concat(ctx0, q_pe, q_nope, 0);
10268
+ cb(Qcur, "Qcur", il);
10269
+
10270
+ lm_ggml_tensor * Kcur = lm_ggml_concat(ctx0, lm_ggml_repeat(ctx0, k_pe, q_pe), k_nope, 0);
10271
+ cb(Kcur, "Kcur", il);
10272
+
10273
+ // note: MLA without the absorption optimization converts into MHA (ie: GQA with full n_head groups)
10274
+ cur = build_attn(inp_attn, gf,
10275
+ model.layers[il].wo, NULL,
10276
+ Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
10277
+ }
10113
10278
  }
10114
10279
 
10115
10280
  if (il == n_layer - 1) {
@@ -10275,7 +10440,7 @@ struct llm_build_bitnet : public llm_graph_context {
10275
10440
 
10276
10441
  cur = build_attn(inp_attn, gf,
10277
10442
  NULL, NULL,
10278
- Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
10443
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
10279
10444
 
10280
10445
  cur = build_norm(cur,
10281
10446
  model.layers[il].attn_sub_norm, NULL,
@@ -10398,7 +10563,7 @@ struct llm_build_t5_enc : public llm_graph_context {
10398
10563
 
10399
10564
  cur = build_attn(inp_attn, gf,
10400
10565
  model.layers[il].wo_enc, nullptr,
10401
- Qcur, Kcur, Vcur, kq_b, 1.0f, il);
10566
+ Qcur, Kcur, Vcur, kq_b, nullptr, 1.0f, il);
10402
10567
  cb(cur, "kqv_out", il);
10403
10568
  }
10404
10569
 
@@ -10504,7 +10669,7 @@ struct llm_build_t5_dec : public llm_graph_context {
10504
10669
 
10505
10670
  cur = build_attn(inp_attn_self, gf,
10506
10671
  model.layers[il].wo, model.layers[il].bo,
10507
- Qcur, Kcur, Vcur, kq_b, 1.0f, il);
10672
+ Qcur, Kcur, Vcur, kq_b, nullptr, 1.0f, il);
10508
10673
  cb(cur, "kqv_out", il);
10509
10674
  }
10510
10675
 
@@ -10536,7 +10701,7 @@ struct llm_build_t5_dec : public llm_graph_context {
10536
10701
 
10537
10702
  cur = build_attn(inp_attn_cross, gf,
10538
10703
  model.layers[il].wo_cross, nullptr,
10539
- Qcur, Kcur, Vcur, nullptr, 1.0f, il);
10704
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
10540
10705
  cb(cur, "kqv_out", il);
10541
10706
 
10542
10707
  //lm_ggml_tensor * q = lm_ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
@@ -10669,7 +10834,7 @@ struct llm_build_jais : public llm_graph_context {
10669
10834
 
10670
10835
  cur = build_attn(inp_attn, gf,
10671
10836
  model.layers[il].wo, model.layers[il].bo,
10672
- Qcur, Kcur, Vcur, nullptr, 1.0f/float(n_embd_head), il);
10837
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/float(n_embd_head), il);
10673
10838
  }
10674
10839
 
10675
10840
  if (il == n_layer - 1) {
@@ -10801,7 +10966,7 @@ struct llm_build_chatglm : public llm_graph_context {
10801
10966
 
10802
10967
  cur = build_attn(inp_attn, gf,
10803
10968
  model.layers[il].wo, NULL,
10804
- Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
10969
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
10805
10970
  }
10806
10971
 
10807
10972
  if (il == n_layer - 1) {
@@ -10854,6 +11019,157 @@ struct llm_build_chatglm : public llm_graph_context {
10854
11019
  }
10855
11020
  };
10856
11021
 
11022
+ struct llm_build_glm4 : public llm_graph_context {
11023
+ llm_build_glm4(const llama_model & model, const llm_graph_params & params, lm_ggml_cgraph * gf) : llm_graph_context(params) {
11024
+ const int64_t n_embd_head = hparams.n_embd_head_v;
11025
+ const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
11026
+
11027
+ LM_GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
11028
+
11029
+ lm_ggml_tensor * cur;
11030
+ lm_ggml_tensor * inpL;
11031
+
11032
+ inpL = build_inp_embd(model.tok_embd);
11033
+
11034
+ // inp_pos - contains the positions
11035
+ lm_ggml_tensor * inp_pos = build_inp_pos();
11036
+
11037
+ auto * inp_attn = build_attn_inp_kv_unified();
11038
+
11039
+ for (int il = 0; il < n_layer; ++il) {
11040
+ lm_ggml_tensor * inpSA = inpL;
11041
+
11042
+ // Pre-attention norm
11043
+ cur = build_norm(inpL,
11044
+ model.layers[il].attn_norm,
11045
+ NULL,
11046
+ LLM_NORM_RMS, il);
11047
+ cb(cur, "attn_norm", il);
11048
+
11049
+ // self-attention
11050
+ {
11051
+ lm_ggml_tensor * Qcur = nullptr;
11052
+ lm_ggml_tensor * Kcur = nullptr;
11053
+ lm_ggml_tensor * Vcur = nullptr;
11054
+
11055
+ if (model.layers[il].wqkv == nullptr) {
11056
+ Qcur = build_lora_mm(model.layers[il].wq, cur);
11057
+ if (model.layers[il].bq) {
11058
+ Qcur = lm_ggml_add(ctx0, Qcur, model.layers[il].bq);
11059
+ }
11060
+ Kcur = build_lora_mm(model.layers[il].wk, cur);
11061
+ if (model.layers[il].bk) {
11062
+ Kcur = lm_ggml_add(ctx0, Kcur, model.layers[il].bk);
11063
+ }
11064
+ Vcur = build_lora_mm(model.layers[il].wv, cur);
11065
+ if (model.layers[il].bv) {
11066
+ Vcur = lm_ggml_add(ctx0, Vcur, model.layers[il].bv);
11067
+ }
11068
+ } else {
11069
+ cur = build_lora_mm(model.layers[il].wqkv, cur);
11070
+ cb(cur, "wqkv", il);
11071
+ if (model.layers[il].bqkv) {
11072
+ cur = lm_ggml_add(ctx0, cur, model.layers[il].bqkv);
11073
+ cb(cur, "bqkv", il);
11074
+ }
11075
+ Qcur = lm_ggml_cont(ctx0, lm_ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
11076
+ Kcur = lm_ggml_cont(ctx0, lm_ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
11077
+ Vcur = lm_ggml_cont(ctx0, lm_ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
11078
+ }
11079
+
11080
+ Qcur = lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
11081
+ Kcur = lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
11082
+ Vcur = lm_ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
11083
+
11084
+ Qcur = lm_ggml_rope_ext(
11085
+ ctx0, Qcur, inp_pos, nullptr,
11086
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
11087
+ ext_factor, attn_factor, beta_fast, beta_slow
11088
+ );
11089
+
11090
+ Kcur = lm_ggml_rope_ext(
11091
+ ctx0, Kcur, inp_pos, nullptr,
11092
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
11093
+ ext_factor, attn_factor, beta_fast, beta_slow
11094
+ );
11095
+
11096
+ cb(Qcur, "Qcur", il);
11097
+ cb(Kcur, "Kcur", il);
11098
+ cb(Vcur, "Vcur", il);
11099
+
11100
+ cur = build_attn(inp_attn, gf,
11101
+ model.layers[il].wo, NULL,
11102
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
11103
+ }
11104
+
11105
+ if (il == n_layer - 1) {
11106
+ // skip computing output for unused tokens
11107
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
11108
+ cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
11109
+ inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
11110
+ }
11111
+
11112
+ // Post-attention norm (new!)
11113
+ cur = build_norm(cur,
11114
+ model.layers[il].attn_post_norm,
11115
+ NULL,
11116
+ LLM_NORM_RMS, il);
11117
+ cb(cur, "post_attn_norm", il);
11118
+
11119
+ // Add the input (residual connection after post-attention norm)
11120
+ lm_ggml_tensor * ffn_inp = lm_ggml_add(ctx0, cur, inpSA);
11121
+ cb(ffn_inp, "ffn_inp", il);
11122
+
11123
+ // FF
11124
+ {
11125
+ // Pre-MLP norm
11126
+ cur = build_norm(ffn_inp,
11127
+ model.layers[il].ffn_norm,
11128
+ NULL,
11129
+ LLM_NORM_RMS, il);
11130
+ cb(cur, "ffn_norm", il);
11131
+
11132
+ // MLP
11133
+ cur = build_ffn(cur,
11134
+ model.layers[il].ffn_up, NULL, NULL,
11135
+ NULL, NULL, NULL,
11136
+ model.layers[il].ffn_down, NULL, NULL,
11137
+ NULL,
11138
+ LLM_FFN_SWIGLU, LLM_FFN_SEQ, il);
11139
+ cb(cur, "ffn_out", il);
11140
+
11141
+ // Post-MLP norm
11142
+ cur = build_norm(cur,
11143
+ model.layers[il].ffn_post_norm,
11144
+ NULL,
11145
+ LLM_NORM_RMS, il);
11146
+ cb(cur, "post_mlp_norm", il);
11147
+ }
11148
+
11149
+ // Add residual connection after post-MLP norm
11150
+ inpL = lm_ggml_add(ctx0, cur, ffn_inp);
11151
+ cb(inpL, "l_out", il);
11152
+ }
11153
+
11154
+ // Final norm
11155
+ cur = build_norm(inpL,
11156
+ model.output_norm,
11157
+ NULL,
11158
+ LLM_NORM_RMS, -1);
11159
+
11160
+ cb(cur, "result_norm", -1);
11161
+ res->t_embd = cur;
11162
+
11163
+ // Output projection
11164
+ cur = build_lora_mm(model.output, cur);
11165
+
11166
+ cb(cur, "result_output", -1);
11167
+ res->t_logits = cur;
11168
+
11169
+ lm_ggml_build_forward_expand(gf, cur);
11170
+ }
11171
+ };
11172
+
10857
11173
  struct llm_build_nemotron : public llm_graph_context {
10858
11174
  llm_build_nemotron(const llama_model & model, const llm_graph_params & params, lm_ggml_cgraph * gf) : llm_graph_context(params) {
10859
11175
  const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -10927,7 +11243,7 @@ struct llm_build_nemotron : public llm_graph_context {
10927
11243
 
10928
11244
  cur = build_attn(inp_attn, gf,
10929
11245
  model.layers[il].wo, model.layers[il].bo,
10930
- Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
11246
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
10931
11247
  }
10932
11248
 
10933
11249
  if (il == n_layer - 1) {
@@ -11012,7 +11328,7 @@ struct llm_build_exaone : public llm_graph_context {
11012
11328
  // self-attention
11013
11329
  {
11014
11330
  // rope freq factors for llama3; may return nullptr for llama2 and other models
11015
- lm_ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
11331
+ lm_ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
11016
11332
 
11017
11333
  // compute Q and K and RoPE them
11018
11334
  lm_ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -11058,7 +11374,7 @@ struct llm_build_exaone : public llm_graph_context {
11058
11374
 
11059
11375
  cur = build_attn(inp_attn, gf,
11060
11376
  model.layers[il].wo, model.layers[il].bo,
11061
- Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
11377
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
11062
11378
  }
11063
11379
 
11064
11380
  if (il == n_layer - 1) {
@@ -11157,7 +11473,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
11157
11473
  lm_ggml_tensor * state_mask,
11158
11474
  const llama_ubatch & ubatch,
11159
11475
  int il) const {
11160
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
11476
+ const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
11161
11477
 
11162
11478
  const auto n_tokens = ubatch.n_tokens;
11163
11479
  const auto n_seqs = ubatch.n_seqs;
@@ -11553,7 +11869,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
11553
11869
  lm_ggml_tensor *& first_layer_value,
11554
11870
  const llama_ubatch & ubatch,
11555
11871
  int il) const {
11556
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
11872
+ const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
11557
11873
 
11558
11874
  const auto n_tokens = ubatch.n_tokens;
11559
11875
  const auto n_seqs = ubatch.n_seqs;
@@ -11960,7 +12276,7 @@ struct llm_build_chameleon : public llm_graph_context {
11960
12276
 
11961
12277
  cur = build_attn(inp_attn, gf,
11962
12278
  model.layers[il].wo, nullptr,
11963
- Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
12279
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
11964
12280
 
11965
12281
  if (hparams.swin_norm) {
11966
12282
  cur = build_norm(cur,
@@ -12316,7 +12632,7 @@ struct llm_build_plm : public llm_graph_context {
12316
12632
 
12317
12633
  cur = build_attn(inp_attn, gf,
12318
12634
  model.layers[il].wo, NULL,
12319
- q_states, k_states, v_states, nullptr, kq_scale, il);
12635
+ q_states, k_states, v_states, nullptr, nullptr, kq_scale, il);
12320
12636
  }
12321
12637
 
12322
12638
  if (il == n_layer - 1) {
@@ -12393,7 +12709,7 @@ struct llm_build_bailingmoe : public llm_graph_context {
12393
12709
  // self-attention
12394
12710
  {
12395
12711
  // rope freq factors for llama3; may return nullptr for llama2 and other models
12396
- lm_ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
12712
+ lm_ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
12397
12713
 
12398
12714
  // compute Q and K and RoPE them
12399
12715
  lm_ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -12439,7 +12755,7 @@ struct llm_build_bailingmoe : public llm_graph_context {
12439
12755
 
12440
12756
  cur = build_attn(inp_attn, gf,
12441
12757
  model.layers[il].wo, model.layers[il].bo,
12442
- Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_rot)), il);
12758
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_rot)), il);
12443
12759
  }
12444
12760
 
12445
12761
  if (il == n_layer - 1) {
@@ -12513,7 +12829,7 @@ struct llm_build_bailingmoe : public llm_graph_context {
12513
12829
  }
12514
12830
  };
12515
12831
 
12516
- llama_memory_i * llama_model::create_memory() const {
12832
+ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
12517
12833
  llama_memory_i * res;
12518
12834
 
12519
12835
  switch (arch) {
@@ -12523,26 +12839,29 @@ llama_memory_i * llama_model::create_memory() const {
12523
12839
  case LLM_ARCH_RWKV7:
12524
12840
  case LLM_ARCH_ARWKV7:
12525
12841
  {
12526
- res = new llama_kv_cache_unified(hparams, {
12527
- /*.get_rope_factors =*/ nullptr
12528
- });
12842
+ res = new llama_kv_cache_recurrent(
12843
+ *this,
12844
+ LM_GGML_TYPE_F32,
12845
+ LM_GGML_TYPE_F32,
12846
+ cparams.offload_kqv,
12847
+ std::max((uint32_t) 1, cparams.n_seq_max));
12529
12848
  } break;
12530
12849
  default:
12531
12850
  {
12532
- res = new llama_kv_cache_unified(hparams, {
12533
- /*.get_rope_factors =*/ [this](uint32_t n_ctx_per_seq, int il) {
12534
- // choose long/short freq factors based on the context size
12535
- if (layers[il].rope_freqs != nullptr) {
12536
- return layers[il].rope_freqs;
12537
- }
12851
+ const auto padding = llama_kv_cache_unified::get_padding(cparams);
12538
12852
 
12539
- if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) {
12540
- return layers[il].rope_long;
12541
- }
12853
+ cparams.n_ctx = LM_GGML_PAD(cparams.n_ctx, padding);
12542
12854
 
12543
- return layers[il].rope_short;
12544
- }
12545
- });
12855
+ LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
12856
+
12857
+ res = new llama_kv_cache_unified(
12858
+ *this,
12859
+ params.type_k,
12860
+ params.type_v,
12861
+ !cparams.flash_attn,
12862
+ cparams.offload_kqv,
12863
+ cparams.n_ctx,
12864
+ padding);
12546
12865
  }
12547
12866
  }
12548
12867
 
@@ -12591,6 +12910,7 @@ llm_graph_result_ptr llama_model::build_graph(
12591
12910
  case LLM_ARCH_BERT:
12592
12911
  case LLM_ARCH_JINA_BERT_V2:
12593
12912
  case LLM_ARCH_NOMIC_BERT:
12913
+ case LLM_ARCH_NOMIC_BERT_MOE:
12594
12914
  {
12595
12915
  llm = std::make_unique<llm_build_bert>(*this, params, gf);
12596
12916
  } break;
@@ -12735,6 +13055,10 @@ llm_graph_result_ptr llama_model::build_graph(
12735
13055
  {
12736
13056
  llm = std::make_unique<llm_build_chatglm>(*this, params, gf);
12737
13057
  } break;
13058
+ case LLM_ARCH_GLM4:
13059
+ {
13060
+ llm = std::make_unique<llm_build_glm4>(*this, params, gf);
13061
+ } break;
12738
13062
  case LLM_ARCH_BITNET:
12739
13063
  {
12740
13064
  llm = std::make_unique<llm_build_bitnet>(*this, params, gf);
@@ -12919,8 +13243,6 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
12919
13243
  case LLM_ARCH_DECI:
12920
13244
  case LLM_ARCH_BAICHUAN:
12921
13245
  case LLM_ARCH_STARCODER:
12922
- case LLM_ARCH_PLAMO:
12923
- case LLM_ARCH_ORION:
12924
13246
  case LLM_ARCH_INTERNLM2:
12925
13247
  case LLM_ARCH_MINICPM:
12926
13248
  case LLM_ARCH_XVERSE:
@@ -12932,6 +13254,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
12932
13254
  case LLM_ARCH_DEEPSEEK2:
12933
13255
  case LLM_ARCH_PLM:
12934
13256
  case LLM_ARCH_CHATGLM:
13257
+ case LLM_ARCH_GLM4:
12935
13258
  case LLM_ARCH_GRANITE:
12936
13259
  case LLM_ARCH_GRANITE_MOE:
12937
13260
  case LLM_ARCH_CHAMELEON:
@@ -12944,6 +13267,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
12944
13267
  case LLM_ARCH_DBRX:
12945
13268
  case LLM_ARCH_BERT:
12946
13269
  case LLM_ARCH_NOMIC_BERT:
13270
+ case LLM_ARCH_NOMIC_BERT_MOE:
12947
13271
  case LLM_ARCH_STABLELM:
12948
13272
  case LLM_ARCH_BITNET:
12949
13273
  case LLM_ARCH_QWEN:
@@ -12956,6 +13280,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
12956
13280
  case LLM_ARCH_PHI2:
12957
13281
  case LLM_ARCH_PHI3:
12958
13282
  case LLM_ARCH_PHIMOE:
13283
+ case LLM_ARCH_PLAMO:
12959
13284
  case LLM_ARCH_GEMMA:
12960
13285
  case LLM_ARCH_GEMMA2:
12961
13286
  case LLM_ARCH_GEMMA3:
@@ -12963,6 +13288,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
12963
13288
  case LLM_ARCH_OPENELM:
12964
13289
  case LLM_ARCH_GPTNEOX:
12965
13290
  case LLM_ARCH_CODESHELL:
13291
+ case LLM_ARCH_ORION:
12966
13292
  case LLM_ARCH_NEMOTRON:
12967
13293
  case LLM_ARCH_EXAONE:
12968
13294
  case LLM_ARCH_MINICPM3: