whispercpp 1.3.2 → 1.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (244) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +6 -3
  3. data/README.md +71 -14
  4. data/Rakefile +20 -7
  5. data/ext/.gitignore +4 -6
  6. data/ext/dependencies.rb +36 -24
  7. data/ext/extconf.rb +1 -1
  8. data/ext/options.rb +48 -184
  9. data/ext/ruby_whisper.c +18 -0
  10. data/ext/ruby_whisper_context.c +43 -12
  11. data/ext/ruby_whisper_model.c +1 -1
  12. data/ext/ruby_whisper_params.c +4 -2
  13. data/ext/ruby_whisper_segment.c +81 -4
  14. data/ext/ruby_whisper_transcribe.cpp +13 -7
  15. data/ext/ruby_whisper_vad_params.c +1 -1
  16. data/ext/sources/CMakeLists.txt +5 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
  19. data/ext/sources/examples/addon.node/addon.cpp +150 -31
  20. data/ext/sources/examples/addon.node/index.js +3 -0
  21. data/ext/sources/examples/addon.node/vad-example.js +132 -0
  22. data/ext/sources/examples/bench/bench.cpp +3 -2
  23. data/ext/sources/examples/cli/cli.cpp +3 -2
  24. data/ext/sources/examples/command/command.cpp +32 -8
  25. data/ext/sources/examples/common-whisper.cpp +14 -7
  26. data/ext/sources/examples/lsp/lsp.cpp +2 -0
  27. data/ext/sources/examples/quantize/quantize.cpp +3 -0
  28. data/ext/sources/examples/server/CMakeLists.txt +3 -0
  29. data/ext/sources/examples/server/server.cpp +169 -22
  30. data/ext/sources/examples/stream/stream.cpp +6 -0
  31. data/ext/sources/examples/talk-llama/CMakeLists.txt +4 -1
  32. data/ext/sources/examples/talk-llama/llama-arch.cpp +171 -3
  33. data/ext/sources/examples/talk-llama/llama-arch.h +28 -1
  34. data/ext/sources/examples/talk-llama/llama-batch.cpp +741 -272
  35. data/ext/sources/examples/talk-llama/llama-batch.h +112 -54
  36. data/ext/sources/examples/talk-llama/llama-chat.cpp +30 -8
  37. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  38. data/ext/sources/examples/talk-llama/llama-context.cpp +520 -351
  39. data/ext/sources/examples/talk-llama/llama-context.h +38 -17
  40. data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
  41. data/ext/sources/examples/talk-llama/llama-cparams.h +1 -1
  42. data/ext/sources/examples/talk-llama/llama-graph.cpp +447 -372
  43. data/ext/sources/examples/talk-llama/llama-graph.h +128 -58
  44. data/ext/sources/examples/talk-llama/llama-hparams.cpp +10 -2
  45. data/ext/sources/examples/talk-llama/llama-hparams.h +19 -2
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +279 -0
  47. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.h +128 -0
  48. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +1841 -0
  49. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +303 -0
  50. data/ext/sources/examples/talk-llama/llama-kv-cache.h +14 -472
  51. data/ext/sources/examples/talk-llama/llama-kv-cells.h +86 -26
  52. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +246 -0
  53. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +138 -0
  54. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1125 -0
  55. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +183 -0
  56. data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
  57. data/ext/sources/examples/talk-llama/llama-memory.h +88 -4
  58. data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
  59. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +42 -17
  60. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
  61. data/ext/sources/examples/talk-llama/llama-model.cpp +1863 -563
  62. data/ext/sources/examples/talk-llama/llama-model.h +27 -0
  63. data/ext/sources/examples/talk-llama/llama-quant.cpp +89 -6
  64. data/ext/sources/examples/talk-llama/llama-vocab.cpp +65 -28
  65. data/ext/sources/examples/talk-llama/llama-vocab.h +1 -0
  66. data/ext/sources/examples/talk-llama/llama.cpp +11 -7
  67. data/ext/sources/examples/talk-llama/llama.h +147 -40
  68. data/ext/sources/examples/talk-llama/talk-llama.cpp +2 -0
  69. data/ext/sources/examples/talk-llama/unicode.cpp +5 -0
  70. data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
  71. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +2 -0
  72. data/ext/sources/ggml/CMakeLists.txt +48 -3
  73. data/ext/sources/ggml/cmake/common.cmake +24 -0
  74. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  75. data/ext/sources/ggml/include/ggml-cpu.h +2 -0
  76. data/ext/sources/ggml/include/ggml.h +144 -5
  77. data/ext/sources/ggml/src/CMakeLists.txt +82 -24
  78. data/ext/sources/ggml/src/ggml-backend-reg.cpp +5 -0
  79. data/ext/sources/ggml/src/ggml-backend.cpp +46 -23
  80. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +3 -3
  81. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +1 -0
  82. data/ext/sources/ggml/src/ggml-cann/common.h +6 -1
  83. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
  84. data/ext/sources/ggml/src/ggml-common.h +4 -0
  85. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +133 -40
  86. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  87. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
  88. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  89. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
  90. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
  91. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
  92. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  93. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
  94. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
  95. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
  96. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
  97. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
  98. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
  99. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
  100. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  101. data/ext/sources/ggml/src/ggml-cpu/common.h +4 -3
  102. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
  103. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +146 -105
  104. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
  105. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  106. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  107. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
  108. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  109. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1057 -174
  110. data/ext/sources/ggml/src/ggml-cpu/ops.h +8 -0
  111. data/ext/sources/ggml/src/ggml-cpu/quants.c +1158 -0
  112. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  113. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1571 -0
  114. data/ext/sources/ggml/src/ggml-cpu/repack.h +98 -0
  115. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +330 -38
  116. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  117. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +111 -18
  118. data/ext/sources/ggml/src/ggml-cpu/vec.h +303 -94
  119. data/ext/sources/ggml/src/ggml-cuda/common.cuh +60 -37
  120. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  121. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  122. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  123. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  124. data/ext/sources/ggml/src/ggml-cuda/convert.cu +22 -0
  125. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +5 -0
  126. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +2 -2
  127. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
  128. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
  129. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +265 -123
  130. data/ext/sources/ggml/src/ggml-cuda/mean.cu +19 -0
  131. data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
  132. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +257 -87
  133. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +2 -3
  134. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
  135. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +5 -18
  136. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  137. data/ext/sources/ggml/src/ggml-cuda/unary.cu +89 -0
  138. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +7 -0
  139. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +4 -0
  140. data/ext/sources/ggml/src/ggml-impl.h +127 -183
  141. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +11 -10
  142. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +27 -0
  143. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +331 -49
  144. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +564 -282
  145. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
  146. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +14 -0
  147. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1859 -489
  148. data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  149. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  150. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +72 -0
  151. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +201 -0
  152. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
  153. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  154. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  155. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  156. data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  157. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
  158. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  159. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  160. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  161. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  162. data/ext/sources/ggml/src/ggml-quants.c +6 -8
  163. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
  164. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
  165. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  166. data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
  167. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +28 -41
  168. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +4 -10
  169. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +117 -165
  170. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +192 -53
  171. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  172. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  173. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  174. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +648 -1039
  175. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  176. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  177. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +8 -105
  178. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -100
  179. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  180. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
  181. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +60 -80
  182. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +158 -203
  183. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +55 -74
  184. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -10
  185. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +138 -27
  186. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
  187. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  188. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  189. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  190. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
  191. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +12 -16
  192. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
  193. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +726 -282
  194. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
  195. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  196. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  197. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  198. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  199. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  200. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
  201. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  202. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +10 -1
  203. data/ext/sources/ggml/src/ggml.c +328 -48
  204. data/ext/sources/ggml/src/ggml.cpp +26 -0
  205. data/ext/sources/ggml/src/gguf.cpp +24 -3
  206. data/ext/sources/include/whisper.h +2 -0
  207. data/ext/sources/src/CMakeLists.txt +2 -0
  208. data/ext/sources/src/coreml/whisper-compat.h +10 -0
  209. data/ext/sources/src/coreml/whisper-compat.m +35 -0
  210. data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
  211. data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
  212. data/ext/sources/src/whisper.cpp +218 -169
  213. data/extsources.rb +15 -9
  214. data/lib/whisper/context.rb +15 -0
  215. data/lib/whisper/model/uri.rb +56 -1
  216. data/lib/whisper/segment.rb +58 -0
  217. data/sig/whisper.rbs +68 -38
  218. data/{tests → test}/helper.rb +1 -12
  219. data/{tests → test}/test_model.rb +9 -0
  220. data/test/test_package.rb +51 -0
  221. data/test/test_segment.rb +146 -0
  222. data/{tests → test}/test_whisper.rb +70 -0
  223. data/whispercpp.gemspec +2 -3
  224. metadata +91 -43
  225. data/ext/sources/.dockerignore +0 -3
  226. data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
  227. data/ext/sources/ci/run.sh +0 -336
  228. data/ext/sources/close-issue.yml +0 -28
  229. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2739
  230. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  231. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
  232. data/tests/test_package.rb +0 -46
  233. data/tests/test_segment.rb +0 -74
  234. /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  235. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  236. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
  237. /data/{tests → test}/jfk_reader/.gitignore +0 -0
  238. /data/{tests → test}/jfk_reader/extconf.rb +0 -0
  239. /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
  240. /data/{tests → test}/test_callback.rb +0 -0
  241. /data/{tests → test}/test_error.rb +0 -0
  242. /data/{tests → test}/test_params.rb +0 -0
  243. /data/{tests → test}/test_vad.rb +0 -0
  244. /data/{tests → test}/test_vad_params.rb +0 -0
@@ -5,7 +5,11 @@
5
5
  #include "llama-batch.h"
6
6
  #include "llama-cparams.h"
7
7
  #include "llama-model-loader.h"
8
- #include "llama-kv-cache.h"
8
+
9
+ #include "llama-kv-cache-unified.h"
10
+ #include "llama-kv-cache-unified-iswa.h"
11
+ #include "llama-memory-hybrid.h"
12
+ #include "llama-memory-recurrent.h"
9
13
 
10
14
  #include "ggml-cpp.h"
11
15
 
@@ -43,6 +47,7 @@ const char * llm_type_name(llm_type type) {
43
47
  case LLM_TYPE_475M: return "475M";
44
48
  case LLM_TYPE_770M: return "770M";
45
49
  case LLM_TYPE_780M: return "780M";
50
+ case LLM_TYPE_0_3B: return "0.3B";
46
51
  case LLM_TYPE_0_5B: return "0.5B";
47
52
  case LLM_TYPE_0_6B: return "0.6B";
48
53
  case LLM_TYPE_1B: return "1B";
@@ -77,6 +82,7 @@ const char * llm_type_name(llm_type type) {
77
82
  case LLM_TYPE_40B: return "40B";
78
83
  case LLM_TYPE_65B: return "65B";
79
84
  case LLM_TYPE_70B: return "70B";
85
+ case LLM_TYPE_142B: return "142B";
80
86
  case LLM_TYPE_236B: return "236B";
81
87
  case LLM_TYPE_290B: return "290B";
82
88
  case LLM_TYPE_314B: return "314B";
@@ -98,6 +104,8 @@ const char * llm_type_name(llm_type type) {
98
104
  case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)";
99
105
  case LLM_TYPE_30B_A3B: return "30B.A3B";
100
106
  case LLM_TYPE_235B_A22B: return "235B.A22B";
107
+ case LLM_TYPE_E2B: return "E2B";
108
+ case LLM_TYPE_E4B: return "E4B";
101
109
  default: return "?B";
102
110
  }
103
111
  }
@@ -466,6 +474,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
466
474
  std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0);
467
475
  std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
468
476
  std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0);
477
+ std::fill(
478
+ hparams.recurrent_layer_arr.begin(),
479
+ hparams.recurrent_layer_arr.end(),
480
+ llm_arch_is_recurrent(ml.get_arch()));
469
481
 
470
482
  std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0);
471
483
 
@@ -540,6 +552,12 @@ void llama_model::load_hparams(llama_model_loader & ml) {
540
552
  uint32_t n_vocab = 0;
541
553
  ml.get_key(LLM_KV_VOCAB_SIZE, n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, n_vocab, false);
542
554
 
555
+ // for classifier models
556
+ ml.get_arr(LLM_KV_CLASSIFIER_OUTPUT_LABELS, classifier_labels, false);
557
+ if (!classifier_labels.empty()) {
558
+ hparams.n_cls_out = classifier_labels.size();
559
+ }
560
+
543
561
  // arch-specific KVs
544
562
  switch (arch) {
545
563
  case LLM_ARCH_LLAMA:
@@ -589,6 +607,16 @@ void llama_model::load_hparams(llama_model_loader & ml) {
589
607
  hparams.use_kq_norm = false;
590
608
  }
591
609
  } break;
610
+ case LLM_ARCH_ARCEE:
611
+ {
612
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
613
+
614
+ // Arcee uses the same structure as Llama
615
+ switch (hparams.n_layer) {
616
+ case 36: type = LLM_TYPE_4B; break;
617
+ default: type = LLM_TYPE_UNKNOWN;
618
+ }
619
+ } break;
592
620
  case LLM_ARCH_DECI:
593
621
  {
594
622
  ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -729,6 +757,16 @@ void llama_model::load_hparams(llama_model_loader & ml) {
729
757
  }
730
758
  }
731
759
  } break;
760
+ case LLM_ARCH_NEO_BERT:
761
+ {
762
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
763
+ ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
764
+ ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type);
765
+
766
+ if (hparams.n_layer == 28) {
767
+ type = LLM_TYPE_250M;
768
+ }
769
+ } break;
732
770
  case LLM_ARCH_BLOOM:
733
771
  {
734
772
  ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -952,6 +990,11 @@ void llama_model::load_hparams(llama_model_loader & ml) {
952
990
  case 46: type = LLM_TYPE_27B; break;
953
991
  default: type = LLM_TYPE_UNKNOWN;
954
992
  }
993
+
994
+ // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L173
995
+ hparams.f_attention_scale = type == LLM_TYPE_27B
996
+ ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
997
+ : 1.0f / std::sqrt(float(hparams.n_embd_head_k));
955
998
  } break;
956
999
  case LLM_ARCH_GEMMA3:
957
1000
  {
@@ -972,10 +1015,29 @@ void llama_model::load_hparams(llama_model_loader & ml) {
972
1015
  default: type = LLM_TYPE_UNKNOWN;
973
1016
  }
974
1017
 
1018
+ // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L289
975
1019
  hparams.f_attention_scale = type == LLM_TYPE_27B
976
1020
  ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
977
1021
  : 1.0f / std::sqrt(float(hparams.n_embd_head_k));
978
1022
  } break;
1023
+ case LLM_ARCH_GEMMA3N:
1024
+ {
1025
+ hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
1026
+ hparams.set_swa_pattern(5);
1027
+
1028
+ hparams.rope_freq_base_train_swa = 10000.0f;
1029
+ hparams.rope_freq_scale_train_swa = 1.0f;
1030
+ hparams.f_attention_scale = 1.0f;
1031
+
1032
+ ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
1033
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1034
+
1035
+ switch (hparams.n_layer) {
1036
+ case 30: type = LLM_TYPE_E2B; break;
1037
+ case 35: type = LLM_TYPE_E4B; break;
1038
+ default: type = LLM_TYPE_UNKNOWN;
1039
+ }
1040
+ } break;
979
1041
  case LLM_ARCH_STARCODER2:
980
1042
  {
981
1043
  ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -1429,6 +1491,28 @@ void llama_model::load_hparams(llama_model_loader & ml) {
1429
1491
  default: type = LLM_TYPE_UNKNOWN;
1430
1492
  }
1431
1493
  } break;
1494
+ case LLM_ARCH_DOTS1:
1495
+ {
1496
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1497
+ ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
1498
+ ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
1499
+ ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
1500
+ ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale);
1501
+ ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false);
1502
+ ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false);
1503
+ switch (hparams.n_layer) {
1504
+ case 62: type = LLM_TYPE_142B; break;
1505
+ default: type = LLM_TYPE_UNKNOWN;
1506
+ }
1507
+ } break;
1508
+ case LLM_ARCH_ERNIE4_5:
1509
+ {
1510
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1511
+ switch (hparams.n_layer) {
1512
+ case 18: type = LLM_TYPE_0_3B; break;
1513
+ default: type = LLM_TYPE_UNKNOWN;
1514
+ }
1515
+ } break;
1432
1516
  default: throw std::runtime_error("unsupported model architecture");
1433
1517
  }
1434
1518
 
@@ -2113,7 +2197,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
2113
2197
  case LLM_ARCH_NOMIC_BERT_MOE:
2114
2198
  {
2115
2199
  tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
2116
- type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, 0);
2200
+ type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, TENSOR_NOT_REQUIRED);
2117
2201
 
2118
2202
  if (arch == LLM_ARCH_BERT) {
2119
2203
  pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0);
@@ -2121,8 +2205,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
2121
2205
  cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED);
2122
2206
  cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED);
2123
2207
 
2124
- cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, 1}, TENSOR_NOT_REQUIRED);
2125
- cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {1}, TENSOR_NOT_REQUIRED);
2208
+ cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED);
2209
+ cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED);
2126
2210
  }
2127
2211
 
2128
2212
  tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
@@ -2131,7 +2215,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
2131
2215
  for (int i = 0; i < n_layer; ++i) {
2132
2216
  auto & layer = layers[i];
2133
2217
 
2134
- if (arch == LLM_ARCH_BERT) {
2218
+ layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED);
2219
+ layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED);
2220
+
2221
+ if (!layer.wqkv) {
2135
2222
  layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
2136
2223
  layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0);
2137
2224
 
@@ -2140,12 +2227,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
2140
2227
 
2141
2228
  layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
2142
2229
  layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0);
2143
- } else {
2144
- layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
2145
- }
2146
-
2147
- if (arch == LLM_ARCH_NOMIC_BERT_MOE) {
2148
- layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0);
2149
2230
  }
2150
2231
 
2151
2232
  layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
@@ -2175,6 +2256,32 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
2175
2256
  layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0);
2176
2257
  }
2177
2258
  } break;
2259
+ case LLM_ARCH_NEO_BERT:
2260
+ {
2261
+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
2262
+
2263
+ cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED);
2264
+ cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED);
2265
+
2266
+ cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED);
2267
+ cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED);
2268
+
2269
+ output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0);
2270
+
2271
+ for (int i = 0; i < n_layer; ++i) {
2272
+ auto & layer = layers[i];
2273
+
2274
+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
2275
+
2276
+ layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
2277
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
2278
+
2279
+ layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
2280
+
2281
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff*2}, 0);
2282
+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
2283
+ }
2284
+ } break;
2178
2285
  case LLM_ARCH_JINA_BERT_V2:
2179
2286
  {
2180
2287
  tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // word_embeddings
@@ -2212,8 +2319,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
2212
2319
  layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED);
2213
2320
  layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
2214
2321
 
2215
- layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
2216
- layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
2322
+ layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED);
2323
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, layer.ffn_gate ? n_ff : n_ff * 2}, 0);
2217
2324
 
2218
2325
  layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
2219
2326
  layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0);
@@ -2872,6 +2979,62 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
2872
2979
  layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
2873
2980
  }
2874
2981
  } break;
2982
+ case LLM_ARCH_GEMMA3N:
2983
+ {
2984
+ const int64_t n_altup = hparams.n_altup;
2985
+ const int64_t laurel_rank = hparams.laurel_rank;
2986
+ const int64_t n_embd_altup = hparams.n_embd_altup;
2987
+
2988
+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
2989
+ // if output is NULL, init from the input tok embed
2990
+ if (output == NULL) {
2991
+ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
2992
+ }
2993
+
2994
+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
2995
+ tok_embd_per_layer = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_altup * n_layer, n_vocab}, 0);
2996
+
2997
+ altup_proj = create_tensor(tn(LLM_TENSOR_ALTUP_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0);
2998
+ altup_unembd_proj = create_tensor(tn(LLM_TENSOR_ALTUP_UNEMBD_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0);
2999
+ per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight"), {n_embd, n_embd_altup * n_layer}, 0);
3000
+ per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight"), {n_embd_altup}, 0);
3001
+
3002
+ output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
3003
+
3004
+ for (int i = 0; i < n_layer; ++i) {
3005
+ auto & layer = layers[i];
3006
+
3007
+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
3008
+
3009
+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
3010
+ layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
3011
+ layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
3012
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
3013
+
3014
+ layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
3015
+ layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
3016
+ layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
3017
+
3018
+ layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
3019
+ layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
3020
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
3021
+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
3022
+ layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
3023
+
3024
+ // altup & laurel
3025
+ layer.per_layer_inp_gate = create_tensor(tn(LLM_TENSOR_PER_LAYER_INP_GATE, "weight", i), {n_embd, n_embd_altup}, 0);
3026
+ layer.per_layer_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ, "weight", i), {n_embd_altup, n_embd}, 0);
3027
+ layer.per_layer_post_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_POST_NORM, "weight", i), {n_embd}, 0);
3028
+ layer.altup_correct_coef = create_tensor(tn(LLM_TENSOR_ALTUP_CORRECT_COEF, "weight", i), {n_altup, n_altup}, 0);
3029
+ layer.altup_correct_scale = create_tensor(tn(LLM_TENSOR_ALTUP_CORRECT_SCALE, "weight", i), {n_embd}, 0);
3030
+ layer.altup_predict_coef = create_tensor(tn(LLM_TENSOR_ALTUP_PREDICT_COEF, "weight", i), {n_altup, n_altup * n_altup}, 0);
3031
+ layer.altup_router = create_tensor(tn(LLM_TENSOR_ALTUP_ROUTER, "weight", i), {n_embd, n_altup}, 0);
3032
+ layer.altup_router_norm = create_tensor(tn(LLM_TENSOR_ALTUP_ROUTER_NORM, "weight", i), {n_embd}, 0);
3033
+ layer.laurel_l = create_tensor(tn(LLM_TENSOR_LAUREL_L, "weight", i), {n_embd, laurel_rank}, 0);
3034
+ layer.laurel_r = create_tensor(tn(LLM_TENSOR_LAUREL_R, "weight", i), {laurel_rank, n_embd}, 0);
3035
+ layer.laurel_post_norm = create_tensor(tn(LLM_TENSOR_LAUREL_POST_NORM, "weight", i), {n_embd}, 0);
3036
+ }
3037
+ } break;
2875
3038
  case LLM_ARCH_STARCODER2:
