cui-llama.rn 1.7.3 → 1.7.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (276) hide show
  1. package/README.md +217 -17
  2. package/android/src/main/CMakeLists.txt +34 -15
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +94 -8
  4. package/android/src/main/java/com/rnllama/RNLlama.java +247 -0
  5. package/android/src/main/jni.cpp +213 -14
  6. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  10. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  11. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  12. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  13. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  14. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +35 -0
  15. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +34 -0
  16. package/cpp/README.md +1 -1
  17. package/cpp/chat-parser.cpp +385 -0
  18. package/cpp/chat-parser.h +120 -0
  19. package/cpp/chat.cpp +726 -596
  20. package/cpp/chat.h +71 -6
  21. package/cpp/common.cpp +56 -38
  22. package/cpp/common.h +9 -3
  23. package/cpp/ggml-backend-reg.cpp +5 -0
  24. package/cpp/ggml-backend.cpp +10 -2
  25. package/cpp/ggml-common.h +4 -0
  26. package/cpp/ggml-cpu/amx/amx.cpp +1 -1
  27. package/cpp/ggml-cpu/amx/mmq.cpp +11 -10
  28. package/cpp/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  29. package/cpp/ggml-cpu/arch/arm/quants.c +4114 -0
  30. package/cpp/ggml-cpu/arch/arm/repack.cpp +2163 -0
  31. package/cpp/ggml-cpu/arch/x86/cpu-feats.cpp +327 -0
  32. package/cpp/ggml-cpu/arch/x86/quants.c +4311 -0
  33. package/cpp/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
  34. package/cpp/ggml-cpu/arch-fallback.h +184 -0
  35. package/cpp/ggml-cpu/common.h +4 -3
  36. package/cpp/ggml-cpu/ggml-cpu-impl.h +21 -16
  37. package/cpp/ggml-cpu/ggml-cpu.c +123 -104
  38. package/cpp/ggml-cpu/ggml-cpu.cpp +11 -8
  39. package/cpp/ggml-cpu/ops.cpp +330 -148
  40. package/cpp/ggml-cpu/ops.h +1 -0
  41. package/cpp/ggml-cpu/quants.c +1158 -0
  42. package/cpp/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  43. package/cpp/ggml-cpu/repack.cpp +1571 -0
  44. package/cpp/ggml-cpu/repack.h +98 -0
  45. package/cpp/ggml-cpu/simd-mappings.h +330 -38
  46. package/cpp/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  47. package/cpp/ggml-cpu/vec.cpp +87 -18
  48. package/cpp/ggml-cpu/vec.h +249 -94
  49. package/cpp/ggml-cpu.h +1 -0
  50. package/cpp/ggml-impl.h +63 -183
  51. package/cpp/ggml-llama-sim.metallib +0 -0
  52. package/cpp/ggml-llama.metallib +0 -0
  53. package/cpp/ggml-metal.m +152 -45
  54. package/cpp/ggml-quants.c +0 -2
  55. package/cpp/ggml.c +61 -21
  56. package/cpp/ggml.h +22 -3
  57. package/cpp/gguf.cpp +24 -3
  58. package/cpp/json-partial.cpp +256 -0
  59. package/cpp/json-partial.h +38 -0
  60. package/cpp/json-schema-to-grammar.cpp +5 -47
  61. package/cpp/json-schema-to-grammar.h +4 -4
  62. package/cpp/llama-arch.cpp +153 -3
  63. package/cpp/llama-arch.h +27 -1
  64. package/cpp/llama-batch.cpp +741 -272
  65. package/cpp/llama-batch.h +112 -54
  66. package/cpp/llama-chat.cpp +30 -8
  67. package/cpp/llama-chat.h +1 -0
  68. package/cpp/llama-context.cpp +524 -339
  69. package/cpp/llama-context.h +38 -17
  70. package/cpp/llama-cparams.cpp +4 -0
  71. package/cpp/llama-cparams.h +2 -0
  72. package/cpp/llama-grammar.cpp +12 -2
  73. package/cpp/llama-graph.cpp +431 -356
  74. package/cpp/llama-graph.h +126 -58
  75. package/cpp/llama-hparams.cpp +10 -2
  76. package/cpp/llama-hparams.h +19 -2
  77. package/cpp/llama-kv-cache-unified-iswa.cpp +279 -0
  78. package/cpp/llama-kv-cache-unified-iswa.h +128 -0
  79. package/cpp/llama-kv-cache-unified.cpp +1841 -0
  80. package/cpp/llama-kv-cache-unified.h +303 -0
  81. package/cpp/llama-kv-cells.h +439 -0
  82. package/cpp/llama-memory-hybrid.cpp +246 -0
  83. package/cpp/llama-memory-hybrid.h +138 -0
  84. package/cpp/llama-memory-recurrent.cpp +1112 -0
  85. package/cpp/llama-memory-recurrent.h +183 -0
  86. package/cpp/llama-memory.cpp +41 -0
  87. package/cpp/llama-memory.h +86 -5
  88. package/cpp/llama-mmap.cpp +1 -1
  89. package/cpp/llama-model-loader.cpp +42 -17
  90. package/cpp/llama-model-saver.cpp +1 -0
  91. package/cpp/llama-model.cpp +1639 -513
  92. package/cpp/llama-model.h +26 -0
  93. package/cpp/llama-sampling.cpp +2 -2
  94. package/cpp/llama-vocab.cpp +65 -28
  95. package/cpp/llama-vocab.h +1 -0
  96. package/cpp/llama.cpp +11 -7
  97. package/cpp/llama.h +150 -42
  98. package/cpp/minja/chat-template.hpp +1 -1
  99. package/cpp/minja/minja.hpp +1 -1
  100. package/cpp/{json.hpp → nlohmann/json.hpp} +3027 -2267
  101. package/cpp/nlohmann/json_fwd.hpp +187 -0
  102. package/cpp/regex-partial.cpp +204 -0
  103. package/cpp/regex-partial.h +56 -0
  104. package/cpp/rn-llama.cpp +646 -35
  105. package/cpp/rn-llama.h +32 -1
  106. package/cpp/rn-tts.h +39 -0
  107. package/cpp/sampling.cpp +7 -8
  108. package/cpp/tools/mtmd/clip-impl.h +5 -0
  109. package/cpp/tools/mtmd/clip.cpp +572 -436
  110. package/cpp/tools/mtmd/clip.h +14 -4
  111. package/cpp/tools/mtmd/mtmd-audio.cpp +0 -86
  112. package/cpp/tools/mtmd/mtmd-audio.h +2 -17
  113. package/cpp/tools/mtmd/mtmd-helper.cpp +175 -12
  114. package/cpp/tools/mtmd/mtmd-helper.h +91 -0
  115. package/cpp/tools/mtmd/mtmd.cpp +368 -248
  116. package/cpp/tools/mtmd/mtmd.h +6 -70
  117. package/cpp/unicode.cpp +5 -0
  118. package/ios/CMakeLists.txt +26 -6
  119. package/ios/RNLlama.h +1 -1
  120. package/ios/RNLlama.mm +153 -3
  121. package/ios/RNLlamaContext.h +9 -1
  122. package/ios/RNLlamaContext.mm +112 -9
  123. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat-parser.h +120 -0
  124. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat.h +71 -6
  125. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +9 -3
  126. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-common.h +4 -0
  127. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu.h +1 -0
  128. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-impl.h +63 -183
  129. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +22 -3
  130. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/json-partial.h +38 -0
  131. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/json-schema-to-grammar.h +4 -4
  132. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-arch.h +27 -1
  133. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-batch.h +112 -54
  134. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
  135. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +38 -17
  136. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cparams.h +2 -0
  137. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +126 -58
  138. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +19 -2
  139. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache-unified-iswa.h +128 -0
  140. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache-unified.h +303 -0
  141. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cells.h +439 -0
  142. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory-hybrid.h +138 -0
  143. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory-recurrent.h +183 -0
  144. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +86 -5
  145. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +26 -0
  146. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-vocab.h +1 -0
  147. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +150 -42
  148. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/chat-template.hpp +1 -1
  149. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/minja.hpp +1 -1
  150. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/{json.hpp → nlohmann/json.hpp} +3027 -2267
  151. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/nlohmann/json_fwd.hpp +187 -0
  152. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/regex-partial.h +56 -0
  153. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +32 -1
  154. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-tts.h +39 -0
  155. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  156. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
  157. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat-parser.h +120 -0
  158. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +71 -6
  159. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +9 -3
  160. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-common.h +4 -0
  161. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +1 -0
  162. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +63 -183
  163. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +22 -3
  164. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/json-partial.h +38 -0
  165. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/json-schema-to-grammar.h +4 -4
  166. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +27 -1
  167. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +112 -54
  168. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
  169. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +38 -17
  170. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +2 -0
  171. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +126 -58
  172. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +19 -2
  173. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache-unified-iswa.h +128 -0
  174. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache-unified.h +303 -0
  175. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cells.h +439 -0
  176. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory-hybrid.h +138 -0
  177. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory-recurrent.h +183 -0
  178. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +86 -5
  179. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +26 -0
  180. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +1 -0
  181. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +150 -42
  182. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +1 -1
  183. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +1 -1
  184. package/ios/rnllama.xcframework/{tvos-arm64/rnllama.framework/Headers → ios-arm64_x86_64-simulator/rnllama.framework/Headers/nlohmann}/json.hpp +3027 -2267
  185. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/nlohmann/json_fwd.hpp +187 -0
  186. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/regex-partial.h +56 -0
  187. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +32 -1
  188. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-tts.h +39 -0
  189. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  190. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  191. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat-parser.h +120 -0
  192. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat.h +71 -6
  193. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +9 -3
  194. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-common.h +4 -0
  195. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu.h +1 -0
  196. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-impl.h +63 -183
  197. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +22 -3
  198. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/json-partial.h +38 -0
  199. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/json-schema-to-grammar.h +4 -4
  200. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-arch.h +27 -1
  201. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-batch.h +112 -54
  202. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
  203. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +38 -17
  204. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cparams.h +2 -0
  205. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +126 -58
  206. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +19 -2
  207. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache-unified-iswa.h +128 -0
  208. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache-unified.h +303 -0
  209. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cells.h +439 -0
  210. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory-hybrid.h +138 -0
  211. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory-recurrent.h +183 -0
  212. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +86 -5
  213. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +26 -0
  214. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-vocab.h +1 -0
  215. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +150 -42
  216. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/chat-template.hpp +1 -1
  217. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/minja.hpp +1 -1
  218. package/ios/rnllama.xcframework/{ios-arm64_x86_64-simulator/rnllama.framework/Headers → tvos-arm64/rnllama.framework/Headers/nlohmann}/json.hpp +3027 -2267
  219. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/nlohmann/json_fwd.hpp +187 -0
  220. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/regex-partial.h +56 -0
  221. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +32 -1
  222. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-tts.h +39 -0
  223. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  224. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
  225. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat-parser.h +120 -0
  226. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +71 -6
  227. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +9 -3
  228. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-common.h +4 -0
  229. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +1 -0
  230. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +63 -183
  231. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +22 -3
  232. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json-partial.h +38 -0
  233. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json-schema-to-grammar.h +4 -4
  234. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +27 -1
  235. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +112 -54
  236. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
  237. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +38 -17
  238. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +2 -0
  239. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +126 -58
  240. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +19 -2
  241. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache-unified-iswa.h +128 -0
  242. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache-unified.h +303 -0
  243. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cells.h +439 -0
  244. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory-hybrid.h +138 -0
  245. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory-recurrent.h +183 -0
  246. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +86 -5
  247. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +26 -0
  248. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +1 -0
  249. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +150 -42
  250. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +1 -1
  251. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +1 -1
  252. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/nlohmann/json.hpp +25526 -0
  253. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/nlohmann/json_fwd.hpp +187 -0
  254. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/regex-partial.h +56 -0
  255. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +32 -1
  256. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-tts.h +39 -0
  257. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  258. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  259. package/jest/mock.js +24 -0
  260. package/package.json +1 -1
  261. package/src/NativeRNLlama.ts +46 -2
  262. package/src/index.ts +105 -1
  263. package/cpp/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  264. package/cpp/ggml-cpu/ggml-cpu-quants.c +0 -13326
  265. package/cpp/ggml-cpu/sgemm.cpp +0 -3544
  266. package/cpp/ggml-cpu/sgemm.h +0 -14
  267. package/cpp/llama-kv-cache.cpp +0 -2827
  268. package/cpp/llama-kv-cache.h +0 -515
  269. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +0 -515
  270. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +0 -515
  271. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +0 -515
  272. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json.hpp +0 -24766
  273. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +0 -515
  274. /package/cpp/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
  275. /package/cpp/tools/mtmd/{miniaudio.h → miniaudio/miniaudio.h} +0 -0
  276. /package/cpp/tools/mtmd/{stb_image.h → stb/stb_image.h} +0 -0
@@ -5,7 +5,11 @@
5
5
  #include "llama-batch.h"
6
6
  #include "llama-cparams.h"
7
7
  #include "llama-model-loader.h"
8
- #include "llama-kv-cache.h"
8
+
9
+ #include "llama-kv-cache-unified.h"
10
+ #include "llama-kv-cache-unified-iswa.h"
11
+ #include "llama-memory-hybrid.h"
12
+ #include "llama-memory-recurrent.h"
9
13
 
10
14
  #include "ggml-cpp.h"
11
15
 
@@ -77,6 +81,7 @@ const char * llm_type_name(llm_type type) {
77
81
  case LLM_TYPE_40B: return "40B";
78
82
  case LLM_TYPE_65B: return "65B";
79
83
  case LLM_TYPE_70B: return "70B";
84
+ case LLM_TYPE_142B: return "142B";
80
85
  case LLM_TYPE_236B: return "236B";
81
86
  case LLM_TYPE_290B: return "290B";
82
87
  case LLM_TYPE_314B: return "314B";
@@ -98,6 +103,8 @@ const char * llm_type_name(llm_type type) {
98
103
  case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)";
99
104
  case LLM_TYPE_30B_A3B: return "30B.A3B";
100
105
  case LLM_TYPE_235B_A22B: return "235B.A22B";
106
+ case LLM_TYPE_E2B: return "E2B";
107
+ case LLM_TYPE_E4B: return "E4B";
101
108
  default: return "?B";
102
109
  }
103
110
  }
@@ -466,6 +473,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
466
473
  std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0);
467
474
  std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
468
475
  std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0);
476
+ std::fill(
477
+ hparams.recurrent_layer_arr.begin(),
478
+ hparams.recurrent_layer_arr.end(),
479
+ llm_arch_is_recurrent(ml.get_arch()));
469
480
 
470
481
  std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0);
471
482
 
@@ -540,6 +551,12 @@ void llama_model::load_hparams(llama_model_loader & ml) {
540
551
  uint32_t n_vocab = 0;
541
552
  ml.get_key(LLM_KV_VOCAB_SIZE, n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, n_vocab, false);
542
553
 
554
+ // for classifier models
555
+ ml.get_arr(LLM_KV_CLASSIFIER_OUTPUT_LABELS, classifier_labels, false);
556
+ if (!classifier_labels.empty()) {
557
+ hparams.n_cls_out = classifier_labels.size();
558
+ }
559
+
543
560
  // arch-specific KVs
544
561
  switch (arch) {
545
562
  case LLM_ARCH_LLAMA:
@@ -589,6 +606,16 @@ void llama_model::load_hparams(llama_model_loader & ml) {
589
606
  hparams.use_kq_norm = false;
590
607
  }
591
608
  } break;
609
+ case LLM_ARCH_ARCEE:
610
+ {
611
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
612
+
613
+ // Arcee uses the same structure as Llama
614
+ switch (hparams.n_layer) {
615
+ case 36: type = LLM_TYPE_4B; break;
616
+ default: type = LLM_TYPE_UNKNOWN;
617
+ }
618
+ } break;
592
619
  case LLM_ARCH_DECI:
593
620
  {
594
621
  ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -729,6 +756,16 @@ void llama_model::load_hparams(llama_model_loader & ml) {
729
756
  }
730
757
  }
731
758
  } break;
759
+ case LLM_ARCH_NEO_BERT:
760
+ {
761
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
762
+ ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
763
+ ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type);
764
+
765
+ if (hparams.n_layer == 28) {
766
+ type = LLM_TYPE_250M;
767
+ }
768
+ } break;
732
769
  case LLM_ARCH_BLOOM:
733
770
  {
734
771
  ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -952,6 +989,11 @@ void llama_model::load_hparams(llama_model_loader & ml) {
952
989
  case 46: type = LLM_TYPE_27B; break;
953
990
  default: type = LLM_TYPE_UNKNOWN;
954
991
  }
992
+
993
+ // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L173
994
+ hparams.f_attention_scale = type == LLM_TYPE_27B
995
+ ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
996
+ : 1.0f / std::sqrt(float(hparams.n_embd_head_k));
955
997
  } break;
956
998
  case LLM_ARCH_GEMMA3:
957
999
  {
@@ -972,10 +1014,29 @@ void llama_model::load_hparams(llama_model_loader & ml) {
972
1014
  default: type = LLM_TYPE_UNKNOWN;
973
1015
  }
974
1016
 
1017
+ // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L289
975
1018
  hparams.f_attention_scale = type == LLM_TYPE_27B
976
1019
  ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
977
1020
  : 1.0f / std::sqrt(float(hparams.n_embd_head_k));
978
1021
  } break;
1022
+ case LLM_ARCH_GEMMA3N:
1023
+ {
1024
+ hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
1025
+ hparams.set_swa_pattern(5);
1026
+
1027
+ hparams.rope_freq_base_train_swa = 10000.0f;
1028
+ hparams.rope_freq_scale_train_swa = 1.0f;
1029
+ hparams.f_attention_scale = 1.0f;
1030
+
1031
+ ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
1032
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1033
+
1034
+ switch (hparams.n_layer) {
1035
+ case 30: type = LLM_TYPE_E2B; break;
1036
+ case 35: type = LLM_TYPE_E4B; break;
1037
+ default: type = LLM_TYPE_UNKNOWN;
1038
+ }
1039
+ } break;
979
1040
  case LLM_ARCH_STARCODER2:
980
1041
  {
981
1042
  ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -1429,6 +1490,20 @@ void llama_model::load_hparams(llama_model_loader & ml) {
1429
1490
  default: type = LLM_TYPE_UNKNOWN;
1430
1491
  }
1431
1492
  } break;
1493
+ case LLM_ARCH_DOTS1:
1494
+ {
1495
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1496
+ ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
1497
+ ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
1498
+ ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
1499
+ ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale);
1500
+ ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false);
1501
+ ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false);
1502
+ switch (hparams.n_layer) {
1503
+ case 62: type = LLM_TYPE_142B; break;
1504
+ default: type = LLM_TYPE_UNKNOWN;
1505
+ }
1506
+ } break;
1432
1507
  default: throw std::runtime_error("unsupported model architecture");
1433
1508
  }
1434
1509
 
@@ -2113,7 +2188,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
2113
2188
  case LLM_ARCH_NOMIC_BERT_MOE:
2114
2189
  {
2115
2190
  tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
2116
- type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, 0);
2191
+ type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, TENSOR_NOT_REQUIRED);
2117
2192
 
2118
2193
  if (arch == LLM_ARCH_BERT) {
2119
2194
  pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0);
@@ -2121,8 +2196,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
2121
2196
  cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED);
2122
2197
  cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED);
2123
2198
 
2124
- cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, 1}, TENSOR_NOT_REQUIRED);
2125
- cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {1}, TENSOR_NOT_REQUIRED);
2199
+ cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED);
2200
+ cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED);
2126
2201
  }
2127
2202
 
2128
2203
  tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
@@ -2131,7 +2206,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
2131
2206
  for (int i = 0; i < n_layer; ++i) {
2132
2207
  auto & layer = layers[i];
2133
2208
 
2134
- if (arch == LLM_ARCH_BERT) {
2209
+ layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED);
2210
+ layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED);
2211
+
2212
+ if (!layer.wqkv) {
2135
2213
  layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
2136
2214
  layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0);
2137
2215
 
@@ -2140,12 +2218,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
2140
2218
 
2141
2219
  layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
2142
2220
  layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0);
2143
- } else {
2144
- layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
2145
- }
2146
-
2147
- if (arch == LLM_ARCH_NOMIC_BERT_MOE) {
2148
- layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0);
2149
2221
  }
2150
2222
 
2151
2223
  layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
@@ -2175,6 +2247,32 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
2175
2247
  layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0);
2176
2248
  }
2177
2249
  } break;
2250
+ case LLM_ARCH_NEO_BERT:
2251
+ {
2252
+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
2253
+
2254
+ cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED);
2255
+ cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED);
2256
+
2257
+ cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED);
2258
+ cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED);
2259
+
2260
+ output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0);
2261
+
2262
+ for (int i = 0; i < n_layer; ++i) {
2263
+ auto & layer = layers[i];
2264
+
2265
+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
2266
+
2267
+ layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
2268
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
2269
+
2270
+ layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
2271
+
2272
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff*2}, 0);
2273
+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
2274
+ }
2275
+ } break;
2178
2276
  case LLM_ARCH_JINA_BERT_V2:
2179
2277
  {
2180
2278
  tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // word_embeddings
@@ -2212,8 +2310,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
2212
2310
  layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED);
2213
2311
  layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
2214
2312
 
2215
- layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
2216
- layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
2313
+ layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED);
2314
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, layer.ffn_gate ? n_ff : n_ff * 2}, 0);
2217
2315
 
2218
2316
  layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
2219
2317
  layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0);
@@ -2489,7 +2587,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
2489
2587
 
2490
2588
  // output
2491
2589
  output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
2492
- output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
2590
+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
2591
+ // if output is NULL, init from the input tok embed
2592
+ if (output == NULL) {
2593
+ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
2594
+ }
2493
2595
 
2494
2596
  for (int i = 0; i < n_layer; ++i) {
2495
2597
  auto & layer = layers[i];
@@ -2868,6 +2970,62 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
2868
2970
  layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
2869
2971
  }
2870
2972
  } break;
2973
+ case LLM_ARCH_GEMMA3N:
2974
+ {
2975
+ const int64_t n_altup = hparams.n_altup;
2976
+ const int64_t laurel_rank = hparams.laurel_rank;
2977
+ const int64_t n_embd_altup = hparams.n_embd_altup;
2978
+
2979
+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
2980
+ // if output is NULL, init from the input tok embed
2981
+ if (output == NULL) {
2982
+ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
2983
+ }
2984
+
2985
+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
2986
+ tok_embd_per_layer = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_altup * n_layer, n_vocab}, 0);
2987
+
2988
+ altup_proj = create_tensor(tn(LLM_TENSOR_ALTUP_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0);
2989
+ altup_unembd_proj = create_tensor(tn(LLM_TENSOR_ALTUP_UNEMBD_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0);
2990
+ per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight"), {n_embd, n_embd_altup * n_layer}, 0);
2991
+ per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight"), {n_embd_altup}, 0);
2992
+
2993
+ output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
2994
+
2995
+ for (int i = 0; i < n_layer; ++i) {
2996
+ auto & layer = layers[i];
2997
+
2998
+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
2999
+
3000
+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
3001
+ layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
3002
+ layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
3003
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
3004
+
3005
+ layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
3006
+ layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
3007
+ layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
3008
+
3009
+ layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
3010
+ layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
3011
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
3012
+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
3013
+ layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
3014
+
3015
+ // altup & laurel
3016
+ layer.per_layer_inp_gate = create_tensor(tn(LLM_TENSOR_PER_LAYER_INP_GATE, "weight", i), {n_embd, n_embd_altup}, 0);
3017
+ layer.per_layer_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ, "weight", i), {n_embd_altup, n_embd}, 0);
3018
+ layer.per_layer_post_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_POST_NORM, "weight", i), {n_embd}, 0);
3019
+ layer.altup_correct_coef = create_tensor(tn(LLM_TENSOR_ALTUP_CORRECT_COEF, "weight", i), {n_altup, n_altup}, 0);
3020
+ layer.altup_correct_scale = create_tensor(tn(LLM_TENSOR_ALTUP_CORRECT_SCALE, "weight", i), {n_embd}, 0);
3021
+ layer.altup_predict_coef = create_tensor(tn(LLM_TENSOR_ALTUP_PREDICT_COEF, "weight", i), {n_altup, n_altup * n_altup}, 0);
3022
+ layer.altup_router = create_tensor(tn(LLM_TENSOR_ALTUP_ROUTER, "weight", i), {n_embd, n_altup}, 0);
3023
+ layer.altup_router_norm = create_tensor(tn(LLM_TENSOR_ALTUP_ROUTER_NORM, "weight", i), {n_embd}, 0);
3024
+ layer.laurel_l = create_tensor(tn(LLM_TENSOR_LAUREL_L, "weight", i), {n_embd, laurel_rank}, 0);
3025
+ layer.laurel_r = create_tensor(tn(LLM_TENSOR_LAUREL_R, "weight", i), {laurel_rank, n_embd}, 0);
3026
+ layer.laurel_post_norm = create_tensor(tn(LLM_TENSOR_LAUREL_POST_NORM, "weight", i), {n_embd}, 0);
3027
+ }
3028
+ } break;
2871
3029
  case LLM_ARCH_STARCODER2:
2872
3030
  {
2873
3031
  tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -4107,6 +4265,89 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
4107
4265
  layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
4108
4266
  }
4109
4267
  } break;
4268
+ case LLM_ARCH_DOTS1:
4269
+ {
4270
+ const int64_t n_ff_exp = hparams.n_ff_exp;
4271
+ const int64_t n_expert_shared = hparams.n_expert_shared;
4272
+
4273
+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
4274
+
4275
+ output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
4276
+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
4277
+
4278
+ for (int i = 0; i < n_layer; ++i) {
4279
+ auto & layer = layers[i];
4280
+
4281
+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
4282
+
4283
+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
4284
+ layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
4285
+ layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
4286
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
4287
+
4288
+ layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
4289
+ layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
4290
+
4291
+ layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
4292
+
4293
+ if (i < (int) hparams.n_layer_dense_lead) {
4294
+ layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
4295
+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
4296
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
4297
+ } else {
4298
+ layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
4299
+ layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED);
4300
+
4301
+ if (n_expert == 0) {
4302
+ throw std::runtime_error("n_expert must be > 0");
4303
+ }
4304
+ if (n_expert_used == 0) {
4305
+ throw std::runtime_error("n_expert_used must be > 0");
4306
+ }
4307
+
4308
+ // MoE branch
4309
+ layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
4310
+ layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
4311
+ layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
4312
+
4313
+ // Shared expert branch
4314
+ layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
4315
+ layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0);
4316
+ layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
4317
+ }
4318
+ }
4319
+ } break;
4320
+ case LLM_ARCH_ARCEE:
4321
+ {
4322
+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
4323
+
4324
+ // output
4325
+ output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
4326
+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
4327
+
4328
+ // if output is NULL, init from the input tok embed
4329
+ if (output == NULL) {
4330
+ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
4331
+ }
4332
+
4333
+ for (int i = 0; i < n_layer; ++i) {
4334
+ auto & layer = layers[i];
4335
+
4336
+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
4337
+
4338
+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
4339
+ layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
4340
+ layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
4341
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
4342
+
4343
+ layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
4344
+
4345
+ layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
4346
+
4347
+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
4348
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
4349
+ }
4350
+ } break;
4110
4351
  default:
4111
4352
  throw std::runtime_error("unknown architecture");
4112
4353
  }
@@ -4351,6 +4592,15 @@ void llama_model::print_info() const {
4351
4592
  LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
4352
4593
  LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank);
4353
4594
  LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms);
4595
+
4596
+ if (!classifier_labels.empty()) {
4597
+ LLAMA_LOG_INFO("%s: n_cls_out = %u\n", __func__, hparams.n_cls_out);
4598
+
4599
+ size_t i = 0;
4600
+ for (auto label : classifier_labels) {
4601
+ LLAMA_LOG_INFO("%s: cls_label[%2zu] = %s\n", __func__, i++, label.c_str());
4602
+ }
4603
+ }
4354
4604
  }
4355
4605
 
4356
4606
  LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str());