2876
3039
  {
2877
3040
  tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -4111,6 +4274,123 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
4111
4274
  layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
4112
4275
  }
4113
4276
  } break;
4277
+ case LLM_ARCH_DOTS1:
4278
+ {
4279
+ const int64_t n_ff_exp = hparams.n_ff_exp;
4280
+ const int64_t n_expert_shared = hparams.n_expert_shared;
4281
+
4282
+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
4283
+
4284
+ output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
4285
+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
4286
+
4287
+ for (int i = 0; i < n_layer; ++i) {
4288
+ auto & layer = layers[i];
4289
+
4290
+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
4291
+
4292
+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
4293
+ layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
4294
+ layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
4295
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
4296
+
4297
+ layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
4298
+ layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
4299
+
4300
+ layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
4301
+
4302
+ if (i < (int) hparams.n_layer_dense_lead) {
4303
+ layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
4304
+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
4305
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
4306
+ } else {
4307
+ layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
4308
+ layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED);
4309
+
4310
+ if (n_expert == 0) {
4311
+ throw std::runtime_error("n_expert must be > 0");
4312
+ }
4313
+ if (n_expert_used == 0) {
4314
+ throw std::runtime_error("n_expert_used must be > 0");
4315
+ }
4316
+
4317
+ // MoE branch
4318
+ layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
4319
+ layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
4320
+ layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
4321
+
4322
+ // Shared expert branch
4323
+ layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
4324
+ layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0);
4325
+ layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
4326
+ }
4327
+ }
4328
+ } break;
4329
+ case LLM_ARCH_ARCEE:
4330
+ {
4331
+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
4332
+
4333
+ // output
4334
+ output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
4335
+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
4336
+
4337
+ // if output is NULL, init from the input tok embed
4338
+ if (output == NULL) {
4339
+ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
4340
+ }
4341
+
4342
+ for (int i = 0; i < n_layer; ++i) {
4343
+ auto & layer = layers[i];
4344
+
4345
+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
4346
+
4347
+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
4348
+ layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
4349
+ layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
4350
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
4351
+
4352
+ layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
4353
+
4354
+ layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
4355
+
4356
+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
4357
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
4358
+ }
4359
+ } break;
4360
+ case LLM_ARCH_ERNIE4_5:
4361
+ {
4362
+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
4363
+
4364
+ // output
4365
+ output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
4366
+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
4367
+ // if output is NULL, init from the input tok embed
4368
+ if (output == NULL) {
4369
+ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
4370
+ }
4371
+
4372
+ for (int i = 0; i < n_layer; ++i) {
4373
+ auto & layer = layers[i];
4374
+
4375
+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
4376
+
4377
+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
4378
+ layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
4379
+ layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
4380
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
4381
+
4382
+ // optional bias tensors
4383
+ layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
4384
+ layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
4385
+ layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
4386
+ layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
4387
+
4388
+ layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
4389
+ layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
4390
+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
4391
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
4392
+ }
4393
+ } break;
4114
4394
  default:
4115
4395
  throw std::runtime_error("unknown architecture");
4116
4396
  }
@@ -4355,6 +4635,15 @@ void llama_model::print_info() const {
4355
4635
  LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
4356
4636
  LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank);
4357
4637
  LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms);
4638
+
4639
+ if (!classifier_labels.empty()) {
4640
+ LLAMA_LOG_INFO("%s: n_cls_out = %u\n", __func__, hparams.n_cls_out);
4641
+
4642
+ size_t i = 0;
4643
+ for (auto label : classifier_labels) {
4644
+ LLAMA_LOG_INFO("%s: cls_label[%2zu] = %s\n", __func__, i++, label.c_str());
4645
+ }
4646
+ }
4358
4647
  }
4359
4648
 
4360
4649
  LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str());