@@ -4533,6 +4783,8 @@ struct llm_build_llama : public llm_graph_context {
4533
4783
 
4534
4784
  const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
4535
4785
 
4786
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
4787
+
4536
4788
  for (int il = 0; il < n_layer; ++il) {
4537
4789
  lm_ggml_tensor * inpSA = inpL;
4538
4790
 
@@ -4595,9 +4847,7 @@ struct llm_build_llama : public llm_graph_context {
4595
4847
  cb(cur, "attn_out", il);
4596
4848
  }
4597
4849
 
4598
- if (il == n_layer - 1) {
4599
- // skip computing output for unused tokens
4600
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
4850
+ if (il == n_layer - 1 && inp_out_ids) {
4601
4851
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
4602
4852
  inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
4603
4853
  }
@@ -4693,6 +4943,8 @@ struct llm_build_llama_iswa : public llm_graph_context {
4693
4943
 
4694
4944
  const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
4695
4945
 
4946
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
4947
+
4696
4948
  for (int il = 0; il < n_layer; ++il) {
4697
4949
  lm_ggml_tensor * inpSA = inpL;
4698
4950
 
@@ -4769,9 +5021,7 @@ struct llm_build_llama_iswa : public llm_graph_context {
4769
5021
  cb(cur, "attn_out", il);
4770
5022
  }
4771
5023
 
4772
- if (il == n_layer - 1) {
4773
- // skip computing output for unused tokens
4774
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
5024
+ if (il == n_layer - 1 && inp_out_ids) {
4775
5025
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
4776
5026
  inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
4777
5027
  }
@@ -4871,6 +5121,9 @@ struct llm_build_deci : public llm_graph_context {
4871
5121
  auto * inp_attn = build_attn_inp_kv_unified();
4872
5122
 
4873
5123
  const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
5124
+
5125
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
5126
+
4874
5127
  for (int il = 0; il < n_layer; ++il) {
4875
5128
  lm_ggml_tensor * inpSA = inpL;
4876
5129
  const int64_t n_head_kv = hparams.n_head_kv(il);
@@ -4944,9 +5197,7 @@ struct llm_build_deci : public llm_graph_context {
4944
5197
  Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
4945
5198
  }
4946
5199
 
4947
- if (il == n_layer - 1) {
4948
- // skip computing output for unused tokens
4949
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
5200
+ if (il == n_layer - 1 && inp_out_ids) {
4950
5201
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
4951
5202
  inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
4952
5203
  }
@@ -5025,6 +5276,8 @@ struct llm_build_baichuan : public llm_graph_context {
5025
5276
 
5026
5277
  auto * inp_attn = build_attn_inp_kv_unified();
5027
5278
 
5279
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
5280
+
5028
5281
  for (int il = 0; il < n_layer; ++il) {
5029
5282
  lm_ggml_tensor * inpSA = inpL;
5030
5283
 
@@ -5076,9 +5329,7 @@ struct llm_build_baichuan : public llm_graph_context {
5076
5329
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
5077
5330
  }
5078
5331
 
5079
- if (il == n_layer - 1) {
5080
- // skip computing output for unused tokens
5081
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
5332
+ if (il == n_layer - 1 && inp_out_ids) {
5082
5333
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
5083
5334
  inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
5084
5335
  }
@@ -5147,6 +5398,8 @@ struct llm_build_xverse : public llm_graph_context {
5147
5398
 
5148
5399
  auto * inp_attn = build_attn_inp_kv_unified();
5149
5400
 
5401
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
5402
+
5150
5403
  for (int il = 0; il < n_layer; ++il) {
5151
5404
  lm_ggml_tensor * inpSA = inpL;
5152
5405
 
@@ -5191,9 +5444,7 @@ struct llm_build_xverse : public llm_graph_context {
5191
5444
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
5192
5445
  }
5193
5446
 
5194
- if (il == n_layer - 1) {
5195
- // skip computing output for unused tokens
5196
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
5447
+ if (il == n_layer - 1 && inp_out_ids) {
5197
5448
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
5198
5449
  inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
5199
5450
  }
@@ -5261,6 +5512,8 @@ struct llm_build_falcon : public llm_graph_context {
5261
5512
 
5262
5513
  auto * inp_attn = build_attn_inp_kv_unified();
5263
5514
 
5515
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
5516
+
5264
5517
  for (int il = 0; il < n_layer; ++il) {
5265
5518
  lm_ggml_tensor * attn_norm;
5266
5519
 
@@ -5316,9 +5569,7 @@ struct llm_build_falcon : public llm_graph_context {
5316
5569
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
5317
5570
  }
5318
5571
 
5319
- if (il == n_layer - 1) {
5320
- // skip computing output for unused tokens
5321
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
5572
+ if (il == n_layer - 1 && inp_out_ids) {
5322
5573
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
5323
5574
  inpL = lm_ggml_get_rows(ctx0, inpL, inp_out_ids);
5324
5575
  attn_norm = lm_ggml_get_rows(ctx0, attn_norm, inp_out_ids);
@@ -5387,6 +5638,8 @@ struct llm_build_grok : public llm_graph_context {
5387
5638
 
5388
5639
  auto * inp_attn = build_attn_inp_kv_unified();
5389
5640
 
5641
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
5642
+
5390
5643
  for (int il = 0; il < n_layer; ++il) {
5391
5644
  lm_ggml_tensor * inpSA = inpL;
5392
5645
 
@@ -5446,9 +5699,7 @@ struct llm_build_grok : public llm_graph_context {
5446
5699
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
5447
5700
  }
5448
5701
 
5449
- if (il == n_layer - 1) {
5450
- // skip computing output for unused tokens
5451
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
5702
+ if (il == n_layer - 1 && inp_out_ids) {
5452
5703
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
5453
5704
  inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
5454
5705
  }
@@ -5547,6 +5798,8 @@ struct llm_build_dbrx : public llm_graph_context {
5547
5798
 
5548
5799
  auto * inp_attn = build_attn_inp_kv_unified();
5549
5800
 
5801
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
5802
+
5550
5803
  for (int il = 0; il < n_layer; ++il) {
5551
5804
  lm_ggml_tensor * inpSA = inpL;
5552
5805
 
@@ -5597,9 +5850,7 @@ struct llm_build_dbrx : public llm_graph_context {
5597
5850
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
5598
5851
  }
5599
5852
 
5600
- if (il == n_layer - 1) {
5601
- // skip computing output for unused tokens
5602
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
5853
+ if (il == n_layer - 1 && inp_out_ids) {
5603
5854
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
5604
5855
  inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
5605
5856
  }
@@ -5679,6 +5930,8 @@ struct llm_build_starcoder : public llm_graph_context {
5679
5930
  inpL = lm_ggml_add(ctx0, inpL, pos);
5680
5931
  cb(inpL, "inpL", -1);
5681
5932
 
5933
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
5934
+
5682
5935
  for (int il = 0; il < n_layer; ++il) {
5683
5936
  cur = build_norm(inpL,
5684
5937
  model.layers[il].attn_norm,
@@ -5711,9 +5964,7 @@ struct llm_build_starcoder : public llm_graph_context {
5711
5964
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
5712
5965
  }
5713
5966
 
5714
- if (il == n_layer - 1) {
5715
- // skip computing output for unused tokens
5716
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
5967
+ if (il == n_layer - 1 && inp_out_ids) {
5717
5968
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
5718
5969
  inpL = lm_ggml_get_rows(ctx0, inpL, inp_out_ids);
5719
5970
  }
@@ -5778,6 +6029,8 @@ struct llm_build_refact : public llm_graph_context {
5778
6029
 
5779
6030
  auto * inp_attn = build_attn_inp_kv_unified();
5780
6031
 
6032
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
6033
+
5781
6034
  for (int il = 0; il < n_layer; ++il) {
5782
6035
  lm_ggml_tensor * inpSA = inpL;
5783
6036
 
@@ -5810,9 +6063,7 @@ struct llm_build_refact : public llm_graph_context {
5810
6063
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
5811
6064
  }
5812
6065
 
5813
- if (il == n_layer - 1) {
5814
- // skip computing output for unused tokens
5815
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
6066
+ if (il == n_layer - 1 && inp_out_ids) {
5816
6067
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
5817
6068
  inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
5818
6069
  }
@@ -5883,8 +6134,10 @@ struct llm_build_bert : public llm_graph_context {
5883
6134
  inpL = build_inp_embd(model.tok_embd);
5884
6135
 
5885
6136
  // token types are hardcoded to zero ("Sentence A")
5886
- lm_ggml_tensor * type_row0 = lm_ggml_view_1d(ctx0, model.type_embd, n_embd, 0);
5887
- inpL = lm_ggml_add(ctx0, inpL, type_row0);
6137
+ if (model.type_embd) {
6138
+ lm_ggml_tensor * type_row0 = lm_ggml_view_1d(ctx0, model.type_embd, n_embd, 0);
6139
+ inpL = lm_ggml_add(ctx0, inpL, type_row0);
6140
+ }
5888
6141
  if (model.arch == LLM_ARCH_BERT) {
5889
6142
  inpL = lm_ggml_add(ctx0, lm_ggml_get_rows(ctx0, model.pos_embd, inp_pos), inpL);
5890
6143
  }
@@ -5896,17 +6149,34 @@ struct llm_build_bert : public llm_graph_context {
5896
6149
 
5897
6150
  auto * inp_attn = build_attn_inp_no_cache();
5898
6151
 
5899
- // iterate layers
6152
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
6153
+
5900
6154
  for (int il = 0; il < n_layer; ++il) {
5901
6155
  lm_ggml_tensor * cur = inpL;
5902
6156
 
5903
- lm_ggml_tensor * Qcur;
5904
- lm_ggml_tensor * Kcur;
5905
- lm_ggml_tensor * Vcur;
6157
+ {
6158
+ lm_ggml_tensor * Qcur;
6159
+ lm_ggml_tensor * Kcur;
6160
+ lm_ggml_tensor * Vcur;
5906
6161
 
5907
- // self-attention
5908
- if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_JINA_BERT_V2) {
5909
- Qcur = lm_ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq);
6162
+ // self-attention
6163
+ if (model.layers[il].wqkv) {
6164
+ cur = build_lora_mm(model.layers[il].wqkv, cur);
6165
+ cb(cur, "wqkv", il);
6166
+
6167
+ if (model.layers[il].bqkv) {
6168
+ cur = lm_ggml_add(ctx0, cur, model.layers[il].bqkv);
6169
+ cb(cur, "bqkv", il);
6170
+ }
6171
+
6172
+ Qcur = lm_ggml_cont(ctx0, lm_ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
6173
+ Kcur = lm_ggml_cont(ctx0, lm_ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
6174
+ Vcur = lm_ggml_cont(ctx0, lm_ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
6175
+ } else {
6176
+ Qcur = lm_ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq);
6177
+ Kcur = lm_ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk);
6178
+ Vcur = lm_ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv);
6179
+ }
5910
6180
 
5911
6181
  if (model.layers[il].attn_q_norm) {
5912
6182
  Qcur = build_norm(Qcur,
@@ -5915,8 +6185,6 @@ struct llm_build_bert : public llm_graph_context {
5915
6185
  LLM_NORM, il);
5916
6186
  }
5917
6187
 
5918
- Kcur = lm_ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk);
5919
-
5920
6188
  if (model.layers[il].attn_k_norm) {
5921
6189
  Kcur = build_norm(Kcur,
5922
6190
  model.layers[il].attn_k_norm,
@@ -5924,54 +6192,36 @@ struct llm_build_bert : public llm_graph_context {
5924
6192
  LLM_NORM, il);
5925
6193
  }
5926
6194
 
5927
- Vcur = lm_ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv);
5928
-
5929
6195
  Qcur = lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
5930
6196
  Kcur = lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
5931
6197
  Vcur = lm_ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
5932
- } else {
5933
- // compute Q and K and RoPE them
5934
- cur = build_lora_mm(model.layers[il].wqkv, cur);
5935
- cb(cur, "wqkv", il);
5936
6198
 
5937
- if (model.arch == LLM_ARCH_NOMIC_BERT_MOE) {
5938
- cur = lm_ggml_add(ctx0, cur, model.layers[il].bqkv);
5939
- cb(cur, "bqkv", il);
6199
+ // RoPE
6200
+ if (model.arch == LLM_ARCH_NOMIC_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE) {
6201
+ Qcur = lm_ggml_rope_ext(
6202
+ ctx0, Qcur, inp_pos, nullptr,
6203
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
6204
+ ext_factor, attn_factor, beta_fast, beta_slow
6205
+ );
6206
+
6207
+ Kcur = lm_ggml_rope_ext(
6208
+ ctx0, Kcur, inp_pos, nullptr,
6209
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
6210
+ ext_factor, attn_factor, beta_fast, beta_slow
6211
+ );
5940
6212
  }
5941
6213
 
5942
- Qcur = lm_ggml_cont(ctx0, lm_ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
5943
- Kcur = lm_ggml_cont(ctx0, lm_ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
5944
- Vcur = lm_ggml_cont(ctx0, lm_ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
6214
+ cb(Qcur, "Qcur", il);
6215
+ cb(Kcur, "Kcur", il);
6216
+ cb(Vcur, "Vcur", il);
5945
6217
 
5946
- Qcur = lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
5947
- Kcur = lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
5948
- Vcur = lm_ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
6218
+ cur = build_attn(inp_attn, gf,
6219
+ model.layers[il].wo, model.layers[il].bo,
6220
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6221
+ cb(cur, "kqv_out", il);
6222
+ }
5949
6223
 
5950
- Qcur = lm_ggml_rope_ext(
5951
- ctx0, Qcur, inp_pos, nullptr,
5952
- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
5953
- ext_factor, attn_factor, beta_fast, beta_slow
5954
- );
5955
-
5956
- Kcur = lm_ggml_rope_ext(
5957
- ctx0, Kcur, inp_pos, nullptr,
5958
- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
5959
- ext_factor, attn_factor, beta_fast, beta_slow
5960
- );
5961
- }
5962
-
5963
- cb(Qcur, "Qcur", il);
5964
- cb(Kcur, "Kcur", il);
5965
- cb(Vcur, "Vcur", il);
5966
-
5967
- cur = build_attn(inp_attn, gf,
5968
- model.layers[il].wo, model.layers[il].bo,
5969
- Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
5970
- cb(cur, "kqv_out", il);
5971
-
5972
- if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) {
5973
- // skip computing output for unused tokens
5974
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
6224
+ if (il == n_layer - 1 && inp_out_ids) {
5975
6225
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
5976
6226
  inpL = lm_ggml_get_rows(ctx0, inpL, inp_out_ids);
5977
6227
  }
@@ -6020,7 +6270,7 @@ struct llm_build_bert : public llm_graph_context {
6020
6270
  model.layers[il].ffn_gate, NULL, NULL,
6021
6271
  model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
6022
6272
  NULL,
6023
- LLM_FFN_GELU, LLM_FFN_PAR, il);
6273
+ model.layers[il].ffn_gate ? LLM_FFN_GELU : LLM_FFN_GEGLU, LLM_FFN_PAR, il);
6024
6274
  cb(cur, "ffn_out", il);
6025
6275
  } else {
6026
6276
  cur = build_ffn(cur,
@@ -6051,6 +6301,118 @@ struct llm_build_bert : public llm_graph_context {
6051
6301
  }
6052
6302
  };
6053
6303
 
6304
+ struct llm_build_neo_bert : public llm_graph_context {
6305
+ llm_build_neo_bert(const llama_model & model, const llm_graph_params & params, lm_ggml_cgraph * gf) : llm_graph_context(params) {
6306
+ const int64_t n_embd_head = hparams.n_embd_head_v;
6307
+ const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
6308
+
6309
+ LM_GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
6310
+
6311
+ lm_ggml_tensor * cur;
6312
+ lm_ggml_tensor * inpL;
6313
+ lm_ggml_tensor * inp_pos = build_inp_pos();
6314
+
6315
+ // construct input embeddings (token, type, position)
6316
+ inpL = build_inp_embd(model.tok_embd);
6317
+ cb(inpL, "inp_embd", -1);
6318
+
6319
+ auto * inp_attn = build_attn_inp_no_cache();
6320
+
6321
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
6322
+
6323
+ for (int il = 0; il < n_layer; ++il) {
6324
+ lm_ggml_tensor * cur = inpL;
6325
+
6326
+ // pre-norm
6327
+ cur = build_norm(inpL,
6328
+ model.layers[il].attn_norm, NULL,
6329
+ LLM_NORM_RMS, il);
6330
+
6331
+ {
6332
+ lm_ggml_tensor * Qcur;
6333
+ lm_ggml_tensor * Kcur;
6334
+ lm_ggml_tensor * Vcur;
6335
+
6336
+ // self-attention
6337
+ cur = build_lora_mm(model.layers[il].wqkv, cur);
6338
+ cb(cur, "wqkv", il);
6339
+
6340
+ Qcur = lm_ggml_cont(ctx0, lm_ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
6341
+ Kcur = lm_ggml_cont(ctx0, lm_ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
6342
+ Vcur = lm_ggml_cont(ctx0, lm_ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
6343
+
6344
+ Qcur = lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
6345
+ Kcur = lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
6346
+ Vcur = lm_ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
6347
+
6348
+ // RoPE
6349
+ Qcur = lm_ggml_rope_ext(
6350
+ ctx0, Qcur, inp_pos, nullptr,
6351
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
6352
+ ext_factor, attn_factor, beta_fast, beta_slow
6353
+ );
6354
+
6355
+ Kcur = lm_ggml_rope_ext(
6356
+ ctx0, Kcur, inp_pos, nullptr,
6357
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
6358
+ ext_factor, attn_factor, beta_fast, beta_slow
6359
+ );
6360
+
6361
+ cb(Qcur, "Qcur", il);
6362
+ cb(Kcur, "Kcur", il);
6363
+ cb(Vcur, "Vcur", il);
6364
+
6365
+ cur = build_attn(inp_attn, gf,
6366
+ model.layers[il].wo, nullptr,
6367
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6368
+ cb(cur, "kqv_out", il);
6369
+ }
6370
+
6371
+ if (il == n_layer - 1 && inp_out_ids) {
6372
+ cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
6373
+ inpL = lm_ggml_get_rows(ctx0, inpL, inp_out_ids);
6374
+ }
6375
+
6376
+ // re-add the layer input
6377
+ cur = lm_ggml_add(ctx0, cur, inpL);
6378
+
6379
+ lm_ggml_tensor * ffn_inp = cur;
6380
+ cb(ffn_inp, "ffn_inp", il);
6381
+
6382
+ // pre-norm
6383
+ cur = build_norm(ffn_inp,
6384
+ model.layers[il].ffn_norm, NULL,
6385
+ LLM_NORM_RMS, il);
6386
+ cb(cur, "ffn_norm", il);
6387
+
6388
+ // feed-forward network
6389
+ cur = build_ffn(cur,
6390
+ model.layers[il].ffn_up,
6391
+ NULL, NULL, NULL, NULL, NULL,
6392
+ model.layers[il].ffn_down,
6393
+ NULL, NULL, NULL,
6394
+ LLM_FFN_SWIGLU, LLM_FFN_SEQ, il);
6395
+
6396
+ // attentions bypass the intermediate layer
6397
+ cur = lm_ggml_add(ctx0, cur, ffn_inp);
6398
+
6399
+ // input for next layer
6400
+ inpL = cur;
6401
+ }
6402
+
6403
+ cur = inpL;
6404
+
6405
+ cur = build_norm(cur,
6406
+ model.output_norm_enc, NULL,
6407
+ LLM_NORM_RMS, -1);
6408
+
6409
+ cb(cur, "result_embd", -1);
6410
+ res->t_embd = cur;
6411
+
6412
+ lm_ggml_build_forward_expand(gf, cur);
6413
+ }
6414
+ };
6415
+
6054
6416
  struct llm_build_bloom : public llm_graph_context {
6055
6417
  llm_build_bloom(const llama_model & model, const llm_graph_params & params, lm_ggml_cgraph * gf) : llm_graph_context(params) {
6056
6418
  const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -6071,6 +6433,8 @@ struct llm_build_bloom : public llm_graph_context {
6071
6433
  LLM_NORM, -1);
6072
6434
  cb(inpL, "inp_norm", -1);
6073
6435
 
6436
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
6437
+
6074
6438
  for (int il = 0; il < n_layer; ++il) {
6075
6439
  cur = build_norm(inpL,
6076
6440
  model.layers[il].attn_norm,
@@ -6103,9 +6467,7 @@ struct llm_build_bloom : public llm_graph_context {
6103
6467
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6104
6468
  }
6105
6469
 
6106
- if (il == n_layer - 1) {
6107
- // skip computing output for unused tokens
6108
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
6470
+ if (il == n_layer - 1 && inp_out_ids) {
6109
6471
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
6110
6472
  inpL = lm_ggml_get_rows(ctx0, inpL, inp_out_ids);
6111
6473
  }
@@ -6182,6 +6544,8 @@ struct llm_build_mpt : public llm_graph_context {
6182
6544
  cb(inpL, "inpL", -1);
6183
6545
  }
6184
6546
 
6547
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
6548
+
6185
6549
  for (int il = 0; il < n_layer; ++il) {
6186
6550
  lm_ggml_tensor * attn_norm;
6187
6551
 
@@ -6244,9 +6608,7 @@ struct llm_build_mpt : public llm_graph_context {
6244
6608
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6245
6609
  }
6246
6610
 
6247
- if (il == n_layer - 1) {
6248
- // skip computing output for unused tokens
6249
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
6611
+ if (il == n_layer - 1 && inp_out_ids) {
6250
6612
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
6251
6613
  inpL = lm_ggml_get_rows(ctx0, inpL, inp_out_ids);
6252
6614
  }
@@ -6315,6 +6677,8 @@ struct llm_build_stablelm : public llm_graph_context {
6315
6677
 
6316
6678
  auto * inp_attn = build_attn_inp_kv_unified();
6317
6679
 
6680
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
6681
+
6318
6682
  for (int il = 0; il < n_layer; ++il) {
6319
6683
  // norm
6320
6684
  cur = build_norm(inpL,
@@ -6390,9 +6754,7 @@ struct llm_build_stablelm : public llm_graph_context {
6390
6754
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6391
6755
  }
6392
6756
 
6393
- if (il == n_layer - 1) {
6394
- // skip computing output for unused tokens
6395
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
6757
+ if (il == n_layer - 1 && inp_out_ids) {
6396
6758
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
6397
6759
  inpL = lm_ggml_get_rows(ctx0, inpL, inp_out_ids);
6398
6760
  inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
@@ -6467,6 +6829,8 @@ struct llm_build_qwen : public llm_graph_context {
6467
6829
 
6468
6830
  auto * inp_attn = build_attn_inp_kv_unified();
6469
6831
 
6832
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
6833
+
6470
6834
  for (int il = 0; il < n_layer; ++il) {
6471
6835
  lm_ggml_tensor * inpSA = inpL;
6472
6836
 
@@ -6513,9 +6877,7 @@ struct llm_build_qwen : public llm_graph_context {
6513
6877
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6514
6878
  }
6515
6879
 
6516
- if (il == n_layer - 1) {
6517
- // skip computing output for unused tokens
6518
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
6880
+ if (il == n_layer - 1 && inp_out_ids) {
6519
6881
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
6520
6882
  inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
6521
6883
  }
@@ -6584,6 +6946,8 @@ struct llm_build_qwen2 : public llm_graph_context {
6584
6946
 
6585
6947
  auto * inp_attn = build_attn_inp_kv_unified();
6586
6948
 
6949
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
6950
+
6587
6951
  for (int il = 0; il < n_layer; ++il) {
6588
6952
  lm_ggml_tensor * inpSA = inpL;
6589
6953
 
@@ -6633,9 +6997,7 @@ struct llm_build_qwen2 : public llm_graph_context {
6633
6997
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6634
6998
  }
6635
6999
 
6636
- if (il == n_layer - 1) {
6637
- // skip computing output for unused tokens
6638
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
7000
+ if (il == n_layer - 1 && inp_out_ids) {
6639
7001
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
6640
7002
  inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
6641
7003
  }
@@ -6705,6 +7067,8 @@ struct llm_build_qwen2vl : public llm_graph_context {
6705
7067
  int sections[4];
6706
7068
  std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
6707
7069
 
7070
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
7071
+
6708
7072
  for (int il = 0; il < n_layer; ++il) {
6709
7073
  lm_ggml_tensor * inpSA = inpL;
6710
7074
 
@@ -6754,9 +7118,7 @@ struct llm_build_qwen2vl : public llm_graph_context {
6754
7118
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6755
7119
  }
6756
7120
 
6757
- if (il == n_layer - 1) {
6758
- // skip computing output for unused tokens
6759
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
7121
+ if (il == n_layer - 1 && inp_out_ids) {
6760
7122
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
6761
7123
  inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
6762
7124
  }
@@ -6823,6 +7185,8 @@ struct llm_build_qwen2moe : public llm_graph_context {
6823
7185
 
6824
7186
  auto * inp_attn = build_attn_inp_kv_unified();
6825
7187
 
7188
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
7189
+
6826
7190
  for (int il = 0; il < n_layer; ++il) {
6827
7191
  lm_ggml_tensor * inpSA = inpL;
6828
7192
 
@@ -6881,9 +7245,7 @@ struct llm_build_qwen2moe : public llm_graph_context {
6881
7245
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6882
7246
  }
6883
7247
 
6884
- if (il == n_layer - 1) {
6885
- // skip computing output for unused tokens
6886
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
7248
+ if (il == n_layer - 1 && inp_out_ids) {
6887
7249
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
6888
7250
  inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
6889
7251
  }
@@ -6982,6 +7344,8 @@ struct llm_build_qwen3 : public llm_graph_context {
6982
7344
 
6983
7345
  auto * inp_attn = build_attn_inp_kv_unified();
6984
7346
 
7347
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
7348
+
6985
7349
  for (int il = 0; il < n_layer; ++il) {
6986
7350
  lm_ggml_tensor * inpSA = inpL;
6987
7351
 
@@ -7034,9 +7398,7 @@ struct llm_build_qwen3 : public llm_graph_context {
7034
7398
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
7035
7399
  }
7036
7400
 
7037
- if (il == n_layer - 1) {
7038
- // skip computing output for unused tokens
7039
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
7401
+ if (il == n_layer - 1 && inp_out_ids) {
7040
7402
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
7041
7403
  inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
7042
7404
  }
@@ -7103,6 +7465,8 @@ struct llm_build_qwen3moe : public llm_graph_context {
7103
7465
 
7104
7466
  auto * inp_attn = build_attn_inp_kv_unified();
7105
7467
 
7468
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
7469
+
7106
7470
  for (int il = 0; il < n_layer; ++il) {
7107
7471
  lm_ggml_tensor * inpSA = inpL;
7108
7472
 
@@ -7155,9 +7519,7 @@ struct llm_build_qwen3moe : public llm_graph_context {
7155
7519
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
7156
7520
  }
7157
7521
 
7158
- if (il == n_layer - 1) {
7159
- // skip computing output for unused tokens
7160
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
7522
+ if (il == n_layer - 1 && inp_out_ids) {
7161
7523
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
7162
7524
  inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
7163
7525
  }
@@ -7233,6 +7595,8 @@ struct llm_build_phi2 : public llm_graph_context {
7233
7595
 
7234
7596
  auto * inp_attn = build_attn_inp_kv_unified();
7235
7597
 
7598
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
7599
+
7236
7600
  for (int il = 0; il < n_layer; ++il) {
7237
7601
  attn_norm_output = build_norm(inpL,
7238
7602
  model.layers[il].attn_norm,
@@ -7295,9 +7659,7 @@ struct llm_build_phi2 : public llm_graph_context {
7295
7659
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
7296
7660
  }
7297
7661
 
7298
- if (il == n_layer - 1) {
7299
- // skip computing output for unused tokens
7300
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
7662
+ if (il == n_layer - 1 && inp_out_ids) {
7301
7663
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
7302
7664
  inpL = lm_ggml_get_rows(ctx0, inpL, inp_out_ids);
7303
7665
  attn_norm_output = lm_ggml_get_rows(ctx0, attn_norm_output, inp_out_ids);
@@ -7369,6 +7731,8 @@ struct llm_build_phi3 : public llm_graph_context {
7369
7731
  inp_attn = build_attn_inp_kv_unified();
7370
7732
  }
7371
7733
 
7734
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
7735
+
7372
7736
  for (int il = 0; il < n_layer; ++il) {
7373
7737
  auto * residual = inpL;
7374
7738
 
@@ -7432,9 +7796,7 @@ struct llm_build_phi3 : public llm_graph_context {
7432
7796
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
7433
7797
  }
7434
7798
 
7435
- if (il == n_layer - 1) {
7436
- // skip computing output for unused tokens
7437
- lm_ggml_tensor* inp_out_ids = build_inp_out_ids();
7799
+ if (il == n_layer - 1 && inp_out_ids) {
7438
7800
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
7439
7801
  residual = lm_ggml_get_rows(ctx0, residual, inp_out_ids);
7440
7802
  }
@@ -7520,15 +7882,16 @@ struct llm_build_plamo : public llm_graph_context {
7520
7882
 
7521
7883
  auto * inp_attn = build_attn_inp_kv_unified();
7522
7884
 
7523
- for (int il = 0; il < n_layer; ++il) {
7885
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
7524
7886
 
7887
+ for (int il = 0; il < n_layer; ++il) {
7525
7888
  // norm
7526
7889
  cur = build_norm(inpL,
7527
7890
  model.layers[il].attn_norm, NULL,
7528
7891
  LLM_NORM_RMS, il);
7529
7892
  cb(cur, "attn_norm", il);
7530
7893
 
7531
- lm_ggml_tensor * attention_norm = cur;
7894
+ lm_ggml_tensor * sa_inp = cur;
7532
7895
 
7533
7896
  // self-attention
7534
7897
  {
@@ -7566,18 +7929,17 @@ struct llm_build_plamo : public llm_graph_context {
7566
7929
  model.layers[il].wo, NULL,
7567
7930
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
7568
7931
  }
7569
- lm_ggml_tensor * sa_out = cur;
7570
-
7571
- cur = attention_norm;
7572
7932
 
7573
- if (il == n_layer - 1) {
7574
- // skip computing output for unused tokens
7575
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
7933
+ if (il == n_layer - 1 && inp_out_ids) {
7576
7934
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
7577
- sa_out = lm_ggml_get_rows(ctx0, sa_out, inp_out_ids);
7935
+ sa_inp = lm_ggml_get_rows(ctx0, sa_inp, inp_out_ids);
7578
7936
  inpL = lm_ggml_get_rows(ctx0, inpL, inp_out_ids);
7579
7937
  }
7580
7938
 
7939
+ lm_ggml_tensor * sa_out = cur;
7940
+
7941
+ cur = sa_inp;
7942
+
7581
7943
  // feed-forward network
7582
7944
  {
7583
7945
  cur = build_ffn(cur,
@@ -7642,6 +8004,8 @@ struct llm_build_gpt2 : public llm_graph_context {
7642
8004
  inpL = lm_ggml_add(ctx0, inpL, pos);
7643
8005
  cb(inpL, "inpL", -1);
7644
8006
 
8007
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
8008
+
7645
8009
  for (int il = 0; il < n_layer; ++il) {
7646
8010
  cur = build_norm(inpL,
7647
8011
  model.layers[il].attn_norm,
@@ -7674,9 +8038,7 @@ struct llm_build_gpt2 : public llm_graph_context {
7674
8038
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
7675
8039
  }
7676
8040
 
7677
- if (il == n_layer - 1) {
7678
- // skip computing output for unused tokens
7679
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
8041
+ if (il == n_layer - 1 && inp_out_ids) {
7680
8042
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
7681
8043
  inpL = lm_ggml_get_rows(ctx0, inpL, inp_out_ids);
7682
8044
  }
@@ -7746,6 +8108,8 @@ struct llm_build_codeshell : public llm_graph_context {
7746
8108
 
7747
8109
  auto * inp_attn = build_attn_inp_kv_unified();
7748
8110
 
8111
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
8112
+
7749
8113
  for (int il = 0; il < n_layer; ++il) {
7750
8114
  cur = build_norm(inpL,
7751
8115
  model.layers[il].attn_norm,
@@ -7790,9 +8154,7 @@ struct llm_build_codeshell : public llm_graph_context {
7790
8154
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
7791
8155
  }
7792
8156
 
7793
- if (il == n_layer - 1) {
7794
- // skip computing output for unused tokens
7795
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
8157
+ if (il == n_layer - 1 && inp_out_ids) {
7796
8158
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
7797
8159
  inpL = lm_ggml_get_rows(ctx0, inpL, inp_out_ids);
7798
8160
  }
@@ -7846,128 +8208,128 @@ struct llm_build_codeshell : public llm_graph_context {
7846
8208
 
7847
8209
  struct llm_build_orion : public llm_graph_context {
7848
8210
  llm_build_orion(const llama_model & model, const llm_graph_params & params, lm_ggml_cgraph * gf) : llm_graph_context(params) {
7849
- const int64_t n_embd_head = hparams.n_embd_head_v;
8211
+ const int64_t n_embd_head = hparams.n_embd_head_v;
7850
8212
 
7851
- LM_GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
7852
- LM_GGML_ASSERT(n_embd_head == hparams.n_rot);
8213
+ LM_GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
8214
+ LM_GGML_ASSERT(n_embd_head == hparams.n_rot);
7853
8215
 
7854
- lm_ggml_tensor * cur;
7855
- lm_ggml_tensor * inpL;
8216
+ lm_ggml_tensor * cur;
8217
+ lm_ggml_tensor * inpL;
7856
8218
 
7857
- inpL = build_inp_embd(model.tok_embd);
8219
+ inpL = build_inp_embd(model.tok_embd);
7858
8220
 
7859
- // inp_pos - contains the positions
7860
- lm_ggml_tensor * inp_pos = build_inp_pos();
8221
+ // inp_pos - contains the positions
8222
+ lm_ggml_tensor * inp_pos = build_inp_pos();
7861
8223
 
7862
- auto * inp_attn = build_attn_inp_kv_unified();
8224
+ auto * inp_attn = build_attn_inp_kv_unified();
7863
8225
 
7864
- for (int il = 0; il < n_layer; ++il) {
7865
- lm_ggml_tensor * inpSA = inpL;
8226
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
7866
8227
 
7867
- // norm
7868
- cur = build_norm(inpL,
7869
- model.layers[il].attn_norm, model.layers[il].attn_norm_b,
7870
- LLM_NORM, il);
7871
- cb(cur, "attn_norm", il);
8228
+ for (int il = 0; il < n_layer; ++il) {
8229
+ lm_ggml_tensor * inpSA = inpL;
7872
8230
 
7873
- // self-attention
7874
- {
7875
- // compute Q and K and RoPE them
7876
- lm_ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
7877
- cb(Qcur, "Qcur", il);
7878
- // if (model.layers[il].bq) {
7879
- // Qcur = lm_ggml_add(ctx0, Qcur, model.layers[il].bq);
7880
- // cb(Qcur, "Qcur", il);
7881
- // }
7882
-
7883
- lm_ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
7884
- cb(Kcur, "Kcur", il);
7885
- // if (model.layers[il].bk) {
7886
- // Kcur = lm_ggml_add(ctx0, Kcur, model.layers[il].bk);
7887
- // cb(Kcur, "Kcur", il);
7888
- // }
7889
-
7890
- lm_ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
7891
- cb(Vcur, "Vcur", il);
7892
- // if (model.layers[il].bv) {
7893
- // Vcur = lm_ggml_add(ctx0, Vcur, model.layers[il].bv);
7894
- // cb(Vcur, "Vcur", il);
7895
- // }
7896
-
7897
- Qcur = lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
7898
- Kcur = lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
7899
- Vcur = lm_ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
7900
-
7901
- Qcur = lm_ggml_rope_ext(
7902
- ctx0, Qcur, inp_pos, nullptr,
7903
- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
7904
- ext_factor, attn_factor, beta_fast, beta_slow
7905
- );
8231
+ // norm
8232
+ cur = build_norm(inpL,
8233
+ model.layers[il].attn_norm, model.layers[il].attn_norm_b,
8234
+ LLM_NORM, il);
8235
+ cb(cur, "attn_norm", il);
7906
8236
 
7907
- Kcur = lm_ggml_rope_ext(
7908
- ctx0, Kcur, inp_pos, nullptr,
7909
- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
7910
- ext_factor, attn_factor, beta_fast, beta_slow
7911
- );
8237
+ // self-attention
8238
+ {
8239
+ // compute Q and K and RoPE them
8240
+ lm_ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
8241
+ cb(Qcur, "Qcur", il);
8242
+ // if (model.layers[il].bq) {
8243
+ // Qcur = lm_ggml_add(ctx0, Qcur, model.layers[il].bq);
8244
+ // cb(Qcur, "Qcur", il);
8245
+ // }
7912
8246
 
7913
- cb(Qcur, "Qcur", il);
7914
- cb(Kcur, "Kcur", il);
7915
- cb(Vcur, "Vcur", il);
8247
+ lm_ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
8248
+ cb(Kcur, "Kcur", il);
8249
+ // if (model.layers[il].bk) {
8250
+ // Kcur = lm_ggml_add(ctx0, Kcur, model.layers[il].bk);
8251
+ // cb(Kcur, "Kcur", il);
8252
+ // }
7916
8253
 
7917
- cur = build_attn(inp_attn, gf,
7918
- model.layers[il].wo, NULL,
7919
- Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
7920
- }
8254
+ lm_ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
8255
+ cb(Vcur, "Vcur", il);
8256
+ // if (model.layers[il].bv) {
8257
+ // Vcur = lm_ggml_add(ctx0, Vcur, model.layers[il].bv);
8258
+ // cb(Vcur, "Vcur", il);
8259
+ // }
7921
8260
 
7922
- if (il == n_layer - 1) {
7923
- // skip computing output for unused tokens
7924
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
7925
- cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
7926
- inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
7927
- }
8261
+ Qcur = lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
8262
+ Kcur = lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
8263
+ Vcur = lm_ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
7928
8264
 
7929
- lm_ggml_tensor * ffn_inp = lm_ggml_add(ctx0, cur, inpSA);
7930
- cb(ffn_inp, "ffn_inp", il);
8265
+ Qcur = lm_ggml_rope_ext(
8266
+ ctx0, Qcur, inp_pos, nullptr,
8267
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
8268
+ ext_factor, attn_factor, beta_fast, beta_slow
8269
+ );
7931
8270
 
7932
- // feed-forward network
7933
- cur = build_norm(ffn_inp,
7934
- model.layers[il].ffn_norm, model.layers[il].ffn_norm_b,
7935
- LLM_NORM, il);
7936
- cb(cur, "ffn_norm", il);
8271
+ Kcur = lm_ggml_rope_ext(
8272
+ ctx0, Kcur, inp_pos, nullptr,
8273
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
8274
+ ext_factor, attn_factor, beta_fast, beta_slow
8275
+ );
7937
8276
 
7938
- cur = build_ffn(cur,
7939
- model.layers[il].ffn_up, NULL, NULL,
7940
- model.layers[il].ffn_gate, NULL, NULL,
7941
- model.layers[il].ffn_down, NULL, NULL,
7942
- NULL,
7943
- LLM_FFN_SILU, LLM_FFN_PAR, il);
7944
- cb(cur, "ffn_out", il);
8277
+ cb(Qcur, "Qcur", il);
8278
+ cb(Kcur, "Kcur", il);
8279
+ cb(Vcur, "Vcur", il);
7945
8280
 
7946
- cur = lm_ggml_add(ctx0, cur, ffn_inp);
8281
+ cur = build_attn(inp_attn, gf,
8282
+ model.layers[il].wo, NULL,
8283
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
8284
+ }
7947
8285
 
7948
- cur = build_cvec(cur, il);
7949
- cb(cur, "l_out", il);
8286
+ if (il == n_layer - 1 && inp_out_ids) {
8287
+ cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
8288
+ inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
8289
+ }
7950
8290
 
7951
- // input for next layer
7952
- inpL = cur;
7953
- }
8291
+ lm_ggml_tensor * ffn_inp = lm_ggml_add(ctx0, cur, inpSA);
8292
+ cb(ffn_inp, "ffn_inp", il);
8293
+
8294
+ // feed-forward network
8295
+ cur = build_norm(ffn_inp,
8296
+ model.layers[il].ffn_norm, model.layers[il].ffn_norm_b,
8297
+ LLM_NORM, il);
8298
+ cb(cur, "ffn_norm", il);
8299
+
8300
+ cur = build_ffn(cur,
8301
+ model.layers[il].ffn_up, NULL, NULL,
8302
+ model.layers[il].ffn_gate, NULL, NULL,
8303
+ model.layers[il].ffn_down, NULL, NULL,
8304
+ NULL,
8305
+ LLM_FFN_SILU, LLM_FFN_PAR, il);
8306
+ cb(cur, "ffn_out", il);
8307
+
8308
+ cur = lm_ggml_add(ctx0, cur, ffn_inp);
8309
+
8310
+ cur = build_cvec(cur, il);
8311
+ cb(cur, "l_out", il);
8312
+
8313
+ // input for next layer
8314
+ inpL = cur;
8315
+ }
7954
8316
 
7955
- cur = inpL;
8317
+ cur = inpL;
7956
8318
 
7957
- cur = build_norm(cur,
7958
- model.output_norm, model.output_norm_b,
7959
- LLM_NORM, -1);
8319
+ cur = build_norm(cur,
8320
+ model.output_norm, model.output_norm_b,
8321
+ LLM_NORM, -1);
7960
8322
 
7961
- cb(cur, "result_norm", -1);
7962
- res->t_embd = cur;
8323
+ cb(cur, "result_norm", -1);
8324
+ res->t_embd = cur;
7963
8325
 
7964
- // lm_head
7965
- cur = build_lora_mm(model.output, cur);
8326
+ // lm_head
8327
+ cur = build_lora_mm(model.output, cur);
7966
8328
 
7967
- cb(cur, "result_output", -1);
7968
- res->t_logits = cur;
8329
+ cb(cur, "result_output", -1);
8330
+ res->t_logits = cur;
7969
8331
 
7970
- lm_ggml_build_forward_expand(gf, cur);
8332
+ lm_ggml_build_forward_expand(gf, cur);
7971
8333
  }
7972
8334
  };
7973
8335
 
@@ -7988,6 +8350,8 @@ struct llm_build_internlm2 : public llm_graph_context {
7988
8350
 
7989
8351
  auto * inp_attn = build_attn_inp_kv_unified();
7990
8352
 
8353
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
8354
+
7991
8355
  for (int il = 0; il < n_layer; ++il) {
7992
8356
  lm_ggml_tensor * inpSA = inpL;
7993
8357
 
@@ -8046,9 +8410,7 @@ struct llm_build_internlm2 : public llm_graph_context {
8046
8410
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
8047
8411
  }
8048
8412
 
8049
- if (il == n_layer - 1) {
8050
- // skip computing output for unused tokens
8051
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
8413
+ if (il == n_layer - 1 && inp_out_ids) {
8052
8414
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
8053
8415
  inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
8054
8416
  }
@@ -8124,6 +8486,8 @@ struct llm_build_minicpm3 : public llm_graph_context {
8124
8486
 
8125
8487
  auto * inp_attn = build_attn_inp_kv_unified();
8126
8488
 
8489
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
8490
+
8127
8491
  for (int il = 0; il < n_layer; ++il) {
8128
8492
  lm_ggml_tensor * inpSA = inpL;
8129
8493
 
@@ -8243,15 +8607,13 @@ struct llm_build_minicpm3 : public llm_graph_context {
8243
8607
  q_states, k_states, v_states, nullptr, nullptr, kq_scale, il);
8244
8608
  }
8245
8609
 
8246
- if (il == n_layer - 1) {
8247
- // skip computing output for unused tokens
8248
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
8610
+ if (il == n_layer - 1 && inp_out_ids) {
8249
8611
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
8250
8612
  inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
8251
8613
  }
8252
8614
 
8253
8615
  // scale_res - scale the hidden states for residual connection
8254
- const float scale_res = scale_depth/sqrtf(float(n_layer));
8616
+ const float scale_res = scale_depth/sqrtf(float(n_layer)); // TODO: is this correct?
8255
8617
  cur = lm_ggml_scale(ctx0, cur, scale_res);
8256
8618
  cb(cur, "hidden_scaled", il);
8257
8619
 
@@ -8328,6 +8690,8 @@ struct llm_build_gemma : public llm_graph_context {
8328
8690
 
8329
8691
  auto * inp_attn = build_attn_inp_kv_unified();
8330
8692
 
8693
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
8694
+
8331
8695
  for (int il = 0; il < n_layer; ++il) {
8332
8696
  // norm
8333
8697
  cur = build_norm(inpL,
@@ -8373,9 +8737,7 @@ struct llm_build_gemma : public llm_graph_context {
8373
8737
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
8374
8738
  }
8375
8739
 
8376
- if (il == n_layer - 1) {
8377
- // skip computing output for unused tokens
8378
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
8740
+ if (il == n_layer - 1 && inp_out_ids) {
8379
8741
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
8380
8742
  inpL = lm_ggml_get_rows(ctx0, inpL, inp_out_ids);
8381
8743
  }
@@ -8444,6 +8806,8 @@ struct llm_build_gemma2_iswa : public llm_graph_context {
8444
8806
 
8445
8807
  auto * inp_attn = build_attn_inp_kv_unified_iswa();
8446
8808
 
8809
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
8810
+
8447
8811
  for (int il = 0; il < n_layer; ++il) {
8448
8812
  // norm
8449
8813
  cur = build_norm(inpL,
@@ -8481,32 +8845,23 @@ struct llm_build_gemma2_iswa : public llm_graph_context {
8481
8845
  cb(Kcur, "Kcur", il);
8482
8846
  cb(Vcur, "Vcur", il);
8483
8847
 
8484
- // ref: https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
8485
- switch (model.type) {
8486
- case LLM_TYPE_2B:
8487
- case LLM_TYPE_9B:
8488
- case LLM_TYPE_27B: Qcur = lm_ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head))); break;
8489
- default: LM_GGML_ABORT("fatal error");
8490
- };
8491
- cb(Qcur, "Qcur_scaled", il);
8848
+ Qcur = lm_ggml_scale(ctx0, Qcur, hparams.f_attention_scale);
8492
8849
 
8493
8850
  cur = build_attn(inp_attn, gf,
8494
8851
  model.layers[il].wo, NULL,
8495
8852
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
8496
8853
  }
8497
8854
 
8855
+ if (il == n_layer - 1 && inp_out_ids) {
8856
+ cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
8857
+ inpL = lm_ggml_get_rows(ctx0, inpL, inp_out_ids);
8858
+ }
8859
+
8498
8860
  cur = build_norm(cur,
8499
8861
  model.layers[il].attn_post_norm, NULL,
8500
8862
  LLM_NORM_RMS, il);
8501
8863
  cb(cur, "attn_post_norm", il);
8502
8864
 
8503
- if (il == n_layer - 1) {
8504
- // skip computing output for unused tokens
8505
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
8506
- cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
8507
- inpL = lm_ggml_get_rows(ctx0, inpL, inp_out_ids);
8508
- }
8509
-
8510
8865
  lm_ggml_tensor * sa_out = lm_ggml_add(ctx0, cur, inpL);
8511
8866
  cb(sa_out, "sa_out", il);
8512
8867
 
@@ -8585,6 +8940,8 @@ struct llm_build_gemma3_iswa : public llm_graph_context {
8585
8940
  // TODO: is causal == true correct? might need some changes
8586
8941
  auto * inp_attn = build_attn_inp_kv_unified_iswa();
8587
8942
 
8943
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
8944
+
8588
8945
  for (int il = 0; il < n_layer; ++il) {
8589
8946
  const float freq_base_l = model.get_rope_freq_base (cparams, il);
8590
8947
  const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
@@ -8629,9 +8986,17 @@ struct llm_build_gemma3_iswa : public llm_graph_context {
8629
8986
  cb(Kcur, "Kcur", il);
8630
8987
  cb(Vcur, "Vcur", il);
8631
8988
 
8989
+ // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/model.py#L315
8990
+ Qcur = lm_ggml_scale(ctx0, Qcur, hparams.f_attention_scale);
8991
+
8632
8992
  cur = build_attn(inp_attn, gf,
8633
8993
  model.layers[il].wo, NULL,
8634
- Qcur, Kcur, Vcur, nullptr, nullptr, hparams.f_attention_scale, il);
8994
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
8995
+ }
8996
+
8997
+ if (il == n_layer - 1 && inp_out_ids) {
8998
+ cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
8999
+ inpL = lm_ggml_get_rows(ctx0, inpL, inp_out_ids);
8635
9000
  }
8636
9001
 
8637
9002
  cur = build_norm(cur,
@@ -8639,13 +9004,6 @@ struct llm_build_gemma3_iswa : public llm_graph_context {
8639
9004
  LLM_NORM_RMS, il);
8640
9005
  cb(cur, "attn_post_norm", il);
8641
9006
 
8642
- if (il == n_layer - 1) {
8643
- // skip computing output for unused tokens
8644
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
8645
- cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
8646
- inpL = lm_ggml_get_rows(ctx0, inpL, inp_out_ids);
8647
- }
8648
-
8649
9007
  lm_ggml_tensor * sa_out = lm_ggml_add(ctx0, cur, inpL);
8650
9008
  cb(sa_out, "sa_out", il);
8651
9009
 
@@ -8698,6 +9056,442 @@ struct llm_build_gemma3_iswa : public llm_graph_context {
8698
9056
  }
8699
9057
  };
8700
9058
 
9059
+ struct llm_build_gemma3n_iswa : public llm_graph_context {
9060
+ const llama_model & model;
9061
+ lm_ggml_cgraph * gf;
9062
+
9063
+ const int64_t n_embd_head;
9064
+ const int64_t n_embd_altup;
9065
+ const int64_t n_altup;
9066
+ const int i_altup_act;
9067
+ const int n_layer_kv = 20; // number of layers having KV [KV_REUSE]
9068
+ const int n_layer_sparsity = 10; // number of layers using activation sparsity
9069
+ const float f_sparsity_std_mul = 1.6448533535003662f; // std_multiplier = normal_dist.icdf(0.95)
9070
+
9071
+ lm_ggml_tensor * one; // containing single element 1.0f
9072
+
9073
+ llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params, lm_ggml_cgraph * gf)
9074
+ : llm_graph_context(params),
9075
+ model(model),
9076
+ gf(gf),
9077
+ n_embd_head(model.hparams.n_embd_head_k),
9078
+ n_embd_altup(model.hparams.n_embd_altup),
9079
+ n_altup(model.hparams.n_altup),
9080
+ i_altup_act(model.hparams.i_altup_act) {
9081
+ lm_ggml_tensor * cur;
9082
+ lm_ggml_tensor * inpL;
9083
+
9084
+ // TODO: remove this when lm_ggml_scale_add is implemented
9085
+ one = lm_ggml_new_tensor_1d(ctx0, LM_GGML_TYPE_F32, 1);
9086
+ {
9087
+ auto inp = std::make_unique<llm_graph_input_one>();
9088
+ inp->one = one;
9089
+ res->add_input(std::move(inp));
9090
+ }
9091
+
9092
+ inpL = build_inp_embd(model.tok_embd);
9093
+
9094
+ // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
9095
+ if (ubatch.token) {
9096
+ inpL = lm_ggml_scale(ctx0, inpL, sqrtf(n_embd));
9097
+ cb(inpL, "inp_scaled", -1);
9098
+ }
9099
+
9100
+ // inp_pos - contains the positions
9101
+ lm_ggml_tensor * inp_pos = build_inp_pos();
9102
+
9103
+ // TODO: is causal == true correct? might need some changes
9104
+ auto * inp_attn = build_attn_inp_kv_unified_iswa();
9105
+
9106
+ // inp_per_layer shape: [n_embd_altup, n_tokens, n_layer]
9107
+ lm_ggml_tensor * inp_per_layer = project_per_layer_inputs(inpL, get_per_layer_inputs());
9108
+
9109
+ // inpL now has only 1 altup, project it to the rest of the altups
9110
+ // these "added" altups will be concat to the last dim of inpL
9111
+ {
9112
+ lm_ggml_tensor * target_magnitude = calc_magnitude(inpL);
9113
+ lm_ggml_tensor * inp_repeated = lm_ggml_repeat_4d(ctx0, inpL, n_embd, n_tokens, n_altup - 1, 1);
9114
+ lm_ggml_tensor * altup_added = lm_ggml_mul_mat(ctx0, model.altup_proj, inp_repeated); // shape: [n_embd, n_tokens, n_altup - 1]
9115
+ lm_ggml_tensor * new_magnitude = calc_magnitude(altup_added);
9116
+ altup_added = lm_ggml_div(ctx0,
9117
+ lm_ggml_mul(ctx0, altup_added, target_magnitude),
9118
+ new_magnitude);
9119
+ inpL = lm_ggml_concat(ctx0, inpL, altup_added, 2); // shape: [n_embd, n_tokens, n_altup]
9120
+ cb(inpL, "inp_stacked", -1);
9121
+ }
9122
+
9123
+ // inpL now has shape: [n_embd, n_tokens, n_altup]
9124
+ // inp_per_layer now has shape: [n_embd_altup, n_tokens, n_layer]
9125
+
9126
+ for (int il = 0; il < n_layer; ++il) {
9127
+ // this block is made to be closely resemble Gemma3p5DecoderLayer on python code
9128
+ const bool has_kv = (il < n_layer_kv);
9129
+
9130
+ const float freq_base_l = model.get_rope_freq_base (cparams, il);
9131
+ const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
9132
+
9133
+ lm_ggml_tensor * cur = inpL; // [n_embd, n_tokens, n_altup]
9134
+ lm_ggml_tensor * predictions = altup_predict(cur, il); // [n_embd, n_tokens, n_altup]
9135
+
9136
+ // predicted value will go through self-attention and laurel
9137
+ lm_ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act); // [n_embd, n_tokens]
9138
+ cur = active_prediction;
9139
+ cb(cur, "active_prediction", il);
9140
+
9141
+ // norm
9142
+ cur = build_norm(cur, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
9143
+ cb(cur, "attn_norm", il);
9144
+
9145
+ // laurel
9146
+ lm_ggml_tensor * laurel_out = laurel(cur, il); // [n_embd, n_tokens]
9147
+
9148
+ // self-attention
9149
+ if (has_kv) {
9150
+ // compute Q and K and RoPE them
9151
+ lm_ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
9152
+ cb(Qcur, "Qcur", il);
9153
+
9154
+ lm_ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
9155
+ cb(Kcur, "Kcur", il);
9156
+
9157
+ lm_ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
9158
+ cb(Vcur, "Vcur", il);
9159
+
9160
+ Qcur = lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
9161
+ Kcur = lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
9162
+ Vcur = lm_ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
9163
+
9164
+ Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
9165
+ Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
9166
+ Vcur = lm_ggml_rms_norm(ctx0, Vcur, hparams.f_norm_rms_eps);
9167
+
9168
+ cb(Qcur, "Qcur_normed", il);
9169
+ cb(Kcur, "Kcur_normed", il);
9170
+ cb(Vcur, "Vcur_normed", il);
9171
+
9172
+ Qcur = lm_ggml_rope_ext(
9173
+ ctx0, Qcur, inp_pos, nullptr,
9174
+ n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
9175
+ ext_factor, attn_factor, beta_fast, beta_slow);
9176
+
9177
+ Kcur = lm_ggml_rope_ext(
9178
+ ctx0, Kcur, inp_pos, nullptr,
9179
+ n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
9180
+ ext_factor, attn_factor, beta_fast, beta_slow);
9181
+
9182
+ cb(Qcur, "Qcur_pos", il);
9183
+ cb(Kcur, "Kcur_pos", il);
9184
+
9185
+ cur = build_attn(inp_attn, gf,
9186
+ model.layers[il].wo, NULL,
9187
+ Qcur, Kcur, Vcur, nullptr, nullptr, hparams.f_attention_scale, il);
9188
+ } else {
9189
+ // no KV layers
9190
+ lm_ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
9191
+ cb(Qcur, "Qcur", il);
9192
+ Qcur = lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
9193
+
9194
+ Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
9195
+ cb(Qcur, "Qcur_normed", il);
9196
+
9197
+ Qcur = lm_ggml_rope_ext(
9198
+ ctx0, Qcur, inp_pos, nullptr,
9199
+ n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
9200
+ ext_factor, attn_factor, beta_fast, beta_slow);
9201
+ cb(Qcur, "Qcur_pos", il);
9202
+
9203
+ cur = build_attn(inp_attn, gf,
9204
+ model.layers[il].wo, NULL,
9205
+ Qcur, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il);
9206
+ }
9207
+
9208
+ cur = build_norm(cur,
9209
+ model.layers[il].attn_post_norm, NULL,
9210
+ LLM_NORM_RMS, il);
9211
+ cb(cur, "attn_post_norm", il);
9212
+
9213
+ cur = lm_ggml_add(ctx0, cur, active_prediction); // [n_embd, n_tokens]
9214
+ cb(cur, "attn_gated", il);
9215
+
9216
+ lm_ggml_tensor * attn_laurel = lm_ggml_scale(ctx0,
9217
+ lm_ggml_add(ctx0, cur, laurel_out),
9218
+ 1.0f / sqrtf(2.0f)); // [n_embd, n_tokens]
9219
+ cb(attn_laurel, "attn_laurel", il);
9220
+
9221
+ cur = build_norm(attn_laurel,
9222
+ model.layers[il].ffn_norm, NULL,
9223
+ LLM_NORM_RMS, il);
9224
+ cb(cur, "ffn_norm", il);
9225
+
9226
+ // feed-forward network
9227
+ {
9228
+ lm_ggml_tensor * up_proj = build_lora_mm(model.layers[il].ffn_up, cur);
9229
+ lm_ggml_tensor * gate_proj = build_lora_mm(model.layers[il].ffn_gate, cur);
9230
+
9231
+ if (il < n_layer_sparsity) {
9232
+ // apply activation sparsity
9233
+ gate_proj = gaussian_topk(gate_proj);
9234
+ }
9235
+ gate_proj = lm_ggml_gelu(ctx0, gate_proj);
9236
+
9237
+ cur = lm_ggml_mul(ctx0, up_proj, gate_proj);
9238
+ cur = build_lora_mm(model.layers[il].ffn_down, cur);
9239
+ cb(cur, "ffn_out", il);
9240
+ }
9241
+
9242
+ cur = build_norm(cur,
9243
+ model.layers[il].ffn_post_norm, NULL,
9244
+ LLM_NORM_RMS, -1);
9245
+ cb(cur, "ffn_post_norm", il);
9246
+
9247
+ lm_ggml_tensor * attn_ffw_laurel_gated = lm_ggml_add(ctx0, cur, attn_laurel); // [n_embd, n_tokens]
9248
+ cb(attn_ffw_laurel_gated, "attn_ffw_laurel_gated", il);
9249
+
9250
+ lm_ggml_tensor * corrected = altup_correct(predictions, attn_ffw_laurel_gated, il); // [n_embd, n_tokens, n_altup]
9251
+
9252
+ lm_ggml_tensor * first_prediction; // [n_embd, n_tokens]
9253
+ {
9254
+ first_prediction = view_2d_slice(corrected, i_altup_act); // [n_embd, n_tokens]
9255
+ first_prediction = lm_ggml_mul(ctx0, first_prediction, model.layers[il].altup_correct_scale);
9256
+ first_prediction = build_lora_mm(model.layers[il].per_layer_inp_gate, first_prediction);
9257
+ first_prediction = lm_ggml_gelu(ctx0, first_prediction); // [n_embd_altup, n_tokens]
9258
+ cb(first_prediction, "first_prediction_gated", il);
9259
+ lm_ggml_tensor * inp_this_layer = view_2d_slice(inp_per_layer, il); // [n_embd_altup, n_tokens]
9260
+ first_prediction = lm_ggml_mul(ctx0, first_prediction, inp_this_layer); // [n_embd_altup, n_tokens]
9261
+ cb(first_prediction, "first_prediction_scaled", il);
9262
+
9263
+ first_prediction = build_lora_mm(model.layers[il].per_layer_proj, first_prediction); // [n_embd, n_tokens]
9264
+ first_prediction = build_norm(first_prediction,
9265
+ model.layers[il].per_layer_post_norm, NULL,
9266
+ LLM_NORM_RMS, il);
9267
+ cb(first_prediction, "first_prediction_out", il);
9268
+ }
9269
+
9270
+ // equivalent to python code: corrected_predictions[1:] += first_prediction
9271
+ {
9272
+ lm_ggml_tensor * slice_first = view_2d_slice(corrected, 0);
9273
+ lm_ggml_tensor * slice_rest = lm_ggml_view_3d(ctx0, corrected, n_embd, n_tokens, n_altup - 1,
9274
+ lm_ggml_row_size(corrected->type, n_embd),
9275
+ lm_ggml_row_size(corrected->type, n_embd*n_tokens),
9276
+ n_embd*n_tokens*lm_ggml_element_size(corrected));
9277
+ lm_ggml_tensor * tmp = lm_ggml_add(ctx0, slice_rest, first_prediction); // [n_embd, n_tokens, n_altup - 1]
9278
+ corrected = lm_ggml_concat(ctx0, slice_first, tmp, 2); // [n_embd, n_tokens, n_altup]
9279
+ }
9280
+
9281
+ cur = corrected; // [n_embd, n_tokens, n_altup]
9282
+ cur = build_cvec(cur, il);
9283
+ cb(cur, "l_out", il);
9284
+
9285
+ // input for next layer
9286
+ inpL = cur;
9287
+ }
9288
+
9289
+ cur = inpL; // [n_embd, n_tokens, n_altup]
9290
+
9291
+ // cur now has multiple altup(s), we want to merge them back to 1 altup
9292
+ {
9293
+ lm_ggml_tensor * target_magnitude = calc_magnitude(view_2d_slice(cur, i_altup_act)); // [n_embd, n_tokens]
9294
+ // do a view to skip the first slice (active altup)
9295
+ lm_ggml_tensor * alt_slice = lm_ggml_view_3d(ctx0, cur, n_embd, n_tokens, n_altup - 1,
9296
+ lm_ggml_row_size(cur->type, n_embd),
9297
+ lm_ggml_row_size(cur->type, n_embd*n_tokens),
9298
+ n_embd*n_tokens*lm_ggml_element_size(cur));
9299
+ lm_ggml_tensor * altup_unembd = lm_ggml_mul_mat(ctx0, model.altup_unembd_proj, alt_slice); // shape: [n_embd, n_tokens, n_altup - 1]
9300
+ lm_ggml_tensor * new_magnitude = calc_magnitude(altup_unembd);
9301
+ altup_unembd = lm_ggml_div(ctx0,
9302
+ lm_ggml_mul(ctx0, altup_unembd, target_magnitude),
9303
+ new_magnitude);
9304
+ cb(altup_unembd, "altup_unembd", -1);
9305
+
9306
+ // equivalent to torch.mean(hidden_states, dim=0)
9307
+ cur = view_2d_slice(cur, 0); // [n_embd, n_tokens]
9308
+ for (int i = 0; i < n_altup - 1; ++i) {
9309
+ cur = lm_ggml_add(ctx0, cur, view_2d_slice(altup_unembd, i));
9310
+ }
9311
+ cur = lm_ggml_scale(ctx0, cur, 1.0f / float(n_altup)); // [n_embd, n_tokens]
9312
+ cb(cur, "unembd_merged", -1);
9313
+ }
9314
+
9315
+ // cur now has shape: [n_embd, n_tokens]
9316
+
9317
+ // TODO: move this to right after the last KV layer
9318
+ {
9319
+ // skip computing output for unused tokens
9320
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
9321
+ cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
9322
+ }
9323
+
9324
+ cur = build_norm(cur,
9325
+ model.output_norm, NULL,
9326
+ LLM_NORM_RMS, -1);
9327
+
9328
+ cb(cur, "result_norm", -1);
9329
+ res->t_embd = cur;
9330
+
9331
+ cur = build_lora_mm(model.output, cur);
9332
+
9333
+ {
9334
+ // final logit soft-capping
9335
+ cur = lm_ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
9336
+ cur = lm_ggml_tanh(ctx0, cur);
9337
+ cur = lm_ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
9338
+ }
9339
+
9340
+ cb(cur, "result_output", -1);
9341
+ res->t_logits = cur;
9342
+
9343
+ lm_ggml_build_forward_expand(gf, cur);
9344
+ }
9345
+
9346
+ lm_ggml_tensor * calc_magnitude(lm_ggml_tensor * x) {
9347
+ return lm_ggml_sqrt(ctx0, lm_ggml_sum_rows(ctx0, lm_ggml_sqr(ctx0, x)));
9348
+ }
9349
+
9350
+ // get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim
9351
+ lm_ggml_tensor * view_2d_slice(lm_ggml_tensor * x, int idx) {
9352
+ LM_GGML_ASSERT(idx < (int)x->ne[2]);
9353
+ return lm_ggml_view_2d(ctx0, x, x->ne[0], x->ne[1],
9354
+ lm_ggml_row_size(x->type, x->ne[0]),
9355
+ idx * x->ne[0] * x->ne[1] * lm_ggml_element_size(x));
9356
+ }
9357
+
9358
+ // equivalent to get_per_layer_inputs() in python code
9359
+ // output shape: [n_embd_altup, n_layer, n_tokens]
9360
+ lm_ggml_tensor * get_per_layer_inputs() {
9361
+ auto inp = std::make_unique<llm_graph_input_embd>();
9362
+ lm_ggml_tensor * inp_per_layer;
9363
+ if (ubatch.token) {
9364
+ inp->tokens = lm_ggml_new_tensor_1d(ctx0, LM_GGML_TYPE_I32, ubatch.n_tokens);
9365
+ lm_ggml_set_input(inp->tokens);
9366
+ res->t_tokens = inp->tokens;
9367
+ inp_per_layer = lm_ggml_get_rows(ctx0, model.tok_embd_per_layer, inp->tokens);
9368
+ inp_per_layer = lm_ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, n_tokens);
9369
+ inp_per_layer = lm_ggml_scale(ctx0, inp_per_layer, sqrtf((float)n_embd_altup));
9370
+ cb(inp_per_layer, "inp_per_layer_selected", -1);
9371
+ } else {
9372
+ LM_GGML_ABORT("TODO: support embd input");
9373
+ }
9374
+ res->add_input(std::move(inp));
9375
+ return inp_per_layer;
9376
+ }
9377
+
9378
+ // equivalent to project_per_layer_inputs() in python code
9379
+ // this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim
9380
+ // output shape: [n_embd_altup, n_tokens, n_layer]
9381
+ lm_ggml_tensor * project_per_layer_inputs(lm_ggml_tensor * inputs_embeds, lm_ggml_tensor * inp_per_layer) {
9382
+ const float per_layer_projection_scale = 1.0f / sqrtf((float)n_embd);
9383
+ const float per_layer_input_scale = 1.0f / sqrtf(2.0f);
9384
+
9385
+ lm_ggml_tensor * per_layer_proj = lm_ggml_mul_mat(ctx0, model.per_layer_model_proj, inputs_embeds);
9386
+ per_layer_proj = lm_ggml_scale(ctx0, per_layer_proj, per_layer_projection_scale);
9387
+ per_layer_proj = lm_ggml_reshape_3d(ctx0, per_layer_proj, n_embd_altup, n_layer, n_tokens);
9388
+ per_layer_proj = build_norm(per_layer_proj,
9389
+ model.per_layer_proj_norm, NULL,
9390
+ LLM_NORM_RMS, -1); // [n_embd_altup, n_layer, n_tokens]
9391
+ cb(per_layer_proj, "per_layer_proj", -1);
9392
+
9393
+ inp_per_layer = lm_ggml_add(ctx0, inp_per_layer, per_layer_proj);
9394
+ inp_per_layer = lm_ggml_scale(ctx0, inp_per_layer, per_layer_input_scale);
9395
+ cb(inp_per_layer, "inp_per_layer", -1);
9396
+
9397
+ // permute to shape: [n_embd_altup, n_tokens, n_layer]
9398
+ inp_per_layer = lm_ggml_cont(ctx0, lm_ggml_permute(ctx0, inp_per_layer, 0, 2, 1, 3));
9399
+ return inp_per_layer;
9400
+ }
9401
+
9402
+ // input cur shape: [n_altup, n_tokens]
9403
+ // output shape: [n_altup, n_tokens]
9404
+ lm_ggml_tensor * laurel(lm_ggml_tensor * cur, int il) {
9405
+ lm_ggml_tensor * tmp = cur;
9406
+ tmp = build_lora_mm(model.layers[il].laurel_l, tmp);
9407
+ tmp = build_lora_mm(model.layers[il].laurel_r, tmp);
9408
+ tmp = build_norm(tmp, model.layers[il].laurel_post_norm, NULL, LLM_NORM_RMS, il);
9409
+ tmp = lm_ggml_add(ctx0, tmp, cur);
9410
+ cb(tmp, "laurel_out", il);
9411
+ return tmp;
9412
+ }
9413
+
9414
+ // input x shape: [n_embd, n_tokens]
9415
+ // output shape: [n_embd, n_tokens]
9416
+ lm_ggml_tensor * gaussian_topk(lm_ggml_tensor * x) {
9417
+ lm_ggml_tensor * mean = lm_ggml_mean(ctx0, x);
9418
+ lm_ggml_tensor * std = lm_ggml_sqrt(ctx0, lm_ggml_scale(ctx0,
9419
+ lm_ggml_sum_rows(ctx0, lm_ggml_sqr(ctx0, lm_ggml_sub(ctx0, x, mean))),
9420
+ 1.0f / (float)(x->ne[0] - 1)
9421
+ ));
9422
+ lm_ggml_tensor * cutoff_x = lm_ggml_add(ctx0, mean, lm_ggml_scale(ctx0, std, f_sparsity_std_mul));
9423
+ return lm_ggml_relu(ctx0, lm_ggml_sub(ctx0, x, cutoff_x));
9424
+ }
9425
+
9426
+ //
9427
+ // altup functions
9428
+ //
9429
+
9430
+ // equivalent to compute_router_modalities() in python code
9431
+ // input x shape: [n_embd, n_tokens]
9432
+ // output shape: [n_altup, n_tokens]
9433
+ lm_ggml_tensor * altup_compute_router_modalities(lm_ggml_tensor * x, int il) {
9434
+ lm_ggml_tensor * router_inputs = build_norm(x,
9435
+ model.layers[il].altup_router_norm, NULL,
9436
+ LLM_NORM_RMS, il);
9437
+
9438
+ // router_input_scale
9439
+ router_inputs = lm_ggml_scale(ctx0, router_inputs, 1.0f / (float)n_embd);
9440
+
9441
+ lm_ggml_tensor * output = lm_ggml_mul_mat(ctx0, model.layers[il].altup_router, router_inputs);
9442
+ return lm_ggml_tanh(ctx0, output); // [n_altup, n_tokens]
9443
+ }
9444
+
9445
+ // input cur shape: [n_embd, n_tokens, n_altup]
9446
+ // output shape: [n_embd, n_tokens, n_altup]
9447
+ lm_ggml_tensor * altup_predict(lm_ggml_tensor * cur, int il) {
9448
+ lm_ggml_tensor * activated = view_2d_slice(cur, i_altup_act); // [n_embd, n_tokens]
9449
+ lm_ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens]
9450
+ cb(modalities, "modalities", il);
9451
+
9452
+ lm_ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_predict_coef, modalities);
9453
+ cb(all_coefs, "all_coefs", il);
9454
+ // first dim now having n_altup^2 elements, we reshape it to 2D (so we end up with 3D tensor)
9455
+ all_coefs = lm_ggml_reshape_3d(ctx0, all_coefs, n_altup, n_altup, n_tokens);
9456
+
9457
+ // permute to [n_altup, n_embd, n_tokens]
9458
+ lm_ggml_tensor * cur_permuted = lm_ggml_cont(ctx0, lm_ggml_permute(ctx0, cur, 1, 2, 0, 3));
9459
+ lm_ggml_tensor * predictions = lm_ggml_mul_mat(ctx0, cur_permuted, all_coefs); // [n_altup, n_embd, n_tokens]
9460
+
9461
+ // final shape must be the same as cur: [n_embd, n_tokens, n_altup]
9462
+ predictions = lm_ggml_cont(ctx0, lm_ggml_permute(ctx0, predictions, 0, 2, 1, 3));
9463
+ predictions = lm_ggml_add(ctx0, predictions, cur);
9464
+ cb(predictions, "predictions", il);
9465
+
9466
+ return predictions;
9467
+ }
9468
+
9469
+ // input predictions shape: [n_embd, n_tokens, n_altup]
9470
+ // input activated shape: [n_embd, n_tokens]
9471
+ // output shape: [n_embd, n_tokens, n_altup]
9472
+ lm_ggml_tensor * altup_correct(lm_ggml_tensor * predictions, lm_ggml_tensor * activated, int il) {
9473
+ lm_ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens]
9474
+ cb(modalities, "modalities", il);
9475
+
9476
+ lm_ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act);
9477
+ lm_ggml_tensor * innovation = lm_ggml_sub(ctx0, activated, active_prediction); // [n_embd, n_tokens]
9478
+ cb(innovation, "innovation", il);
9479
+
9480
+ lm_ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_correct_coef, modalities); // [n_altup, n_tokens]
9481
+ all_coefs = lm_ggml_add(ctx0, all_coefs, one);
9482
+ cb(all_coefs, "all_coefs", il);
9483
+ all_coefs = lm_ggml_cont(ctx0, lm_ggml_transpose(ctx0, all_coefs)); // [n_tokens, n_altup]
9484
+ all_coefs = lm_ggml_reshape_3d(ctx0, all_coefs, 1, n_tokens, n_altup); // [1, n_tokens, n_altup]
9485
+
9486
+ innovation = lm_ggml_repeat_4d(ctx0, innovation, n_embd, n_tokens, n_altup, 1);
9487
+ lm_ggml_tensor * corrected = lm_ggml_mul(ctx0, innovation, all_coefs); // [n_embd, n_tokens, n_altup]
9488
+ corrected = lm_ggml_add(ctx0, corrected, predictions); // [n_embd, n_tokens, n_altup]
9489
+ cb(corrected, "corrected", il);
9490
+
9491
+ return corrected;
9492
+ }
9493
+ };
9494
+
8701
9495
  // TODO: move up next to build_starcoder
8702
9496
  struct llm_build_starcoder2 : public llm_graph_context {
8703
9497
  llm_build_starcoder2(const llama_model & model, const llm_graph_params & params, lm_ggml_cgraph * gf) : llm_graph_context(params) {
@@ -8716,6 +9510,8 @@ struct llm_build_starcoder2 : public llm_graph_context {
8716
9510
 
8717
9511
  auto * inp_attn = build_attn_inp_kv_unified();
8718
9512
 
9513
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
9514
+
8719
9515
  for (int il = 0; il < n_layer; ++il) {
8720
9516
  lm_ggml_tensor * inpSA = inpL;
8721
9517
 
@@ -8774,9 +9570,7 @@ struct llm_build_starcoder2 : public llm_graph_context {
8774
9570
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
8775
9571
  }
8776
9572
 
8777
- if (il == n_layer - 1) {
8778
- // skip computing output for unused tokens
8779
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
9573
+ if (il == n_layer - 1 && inp_out_ids) {
8780
9574
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
8781
9575
  inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
8782
9576
  }
@@ -8837,8 +9631,9 @@ struct llm_build_mamba : public llm_graph_context {
8837
9631
  // {n_embd, n_tokens}
8838
9632
  inpL = build_inp_embd(model.tok_embd);
8839
9633
 
8840
- lm_ggml_tensor * state_copy = build_inp_s_copy();
8841
- lm_ggml_tensor * state_mask = build_inp_s_mask();
9634
+ auto * rs_inp = build_rs_inp();
9635
+
9636
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
8842
9637
 
8843
9638
  for (int il = 0; il < n_layer; ++il) {
8844
9639
  // norm
@@ -8847,12 +9642,9 @@ struct llm_build_mamba : public llm_graph_context {
8847
9642
  LLM_NORM_RMS, il);
8848
9643
  cb(cur, "attn_norm", il);
8849
9644
 
8850
- //cur = build_mamba_layer(gf, cur, state_copy, state_mask, il);
8851
- cur = build_mamba_layer(gf, cur, state_copy, state_mask, ubatch, il);
9645
+ cur = build_mamba_layer(rs_inp, gf, cur, ubatch, il);
8852
9646
 
8853
- if (il == n_layer - 1) {
8854
- // skip computing output for unused tokens
8855
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
9647
+ if (il == n_layer - 1 && inp_out_ids) {
8856
9648
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
8857
9649
  inpL = lm_ggml_get_rows(ctx0, inpL, inp_out_ids);
8858
9650
  }
@@ -8886,15 +9678,14 @@ struct llm_build_mamba : public llm_graph_context {
8886
9678
 
8887
9679
  // TODO: split
8888
9680
  lm_ggml_tensor * build_mamba_layer(
8889
- lm_ggml_cgraph * gf,
8890
- lm_ggml_tensor * cur,
8891
- lm_ggml_tensor * state_copy,
8892
- lm_ggml_tensor * state_mask,
8893
- const llama_ubatch & ubatch,
8894
- int il) const {
8895
- const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
9681
+ llm_graph_input_rs * inp,
9682
+ lm_ggml_cgraph * gf,
9683
+ lm_ggml_tensor * cur,
9684
+ const llama_ubatch & ubatch,
9685
+ int il) const {
9686
+ const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
8896
9687
 
8897
- const auto kv_head = kv_self->head;
9688
+ const auto kv_head = mctx_cur->get_head();
8898
9689
 
8899
9690
  const int64_t d_conv = hparams.ssm_d_conv;
8900
9691
  const int64_t d_inner = hparams.ssm_d_inner;
@@ -8912,17 +9703,17 @@ struct llm_build_mamba : public llm_graph_context {
8912
9703
  LM_GGML_ASSERT(ubatch.equal_seqs);
8913
9704
  LM_GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
8914
9705
 
8915
- lm_ggml_tensor * conv_states_all = kv_self->k_l[il];
8916
- lm_ggml_tensor * ssm_states_all = kv_self->v_l[il];
9706
+ lm_ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
9707
+ lm_ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
8917
9708
 
8918
9709
  // (ab)using the KV cache to store the states
8919
- lm_ggml_tensor * conv = build_copy_mask_state(
8920
- gf, conv_states_all, state_copy, state_mask,
8921
- hparams.n_embd_k_s(), n_seqs);
9710
+ lm_ggml_tensor * conv = build_rs(
9711
+ inp, gf, conv_states_all,
9712
+ hparams.n_embd_r(), n_seqs);
8922
9713
  conv = lm_ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
8923
- lm_ggml_tensor * ssm = build_copy_mask_state(
8924
- gf, ssm_states_all, state_copy, state_mask,
8925
- hparams.n_embd_v_s(), n_seqs);
9714
+ lm_ggml_tensor * ssm = build_rs(
9715
+ inp, gf, ssm_states_all,
9716
+ hparams.n_embd_s(), n_seqs);
8926
9717
  ssm = lm_ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs);
8927
9718
 
8928
9719
  // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
@@ -9035,13 +9826,15 @@ struct llm_build_command_r : public llm_graph_context {
9035
9826
 
9036
9827
  auto * inp_attn = build_attn_inp_kv_unified();
9037
9828
 
9038
- for (int il = 0; il < n_layer; ++il) {
9829
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
9039
9830
 
9831
+ for (int il = 0; il < n_layer; ++il) {
9040
9832
  // norm
9041
9833
  cur = build_norm(inpL,
9042
9834
  model.layers[il].attn_norm, NULL,
9043
9835
  LLM_NORM, il);
9044
9836
  cb(cur, "attn_norm", il);
9837
+
9045
9838
  lm_ggml_tensor * ffn_inp = cur;
9046
9839
 
9047
9840
  // self-attention
@@ -9109,9 +9902,7 @@ struct llm_build_command_r : public llm_graph_context {
9109
9902
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
9110
9903
  }
9111
9904
 
9112
- if (il == n_layer - 1) {
9113
- // skip computing output for unused tokens
9114
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
9905
+ if (il == n_layer - 1 && inp_out_ids) {
9115
9906
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
9116
9907
  inpL = lm_ggml_get_rows(ctx0, inpL, inp_out_ids);
9117
9908
  ffn_inp = lm_ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
@@ -9182,6 +9973,8 @@ struct llm_build_cohere2_iswa : public llm_graph_context {
9182
9973
 
9183
9974
  auto * inp_attn = build_attn_inp_kv_unified_iswa();
9184
9975
 
9976
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
9977
+
9185
9978
  for (int il = 0; il < n_layer; ++il) {
9186
9979
  const bool is_swa = hparams.is_swa(il);
9187
9980
 
@@ -9244,9 +10037,7 @@ struct llm_build_cohere2_iswa : public llm_graph_context {
9244
10037
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
9245
10038
  }
9246
10039
 
9247
- if (il == n_layer - 1) {
9248
- // skip computing output for unused tokens
9249
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
10040
+ if (il == n_layer - 1 && inp_out_ids) {
9250
10041
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
9251
10042
  inpL = lm_ggml_get_rows(ctx0, inpL, inp_out_ids);
9252
10043
  ffn_inp = lm_ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
@@ -9317,6 +10108,8 @@ struct llm_build_olmo : public llm_graph_context {
9317
10108
 
9318
10109
  auto * inp_attn = build_attn_inp_kv_unified();
9319
10110
 
10111
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
10112
+
9320
10113
  for (int il = 0; il < n_layer; ++il) {
9321
10114
  lm_ggml_tensor * inpSA = inpL;
9322
10115
 
@@ -9375,9 +10168,7 @@ struct llm_build_olmo : public llm_graph_context {
9375
10168
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
9376
10169
  }
9377
10170
 
9378
- if (il == n_layer - 1) {
9379
- // skip computing output for unused tokens
9380
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
10171
+ if (il == n_layer - 1 && inp_out_ids) {
9381
10172
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
9382
10173
  inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
9383
10174
  }
@@ -9445,6 +10236,8 @@ struct llm_build_olmo2 : public llm_graph_context {
9445
10236
 
9446
10237
  auto * inp_attn = build_attn_inp_kv_unified();
9447
10238
 
10239
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
10240
+
9448
10241
  for (int il = 0; il < n_layer; ++il) {
9449
10242
  lm_ggml_tensor * inpSA = inpL;
9450
10243
 
@@ -9495,18 +10288,16 @@ struct llm_build_olmo2 : public llm_graph_context {
9495
10288
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
9496
10289
  }
9497
10290
 
10291
+ if (il == n_layer - 1 && inp_out_ids) {
10292
+ cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
10293
+ inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
10294
+ }
10295
+
9498
10296
  cur = build_norm(cur,
9499
10297
  model.layers[il].attn_post_norm, NULL,
9500
10298
  LLM_NORM_RMS, il);
9501
10299
  cb(cur, "attn_post_norm", il);
9502
10300
 
9503
- if (il == n_layer - 1) {
9504
- // skip computing output for unused tokens
9505
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
9506
- cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
9507
- inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
9508
- }
9509
-
9510
10301
  lm_ggml_tensor * ffn_inp = lm_ggml_add(ctx0, cur, inpSA);
9511
10302
  cb(ffn_inp, "ffn_inp", il);
9512
10303
 
@@ -9574,6 +10365,8 @@ struct llm_build_olmoe : public llm_graph_context {
9574
10365
 
9575
10366
  auto * inp_attn = build_attn_inp_kv_unified();
9576
10367
 
10368
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
10369
+
9577
10370
  for (int il = 0; il < n_layer; ++il) {
9578
10371
  lm_ggml_tensor * inpSA = inpL;
9579
10372
 
@@ -9628,9 +10421,7 @@ struct llm_build_olmoe : public llm_graph_context {
9628
10421
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
9629
10422
  }
9630
10423
 
9631
- if (il == n_layer - 1) {
9632
- // skip computing output for unused tokens
9633
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
10424
+ if (il == n_layer - 1 && inp_out_ids) {
9634
10425
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
9635
10426
  inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
9636
10427
  }
@@ -9700,6 +10491,8 @@ struct llm_build_openelm : public llm_graph_context {
9700
10491
 
9701
10492
  auto * inp_attn = build_attn_inp_kv_unified();
9702
10493
 
10494
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
10495
+
9703
10496
  for (int il = 0; il < n_layer; ++il) {
9704
10497
  const int64_t n_head = hparams.n_head(il);
9705
10498
  const int64_t n_head_kv = hparams.n_head_kv(il);
@@ -9761,11 +10554,9 @@ struct llm_build_openelm : public llm_graph_context {
9761
10554
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
9762
10555
  }
9763
10556
 
9764
- if (il == n_layer - 1) {
9765
- // skip computing output for unused tokens
9766
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
10557
+ if (il == n_layer - 1 && inp_out_ids) {
9767
10558
  residual = lm_ggml_get_rows(ctx0, residual, inp_out_ids);
9768
- cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
10559
+ cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
9769
10560
  }
9770
10561
 
9771
10562
  lm_ggml_tensor * ffn_inp = lm_ggml_add(ctx0, residual, cur);
@@ -9831,6 +10622,8 @@ struct llm_build_gptneox : public llm_graph_context {
9831
10622
 
9832
10623
  auto * inp_attn = build_attn_inp_kv_unified();
9833
10624
 
10625
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
10626
+
9834
10627
  for (int il = 0; il < n_layer; ++il) {
9835
10628
  cur = build_norm(inpL,
9836
10629
  model.layers[il].attn_norm,
@@ -9875,9 +10668,7 @@ struct llm_build_gptneox : public llm_graph_context {
9875
10668
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
9876
10669
  }
9877
10670
 
9878
- if (il == n_layer - 1) {
9879
- // skip computing output for unused tokens
9880
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
10671
+ if (il == n_layer - 1 && inp_out_ids) {
9881
10672
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
9882
10673
  inpL = lm_ggml_get_rows(ctx0, inpL, inp_out_ids);
9883
10674
  }
@@ -9979,6 +10770,8 @@ struct llm_build_arctic : public llm_graph_context {
9979
10770
 
9980
10771
  auto * inp_attn = build_attn_inp_kv_unified();
9981
10772
 
10773
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
10774
+
9982
10775
  for (int il = 0; il < n_layer; ++il) {
9983
10776
  lm_ggml_tensor * inpSA = inpL;
9984
10777
 
@@ -10025,9 +10818,7 @@ struct llm_build_arctic : public llm_graph_context {
10025
10818
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
10026
10819
  }
10027
10820
 
10028
- if (il == n_layer - 1) {
10029
- // skip computing output for unused tokens
10030
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
10821
+ if (il == n_layer - 1 && inp_out_ids) {
10031
10822
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
10032
10823
  inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
10033
10824
  }
@@ -10119,6 +10910,8 @@ struct llm_build_deepseek : public llm_graph_context {
10119
10910
 
10120
10911
  const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
10121
10912
 
10913
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
10914
+
10122
10915
  for (int il = 0; il < n_layer; ++il) {
10123
10916
  lm_ggml_tensor * inpSA = inpL;
10124
10917
 
@@ -10180,14 +10973,11 @@ struct llm_build_deepseek : public llm_graph_context {
10180
10973
  Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
10181
10974
  }
10182
10975
 
10183
- if (il == n_layer - 1) {
10184
- // skip computing output for unused tokens
10185
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
10976
+ if (il == n_layer - 1 && inp_out_ids) {
10186
10977
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
10187
10978
  inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
10188
10979
  }
10189
10980
 
10190
-
10191
10981
  lm_ggml_tensor * ffn_inp = lm_ggml_add(ctx0, cur, inpSA);
10192
10982
  cb(ffn_inp, "ffn_inp", il);
10193
10983
 
@@ -10295,6 +11085,8 @@ struct llm_build_deepseek2 : public llm_graph_context {
10295
11085
 
10296
11086
  auto * inp_attn = build_attn_inp_kv_unified();
10297
11087
 
11088
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
11089
+
10298
11090
  for (int il = 0; il < n_layer; ++il) {
10299
11091
  lm_ggml_tensor * inpSA = inpL;
10300
11092
 
@@ -10444,9 +11236,7 @@ struct llm_build_deepseek2 : public llm_graph_context {
10444
11236
  }
10445
11237
  }
10446
11238
 
10447
- if (il == n_layer - 1) {
10448
- // skip computing output for unused tokens
10449
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
11239
+ if (il == n_layer - 1 && inp_out_ids) {
10450
11240
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
10451
11241
  inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
10452
11242
  }
@@ -10542,6 +11332,8 @@ struct llm_build_bitnet : public llm_graph_context {
10542
11332
 
10543
11333
  auto * inp_attn = build_attn_inp_kv_unified();
10544
11334
 
11335
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
11336
+
10545
11337
  for (int il = 0; il < n_layer; ++il) {
10546
11338
  lm_ggml_tensor * inpSA = inpL;
10547
11339
 
@@ -10624,9 +11416,7 @@ struct llm_build_bitnet : public llm_graph_context {
10624
11416
  cb(cur, "attn_o_out", il);
10625
11417
  }
10626
11418
 
10627
- if (il == n_layer - 1) {
10628
- // skip computing output for unused tokens
10629
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
11419
+ if (il == n_layer - 1 && inp_out_ids) {
10630
11420
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
10631
11421
  inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
10632
11422
  }
@@ -10701,6 +11491,8 @@ struct llm_build_t5_enc : public llm_graph_context {
10701
11491
 
10702
11492
  auto * inp_attn = build_attn_inp_no_cache();
10703
11493
 
11494
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
11495
+
10704
11496
  for (int il = 0; il < n_layer; ++il) {
10705
11497
  lm_ggml_tensor * inpSA = inpL;
10706
11498
 
@@ -10734,9 +11526,7 @@ struct llm_build_t5_enc : public llm_graph_context {
10734
11526
  cb(cur, "kqv_out", il);
10735
11527
  }
10736
11528
 
10737
- if (il == n_layer - 1) {
10738
- // skip computing output for unused tokens
10739
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
11529
+ if (il == n_layer - 1 && inp_out_ids) {
10740
11530
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
10741
11531
  inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
10742
11532
  }
@@ -10807,6 +11597,8 @@ struct llm_build_t5_dec : public llm_graph_context {
10807
11597
  auto * inp_attn_self = build_attn_inp_kv_unified();
10808
11598
  auto * inp_attn_cross = build_attn_inp_cross();
10809
11599
 
11600
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
11601
+
10810
11602
  for (int il = 0; il < n_layer; ++il) {
10811
11603
  lm_ggml_tensor * inpSA = inpL;
10812
11604
 
@@ -10898,11 +11690,8 @@ struct llm_build_t5_dec : public llm_graph_context {
10898
11690
  //cb(cur, "kqv_out", il);
10899
11691
  }
10900
11692
 
10901
- if (il == n_layer - 1) {
10902
- // skip computing output for unused tokens
10903
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
11693
+ if (il == n_layer - 1 && inp_out_ids) {
10904
11694
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
10905
- inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
10906
11695
  inpCA = lm_ggml_get_rows(ctx0, inpCA, inp_out_ids);
10907
11696
  }
10908
11697
 
@@ -10972,6 +11761,8 @@ struct llm_build_jais : public llm_graph_context {
10972
11761
 
10973
11762
  auto * inp_attn = build_attn_inp_kv_unified();
10974
11763
 
11764
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
11765
+
10975
11766
  for (int il = 0; il < n_layer; ++il) {
10976
11767
  cur = build_norm(inpL,
10977
11768
  model.layers[il].attn_norm,
@@ -11004,9 +11795,7 @@ struct llm_build_jais : public llm_graph_context {
11004
11795
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/float(n_embd_head), il);
11005
11796
  }
11006
11797
 
11007
- if (il == n_layer - 1) {
11008
- // skip computing output for unused tokens
11009
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
11798
+ if (il == n_layer - 1 && inp_out_ids) {
11010
11799
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
11011
11800
  inpL = lm_ggml_get_rows(ctx0, inpL, inp_out_ids);
11012
11801
  }
@@ -11070,6 +11859,8 @@ struct llm_build_chatglm : public llm_graph_context {
11070
11859
 
11071
11860
  auto * inp_attn = build_attn_inp_kv_unified();
11072
11861
 
11862
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
11863
+
11073
11864
  for (int il = 0; il < n_layer; ++il) {
11074
11865
  lm_ggml_tensor * inpSA = inpL;
11075
11866
 
@@ -11136,9 +11927,7 @@ struct llm_build_chatglm : public llm_graph_context {
11136
11927
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
11137
11928
  }
11138
11929
 
11139
- if (il == n_layer - 1) {
11140
- // skip computing output for unused tokens
11141
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
11930
+ if (il == n_layer - 1 && inp_out_ids) {
11142
11931
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
11143
11932
  inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
11144
11933
  }
@@ -11203,6 +11992,8 @@ struct llm_build_glm4 : public llm_graph_context {
11203
11992
 
11204
11993
  auto * inp_attn = build_attn_inp_kv_unified();
11205
11994
 
11995
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
11996
+
11206
11997
  for (int il = 0; il < n_layer; ++il) {
11207
11998
  lm_ggml_tensor * inpSA = inpL;
11208
11999
 
@@ -11269,9 +12060,7 @@ struct llm_build_glm4 : public llm_graph_context {
11269
12060
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
11270
12061
  }
11271
12062
 
11272
- if (il == n_layer - 1) {
11273
- // skip computing output for unused tokens
11274
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
12063
+ if (il == n_layer - 1 && inp_out_ids) {
11275
12064
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
11276
12065
  inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
11277
12066
  }
@@ -11354,6 +12143,8 @@ struct llm_build_nemotron : public llm_graph_context {
11354
12143
 
11355
12144
  auto * inp_attn = build_attn_inp_kv_unified();
11356
12145
 
12146
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
12147
+
11357
12148
  for (int il = 0; il < n_layer; ++il) {
11358
12149
  lm_ggml_tensor * inpSA = inpL;
11359
12150
 
@@ -11413,9 +12204,7 @@ struct llm_build_nemotron : public llm_graph_context {
11413
12204
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
11414
12205
  }
11415
12206
 
11416
- if (il == n_layer - 1) {
11417
- // skip computing output for unused tokens
11418
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
12207
+ if (il == n_layer - 1 && inp_out_ids) {
11419
12208
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
11420
12209
  inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
11421
12210
  }
@@ -11483,6 +12272,8 @@ struct llm_build_exaone : public llm_graph_context {
11483
12272
 
11484
12273
  auto * inp_attn = build_attn_inp_kv_unified();
11485
12274
 
12275
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
12276
+
11486
12277
  for (int il = 0; il < n_layer; ++il) {
11487
12278
  lm_ggml_tensor * inpSA = inpL;
11488
12279
 
@@ -11544,9 +12335,7 @@ struct llm_build_exaone : public llm_graph_context {
11544
12335
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
11545
12336
  }
11546
12337
 
11547
- if (il == n_layer - 1) {
11548
- // skip computing output for unused tokens
11549
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
12338
+ if (il == n_layer - 1 && inp_out_ids) {
11550
12339
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
11551
12340
  inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
11552
12341
  }
@@ -11633,14 +12422,13 @@ struct llm_build_rwkv6_base : public llm_graph_context {
11633
12422
  }
11634
12423
 
11635
12424
  lm_ggml_tensor * build_rwkv6_time_mix(
12425
+ llm_graph_input_rs * inp,
11636
12426
  lm_ggml_cgraph * gf,
11637
12427
  lm_ggml_tensor * cur,
11638
12428
  lm_ggml_tensor * x_prev,
11639
- lm_ggml_tensor * state_copy,
11640
- lm_ggml_tensor * state_mask,
11641
12429
  const llama_ubatch & ubatch,
11642
12430
  int il) const {
11643
- const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
12431
+ const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
11644
12432
 
11645
12433
  const auto n_tokens = ubatch.n_tokens;
11646
12434
  const auto n_seqs = ubatch.n_seqs;
@@ -11650,7 +12438,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
11650
12438
  const auto n_head = n_embd / head_size;
11651
12439
  const auto n_head_kv = hparams.n_head_kv(il);
11652
12440
 
11653
- const auto kv_head = kv_self->head;
12441
+ const auto kv_head = mctx_cur->get_head();
11654
12442
 
11655
12443
  const auto & layer = model.layers[il];
11656
12444
 
@@ -11761,9 +12549,9 @@ struct llm_build_rwkv6_base : public llm_graph_context {
11761
12549
  k = lm_ggml_sub(ctx0, k, lm_ggml_mul(ctx0, k, w));
11762
12550
  }
11763
12551
 
11764
- lm_ggml_tensor * wkv_state = build_copy_mask_state(
11765
- gf, kv_self->v_l[il], state_copy, state_mask,
11766
- hparams.n_embd_v_s(), n_seqs);
12552
+ lm_ggml_tensor * wkv_state = build_rs(
12553
+ inp, gf, mctx_cur->get_s_l(il),
12554
+ hparams.n_embd_s(), n_seqs);
11767
12555
 
11768
12556
  lm_ggml_tensor * wkv_output;
11769
12557
  if (is_qrwkv) {
@@ -11781,9 +12569,9 @@ struct llm_build_rwkv6_base : public llm_graph_context {
11781
12569
  wkv_state,
11782
12570
  lm_ggml_view_1d(
11783
12571
  ctx0,
11784
- kv_self->v_l[il],
11785
- hparams.n_embd_v_s() * n_seqs,
11786
- hparams.n_embd_v_s() * kv_head * lm_ggml_element_size(kv_self->v_l[il])
12572
+ mctx_cur->get_s_l(il),
12573
+ hparams.n_embd_s() * n_seqs,
12574
+ hparams.n_embd_s() * kv_head * lm_ggml_element_size(mctx_cur->get_s_l(il))
11787
12575
  )
11788
12576
  )
11789
12577
  );
@@ -11817,20 +12605,19 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
11817
12605
  inpL = build_inp_embd(model.tok_embd);
11818
12606
  inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
11819
12607
 
11820
- lm_ggml_tensor * state_copy = build_inp_s_copy();
11821
- lm_ggml_tensor * state_mask = build_inp_s_mask();
12608
+ auto * rs_inp = build_rs_inp();
11822
12609
 
11823
12610
  const auto n_embd = hparams.n_embd;
11824
12611
  const auto n_seq_tokens = ubatch.n_seq_tokens;
11825
12612
  const auto n_seqs = ubatch.n_seqs;
11826
12613
 
12614
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
12615
+
11827
12616
  for (int il = 0; il < n_layer; ++il) {
11828
12617
  const llama_layer * layer = &model.layers[il];
11829
12618
  inpL = lm_ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
11830
12619
 
11831
- lm_ggml_tensor * token_shift = build_rwkv_token_shift_load(
11832
- gf, state_copy, state_mask, ubatch, il
11833
- );
12620
+ lm_ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
11834
12621
 
11835
12622
  lm_ggml_tensor * att_shift = lm_ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
11836
12623
  lm_ggml_tensor * ffn_shift = lm_ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * lm_ggml_element_size(token_shift));
@@ -11845,7 +12632,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
11845
12632
  1
11846
12633
  );
11847
12634
 
11848
- cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, state_mask, ubatch, il);
12635
+ cur = build_rwkv6_time_mix(rs_inp, gf, att_norm, x_prev, ubatch, il);
11849
12636
 
11850
12637
  lm_ggml_tensor * ffn_inp = lm_ggml_add(ctx0, cur, inpL);
11851
12638
  cb(ffn_inp, "ffn_inp", il);
@@ -11867,13 +12654,16 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
11867
12654
  );
11868
12655
  lm_ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
11869
12656
 
11870
- if (il == n_layer - 1) {
11871
- // skip computing output for unused tokens
11872
- struct lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
11873
- ffn_inp = lm_ggml_get_rows(ctx0, lm_ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens), inp_out_ids);
11874
- ffn_norm = lm_ggml_get_rows(ctx0, lm_ggml_reshape_2d(ctx0, ffn_norm, n_embd, n_tokens), inp_out_ids);
11875
- x_prev = lm_ggml_get_rows(ctx0, lm_ggml_reshape_2d(ctx0, x_prev, n_embd, n_tokens), inp_out_ids);
11876
- cur = lm_ggml_get_rows(ctx0, lm_ggml_reshape_2d(ctx0, cur, n_embd, n_tokens), inp_out_ids);
12657
+ ffn_inp = lm_ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens);
12658
+ ffn_norm = lm_ggml_reshape_2d(ctx0, ffn_norm, n_embd, n_tokens);
12659
+ x_prev = lm_ggml_reshape_2d(ctx0, x_prev, n_embd, n_tokens);
12660
+ cur = lm_ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
12661
+
12662
+ if (il == n_layer - 1 && inp_out_ids) {
12663
+ ffn_inp = lm_ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
12664
+ ffn_norm = lm_ggml_get_rows(ctx0, ffn_norm, inp_out_ids);
12665
+ x_prev = lm_ggml_get_rows(ctx0, x_prev, inp_out_ids);
12666
+ cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
11877
12667
  }
11878
12668
 
11879
12669
  cur = build_rwkv6_channel_mix(layer, ffn_norm, x_prev, LLM_ARCH_RWKV6);
@@ -11908,27 +12698,26 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
11908
12698
  // ref: https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1/blob/main/modeling_rwkv6qwen2.py
11909
12699
  struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
11910
12700
  llm_build_rwkv6qwen2(const llama_model & model, const llm_graph_params & params, lm_ggml_cgraph * gf) : llm_build_rwkv6_base(model, params) {
11911
- LM_GGML_ASSERT(n_embd == hparams.n_embd_k_s());
12701
+ LM_GGML_ASSERT(n_embd == hparams.n_embd_r());
11912
12702
 
11913
12703
  lm_ggml_tensor * cur;
11914
12704
  lm_ggml_tensor * inpL;
11915
12705
 
11916
12706
  inpL = build_inp_embd(model.tok_embd);
11917
12707
 
11918
- lm_ggml_tensor * state_copy = build_inp_s_copy();
11919
- lm_ggml_tensor * state_mask = build_inp_s_mask();
12708
+ auto * rs_inp = build_rs_inp();
11920
12709
 
11921
12710
  const auto n_embd = hparams.n_embd;
11922
12711
  const auto n_seq_tokens = ubatch.n_seq_tokens;
11923
12712
  const auto n_seqs = ubatch.n_seqs;
11924
12713
 
12714
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
12715
+
11925
12716
  for (int il = 0; il < n_layer; ++il) {
11926
12717
  const llama_layer * layer = &model.layers[il];
11927
12718
  inpL = lm_ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
11928
12719
 
11929
- lm_ggml_tensor * token_shift = build_rwkv_token_shift_load(
11930
- gf, state_copy, state_mask, ubatch, il
11931
- );
12720
+ lm_ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
11932
12721
 
11933
12722
  lm_ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
11934
12723
  cb(att_norm, "attn_norm", il);
@@ -11940,7 +12729,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
11940
12729
  1
11941
12730
  );
11942
12731
 
11943
- cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, state_mask, ubatch, il);
12732
+ cur = build_rwkv6_time_mix(rs_inp, gf, att_norm, x_prev, ubatch, il);
11944
12733
 
11945
12734
  token_shift = lm_ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*lm_ggml_element_size(att_norm));
11946
12735
  lm_ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
@@ -11948,11 +12737,12 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
11948
12737
  lm_ggml_tensor * ffn_inp = lm_ggml_add(ctx0, cur, inpL);
11949
12738
  cb(ffn_inp, "ffn_inp", il);
11950
12739
 
11951
- if (il == n_layer - 1) {
11952
- // skip computing output for unused tokens
11953
- struct lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
11954
- cur = lm_ggml_get_rows(ctx0, lm_ggml_reshape_2d(ctx0, cur, n_embd, n_tokens), inp_out_ids);
11955
- ffn_inp = lm_ggml_get_rows(ctx0, lm_ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens), inp_out_ids);
12740
+ cur = lm_ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
12741
+ ffn_inp = lm_ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens);
12742
+
12743
+ if (il == n_layer - 1 && inp_out_ids) {
12744
+ cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
12745
+ ffn_inp = lm_ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
11956
12746
  }
11957
12747
 
11958
12748
  // feed-forward network
@@ -12028,15 +12818,14 @@ struct llm_build_rwkv7_base : public llm_graph_context {
12028
12818
  }
12029
12819
 
12030
12820
  lm_ggml_tensor * build_rwkv7_time_mix(
12821
+ llm_graph_input_rs * inp,
12031
12822
  lm_ggml_cgraph * gf,
12032
12823
  lm_ggml_tensor * cur,
12033
12824
  lm_ggml_tensor * x_prev,
12034
- lm_ggml_tensor * state_copy,
12035
- lm_ggml_tensor * state_mask,
12036
12825
  lm_ggml_tensor *& first_layer_value,
12037
12826
  const llama_ubatch & ubatch,
12038
12827
  int il) const {
12039
- const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
12828
+ const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
12040
12829
 
12041
12830
  const auto n_tokens = ubatch.n_tokens;
12042
12831
  const auto n_seqs = ubatch.n_seqs;
@@ -12045,7 +12834,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
12045
12834
  const auto head_count = n_embd / head_size;
12046
12835
  const auto n_seq_tokens = ubatch.n_seq_tokens;
12047
12836
 
12048
- const auto kv_head = kv_self->head;
12837
+ const auto kv_head = mctx_cur->get_head();
12049
12838
 
12050
12839
  const auto & layer = model.layers[il];
12051
12840
 
@@ -12115,9 +12904,9 @@ struct llm_build_rwkv7_base : public llm_graph_context {
12115
12904
  v = lm_ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens);
12116
12905
  a = lm_ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
12117
12906
 
12118
- lm_ggml_tensor * wkv_state = build_copy_mask_state(
12119
- gf, kv_self->v_l[il], state_copy, state_mask,
12120
- hparams.n_embd_v_s(), n_seqs);
12907
+ lm_ggml_tensor * wkv_state = build_rs(
12908
+ inp, gf, mctx_cur->get_s_l(il),
12909
+ hparams.n_embd_s(), n_seqs);
12121
12910
 
12122
12911
  lm_ggml_tensor * wkv_output = lm_ggml_rwkv_wkv7(ctx0, r, w, k, v, lm_ggml_neg(ctx0, kk), lm_ggml_mul(ctx0, kk, a), wkv_state);
12123
12912
  cur = lm_ggml_view_1d(ctx0, wkv_output, n_embd * n_tokens, 0);
@@ -12130,9 +12919,9 @@ struct llm_build_rwkv7_base : public llm_graph_context {
12130
12919
  wkv_state,
12131
12920
  lm_ggml_view_1d(
12132
12921
  ctx0,
12133
- kv_self->v_l[il],
12134
- hparams.n_embd_v_s() * n_seqs,
12135
- hparams.n_embd_v_s() * kv_head * lm_ggml_element_size(kv_self->v_l[il])
12922
+ mctx_cur->get_s_l(il),
12923
+ hparams.n_embd_s() * n_seqs,
12924
+ hparams.n_embd_s() * kv_head * lm_ggml_element_size(mctx_cur->get_s_l(il))
12136
12925
  )
12137
12926
  )
12138
12927
  );
@@ -12173,20 +12962,19 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
12173
12962
  inpL = build_inp_embd(model.tok_embd);
12174
12963
  inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
12175
12964
 
12176
- lm_ggml_tensor * state_copy = build_inp_s_copy();
12177
- lm_ggml_tensor * state_mask = build_inp_s_mask();
12965
+ auto * rs_inp = build_rs_inp();
12178
12966
 
12179
12967
  const auto n_embd = hparams.n_embd;
12180
12968
  const auto n_seq_tokens = ubatch.n_seq_tokens;
12181
12969
  const auto n_seqs = ubatch.n_seqs;
12182
12970
 
12971
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
12972
+
12183
12973
  for (int il = 0; il < n_layer; ++il) {
12184
12974
  const llama_layer * layer = &model.layers[il];
12185
12975
  inpL = lm_ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
12186
12976
 
12187
- lm_ggml_tensor * token_shift = build_rwkv_token_shift_load(
12188
- gf, state_copy, state_mask, ubatch, il
12189
- );
12977
+ lm_ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
12190
12978
 
12191
12979
  lm_ggml_tensor * att_shift = lm_ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
12192
12980
  lm_ggml_tensor * ffn_shift = lm_ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * lm_ggml_element_size(token_shift));
@@ -12201,7 +12989,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
12201
12989
  1
12202
12990
  );
12203
12991
 
12204
- cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, state_mask, v_first, ubatch, il);
12992
+ cur = build_rwkv7_time_mix(rs_inp, gf, att_norm, x_prev, v_first, ubatch, il);
12205
12993
 
12206
12994
  lm_ggml_tensor * ffn_inp = lm_ggml_add(ctx0, cur, inpL);
12207
12995
  cb(ffn_inp, "ffn_inp", il);
@@ -12223,12 +13011,14 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
12223
13011
  );
12224
13012
  lm_ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
12225
13013
 
12226
- if (il == n_layer - 1) {
12227
- // skip computing output for unused tokens
12228
- struct lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
12229
- ffn_inp = lm_ggml_get_rows(ctx0, lm_ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens), inp_out_ids);
12230
- ffn_norm = lm_ggml_get_rows(ctx0, lm_ggml_reshape_2d(ctx0, ffn_norm, n_embd, n_tokens), inp_out_ids);
12231
- x_prev = lm_ggml_get_rows(ctx0, lm_ggml_reshape_2d(ctx0, x_prev, n_embd, n_tokens), inp_out_ids);
13014
+ ffn_inp = lm_ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens);
13015
+ ffn_norm = lm_ggml_reshape_2d(ctx0, ffn_norm, n_embd, n_tokens);
13016
+ x_prev = lm_ggml_reshape_2d(ctx0, x_prev, n_embd, n_tokens);
13017
+
13018
+ if (il == n_layer - 1 && inp_out_ids) {
13019
+ ffn_inp = lm_ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
13020
+ ffn_norm = lm_ggml_get_rows(ctx0, ffn_norm, inp_out_ids);
13021
+ x_prev = lm_ggml_get_rows(ctx0, x_prev, inp_out_ids);
12232
13022
  }
12233
13023
 
12234
13024
  cur = build_rwkv7_channel_mix(layer, ffn_norm, x_prev, LLM_ARCH_RWKV7);
@@ -12259,7 +13049,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
12259
13049
 
12260
13050
  struct llm_build_arwkv7 : public llm_build_rwkv7_base {
12261
13051
  llm_build_arwkv7(const llama_model & model, const llm_graph_params & params, lm_ggml_cgraph * gf) : llm_build_rwkv7_base(model, params) {
12262
- LM_GGML_ASSERT(n_embd == hparams.n_embd_k_s());
13052
+ LM_GGML_ASSERT(n_embd == hparams.n_embd_r());
12263
13053
 
12264
13054
  lm_ggml_tensor * cur;
12265
13055
  lm_ggml_tensor * inpL;
@@ -12267,20 +13057,19 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
12267
13057
 
12268
13058
  inpL = build_inp_embd(model.tok_embd);
12269
13059
 
12270
- lm_ggml_tensor * state_copy = build_inp_s_copy();
12271
- lm_ggml_tensor * state_mask = build_inp_s_mask();
13060
+ auto * rs_inp = build_rs_inp();
12272
13061
 
12273
13062
  const auto n_embd = hparams.n_embd;
12274
13063
  const auto n_seq_tokens = ubatch.n_seq_tokens;
12275
13064
  const auto n_seqs = ubatch.n_seqs;
12276
13065
 
13066
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
13067
+
12277
13068
  for (int il = 0; il < n_layer; ++il) {
12278
13069
  const llama_layer * layer = &model.layers[il];
12279
13070
  inpL = lm_ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
12280
13071
 
12281
- lm_ggml_tensor * token_shift = build_rwkv_token_shift_load(
12282
- gf, state_copy, state_mask, ubatch, il
12283
- );
13072
+ lm_ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
12284
13073
 
12285
13074
  lm_ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
12286
13075
  cb(att_norm, "attn_norm", il);
@@ -12292,7 +13081,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
12292
13081
  1
12293
13082
  );
12294
13083
 
12295
- cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, state_mask, v_first, ubatch, il);
13084
+ cur = build_rwkv7_time_mix(rs_inp, gf, att_norm, x_prev, v_first, ubatch, il);
12296
13085
 
12297
13086
  token_shift = lm_ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*lm_ggml_element_size(att_norm));
12298
13087
  lm_ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
@@ -12300,11 +13089,12 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
12300
13089
  lm_ggml_tensor * ffn_inp = lm_ggml_add(ctx0, cur, inpL);
12301
13090
  cb(ffn_inp, "ffn_inp", il);
12302
13091
 
12303
- if (il == n_layer - 1) {
12304
- // skip computing output for unused tokens
12305
- struct lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
12306
- cur = lm_ggml_get_rows(ctx0, lm_ggml_reshape_2d(ctx0, cur, n_embd, n_tokens), inp_out_ids);
12307
- ffn_inp = lm_ggml_get_rows(ctx0, lm_ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens), inp_out_ids);
13092
+ cur = lm_ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
13093
+ ffn_inp = lm_ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens);
13094
+
13095
+ if (il == n_layer - 1 && inp_out_ids) {
13096
+ cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
13097
+ ffn_inp = lm_ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
12308
13098
  }
12309
13099
 
12310
13100
  // feed-forward network
@@ -12373,6 +13163,9 @@ struct llm_build_granite : public llm_graph_context {
12373
13163
  auto * inp_attn = build_attn_inp_kv_unified();
12374
13164
 
12375
13165
  const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
13166
+
13167
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
13168
+
12376
13169
  for (int il = 0; il < n_layer; ++il) {
12377
13170
  lm_ggml_tensor * inpSA = inpL;
12378
13171
 
@@ -12435,9 +13228,7 @@ struct llm_build_granite : public llm_graph_context {
12435
13228
  cb(cur, "attn_out", il);
12436
13229
  }
12437
13230
 
12438
- if (il == n_layer - 1) {
12439
- // skip computing output for unused tokens
12440
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
13231
+ if (il == n_layer - 1 && inp_out_ids) {
12441
13232
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
12442
13233
  inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
12443
13234
  }
@@ -12556,6 +13347,8 @@ struct llm_build_chameleon : public llm_graph_context {
12556
13347
 
12557
13348
  auto * inp_attn = build_attn_inp_kv_unified();
12558
13349
 
13350
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
13351
+
12559
13352
  for (int il = 0; il < n_layer; ++il) {
12560
13353
  lm_ggml_tensor * inpSA = inpL;
12561
13354
 
@@ -12632,21 +13425,19 @@ struct llm_build_chameleon : public llm_graph_context {
12632
13425
  cur = build_attn(inp_attn, gf,
12633
13426
  model.layers[il].wo, nullptr,
12634
13427
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
12635
-
12636
- if (hparams.swin_norm) {
12637
- cur = build_norm(cur,
12638
- model.layers[il].attn_norm, NULL,
12639
- LLM_NORM_RMS, il);
12640
- }
12641
13428
  }
12642
13429
 
12643
- if (il == n_layer - 1) {
12644
- // skip computing output for unused tokens
12645
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
13430
+ if (il == n_layer - 1 && inp_out_ids) {
12646
13431
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
12647
13432
  inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
12648
13433
  }
12649
13434
 
13435
+ if (hparams.swin_norm) {
13436
+ cur = build_norm(cur,
13437
+ model.layers[il].attn_norm, NULL,
13438
+ LLM_NORM_RMS, il);
13439
+ }
13440
+
12650
13441
  lm_ggml_tensor * ffn_inp = lm_ggml_add(ctx0, cur, inpSA);
12651
13442
  cb(ffn_inp, "ffn_inp", il);
12652
13443
 
@@ -12887,6 +13678,8 @@ struct llm_build_plm : public llm_graph_context {
12887
13678
 
12888
13679
  auto * inp_attn = build_attn_inp_kv_unified();
12889
13680
 
13681
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
13682
+
12890
13683
  for (int il = 0; il < n_layer; ++il) {
12891
13684
  lm_ggml_tensor * inpSA = inpL;
12892
13685
 
@@ -12990,9 +13783,7 @@ struct llm_build_plm : public llm_graph_context {
12990
13783
  q_states, k_states, v_states, nullptr, nullptr, kq_scale, il);
12991
13784
  }
12992
13785
 
12993
- if (il == n_layer - 1) {
12994
- // skip computing output for unused tokens
12995
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
13786
+ if (il == n_layer - 1 && inp_out_ids) {
12996
13787
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
12997
13788
  inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
12998
13789
  }
@@ -13052,6 +13843,8 @@ struct llm_build_bailingmoe : public llm_graph_context {
13052
13843
 
13053
13844
  auto * inp_attn = build_attn_inp_kv_unified();
13054
13845
 
13846
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
13847
+
13055
13848
  for (int il = 0; il < n_layer; ++il) {
13056
13849
  lm_ggml_tensor * inpSA = inpL;
13057
13850
 
@@ -13113,9 +13906,7 @@ struct llm_build_bailingmoe : public llm_graph_context {
13113
13906
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_rot)), il);
13114
13907
  }
13115
13908
 
13116
- if (il == n_layer - 1) {
13117
- // skip computing output for unused tokens
13118
- lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
13909
+ if (il == n_layer - 1 && inp_out_ids) {
13119
13910
  cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
13120
13911
  inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
13121
13912
  }
@@ -13184,69 +13975,375 @@ struct llm_build_bailingmoe : public llm_graph_context {
13184
13975
  }
13185
13976
  };
13186
13977
 
13187
- llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
13188
- llama_memory_i * res;
13978
+ struct llm_build_dots1 : public llm_graph_context {
13979
+ llm_build_dots1(const llama_model & model, const llm_graph_params & params, lm_ggml_cgraph * gf) : llm_graph_context(params) {
13980
+ const int64_t n_embd_head = hparams.n_embd_head_v;
13189
13981
 
13190
- switch (arch) {
13191
- case LLM_ARCH_BERT:
13192
- case LLM_ARCH_JINA_BERT_V2:
13193
- case LLM_ARCH_NOMIC_BERT:
13194
- case LLM_ARCH_NOMIC_BERT_MOE:
13195
- case LLM_ARCH_WAVTOKENIZER_DEC:
13196
- {
13197
- res = nullptr;
13198
- } break;
13199
- case LLM_ARCH_MAMBA:
13200
- case LLM_ARCH_RWKV6:
13201
- case LLM_ARCH_RWKV6QWEN2:
13202
- case LLM_ARCH_RWKV7:
13203
- case LLM_ARCH_ARWKV7:
13982
+ LM_GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
13983
+ LM_GGML_ASSERT(n_embd_head == hparams.n_rot);
13984
+
13985
+ lm_ggml_tensor * cur;
13986
+ lm_ggml_tensor * inpL;
13987
+
13988
+ inpL = build_inp_embd(model.tok_embd);
13989
+
13990
+ // inp_pos - contains the positions
13991
+ lm_ggml_tensor * inp_pos = build_inp_pos();
13992
+
13993
+ auto * inp_attn = build_attn_inp_kv_unified();
13994
+
13995
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
13996
+
13997
+ for (int il = 0; il < n_layer; ++il) {
13998
+ lm_ggml_tensor * inpSA = inpL;
13999
+
14000
+ // norm
14001
+ cur = build_norm(inpL,
14002
+ model.layers[il].attn_norm, NULL,
14003
+ LLM_NORM_RMS, il);
14004
+ cb(cur, "attn_norm", il);
14005
+
14006
+ // self_attention
13204
14007
  {
13205
- res = new llama_kv_cache_recurrent(
13206
- *this,
13207
- LM_GGML_TYPE_F32,
13208
- LM_GGML_TYPE_F32,
13209
- cparams.offload_kqv,
13210
- std::max((uint32_t) 1, cparams.n_seq_max),
13211
- cparams.n_seq_max);
13212
- } break;
13213
- default:
14008
+ // compute Q and K and RoPE them
14009
+ lm_ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
14010
+ cb(Qcur, "Qcur", il);
14011
+
14012
+ lm_ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
14013
+ cb(Kcur, "Kcur", il);
14014
+
14015
+ lm_ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
14016
+ cb(Vcur, "Vcur", il);
14017
+
14018
+ Qcur = lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
14019
+ Kcur = lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
14020
+ Vcur = lm_ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
14021
+
14022
+ Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
14023
+ cb(Qcur, "Qcur_normed", il);
14024
+
14025
+ Qcur = lm_ggml_rope_ext(
14026
+ ctx0, Qcur, inp_pos, nullptr,
14027
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
14028
+ ext_factor, attn_factor, beta_fast, beta_slow
14029
+ );
14030
+
14031
+ Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
14032
+ cb(Kcur, "Kcur_normed", il);
14033
+
14034
+ Kcur = lm_ggml_rope_ext(
14035
+ ctx0, Kcur, inp_pos, nullptr,
14036
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
14037
+ ext_factor, attn_factor, beta_fast, beta_slow
14038
+ );
14039
+
14040
+ cb(Qcur, "Qcur", il);
14041
+ cb(Kcur, "Kcur", il);
14042
+ cb(Vcur, "Vcur", il);
14043
+
14044
+ cur = build_attn(inp_attn, gf,
14045
+ model.layers[il].wo, model.layers[il].bo,
14046
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
14047
+ }
14048
+
14049
+ if (il == n_layer - 1 && inp_out_ids) {
14050
+ cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
14051
+ inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
14052
+ }
14053
+
14054
+ lm_ggml_tensor * ffn_inp = lm_ggml_add(ctx0, cur, inpSA);
14055
+ cb(ffn_inp, "ffn_inp", il);
14056
+
14057
+ // MoE branch
14058
+ cur = build_norm(ffn_inp,
14059
+ model.layers[il].ffn_norm, NULL,
14060
+ LLM_NORM_RMS, il);
14061
+ cb(cur, "ffn_norm", il);
14062
+
14063
+ if ((uint32_t) il < hparams.n_layer_dense_lead) {
14064
+ cur = build_ffn(cur,
14065
+ model.layers[il].ffn_up, NULL, NULL,
14066
+ model.layers[il].ffn_gate, NULL, NULL,
14067
+ model.layers[il].ffn_down, NULL, NULL,
14068
+ NULL,
14069
+ LLM_FFN_SILU, LLM_FFN_PAR, il);
14070
+ cb(cur, "ffn_out", il);
14071
+ } else {
14072
+ lm_ggml_tensor * moe_out =
14073
+ build_moe_ffn(cur,
14074
+ model.layers[il].ffn_gate_inp,
14075
+ model.layers[il].ffn_up_exps,
14076
+ model.layers[il].ffn_gate_exps,
14077
+ model.layers[il].ffn_down_exps,
14078
+ model.layers[il].ffn_exp_probs_b,
14079
+ n_expert, n_expert_used,
14080
+ LLM_FFN_SILU, hparams.expert_weights_norm,
14081
+ true, hparams.expert_weights_scale,
14082
+ (llama_expert_gating_func_type) hparams.expert_gating_func,
14083
+ il);
14084
+ cb(moe_out, "ffn_moe_out", il);
14085
+
14086
+ {
14087
+ lm_ggml_tensor * ffn_shexp = build_ffn(cur,
14088
+ model.layers[il].ffn_up_shexp, NULL, NULL,
14089
+ model.layers[il].ffn_gate_shexp, NULL, NULL,
14090
+ model.layers[il].ffn_down_shexp, NULL, NULL,
14091
+ NULL,
14092
+ LLM_FFN_SILU, LLM_FFN_PAR, il);
14093
+ cb(ffn_shexp, "ffn_shexp", il);
14094
+
14095
+ cur = lm_ggml_add(ctx0, moe_out, ffn_shexp);
14096
+ cb(cur, "ffn_out", il);
14097
+ }
14098
+ }
14099
+
14100
+ cur = lm_ggml_add(ctx0, cur, ffn_inp);
14101
+
14102
+ cur = build_cvec(cur, il);
14103
+ cb(cur, "l_out", il);
14104
+
14105
+ // input for next layer
14106
+ inpL = cur;
14107
+ }
14108
+
14109
+ cur = inpL;
14110
+
14111
+ cur = build_norm(cur,
14112
+ model.output_norm, NULL,
14113
+ LLM_NORM_RMS, -1);
14114
+
14115
+ cb(cur, "result_norm", -1);
14116
+ res->t_embd = cur;
14117
+
14118
+ // lm_head
14119
+ cur = build_lora_mm(model.output, cur);
14120
+
14121
+ cb(cur, "result_output", -1);
14122
+ res->t_logits = cur;
14123
+
14124
+ lm_ggml_build_forward_expand(gf, cur);
14125
+ }
14126
+ };
14127
+
14128
+ struct llm_build_arcee : public llm_graph_context {
14129
+ llm_build_arcee(const llama_model & model, const llm_graph_params & params, lm_ggml_cgraph * gf) : llm_graph_context(params) {
14130
+ const int64_t n_embd_head = hparams.n_embd_head_v;
14131
+
14132
+ LM_GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
14133
+ LM_GGML_ASSERT(n_embd_head == hparams.n_rot);
14134
+
14135
+ lm_ggml_tensor * cur;
14136
+ lm_ggml_tensor * inpL;
14137
+
14138
+ inpL = build_inp_embd(model.tok_embd);
14139
+
14140
+ // inp_pos - contains the positions
14141
+ lm_ggml_tensor * inp_pos = build_inp_pos();
14142
+
14143
+ auto * inp_attn = build_attn_inp_kv_unified();
14144
+
14145
+ const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
14146
+
14147
+ lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
14148
+
14149
+ for (int il = 0; il < n_layer; ++il) {
14150
+ lm_ggml_tensor * inpSA = inpL;
14151
+
14152
+ // norm
14153
+ cur = build_norm(inpL,
14154
+ model.layers[il].attn_norm, NULL,
14155
+ LLM_NORM_RMS, il);
14156
+ cb(cur, "attn_norm", il);
14157
+
14158
+ // self-attention
13214
14159
  {
13215
- const auto padding = llama_kv_cache_unified::get_padding(cparams);
14160
+ // rope freq factors for llama3; may return nullptr for llama2 and other models
14161
+ lm_ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
13216
14162
 
13217
- cparams.n_ctx = LM_GGML_PAD(cparams.n_ctx, padding);
14163
+ // compute Q and K and RoPE them
14164
+ lm_ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
14165
+ cb(Qcur, "Qcur", il);
14166
+ if (model.layers[il].bq) {
14167
+ Qcur = lm_ggml_add(ctx0, Qcur, model.layers[il].bq);
14168
+ cb(Qcur, "Qcur", il);
14169
+ }
13218
14170
 
13219
- LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
14171
+ lm_ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
14172
+ cb(Kcur, "Kcur", il);
14173
+ if (model.layers[il].bk) {
14174
+ Kcur = lm_ggml_add(ctx0, Kcur, model.layers[il].bk);
14175
+ cb(Kcur, "Kcur", il);
14176
+ }
13220
14177
 
13221
- if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
13222
- LM_GGML_ASSERT(hparams.is_swa_any());
14178
+ lm_ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
14179
+ cb(Vcur, "Vcur", il);
14180
+ if (model.layers[il].bv) {
14181
+ Vcur = lm_ggml_add(ctx0, Vcur, model.layers[il].bv);
14182
+ cb(Vcur, "Vcur", il);
14183
+ }
13223
14184
 
13224
- res = new llama_kv_cache_unified_iswa(
13225
- *this,
13226
- params.type_k,
13227
- params.type_v,
13228
- !cparams.flash_attn,
13229
- cparams.offload_kqv,
13230
- params.swa_full,
13231
- cparams.n_ctx,
13232
- cparams.n_seq_max,
13233
- cparams.n_batch,
13234
- padding);
13235
- } else {
13236
- LM_GGML_ASSERT(!hparams.is_swa_any());
14185
+ Qcur = lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
14186
+ Kcur = lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
14187
+ Vcur = lm_ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
14188
+
14189
+ Qcur = lm_ggml_rope_ext(
14190
+ ctx0, Qcur, inp_pos, rope_factors,
14191
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
14192
+ ext_factor, attn_factor, beta_fast, beta_slow
14193
+ );
14194
+
14195
+ Kcur = lm_ggml_rope_ext(
14196
+ ctx0, Kcur, inp_pos, rope_factors,
14197
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
14198
+ ext_factor, attn_factor, beta_fast, beta_slow
14199
+ );
14200
+
14201
+ cb(Qcur, "Qcur", il);
14202
+ cb(Kcur, "Kcur", il);
14203
+ cb(Vcur, "Vcur", il);
14204
+
14205
+ cur = build_attn(inp_attn, gf,
14206
+ model.layers[il].wo, model.layers[il].bo,
14207
+ Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
14208
+ cb(cur, "attn_out", il);
14209
+ }
14210
+
14211
+ if (il == n_layer - 1 && inp_out_ids) {
14212
+ cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
14213
+ inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids);
14214
+ }
13237
14215
 
13238
- res = new llama_kv_cache_unified(
14216
+ lm_ggml_tensor * ffn_inp = lm_ggml_add(ctx0, cur, inpSA);
14217
+ cb(ffn_inp, "ffn_inp", il);
14218
+
14219
+ // feed-forward network
14220
+ // ARCEE uses relu^2 instead of silu
14221
+ cur = build_norm(ffn_inp,
14222
+ model.layers[il].ffn_norm, NULL,
14223
+ LLM_NORM_RMS, il);
14224
+ cb(cur, "ffn_norm", il);
14225
+
14226
+ cur = build_ffn(cur,
14227
+ model.layers[il].ffn_up, NULL, NULL,
14228
+ NULL, NULL, NULL,
14229
+ model.layers[il].ffn_down, NULL, NULL,
14230
+ NULL,
14231
+ LLM_FFN_RELU_SQR, LLM_FFN_SEQ, il);
14232
+ cb(cur, "ffn_out", il);
14233
+
14234
+ cur = lm_ggml_add(ctx0, cur, ffn_inp);
14235
+ cb(cur, "ffn_out", il);
14236
+
14237
+ cur = build_cvec(cur, il);
14238
+ cb(cur, "l_out", il);
14239
+
14240
+ // input for next layer
14241
+ inpL = cur;
14242
+ }
14243
+
14244
+ cur = inpL;
14245
+
14246
+ cur = build_norm(cur,
14247
+ model.output_norm, NULL,
14248
+ LLM_NORM_RMS, -1);
14249
+
14250
+ cb(cur, "result_norm", -1);
14251
+ res->t_embd = cur;
14252
+
14253
+ // lm_head
14254
+ cur = build_lora_mm(model.output, cur);
14255
+
14256
+ cb(cur, "result_output", -1);
14257
+ res->t_logits = cur;
14258
+
14259
+ lm_ggml_build_forward_expand(gf, cur);
14260
+ }
14261
+ };
14262
+
14263
+ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
14264
+ llama_memory_i * res;
14265
+
14266
+ switch (arch) {
14267
+ // Models that need specific instantiation should be handled in the
14268
+ // switch statement
14269
+ case LLM_ARCH_BERT:
14270
+ case LLM_ARCH_JINA_BERT_V2:
14271
+ case LLM_ARCH_NOMIC_BERT:
14272
+ case LLM_ARCH_NOMIC_BERT_MOE:
14273
+ case LLM_ARCH_NEO_BERT:
14274
+ case LLM_ARCH_WAVTOKENIZER_DEC:
14275
+ {
14276
+ res = nullptr;
14277
+ } break;
14278
+ // Models that need standard caching should rely on recurrent/hybrid
14279
+ // checks
14280
+ default:
14281
+ {
14282
+ if (llm_arch_is_recurrent(arch)) {
14283
+ res = new llama_memory_recurrent(
13239
14284
  *this,
13240
14285
  nullptr,
13241
- params.type_k,
13242
- params.type_v,
13243
- !cparams.flash_attn,
14286
+ LM_GGML_TYPE_F32,
14287
+ LM_GGML_TYPE_F32,
13244
14288
  cparams.offload_kqv,
13245
- cparams.n_ctx,
13246
- cparams.n_seq_max,
13247
- padding,
13248
- hparams.n_swa,
13249
- hparams.swa_type);
14289
+ std::max((uint32_t) 1, cparams.n_seq_max),
14290
+ cparams.n_seq_max);
14291
+ } else if (llm_arch_is_hybrid(arch)) {
14292
+ const auto padding = llama_kv_cache_unified::get_padding(cparams);
14293
+
14294
+ cparams.n_ctx = LM_GGML_PAD(cparams.n_ctx, padding);
14295
+
14296
+ res = new llama_memory_hybrid(
14297
+ /* model */ *this,
14298
+ /* attn_type_k */ params.type_k,
14299
+ /* attn_type_v */ params.type_v,
14300
+ /* attn_v_trans */ !cparams.flash_attn,
14301
+ /* attn_kv_size */ cparams.n_ctx,
14302
+ /* attn_n_pad */ padding,
14303
+ /* attn_n_swa */ hparams.n_swa,
14304
+ /* attn_swa_type */ hparams.swa_type,
14305
+ /* recurrent_type_k */ LM_GGML_TYPE_F32,
14306
+ /* recurrent_type_v */ LM_GGML_TYPE_F32,
14307
+ /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
14308
+ /* n_seq_max */ cparams.n_seq_max,
14309
+ /* offload */ cparams.offload_kqv);
14310
+ } else {
14311
+ const auto padding = llama_kv_cache_unified::get_padding(cparams);
14312
+
14313
+ cparams.n_ctx = LM_GGML_PAD(cparams.n_ctx, padding);
14314
+
14315
+ LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
14316
+
14317
+ if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
14318
+ LM_GGML_ASSERT(hparams.is_swa_any());
14319
+
14320
+ res = new llama_kv_cache_unified_iswa(
14321
+ *this,
14322
+ params.type_k,
14323
+ params.type_v,
14324
+ !cparams.flash_attn,
14325
+ cparams.offload_kqv,
14326
+ params.swa_full,
14327
+ cparams.n_ctx,
14328
+ cparams.n_seq_max,
14329
+ cparams.n_ubatch,
14330
+ padding);
14331
+ } else {
14332
+ LM_GGML_ASSERT(!hparams.is_swa_any());
14333
+
14334
+ res = new llama_kv_cache_unified(
14335
+ *this,
14336
+ nullptr,
14337
+ params.type_k,
14338
+ params.type_v,
14339
+ !cparams.flash_attn,
14340
+ cparams.offload_kqv,
14341
+ cparams.n_ctx,
14342
+ cparams.n_seq_max,
14343
+ padding,
14344
+ hparams.n_swa,
14345
+ hparams.swa_type);
14346
+ }
13250
14347
  }
13251
14348
  }
13252
14349
  }
@@ -13262,7 +14359,6 @@ llm_graph_result_ptr llama_model::build_graph(
13262
14359
 
13263
14360
  switch (arch) {
13264
14361
  case LLM_ARCH_LLAMA:
13265
- case LLM_ARCH_MINICPM:
13266
14362
  {
13267
14363
  llm = std::make_unique<llm_build_llama>(*this, params, gf);
13268
14364
  } break;
@@ -13301,6 +14397,10 @@ llm_graph_result_ptr llama_model::build_graph(
13301
14397
  {
13302
14398
  llm = std::make_unique<llm_build_bert>(*this, params, gf);
13303
14399
  } break;
14400
+ case LLM_ARCH_NEO_BERT:
14401
+ {
14402
+ llm = std::make_unique<llm_build_neo_bert>(*this, params, gf);
14403
+ } break;
13304
14404
  case LLM_ARCH_BLOOM:
13305
14405
  {
13306
14406
  llm = std::make_unique<llm_build_bloom>(*this, params, gf);
@@ -13386,6 +14486,10 @@ llm_graph_result_ptr llama_model::build_graph(
13386
14486
  {
13387
14487
  llm = std::make_unique<llm_build_gemma3_iswa>(*this, params, gf);
13388
14488
  } break;
14489
+ case LLM_ARCH_GEMMA3N:
14490
+ {
14491
+ llm = std::make_unique<llm_build_gemma3n_iswa>(*this, params, gf);
14492
+ } break;
13389
14493
  case LLM_ARCH_STARCODER2:
13390
14494
  {
13391
14495
  llm = std::make_unique<llm_build_starcoder2>(*this, params, gf);
@@ -13503,6 +14607,7 @@ llm_graph_result_ptr llama_model::build_graph(
13503
14607
  } break;
13504
14608
  case LLM_ARCH_GRANITE:
13505
14609
  case LLM_ARCH_GRANITE_MOE:
14610
+ case LLM_ARCH_MINICPM:
13506
14611
  {
13507
14612
  llm = std::make_unique<llm_build_granite>(*this, params, gf);
13508
14613
  } break;
@@ -13522,6 +14627,14 @@ llm_graph_result_ptr llama_model::build_graph(
13522
14627
  {
13523
14628
  llm = std::make_unique<llm_build_bailingmoe>(*this, params, gf);
13524
14629
  } break;
14630
+ case LLM_ARCH_DOTS1:
14631
+ {
14632
+ llm = std::make_unique<llm_build_dots1>(*this, params, gf);
14633
+ } break;
14634
+ case LLM_ARCH_ARCEE:
14635
+ {
14636
+ llm = std::make_unique<llm_build_arcee>(*this, params, gf);
14637
+ } break;
13525
14638
  default:
13526
14639
  LM_GGML_ABORT("fatal error");
13527
14640
  }
@@ -13593,6 +14706,22 @@ int32_t llama_model_n_head_kv(const llama_model * model) {
13593
14706
  return model->hparams.n_head_kv();
13594
14707
  }
13595
14708
 
14709
+ int32_t llama_model_n_swa(const llama_model * model) {
14710
+ return model->hparams.n_swa;
14711
+ }
14712
+
14713
+ uint32_t llama_model_n_cls_out(const struct llama_model * model) {
14714
+ return model->hparams.n_cls_out;
14715
+ }
14716
+
14717
+ const char * llama_model_cls_label(const struct llama_model * model, uint32_t i) {
14718
+ if (i < model->classifier_labels.size()) {
14719
+ return model->classifier_labels[i].c_str();
14720
+ }
14721
+
14722
+ return nullptr;
14723
+ }
14724
+
13596
14725
  // deprecated
13597
14726
  int32_t llama_n_ctx_train(const llama_model * model) {
13598
14727
  return llama_model_n_ctx_train(model);
@@ -13655,6 +14784,8 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
13655
14784
  case LLM_ARCH_GRANITE_MOE:
13656
14785
  case LLM_ARCH_CHAMELEON:
13657
14786
  case LLM_ARCH_BAILINGMOE:
14787
+ case LLM_ARCH_NEO_BERT:
14788
+ case LLM_ARCH_ARCEE:
13658
14789
  return LLAMA_ROPE_TYPE_NORM;
13659
14790
 
13660
14791
  // the pairs of head values are offset by n_rot/2
@@ -13680,6 +14811,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
13680
14811
  case LLM_ARCH_GEMMA:
13681
14812
  case LLM_ARCH_GEMMA2:
13682
14813
  case LLM_ARCH_GEMMA3:
14814
+ case LLM_ARCH_GEMMA3N:
13683
14815
  case LLM_ARCH_STARCODER2:
13684
14816
  case LLM_ARCH_OPENELM:
13685
14817
  case LLM_ARCH_GPTNEOX:
@@ -13688,6 +14820,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
13688
14820
  case LLM_ARCH_NEMOTRON:
13689
14821
  case LLM_ARCH_EXAONE:
13690
14822
  case LLM_ARCH_MINICPM3:
14823
+ case LLM_ARCH_DOTS1:
13691
14824
  return LLAMA_ROPE_TYPE_NEOX;
13692
14825
 
13693
14826
  case LLM_ARCH_QWEN2VL:
@@ -13753,7 +14886,7 @@ uint64_t llama_model_size(const llama_model * model) {
13753
14886
  }
13754
14887
 
13755
14888
  const char * llama_model_chat_template(const llama_model * model, const char * name) {
13756
- const auto key = name ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE_N)
14889
+ const auto key = name ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE)
13757
14890
  : LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE);
13758
14891
  const auto & it = model->lm_gguf_kv.find(key);
13759
14892
  if (it == model->lm_gguf_kv.end()) {
@@ -13761,7 +14894,7 @@ const char * llama_model_chat_template(const llama_model * model, const char * n
13761
14894
  // do not extend this list unless absolutely necessary
13762
14895
  // Mistral-Small-2503 does not have built-in chat template
13763
14896
  llama_vocab_pre_type pre_type = model->vocab.get_pre_type();
13764
- if (pre_type == LLAMA_VOCAB_PRE_TYPE_TEKKEN && model->layers.size() == 40) {
14897
+ if (!name && pre_type == LLAMA_VOCAB_PRE_TYPE_TEKKEN && model->layers.size() == 40) {
13765
14898
  return "mistral-v7-tekken";
13766
14899
  }
13767
14900
 
@@ -13795,14 +14928,7 @@ llama_token llama_model_decoder_start_token(const llama_model * model) {
13795
14928
  }
13796
14929
 
13797
14930
  bool llama_model_is_recurrent(const llama_model * model) {
13798
- switch (model->arch) {
13799
- case LLM_ARCH_MAMBA: return true;
13800
- case LLM_ARCH_RWKV6: return true;
13801
- case LLM_ARCH_RWKV6QWEN2: return true;
13802
- case LLM_ARCH_RWKV7: return true;
13803
- case LLM_ARCH_ARWKV7: return true;
13804
- default: return false;
13805
- }
14931
+ return llm_arch_is_recurrent(model->arch);
13806
14932
  }
13807
14933
 
13808
14934
  const std::vector<std::pair<std::string, lm_ggml_tensor *>> & llama_internal_get_tensor_map(const llama_model * model) {