@@ -4537,6 +4826,8 @@ struct llm_build_llama : public llm_graph_context {
4537
4826
 
4538
4827
  const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
4539
4828
 
4829
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
4830
+
4540
4831
  for (int il = 0; il < n_layer; ++il) {
4541
4832
  ggml_tensor * inpSA = inpL;
4542
4833
 
@@ -4599,9 +4890,7 @@ struct llm_build_llama : public llm_graph_context {
4599
4890
  cb(cur, "attn_out", il);
4600
4891
  }
4601
4892
 
4602
- if (il == n_layer - 1) {
4603
- // skip computing output for unused tokens
4604
- ggml_tensor * inp_out_ids = build_inp_out_ids();
4893
+ if (il == n_layer - 1 && inp_out_ids) {
4605
4894
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
4606
4895
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
4607
4896
  }
@@ -4697,6 +4986,8 @@ struct llm_build_llama_iswa : public llm_graph_context {
4697
4986
 
4698
4987
  const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
4699
4988
 
4989
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
4990
+
4700
4991
  for (int il = 0; il < n_layer; ++il) {
4701
4992
  ggml_tensor * inpSA = inpL;
4702
4993
 
@@ -4773,9 +5064,7 @@ struct llm_build_llama_iswa : public llm_graph_context {
4773
5064
  cb(cur, "attn_out", il);
4774
5065
  }
4775
5066
 
4776
- if (il == n_layer - 1) {
4777
- // skip computing output for unused tokens
4778
- ggml_tensor * inp_out_ids = build_inp_out_ids();
5067
+ if (il == n_layer - 1 && inp_out_ids) {
4779
5068
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
4780
5069
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
4781
5070
  }
@@ -4875,6 +5164,9 @@ struct llm_build_deci : public llm_graph_context {
4875
5164
  auto * inp_attn = build_attn_inp_kv_unified();
4876
5165
 
4877
5166
  const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
5167
+
5168
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
5169
+
4878
5170
  for (int il = 0; il < n_layer; ++il) {
4879
5171
  ggml_tensor * inpSA = inpL;
4880
5172
  const int64_t n_head_kv = hparams.n_head_kv(il);
@@ -4948,9 +5240,7 @@ struct llm_build_deci : public llm_graph_context {
4948
5240
  Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
4949
5241
  }
4950
5242
 
4951
- if (il == n_layer - 1) {
4952
- // skip computing output for unused tokens
4953
- ggml_tensor * inp_out_ids = build_inp_out_ids();
5243
+ if (il == n_layer - 1 && inp_out_ids) {
4954
5244
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
4955
5245
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
4956
5246
  }
@@ -5029,6 +5319,8 @@ struct llm_build_baichuan : public llm_graph_context {
5029
5319
 
5030
5320
  auto * inp_attn = build_attn_inp_kv_unified();
5031
5321
 
5322
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
5323
+
5032
5324
  for (int il = 0; il < n_layer; ++il) {
5033
5325
  ggml_tensor * inpSA = inpL;
5034
5326
 
@@ -5080,9 +5372,7 @@ struct llm_build_baichuan : public llm_graph_context {
5080
5372
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
5081
5373
  }
5082
5374
 
5083
- if (il == n_layer - 1) {
5084
- // skip computing output for unused tokens
5085
- ggml_tensor * inp_out_ids = build_inp_out_ids();
5375
+ if (il == n_layer - 1 && inp_out_ids) {
5086
5376
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
5087
5377
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
5088
5378
  }
@@ -5151,6 +5441,8 @@ struct llm_build_xverse : public llm_graph_context {
5151
5441
 
5152
5442
  auto * inp_attn = build_attn_inp_kv_unified();
5153
5443
 
5444
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
5445
+
5154
5446
  for (int il = 0; il < n_layer; ++il) {
5155
5447
  ggml_tensor * inpSA = inpL;
5156
5448
 
@@ -5195,9 +5487,7 @@ struct llm_build_xverse : public llm_graph_context {
5195
5487
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
5196
5488
  }
5197
5489
 
5198
- if (il == n_layer - 1) {
5199
- // skip computing output for unused tokens
5200
- ggml_tensor * inp_out_ids = build_inp_out_ids();
5490
+ if (il == n_layer - 1 && inp_out_ids) {
5201
5491
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
5202
5492
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
5203
5493
  }
@@ -5265,6 +5555,8 @@ struct llm_build_falcon : public llm_graph_context {
5265
5555
 
5266
5556
  auto * inp_attn = build_attn_inp_kv_unified();
5267
5557
 
5558
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
5559
+
5268
5560
  for (int il = 0; il < n_layer; ++il) {
5269
5561
  ggml_tensor * attn_norm;
5270
5562
 
@@ -5320,9 +5612,7 @@ struct llm_build_falcon : public llm_graph_context {
5320
5612
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
5321
5613
  }
5322
5614
 
5323
- if (il == n_layer - 1) {
5324
- // skip computing output for unused tokens
5325
- ggml_tensor * inp_out_ids = build_inp_out_ids();
5615
+ if (il == n_layer - 1 && inp_out_ids) {
5326
5616
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
5327
5617
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
5328
5618
  attn_norm = ggml_get_rows(ctx0, attn_norm, inp_out_ids);
@@ -5391,6 +5681,8 @@ struct llm_build_grok : public llm_graph_context {
5391
5681
 
5392
5682
  auto * inp_attn = build_attn_inp_kv_unified();
5393
5683
 
5684
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
5685
+
5394
5686
  for (int il = 0; il < n_layer; ++il) {
5395
5687
  ggml_tensor * inpSA = inpL;
5396
5688
 
@@ -5450,9 +5742,7 @@ struct llm_build_grok : public llm_graph_context {
5450
5742
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
5451
5743
  }
5452
5744
 
5453
- if (il == n_layer - 1) {
5454
- // skip computing output for unused tokens
5455
- ggml_tensor * inp_out_ids = build_inp_out_ids();
5745
+ if (il == n_layer - 1 && inp_out_ids) {
5456
5746
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
5457
5747
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
5458
5748
  }
@@ -5551,6 +5841,8 @@ struct llm_build_dbrx : public llm_graph_context {
5551
5841
 
5552
5842
  auto * inp_attn = build_attn_inp_kv_unified();
5553
5843
 
5844
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
5845
+
5554
5846
  for (int il = 0; il < n_layer; ++il) {
5555
5847
  ggml_tensor * inpSA = inpL;
5556
5848
 
@@ -5601,9 +5893,7 @@ struct llm_build_dbrx : public llm_graph_context {
5601
5893
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
5602
5894
  }
5603
5895
 
5604
- if (il == n_layer - 1) {
5605
- // skip computing output for unused tokens
5606
- ggml_tensor * inp_out_ids = build_inp_out_ids();
5896
+ if (il == n_layer - 1 && inp_out_ids) {
5607
5897
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
5608
5898
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
5609
5899
  }
@@ -5683,6 +5973,8 @@ struct llm_build_starcoder : public llm_graph_context {
5683
5973
  inpL = ggml_add(ctx0, inpL, pos);
5684
5974
  cb(inpL, "inpL", -1);
5685
5975
 
5976
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
5977
+
5686
5978
  for (int il = 0; il < n_layer; ++il) {
5687
5979
  cur = build_norm(inpL,
5688
5980
  model.layers[il].attn_norm,
@@ -5715,9 +6007,7 @@ struct llm_build_starcoder : public llm_graph_context {
5715
6007
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
5716
6008
  }
5717
6009
 
5718
- if (il == n_layer - 1) {
5719
- // skip computing output for unused tokens
5720
- ggml_tensor * inp_out_ids = build_inp_out_ids();
6010
+ if (il == n_layer - 1 && inp_out_ids) {
5721
6011
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
5722
6012
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
5723
6013
  }
@@ -5782,6 +6072,8 @@ struct llm_build_refact : public llm_graph_context {
5782
6072
 
5783
6073
  auto * inp_attn = build_attn_inp_kv_unified();
5784
6074
 
6075
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
6076
+
5785
6077
  for (int il = 0; il < n_layer; ++il) {
5786
6078
  ggml_tensor * inpSA = inpL;
5787
6079
 
@@ -5814,9 +6106,7 @@ struct llm_build_refact : public llm_graph_context {
5814
6106
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
5815
6107
  }
5816
6108
 
5817
- if (il == n_layer - 1) {
5818
- // skip computing output for unused tokens
5819
- ggml_tensor * inp_out_ids = build_inp_out_ids();
6109
+ if (il == n_layer - 1 && inp_out_ids) {
5820
6110
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
5821
6111
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
5822
6112
  }
@@ -5887,8 +6177,10 @@ struct llm_build_bert : public llm_graph_context {
5887
6177
  inpL = build_inp_embd(model.tok_embd);
5888
6178
 
5889
6179
  // token types are hardcoded to zero ("Sentence A")
5890
- ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0);
5891
- inpL = ggml_add(ctx0, inpL, type_row0);
6180
+ if (model.type_embd) {
6181
+ ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0);
6182
+ inpL = ggml_add(ctx0, inpL, type_row0);
6183
+ }
5892
6184
  if (model.arch == LLM_ARCH_BERT) {
5893
6185
  inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.pos_embd, inp_pos), inpL);
5894
6186
  }
@@ -5900,17 +6192,34 @@ struct llm_build_bert : public llm_graph_context {
5900
6192
 
5901
6193
  auto * inp_attn = build_attn_inp_no_cache();
5902
6194
 
5903
- // iterate layers
6195
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
6196
+
5904
6197
  for (int il = 0; il < n_layer; ++il) {
5905
6198
  ggml_tensor * cur = inpL;
5906
6199
 
5907
- ggml_tensor * Qcur;
5908
- ggml_tensor * Kcur;
5909
- ggml_tensor * Vcur;
6200
+ {
6201
+ ggml_tensor * Qcur;
6202
+ ggml_tensor * Kcur;
6203
+ ggml_tensor * Vcur;
5910
6204
 
5911
- // self-attention
5912
- if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_JINA_BERT_V2) {
5913
- Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq);
6205
+ // self-attention
6206
+ if (model.layers[il].wqkv) {
6207
+ cur = build_lora_mm(model.layers[il].wqkv, cur);
6208
+ cb(cur, "wqkv", il);
6209
+
6210
+ if (model.layers[il].bqkv) {
6211
+ cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
6212
+ cb(cur, "bqkv", il);
6213
+ }
6214
+
6215
+ Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
6216
+ Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
6217
+ Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
6218
+ } else {
6219
+ Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq);
6220
+ Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk);
6221
+ Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv);
6222
+ }
5914
6223
 
5915
6224
  if (model.layers[il].attn_q_norm) {
5916
6225
  Qcur = build_norm(Qcur,
@@ -5919,8 +6228,6 @@ struct llm_build_bert : public llm_graph_context {
5919
6228
  LLM_NORM, il);
5920
6229
  }
5921
6230
 
5922
- Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk);
5923
-
5924
6231
  if (model.layers[il].attn_k_norm) {
5925
6232
  Kcur = build_norm(Kcur,
5926
6233
  model.layers[il].attn_k_norm,
@@ -5928,54 +6235,36 @@ struct llm_build_bert : public llm_graph_context {
5928
6235
  LLM_NORM, il);
5929
6236
  }
5930
6237
 
5931
- Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv);
5932
-
5933
6238
  Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
5934
6239
  Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
5935
6240
  Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
5936
- } else {
5937
- // compute Q and K and RoPE them
5938
- cur = build_lora_mm(model.layers[il].wqkv, cur);
5939
- cb(cur, "wqkv", il);
5940
-
5941
- if (model.arch == LLM_ARCH_NOMIC_BERT_MOE) {
5942
- cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
5943
- cb(cur, "bqkv", il);
5944
- }
5945
6241
 
5946
- Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
5947
- Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
5948
- Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
6242
+ // RoPE
6243
+ if (model.arch == LLM_ARCH_NOMIC_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE) {
6244
+ Qcur = ggml_rope_ext(
6245
+ ctx0, Qcur, inp_pos, nullptr,
6246
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
6247
+ ext_factor, attn_factor, beta_fast, beta_slow
6248
+ );
5949
6249
 
5950
- Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
5951
- Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
5952
- Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
6250
+ Kcur = ggml_rope_ext(
6251
+ ctx0, Kcur, inp_pos, nullptr,
6252
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
6253
+ ext_factor, attn_factor, beta_fast, beta_slow
6254
+ );
6255
+ }
5953
6256
 
5954
- Qcur = ggml_rope_ext(
5955
- ctx0, Qcur, inp_pos, nullptr,
5956
- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
5957
- ext_factor, attn_factor, beta_fast, beta_slow
5958
- );
6257
+ cb(Qcur, "Qcur", il);
6258
+ cb(Kcur, "Kcur", il);
6259
+ cb(Vcur, "Vcur", il);
5959
6260
 
5960
- Kcur = ggml_rope_ext(
5961
- ctx0, Kcur, inp_pos, nullptr,
5962
- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
5963
- ext_factor, attn_factor, beta_fast, beta_slow
5964
- );
6261
+ cur = build_attn(inp_attn, gf,
6262
+ model.layers[il].wo, model.layers[il].bo,
6263
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6264
+ cb(cur, "kqv_out", il);
5965
6265
  }
5966
6266
 
5967
- cb(Qcur, "Qcur", il);
5968
- cb(Kcur, "Kcur", il);
5969
- cb(Vcur, "Vcur", il);
5970
-
5971
- cur = build_attn(inp_attn, gf,
5972
- model.layers[il].wo, model.layers[il].bo,
5973
- Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
5974
- cb(cur, "kqv_out", il);
5975
-
5976
- if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) {
5977
- // skip computing output for unused tokens
5978
- ggml_tensor * inp_out_ids = build_inp_out_ids();
6267
+ if (il == n_layer - 1 && inp_out_ids) {
5979
6268
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
5980
6269
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
5981
6270
  }
@@ -6024,7 +6313,7 @@ struct llm_build_bert : public llm_graph_context {
6024
6313
  model.layers[il].ffn_gate, NULL, NULL,
6025
6314
  model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
6026
6315
  NULL,
6027
- LLM_FFN_GELU, LLM_FFN_PAR, il);
6316
+ model.layers[il].ffn_gate ? LLM_FFN_GELU : LLM_FFN_GEGLU, LLM_FFN_PAR, il);
6028
6317
  cb(cur, "ffn_out", il);
6029
6318
  } else {
6030
6319
  cur = build_ffn(cur,
@@ -6055,6 +6344,118 @@ struct llm_build_bert : public llm_graph_context {
6055
6344
  }
6056
6345
  };
6057
6346
 
6347
+ struct llm_build_neo_bert : public llm_graph_context {
6348
+ llm_build_neo_bert(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
6349
+ const int64_t n_embd_head = hparams.n_embd_head_v;
6350
+ const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
6351
+
6352
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
6353
+
6354
+ ggml_tensor * cur;
6355
+ ggml_tensor * inpL;
6356
+ ggml_tensor * inp_pos = build_inp_pos();
6357
+
6358
+ // construct input embeddings (token, type, position)
6359
+ inpL = build_inp_embd(model.tok_embd);
6360
+ cb(inpL, "inp_embd", -1);
6361
+
6362
+ auto * inp_attn = build_attn_inp_no_cache();
6363
+
6364
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
6365
+
6366
+ for (int il = 0; il < n_layer; ++il) {
6367
+ ggml_tensor * cur = inpL;
6368
+
6369
+ // pre-norm
6370
+ cur = build_norm(inpL,
6371
+ model.layers[il].attn_norm, NULL,
6372
+ LLM_NORM_RMS, il);
6373
+
6374
+ {
6375
+ ggml_tensor * Qcur;
6376
+ ggml_tensor * Kcur;
6377
+ ggml_tensor * Vcur;
6378
+
6379
+ // self-attention
6380
+ cur = build_lora_mm(model.layers[il].wqkv, cur);
6381
+ cb(cur, "wqkv", il);
6382
+
6383
+ Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
6384
+ Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
6385
+ Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
6386
+
6387
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
6388
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
6389
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
6390
+
6391
+ // RoPE
6392
+ Qcur = ggml_rope_ext(
6393
+ ctx0, Qcur, inp_pos, nullptr,
6394
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
6395
+ ext_factor, attn_factor, beta_fast, beta_slow
6396
+ );
6397
+
6398
+ Kcur = ggml_rope_ext(
6399
+ ctx0, Kcur, inp_pos, nullptr,
6400
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
6401
+ ext_factor, attn_factor, beta_fast, beta_slow
6402
+ );
6403
+
6404
+ cb(Qcur, "Qcur", il);
6405
+ cb(Kcur, "Kcur", il);
6406
+ cb(Vcur, "Vcur", il);
6407
+
6408
+ cur = build_attn(inp_attn, gf,
6409
+ model.layers[il].wo, nullptr,
6410
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6411
+ cb(cur, "kqv_out", il);
6412
+ }
6413
+
6414
+ if (il == n_layer - 1 && inp_out_ids) {
6415
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
6416
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
6417
+ }
6418
+
6419
+ // re-add the layer input
6420
+ cur = ggml_add(ctx0, cur, inpL);
6421
+
6422
+ ggml_tensor * ffn_inp = cur;
6423
+ cb(ffn_inp, "ffn_inp", il);
6424
+
6425
+ // pre-norm
6426
+ cur = build_norm(ffn_inp,
6427
+ model.layers[il].ffn_norm, NULL,
6428
+ LLM_NORM_RMS, il);
6429
+ cb(cur, "ffn_norm", il);
6430
+
6431
+ // feed-forward network
6432
+ cur = build_ffn(cur,
6433
+ model.layers[il].ffn_up,
6434
+ NULL, NULL, NULL, NULL, NULL,
6435
+ model.layers[il].ffn_down,
6436
+ NULL, NULL, NULL,
6437
+ LLM_FFN_SWIGLU, LLM_FFN_SEQ, il);
6438
+
6439
+ // attentions bypass the intermediate layer
6440
+ cur = ggml_add(ctx0, cur, ffn_inp);
6441
+
6442
+ // input for next layer
6443
+ inpL = cur;
6444
+ }
6445
+
6446
+ cur = inpL;
6447
+
6448
+ cur = build_norm(cur,
6449
+ model.output_norm_enc, NULL,
6450
+ LLM_NORM_RMS, -1);
6451
+
6452
+ cb(cur, "result_embd", -1);
6453
+ res->t_embd = cur;
6454
+
6455
+ ggml_build_forward_expand(gf, cur);
6456
+ }
6457
+ };
6458
+
6058
6459
  struct llm_build_bloom : public llm_graph_context {
6059
6460
  llm_build_bloom(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
6060
6461
  const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -6075,6 +6476,8 @@ struct llm_build_bloom : public llm_graph_context {
6075
6476
  LLM_NORM, -1);
6076
6477
  cb(inpL, "inp_norm", -1);
6077
6478
 
6479
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
6480
+
6078
6481
  for (int il = 0; il < n_layer; ++il) {
6079
6482
  cur = build_norm(inpL,
6080
6483
  model.layers[il].attn_norm,
@@ -6107,9 +6510,7 @@ struct llm_build_bloom : public llm_graph_context {
6107
6510
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6108
6511
  }
6109
6512
 
6110
- if (il == n_layer - 1) {
6111
- // skip computing output for unused tokens
6112
- ggml_tensor * inp_out_ids = build_inp_out_ids();
6513
+ if (il == n_layer - 1 && inp_out_ids) {
6113
6514
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
6114
6515
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
6115
6516
  }
@@ -6186,6 +6587,8 @@ struct llm_build_mpt : public llm_graph_context {
6186
6587
  cb(inpL, "inpL", -1);
6187
6588
  }
6188
6589
 
6590
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
6591
+
6189
6592
  for (int il = 0; il < n_layer; ++il) {
6190
6593
  ggml_tensor * attn_norm;
6191
6594
 
@@ -6248,9 +6651,7 @@ struct llm_build_mpt : public llm_graph_context {
6248
6651
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6249
6652
  }
6250
6653
 
6251
- if (il == n_layer - 1) {
6252
- // skip computing output for unused tokens
6253
- ggml_tensor * inp_out_ids = build_inp_out_ids();
6654
+ if (il == n_layer - 1 && inp_out_ids) {
6254
6655
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
6255
6656
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
6256
6657
  }
@@ -6319,6 +6720,8 @@ struct llm_build_stablelm : public llm_graph_context {
6319
6720
 
6320
6721
  auto * inp_attn = build_attn_inp_kv_unified();
6321
6722
 
6723
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
6724
+
6322
6725
  for (int il = 0; il < n_layer; ++il) {
6323
6726
  // norm
6324
6727
  cur = build_norm(inpL,
@@ -6394,9 +6797,7 @@ struct llm_build_stablelm : public llm_graph_context {
6394
6797
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6395
6798
  }
6396
6799
 
6397
- if (il == n_layer - 1) {
6398
- // skip computing output for unused tokens
6399
- ggml_tensor * inp_out_ids = build_inp_out_ids();
6800
+ if (il == n_layer - 1 && inp_out_ids) {
6400
6801
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
6401
6802
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
6402
6803
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
@@ -6471,6 +6872,8 @@ struct llm_build_qwen : public llm_graph_context {
6471
6872
 
6472
6873
  auto * inp_attn = build_attn_inp_kv_unified();
6473
6874
 
6875
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
6876
+
6474
6877
  for (int il = 0; il < n_layer; ++il) {
6475
6878
  ggml_tensor * inpSA = inpL;
6476
6879
 
@@ -6517,9 +6920,7 @@ struct llm_build_qwen : public llm_graph_context {
6517
6920
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6518
6921
  }
6519
6922
 
6520
- if (il == n_layer - 1) {
6521
- // skip computing output for unused tokens
6522
- ggml_tensor * inp_out_ids = build_inp_out_ids();
6923
+ if (il == n_layer - 1 && inp_out_ids) {
6523
6924
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
6524
6925
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
6525
6926
  }
@@ -6588,6 +6989,8 @@ struct llm_build_qwen2 : public llm_graph_context {
6588
6989
 
6589
6990
  auto * inp_attn = build_attn_inp_kv_unified();
6590
6991
 
6992
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
6993
+
6591
6994
  for (int il = 0; il < n_layer; ++il) {
6592
6995
  ggml_tensor * inpSA = inpL;
6593
6996
 
@@ -6637,9 +7040,7 @@ struct llm_build_qwen2 : public llm_graph_context {
6637
7040
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6638
7041
  }
6639
7042
 
6640
- if (il == n_layer - 1) {
6641
- // skip computing output for unused tokens
6642
- ggml_tensor * inp_out_ids = build_inp_out_ids();
7043
+ if (il == n_layer - 1 && inp_out_ids) {
6643
7044
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
6644
7045
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
6645
7046
  }
@@ -6709,6 +7110,8 @@ struct llm_build_qwen2vl : public llm_graph_context {
6709
7110
  int sections[4];
6710
7111
  std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
6711
7112
 
7113
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
7114
+
6712
7115
  for (int il = 0; il < n_layer; ++il) {
6713
7116
  ggml_tensor * inpSA = inpL;
6714
7117
 
@@ -6758,9 +7161,7 @@ struct llm_build_qwen2vl : public llm_graph_context {
6758
7161
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6759
7162
  }
6760
7163
 
6761
- if (il == n_layer - 1) {
6762
- // skip computing output for unused tokens
6763
- ggml_tensor * inp_out_ids = build_inp_out_ids();
7164
+ if (il == n_layer - 1 && inp_out_ids) {
6764
7165
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
6765
7166
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
6766
7167
  }
@@ -6827,6 +7228,8 @@ struct llm_build_qwen2moe : public llm_graph_context {
6827
7228
 
6828
7229
  auto * inp_attn = build_attn_inp_kv_unified();
6829
7230
 
7231
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
7232
+
6830
7233
  for (int il = 0; il < n_layer; ++il) {
6831
7234
  ggml_tensor * inpSA = inpL;
6832
7235
 
@@ -6885,9 +7288,7 @@ struct llm_build_qwen2moe : public llm_graph_context {
6885
7288
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6886
7289
  }
6887
7290
 
6888
- if (il == n_layer - 1) {
6889
- // skip computing output for unused tokens
6890
- ggml_tensor * inp_out_ids = build_inp_out_ids();
7291
+ if (il == n_layer - 1 && inp_out_ids) {
6891
7292
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
6892
7293
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
6893
7294
  }
@@ -6986,6 +7387,8 @@ struct llm_build_qwen3 : public llm_graph_context {
6986
7387
 
6987
7388
  auto * inp_attn = build_attn_inp_kv_unified();
6988
7389
 
7390
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
7391
+
6989
7392
  for (int il = 0; il < n_layer; ++il) {
6990
7393
  ggml_tensor * inpSA = inpL;
6991
7394
 
@@ -7038,9 +7441,7 @@ struct llm_build_qwen3 : public llm_graph_context {
7038
7441
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
7039
7442
  }
7040
7443
 
7041
- if (il == n_layer - 1) {
7042
- // skip computing output for unused tokens
7043
- ggml_tensor * inp_out_ids = build_inp_out_ids();
7444
+ if (il == n_layer - 1 && inp_out_ids) {
7044
7445
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
7045
7446
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
7046
7447
  }
@@ -7107,6 +7508,8 @@ struct llm_build_qwen3moe : public llm_graph_context {
7107
7508
 
7108
7509
  auto * inp_attn = build_attn_inp_kv_unified();
7109
7510
 
7511
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
7512
+
7110
7513
  for (int il = 0; il < n_layer; ++il) {
7111
7514
  ggml_tensor * inpSA = inpL;
7112
7515
 
@@ -7159,9 +7562,7 @@ struct llm_build_qwen3moe : public llm_graph_context {
7159
7562
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
7160
7563
  }
7161
7564
 
7162
- if (il == n_layer - 1) {
7163
- // skip computing output for unused tokens
7164
- ggml_tensor * inp_out_ids = build_inp_out_ids();
7565
+ if (il == n_layer - 1 && inp_out_ids) {
7165
7566
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
7166
7567
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
7167
7568
  }
@@ -7237,6 +7638,8 @@ struct llm_build_phi2 : public llm_graph_context {
7237
7638
 
7238
7639
  auto * inp_attn = build_attn_inp_kv_unified();
7239
7640
 
7641
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
7642
+
7240
7643
  for (int il = 0; il < n_layer; ++il) {
7241
7644
  attn_norm_output = build_norm(inpL,
7242
7645
  model.layers[il].attn_norm,
@@ -7299,9 +7702,7 @@ struct llm_build_phi2 : public llm_graph_context {
7299
7702
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
7300
7703
  }
7301
7704
 
7302
- if (il == n_layer - 1) {
7303
- // skip computing output for unused tokens
7304
- ggml_tensor * inp_out_ids = build_inp_out_ids();
7705
+ if (il == n_layer - 1 && inp_out_ids) {
7305
7706
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
7306
7707
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
7307
7708
  attn_norm_output = ggml_get_rows(ctx0, attn_norm_output, inp_out_ids);
@@ -7373,6 +7774,8 @@ struct llm_build_phi3 : public llm_graph_context {
7373
7774
  inp_attn = build_attn_inp_kv_unified();
7374
7775
  }
7375
7776
 
7777
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
7778
+
7376
7779
  for (int il = 0; il < n_layer; ++il) {
7377
7780
  auto * residual = inpL;
7378
7781
 
@@ -7436,9 +7839,7 @@ struct llm_build_phi3 : public llm_graph_context {
7436
7839
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
7437
7840
  }
7438
7841
 
7439
- if (il == n_layer - 1) {
7440
- // skip computing output for unused tokens
7441
- ggml_tensor* inp_out_ids = build_inp_out_ids();
7842
+ if (il == n_layer - 1 && inp_out_ids) {
7442
7843
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
7443
7844
  residual = ggml_get_rows(ctx0, residual, inp_out_ids);
7444
7845
  }
@@ -7524,15 +7925,16 @@ struct llm_build_plamo : public llm_graph_context {
7524
7925
 
7525
7926
  auto * inp_attn = build_attn_inp_kv_unified();
7526
7927
 
7527
- for (int il = 0; il < n_layer; ++il) {
7928
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
7528
7929
 
7930
+ for (int il = 0; il < n_layer; ++il) {
7529
7931
  // norm
7530
7932
  cur = build_norm(inpL,
7531
7933
  model.layers[il].attn_norm, NULL,
7532
7934
  LLM_NORM_RMS, il);
7533
7935
  cb(cur, "attn_norm", il);
7534
7936
 
7535
- ggml_tensor * attention_norm = cur;
7937
+ ggml_tensor * sa_inp = cur;
7536
7938
 
7537
7939
  // self-attention
7538
7940
  {
@@ -7570,18 +7972,17 @@ struct llm_build_plamo : public llm_graph_context {
7570
7972
  model.layers[il].wo, NULL,
7571
7973
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
7572
7974
  }
7573
- ggml_tensor * sa_out = cur;
7574
7975
 
7575
- cur = attention_norm;
7576
-
7577
- if (il == n_layer - 1) {
7578
- // skip computing output for unused tokens
7579
- ggml_tensor * inp_out_ids = build_inp_out_ids();
7976
+ if (il == n_layer - 1 && inp_out_ids) {
7580
7977
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
7581
- sa_out = ggml_get_rows(ctx0, sa_out, inp_out_ids);
7978
+ sa_inp = ggml_get_rows(ctx0, sa_inp, inp_out_ids);
7582
7979
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
7583
7980
  }
7584
7981
 
7982
+ ggml_tensor * sa_out = cur;
7983
+
7984
+ cur = sa_inp;
7985
+
7585
7986
  // feed-forward network
7586
7987
  {
7587
7988
  cur = build_ffn(cur,
@@ -7646,6 +8047,8 @@ struct llm_build_gpt2 : public llm_graph_context {
7646
8047
  inpL = ggml_add(ctx0, inpL, pos);
7647
8048
  cb(inpL, "inpL", -1);
7648
8049
 
8050
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
8051
+
7649
8052
  for (int il = 0; il < n_layer; ++il) {
7650
8053
  cur = build_norm(inpL,
7651
8054
  model.layers[il].attn_norm,
@@ -7678,9 +8081,7 @@ struct llm_build_gpt2 : public llm_graph_context {
7678
8081
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
7679
8082
  }
7680
8083
 
7681
- if (il == n_layer - 1) {
7682
- // skip computing output for unused tokens
7683
- ggml_tensor * inp_out_ids = build_inp_out_ids();
8084
+ if (il == n_layer - 1 && inp_out_ids) {
7684
8085
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
7685
8086
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
7686
8087
  }
@@ -7750,6 +8151,8 @@ struct llm_build_codeshell : public llm_graph_context {
7750
8151
 
7751
8152
  auto * inp_attn = build_attn_inp_kv_unified();
7752
8153
 
8154
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
8155
+
7753
8156
  for (int il = 0; il < n_layer; ++il) {
7754
8157
  cur = build_norm(inpL,
7755
8158
  model.layers[il].attn_norm,
@@ -7794,9 +8197,7 @@ struct llm_build_codeshell : public llm_graph_context {
7794
8197
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
7795
8198
  }
7796
8199
 
7797
- if (il == n_layer - 1) {
7798
- // skip computing output for unused tokens
7799
- ggml_tensor * inp_out_ids = build_inp_out_ids();
8200
+ if (il == n_layer - 1 && inp_out_ids) {
7800
8201
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
7801
8202
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
7802
8203
  }
@@ -7850,128 +8251,128 @@ struct llm_build_codeshell : public llm_graph_context {
7850
8251
 
7851
8252
  struct llm_build_orion : public llm_graph_context {
7852
8253
  llm_build_orion(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
7853
- const int64_t n_embd_head = hparams.n_embd_head_v;
8254
+ const int64_t n_embd_head = hparams.n_embd_head_v;
7854
8255
 
7855
- GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
7856
- GGML_ASSERT(n_embd_head == hparams.n_rot);
8256
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
8257
+ GGML_ASSERT(n_embd_head == hparams.n_rot);
7857
8258
 
7858
- ggml_tensor * cur;
7859
- ggml_tensor * inpL;
8259
+ ggml_tensor * cur;
8260
+ ggml_tensor * inpL;
7860
8261
 
7861
- inpL = build_inp_embd(model.tok_embd);
8262
+ inpL = build_inp_embd(model.tok_embd);
7862
8263
 
7863
- // inp_pos - contains the positions
7864
- ggml_tensor * inp_pos = build_inp_pos();
8264
+ // inp_pos - contains the positions
8265
+ ggml_tensor * inp_pos = build_inp_pos();
7865
8266
 
7866
- auto * inp_attn = build_attn_inp_kv_unified();
8267
+ auto * inp_attn = build_attn_inp_kv_unified();
7867
8268
 
7868
- for (int il = 0; il < n_layer; ++il) {
7869
- ggml_tensor * inpSA = inpL;
8269
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
7870
8270
 
7871
- // norm
7872
- cur = build_norm(inpL,
7873
- model.layers[il].attn_norm, model.layers[il].attn_norm_b,
7874
- LLM_NORM, il);
7875
- cb(cur, "attn_norm", il);
8271
+ for (int il = 0; il < n_layer; ++il) {
8272
+ ggml_tensor * inpSA = inpL;
7876
8273
 
7877
- // self-attention
7878
- {
7879
- // compute Q and K and RoPE them
7880
- ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
7881
- cb(Qcur, "Qcur", il);
7882
- // if (model.layers[il].bq) {
7883
- // Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
7884
- // cb(Qcur, "Qcur", il);
7885
- // }
7886
-
7887
- ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
7888
- cb(Kcur, "Kcur", il);
7889
- // if (model.layers[il].bk) {
7890
- // Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
7891
- // cb(Kcur, "Kcur", il);
7892
- // }
7893
-
7894
- ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
7895
- cb(Vcur, "Vcur", il);
7896
- // if (model.layers[il].bv) {
7897
- // Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
7898
- // cb(Vcur, "Vcur", il);
7899
- // }
7900
-
7901
- Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
7902
- Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
7903
- Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
7904
-
7905
- Qcur = ggml_rope_ext(
7906
- ctx0, Qcur, inp_pos, nullptr,
7907
- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
7908
- ext_factor, attn_factor, beta_fast, beta_slow
7909
- );
8274
+ // norm
8275
+ cur = build_norm(inpL,
8276
+ model.layers[il].attn_norm, model.layers[il].attn_norm_b,
8277
+ LLM_NORM, il);
8278
+ cb(cur, "attn_norm", il);
7910
8279
 
7911
- Kcur = ggml_rope_ext(
7912
- ctx0, Kcur, inp_pos, nullptr,
7913
- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
7914
- ext_factor, attn_factor, beta_fast, beta_slow
7915
- );
8280
+ // self-attention
8281
+ {
8282
+ // compute Q and K and RoPE them
8283
+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
8284
+ cb(Qcur, "Qcur", il);
8285
+ // if (model.layers[il].bq) {
8286
+ // Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
8287
+ // cb(Qcur, "Qcur", il);
8288
+ // }
7916
8289
 
7917
- cb(Qcur, "Qcur", il);
7918
- cb(Kcur, "Kcur", il);
7919
- cb(Vcur, "Vcur", il);
8290
+ ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
8291
+ cb(Kcur, "Kcur", il);
8292
+ // if (model.layers[il].bk) {
8293
+ // Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
8294
+ // cb(Kcur, "Kcur", il);
8295
+ // }
7920
8296
 
7921
- cur = build_attn(inp_attn, gf,
7922
- model.layers[il].wo, NULL,
7923
- Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
7924
- }
8297
+ ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
8298
+ cb(Vcur, "Vcur", il);
8299
+ // if (model.layers[il].bv) {
8300
+ // Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
8301
+ // cb(Vcur, "Vcur", il);
8302
+ // }
7925
8303
 
7926
- if (il == n_layer - 1) {
7927
- // skip computing output for unused tokens
7928
- ggml_tensor * inp_out_ids = build_inp_out_ids();
7929
- cur = ggml_get_rows(ctx0, cur, inp_out_ids);
7930
- inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
7931
- }
8304
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
8305
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
8306
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
7932
8307
 
7933
- ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
7934
- cb(ffn_inp, "ffn_inp", il);
8308
+ Qcur = ggml_rope_ext(
8309
+ ctx0, Qcur, inp_pos, nullptr,
8310
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
8311
+ ext_factor, attn_factor, beta_fast, beta_slow
8312
+ );
7935
8313
 
7936
- // feed-forward network
7937
- cur = build_norm(ffn_inp,
7938
- model.layers[il].ffn_norm, model.layers[il].ffn_norm_b,
7939
- LLM_NORM, il);
7940
- cb(cur, "ffn_norm", il);
8314
+ Kcur = ggml_rope_ext(
8315
+ ctx0, Kcur, inp_pos, nullptr,
8316
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
8317
+ ext_factor, attn_factor, beta_fast, beta_slow
8318
+ );
7941
8319
 
7942
- cur = build_ffn(cur,
7943
- model.layers[il].ffn_up, NULL, NULL,
7944
- model.layers[il].ffn_gate, NULL, NULL,
7945
- model.layers[il].ffn_down, NULL, NULL,
7946
- NULL,
7947
- LLM_FFN_SILU, LLM_FFN_PAR, il);
7948
- cb(cur, "ffn_out", il);
8320
+ cb(Qcur, "Qcur", il);
8321
+ cb(Kcur, "Kcur", il);
8322
+ cb(Vcur, "Vcur", il);
7949
8323
 
7950
- cur = ggml_add(ctx0, cur, ffn_inp);
8324
+ cur = build_attn(inp_attn, gf,
8325
+ model.layers[il].wo, NULL,
8326
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
8327
+ }
7951
8328
 
7952
- cur = build_cvec(cur, il);
7953
- cb(cur, "l_out", il);
8329
+ if (il == n_layer - 1 && inp_out_ids) {
8330
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
8331
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
8332
+ }
7954
8333
 
7955
- // input for next layer
7956
- inpL = cur;
7957
- }
8334
+ ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
8335
+ cb(ffn_inp, "ffn_inp", il);
8336
+
8337
+ // feed-forward network
8338
+ cur = build_norm(ffn_inp,
8339
+ model.layers[il].ffn_norm, model.layers[il].ffn_norm_b,
8340
+ LLM_NORM, il);
8341
+ cb(cur, "ffn_norm", il);
8342
+
8343
+ cur = build_ffn(cur,
8344
+ model.layers[il].ffn_up, NULL, NULL,
8345
+ model.layers[il].ffn_gate, NULL, NULL,
8346
+ model.layers[il].ffn_down, NULL, NULL,
8347
+ NULL,
8348
+ LLM_FFN_SILU, LLM_FFN_PAR, il);
8349
+ cb(cur, "ffn_out", il);
8350
+
8351
+ cur = ggml_add(ctx0, cur, ffn_inp);
8352
+
8353
+ cur = build_cvec(cur, il);
8354
+ cb(cur, "l_out", il);
8355
+
8356
+ // input for next layer
8357
+ inpL = cur;
8358
+ }
7958
8359
 
7959
- cur = inpL;
8360
+ cur = inpL;
7960
8361
 
7961
- cur = build_norm(cur,
7962
- model.output_norm, model.output_norm_b,
7963
- LLM_NORM, -1);
8362
+ cur = build_norm(cur,
8363
+ model.output_norm, model.output_norm_b,
8364
+ LLM_NORM, -1);
7964
8365
 
7965
- cb(cur, "result_norm", -1);
7966
- res->t_embd = cur;
8366
+ cb(cur, "result_norm", -1);
8367
+ res->t_embd = cur;
7967
8368
 
7968
- // lm_head
7969
- cur = build_lora_mm(model.output, cur);
8369
+ // lm_head
8370
+ cur = build_lora_mm(model.output, cur);
7970
8371
 
7971
- cb(cur, "result_output", -1);
7972
- res->t_logits = cur;
8372
+ cb(cur, "result_output", -1);
8373
+ res->t_logits = cur;
7973
8374
 
7974
- ggml_build_forward_expand(gf, cur);
8375
+ ggml_build_forward_expand(gf, cur);
7975
8376
  }
7976
8377
  };
7977
8378
 
@@ -7992,6 +8393,8 @@ struct llm_build_internlm2 : public llm_graph_context {
7992
8393
 
7993
8394
  auto * inp_attn = build_attn_inp_kv_unified();
7994
8395
 
8396
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
8397
+
7995
8398
  for (int il = 0; il < n_layer; ++il) {
7996
8399
  ggml_tensor * inpSA = inpL;
7997
8400
 
@@ -8050,9 +8453,7 @@ struct llm_build_internlm2 : public llm_graph_context {
8050
8453
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
8051
8454
  }
8052
8455
 
8053
- if (il == n_layer - 1) {
8054
- // skip computing output for unused tokens
8055
- ggml_tensor * inp_out_ids = build_inp_out_ids();
8456
+ if (il == n_layer - 1 && inp_out_ids) {
8056
8457
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
8057
8458
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
8058
8459
  }
@@ -8128,6 +8529,8 @@ struct llm_build_minicpm3 : public llm_graph_context {
8128
8529
 
8129
8530
  auto * inp_attn = build_attn_inp_kv_unified();
8130
8531
 
8532
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
8533
+
8131
8534
  for (int il = 0; il < n_layer; ++il) {
8132
8535
  ggml_tensor * inpSA = inpL;
8133
8536
 
@@ -8247,15 +8650,13 @@ struct llm_build_minicpm3 : public llm_graph_context {
8247
8650
  q_states, k_states, v_states, nullptr, nullptr, kq_scale, il);
8248
8651
  }
8249
8652
 
8250
- if (il == n_layer - 1) {
8251
- // skip computing output for unused tokens
8252
- ggml_tensor * inp_out_ids = build_inp_out_ids();
8653
+ if (il == n_layer - 1 && inp_out_ids) {
8253
8654
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
8254
8655
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
8255
8656
  }
8256
8657
 
8257
8658
  // scale_res - scale the hidden states for residual connection
8258
- const float scale_res = scale_depth/sqrtf(float(n_layer));
8659
+ const float scale_res = scale_depth/sqrtf(float(n_layer)); // TODO: is this correct?
8259
8660
  cur = ggml_scale(ctx0, cur, scale_res);
8260
8661
  cb(cur, "hidden_scaled", il);
8261
8662
 
@@ -8332,6 +8733,8 @@ struct llm_build_gemma : public llm_graph_context {
8332
8733
 
8333
8734
  auto * inp_attn = build_attn_inp_kv_unified();
8334
8735
 
8736
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
8737
+
8335
8738
  for (int il = 0; il < n_layer; ++il) {
8336
8739
  // norm
8337
8740
  cur = build_norm(inpL,
@@ -8377,9 +8780,7 @@ struct llm_build_gemma : public llm_graph_context {
8377
8780
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
8378
8781
  }
8379
8782
 
8380
- if (il == n_layer - 1) {
8381
- // skip computing output for unused tokens
8382
- ggml_tensor * inp_out_ids = build_inp_out_ids();
8783
+ if (il == n_layer - 1 && inp_out_ids) {
8383
8784
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
8384
8785
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
8385
8786
  }
@@ -8448,6 +8849,8 @@ struct llm_build_gemma2_iswa : public llm_graph_context {
8448
8849
 
8449
8850
  auto * inp_attn = build_attn_inp_kv_unified_iswa();
8450
8851
 
8852
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
8853
+
8451
8854
  for (int il = 0; il < n_layer; ++il) {
8452
8855
  // norm
8453
8856
  cur = build_norm(inpL,
@@ -8485,32 +8888,23 @@ struct llm_build_gemma2_iswa : public llm_graph_context {
8485
8888
  cb(Kcur, "Kcur", il);
8486
8889
  cb(Vcur, "Vcur", il);
8487
8890
 
8488
- // ref: https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
8489
- switch (model.type) {
8490
- case LLM_TYPE_2B:
8491
- case LLM_TYPE_9B:
8492
- case LLM_TYPE_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head))); break;
8493
- default: GGML_ABORT("fatal error");
8494
- };
8495
- cb(Qcur, "Qcur_scaled", il);
8891
+ Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale);
8496
8892
 
8497
8893
  cur = build_attn(inp_attn, gf,
8498
8894
  model.layers[il].wo, NULL,
8499
8895
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
8500
8896
  }
8501
8897
 
8898
+ if (il == n_layer - 1 && inp_out_ids) {
8899
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
8900
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
8901
+ }
8902
+
8502
8903
  cur = build_norm(cur,
8503
8904
  model.layers[il].attn_post_norm, NULL,
8504
8905
  LLM_NORM_RMS, il);
8505
8906
  cb(cur, "attn_post_norm", il);
8506
8907
 
8507
- if (il == n_layer - 1) {
8508
- // skip computing output for unused tokens
8509
- ggml_tensor * inp_out_ids = build_inp_out_ids();
8510
- cur = ggml_get_rows(ctx0, cur, inp_out_ids);
8511
- inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
8512
- }
8513
-
8514
8908
  ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
8515
8909
  cb(sa_out, "sa_out", il);
8516
8910
 
@@ -8589,6 +8983,8 @@ struct llm_build_gemma3_iswa : public llm_graph_context {
8589
8983
  // TODO: is causal == true correct? might need some changes
8590
8984
  auto * inp_attn = build_attn_inp_kv_unified_iswa();
8591
8985
 
8986
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
8987
+
8592
8988
  for (int il = 0; il < n_layer; ++il) {
8593
8989
  const float freq_base_l = model.get_rope_freq_base (cparams, il);
8594
8990
  const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
@@ -8633,9 +9029,17 @@ struct llm_build_gemma3_iswa : public llm_graph_context {
8633
9029
  cb(Kcur, "Kcur", il);
8634
9030
  cb(Vcur, "Vcur", il);
8635
9031
 
9032
+ // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/model.py#L315
9033
+ Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale);
9034
+
8636
9035
  cur = build_attn(inp_attn, gf,
8637
9036
  model.layers[il].wo, NULL,
8638
- Qcur, Kcur, Vcur, nullptr, nullptr, hparams.f_attention_scale, il);
9037
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
9038
+ }
9039
+
9040
+ if (il == n_layer - 1 && inp_out_ids) {
9041
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
9042
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
8639
9043
  }
8640
9044
 
8641
9045
  cur = build_norm(cur,
@@ -8643,13 +9047,6 @@ struct llm_build_gemma3_iswa : public llm_graph_context {
8643
9047
  LLM_NORM_RMS, il);
8644
9048
  cb(cur, "attn_post_norm", il);
8645
9049
 
8646
- if (il == n_layer - 1) {
8647
- // skip computing output for unused tokens
8648
- ggml_tensor * inp_out_ids = build_inp_out_ids();
8649
- cur = ggml_get_rows(ctx0, cur, inp_out_ids);
8650
- inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
8651
- }
8652
-
8653
9050
  ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
8654
9051
  cb(sa_out, "sa_out", il);
8655
9052
 
@@ -8702,109 +9099,229 @@ struct llm_build_gemma3_iswa : public llm_graph_context {
8702
9099
  }
8703
9100
  };
8704
9101
 
8705
- // TODO: move up next to build_starcoder
8706
- struct llm_build_starcoder2 : public llm_graph_context {
8707
- llm_build_starcoder2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
8708
- const int64_t n_embd_head = hparams.n_embd_head_v;
8709
-
8710
- GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
8711
- GGML_ASSERT(n_embd_head == hparams.n_rot);
8712
-
9102
+ struct llm_build_gemma3n_iswa : public llm_graph_context {
9103
+ const llama_model & model;
9104
+ ggml_cgraph * gf;
9105
+
9106
+ const int64_t n_embd_head;
9107
+ const int64_t n_embd_altup;
9108
+ const int64_t n_altup;
9109
+ const int i_altup_act;
9110
+ const int n_layer_kv = 20; // number of layers having KV [KV_REUSE]
9111
+ const int n_layer_sparsity = 10; // number of layers using activation sparsity
9112
+ const float f_sparsity_std_mul = 1.6448533535003662f; // std_multiplier = normal_dist.icdf(0.95)
9113
+
9114
+ ggml_tensor * one; // containing single element 1.0f
9115
+
9116
+ llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf)
9117
+ : llm_graph_context(params),
9118
+ model(model),
9119
+ gf(gf),
9120
+ n_embd_head(model.hparams.n_embd_head_k),
9121
+ n_embd_altup(model.hparams.n_embd_altup),
9122
+ n_altup(model.hparams.n_altup),
9123
+ i_altup_act(model.hparams.i_altup_act) {
8713
9124
  ggml_tensor * cur;
8714
9125
  ggml_tensor * inpL;
8715
9126
 
9127
+ // TODO: remove this when ggml_scale_add is implemented
9128
+ one = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
9129
+ {
9130
+ auto inp = std::make_unique<llm_graph_input_one>();
9131
+ inp->one = one;
9132
+ res->add_input(std::move(inp));
9133
+ }
9134
+
8716
9135
  inpL = build_inp_embd(model.tok_embd);
8717
9136
 
9137
+ // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
9138
+ if (ubatch.token) {
9139
+ inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
9140
+ cb(inpL, "inp_scaled", -1);
9141
+ }
9142
+
8718
9143
  // inp_pos - contains the positions
8719
9144
  ggml_tensor * inp_pos = build_inp_pos();
8720
9145
 
8721
- auto * inp_attn = build_attn_inp_kv_unified();
9146
+ // TODO: is causal == true correct? might need some changes
9147
+ auto * inp_attn = build_attn_inp_kv_unified_iswa();
9148
+
9149
+ // inp_per_layer shape: [n_embd_altup, n_tokens, n_layer]
9150
+ ggml_tensor * inp_per_layer = project_per_layer_inputs(inpL, get_per_layer_inputs());
9151
+
9152
+ // inpL now has only 1 altup, project it to the rest of the altups
9153
+ // these "added" altups will be concat to the last dim of inpL
9154
+ {
9155
+ ggml_tensor * target_magnitude = calc_magnitude(inpL);
9156
+ ggml_tensor * inp_repeated = ggml_repeat_4d(ctx0, inpL, n_embd, n_tokens, n_altup - 1, 1);
9157
+ ggml_tensor * altup_added = ggml_mul_mat(ctx0, model.altup_proj, inp_repeated); // shape: [n_embd, n_tokens, n_altup - 1]
9158
+ ggml_tensor * new_magnitude = calc_magnitude(altup_added);
9159
+ altup_added = ggml_div(ctx0,
9160
+ ggml_mul(ctx0, altup_added, target_magnitude),
9161
+ new_magnitude);
9162
+ inpL = ggml_concat(ctx0, inpL, altup_added, 2); // shape: [n_embd, n_tokens, n_altup]
9163
+ cb(inpL, "inp_stacked", -1);
9164
+ }
9165
+
9166
+ // inpL now has shape: [n_embd, n_tokens, n_altup]
9167
+ // inp_per_layer now has shape: [n_embd_altup, n_tokens, n_layer]
8722
9168
 
8723
9169
  for (int il = 0; il < n_layer; ++il) {
8724
- ggml_tensor * inpSA = inpL;
9170
+ // this block is made to be closely resemble Gemma3p5DecoderLayer on python code
9171
+ const bool has_kv = (il < n_layer_kv);
9172
+
9173
+ const float freq_base_l = model.get_rope_freq_base (cparams, il);
9174
+ const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
9175
+
9176
+ ggml_tensor * cur = inpL; // [n_embd, n_tokens, n_altup]
9177
+ ggml_tensor * predictions = altup_predict(cur, il); // [n_embd, n_tokens, n_altup]
9178
+
9179
+ // predicted value will go through self-attention and laurel
9180
+ ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act); // [n_embd, n_tokens]
9181
+ cur = active_prediction;
9182
+ cb(cur, "active_prediction", il);
8725
9183
 
8726
9184
  // norm
8727
- cur = build_norm(inpL,
8728
- model.layers[il].attn_norm, model.layers[il].attn_norm_b,
8729
- LLM_NORM, il);
9185
+ cur = build_norm(cur, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
8730
9186
  cb(cur, "attn_norm", il);
8731
9187
 
9188
+ // laurel
9189
+ ggml_tensor * laurel_out = laurel(cur, il); // [n_embd, n_tokens]
9190
+
8732
9191
  // self-attention
8733
- {
9192
+ if (has_kv) {
8734
9193
  // compute Q and K and RoPE them
8735
9194
  ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
8736
9195
  cb(Qcur, "Qcur", il);
8737
- if (model.layers[il].bq) {
8738
- Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
8739
- cb(Qcur, "Qcur", il);
8740
- }
8741
9196
 
8742
9197
  ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
8743
9198
  cb(Kcur, "Kcur", il);
8744
- if (model.layers[il].bk) {
8745
- Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
8746
- cb(Kcur, "Kcur", il);
8747
- }
8748
9199
 
8749
9200
  ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
8750
9201
  cb(Vcur, "Vcur", il);
8751
- if (model.layers[il].bv) {
8752
- Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
8753
- cb(Vcur, "Vcur", il);
8754
- }
8755
9202
 
8756
9203
  Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
8757
9204
  Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
8758
9205
  Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
8759
9206
 
9207
+ Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
9208
+ Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
9209
+ Vcur = ggml_rms_norm(ctx0, Vcur, hparams.f_norm_rms_eps);
9210
+
9211
+ cb(Qcur, "Qcur_normed", il);
9212
+ cb(Kcur, "Kcur_normed", il);
9213
+ cb(Vcur, "Vcur_normed", il);
9214
+
8760
9215
  Qcur = ggml_rope_ext(
8761
9216
  ctx0, Qcur, inp_pos, nullptr,
8762
- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
8763
- ext_factor, attn_factor, beta_fast, beta_slow
8764
- );
9217
+ n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
9218
+ ext_factor, attn_factor, beta_fast, beta_slow);
8765
9219
 
8766
9220
  Kcur = ggml_rope_ext(
8767
9221
  ctx0, Kcur, inp_pos, nullptr,
8768
- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
8769
- ext_factor, attn_factor, beta_fast, beta_slow
8770
- );
9222
+ n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
9223
+ ext_factor, attn_factor, beta_fast, beta_slow);
8771
9224
 
9225
+ cb(Qcur, "Qcur_pos", il);
9226
+ cb(Kcur, "Kcur_pos", il);
9227
+
9228
+ cur = build_attn(inp_attn, gf,
9229
+ model.layers[il].wo, NULL,
9230
+ Qcur, Kcur, Vcur, nullptr, nullptr, hparams.f_attention_scale, il);
9231
+ } else {
9232
+ // no KV layers
9233
+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
8772
9234
  cb(Qcur, "Qcur", il);
8773
- cb(Kcur, "Kcur", il);
8774
- cb(Vcur, "Vcur", il);
9235
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
9236
+
9237
+ Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
9238
+ cb(Qcur, "Qcur_normed", il);
9239
+
9240
+ Qcur = ggml_rope_ext(
9241
+ ctx0, Qcur, inp_pos, nullptr,
9242
+ n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
9243
+ ext_factor, attn_factor, beta_fast, beta_slow);
9244
+ cb(Qcur, "Qcur_pos", il);
8775
9245
 
8776
9246
  cur = build_attn(inp_attn, gf,
8777
- model.layers[il].wo, model.layers[il].bo,
8778
- Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
9247
+ model.layers[il].wo, NULL,
9248
+ Qcur, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il);
8779
9249
  }
8780
9250
 
8781
- if (il == n_layer - 1) {
8782
- // skip computing output for unused tokens
8783
- ggml_tensor * inp_out_ids = build_inp_out_ids();
8784
- cur = ggml_get_rows(ctx0, cur, inp_out_ids);
8785
- inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
8786
- }
9251
+ cur = build_norm(cur,
9252
+ model.layers[il].attn_post_norm, NULL,
9253
+ LLM_NORM_RMS, il);
9254
+ cb(cur, "attn_post_norm", il);
8787
9255
 
8788
- ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
8789
- cb(ffn_inp, "ffn_inp", il);
9256
+ cur = ggml_add(ctx0, cur, active_prediction); // [n_embd, n_tokens]
9257
+ cb(cur, "attn_gated", il);
8790
9258
 
8791
- // feed-forward network
9259
+ ggml_tensor * attn_laurel = ggml_scale(ctx0,
9260
+ ggml_add(ctx0, cur, laurel_out),
9261
+ 1.0f / sqrtf(2.0f)); // [n_embd, n_tokens]
9262
+ cb(attn_laurel, "attn_laurel", il);
8792
9263
 
8793
- cur = build_norm(ffn_inp,
8794
- model.layers[il].ffn_norm, model.layers[il].ffn_norm_b,
8795
- LLM_NORM, il);
9264
+ cur = build_norm(attn_laurel,
9265
+ model.layers[il].ffn_norm, NULL,
9266
+ LLM_NORM_RMS, il);
8796
9267
  cb(cur, "ffn_norm", il);
8797
9268
 
8798
- cur = build_ffn(cur,
8799
- model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
8800
- NULL, NULL, NULL,
8801
- model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
8802
- NULL,
8803
- LLM_FFN_GELU, LLM_FFN_SEQ, il);
8804
- cb(cur, "ffn_out", il);
9269
+ // feed-forward network
9270
+ {
9271
+ ggml_tensor * up_proj = build_lora_mm(model.layers[il].ffn_up, cur);
9272
+ ggml_tensor * gate_proj = build_lora_mm(model.layers[il].ffn_gate, cur);
8805
9273
 
8806
- cur = ggml_add(ctx0, cur, ffn_inp);
9274
+ if (il < n_layer_sparsity) {
9275
+ // apply activation sparsity
9276
+ gate_proj = gaussian_topk(gate_proj);
9277
+ }
9278
+ gate_proj = ggml_gelu(ctx0, gate_proj);
9279
+
9280
+ cur = ggml_mul(ctx0, up_proj, gate_proj);
9281
+ cur = build_lora_mm(model.layers[il].ffn_down, cur);
9282
+ cb(cur, "ffn_out", il);
9283
+ }
9284
+
9285
+ cur = build_norm(cur,
9286
+ model.layers[il].ffn_post_norm, NULL,
9287
+ LLM_NORM_RMS, -1);
9288
+ cb(cur, "ffn_post_norm", il);
9289
+
9290
+ ggml_tensor * attn_ffw_laurel_gated = ggml_add(ctx0, cur, attn_laurel); // [n_embd, n_tokens]
9291
+ cb(attn_ffw_laurel_gated, "attn_ffw_laurel_gated", il);
8807
9292
 
9293
+ ggml_tensor * corrected = altup_correct(predictions, attn_ffw_laurel_gated, il); // [n_embd, n_tokens, n_altup]
9294
+
9295
+ ggml_tensor * first_prediction; // [n_embd, n_tokens]
9296
+ {
9297
+ first_prediction = view_2d_slice(corrected, i_altup_act); // [n_embd, n_tokens]
9298
+ first_prediction = ggml_mul(ctx0, first_prediction, model.layers[il].altup_correct_scale);
9299
+ first_prediction = build_lora_mm(model.layers[il].per_layer_inp_gate, first_prediction);
9300
+ first_prediction = ggml_gelu(ctx0, first_prediction); // [n_embd_altup, n_tokens]
9301
+ cb(first_prediction, "first_prediction_gated", il);
9302
+ ggml_tensor * inp_this_layer = view_2d_slice(inp_per_layer, il); // [n_embd_altup, n_tokens]
9303
+ first_prediction = ggml_mul(ctx0, first_prediction, inp_this_layer); // [n_embd_altup, n_tokens]
9304
+ cb(first_prediction, "first_prediction_scaled", il);
9305
+
9306
+ first_prediction = build_lora_mm(model.layers[il].per_layer_proj, first_prediction); // [n_embd, n_tokens]
9307
+ first_prediction = build_norm(first_prediction,
9308
+ model.layers[il].per_layer_post_norm, NULL,
9309
+ LLM_NORM_RMS, il);
9310
+ cb(first_prediction, "first_prediction_out", il);
9311
+ }
9312
+
9313
+ // equivalent to python code: corrected_predictions[1:] += first_prediction
9314
+ {
9315
+ ggml_tensor * slice_first = view_2d_slice(corrected, 0);
9316
+ ggml_tensor * slice_rest = ggml_view_3d(ctx0, corrected, n_embd, n_tokens, n_altup - 1,
9317
+ ggml_row_size(corrected->type, n_embd),
9318
+ ggml_row_size(corrected->type, n_embd*n_tokens),
9319
+ n_embd*n_tokens*ggml_element_size(corrected));
9320
+ ggml_tensor * tmp = ggml_add(ctx0, slice_rest, first_prediction); // [n_embd, n_tokens, n_altup - 1]
9321
+ corrected = ggml_concat(ctx0, slice_first, tmp, 2); // [n_embd, n_tokens, n_altup]
9322
+ }
9323
+
9324
+ cur = corrected; // [n_embd, n_tokens, n_altup]
8808
9325
  cur = build_cvec(cur, il);
8809
9326
  cb(cur, "l_out", il);
8810
9327
 
@@ -8812,13 +9329,329 @@ struct llm_build_starcoder2 : public llm_graph_context {
8812
9329
  inpL = cur;
8813
9330
  }
8814
9331
 
8815
- cur = inpL;
9332
+ cur = inpL; // [n_embd, n_tokens, n_altup]
8816
9333
 
8817
- cur = build_norm(cur,
8818
- model.output_norm, model.output_norm_b,
8819
- LLM_NORM, -1);
9334
+ // cur now has multiple altup(s), we want to merge them back to 1 altup
9335
+ {
9336
+ ggml_tensor * target_magnitude = calc_magnitude(view_2d_slice(cur, i_altup_act)); // [n_embd, n_tokens]
9337
+ // do a view to skip the first slice (active altup)
9338
+ ggml_tensor * alt_slice = ggml_view_3d(ctx0, cur, n_embd, n_tokens, n_altup - 1,
9339
+ ggml_row_size(cur->type, n_embd),
9340
+ ggml_row_size(cur->type, n_embd*n_tokens),
9341
+ n_embd*n_tokens*ggml_element_size(cur));
9342
+ ggml_tensor * altup_unembd = ggml_mul_mat(ctx0, model.altup_unembd_proj, alt_slice); // shape: [n_embd, n_tokens, n_altup - 1]
9343
+ ggml_tensor * new_magnitude = calc_magnitude(altup_unembd);
9344
+ altup_unembd = ggml_div(ctx0,
9345
+ ggml_mul(ctx0, altup_unembd, target_magnitude),
9346
+ new_magnitude);
9347
+ cb(altup_unembd, "altup_unembd", -1);
9348
+
9349
+ // equivalent to torch.mean(hidden_states, dim=0)
9350
+ cur = view_2d_slice(cur, 0); // [n_embd, n_tokens]
9351
+ for (int i = 0; i < n_altup - 1; ++i) {
9352
+ cur = ggml_add(ctx0, cur, view_2d_slice(altup_unembd, i));
9353
+ }
9354
+ cur = ggml_scale(ctx0, cur, 1.0f / float(n_altup)); // [n_embd, n_tokens]
9355
+ cb(cur, "unembd_merged", -1);
9356
+ }
8820
9357
 
8821
- cb(cur, "result_norm", -1);
9358
+ // cur now has shape: [n_embd, n_tokens]
9359
+
9360
+ // TODO: move this to right after the last KV layer
9361
+ {
9362
+ // skip computing output for unused tokens
9363
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
9364
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
9365
+ }
9366
+
9367
+ cur = build_norm(cur,
9368
+ model.output_norm, NULL,
9369
+ LLM_NORM_RMS, -1);
9370
+
9371
+ cb(cur, "result_norm", -1);
9372
+ res->t_embd = cur;
9373
+
9374
+ cur = build_lora_mm(model.output, cur);
9375
+
9376
+ {
9377
+ // final logit soft-capping
9378
+ cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
9379
+ cur = ggml_tanh(ctx0, cur);
9380
+ cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
9381
+ }
9382
+
9383
+ cb(cur, "result_output", -1);
9384
+ res->t_logits = cur;
9385
+
9386
+ ggml_build_forward_expand(gf, cur);
9387
+ }
9388
+
9389
+ ggml_tensor * calc_magnitude(ggml_tensor * x) {
9390
+ return ggml_sqrt(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, x)));
9391
+ }
9392
+
9393
+ // get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim
9394
+ ggml_tensor * view_2d_slice(ggml_tensor * x, int idx) {
9395
+ GGML_ASSERT(idx < (int)x->ne[2]);
9396
+ return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1],
9397
+ ggml_row_size(x->type, x->ne[0]),
9398
+ idx * x->ne[0] * x->ne[1] * ggml_element_size(x));
9399
+ }
9400
+
9401
+ // equivalent to get_per_layer_inputs() in python code
9402
+ // output shape: [n_embd_altup, n_layer, n_tokens]
9403
+ ggml_tensor * get_per_layer_inputs() {
9404
+ auto inp = std::make_unique<llm_graph_input_embd>();
9405
+ ggml_tensor * inp_per_layer;
9406
+ if (ubatch.token) {
9407
+ inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
9408
+ ggml_set_input(inp->tokens);
9409
+ res->t_tokens = inp->tokens;
9410
+ inp_per_layer = ggml_get_rows(ctx0, model.tok_embd_per_layer, inp->tokens);
9411
+ inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, n_tokens);
9412
+ inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf((float)n_embd_altup));
9413
+ cb(inp_per_layer, "inp_per_layer_selected", -1);
9414
+ } else {
9415
+ GGML_ABORT("TODO: support embd input");
9416
+ }
9417
+ res->add_input(std::move(inp));
9418
+ return inp_per_layer;
9419
+ }
9420
+
9421
+ // equivalent to project_per_layer_inputs() in python code
9422
+ // this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim
9423
+ // output shape: [n_embd_altup, n_tokens, n_layer]
9424
+ ggml_tensor * project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer) {
9425
+ const float per_layer_projection_scale = 1.0f / sqrtf((float)n_embd);
9426
+ const float per_layer_input_scale = 1.0f / sqrtf(2.0f);
9427
+
9428
+ ggml_tensor * per_layer_proj = ggml_mul_mat(ctx0, model.per_layer_model_proj, inputs_embeds);
9429
+ per_layer_proj = ggml_scale(ctx0, per_layer_proj, per_layer_projection_scale);
9430
+ per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_altup, n_layer, n_tokens);
9431
+ per_layer_proj = build_norm(per_layer_proj,
9432
+ model.per_layer_proj_norm, NULL,
9433
+ LLM_NORM_RMS, -1); // [n_embd_altup, n_layer, n_tokens]
9434
+ cb(per_layer_proj, "per_layer_proj", -1);
9435
+
9436
+ inp_per_layer = ggml_add(ctx0, inp_per_layer, per_layer_proj);
9437
+ inp_per_layer = ggml_scale(ctx0, inp_per_layer, per_layer_input_scale);
9438
+ cb(inp_per_layer, "inp_per_layer", -1);
9439
+
9440
+ // permute to shape: [n_embd_altup, n_tokens, n_layer]
9441
+ inp_per_layer = ggml_cont(ctx0, ggml_permute(ctx0, inp_per_layer, 0, 2, 1, 3));
9442
+ return inp_per_layer;
9443
+ }
9444
+
9445
+ // input cur shape: [n_altup, n_tokens]
9446
+ // output shape: [n_altup, n_tokens]
9447
+ ggml_tensor * laurel(ggml_tensor * cur, int il) {
9448
+ ggml_tensor * tmp = cur;
9449
+ tmp = build_lora_mm(model.layers[il].laurel_l, tmp);
9450
+ tmp = build_lora_mm(model.layers[il].laurel_r, tmp);
9451
+ tmp = build_norm(tmp, model.layers[il].laurel_post_norm, NULL, LLM_NORM_RMS, il);
9452
+ tmp = ggml_add(ctx0, tmp, cur);
9453
+ cb(tmp, "laurel_out", il);
9454
+ return tmp;
9455
+ }
9456
+
9457
+ // input x shape: [n_embd, n_tokens]
9458
+ // output shape: [n_embd, n_tokens]
9459
+ ggml_tensor * gaussian_topk(ggml_tensor * x) {
9460
+ ggml_tensor * mean = ggml_mean(ctx0, x);
9461
+ ggml_tensor * std = ggml_sqrt(ctx0, ggml_scale(ctx0,
9462
+ ggml_sum_rows(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x, mean))),
9463
+ 1.0f / (float)(x->ne[0] - 1)
9464
+ ));
9465
+ ggml_tensor * cutoff_x = ggml_add(ctx0, mean, ggml_scale(ctx0, std, f_sparsity_std_mul));
9466
+ return ggml_relu(ctx0, ggml_sub(ctx0, x, cutoff_x));
9467
+ }
9468
+
9469
+ //
9470
+ // altup functions
9471
+ //
9472
+
9473
+ // equivalent to compute_router_modalities() in python code
9474
+ // input x shape: [n_embd, n_tokens]
9475
+ // output shape: [n_altup, n_tokens]
9476
+ ggml_tensor * altup_compute_router_modalities(ggml_tensor * x, int il) {
9477
+ ggml_tensor * router_inputs = build_norm(x,
9478
+ model.layers[il].altup_router_norm, NULL,
9479
+ LLM_NORM_RMS, il);
9480
+
9481
+ // router_input_scale
9482
+ router_inputs = ggml_scale(ctx0, router_inputs, 1.0f / (float)n_embd);
9483
+
9484
+ ggml_tensor * output = ggml_mul_mat(ctx0, model.layers[il].altup_router, router_inputs);
9485
+ return ggml_tanh(ctx0, output); // [n_altup, n_tokens]
9486
+ }
9487
+
9488
+ // input cur shape: [n_embd, n_tokens, n_altup]
9489
+ // output shape: [n_embd, n_tokens, n_altup]
9490
+ ggml_tensor * altup_predict(ggml_tensor * cur, int il) {
9491
+ ggml_tensor * activated = view_2d_slice(cur, i_altup_act); // [n_embd, n_tokens]
9492
+ ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens]
9493
+ cb(modalities, "modalities", il);
9494
+
9495
+ ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_predict_coef, modalities);
9496
+ cb(all_coefs, "all_coefs", il);
9497
+ // first dim now having n_altup^2 elements, we reshape it to 2D (so we end up with 3D tensor)
9498
+ all_coefs = ggml_reshape_3d(ctx0, all_coefs, n_altup, n_altup, n_tokens);
9499
+
9500
+ // permute to [n_altup, n_embd, n_tokens]
9501
+ ggml_tensor * cur_permuted = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3));
9502
+ ggml_tensor * predictions = ggml_mul_mat(ctx0, cur_permuted, all_coefs); // [n_altup, n_embd, n_tokens]
9503
+
9504
+ // final shape must be the same as cur: [n_embd, n_tokens, n_altup]
9505
+ predictions = ggml_cont(ctx0, ggml_permute(ctx0, predictions, 0, 2, 1, 3));
9506
+ predictions = ggml_add(ctx0, predictions, cur);
9507
+ cb(predictions, "predictions", il);
9508
+
9509
+ return predictions;
9510
+ }
9511
+
9512
+ // input predictions shape: [n_embd, n_tokens, n_altup]
9513
+ // input activated shape: [n_embd, n_tokens]
9514
+ // output shape: [n_embd, n_tokens, n_altup]
9515
+ ggml_tensor * altup_correct(ggml_tensor * predictions, ggml_tensor * activated, int il) {
9516
+ ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens]
9517
+ cb(modalities, "modalities", il);
9518
+
9519
+ ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act);
9520
+ ggml_tensor * innovation = ggml_sub(ctx0, activated, active_prediction); // [n_embd, n_tokens]
9521
+ cb(innovation, "innovation", il);
9522
+
9523
+ ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_correct_coef, modalities); // [n_altup, n_tokens]
9524
+ all_coefs = ggml_add(ctx0, all_coefs, one);
9525
+ cb(all_coefs, "all_coefs", il);
9526
+ all_coefs = ggml_cont(ctx0, ggml_transpose(ctx0, all_coefs)); // [n_tokens, n_altup]
9527
+ all_coefs = ggml_reshape_3d(ctx0, all_coefs, 1, n_tokens, n_altup); // [1, n_tokens, n_altup]
9528
+
9529
+ innovation = ggml_repeat_4d(ctx0, innovation, n_embd, n_tokens, n_altup, 1);
9530
+ ggml_tensor * corrected = ggml_mul(ctx0, innovation, all_coefs); // [n_embd, n_tokens, n_altup]
9531
+ corrected = ggml_add(ctx0, corrected, predictions); // [n_embd, n_tokens, n_altup]
9532
+ cb(corrected, "corrected", il);
9533
+
9534
+ return corrected;
9535
+ }
9536
+ };
9537
+
9538
+ // TODO: move up next to build_starcoder
9539
+ struct llm_build_starcoder2 : public llm_graph_context {
9540
+ llm_build_starcoder2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
9541
+ const int64_t n_embd_head = hparams.n_embd_head_v;
9542
+
9543
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
9544
+ GGML_ASSERT(n_embd_head == hparams.n_rot);
9545
+
9546
+ ggml_tensor * cur;
9547
+ ggml_tensor * inpL;
9548
+
9549
+ inpL = build_inp_embd(model.tok_embd);
9550
+
9551
+ // inp_pos - contains the positions
9552
+ ggml_tensor * inp_pos = build_inp_pos();
9553
+
9554
+ auto * inp_attn = build_attn_inp_kv_unified();
9555
+
9556
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
9557
+
9558
+ for (int il = 0; il < n_layer; ++il) {
9559
+ ggml_tensor * inpSA = inpL;
9560
+
9561
+ // norm
9562
+ cur = build_norm(inpL,
9563
+ model.layers[il].attn_norm, model.layers[il].attn_norm_b,
9564
+ LLM_NORM, il);
9565
+ cb(cur, "attn_norm", il);
9566
+
9567
+ // self-attention
9568
+ {
9569
+ // compute Q and K and RoPE them
9570
+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
9571
+ cb(Qcur, "Qcur", il);
9572
+ if (model.layers[il].bq) {
9573
+ Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
9574
+ cb(Qcur, "Qcur", il);
9575
+ }
9576
+
9577
+ ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
9578
+ cb(Kcur, "Kcur", il);
9579
+ if (model.layers[il].bk) {
9580
+ Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
9581
+ cb(Kcur, "Kcur", il);
9582
+ }
9583
+
9584
+ ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
9585
+ cb(Vcur, "Vcur", il);
9586
+ if (model.layers[il].bv) {
9587
+ Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
9588
+ cb(Vcur, "Vcur", il);
9589
+ }
9590
+
9591
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
9592
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
9593
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
9594
+
9595
+ Qcur = ggml_rope_ext(
9596
+ ctx0, Qcur, inp_pos, nullptr,
9597
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
9598
+ ext_factor, attn_factor, beta_fast, beta_slow
9599
+ );
9600
+
9601
+ Kcur = ggml_rope_ext(
9602
+ ctx0, Kcur, inp_pos, nullptr,
9603
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
9604
+ ext_factor, attn_factor, beta_fast, beta_slow
9605
+ );
9606
+
9607
+ cb(Qcur, "Qcur", il);
9608
+ cb(Kcur, "Kcur", il);
9609
+ cb(Vcur, "Vcur", il);
9610
+
9611
+ cur = build_attn(inp_attn, gf,
9612
+ model.layers[il].wo, model.layers[il].bo,
9613
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
9614
+ }
9615
+
9616
+ if (il == n_layer - 1 && inp_out_ids) {
9617
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
9618
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
9619
+ }
9620
+
9621
+ ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
9622
+ cb(ffn_inp, "ffn_inp", il);
9623
+
9624
+ // feed-forward network
9625
+
9626
+ cur = build_norm(ffn_inp,
9627
+ model.layers[il].ffn_norm, model.layers[il].ffn_norm_b,
9628
+ LLM_NORM, il);
9629
+ cb(cur, "ffn_norm", il);
9630
+
9631
+ cur = build_ffn(cur,
9632
+ model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
9633
+ NULL, NULL, NULL,
9634
+ model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
9635
+ NULL,
9636
+ LLM_FFN_GELU, LLM_FFN_SEQ, il);
9637
+ cb(cur, "ffn_out", il);
9638
+
9639
+ cur = ggml_add(ctx0, cur, ffn_inp);
9640
+
9641
+ cur = build_cvec(cur, il);
9642
+ cb(cur, "l_out", il);
9643
+
9644
+ // input for next layer
9645
+ inpL = cur;
9646
+ }
9647
+
9648
+ cur = inpL;
9649
+
9650
+ cur = build_norm(cur,
9651
+ model.output_norm, model.output_norm_b,
9652
+ LLM_NORM, -1);
9653
+
9654
+ cb(cur, "result_norm", -1);
8822
9655
  res->t_embd = cur;
8823
9656
 
8824
9657
  // lm_head
@@ -8841,8 +9674,9 @@ struct llm_build_mamba : public llm_graph_context {
8841
9674
  // {n_embd, n_tokens}
8842
9675
  inpL = build_inp_embd(model.tok_embd);
8843
9676
 
8844
- ggml_tensor * state_copy = build_inp_s_copy();
8845
- ggml_tensor * state_mask = build_inp_s_mask();
9677
+ auto * rs_inp = build_rs_inp();
9678
+
9679
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
8846
9680
 
8847
9681
  for (int il = 0; il < n_layer; ++il) {
8848
9682
  // norm
@@ -8851,12 +9685,9 @@ struct llm_build_mamba : public llm_graph_context {
8851
9685
  LLM_NORM_RMS, il);
8852
9686
  cb(cur, "attn_norm", il);
8853
9687
 
8854
- //cur = build_mamba_layer(gf, cur, state_copy, state_mask, il);
8855
- cur = build_mamba_layer(gf, cur, state_copy, state_mask, ubatch, il);
9688
+ cur = build_mamba_layer(rs_inp, gf, cur, ubatch, il);
8856
9689
 
8857
- if (il == n_layer - 1) {
8858
- // skip computing output for unused tokens
8859
- ggml_tensor * inp_out_ids = build_inp_out_ids();
9690
+ if (il == n_layer - 1 && inp_out_ids) {
8860
9691
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
8861
9692
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
8862
9693
  }
@@ -8890,15 +9721,14 @@ struct llm_build_mamba : public llm_graph_context {
8890
9721
 
8891
9722
  // TODO: split
8892
9723
  ggml_tensor * build_mamba_layer(
8893
- ggml_cgraph * gf,
8894
- ggml_tensor * cur,
8895
- ggml_tensor * state_copy,
8896
- ggml_tensor * state_mask,
8897
- const llama_ubatch & ubatch,
8898
- int il) const {
8899
- const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
9724
+ llm_graph_input_rs * inp,
9725
+ ggml_cgraph * gf,
9726
+ ggml_tensor * cur,
9727
+ const llama_ubatch & ubatch,
9728
+ int il) const {
9729
+ const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
8900
9730
 
8901
- const auto kv_head = kv_self->head;
9731
+ const auto kv_head = mctx_cur->get_head();
8902
9732
 
8903
9733
  const int64_t d_conv = hparams.ssm_d_conv;
8904
9734
  const int64_t d_inner = hparams.ssm_d_inner;
@@ -8916,17 +9746,17 @@ struct llm_build_mamba : public llm_graph_context {
8916
9746
  GGML_ASSERT(ubatch.equal_seqs);
8917
9747
  GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
8918
9748
 
8919
- ggml_tensor * conv_states_all = kv_self->k_l[il];
8920
- ggml_tensor * ssm_states_all = kv_self->v_l[il];
9749
+ ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
9750
+ ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
8921
9751
 
8922
9752
  // (ab)using the KV cache to store the states
8923
- ggml_tensor * conv = build_copy_mask_state(
8924
- gf, conv_states_all, state_copy, state_mask,
8925
- hparams.n_embd_k_s(), n_seqs);
9753
+ ggml_tensor * conv = build_rs(
9754
+ inp, gf, conv_states_all,
9755
+ hparams.n_embd_r(), n_seqs);
8926
9756
  conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
8927
- ggml_tensor * ssm = build_copy_mask_state(
8928
- gf, ssm_states_all, state_copy, state_mask,
8929
- hparams.n_embd_v_s(), n_seqs);
9757
+ ggml_tensor * ssm = build_rs(
9758
+ inp, gf, ssm_states_all,
9759
+ hparams.n_embd_s(), n_seqs);
8930
9760
  ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs);
8931
9761
 
8932
9762
  // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
@@ -9039,13 +9869,15 @@ struct llm_build_command_r : public llm_graph_context {
9039
9869
 
9040
9870
  auto * inp_attn = build_attn_inp_kv_unified();
9041
9871
 
9042
- for (int il = 0; il < n_layer; ++il) {
9872
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
9043
9873
 
9874
+ for (int il = 0; il < n_layer; ++il) {
9044
9875
  // norm
9045
9876
  cur = build_norm(inpL,
9046
9877
  model.layers[il].attn_norm, NULL,
9047
9878
  LLM_NORM, il);
9048
9879
  cb(cur, "attn_norm", il);
9880
+
9049
9881
  ggml_tensor * ffn_inp = cur;
9050
9882
 
9051
9883
  // self-attention
@@ -9113,9 +9945,7 @@ struct llm_build_command_r : public llm_graph_context {
9113
9945
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
9114
9946
  }
9115
9947
 
9116
- if (il == n_layer - 1) {
9117
- // skip computing output for unused tokens
9118
- ggml_tensor * inp_out_ids = build_inp_out_ids();
9948
+ if (il == n_layer - 1 && inp_out_ids) {
9119
9949
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
9120
9950
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
9121
9951
  ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
@@ -9186,6 +10016,8 @@ struct llm_build_cohere2_iswa : public llm_graph_context {
9186
10016
 
9187
10017
  auto * inp_attn = build_attn_inp_kv_unified_iswa();
9188
10018
 
10019
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
10020
+
9189
10021
  for (int il = 0; il < n_layer; ++il) {
9190
10022
  const bool is_swa = hparams.is_swa(il);
9191
10023
 
@@ -9248,9 +10080,7 @@ struct llm_build_cohere2_iswa : public llm_graph_context {
9248
10080
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
9249
10081
  }
9250
10082
 
9251
- if (il == n_layer - 1) {
9252
- // skip computing output for unused tokens
9253
- ggml_tensor * inp_out_ids = build_inp_out_ids();
10083
+ if (il == n_layer - 1 && inp_out_ids) {
9254
10084
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
9255
10085
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
9256
10086
  ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
@@ -9321,6 +10151,8 @@ struct llm_build_olmo : public llm_graph_context {
9321
10151
 
9322
10152
  auto * inp_attn = build_attn_inp_kv_unified();
9323
10153
 
10154
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
10155
+
9324
10156
  for (int il = 0; il < n_layer; ++il) {
9325
10157
  ggml_tensor * inpSA = inpL;
9326
10158
 
@@ -9379,9 +10211,7 @@ struct llm_build_olmo : public llm_graph_context {
9379
10211
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
9380
10212
  }
9381
10213
 
9382
- if (il == n_layer - 1) {
9383
- // skip computing output for unused tokens
9384
- ggml_tensor * inp_out_ids = build_inp_out_ids();
10214
+ if (il == n_layer - 1 && inp_out_ids) {
9385
10215
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
9386
10216
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
9387
10217
  }
@@ -9449,6 +10279,8 @@ struct llm_build_olmo2 : public llm_graph_context {
9449
10279
 
9450
10280
  auto * inp_attn = build_attn_inp_kv_unified();
9451
10281
 
10282
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
10283
+
9452
10284
  for (int il = 0; il < n_layer; ++il) {
9453
10285
  ggml_tensor * inpSA = inpL;
9454
10286
 
@@ -9499,18 +10331,16 @@ struct llm_build_olmo2 : public llm_graph_context {
9499
10331
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
9500
10332
  }
9501
10333
 
10334
+ if (il == n_layer - 1 && inp_out_ids) {
10335
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
10336
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
10337
+ }
10338
+
9502
10339
  cur = build_norm(cur,
9503
10340
  model.layers[il].attn_post_norm, NULL,
9504
10341
  LLM_NORM_RMS, il);
9505
10342
  cb(cur, "attn_post_norm", il);
9506
10343
 
9507
- if (il == n_layer - 1) {
9508
- // skip computing output for unused tokens
9509
- ggml_tensor * inp_out_ids = build_inp_out_ids();
9510
- cur = ggml_get_rows(ctx0, cur, inp_out_ids);
9511
- inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
9512
- }
9513
-
9514
10344
  ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
9515
10345
  cb(ffn_inp, "ffn_inp", il);
9516
10346
 
@@ -9578,6 +10408,8 @@ struct llm_build_olmoe : public llm_graph_context {
9578
10408
 
9579
10409
  auto * inp_attn = build_attn_inp_kv_unified();
9580
10410
 
10411
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
10412
+
9581
10413
  for (int il = 0; il < n_layer; ++il) {
9582
10414
  ggml_tensor * inpSA = inpL;
9583
10415
 
@@ -9632,9 +10464,7 @@ struct llm_build_olmoe : public llm_graph_context {
9632
10464
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
9633
10465
  }
9634
10466
 
9635
- if (il == n_layer - 1) {
9636
- // skip computing output for unused tokens
9637
- ggml_tensor * inp_out_ids = build_inp_out_ids();
10467
+ if (il == n_layer - 1 && inp_out_ids) {
9638
10468
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
9639
10469
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
9640
10470
  }
@@ -9704,6 +10534,8 @@ struct llm_build_openelm : public llm_graph_context {
9704
10534
 
9705
10535
  auto * inp_attn = build_attn_inp_kv_unified();
9706
10536
 
10537
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
10538
+
9707
10539
  for (int il = 0; il < n_layer; ++il) {
9708
10540
  const int64_t n_head = hparams.n_head(il);
9709
10541
  const int64_t n_head_kv = hparams.n_head_kv(il);
@@ -9765,11 +10597,9 @@ struct llm_build_openelm : public llm_graph_context {
9765
10597
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
9766
10598
  }
9767
10599
 
9768
- if (il == n_layer - 1) {
9769
- // skip computing output for unused tokens
9770
- ggml_tensor * inp_out_ids = build_inp_out_ids();
10600
+ if (il == n_layer - 1 && inp_out_ids) {
9771
10601
  residual = ggml_get_rows(ctx0, residual, inp_out_ids);
9772
- cur = ggml_get_rows(ctx0, cur, inp_out_ids);
10602
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
9773
10603
  }
9774
10604
 
9775
10605
  ggml_tensor * ffn_inp = ggml_add(ctx0, residual, cur);
@@ -9835,6 +10665,8 @@ struct llm_build_gptneox : public llm_graph_context {
9835
10665
 
9836
10666
  auto * inp_attn = build_attn_inp_kv_unified();
9837
10667
 
10668
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
10669
+
9838
10670
  for (int il = 0; il < n_layer; ++il) {
9839
10671
  cur = build_norm(inpL,
9840
10672
  model.layers[il].attn_norm,
@@ -9879,9 +10711,7 @@ struct llm_build_gptneox : public llm_graph_context {
9879
10711
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
9880
10712
  }
9881
10713
 
9882
- if (il == n_layer - 1) {
9883
- // skip computing output for unused tokens
9884
- ggml_tensor * inp_out_ids = build_inp_out_ids();
10714
+ if (il == n_layer - 1 && inp_out_ids) {
9885
10715
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
9886
10716
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
9887
10717
  }
@@ -9983,6 +10813,8 @@ struct llm_build_arctic : public llm_graph_context {
9983
10813
 
9984
10814
  auto * inp_attn = build_attn_inp_kv_unified();
9985
10815
 
10816
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
10817
+
9986
10818
  for (int il = 0; il < n_layer; ++il) {
9987
10819
  ggml_tensor * inpSA = inpL;
9988
10820
 
@@ -10029,9 +10861,7 @@ struct llm_build_arctic : public llm_graph_context {
10029
10861
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
10030
10862
  }
10031
10863
 
10032
- if (il == n_layer - 1) {
10033
- // skip computing output for unused tokens
10034
- ggml_tensor * inp_out_ids = build_inp_out_ids();
10864
+ if (il == n_layer - 1 && inp_out_ids) {
10035
10865
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
10036
10866
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
10037
10867
  }
@@ -10123,6 +10953,8 @@ struct llm_build_deepseek : public llm_graph_context {
10123
10953
 
10124
10954
  const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
10125
10955
 
10956
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
10957
+
10126
10958
  for (int il = 0; il < n_layer; ++il) {
10127
10959
  ggml_tensor * inpSA = inpL;
10128
10960
 
@@ -10184,14 +11016,11 @@ struct llm_build_deepseek : public llm_graph_context {
10184
11016
  Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
10185
11017
  }
10186
11018
 
10187
- if (il == n_layer - 1) {
10188
- // skip computing output for unused tokens
10189
- ggml_tensor * inp_out_ids = build_inp_out_ids();
11019
+ if (il == n_layer - 1 && inp_out_ids) {
10190
11020
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
10191
11021
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
10192
11022
  }
10193
11023
 
10194
-
10195
11024
  ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
10196
11025
  cb(ffn_inp, "ffn_inp", il);
10197
11026
 
@@ -10299,6 +11128,8 @@ struct llm_build_deepseek2 : public llm_graph_context {
10299
11128
 
10300
11129
  auto * inp_attn = build_attn_inp_kv_unified();
10301
11130
 
11131
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
11132
+
10302
11133
  for (int il = 0; il < n_layer; ++il) {
10303
11134
  ggml_tensor * inpSA = inpL;
10304
11135
 
@@ -10448,9 +11279,7 @@ struct llm_build_deepseek2 : public llm_graph_context {
10448
11279
  }
10449
11280
  }
10450
11281
 
10451
- if (il == n_layer - 1) {
10452
- // skip computing output for unused tokens
10453
- ggml_tensor * inp_out_ids = build_inp_out_ids();
11282
+ if (il == n_layer - 1 && inp_out_ids) {
10454
11283
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
10455
11284
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
10456
11285
  }
@@ -10546,6 +11375,8 @@ struct llm_build_bitnet : public llm_graph_context {
10546
11375
 
10547
11376
  auto * inp_attn = build_attn_inp_kv_unified();
10548
11377
 
11378
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
11379
+
10549
11380
  for (int il = 0; il < n_layer; ++il) {
10550
11381
  ggml_tensor * inpSA = inpL;
10551
11382
 
@@ -10628,9 +11459,7 @@ struct llm_build_bitnet : public llm_graph_context {
10628
11459
  cb(cur, "attn_o_out", il);
10629
11460
  }
10630
11461
 
10631
- if (il == n_layer - 1) {
10632
- // skip computing output for unused tokens
10633
- ggml_tensor * inp_out_ids = build_inp_out_ids();
11462
+ if (il == n_layer - 1 && inp_out_ids) {
10634
11463
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
10635
11464
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
10636
11465
  }
@@ -10705,6 +11534,8 @@ struct llm_build_t5_enc : public llm_graph_context {
10705
11534
 
10706
11535
  auto * inp_attn = build_attn_inp_no_cache();
10707
11536
 
11537
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
11538
+
10708
11539
  for (int il = 0; il < n_layer; ++il) {
10709
11540
  ggml_tensor * inpSA = inpL;
10710
11541
 
@@ -10738,9 +11569,7 @@ struct llm_build_t5_enc : public llm_graph_context {
10738
11569
  cb(cur, "kqv_out", il);
10739
11570
  }
10740
11571
 
10741
- if (il == n_layer - 1) {
10742
- // skip computing output for unused tokens
10743
- ggml_tensor * inp_out_ids = build_inp_out_ids();
11572
+ if (il == n_layer - 1 && inp_out_ids) {
10744
11573
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
10745
11574
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
10746
11575
  }
@@ -10811,6 +11640,8 @@ struct llm_build_t5_dec : public llm_graph_context {
10811
11640
  auto * inp_attn_self = build_attn_inp_kv_unified();
10812
11641
  auto * inp_attn_cross = build_attn_inp_cross();
10813
11642
 
11643
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
11644
+
10814
11645
  for (int il = 0; il < n_layer; ++il) {
10815
11646
  ggml_tensor * inpSA = inpL;
10816
11647
 
@@ -10902,11 +11733,8 @@ struct llm_build_t5_dec : public llm_graph_context {
10902
11733
  //cb(cur, "kqv_out", il);
10903
11734
  }
10904
11735
 
10905
- if (il == n_layer - 1) {
10906
- // skip computing output for unused tokens
10907
- ggml_tensor * inp_out_ids = build_inp_out_ids();
11736
+ if (il == n_layer - 1 && inp_out_ids) {
10908
11737
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
10909
- inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
10910
11738
  inpCA = ggml_get_rows(ctx0, inpCA, inp_out_ids);
10911
11739
  }
10912
11740
 
@@ -10976,6 +11804,8 @@ struct llm_build_jais : public llm_graph_context {
10976
11804
 
10977
11805
  auto * inp_attn = build_attn_inp_kv_unified();
10978
11806
 
11807
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
11808
+
10979
11809
  for (int il = 0; il < n_layer; ++il) {
10980
11810
  cur = build_norm(inpL,
10981
11811
  model.layers[il].attn_norm,
@@ -11008,9 +11838,7 @@ struct llm_build_jais : public llm_graph_context {
11008
11838
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/float(n_embd_head), il);
11009
11839
  }
11010
11840
 
11011
- if (il == n_layer - 1) {
11012
- // skip computing output for unused tokens
11013
- ggml_tensor * inp_out_ids = build_inp_out_ids();
11841
+ if (il == n_layer - 1 && inp_out_ids) {
11014
11842
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
11015
11843
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
11016
11844
  }
@@ -11074,6 +11902,8 @@ struct llm_build_chatglm : public llm_graph_context {
11074
11902
 
11075
11903
  auto * inp_attn = build_attn_inp_kv_unified();
11076
11904
 
11905
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
11906
+
11077
11907
  for (int il = 0; il < n_layer; ++il) {
11078
11908
  ggml_tensor * inpSA = inpL;
11079
11909
 
@@ -11140,9 +11970,7 @@ struct llm_build_chatglm : public llm_graph_context {
11140
11970
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
11141
11971
  }
11142
11972
 
11143
- if (il == n_layer - 1) {
11144
- // skip computing output for unused tokens
11145
- ggml_tensor * inp_out_ids = build_inp_out_ids();
11973
+ if (il == n_layer - 1 && inp_out_ids) {
11146
11974
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
11147
11975
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
11148
11976
  }
@@ -11207,6 +12035,8 @@ struct llm_build_glm4 : public llm_graph_context {
11207
12035
 
11208
12036
  auto * inp_attn = build_attn_inp_kv_unified();
11209
12037
 
12038
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
12039
+
11210
12040
  for (int il = 0; il < n_layer; ++il) {
11211
12041
  ggml_tensor * inpSA = inpL;
11212
12042
 
@@ -11273,9 +12103,7 @@ struct llm_build_glm4 : public llm_graph_context {
11273
12103
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
11274
12104
  }
11275
12105
 
11276
- if (il == n_layer - 1) {
11277
- // skip computing output for unused tokens
11278
- ggml_tensor * inp_out_ids = build_inp_out_ids();
12106
+ if (il == n_layer - 1 && inp_out_ids) {
11279
12107
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
11280
12108
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
11281
12109
  }
@@ -11358,6 +12186,8 @@ struct llm_build_nemotron : public llm_graph_context {
11358
12186
 
11359
12187
  auto * inp_attn = build_attn_inp_kv_unified();
11360
12188
 
12189
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
12190
+
11361
12191
  for (int il = 0; il < n_layer; ++il) {
11362
12192
  ggml_tensor * inpSA = inpL;
11363
12193
 
@@ -11417,9 +12247,7 @@ struct llm_build_nemotron : public llm_graph_context {
11417
12247
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
11418
12248
  }
11419
12249
 
11420
- if (il == n_layer - 1) {
11421
- // skip computing output for unused tokens
11422
- ggml_tensor * inp_out_ids = build_inp_out_ids();
12250
+ if (il == n_layer - 1 && inp_out_ids) {
11423
12251
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
11424
12252
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
11425
12253
  }
@@ -11487,6 +12315,8 @@ struct llm_build_exaone : public llm_graph_context {
11487
12315
 
11488
12316
  auto * inp_attn = build_attn_inp_kv_unified();
11489
12317
 
12318
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
12319
+
11490
12320
  for (int il = 0; il < n_layer; ++il) {
11491
12321
  ggml_tensor * inpSA = inpL;
11492
12322
 
@@ -11548,9 +12378,7 @@ struct llm_build_exaone : public llm_graph_context {
11548
12378
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
11549
12379
  }
11550
12380
 
11551
- if (il == n_layer - 1) {
11552
- // skip computing output for unused tokens
11553
- ggml_tensor * inp_out_ids = build_inp_out_ids();
12381
+ if (il == n_layer - 1 && inp_out_ids) {
11554
12382
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
11555
12383
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
11556
12384
  }
@@ -11637,14 +12465,13 @@ struct llm_build_rwkv6_base : public llm_graph_context {
11637
12465
  }
11638
12466
 
11639
12467
  ggml_tensor * build_rwkv6_time_mix(
12468
+ llm_graph_input_rs * inp,
11640
12469
  ggml_cgraph * gf,
11641
12470
  ggml_tensor * cur,
11642
12471
  ggml_tensor * x_prev,
11643
- ggml_tensor * state_copy,
11644
- ggml_tensor * state_mask,
11645
12472
  const llama_ubatch & ubatch,
11646
12473
  int il) const {
11647
- const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
12474
+ const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
11648
12475
 
11649
12476
  const auto n_tokens = ubatch.n_tokens;
11650
12477
  const auto n_seqs = ubatch.n_seqs;
@@ -11654,7 +12481,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
11654
12481
  const auto n_head = n_embd / head_size;
11655
12482
  const auto n_head_kv = hparams.n_head_kv(il);
11656
12483
 
11657
- const auto kv_head = kv_self->head;
12484
+ const auto kv_head = mctx_cur->get_head();
11658
12485
 
11659
12486
  const auto & layer = model.layers[il];
11660
12487
 
@@ -11765,9 +12592,9 @@ struct llm_build_rwkv6_base : public llm_graph_context {
11765
12592
  k = ggml_sub(ctx0, k, ggml_mul(ctx0, k, w));
11766
12593
  }
11767
12594
 
11768
- ggml_tensor * wkv_state = build_copy_mask_state(
11769
- gf, kv_self->v_l[il], state_copy, state_mask,
11770
- hparams.n_embd_v_s(), n_seqs);
12595
+ ggml_tensor * wkv_state = build_rs(
12596
+ inp, gf, mctx_cur->get_s_l(il),
12597
+ hparams.n_embd_s(), n_seqs);
11771
12598
 
11772
12599
  ggml_tensor * wkv_output;
11773
12600
  if (is_qrwkv) {
@@ -11785,9 +12612,9 @@ struct llm_build_rwkv6_base : public llm_graph_context {
11785
12612
  wkv_state,
11786
12613
  ggml_view_1d(
11787
12614
  ctx0,
11788
- kv_self->v_l[il],
11789
- hparams.n_embd_v_s() * n_seqs,
11790
- hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self->v_l[il])
12615
+ mctx_cur->get_s_l(il),
12616
+ hparams.n_embd_s() * n_seqs,
12617
+ hparams.n_embd_s() * kv_head * ggml_element_size(mctx_cur->get_s_l(il))
11791
12618
  )
11792
12619
  )
11793
12620
  );
@@ -11821,20 +12648,19 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
11821
12648
  inpL = build_inp_embd(model.tok_embd);
11822
12649
  inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
11823
12650
 
11824
- ggml_tensor * state_copy = build_inp_s_copy();
11825
- ggml_tensor * state_mask = build_inp_s_mask();
12651
+ auto * rs_inp = build_rs_inp();
11826
12652
 
11827
12653
  const auto n_embd = hparams.n_embd;
11828
12654
  const auto n_seq_tokens = ubatch.n_seq_tokens;
11829
12655
  const auto n_seqs = ubatch.n_seqs;
11830
12656
 
12657
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
12658
+
11831
12659
  for (int il = 0; il < n_layer; ++il) {
11832
12660
  const llama_layer * layer = &model.layers[il];
11833
12661
  inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
11834
12662
 
11835
- ggml_tensor * token_shift = build_rwkv_token_shift_load(
11836
- gf, state_copy, state_mask, ubatch, il
11837
- );
12663
+ ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
11838
12664
 
11839
12665
  ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
11840
12666
  ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
@@ -11849,7 +12675,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
11849
12675
  1
11850
12676
  );
11851
12677
 
11852
- cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, state_mask, ubatch, il);
12678
+ cur = build_rwkv6_time_mix(rs_inp, gf, att_norm, x_prev, ubatch, il);
11853
12679
 
11854
12680
  ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
11855
12681
  cb(ffn_inp, "ffn_inp", il);
@@ -11871,13 +12697,16 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
11871
12697
  );
11872
12698
  ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
11873
12699
 
11874
- if (il == n_layer - 1) {
11875
- // skip computing output for unused tokens
11876
- struct ggml_tensor * inp_out_ids = build_inp_out_ids();
11877
- ffn_inp = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens), inp_out_ids);
11878
- ffn_norm = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_norm, n_embd, n_tokens), inp_out_ids);
11879
- x_prev = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, x_prev, n_embd, n_tokens), inp_out_ids);
11880
- cur = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, cur, n_embd, n_tokens), inp_out_ids);
12700
+ ffn_inp = ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens);
12701
+ ffn_norm = ggml_reshape_2d(ctx0, ffn_norm, n_embd, n_tokens);
12702
+ x_prev = ggml_reshape_2d(ctx0, x_prev, n_embd, n_tokens);
12703
+ cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
12704
+
12705
+ if (il == n_layer - 1 && inp_out_ids) {
12706
+ ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
12707
+ ffn_norm = ggml_get_rows(ctx0, ffn_norm, inp_out_ids);
12708
+ x_prev = ggml_get_rows(ctx0, x_prev, inp_out_ids);
12709
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
11881
12710
  }
11882
12711
 
11883
12712
  cur = build_rwkv6_channel_mix(layer, ffn_norm, x_prev, LLM_ARCH_RWKV6);
@@ -11912,27 +12741,26 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
11912
12741
  // ref: https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1/blob/main/modeling_rwkv6qwen2.py
11913
12742
  struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
11914
12743
  llm_build_rwkv6qwen2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv6_base(model, params) {
11915
- GGML_ASSERT(n_embd == hparams.n_embd_k_s());
12744
+ GGML_ASSERT(n_embd == hparams.n_embd_r());
11916
12745
 
11917
12746
  ggml_tensor * cur;
11918
12747
  ggml_tensor * inpL;
11919
12748
 
11920
12749
  inpL = build_inp_embd(model.tok_embd);
11921
12750
 
11922
- ggml_tensor * state_copy = build_inp_s_copy();
11923
- ggml_tensor * state_mask = build_inp_s_mask();
12751
+ auto * rs_inp = build_rs_inp();
11924
12752
 
11925
12753
  const auto n_embd = hparams.n_embd;
11926
12754
  const auto n_seq_tokens = ubatch.n_seq_tokens;
11927
12755
  const auto n_seqs = ubatch.n_seqs;
11928
12756
 
12757
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
12758
+
11929
12759
  for (int il = 0; il < n_layer; ++il) {
11930
12760
  const llama_layer * layer = &model.layers[il];
11931
12761
  inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
11932
12762
 
11933
- ggml_tensor * token_shift = build_rwkv_token_shift_load(
11934
- gf, state_copy, state_mask, ubatch, il
11935
- );
12763
+ ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
11936
12764
 
11937
12765
  ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
11938
12766
  cb(att_norm, "attn_norm", il);
@@ -11944,7 +12772,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
11944
12772
  1
11945
12773
  );
11946
12774
 
11947
- cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, state_mask, ubatch, il);
12775
+ cur = build_rwkv6_time_mix(rs_inp, gf, att_norm, x_prev, ubatch, il);
11948
12776
 
11949
12777
  token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
11950
12778
  ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
@@ -11952,11 +12780,12 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
11952
12780
  ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
11953
12781
  cb(ffn_inp, "ffn_inp", il);
11954
12782
 
11955
- if (il == n_layer - 1) {
11956
- // skip computing output for unused tokens
11957
- struct ggml_tensor * inp_out_ids = build_inp_out_ids();
11958
- cur = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, cur, n_embd, n_tokens), inp_out_ids);
11959
- ffn_inp = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens), inp_out_ids);
12783
+ cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
12784
+ ffn_inp = ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens);
12785
+
12786
+ if (il == n_layer - 1 && inp_out_ids) {
12787
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
12788
+ ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
11960
12789
  }
11961
12790
 
11962
12791
  // feed-forward network
@@ -12032,15 +12861,14 @@ struct llm_build_rwkv7_base : public llm_graph_context {
12032
12861
  }
12033
12862
 
12034
12863
  ggml_tensor * build_rwkv7_time_mix(
12864
+ llm_graph_input_rs * inp,
12035
12865
  ggml_cgraph * gf,
12036
12866
  ggml_tensor * cur,
12037
12867
  ggml_tensor * x_prev,
12038
- ggml_tensor * state_copy,
12039
- ggml_tensor * state_mask,
12040
12868
  ggml_tensor *& first_layer_value,
12041
12869
  const llama_ubatch & ubatch,
12042
12870
  int il) const {
12043
- const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
12871
+ const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
12044
12872
 
12045
12873
  const auto n_tokens = ubatch.n_tokens;
12046
12874
  const auto n_seqs = ubatch.n_seqs;
@@ -12049,7 +12877,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
12049
12877
  const auto head_count = n_embd / head_size;
12050
12878
  const auto n_seq_tokens = ubatch.n_seq_tokens;
12051
12879
 
12052
- const auto kv_head = kv_self->head;
12880
+ const auto kv_head = mctx_cur->get_head();
12053
12881
 
12054
12882
  const auto & layer = model.layers[il];
12055
12883
 
@@ -12119,9 +12947,9 @@ struct llm_build_rwkv7_base : public llm_graph_context {
12119
12947
  v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens);
12120
12948
  a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
12121
12949
 
12122
- ggml_tensor * wkv_state = build_copy_mask_state(
12123
- gf, kv_self->v_l[il], state_copy, state_mask,
12124
- hparams.n_embd_v_s(), n_seqs);
12950
+ ggml_tensor * wkv_state = build_rs(
12951
+ inp, gf, mctx_cur->get_s_l(il),
12952
+ hparams.n_embd_s(), n_seqs);
12125
12953
 
12126
12954
  ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
12127
12955
  cur = ggml_view_1d(ctx0, wkv_output, n_embd * n_tokens, 0);
@@ -12134,9 +12962,9 @@ struct llm_build_rwkv7_base : public llm_graph_context {
12134
12962
  wkv_state,
12135
12963
  ggml_view_1d(
12136
12964
  ctx0,
12137
- kv_self->v_l[il],
12138
- hparams.n_embd_v_s() * n_seqs,
12139
- hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self->v_l[il])
12965
+ mctx_cur->get_s_l(il),
12966
+ hparams.n_embd_s() * n_seqs,
12967
+ hparams.n_embd_s() * kv_head * ggml_element_size(mctx_cur->get_s_l(il))
12140
12968
  )
12141
12969
  )
12142
12970
  );
@@ -12177,20 +13005,19 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
12177
13005
  inpL = build_inp_embd(model.tok_embd);
12178
13006
  inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
12179
13007
 
12180
- ggml_tensor * state_copy = build_inp_s_copy();
12181
- ggml_tensor * state_mask = build_inp_s_mask();
13008
+ auto * rs_inp = build_rs_inp();
12182
13009
 
12183
13010
  const auto n_embd = hparams.n_embd;
12184
13011
  const auto n_seq_tokens = ubatch.n_seq_tokens;
12185
13012
  const auto n_seqs = ubatch.n_seqs;
12186
13013
 
13014
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
13015
+
12187
13016
  for (int il = 0; il < n_layer; ++il) {
12188
13017
  const llama_layer * layer = &model.layers[il];
12189
13018
  inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
12190
13019
 
12191
- ggml_tensor * token_shift = build_rwkv_token_shift_load(
12192
- gf, state_copy, state_mask, ubatch, il
12193
- );
13020
+ ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
12194
13021
 
12195
13022
  ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
12196
13023
  ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
@@ -12205,7 +13032,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
12205
13032
  1
12206
13033
  );
12207
13034
 
12208
- cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, state_mask, v_first, ubatch, il);
13035
+ cur = build_rwkv7_time_mix(rs_inp, gf, att_norm, x_prev, v_first, ubatch, il);
12209
13036
 
12210
13037
  ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
12211
13038
  cb(ffn_inp, "ffn_inp", il);
@@ -12227,12 +13054,14 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
12227
13054
  );
12228
13055
  ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
12229
13056
 
12230
- if (il == n_layer - 1) {
12231
- // skip computing output for unused tokens
12232
- struct ggml_tensor * inp_out_ids = build_inp_out_ids();
12233
- ffn_inp = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens), inp_out_ids);
12234
- ffn_norm = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_norm, n_embd, n_tokens), inp_out_ids);
12235
- x_prev = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, x_prev, n_embd, n_tokens), inp_out_ids);
13057
+ ffn_inp = ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens);
13058
+ ffn_norm = ggml_reshape_2d(ctx0, ffn_norm, n_embd, n_tokens);
13059
+ x_prev = ggml_reshape_2d(ctx0, x_prev, n_embd, n_tokens);
13060
+
13061
+ if (il == n_layer - 1 && inp_out_ids) {
13062
+ ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
13063
+ ffn_norm = ggml_get_rows(ctx0, ffn_norm, inp_out_ids);
13064
+ x_prev = ggml_get_rows(ctx0, x_prev, inp_out_ids);
12236
13065
  }
12237
13066
 
12238
13067
  cur = build_rwkv7_channel_mix(layer, ffn_norm, x_prev, LLM_ARCH_RWKV7);
@@ -12263,7 +13092,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
12263
13092
 
12264
13093
  struct llm_build_arwkv7 : public llm_build_rwkv7_base {
12265
13094
  llm_build_arwkv7(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv7_base(model, params) {
12266
- GGML_ASSERT(n_embd == hparams.n_embd_k_s());
13095
+ GGML_ASSERT(n_embd == hparams.n_embd_r());
12267
13096
 
12268
13097
  ggml_tensor * cur;
12269
13098
  ggml_tensor * inpL;
@@ -12271,20 +13100,19 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
12271
13100
 
12272
13101
  inpL = build_inp_embd(model.tok_embd);
12273
13102
 
12274
- ggml_tensor * state_copy = build_inp_s_copy();
12275
- ggml_tensor * state_mask = build_inp_s_mask();
13103
+ auto * rs_inp = build_rs_inp();
12276
13104
 
12277
13105
  const auto n_embd = hparams.n_embd;
12278
13106
  const auto n_seq_tokens = ubatch.n_seq_tokens;
12279
13107
  const auto n_seqs = ubatch.n_seqs;
12280
13108
 
13109
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
13110
+
12281
13111
  for (int il = 0; il < n_layer; ++il) {
12282
13112
  const llama_layer * layer = &model.layers[il];
12283
13113
  inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
12284
13114
 
12285
- ggml_tensor * token_shift = build_rwkv_token_shift_load(
12286
- gf, state_copy, state_mask, ubatch, il
12287
- );
13115
+ ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
12288
13116
 
12289
13117
  ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
12290
13118
  cb(att_norm, "attn_norm", il);
@@ -12296,7 +13124,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
12296
13124
  1
12297
13125
  );
12298
13126
 
12299
- cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, state_mask, v_first, ubatch, il);
13127
+ cur = build_rwkv7_time_mix(rs_inp, gf, att_norm, x_prev, v_first, ubatch, il);
12300
13128
 
12301
13129
  token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
12302
13130
  ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
@@ -12304,11 +13132,12 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
12304
13132
  ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
12305
13133
  cb(ffn_inp, "ffn_inp", il);
12306
13134
 
12307
- if (il == n_layer - 1) {
12308
- // skip computing output for unused tokens
12309
- struct ggml_tensor * inp_out_ids = build_inp_out_ids();
12310
- cur = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, cur, n_embd, n_tokens), inp_out_ids);
12311
- ffn_inp = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens), inp_out_ids);
13135
+ cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
13136
+ ffn_inp = ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens);
13137
+
13138
+ if (il == n_layer - 1 && inp_out_ids) {
13139
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
13140
+ ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
12312
13141
  }
12313
13142
 
12314
13143
  // feed-forward network
@@ -12377,6 +13206,9 @@ struct llm_build_granite : public llm_graph_context {
12377
13206
  auto * inp_attn = build_attn_inp_kv_unified();
12378
13207
 
12379
13208
  const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
13209
+
13210
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
13211
+
12380
13212
  for (int il = 0; il < n_layer; ++il) {
12381
13213
  ggml_tensor * inpSA = inpL;
12382
13214
 
@@ -12439,9 +13271,7 @@ struct llm_build_granite : public llm_graph_context {
12439
13271
  cb(cur, "attn_out", il);
12440
13272
  }
12441
13273
 
12442
- if (il == n_layer - 1) {
12443
- // skip computing output for unused tokens
12444
- ggml_tensor * inp_out_ids = build_inp_out_ids();
13274
+ if (il == n_layer - 1 && inp_out_ids) {
12445
13275
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
12446
13276
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
12447
13277
  }
@@ -12560,6 +13390,8 @@ struct llm_build_chameleon : public llm_graph_context {
12560
13390
 
12561
13391
  auto * inp_attn = build_attn_inp_kv_unified();
12562
13392
 
13393
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
13394
+
12563
13395
  for (int il = 0; il < n_layer; ++il) {
12564
13396
  ggml_tensor * inpSA = inpL;
12565
13397
 
@@ -12636,21 +13468,19 @@ struct llm_build_chameleon : public llm_graph_context {
12636
13468
  cur = build_attn(inp_attn, gf,
12637
13469
  model.layers[il].wo, nullptr,
12638
13470
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
12639
-
12640
- if (hparams.swin_norm) {
12641
- cur = build_norm(cur,
12642
- model.layers[il].attn_norm, NULL,
12643
- LLM_NORM_RMS, il);
12644
- }
12645
13471
  }
12646
13472
 
12647
- if (il == n_layer - 1) {
12648
- // skip computing output for unused tokens
12649
- ggml_tensor * inp_out_ids = build_inp_out_ids();
13473
+ if (il == n_layer - 1 && inp_out_ids) {
12650
13474
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
12651
13475
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
12652
13476
  }
12653
13477
 
13478
+ if (hparams.swin_norm) {
13479
+ cur = build_norm(cur,
13480
+ model.layers[il].attn_norm, NULL,
13481
+ LLM_NORM_RMS, il);
13482
+ }
13483
+
12654
13484
  ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
12655
13485
  cb(ffn_inp, "ffn_inp", il);
12656
13486
 
@@ -12891,6 +13721,8 @@ struct llm_build_plm : public llm_graph_context {
12891
13721
 
12892
13722
  auto * inp_attn = build_attn_inp_kv_unified();
12893
13723
 
13724
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
13725
+
12894
13726
  for (int il = 0; il < n_layer; ++il) {
12895
13727
  ggml_tensor * inpSA = inpL;
12896
13728
 
@@ -12994,9 +13826,7 @@ struct llm_build_plm : public llm_graph_context {
12994
13826
  q_states, k_states, v_states, nullptr, nullptr, kq_scale, il);
12995
13827
  }
12996
13828
 
12997
- if (il == n_layer - 1) {
12998
- // skip computing output for unused tokens
12999
- ggml_tensor * inp_out_ids = build_inp_out_ids();
13829
+ if (il == n_layer - 1 && inp_out_ids) {
13000
13830
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
13001
13831
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
13002
13832
  }
@@ -13056,6 +13886,8 @@ struct llm_build_bailingmoe : public llm_graph_context {
13056
13886
 
13057
13887
  auto * inp_attn = build_attn_inp_kv_unified();
13058
13888
 
13889
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
13890
+
13059
13891
  for (int il = 0; il < n_layer; ++il) {
13060
13892
  ggml_tensor * inpSA = inpL;
13061
13893
 
@@ -13117,9 +13949,7 @@ struct llm_build_bailingmoe : public llm_graph_context {
13117
13949
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_rot)), il);
13118
13950
  }
13119
13951
 
13120
- if (il == n_layer - 1) {
13121
- // skip computing output for unused tokens
13122
- ggml_tensor * inp_out_ids = build_inp_out_ids();
13952
+ if (il == n_layer - 1 && inp_out_ids) {
13123
13953
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
13124
13954
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
13125
13955
  }
@@ -13188,69 +14018,505 @@ struct llm_build_bailingmoe : public llm_graph_context {
13188
14018
  }
13189
14019
  };
13190
14020
 
14021
+ struct llm_build_dots1 : public llm_graph_context {
14022
+ llm_build_dots1(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
14023
+ const int64_t n_embd_head = hparams.n_embd_head_v;
14024
+
14025
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
14026
+ GGML_ASSERT(n_embd_head == hparams.n_rot);
14027
+
14028
+ ggml_tensor * cur;
14029
+ ggml_tensor * inpL;
14030
+
14031
+ inpL = build_inp_embd(model.tok_embd);
14032
+
14033
+ // inp_pos - contains the positions
14034
+ ggml_tensor * inp_pos = build_inp_pos();
14035
+
14036
+ auto * inp_attn = build_attn_inp_kv_unified();
14037
+
14038
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
14039
+
14040
+ for (int il = 0; il < n_layer; ++il) {
14041
+ ggml_tensor * inpSA = inpL;
14042
+
14043
+ // norm
14044
+ cur = build_norm(inpL,
14045
+ model.layers[il].attn_norm, NULL,
14046
+ LLM_NORM_RMS, il);
14047
+ cb(cur, "attn_norm", il);
14048
+
14049
+ // self_attention
14050
+ {
14051
+ // compute Q and K and RoPE them
14052
+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
14053
+ cb(Qcur, "Qcur", il);
14054
+
14055
+ ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
14056
+ cb(Kcur, "Kcur", il);
14057
+
14058
+ ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
14059
+ cb(Vcur, "Vcur", il);
14060
+
14061
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
14062
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
14063
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
14064
+
14065
+ Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
14066
+ cb(Qcur, "Qcur_normed", il);
14067
+
14068
+ Qcur = ggml_rope_ext(
14069
+ ctx0, Qcur, inp_pos, nullptr,
14070
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
14071
+ ext_factor, attn_factor, beta_fast, beta_slow
14072
+ );
14073
+
14074
+ Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
14075
+ cb(Kcur, "Kcur_normed", il);
14076
+
14077
+ Kcur = ggml_rope_ext(
14078
+ ctx0, Kcur, inp_pos, nullptr,
14079
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
14080
+ ext_factor, attn_factor, beta_fast, beta_slow
14081
+ );
14082
+
14083
+ cb(Qcur, "Qcur", il);
14084
+ cb(Kcur, "Kcur", il);
14085
+ cb(Vcur, "Vcur", il);
14086
+
14087
+ cur = build_attn(inp_attn, gf,
14088
+ model.layers[il].wo, model.layers[il].bo,
14089
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
14090
+ }
14091
+
14092
+ if (il == n_layer - 1 && inp_out_ids) {
14093
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
14094
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
14095
+ }
14096
+
14097
+ ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
14098
+ cb(ffn_inp, "ffn_inp", il);
14099
+
14100
+ // MoE branch
14101
+ cur = build_norm(ffn_inp,
14102
+ model.layers[il].ffn_norm, NULL,
14103
+ LLM_NORM_RMS, il);
14104
+ cb(cur, "ffn_norm", il);
14105
+
14106
+ if ((uint32_t) il < hparams.n_layer_dense_lead) {
14107
+ cur = build_ffn(cur,
14108
+ model.layers[il].ffn_up, NULL, NULL,
14109
+ model.layers[il].ffn_gate, NULL, NULL,
14110
+ model.layers[il].ffn_down, NULL, NULL,
14111
+ NULL,
14112
+ LLM_FFN_SILU, LLM_FFN_PAR, il);
14113
+ cb(cur, "ffn_out", il);
14114
+ } else {
14115
+ ggml_tensor * moe_out =
14116
+ build_moe_ffn(cur,
14117
+ model.layers[il].ffn_gate_inp,
14118
+ model.layers[il].ffn_up_exps,
14119
+ model.layers[il].ffn_gate_exps,
14120
+ model.layers[il].ffn_down_exps,
14121
+ model.layers[il].ffn_exp_probs_b,
14122
+ n_expert, n_expert_used,
14123
+ LLM_FFN_SILU, hparams.expert_weights_norm,
14124
+ true, hparams.expert_weights_scale,
14125
+ (llama_expert_gating_func_type) hparams.expert_gating_func,
14126
+ il);
14127
+ cb(moe_out, "ffn_moe_out", il);
14128
+
14129
+ {
14130
+ ggml_tensor * ffn_shexp = build_ffn(cur,
14131
+ model.layers[il].ffn_up_shexp, NULL, NULL,
14132
+ model.layers[il].ffn_gate_shexp, NULL, NULL,
14133
+ model.layers[il].ffn_down_shexp, NULL, NULL,
14134
+ NULL,
14135
+ LLM_FFN_SILU, LLM_FFN_PAR, il);
14136
+ cb(ffn_shexp, "ffn_shexp", il);
14137
+
14138
+ cur = ggml_add(ctx0, moe_out, ffn_shexp);
14139
+ cb(cur, "ffn_out", il);
14140
+ }
14141
+ }
14142
+
14143
+ cur = ggml_add(ctx0, cur, ffn_inp);
14144
+
14145
+ cur = build_cvec(cur, il);
14146
+ cb(cur, "l_out", il);
14147
+
14148
+ // input for next layer
14149
+ inpL = cur;
14150
+ }
14151
+
14152
+ cur = inpL;
14153
+
14154
+ cur = build_norm(cur,
14155
+ model.output_norm, NULL,
14156
+ LLM_NORM_RMS, -1);
14157
+
14158
+ cb(cur, "result_norm", -1);
14159
+ res->t_embd = cur;
14160
+
14161
+ // lm_head
14162
+ cur = build_lora_mm(model.output, cur);
14163
+
14164
+ cb(cur, "result_output", -1);
14165
+ res->t_logits = cur;
14166
+
14167
+ ggml_build_forward_expand(gf, cur);
14168
+ }
14169
+ };
14170
+
14171
+ struct llm_build_ernie4_5 : public llm_graph_context {
14172
+ llm_build_ernie4_5(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
14173
+ const int64_t n_embd_head = hparams.n_embd_head_v;
14174
+
14175
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
14176
+ GGML_ASSERT(n_embd_head == hparams.n_rot);
14177
+
14178
+ ggml_tensor * cur;
14179
+ ggml_tensor * inpL;
14180
+
14181
+ inpL = build_inp_embd(model.tok_embd);
14182
+
14183
+ // inp_pos - contains the positions
14184
+ ggml_tensor * inp_pos = build_inp_pos();
14185
+
14186
+ auto * inp_attn = build_attn_inp_kv_unified();
14187
+
14188
+ for (int il = 0; il < n_layer; ++il) {
14189
+ ggml_tensor * inpSA = inpL;
14190
+
14191
+ // norm
14192
+ {
14193
+ cur = build_norm(inpL,
14194
+ model.layers[il].attn_norm, NULL,
14195
+ LLM_NORM_RMS, il);
14196
+ cb(cur, "attn_norm", il);
14197
+ }
14198
+
14199
+ // self-attention
14200
+ {
14201
+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
14202
+ cb(Qcur, "Qcur", il);
14203
+ if (model.layers[il].bq) {
14204
+ Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
14205
+ cb(Qcur, "Qcur", il);
14206
+ }
14207
+
14208
+ ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
14209
+ cb(Kcur, "Kcur", il);
14210
+ if (model.layers[il].bk) {
14211
+ Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
14212
+ cb(Kcur, "Kcur", il);
14213
+ }
14214
+
14215
+ ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
14216
+ cb(Vcur, "Vcur", il);
14217
+ if (model.layers[il].bv) {
14218
+ Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
14219
+ cb(Vcur, "Vcur", il);
14220
+ }
14221
+
14222
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
14223
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
14224
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
14225
+
14226
+ Qcur = ggml_rope_ext(
14227
+ ctx0, Qcur, inp_pos, nullptr,
14228
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
14229
+ ext_factor, attn_factor, beta_fast, beta_slow
14230
+ );
14231
+
14232
+ Kcur = ggml_rope_ext(
14233
+ ctx0, Kcur, inp_pos, nullptr,
14234
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
14235
+ ext_factor, attn_factor, beta_fast, beta_slow
14236
+ );
14237
+
14238
+ cb(Qcur, "Qcur", il);
14239
+ cb(Kcur, "Kcur", il);
14240
+ cb(Vcur, "Vcur", il);
14241
+
14242
+ cur = build_attn(inp_attn, gf,
14243
+ model.layers[il].wo, NULL,
14244
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
14245
+ }
14246
+
14247
+ if (il == n_layer - 1) {
14248
+ // skip computing output for unused tokens
14249
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
14250
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
14251
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
14252
+ }
14253
+
14254
+ ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
14255
+ cb(ffn_inp, "ffn_inp", il);
14256
+
14257
+ // feed-forward network
14258
+ {
14259
+ cur = build_norm(ffn_inp,
14260
+ model.layers[il].ffn_norm, NULL,
14261
+ LLM_NORM_RMS, il);
14262
+ cb(cur, "ffn_norm", il);
14263
+
14264
+ cur = build_ffn(cur,
14265
+ model.layers[il].ffn_up, NULL, NULL,
14266
+ model.layers[il].ffn_gate, NULL, NULL,
14267
+ model.layers[il].ffn_down, NULL, NULL,
14268
+ NULL,
14269
+ LLM_FFN_SILU, LLM_FFN_PAR, il);
14270
+ cb(cur, "ffn_out", il);
14271
+ }
14272
+
14273
+ cur = ggml_add(ctx0, cur, ffn_inp);
14274
+
14275
+ cur = build_cvec(cur, il);
14276
+ cb(cur, "l_out", il);
14277
+
14278
+ // input for next layer
14279
+ inpL = cur;
14280
+ }
14281
+
14282
+ cur = inpL;
14283
+
14284
+ cur = build_norm(cur,
14285
+ model.output_norm, NULL,
14286
+ LLM_NORM_RMS, -1);
14287
+
14288
+ cb(cur, "result_norm", -1);
14289
+ res->t_embd = cur;
14290
+
14291
+ // lm_head
14292
+ cur = build_lora_mm(model.output, cur);
14293
+
14294
+ cb(cur, "result_output", -1);
14295
+ res->t_logits = cur;
14296
+
14297
+ ggml_build_forward_expand(gf, cur);
14298
+ }
14299
+ };
14300
+
14301
+ struct llm_build_arcee : public llm_graph_context {
14302
+ llm_build_arcee(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
14303
+ const int64_t n_embd_head = hparams.n_embd_head_v;
14304
+
14305
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
14306
+ GGML_ASSERT(n_embd_head == hparams.n_rot);
14307
+
14308
+ ggml_tensor * cur;
14309
+ ggml_tensor * inpL;
14310
+
14311
+ inpL = build_inp_embd(model.tok_embd);
14312
+
14313
+ // inp_pos - contains the positions
14314
+ ggml_tensor * inp_pos = build_inp_pos();
14315
+
14316
+ auto * inp_attn = build_attn_inp_kv_unified();
14317
+
14318
+ const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
14319
+
14320
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
14321
+
14322
+ for (int il = 0; il < n_layer; ++il) {
14323
+ ggml_tensor * inpSA = inpL;
14324
+
14325
+ // norm
14326
+ cur = build_norm(inpL,
14327
+ model.layers[il].attn_norm, NULL,
14328
+ LLM_NORM_RMS, il);
14329
+ cb(cur, "attn_norm", il);
14330
+
14331
+ // self-attention
14332
+ {
14333
+ // rope freq factors for llama3; may return nullptr for llama2 and other models
14334
+ ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
14335
+
14336
+ // compute Q and K and RoPE them
14337
+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
14338
+ cb(Qcur, "Qcur", il);
14339
+ if (model.layers[il].bq) {
14340
+ Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
14341
+ cb(Qcur, "Qcur", il);
14342
+ }
14343
+
14344
+ ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
14345
+ cb(Kcur, "Kcur", il);
14346
+ if (model.layers[il].bk) {
14347
+ Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
14348
+ cb(Kcur, "Kcur", il);
14349
+ }
14350
+
14351
+ ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
14352
+ cb(Vcur, "Vcur", il);
14353
+ if (model.layers[il].bv) {
14354
+ Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
14355
+ cb(Vcur, "Vcur", il);
14356
+ }
14357
+
14358
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
14359
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
14360
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
14361
+
14362
+ Qcur = ggml_rope_ext(
14363
+ ctx0, Qcur, inp_pos, rope_factors,
14364
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
14365
+ ext_factor, attn_factor, beta_fast, beta_slow
14366
+ );
14367
+
14368
+ Kcur = ggml_rope_ext(
14369
+ ctx0, Kcur, inp_pos, rope_factors,
14370
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
14371
+ ext_factor, attn_factor, beta_fast, beta_slow
14372
+ );
14373
+
14374
+ cb(Qcur, "Qcur", il);
14375
+ cb(Kcur, "Kcur", il);
14376
+ cb(Vcur, "Vcur", il);
14377
+
14378
+ cur = build_attn(inp_attn, gf,
14379
+ model.layers[il].wo, model.layers[il].bo,
14380
+ Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
14381
+ cb(cur, "attn_out", il);
14382
+ }
14383
+
14384
+ if (il == n_layer - 1 && inp_out_ids) {
14385
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
14386
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
14387
+ }
14388
+
14389
+ ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
14390
+ cb(ffn_inp, "ffn_inp", il);
14391
+
14392
+ // feed-forward network
14393
+ // ARCEE uses relu^2 instead of silu
14394
+ cur = build_norm(ffn_inp,
14395
+ model.layers[il].ffn_norm, NULL,
14396
+ LLM_NORM_RMS, il);
14397
+ cb(cur, "ffn_norm", il);
14398
+
14399
+ cur = build_ffn(cur,
14400
+ model.layers[il].ffn_up, NULL, NULL,
14401
+ NULL, NULL, NULL,
14402
+ model.layers[il].ffn_down, NULL, NULL,
14403
+ NULL,
14404
+ LLM_FFN_RELU_SQR, LLM_FFN_SEQ, il);
14405
+ cb(cur, "ffn_out", il);
14406
+
14407
+ cur = ggml_add(ctx0, cur, ffn_inp);
14408
+ cb(cur, "ffn_out", il);
14409
+
14410
+ cur = build_cvec(cur, il);
14411
+ cb(cur, "l_out", il);
14412
+
14413
+ // input for next layer
14414
+ inpL = cur;
14415
+ }
14416
+
14417
+ cur = inpL;
14418
+
14419
+ cur = build_norm(cur,
14420
+ model.output_norm, NULL,
14421
+ LLM_NORM_RMS, -1);
14422
+
14423
+ cb(cur, "result_norm", -1);
14424
+ res->t_embd = cur;
14425
+
14426
+ // lm_head
14427
+ cur = build_lora_mm(model.output, cur);
14428
+
14429
+ cb(cur, "result_output", -1);
14430
+ res->t_logits = cur;
14431
+
14432
+ ggml_build_forward_expand(gf, cur);
14433
+ }
14434
+ };
14435
+
13191
14436
  llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
13192
14437
  llama_memory_i * res;
13193
14438
 
13194
14439
  switch (arch) {
14440
+ // Models that need specific instantiation should be handled in the
14441
+ // switch statement
13195
14442
  case LLM_ARCH_BERT:
13196
14443
  case LLM_ARCH_JINA_BERT_V2:
13197
14444
  case LLM_ARCH_NOMIC_BERT:
13198
14445
  case LLM_ARCH_NOMIC_BERT_MOE:
14446
+ case LLM_ARCH_NEO_BERT:
13199
14447
  case LLM_ARCH_WAVTOKENIZER_DEC:
13200
14448
  {
13201
14449
  res = nullptr;
13202
14450
  } break;
13203
- case LLM_ARCH_MAMBA:
13204
- case LLM_ARCH_RWKV6:
13205
- case LLM_ARCH_RWKV6QWEN2:
13206
- case LLM_ARCH_RWKV7:
13207
- case LLM_ARCH_ARWKV7:
13208
- {
13209
- res = new llama_kv_cache_recurrent(
13210
- *this,
13211
- GGML_TYPE_F32,
13212
- GGML_TYPE_F32,
13213
- cparams.offload_kqv,
13214
- std::max((uint32_t) 1, cparams.n_seq_max),
13215
- cparams.n_seq_max);
13216
- } break;
14451
+ // Models that need standard caching should rely on recurrent/hybrid
14452
+ // checks
13217
14453
  default:
13218
14454
  {
13219
- const auto padding = llama_kv_cache_unified::get_padding(cparams);
13220
-
13221
- cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13222
-
13223
- LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13224
-
13225
- if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
13226
- GGML_ASSERT(hparams.is_swa_any());
13227
-
13228
- res = new llama_kv_cache_unified_iswa(
13229
- *this,
13230
- params.type_k,
13231
- params.type_v,
13232
- !cparams.flash_attn,
13233
- cparams.offload_kqv,
13234
- params.swa_full,
13235
- cparams.n_ctx,
13236
- cparams.n_seq_max,
13237
- cparams.n_batch,
13238
- padding);
13239
- } else {
13240
- GGML_ASSERT(!hparams.is_swa_any());
13241
-
13242
- res = new llama_kv_cache_unified(
14455
+ if (llm_arch_is_recurrent(arch)) {
14456
+ res = new llama_memory_recurrent(
13243
14457
  *this,
13244
14458
  nullptr,
13245
- params.type_k,
13246
- params.type_v,
13247
- !cparams.flash_attn,
14459
+ GGML_TYPE_F32,
14460
+ GGML_TYPE_F32,
13248
14461
  cparams.offload_kqv,
13249
- cparams.n_ctx,
13250
- cparams.n_seq_max,
13251
- padding,
13252
- hparams.n_swa,
13253
- hparams.swa_type);
14462
+ std::max((uint32_t) 1, cparams.n_seq_max),
14463
+ cparams.n_seq_max);
14464
+ } else if (llm_arch_is_hybrid(arch)) {
14465
+ const auto padding = llama_kv_cache_unified::get_padding(cparams);
14466
+
14467
+ cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
14468
+
14469
+ res = new llama_memory_hybrid(
14470
+ /* model */ *this,
14471
+ /* attn_type_k */ params.type_k,
14472
+ /* attn_type_v */ params.type_v,
14473
+ /* attn_v_trans */ !cparams.flash_attn,
14474
+ /* attn_kv_size */ cparams.n_ctx,
14475
+ /* attn_n_pad */ padding,
14476
+ /* attn_n_swa */ hparams.n_swa,
14477
+ /* attn_swa_type */ hparams.swa_type,
14478
+ /* recurrent_type_k */ GGML_TYPE_F32,
14479
+ /* recurrent_type_v */ GGML_TYPE_F32,
14480
+ /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
14481
+ /* n_seq_max */ cparams.n_seq_max,
14482
+ /* offload */ cparams.offload_kqv);
14483
+ } else {
14484
+ const auto padding = llama_kv_cache_unified::get_padding(cparams);
14485
+
14486
+ cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
14487
+
14488
+ LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
14489
+
14490
+ if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
14491
+ GGML_ASSERT(hparams.is_swa_any());
14492
+
14493
+ res = new llama_kv_cache_unified_iswa(
14494
+ *this,
14495
+ params.type_k,
14496
+ params.type_v,
14497
+ !cparams.flash_attn,
14498
+ cparams.offload_kqv,
14499
+ params.swa_full,
14500
+ cparams.n_ctx,
14501
+ cparams.n_seq_max,
14502
+ cparams.n_ubatch,
14503
+ padding);
14504
+ } else {
14505
+ GGML_ASSERT(!hparams.is_swa_any());
14506
+
14507
+ res = new llama_kv_cache_unified(
14508
+ *this,
14509
+ nullptr,
14510
+ params.type_k,
14511
+ params.type_v,
14512
+ !cparams.flash_attn,
14513
+ cparams.offload_kqv,
14514
+ cparams.n_ctx,
14515
+ cparams.n_seq_max,
14516
+ padding,
14517
+ hparams.n_swa,
14518
+ hparams.swa_type);
14519
+ }
13254
14520
  }
13255
14521
  }
13256
14522
  }
@@ -13266,7 +14532,6 @@ llm_graph_result_ptr llama_model::build_graph(
13266
14532
 
13267
14533
  switch (arch) {
13268
14534
  case LLM_ARCH_LLAMA:
13269
- case LLM_ARCH_MINICPM:
13270
14535
  {
13271
14536
  llm = std::make_unique<llm_build_llama>(*this, params, gf);
13272
14537
  } break;
@@ -13305,6 +14570,10 @@ llm_graph_result_ptr llama_model::build_graph(
13305
14570
  {
13306
14571
  llm = std::make_unique<llm_build_bert>(*this, params, gf);
13307
14572
  } break;
14573
+ case LLM_ARCH_NEO_BERT:
14574
+ {
14575
+ llm = std::make_unique<llm_build_neo_bert>(*this, params, gf);
14576
+ } break;
13308
14577
  case LLM_ARCH_BLOOM:
13309
14578
  {
13310
14579
  llm = std::make_unique<llm_build_bloom>(*this, params, gf);
@@ -13390,6 +14659,10 @@ llm_graph_result_ptr llama_model::build_graph(
13390
14659
  {
13391
14660
  llm = std::make_unique<llm_build_gemma3_iswa>(*this, params, gf);
13392
14661
  } break;
14662
+ case LLM_ARCH_GEMMA3N:
14663
+ {
14664
+ llm = std::make_unique<llm_build_gemma3n_iswa>(*this, params, gf);
14665
+ } break;
13393
14666
  case LLM_ARCH_STARCODER2:
13394
14667
  {
13395
14668
  llm = std::make_unique<llm_build_starcoder2>(*this, params, gf);
@@ -13507,6 +14780,7 @@ llm_graph_result_ptr llama_model::build_graph(
13507
14780
  } break;
13508
14781
  case LLM_ARCH_GRANITE:
13509
14782
  case LLM_ARCH_GRANITE_MOE:
14783
+ case LLM_ARCH_MINICPM:
13510
14784
  {
13511
14785
  llm = std::make_unique<llm_build_granite>(*this, params, gf);
13512
14786
  } break;
@@ -13526,6 +14800,18 @@ llm_graph_result_ptr llama_model::build_graph(
13526
14800
  {
13527
14801
  llm = std::make_unique<llm_build_bailingmoe>(*this, params, gf);
13528
14802
  } break;
14803
+ case LLM_ARCH_DOTS1:
14804
+ {
14805
+ llm = std::make_unique<llm_build_dots1>(*this, params, gf);
14806
+ } break;
14807
+ case LLM_ARCH_ARCEE:
14808
+ {
14809
+ llm = std::make_unique<llm_build_arcee>(*this, params, gf);
14810
+ } break;
14811
+ case LLM_ARCH_ERNIE4_5:
14812
+ {
14813
+ llm = std::make_unique<llm_build_ernie4_5>(*this, params, gf);
14814
+ } break;
13529
14815
  default:
13530
14816
  GGML_ABORT("fatal error");
13531
14817
  }
@@ -13597,6 +14883,22 @@ int32_t llama_model_n_head_kv(const llama_model * model) {
13597
14883
  return model->hparams.n_head_kv();
13598
14884
  }
13599
14885
 
14886
+ int32_t llama_model_n_swa(const llama_model * model) {
14887
+ return model->hparams.n_swa;
14888
+ }
14889
+
14890
+ uint32_t llama_model_n_cls_out(const struct llama_model * model) {
14891
+ return model->hparams.n_cls_out;
14892
+ }
14893
+
14894
+ const char * llama_model_cls_label(const struct llama_model * model, uint32_t i) {
14895
+ if (i < model->classifier_labels.size()) {
14896
+ return model->classifier_labels[i].c_str();
14897
+ }
14898
+
14899
+ return nullptr;
14900
+ }
14901
+
13600
14902
  // deprecated
13601
14903
  int32_t llama_n_ctx_train(const llama_model * model) {
13602
14904
  return llama_model_n_ctx_train(model);
@@ -13659,6 +14961,9 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
13659
14961
  case LLM_ARCH_GRANITE_MOE:
13660
14962
  case LLM_ARCH_CHAMELEON:
13661
14963
  case LLM_ARCH_BAILINGMOE:
14964
+ case LLM_ARCH_NEO_BERT:
14965
+ case LLM_ARCH_ARCEE:
14966
+ case LLM_ARCH_ERNIE4_5:
13662
14967
  return LLAMA_ROPE_TYPE_NORM;
13663
14968
 
13664
14969
  // the pairs of head values are offset by n_rot/2
@@ -13684,6 +14989,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
13684
14989
  case LLM_ARCH_GEMMA:
13685
14990
  case LLM_ARCH_GEMMA2:
13686
14991
  case LLM_ARCH_GEMMA3:
14992
+ case LLM_ARCH_GEMMA3N:
13687
14993
  case LLM_ARCH_STARCODER2:
13688
14994
  case LLM_ARCH_OPENELM:
13689
14995
  case LLM_ARCH_GPTNEOX:
@@ -13692,6 +14998,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
13692
14998
  case LLM_ARCH_NEMOTRON:
13693
14999
  case LLM_ARCH_EXAONE:
13694
15000
  case LLM_ARCH_MINICPM3:
15001
+ case LLM_ARCH_DOTS1:
13695
15002
  return LLAMA_ROPE_TYPE_NEOX;
13696
15003
 
13697
15004
  case LLM_ARCH_QWEN2VL:
@@ -13757,7 +15064,7 @@ uint64_t llama_model_size(const llama_model * model) {
13757
15064
  }
13758
15065
 
13759
15066
  const char * llama_model_chat_template(const llama_model * model, const char * name) {
13760
- const auto key = name ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE_N)
15067
+ const auto key = name ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE)
13761
15068
  : LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE);
13762
15069
  const auto & it = model->gguf_kv.find(key);
13763
15070
  if (it == model->gguf_kv.end()) {
@@ -13765,7 +15072,7 @@ const char * llama_model_chat_template(const llama_model * model, const char * n
13765
15072
  // do not extend this list unless absolutely necessary
13766
15073
  // Mistral-Small-2503 does not have built-in chat template
13767
15074
  llama_vocab_pre_type pre_type = model->vocab.get_pre_type();
13768
- if (pre_type == LLAMA_VOCAB_PRE_TYPE_TEKKEN && model->layers.size() == 40) {
15075
+ if (!name && pre_type == LLAMA_VOCAB_PRE_TYPE_TEKKEN && model->layers.size() == 40) {
13769
15076
  return "mistral-v7-tekken";
13770
15077
  }
13771
15078
 
@@ -13799,14 +15106,7 @@ llama_token llama_model_decoder_start_token(const llama_model * model) {
13799
15106
  }
13800
15107
 
13801
15108
  bool llama_model_is_recurrent(const llama_model * model) {
13802
- switch (model->arch) {
13803
- case LLM_ARCH_MAMBA: return true;
13804
- case LLM_ARCH_RWKV6: return true;
13805
- case LLM_ARCH_RWKV6QWEN2: return true;
13806
- case LLM_ARCH_RWKV7: return true;
13807
- case LLM_ARCH_ARWKV7: return true;
13808
- default: return false;
13809
- }
15109
+ return llm_arch_is_recurrent(model->arch);
13810
15110
  }
13811
15111
 
13812
15112
  const std::vector<std::pair<std::string, ggml_tensor *>> & llama_internal_get_tensor_map(const llama_model * model) {