local-llm-rn 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (626) hide show
  1. package/cpp/CMakeLists.txt +285 -0
  2. package/cpp/common/CMakeLists.txt +149 -0
  3. package/cpp/common/arg.cpp +3799 -0
  4. package/cpp/common/arg.h +131 -0
  5. package/cpp/common/base64.hpp +392 -0
  6. package/cpp/common/build-info.cpp.in +4 -0
  7. package/cpp/common/chat-parser-xml-toolcall.cpp +879 -0
  8. package/cpp/common/chat-parser-xml-toolcall.h +45 -0
  9. package/cpp/common/chat-parser.cpp +1649 -0
  10. package/cpp/common/chat-parser.h +133 -0
  11. package/cpp/common/chat-peg-parser.cpp +124 -0
  12. package/cpp/common/chat-peg-parser.h +105 -0
  13. package/cpp/common/chat.cpp +3355 -0
  14. package/cpp/common/chat.h +252 -0
  15. package/cpp/common/common.cpp +1824 -0
  16. package/cpp/common/common.h +930 -0
  17. package/cpp/common/console.cpp +1137 -0
  18. package/cpp/common/console.h +41 -0
  19. package/cpp/common/debug.cpp +167 -0
  20. package/cpp/common/debug.h +43 -0
  21. package/cpp/common/download.cpp +792 -0
  22. package/cpp/common/download.h +84 -0
  23. package/cpp/common/http.h +84 -0
  24. package/cpp/common/jinja/README.md +88 -0
  25. package/cpp/common/jinja/caps.cpp +285 -0
  26. package/cpp/common/jinja/caps.h +30 -0
  27. package/cpp/common/jinja/lexer.cpp +341 -0
  28. package/cpp/common/jinja/lexer.h +157 -0
  29. package/cpp/common/jinja/parser.cpp +591 -0
  30. package/cpp/common/jinja/parser.h +21 -0
  31. package/cpp/common/jinja/runtime.cpp +867 -0
  32. package/cpp/common/jinja/runtime.h +638 -0
  33. package/cpp/common/jinja/string.cpp +213 -0
  34. package/cpp/common/jinja/string.h +61 -0
  35. package/cpp/common/jinja/utils.h +149 -0
  36. package/cpp/common/jinja/value.cpp +1393 -0
  37. package/cpp/common/jinja/value.h +756 -0
  38. package/cpp/common/json-partial.cpp +324 -0
  39. package/cpp/common/json-partial.h +39 -0
  40. package/cpp/common/json-schema-to-grammar.cpp +1153 -0
  41. package/cpp/common/json-schema-to-grammar.h +43 -0
  42. package/cpp/common/llguidance.cpp +258 -0
  43. package/cpp/common/log.cpp +446 -0
  44. package/cpp/common/log.h +119 -0
  45. package/cpp/common/ngram-cache.cpp +285 -0
  46. package/cpp/common/ngram-cache.h +101 -0
  47. package/cpp/common/ngram-map.cpp +530 -0
  48. package/cpp/common/ngram-map.h +115 -0
  49. package/cpp/common/ngram-mod.cpp +60 -0
  50. package/cpp/common/ngram-mod.h +38 -0
  51. package/cpp/common/peg-parser.cpp +1712 -0
  52. package/cpp/common/peg-parser.h +459 -0
  53. package/cpp/common/preset.cpp +483 -0
  54. package/cpp/common/preset.h +83 -0
  55. package/cpp/common/regex-partial.cpp +204 -0
  56. package/cpp/common/regex-partial.h +56 -0
  57. package/cpp/common/sampling.cpp +745 -0
  58. package/cpp/common/sampling.h +119 -0
  59. package/cpp/common/speculative.cpp +1074 -0
  60. package/cpp/common/speculative.h +41 -0
  61. package/cpp/common/unicode.cpp +64 -0
  62. package/cpp/common/unicode.h +22 -0
  63. package/cpp/ggml/CMakeLists.txt +494 -0
  64. package/cpp/ggml/cmake/GitVars.cmake +22 -0
  65. package/cpp/ggml/cmake/common.cmake +50 -0
  66. package/cpp/ggml/cmake/ggml-config.cmake.in +191 -0
  67. package/cpp/ggml/include/ggml-alloc.h +85 -0
  68. package/cpp/ggml/include/ggml-backend.h +373 -0
  69. package/cpp/ggml/include/ggml-blas.h +25 -0
  70. package/cpp/ggml/include/ggml-cann.h +123 -0
  71. package/cpp/ggml/include/ggml-cpp.h +39 -0
  72. package/cpp/ggml/include/ggml-cpu.h +151 -0
  73. package/cpp/ggml/include/ggml-cuda.h +47 -0
  74. package/cpp/ggml/include/ggml-hexagon.h +19 -0
  75. package/cpp/ggml/include/ggml-metal.h +61 -0
  76. package/cpp/ggml/include/ggml-opencl.h +26 -0
  77. package/cpp/ggml/include/ggml-opt.h +256 -0
  78. package/cpp/ggml/include/ggml-rpc.h +30 -0
  79. package/cpp/ggml/include/ggml-sycl.h +49 -0
  80. package/cpp/ggml/include/ggml-virtgpu.h +14 -0
  81. package/cpp/ggml/include/ggml-vulkan.h +29 -0
  82. package/cpp/ggml/include/ggml-webgpu.h +19 -0
  83. package/cpp/ggml/include/ggml-zdnn.h +17 -0
  84. package/cpp/ggml/include/ggml-zendnn.h +22 -0
  85. package/cpp/ggml/include/ggml.h +2753 -0
  86. package/cpp/ggml/include/gguf.h +204 -0
  87. package/cpp/ggml/src/CMakeLists.txt +492 -0
  88. package/cpp/ggml/src/ggml-alloc.c +1244 -0
  89. package/cpp/ggml/src/ggml-backend-dl.cpp +48 -0
  90. package/cpp/ggml/src/ggml-backend-dl.h +45 -0
  91. package/cpp/ggml/src/ggml-backend-impl.h +255 -0
  92. package/cpp/ggml/src/ggml-backend-reg.cpp +566 -0
  93. package/cpp/ggml/src/ggml-backend.cpp +2270 -0
  94. package/cpp/ggml/src/ggml-blas/CMakeLists.txt +101 -0
  95. package/cpp/ggml/src/ggml-blas/ggml-blas.cpp +518 -0
  96. package/cpp/ggml/src/ggml-common.h +1878 -0
  97. package/cpp/ggml/src/ggml-cpu/CMakeLists.txt +691 -0
  98. package/cpp/ggml/src/ggml-cpu/amx/amx.cpp +247 -0
  99. package/cpp/ggml/src/ggml-cpu/amx/amx.h +8 -0
  100. package/cpp/ggml/src/ggml-cpu/amx/common.h +91 -0
  101. package/cpp/ggml/src/ggml-cpu/amx/mmq.cpp +2512 -0
  102. package/cpp/ggml/src/ggml-cpu/amx/mmq.h +10 -0
  103. package/cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +98 -0
  104. package/cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4052 -0
  105. package/cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +4935 -0
  106. package/cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2159 -0
  107. package/cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  108. package/cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2305 -0
  109. package/cpp/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  110. package/cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2726 -0
  111. package/cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +342 -0
  112. package/cpp/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  113. package/cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1468 -0
  114. package/cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1221 -0
  115. package/cpp/ggml/src/ggml-cpu/arch/x86/cpu-feats.cpp +327 -0
  116. package/cpp/ggml/src/ggml-cpu/arch/x86/quants.c +3820 -0
  117. package/cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +6307 -0
  118. package/cpp/ggml/src/ggml-cpu/arch-fallback.h +313 -0
  119. package/cpp/ggml/src/ggml-cpu/binary-ops.cpp +154 -0
  120. package/cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
  121. package/cpp/ggml/src/ggml-cpu/cmake/FindSIMD.cmake +100 -0
  122. package/cpp/ggml/src/ggml-cpu/common.h +95 -0
  123. package/cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +529 -0
  124. package/cpp/ggml/src/ggml-cpu/ggml-cpu.c +3734 -0
  125. package/cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +701 -0
  126. package/cpp/ggml/src/ggml-cpu/hbm.cpp +55 -0
  127. package/cpp/ggml/src/ggml-cpu/hbm.h +8 -0
  128. package/cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +938 -0
  129. package/cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +90 -0
  130. package/cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +798 -0
  131. package/cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
  132. package/cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +4033 -0
  133. package/cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +25 -0
  134. package/cpp/ggml/src/ggml-cpu/ops.cpp +10978 -0
  135. package/cpp/ggml/src/ggml-cpu/ops.h +116 -0
  136. package/cpp/ggml/src/ggml-cpu/quants.c +1193 -0
  137. package/cpp/ggml/src/ggml-cpu/quants.h +97 -0
  138. package/cpp/ggml/src/ggml-cpu/repack.cpp +3316 -0
  139. package/cpp/ggml/src/ggml-cpu/repack.h +173 -0
  140. package/cpp/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  141. package/cpp/ggml/src/ggml-cpu/simd-mappings.h +1279 -0
  142. package/cpp/ggml/src/ggml-cpu/spacemit/ime.cpp +1025 -0
  143. package/cpp/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  144. package/cpp/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  145. package/cpp/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  146. package/cpp/ggml/src/ggml-cpu/traits.cpp +36 -0
  147. package/cpp/ggml/src/ggml-cpu/traits.h +38 -0
  148. package/cpp/ggml/src/ggml-cpu/unary-ops.cpp +337 -0
  149. package/cpp/ggml/src/ggml-cpu/unary-ops.h +35 -0
  150. package/cpp/ggml/src/ggml-cpu/vec.cpp +629 -0
  151. package/cpp/ggml/src/ggml-cpu/vec.h +1585 -0
  152. package/cpp/ggml/src/ggml-hexagon/CMakeLists.txt +117 -0
  153. package/cpp/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3232 -0
  154. package/cpp/ggml/src/ggml-hexagon/htp/CMakeLists.txt +45 -0
  155. package/cpp/ggml/src/ggml-hexagon/htp/act-ops.c +815 -0
  156. package/cpp/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  157. package/cpp/ggml/src/ggml-hexagon/htp/binary-ops.c +827 -0
  158. package/cpp/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  159. package/cpp/ggml/src/ggml-hexagon/htp/cpy-ops.c +251 -0
  160. package/cpp/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +666 -0
  161. package/cpp/ggml/src/ggml-hexagon/htp/get-rows-ops.c +111 -0
  162. package/cpp/ggml/src/ggml-hexagon/htp/hex-dma.c +63 -0
  163. package/cpp/ggml/src/ggml-hexagon/htp/hex-dma.h +182 -0
  164. package/cpp/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  165. package/cpp/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  166. package/cpp/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  167. package/cpp/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  168. package/cpp/ggml/src/ggml-hexagon/htp/htp-msg.h +154 -0
  169. package/cpp/ggml/src/ggml-hexagon/htp/htp-ops.h +65 -0
  170. package/cpp/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  171. package/cpp/ggml/src/ggml-hexagon/htp/hvx-arith.h +470 -0
  172. package/cpp/ggml/src/ggml-hexagon/htp/hvx-base.h +173 -0
  173. package/cpp/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  174. package/cpp/ggml/src/ggml-hexagon/htp/hvx-div.h +116 -0
  175. package/cpp/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  176. package/cpp/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  177. package/cpp/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  178. package/cpp/ggml/src/ggml-hexagon/htp/hvx-inverse.h +176 -0
  179. package/cpp/ggml/src/ggml-hexagon/htp/hvx-reduce.h +266 -0
  180. package/cpp/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  181. package/cpp/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  182. package/cpp/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  183. package/cpp/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  184. package/cpp/ggml/src/ggml-hexagon/htp/hvx-utils.h +18 -0
  185. package/cpp/ggml/src/ggml-hexagon/htp/main.c +1150 -0
  186. package/cpp/ggml/src/ggml-hexagon/htp/matmul-ops.c +2595 -0
  187. package/cpp/ggml/src/ggml-hexagon/htp/rope-ops.c +498 -0
  188. package/cpp/ggml/src/ggml-hexagon/htp/set-rows-ops.c +167 -0
  189. package/cpp/ggml/src/ggml-hexagon/htp/softmax-ops.c +421 -0
  190. package/cpp/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +130 -0
  191. package/cpp/ggml/src/ggml-hexagon/htp/unary-ops.c +384 -0
  192. package/cpp/ggml/src/ggml-hexagon/htp/worker-pool.c +293 -0
  193. package/cpp/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  194. package/cpp/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  195. package/cpp/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  196. package/cpp/ggml/src/ggml-hexagon/libdl.h +79 -0
  197. package/cpp/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  198. package/cpp/ggml/src/ggml-hexagon/op-desc.h +153 -0
  199. package/cpp/ggml/src/ggml-impl.h +724 -0
  200. package/cpp/ggml/src/ggml-metal/CMakeLists.txt +124 -0
  201. package/cpp/ggml/src/ggml-metal/ggml-metal-common.cpp +457 -0
  202. package/cpp/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  203. package/cpp/ggml/src/ggml-metal/ggml-metal-context.h +41 -0
  204. package/cpp/ggml/src/ggml-metal/ggml-metal-context.m +702 -0
  205. package/cpp/ggml/src/ggml-metal/ggml-metal-device.cpp +1890 -0
  206. package/cpp/ggml/src/ggml-metal/ggml-metal-device.h +290 -0
  207. package/cpp/ggml/src/ggml-metal/ggml-metal-device.m +1749 -0
  208. package/cpp/ggml/src/ggml-metal/ggml-metal-impl.h +1054 -0
  209. package/cpp/ggml/src/ggml-metal/ggml-metal-ops.cpp +4370 -0
  210. package/cpp/ggml/src/ggml-metal/ggml-metal-ops.h +94 -0
  211. package/cpp/ggml/src/ggml-metal/ggml-metal.cpp +937 -0
  212. package/cpp/ggml/src/ggml-metal/ggml-metal.metal +9819 -0
  213. package/cpp/ggml/src/ggml-musa/CMakeLists.txt +125 -0
  214. package/cpp/ggml/src/ggml-musa/mudnn.cu +112 -0
  215. package/cpp/ggml/src/ggml-musa/mudnn.cuh +12 -0
  216. package/cpp/ggml/src/ggml-opencl/CMakeLists.txt +150 -0
  217. package/cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +11553 -0
  218. package/cpp/ggml/src/ggml-opencl/kernels/add.cl +190 -0
  219. package/cpp/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  220. package/cpp/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  221. package/cpp/ggml/src/ggml-opencl/kernels/clamp.cl +20 -0
  222. package/cpp/ggml/src/ggml-opencl/kernels/concat.cl +51 -0
  223. package/cpp/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  224. package/cpp/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  225. package/cpp/ggml/src/ggml-opencl/kernels/cpy.cl +184 -0
  226. package/cpp/ggml/src/ggml-opencl/kernels/cvt.cl +417 -0
  227. package/cpp/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl +58 -0
  228. package/cpp/ggml/src/ggml-opencl/kernels/div.cl +138 -0
  229. package/cpp/ggml/src/ggml-opencl/kernels/embed_kernel.py +26 -0
  230. package/cpp/ggml/src/ggml-opencl/kernels/expm1.cl +113 -0
  231. package/cpp/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  232. package/cpp/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  233. package/cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +371 -0
  234. package/cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  235. package/cpp/ggml/src/ggml-opencl/kernels/gelu.cl +89 -0
  236. package/cpp/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  237. package/cpp/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  238. package/cpp/ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl +268 -0
  239. package/cpp/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl +274 -0
  240. package/cpp/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  241. package/cpp/ggml/src/ggml-opencl/kernels/get_rows.cl +187 -0
  242. package/cpp/ggml/src/ggml-opencl/kernels/glu.cl +378 -0
  243. package/cpp/ggml/src/ggml-opencl/kernels/group_norm.cl +121 -0
  244. package/cpp/ggml/src/ggml-opencl/kernels/im2col_f16.cl +57 -0
  245. package/cpp/ggml/src/ggml-opencl/kernels/im2col_f32.cl +57 -0
  246. package/cpp/ggml/src/ggml-opencl/kernels/mean.cl +140 -0
  247. package/cpp/ggml/src/ggml-opencl/kernels/mul.cl +152 -0
  248. package/cpp/ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl +139 -0
  249. package/cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  250. package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  251. package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +146 -0
  252. package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +147 -0
  253. package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  254. package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  255. package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  256. package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  257. package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  258. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl +118 -0
  259. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl +118 -0
  260. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl +94 -0
  261. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl +84 -0
  262. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl +118 -0
  263. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  264. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  265. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  266. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  267. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  268. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  269. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  270. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl +192 -0
  271. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl +307 -0
  272. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl +265 -0
  273. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl +272 -0
  274. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl +254 -0
  275. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  276. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  277. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  278. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32.cl +194 -0
  279. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  280. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  281. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  282. package/cpp/ggml/src/ggml-opencl/kernels/norm.cl +161 -0
  283. package/cpp/ggml/src/ggml-opencl/kernels/pad.cl +39 -0
  284. package/cpp/ggml/src/ggml-opencl/kernels/relu.cl +16 -0
  285. package/cpp/ggml/src/ggml-opencl/kernels/repeat.cl +38 -0
  286. package/cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +190 -0
  287. package/cpp/ggml/src/ggml-opencl/kernels/rope.cl +747 -0
  288. package/cpp/ggml/src/ggml-opencl/kernels/scale.cl +27 -0
  289. package/cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +208 -0
  290. package/cpp/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  291. package/cpp/ggml/src/ggml-opencl/kernels/silu.cl +30 -0
  292. package/cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +108 -0
  293. package/cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +108 -0
  294. package/cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +107 -0
  295. package/cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +107 -0
  296. package/cpp/ggml/src/ggml-opencl/kernels/softplus.cl +116 -0
  297. package/cpp/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  298. package/cpp/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  299. package/cpp/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  300. package/cpp/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  301. package/cpp/ggml/src/ggml-opencl/kernels/sub.cl +138 -0
  302. package/cpp/ggml/src/ggml-opencl/kernels/sum_rows.cl +140 -0
  303. package/cpp/ggml/src/ggml-opencl/kernels/tanh.cl +109 -0
  304. package/cpp/ggml/src/ggml-opencl/kernels/transpose.cl +117 -0
  305. package/cpp/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  306. package/cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  307. package/cpp/ggml/src/ggml-opencl/kernels/upscale.cl +120 -0
  308. package/cpp/ggml/src/ggml-opt.cpp +1093 -0
  309. package/cpp/ggml/src/ggml-quants.c +5325 -0
  310. package/cpp/ggml/src/ggml-quants.h +106 -0
  311. package/cpp/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
  312. package/cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +2118 -0
  313. package/cpp/ggml/src/ggml-threading.cpp +12 -0
  314. package/cpp/ggml/src/ggml-threading.h +14 -0
  315. package/cpp/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  316. package/cpp/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  317. package/cpp/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  318. package/cpp/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  319. package/cpp/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  320. package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  321. package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  322. package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  323. package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  324. package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  325. package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  326. package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  327. package/cpp/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  328. package/cpp/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  329. package/cpp/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  330. package/cpp/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  331. package/cpp/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  332. package/cpp/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  333. package/cpp/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  334. package/cpp/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  335. package/cpp/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  336. package/cpp/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  337. package/cpp/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  338. package/cpp/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  339. package/cpp/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  340. package/cpp/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  341. package/cpp/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  342. package/cpp/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  343. package/cpp/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  344. package/cpp/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  345. package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  346. package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  347. package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  348. package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  349. package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  350. package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  351. package/cpp/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  352. package/cpp/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  353. package/cpp/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  354. package/cpp/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  355. package/cpp/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  356. package/cpp/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  357. package/cpp/ggml/src/ggml-webgpu/CMakeLists.txt +80 -0
  358. package/cpp/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1231 -0
  359. package/cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3150 -0
  360. package/cpp/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  361. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  362. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  363. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  364. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +107 -0
  365. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +923 -0
  366. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +107 -0
  367. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  368. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +182 -0
  369. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +636 -0
  370. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +668 -0
  371. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  372. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  373. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +713 -0
  374. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +103 -0
  375. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +138 -0
  376. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +188 -0
  377. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +194 -0
  378. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  379. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +123 -0
  380. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  381. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +63 -0
  382. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +109 -0
  383. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  384. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  385. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  386. package/cpp/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  387. package/cpp/ggml/src/ggml-zdnn/common.hpp +59 -0
  388. package/cpp/ggml/src/ggml-zdnn/ggml-zdnn.cpp +633 -0
  389. package/cpp/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  390. package/cpp/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  391. package/cpp/ggml/src/ggml-zdnn/utils.cpp +79 -0
  392. package/cpp/ggml/src/ggml-zdnn/utils.hpp +19 -0
  393. package/cpp/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  394. package/cpp/ggml/src/ggml-zendnn/ggml-zendnn.cpp +469 -0
  395. package/cpp/ggml/src/ggml.c +7669 -0
  396. package/cpp/ggml/src/ggml.cpp +26 -0
  397. package/cpp/ggml/src/gguf.cpp +1699 -0
  398. package/cpp/include/llama-cpp.h +32 -0
  399. package/cpp/include/llama.h +1568 -0
  400. package/cpp/mtmd/CMakeLists.txt +98 -0
  401. package/cpp/mtmd/README.md +63 -0
  402. package/cpp/mtmd/clip-graph.h +117 -0
  403. package/cpp/mtmd/clip-impl.h +586 -0
  404. package/cpp/mtmd/clip-model.h +390 -0
  405. package/cpp/mtmd/clip.cpp +4154 -0
  406. package/cpp/mtmd/clip.h +121 -0
  407. package/cpp/mtmd/deprecation-warning.cpp +22 -0
  408. package/cpp/mtmd/legacy-models/convert_image_encoder_to_gguf.py +412 -0
  409. package/cpp/mtmd/legacy-models/glmedge-convert-image-encoder-to-gguf.py +280 -0
  410. package/cpp/mtmd/legacy-models/glmedge-surgery.py +33 -0
  411. package/cpp/mtmd/legacy-models/llava_surgery.py +38 -0
  412. package/cpp/mtmd/legacy-models/llava_surgery_v2.py +180 -0
  413. package/cpp/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py +892 -0
  414. package/cpp/mtmd/legacy-models/minicpmv-surgery.py +47 -0
  415. package/cpp/mtmd/models/cogvlm.cpp +98 -0
  416. package/cpp/mtmd/models/conformer.cpp +216 -0
  417. package/cpp/mtmd/models/glm4v.cpp +122 -0
  418. package/cpp/mtmd/models/internvl.cpp +69 -0
  419. package/cpp/mtmd/models/kimik25.cpp +101 -0
  420. package/cpp/mtmd/models/kimivl.cpp +63 -0
  421. package/cpp/mtmd/models/llama4.cpp +96 -0
  422. package/cpp/mtmd/models/llava.cpp +374 -0
  423. package/cpp/mtmd/models/minicpmv.cpp +114 -0
  424. package/cpp/mtmd/models/mobilenetv5.cpp +451 -0
  425. package/cpp/mtmd/models/models.h +128 -0
  426. package/cpp/mtmd/models/nemotron-v2-vl.cpp +35 -0
  427. package/cpp/mtmd/models/paddleocr.cpp +52 -0
  428. package/cpp/mtmd/models/pixtral.cpp +86 -0
  429. package/cpp/mtmd/models/qwen2vl.cpp +183 -0
  430. package/cpp/mtmd/models/qwen3vl.cpp +193 -0
  431. package/cpp/mtmd/models/siglip.cpp +86 -0
  432. package/cpp/mtmd/models/whisper-enc.cpp +115 -0
  433. package/cpp/mtmd/models/youtuvl.cpp +179 -0
  434. package/cpp/mtmd/mtmd-audio.cpp +730 -0
  435. package/cpp/mtmd/mtmd-audio.h +113 -0
  436. package/cpp/mtmd/mtmd-cli.cpp +437 -0
  437. package/cpp/mtmd/mtmd-helper.cpp +521 -0
  438. package/cpp/mtmd/mtmd-helper.h +96 -0
  439. package/cpp/mtmd/mtmd.cpp +1156 -0
  440. package/cpp/mtmd/mtmd.h +319 -0
  441. package/cpp/mtmd/requirements.txt +5 -0
  442. package/cpp/mtmd/test-1.jpeg +0 -0
  443. package/cpp/mtmd/test-2.mp3 +0 -0
  444. package/cpp/mtmd/tests.sh +192 -0
  445. package/cpp/src/CMakeLists.txt +169 -0
  446. package/cpp/src/llama-adapter.cpp +488 -0
  447. package/cpp/src/llama-adapter.h +89 -0
  448. package/cpp/src/llama-arch.cpp +2855 -0
  449. package/cpp/src/llama-arch.h +619 -0
  450. package/cpp/src/llama-batch.cpp +917 -0
  451. package/cpp/src/llama-batch.h +173 -0
  452. package/cpp/src/llama-chat.cpp +896 -0
  453. package/cpp/src/llama-chat.h +71 -0
  454. package/cpp/src/llama-context.cpp +3512 -0
  455. package/cpp/src/llama-context.h +359 -0
  456. package/cpp/src/llama-cparams.cpp +5 -0
  457. package/cpp/src/llama-cparams.h +44 -0
  458. package/cpp/src/llama-grammar.cpp +1464 -0
  459. package/cpp/src/llama-grammar.h +194 -0
  460. package/cpp/src/llama-graph.cpp +2685 -0
  461. package/cpp/src/llama-graph.h +1026 -0
  462. package/cpp/src/llama-hparams.cpp +234 -0
  463. package/cpp/src/llama-hparams.h +339 -0
  464. package/cpp/src/llama-impl.cpp +171 -0
  465. package/cpp/src/llama-impl.h +73 -0
  466. package/cpp/src/llama-io.cpp +15 -0
  467. package/cpp/src/llama-io.h +35 -0
  468. package/cpp/src/llama-kv-cache-iswa.cpp +330 -0
  469. package/cpp/src/llama-kv-cache-iswa.h +137 -0
  470. package/cpp/src/llama-kv-cache.cpp +2271 -0
  471. package/cpp/src/llama-kv-cache.h +388 -0
  472. package/cpp/src/llama-kv-cells.h +533 -0
  473. package/cpp/src/llama-memory-hybrid-iswa.cpp +275 -0
  474. package/cpp/src/llama-memory-hybrid-iswa.h +140 -0
  475. package/cpp/src/llama-memory-hybrid.cpp +268 -0
  476. package/cpp/src/llama-memory-hybrid.h +139 -0
  477. package/cpp/src/llama-memory-recurrent.cpp +1165 -0
  478. package/cpp/src/llama-memory-recurrent.h +182 -0
  479. package/cpp/src/llama-memory.cpp +59 -0
  480. package/cpp/src/llama-memory.h +122 -0
  481. package/cpp/src/llama-mmap.cpp +785 -0
  482. package/cpp/src/llama-mmap.h +92 -0
  483. package/cpp/src/llama-model-loader.cpp +1414 -0
  484. package/cpp/src/llama-model-loader.h +203 -0
  485. package/cpp/src/llama-model-saver.cpp +286 -0
  486. package/cpp/src/llama-model-saver.h +37 -0
  487. package/cpp/src/llama-model.cpp +9253 -0
  488. package/cpp/src/llama-model.h +576 -0
  489. package/cpp/src/llama-quant.cpp +1119 -0
  490. package/cpp/src/llama-quant.h +1 -0
  491. package/cpp/src/llama-sampler.cpp +3885 -0
  492. package/cpp/src/llama-sampler.h +42 -0
  493. package/cpp/src/llama-vocab.cpp +3970 -0
  494. package/cpp/src/llama-vocab.h +187 -0
  495. package/cpp/src/llama.cpp +1313 -0
  496. package/cpp/src/models/afmoe.cpp +191 -0
  497. package/cpp/src/models/apertus.cpp +125 -0
  498. package/cpp/src/models/arcee.cpp +135 -0
  499. package/cpp/src/models/arctic.cpp +138 -0
  500. package/cpp/src/models/arwkv7.cpp +86 -0
  501. package/cpp/src/models/baichuan.cpp +122 -0
  502. package/cpp/src/models/bailingmoe.cpp +144 -0
  503. package/cpp/src/models/bailingmoe2.cpp +135 -0
  504. package/cpp/src/models/bert.cpp +178 -0
  505. package/cpp/src/models/bitnet.cpp +160 -0
  506. package/cpp/src/models/bloom.cpp +101 -0
  507. package/cpp/src/models/chameleon.cpp +178 -0
  508. package/cpp/src/models/chatglm.cpp +132 -0
  509. package/cpp/src/models/codeshell.cpp +111 -0
  510. package/cpp/src/models/cogvlm.cpp +102 -0
  511. package/cpp/src/models/cohere2-iswa.cpp +134 -0
  512. package/cpp/src/models/command-r.cpp +122 -0
  513. package/cpp/src/models/dbrx.cpp +123 -0
  514. package/cpp/src/models/deci.cpp +135 -0
  515. package/cpp/src/models/deepseek.cpp +144 -0
  516. package/cpp/src/models/deepseek2.cpp +262 -0
  517. package/cpp/src/models/delta-net-base.cpp +376 -0
  518. package/cpp/src/models/dots1.cpp +134 -0
  519. package/cpp/src/models/dream.cpp +105 -0
  520. package/cpp/src/models/ernie4-5-moe.cpp +150 -0
  521. package/cpp/src/models/ernie4-5.cpp +110 -0
  522. package/cpp/src/models/eurobert.cpp +97 -0
  523. package/cpp/src/models/exaone-moe.cpp +146 -0
  524. package/cpp/src/models/exaone.cpp +114 -0
  525. package/cpp/src/models/exaone4.cpp +123 -0
  526. package/cpp/src/models/falcon-h1.cpp +111 -0
  527. package/cpp/src/models/falcon.cpp +120 -0
  528. package/cpp/src/models/gemma-embedding.cpp +116 -0
  529. package/cpp/src/models/gemma.cpp +112 -0
  530. package/cpp/src/models/gemma2-iswa.cpp +128 -0
  531. package/cpp/src/models/gemma3.cpp +155 -0
  532. package/cpp/src/models/gemma3n-iswa.cpp +384 -0
  533. package/cpp/src/models/glm4-moe.cpp +170 -0
  534. package/cpp/src/models/glm4.cpp +157 -0
  535. package/cpp/src/models/gpt2.cpp +105 -0
  536. package/cpp/src/models/gptneox.cpp +144 -0
  537. package/cpp/src/models/granite-hybrid.cpp +196 -0
  538. package/cpp/src/models/granite.cpp +211 -0
  539. package/cpp/src/models/grok.cpp +159 -0
  540. package/cpp/src/models/grovemoe.cpp +141 -0
  541. package/cpp/src/models/hunyuan-dense.cpp +132 -0
  542. package/cpp/src/models/hunyuan-moe.cpp +154 -0
  543. package/cpp/src/models/internlm2.cpp +120 -0
  544. package/cpp/src/models/jais.cpp +86 -0
  545. package/cpp/src/models/jais2.cpp +123 -0
  546. package/cpp/src/models/jamba.cpp +106 -0
  547. package/cpp/src/models/kimi-linear.cpp +392 -0
  548. package/cpp/src/models/lfm2.cpp +190 -0
  549. package/cpp/src/models/llada-moe.cpp +122 -0
  550. package/cpp/src/models/llada.cpp +99 -0
  551. package/cpp/src/models/llama-iswa.cpp +178 -0
  552. package/cpp/src/models/llama.cpp +168 -0
  553. package/cpp/src/models/maincoder.cpp +117 -0
  554. package/cpp/src/models/mamba-base.cpp +285 -0
  555. package/cpp/src/models/mamba.cpp +54 -0
  556. package/cpp/src/models/mimo2-iswa.cpp +123 -0
  557. package/cpp/src/models/minicpm3.cpp +200 -0
  558. package/cpp/src/models/minimax-m2.cpp +124 -0
  559. package/cpp/src/models/mistral3.cpp +160 -0
  560. package/cpp/src/models/models.h +684 -0
  561. package/cpp/src/models/modern-bert.cpp +109 -0
  562. package/cpp/src/models/mpt.cpp +126 -0
  563. package/cpp/src/models/nemotron-h.cpp +148 -0
  564. package/cpp/src/models/nemotron.cpp +122 -0
  565. package/cpp/src/models/neo-bert.cpp +104 -0
  566. package/cpp/src/models/olmo.cpp +121 -0
  567. package/cpp/src/models/olmo2.cpp +150 -0
  568. package/cpp/src/models/olmoe.cpp +124 -0
  569. package/cpp/src/models/openai-moe-iswa.cpp +127 -0
  570. package/cpp/src/models/openelm.cpp +124 -0
  571. package/cpp/src/models/orion.cpp +123 -0
  572. package/cpp/src/models/paddleocr.cpp +122 -0
  573. package/cpp/src/models/pangu-embedded.cpp +121 -0
  574. package/cpp/src/models/phi2.cpp +121 -0
  575. package/cpp/src/models/phi3.cpp +152 -0
  576. package/cpp/src/models/plamo.cpp +110 -0
  577. package/cpp/src/models/plamo2.cpp +318 -0
  578. package/cpp/src/models/plamo3.cpp +128 -0
  579. package/cpp/src/models/plm.cpp +169 -0
  580. package/cpp/src/models/qwen.cpp +108 -0
  581. package/cpp/src/models/qwen2.cpp +126 -0
  582. package/cpp/src/models/qwen2moe.cpp +151 -0
  583. package/cpp/src/models/qwen2vl.cpp +117 -0
  584. package/cpp/src/models/qwen3.cpp +117 -0
  585. package/cpp/src/models/qwen35.cpp +386 -0
  586. package/cpp/src/models/qwen35moe.cpp +420 -0
  587. package/cpp/src/models/qwen3moe.cpp +124 -0
  588. package/cpp/src/models/qwen3next.cpp +525 -0
  589. package/cpp/src/models/qwen3vl-moe.cpp +140 -0
  590. package/cpp/src/models/qwen3vl.cpp +132 -0
  591. package/cpp/src/models/refact.cpp +94 -0
  592. package/cpp/src/models/rnd1.cpp +126 -0
  593. package/cpp/src/models/rwkv6-base.cpp +164 -0
  594. package/cpp/src/models/rwkv6.cpp +94 -0
  595. package/cpp/src/models/rwkv6qwen2.cpp +86 -0
  596. package/cpp/src/models/rwkv7-base.cpp +137 -0
  597. package/cpp/src/models/rwkv7.cpp +90 -0
  598. package/cpp/src/models/seed-oss.cpp +124 -0
  599. package/cpp/src/models/smallthinker.cpp +126 -0
  600. package/cpp/src/models/smollm3.cpp +128 -0
  601. package/cpp/src/models/stablelm.cpp +146 -0
  602. package/cpp/src/models/starcoder.cpp +100 -0
  603. package/cpp/src/models/starcoder2.cpp +121 -0
  604. package/cpp/src/models/step35-iswa.cpp +168 -0
  605. package/cpp/src/models/t5-dec.cpp +166 -0
  606. package/cpp/src/models/t5-enc.cpp +96 -0
  607. package/cpp/src/models/wavtokenizer-dec.cpp +149 -0
  608. package/cpp/src/models/xverse.cpp +108 -0
  609. package/cpp/src/unicode-data.cpp +7034 -0
  610. package/cpp/src/unicode-data.h +20 -0
  611. package/cpp/src/unicode.cpp +1103 -0
  612. package/cpp/src/unicode.h +111 -0
  613. package/cpp/vendor/nlohmann/json.hpp +25526 -0
  614. package/cpp/vendor/nlohmann/json_fwd.hpp +187 -0
  615. package/cpp/vendor/stb/stb_image.h +7988 -0
  616. package/ios/LocalLLM-Bridging-Header.h +2 -0
  617. package/ios/LocalLLM.h +5 -0
  618. package/ios/LocalLLM.mm +1267 -0
  619. package/local-llm-rn.podspec +60 -0
  620. package/package.json +35 -0
  621. package/src/NativeLocalLLM.ts +73 -0
  622. package/src/device.ts +50 -0
  623. package/src/download-adapter.ts +17 -0
  624. package/src/index.ts +21 -0
  625. package/src/native-bridge.ts +142 -0
  626. package/src/rn-downloader.ts +37 -0
@@ -0,0 +1,2685 @@
1
+ #include "llama-graph.h"
2
+
3
+ #include "llama-impl.h"
4
+ #include "llama-batch.h"
5
+ #include "llama-cparams.h"
6
+
7
+ #include "llama-kv-cache.h"
8
+ #include "llama-kv-cache-iswa.h"
9
+ #include "llama-memory-hybrid.h"
10
+ #include "llama-memory-hybrid-iswa.h"
11
+ #include "llama-memory-recurrent.h"
12
+
13
+ #include <cassert>
14
+ #include <cmath>
15
+ #include <cstring>
16
+ #include <numeric>
17
+ #include <sstream>
18
+ #include <unordered_set>
19
+
20
+ // dedup helpers
21
+
22
+ static ggml_tensor * build_kq_mask(
23
+ ggml_context * ctx,
24
+ const llama_kv_cache_context * mctx,
25
+ const llama_ubatch & ubatch,
26
+ const llama_cparams & cparams) {
27
+ const auto n_kv = mctx->get_n_kv();
28
+ const auto n_tokens = ubatch.n_tokens;
29
+ const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
30
+
31
+ return ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
32
+ }
33
+
34
+ static bool can_reuse_kq_mask(
35
+ ggml_tensor * kq_mask,
36
+ const llama_kv_cache_context * mctx,
37
+ const llama_ubatch & ubatch,
38
+ const llama_cparams & cparams) {
39
+ const auto n_kv = mctx->get_n_kv();
40
+ const auto n_tokens = ubatch.n_tokens;
41
+ const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
42
+
43
+ bool res = true;
44
+
45
+ res &= (kq_mask->ne[0] == n_kv);
46
+ res &= (kq_mask->ne[1] == n_tokens/n_stream);
47
+ res &= (kq_mask->ne[2] == 1);
48
+ res &= (kq_mask->ne[3] == n_stream);
49
+
50
+ return res;
51
+ }
52
+
53
+ // impl
54
+
55
+ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
56
+ if (ubatch->token) {
57
+ const int64_t n_tokens = ubatch->n_tokens;
58
+
59
+ ggml_backend_tensor_set(tokens, ubatch->token, 0, n_tokens*ggml_element_size(tokens));
60
+ }
61
+
62
+ if (ubatch->embd) {
63
+ GGML_ASSERT(n_embd == embd->ne[0]);
64
+
65
+ const int64_t n_tokens = ubatch->n_tokens;
66
+
67
+ ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(embd));
68
+ }
69
+ }
70
+
71
+ bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) {
72
+ bool res = true;
73
+
74
+ res &= (!params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
75
+ res &= (!params.ubatch.embd) || (embd && embd->ne[1] == params.ubatch.n_tokens);
76
+
77
+ return res;
78
+ }
79
+
80
+ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
81
+ if (ubatch->pos && pos) {
82
+ const int64_t n_tokens = ubatch->n_tokens;
83
+
84
+ if (ubatch->token && n_pos_per_embd == 4) {
85
+ // in case we're using M-RoPE with text tokens, convert the 1D positions to 4D
86
+ // the 3 first dims are the same, and 4th dim is all 0
87
+ std::vector<llama_pos> pos_data(n_tokens*n_pos_per_embd);
88
+ // copy the first dimension
89
+ for (int i = 0; i < n_tokens; ++i) {
90
+ pos_data[ i] = ubatch->pos[i];
91
+ pos_data[ n_tokens + i] = ubatch->pos[i];
92
+ pos_data[2 * n_tokens + i] = ubatch->pos[i];
93
+ pos_data[3 * n_tokens + i] = 0; // 4th dim is 0
94
+ }
95
+ ggml_backend_tensor_set(pos, pos_data.data(), 0, pos_data.size()*ggml_element_size(pos));
96
+ } else {
97
+ ggml_backend_tensor_set(pos, ubatch->pos, 0, n_tokens*n_pos_per_embd*ggml_element_size(pos));
98
+ }
99
+ }
100
+ }
101
+
102
+ bool llm_graph_input_pos::can_reuse(const llm_graph_params & params) {
103
+ bool res = true;
104
+
105
+ res &= pos->ne[0] == params.ubatch.n_tokens*n_pos_per_embd;
106
+
107
+ return res;
108
+ }
109
+
110
+ void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
111
+ if (ubatch->pos && attn_scale) {
112
+ const int64_t n_tokens = ubatch->n_tokens;
113
+
114
+ GGML_ASSERT(f_attn_temp_scale != 0.0f);
115
+ GGML_ASSERT(n_attn_temp_floor_scale != 0);
116
+
117
+ std::vector<float> attn_scale_data(n_tokens, 0.0f);
118
+ for (int i = 0; i < n_tokens; ++i) {
119
+ const float pos = ubatch->pos[i];
120
+ attn_scale_data[i] = std::log(
121
+ std::floor((pos + f_attn_temp_offset) / n_attn_temp_floor_scale) + 1.0
122
+ ) * f_attn_temp_scale + 1.0;
123
+ }
124
+
125
+ ggml_backend_tensor_set(attn_scale, attn_scale_data.data(), 0, n_tokens*ggml_element_size(attn_scale));
126
+ }
127
+ }
128
+
129
+ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
130
+ if (pos_bucket) {
131
+ const int64_t n_tokens = ubatch->n_tokens;
132
+
133
+ GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
134
+ GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
135
+
136
+ int32_t * data = (int32_t *) pos_bucket->data;
137
+
138
+ for (int j = 0; j < n_tokens; ++j) {
139
+ for (int i = 0; i < n_tokens; ++i) {
140
+ data[j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true);
141
+ }
142
+ }
143
+ }
144
+ }
145
+
146
+ void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
147
+ if (pos_bucket) {
148
+ mctx->set_input_pos_bucket(pos_bucket, ubatch);
149
+ }
150
+ }
151
+
152
+ void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
153
+ GGML_ASSERT(out_ids);
154
+
155
+ const int64_t n_tokens = ubatch->n_tokens;
156
+
157
+ GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer));
158
+ int32_t * data = (int32_t *) out_ids->data;
159
+
160
+ if (n_outputs == n_tokens) {
161
+ for (int i = 0; i < n_tokens; ++i) {
162
+ data[i] = i;
163
+ }
164
+
165
+ return;
166
+ }
167
+
168
+ GGML_ASSERT(ubatch->output);
169
+
170
+ int n_outputs = 0;
171
+
172
+ for (int i = 0; i < n_tokens; ++i) {
173
+ if (ubatch->output[i]) {
174
+ data[n_outputs++] = i;
175
+ }
176
+ }
177
+ }
178
+
179
+ bool llm_graph_input_out_ids::can_reuse(const llm_graph_params & params) {
180
+ bool res = true;
181
+
182
+ res &= n_outputs == params.n_outputs;
183
+
184
+ return res;
185
+ }
186
+
187
+ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
188
+ if (cparams.embeddings &&
189
+ (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN ||
190
+ cparams.pooling_type == LLAMA_POOLING_TYPE_RANK )) {
191
+
192
+ const int64_t n_tokens = ubatch->n_tokens;
193
+ const int64_t n_seq_tokens = ubatch->n_seq_tokens;
194
+ const int64_t n_seqs_unq = ubatch->n_seqs_unq;
195
+
196
+ GGML_ASSERT(mean);
197
+ GGML_ASSERT(ggml_backend_buffer_is_host(mean->buffer));
198
+
199
+ float * data = (float *) mean->data;
200
+ memset(mean->data, 0, n_tokens*n_seqs_unq*ggml_element_size(mean));
201
+
202
+ std::vector<uint64_t> sums(n_seqs_unq, 0);
203
+ for (int i = 0; i < n_tokens; i += n_seq_tokens) {
204
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
205
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
206
+ const int32_t seq_idx = ubatch->seq_idx[seq_id];
207
+
208
+ sums[seq_idx] += ubatch->n_seq_tokens;
209
+ }
210
+ }
211
+
212
+ std::vector<float> div(n_seqs_unq, 0.0f);
213
+ for (int s = 0; s < n_seqs_unq; ++s) {
214
+ const uint64_t sum = sums[s];
215
+ if (sum > 0) {
216
+ div[s] = 1.0f/float(sum);
217
+ }
218
+ }
219
+
220
+ for (int i = 0; i < n_tokens; i += n_seq_tokens) {
221
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
222
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
223
+ const int32_t seq_idx = ubatch->seq_idx[seq_id];
224
+
225
+ for (int j = 0; j < n_seq_tokens; ++j) {
226
+ data[seq_idx*n_tokens + i + j] = div[seq_idx];
227
+ }
228
+ }
229
+ }
230
+ }
231
+ }
232
+
233
+ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
234
+ const int64_t n_tokens = ubatch->n_tokens;
235
+ const int64_t n_seqs_unq = ubatch->n_seqs_unq;
236
+
237
+ if (cparams.embeddings && (
238
+ cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
239
+ cparams.pooling_type == LLAMA_POOLING_TYPE_RANK ||
240
+ cparams.pooling_type == LLAMA_POOLING_TYPE_LAST
241
+ )) {
242
+ GGML_ASSERT(cls);
243
+ GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
244
+
245
+ uint32_t * data = (uint32_t *) cls->data;
246
+ memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
247
+
248
+ std::vector<int> target_pos(n_seqs_unq, -1);
249
+ std::vector<int> target_row(n_seqs_unq, -1);
250
+
251
+ const bool last = (
252
+ cparams.pooling_type == LLAMA_POOLING_TYPE_LAST ||
253
+ (cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && arch == LLM_ARCH_QWEN3) // qwen3 reranking & embedding models use last token
254
+ );
255
+
256
+ for (int i = 0; i < n_tokens; ++i) {
257
+ const llama_pos pos = ubatch->pos[i];
258
+
259
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
260
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
261
+ const int32_t seq_idx = ubatch->seq_idx[seq_id];
262
+
263
+ if (
264
+ (target_pos[seq_idx] == -1) ||
265
+ ( last && pos >= target_pos[seq_idx]) ||
266
+ (!last && pos < target_pos[seq_idx])
267
+ ) {
268
+ target_pos[seq_idx] = pos;
269
+ target_row[seq_idx] = i;
270
+ }
271
+ }
272
+ }
273
+
274
+ for (int s = 0; s < n_seqs_unq; ++s) {
275
+ if (target_row[s] >= 0) {
276
+ data[s] = target_row[s];
277
+ }
278
+ }
279
+ }
280
+ }
281
+
282
+ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
283
+ GGML_UNUSED(ubatch);
284
+
285
+ const int64_t n_rs = mctx->get_n_rs();
286
+
287
+ if (s_copy) {
288
+ GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
289
+ int32_t * data = (int32_t *) s_copy->data;
290
+
291
+ // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
292
+ for (uint32_t i = 0; i < n_rs; ++i) {
293
+ data[i] = mctx->s_copy(i);
294
+ }
295
+ }
296
+ }
297
+
298
+ bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) {
299
+ const auto * mctx = static_cast<const llama_memory_recurrent_context *>(params.mctx);
300
+
301
+ this->mctx = mctx;
302
+
303
+ bool res = true;
304
+
305
+ res &= s_copy->ne[0] == mctx->get_n_rs();
306
+
307
+ res &= s_copy_main->ne[0] == params.ubatch.n_seqs;
308
+ res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs;
309
+
310
+ res &= head == mctx->get_head();
311
+ res &= rs_z == mctx->get_rs_z();
312
+
313
+ return res;
314
+ }
315
+
316
+ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
317
+ GGML_UNUSED(ubatch);
318
+
319
+ if (cross_embd && !cross->v_embd.empty()) {
320
+ assert(cross_embd->type == GGML_TYPE_F32);
321
+
322
+ ggml_backend_tensor_set(cross_embd, cross->v_embd.data(), 0, ggml_nbytes(cross_embd));
323
+ }
324
+ }
325
+
326
+ static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
327
+ LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
328
+ const char * swa_type_str = "unknown";
329
+
330
+ switch (swa_type) {
331
+ case LLAMA_SWA_TYPE_NONE: swa_type_str = "LLAMA_SWA_TYPE_NONE"; break;
332
+ case LLAMA_SWA_TYPE_STANDARD: swa_type_str = "LLAMA_SWA_TYPE_STANDARD"; break;
333
+ case LLAMA_SWA_TYPE_CHUNKED: swa_type_str = "LLAMA_SWA_TYPE_CHUNKED"; break;
334
+ case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = "LLAMA_SWA_TYPE_SYMMETRIC"; break;
335
+ };
336
+
337
+ LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
338
+ LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
339
+ LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
340
+
341
+ LLAMA_LOG_DEBUG(" ");
342
+ for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) {
343
+ LLAMA_LOG_DEBUG("%2d", j);
344
+ }
345
+ LLAMA_LOG_DEBUG("\n");
346
+
347
+ for (int i = 0; i < std::min((int64_t)20, n_tokens); ++i) {
348
+ LLAMA_LOG_DEBUG(" %2d ", i);
349
+ for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) {
350
+ float val = data[i * n_kv + j];
351
+ if (val == -INFINITY) {
352
+ LLAMA_LOG_DEBUG(" ∞");
353
+ } else {
354
+ LLAMA_LOG_DEBUG(" 0");
355
+ }
356
+ }
357
+ LLAMA_LOG_DEBUG("\n");
358
+ }
359
+ }
360
+
361
+ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
362
+ const int64_t n_kv = ubatch->n_tokens;
363
+ const int64_t n_tokens = ubatch->n_tokens;
364
+
365
+ const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) {
366
+ for (int i1 = 0; i1 < n_tokens; ++i1) {
367
+ const llama_seq_id s1 = ubatch->seq_id[i1][0];
368
+ const llama_pos p1 = ubatch->pos[i1];
369
+
370
+ const uint64_t idst = i1*n_kv;
371
+
372
+ for (int i0 = 0; i0 < n_tokens; ++i0) {
373
+ const llama_seq_id s0 = ubatch->seq_id[i0][0];
374
+ const llama_pos p0 = ubatch->pos[i0];
375
+
376
+ // mask different sequences
377
+ if (s0 != s1) {
378
+ continue;
379
+ }
380
+
381
+ // mask future tokens
382
+ if (cparams.causal_attn && p0 > p1) {
383
+ continue;
384
+ }
385
+
386
+ // apply SWA if any
387
+ if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
388
+ continue;
389
+ }
390
+
391
+ data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
392
+ }
393
+ }
394
+ };
395
+
396
+ {
397
+ GGML_ASSERT(self_kq_mask);
398
+ GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
399
+
400
+ float * data = (float *) self_kq_mask->data;
401
+
402
+ std::fill(data, data + ggml_nelements(self_kq_mask), -INFINITY);
403
+
404
+ fill_mask(data, 0, LLAMA_SWA_TYPE_NONE);
405
+
406
+ if (debug) {
407
+ print_mask(data, n_tokens, n_kv, 0, LLAMA_SWA_TYPE_NONE);
408
+ }
409
+ }
410
+
411
+ if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
412
+ GGML_ASSERT(self_kq_mask_swa);
413
+ GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
414
+
415
+ float * data = (float *) self_kq_mask_swa->data;
416
+
417
+ std::fill(data, data + ggml_nelements(self_kq_mask_swa), -INFINITY);
418
+
419
+ fill_mask(data, hparams.n_swa, hparams.swa_type);
420
+
421
+ if (debug) {
422
+ print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
423
+ }
424
+ }
425
+ }
426
+
427
+ void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
428
+ mctx->set_input_k_idxs(self_k_idxs, ubatch);
429
+ mctx->set_input_v_idxs(self_v_idxs, ubatch);
430
+
431
+ mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
432
+ }
433
+
434
+ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
435
+ const auto * mctx = static_cast<const llama_kv_cache_context *>(params.mctx);
436
+
437
+ this->mctx = mctx;
438
+
439
+ bool res = true;
440
+
441
+ res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
442
+ //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
443
+
444
+ res &= can_reuse_kq_mask(self_kq_mask, mctx, params.ubatch, params.cparams);
445
+
446
+ return res;
447
+ }
448
+
449
+ void llm_graph_input_attn_k::set_input(const llama_ubatch * ubatch) {
450
+ mctx->set_input_k_idxs(self_k_idxs, ubatch);
451
+
452
+ mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
453
+ }
454
+
455
+ bool llm_graph_input_attn_k::can_reuse(const llm_graph_params & params) {
456
+ const auto * mctx = static_cast<const llama_kv_cache_context *>(params.mctx);
457
+
458
+ this->mctx = mctx;
459
+
460
+ bool res = true;
461
+
462
+ res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
463
+
464
+ res &= can_reuse_kq_mask(self_kq_mask, mctx, params.ubatch, params.cparams);
465
+
466
+ return res;
467
+ }
468
+
469
+ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
470
+ mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
471
+ mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
472
+
473
+ mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
474
+
475
+ mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
476
+ mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
477
+
478
+ mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
479
+ }
480
+
481
+ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
482
+ const auto * mctx = static_cast<const llama_kv_cache_iswa_context *>(params.mctx);
483
+
484
+ this->mctx = mctx;
485
+
486
+ bool res = true;
487
+
488
+ res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
489
+ //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
490
+
491
+ res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
492
+ //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
493
+
494
+ res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams);
495
+ res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams);
496
+
497
+ return res;
498
+ }
499
+
500
+ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
501
+ GGML_ASSERT(cross_kq_mask);
502
+
503
+ const int64_t n_enc = cross_kq_mask->ne[0];
504
+ const int64_t n_tokens = ubatch->n_tokens;
505
+
506
+ GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
507
+ GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
508
+
509
+ float * data = (float *) cross_kq_mask->data;
510
+
511
+ for (int i = 0; i < n_tokens; ++i) {
512
+ for (int j = 0; j < n_enc; ++j) {
513
+ float f = -INFINITY;
514
+
515
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
516
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
517
+
518
+ if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
519
+ f = 0.0f;
520
+ }
521
+ }
522
+
523
+ data[i*n_enc + j] = f;
524
+ }
525
+ }
526
+ }
527
+
528
+ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
529
+ mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
530
+ mctx->get_attn()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);
531
+
532
+ mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
533
+
534
+ const int64_t n_rs = mctx->get_recr()->get_n_rs();
535
+
536
+ if (inp_rs->s_copy) {
537
+ GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
538
+ int32_t * data = (int32_t *) inp_rs->s_copy->data;
539
+
540
+ // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
541
+ for (uint32_t i = 0; i < n_rs; ++i) {
542
+ data[i] = mctx->get_recr()->s_copy(i);
543
+ }
544
+ }
545
+ }
546
+
547
+ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
548
+ const auto * mctx = static_cast<const llama_memory_hybrid_context *>(params.mctx);
549
+
550
+ this->mctx = mctx;
551
+
552
+ bool res = true;
553
+
554
+ res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
555
+ //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
556
+
557
+ res &= can_reuse_kq_mask(inp_attn->self_kq_mask, mctx->get_attn(), params.ubatch, params.cparams);
558
+
559
+ res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
560
+
561
+ res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
562
+ res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
563
+
564
+ res &= inp_rs->head == mctx->get_recr()->get_head();
565
+ res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
566
+
567
+ return res;
568
+ }
569
+
570
+ // TODO: Hybrid input classes are a bit redundant.
571
+ // Instead of creating a hybrid input, the graph can simply create 2 separate inputs.
572
+ // Refactoring is required in the future.
573
+ void llm_graph_input_mem_hybrid_k::set_input(const llama_ubatch * ubatch) {
574
+ mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
575
+
576
+ mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
577
+
578
+ const int64_t n_rs = mctx->get_recr()->get_n_rs();
579
+
580
+ if (inp_rs->s_copy) {
581
+ GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
582
+ int32_t * data = (int32_t *) inp_rs->s_copy->data;
583
+
584
+ // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
585
+ for (uint32_t i = 0; i < n_rs; ++i) {
586
+ data[i] = mctx->get_recr()->s_copy(i);
587
+ }
588
+ }
589
+ }
590
+
591
+ bool llm_graph_input_mem_hybrid_k::can_reuse(const llm_graph_params & params) {
592
+ const auto * mctx = static_cast<const llama_memory_hybrid_context *>(params.mctx);
593
+
594
+ this->mctx = mctx;
595
+
596
+ bool res = true;
597
+
598
+ res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
599
+
600
+ res &= can_reuse_kq_mask(inp_attn->self_kq_mask, mctx->get_attn(), params.ubatch, params.cparams);
601
+
602
+ res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
603
+
604
+ res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
605
+ res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
606
+
607
+ res &= inp_rs->head == mctx->get_recr()->get_head();
608
+ res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
609
+
610
+ return res;
611
+ }
612
+
613
+ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
614
+ const auto * attn_ctx = mctx->get_attn();
615
+
616
+ // base tensors may not be allocated if there are no non-SWA attention layers
617
+ if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
618
+ attn_ctx->get_base()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
619
+ attn_ctx->get_base()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);
620
+
621
+ attn_ctx->get_base()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
622
+ }
623
+
624
+ // swa tensors may not be allocated if there are no SWA attention layers
625
+ if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
626
+ attn_ctx->get_swa()->set_input_k_idxs(inp_attn->self_k_idxs_swa, ubatch);
627
+ attn_ctx->get_swa()->set_input_v_idxs(inp_attn->self_v_idxs_swa, ubatch);
628
+
629
+ attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn);
630
+ }
631
+
632
+ const int64_t n_rs = mctx->get_recr()->get_n_rs();
633
+
634
+ if (inp_rs->s_copy) {
635
+ GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
636
+ int32_t * data = (int32_t *) inp_rs->s_copy->data;
637
+
638
+ // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
639
+ for (uint32_t i = 0; i < n_rs; ++i) {
640
+ data[i] = mctx->get_recr()->s_copy(i);
641
+ }
642
+ }
643
+ }
644
+
645
+ bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params) {
646
+ const auto * mctx = static_cast<const llama_memory_hybrid_iswa_context *>(params.mctx);
647
+
648
+ this->mctx = mctx;
649
+
650
+ bool res = true;
651
+
652
+ const auto * attn_ctx = mctx->get_attn();
653
+
654
+ // base tensors may not be allocated if there are no non-SWA attention layers
655
+ if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
656
+ res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
657
+ //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
658
+
659
+ res &= can_reuse_kq_mask(inp_attn->self_kq_mask, attn_ctx->get_base(), params.ubatch, params.cparams);
660
+ }
661
+
662
+ // swa tensors may not be allocated if there are no SWA attention layers
663
+ if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
664
+ res &= inp_attn->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
665
+ //res &= inp_attn->self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
666
+
667
+ res &= can_reuse_kq_mask(inp_attn->self_kq_mask_swa, attn_ctx->get_swa(), params.ubatch, params.cparams);
668
+ }
669
+
670
+ res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
671
+
672
+ res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
673
+ res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
674
+
675
+ res &= inp_rs->head == mctx->get_recr()->get_head();
676
+ res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
677
+
678
+ return res;
679
+ }
680
+
681
+ void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) {
682
+ // set the inputs only for the active samplers in the current ubatch
683
+ std::unordered_set<llama_seq_id> active_samplers;
684
+ for (uint32_t i = 0; i < ubatch->n_tokens; i++) {
685
+ if (ubatch->output[i]) {
686
+ llama_seq_id seq_id = ubatch->seq_id[i][0];
687
+ active_samplers.insert(seq_id);
688
+ }
689
+ }
690
+
691
+ for (auto seq_id : active_samplers) {
692
+ if (samplers.find(seq_id) == samplers.end()) {
693
+ continue;
694
+ }
695
+
696
+ auto & sampler = samplers[seq_id];
697
+
698
+ if (sampler->iface->backend_set_input) {
699
+ sampler->iface->backend_set_input(sampler);
700
+ }
701
+ }
702
+ }
703
+
704
+ bool llm_graph_input_sampling::can_reuse(const llm_graph_params & params) {
705
+ if (samplers.size() != params.samplers.size()) {
706
+ return false;
707
+ }
708
+
709
+ for (const auto & [seq_id, sampler] : params.samplers) {
710
+ if (samplers[seq_id] != sampler) {
711
+ return false;
712
+ }
713
+ }
714
+
715
+ return true;
716
+ }
717
+
718
+ //
719
+ // llm_graph_result
720
+ //
721
+
722
+ llm_graph_result::llm_graph_result(int64_t max_nodes) : max_nodes(max_nodes) {
723
+ reset();
724
+
725
+ const char * LLAMA_GRAPH_RESULT_DEBUG = getenv("LLAMA_GRAPH_RESULT_DEBUG");
726
+ debug = LLAMA_GRAPH_RESULT_DEBUG ? atoi(LLAMA_GRAPH_RESULT_DEBUG) : 0;
727
+ }
728
+
729
+ int64_t llm_graph_result::get_max_nodes() const {
730
+ return max_nodes;
731
+ }
732
+
733
+ void llm_graph_result::reset() {
734
+ t_inp_tokens = nullptr;
735
+ t_inp_embd = nullptr;
736
+ t_logits = nullptr;
737
+ t_embd = nullptr;
738
+ t_embd_pooled = nullptr;
739
+ t_sampled.clear();
740
+ t_sampled_probs.clear();
741
+ t_sampled_logits.clear();
742
+ t_candidates.clear();
743
+
744
+ params = {};
745
+
746
+ inputs.clear();
747
+
748
+ buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
749
+
750
+ ggml_init_params params = {
751
+ /*.mem_size =*/ buf_compute_meta.size(),
752
+ /*.mem_buffer =*/ buf_compute_meta.data(),
753
+ /*.no_alloc =*/ true,
754
+ };
755
+
756
+ ctx_compute.reset(ggml_init(params));
757
+
758
+ gf = ggml_new_graph_custom(ctx_compute.get(), max_nodes, false);
759
+ }
760
+
761
+ void llm_graph_result::set_inputs(const llama_ubatch * ubatch) {
762
+ for (auto & input : inputs) {
763
+ input->set_input(ubatch);
764
+ }
765
+ }
766
+
767
+ void llm_graph_result::set_outputs() {
768
+ if (t_logits != nullptr) {
769
+ ggml_set_output(t_logits);
770
+ }
771
+ if (t_embd != nullptr) {
772
+ ggml_set_output(t_embd);
773
+ }
774
+ if (t_embd_pooled != nullptr) {
775
+ ggml_set_output(t_embd_pooled);
776
+ }
777
+ for (auto & [seq_id, t] : t_sampled) {
778
+ if (t != nullptr) {
779
+ ggml_set_output(t);
780
+ }
781
+ }
782
+ for (auto & [seq_id, t] : t_sampled_probs) {
783
+ if (t != nullptr) {
784
+ ggml_set_output(t);
785
+ }
786
+ }
787
+ for (auto & [seq_id, t] : t_sampled_logits) {
788
+ if (t != nullptr) {
789
+ ggml_set_output(t);
790
+ }
791
+ }
792
+ for (auto & [seq_id, t] : t_candidates) {
793
+ if (t != nullptr) {
794
+ ggml_set_output(t);
795
+ }
796
+ }
797
+ }
798
+
799
+ bool llm_graph_result::can_reuse(const llm_graph_params & params) {
800
+ if (!this->params.allow_reuse(params)) {
801
+ if (debug > 1) {
802
+ LLAMA_LOG_DEBUG("%s: cannot reuse graph due to incompatible graph parameters\n", __func__);
803
+ }
804
+
805
+ return false;
806
+ }
807
+
808
+ if (debug > 1) {
809
+ LLAMA_LOG_DEBUG("%s: checking compatibility of %d inputs:\n", __func__, (int) inputs.size());
810
+ }
811
+
812
+ bool res = true;
813
+
814
+ for (auto & input : inputs) {
815
+ const bool cur = input->can_reuse(params);
816
+
817
+ if (debug > 1) {
818
+ LLAMA_LOG_DEBUG("%s: can_reuse = %d\n", "placeholder", cur);
819
+ }
820
+
821
+ res = res && cur;
822
+ }
823
+
824
+ if (debug > 0) {
825
+ LLAMA_LOG_DEBUG("%s: can reuse graph = %d\n", __func__, res);
826
+ }
827
+
828
+ return res;
829
+ }
830
+
831
+ llm_graph_input_i * llm_graph_result::add_input(llm_graph_input_ptr input) {
832
+ inputs.emplace_back(std::move(input));
833
+ return inputs.back().get();
834
+ }
835
+
836
+ void llm_graph_result::set_params(const llm_graph_params & params) {
837
+ this->params = params;
838
+ }
839
+
840
+ //
841
+ // llm_graph_context
842
+ //
843
+
844
+ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
845
+ arch (params.arch),
846
+ hparams (params.hparams),
847
+ cparams (params.cparams),
848
+ ubatch (params.ubatch),
849
+ n_embd (hparams.n_embd),
850
+ n_layer (hparams.n_layer),
851
+ n_rot (hparams.n_rot),
852
+ n_ctx (cparams.n_ctx),
853
+ n_head (hparams.n_head()),
854
+ n_head_kv (hparams.n_head_kv()),
855
+ n_embd_head_k (hparams.n_embd_head_k),
856
+ n_embd_k_gqa (hparams.n_embd_k_gqa()),
857
+ n_embd_head_v (hparams.n_embd_head_v),
858
+ n_embd_v_gqa (hparams.n_embd_v_gqa()),
859
+ n_expert (hparams.n_expert),
860
+ n_expert_used (cparams.warmup ? hparams.n_expert : hparams.n_expert_used),
861
+ freq_base (cparams.rope_freq_base),
862
+ freq_scale (cparams.rope_freq_scale),
863
+ ext_factor (cparams.yarn_ext_factor),
864
+ attn_factor (cparams.yarn_attn_factor),
865
+ beta_fast (cparams.yarn_beta_fast),
866
+ beta_slow (cparams.yarn_beta_slow),
867
+ norm_eps (hparams.f_norm_eps),
868
+ norm_rms_eps (hparams.f_norm_rms_eps),
869
+ n_tokens (ubatch.n_tokens),
870
+ n_outputs (params.n_outputs),
871
+ n_ctx_orig (cparams.n_ctx_orig_yarn),
872
+ pooling_type (cparams.pooling_type),
873
+ rope_type (hparams.rope_type),
874
+ sched (params.sched),
875
+ backend_cpu (params.backend_cpu),
876
+ cvec (params.cvec),
877
+ loras (params.loras),
878
+ mctx (params.mctx),
879
+ cross (params.cross),
880
+ samplers (params.samplers),
881
+ cb_func (params.cb),
882
+ res (params.res),
883
+ ctx0 (res->get_ctx()),
884
+ gf (res->get_gf()) {
885
+ res->set_params(params);
886
+ }
887
+
888
+ void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
889
+ if (cb_func) {
890
+ cb_func(ubatch, cur, name, il);
891
+ }
892
+ }
893
+
894
+ ggml_tensor * llm_graph_context::build_cvec(
895
+ ggml_tensor * cur,
896
+ int il) const {
897
+ return cvec->apply_to(ctx0, cur, il);
898
+ }
899
+
900
+ ggml_tensor * llm_graph_context::build_lora_mm(
901
+ ggml_tensor * w,
902
+ ggml_tensor * cur) const {
903
+ ggml_tensor * res = ggml_mul_mat(ctx0, w, cur);
904
+
905
+ for (const auto & lora : *loras) {
906
+ llama_adapter_lora_weight * lw = lora.first->get_weight(w);
907
+ if (lw == nullptr) {
908
+ continue;
909
+ }
910
+
911
+ const float adapter_scale = lora.second;
912
+ const float scale = lw->get_scale(lora.first->alpha, adapter_scale);
913
+
914
+ ggml_tensor * ab_cur = ggml_mul_mat(
915
+ ctx0, lw->b,
916
+ ggml_mul_mat(ctx0, lw->a, cur)
917
+ );
918
+
919
+ ab_cur = ggml_scale(ctx0, ab_cur, scale);
920
+ res = ggml_add(ctx0, res, ab_cur);
921
+ }
922
+
923
+ return res;
924
+ }
925
+
926
+ ggml_tensor * llm_graph_context::build_lora_mm_id(
927
+ ggml_tensor * w, // ggml_tensor * as
928
+ ggml_tensor * cur, // ggml_tensor * b
929
+ ggml_tensor * ids) const {
930
+ ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids);
931
+ for (const auto & lora : *loras) {
932
+ llama_adapter_lora_weight * lw = lora.first->get_weight(w);
933
+ if (lw == nullptr) {
934
+ continue;
935
+ }
936
+
937
+ const float alpha = lora.first->alpha;
938
+ const float rank = (float) lw->b->ne[0];
939
+ const float scale = alpha ? lora.second * alpha / rank : lora.second;
940
+
941
+ ggml_tensor * ab_cur = ggml_mul_mat_id(
942
+ ctx0, lw->b,
943
+ ggml_mul_mat_id(ctx0, lw->a, cur, ids),
944
+ ids
945
+ );
946
+
947
+ ab_cur = ggml_scale(ctx0, ab_cur, scale);
948
+ res = ggml_add(ctx0, res, ab_cur);
949
+ }
950
+
951
+ return res;
952
+ }
953
+
954
+ ggml_tensor * llm_graph_context::build_norm(
955
+ ggml_tensor * cur,
956
+ ggml_tensor * mw,
957
+ ggml_tensor * mb,
958
+ llm_norm_type type,
959
+ int il) const {
960
+ switch (type) {
961
+ case LLM_NORM: cur = ggml_norm (ctx0, cur, hparams.f_norm_eps); break;
962
+ case LLM_NORM_RMS: cur = ggml_rms_norm(ctx0, cur, hparams.f_norm_rms_eps); break;
963
+ case LLM_NORM_GROUP:
964
+ {
965
+ cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], 1, cur->ne[1]);
966
+ cur = ggml_group_norm(ctx0, cur, hparams.n_norm_groups, hparams.f_norm_group_eps);
967
+ cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[2]);
968
+ } break;
969
+ }
970
+
971
+ if (mw || mb) {
972
+ cb(cur, "norm", il);
973
+ }
974
+
975
+ if (mw) {
976
+ cur = ggml_mul(ctx0, cur, mw);
977
+ if (mb) {
978
+ cb(cur, "norm_w", il);
979
+ }
980
+ }
981
+
982
+ if (mb) {
983
+ cur = ggml_add(ctx0, cur, mb);
984
+ }
985
+
986
+ return cur;
987
+ }
988
+
989
+ ggml_tensor * llm_graph_context::build_ffn(
990
+ ggml_tensor * cur,
991
+ ggml_tensor * up,
992
+ ggml_tensor * up_b,
993
+ ggml_tensor * up_s,
994
+ ggml_tensor * gate,
995
+ ggml_tensor * gate_b,
996
+ ggml_tensor * gate_s,
997
+ ggml_tensor * down,
998
+ ggml_tensor * down_b,
999
+ ggml_tensor * down_s,
1000
+ ggml_tensor * act_scales,
1001
+ llm_ffn_op_type type_op,
1002
+ llm_ffn_gate_type type_gate,
1003
+ int il) const {
1004
+ ggml_tensor * tmp = up ? build_lora_mm(up, cur) : cur;
1005
+ cb(tmp, "ffn_up", il);
1006
+
1007
+ if (up_b) {
1008
+ tmp = ggml_add(ctx0, tmp, up_b);
1009
+ cb(tmp, "ffn_up_b", il);
1010
+ }
1011
+
1012
+ if (up_s) {
1013
+ tmp = ggml_mul(ctx0, tmp, up_s);
1014
+ cb(tmp, "ffn_up_s", il);
1015
+ }
1016
+
1017
+ if (gate) {
1018
+ switch (type_gate) {
1019
+ case LLM_FFN_SEQ:
1020
+ {
1021
+ cur = build_lora_mm(gate, tmp);
1022
+ cb(cur, "ffn_gate", il);
1023
+ } break;
1024
+ case LLM_FFN_PAR:
1025
+ {
1026
+ cur = build_lora_mm(gate, cur);
1027
+ cb(cur, "ffn_gate", il);
1028
+ } break;
1029
+ }
1030
+
1031
+ if (gate_b) {
1032
+ cur = ggml_add(ctx0, cur, gate_b);
1033
+ cb(cur, "ffn_gate_b", il);
1034
+ }
1035
+
1036
+ if (gate_s) {
1037
+ cur = ggml_mul(ctx0, cur, gate_s);
1038
+ cb(cur, "ffn_gate_s", il);
1039
+ }
1040
+
1041
+ } else {
1042
+ cur = tmp;
1043
+ }
1044
+
1045
+ switch (type_op) {
1046
+ case LLM_FFN_SILU:
1047
+ if (gate && type_gate == LLM_FFN_PAR) {
1048
+ // Step35: HF clamps gate (after SiLU) and up before multiplication
1049
+ if (arch == LLM_ARCH_STEP35 && il >= 0) {
1050
+ const float limit = hparams.swiglu_clamp_shexp[il];
1051
+ constexpr float eps = 1e-6f;
1052
+ if (limit > eps) {
1053
+ ggml_tensor * gate_act = ggml_silu(ctx0, cur);
1054
+ cb(gate_act, "ffn_silu", il);
1055
+ gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit);
1056
+ cb(gate_act, "ffn_silu_clamped", il);
1057
+
1058
+ tmp = ggml_clamp(ctx0, tmp, -limit, limit);
1059
+ cb(tmp, "ffn_up_clamped", il);
1060
+
1061
+ cur = ggml_mul(ctx0, gate_act, tmp);
1062
+ cb(cur, "ffn_swiglu_limited", il);
1063
+ type_gate = LLM_FFN_SEQ;
1064
+ break;
1065
+ }
1066
+ }
1067
+
1068
+ cur = ggml_swiglu_split(ctx0, cur, tmp);
1069
+ cb(cur, "ffn_swiglu", il);
1070
+ type_gate = LLM_FFN_SEQ;
1071
+ } else {
1072
+ cur = ggml_silu(ctx0, cur);
1073
+ cb(cur, "ffn_silu", il);
1074
+ } break;
1075
+ case LLM_FFN_GELU:
1076
+ if (gate && type_gate == LLM_FFN_PAR) {
1077
+ cur = ggml_geglu_split(ctx0, cur, tmp);
1078
+ cb(cur, "ffn_geglu", il);
1079
+ type_gate = LLM_FFN_SEQ;
1080
+ } else {
1081
+ cur = ggml_gelu(ctx0, cur);
1082
+ cb(cur, "ffn_gelu", il);
1083
+ if (act_scales != NULL) {
1084
+ cur = ggml_div(ctx0, cur, act_scales);
1085
+ cb(cur, "ffn_act", il);
1086
+ }
1087
+ } break;
1088
+ case LLM_FFN_RELU:
1089
+ if (gate && type_gate == LLM_FFN_PAR) {
1090
+ cur = ggml_reglu_split(ctx0, cur, tmp);
1091
+ cb(cur, "ffn_reglu", il);
1092
+ type_gate = LLM_FFN_SEQ;
1093
+ } else {
1094
+ cur = ggml_relu(ctx0, cur);
1095
+ cb(cur, "ffn_relu", il);
1096
+ } break;
1097
+ case LLM_FFN_RELU_SQR:
1098
+ {
1099
+ cur = ggml_relu(ctx0, cur);
1100
+ cb(cur, "ffn_relu", il);
1101
+
1102
+ cur = ggml_sqr(ctx0, cur);
1103
+ cb(cur, "ffn_sqr(relu)", il);
1104
+ } break;
1105
+ case LLM_FFN_SWIGLU:
1106
+ {
1107
+ cur = ggml_swiglu(ctx0, cur);
1108
+ cb(cur, "ffn_swiglu", il);
1109
+ } break;
1110
+ case LLM_FFN_GEGLU:
1111
+ {
1112
+ cur = ggml_geglu(ctx0, cur);
1113
+ cb(cur, "ffn_geglu", il);
1114
+ } break;
1115
+ case LLM_FFN_REGLU:
1116
+ {
1117
+ cur = ggml_reglu(ctx0, cur);
1118
+ cb(cur, "ffn_reglu", il);
1119
+ } break;
1120
+ default:
1121
+ GGML_ABORT("fatal error");
1122
+ }
1123
+
1124
+ if (gate && type_gate == LLM_FFN_PAR) {
1125
+ cur = ggml_mul(ctx0, cur, tmp);
1126
+ cb(cur, "ffn_gate_par", il);
1127
+ }
1128
+
1129
+ if (down) {
1130
+ cur = build_lora_mm(down, cur);
1131
+ if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) {
1132
+ // GLM4, GLM4_MOE, and JAIS2 seem to have numerical issues with half-precision accumulators
1133
+ ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
1134
+ }
1135
+ }
1136
+
1137
+ if (down_b) {
1138
+ cb(cur, "ffn_down", il);
1139
+ }
1140
+
1141
+ if (down_b) {
1142
+ cur = ggml_add(ctx0, cur, down_b);
1143
+ }
1144
+
1145
+ if (down_s) {
1146
+ cur = ggml_mul(ctx0, cur, down_s);
1147
+ cb(cur, "ffn_down_s", il);
1148
+ }
1149
+
1150
+ return cur;
1151
+ }
1152
+
1153
+ ggml_tensor * llm_graph_context::build_moe_ffn(
1154
+ ggml_tensor * cur,
1155
+ ggml_tensor * gate_inp,
1156
+ ggml_tensor * up_exps,
1157
+ ggml_tensor * gate_exps,
1158
+ ggml_tensor * down_exps,
1159
+ ggml_tensor * exp_probs_b,
1160
+ int64_t n_expert,
1161
+ int64_t n_expert_used,
1162
+ llm_ffn_op_type type_op,
1163
+ bool norm_w,
1164
+ bool scale_w,
1165
+ float w_scale,
1166
+ llama_expert_gating_func_type gating_op,
1167
+ int il,
1168
+ ggml_tensor * probs_in,
1169
+ ggml_tensor * gate_up_exps) const {
1170
+ return build_moe_ffn(
1171
+ cur,
1172
+ gate_inp, /* gate_inp_b */ nullptr,
1173
+ up_exps, /* up_exps_b */ nullptr,
1174
+ gate_exps, /* gate_exps_b */ nullptr,
1175
+ down_exps, /* down_exps_b */ nullptr,
1176
+ exp_probs_b,
1177
+ n_expert,
1178
+ n_expert_used,
1179
+ type_op,
1180
+ norm_w,
1181
+ scale_w,
1182
+ w_scale,
1183
+ gating_op,
1184
+ il,
1185
+ probs_in,
1186
+ gate_up_exps
1187
+ );
1188
+ }
1189
+
1190
+ ggml_tensor * llm_graph_context::build_moe_ffn(
1191
+ ggml_tensor * cur,
1192
+ ggml_tensor * gate_inp,
1193
+ ggml_tensor * gate_inp_b,
1194
+ ggml_tensor * up_exps,
1195
+ ggml_tensor * up_exps_b,
1196
+ ggml_tensor * gate_exps,
1197
+ ggml_tensor * gate_exps_b,
1198
+ ggml_tensor * down_exps,
1199
+ ggml_tensor * down_exps_b,
1200
+ ggml_tensor * exp_probs_b,
1201
+ int64_t n_expert,
1202
+ int64_t n_expert_used,
1203
+ llm_ffn_op_type type_op,
1204
+ bool norm_w,
1205
+ bool scale_w,
1206
+ float w_scale,
1207
+ llama_expert_gating_func_type gating_op,
1208
+ int il,
1209
+ ggml_tensor * probs_in,
1210
+ ggml_tensor * gate_up_exps,
1211
+ ggml_tensor * gate_up_exps_b) const {
1212
+ const int64_t n_embd = cur->ne[0];
1213
+ const int64_t n_tokens = cur->ne[1];
1214
+ const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
1215
+
1216
+ ggml_tensor * logits = nullptr;
1217
+
1218
+ if (probs_in == nullptr) {
1219
+ logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens]
1220
+ cb(logits, "ffn_moe_logits", il);
1221
+ } else {
1222
+ logits = probs_in;
1223
+ }
1224
+
1225
+ if (gate_inp_b) {
1226
+ logits = ggml_add(ctx0, logits, gate_inp_b);
1227
+ cb(logits, "ffn_moe_logits_biased", il);
1228
+ }
1229
+
1230
+ ggml_tensor * probs = nullptr;
1231
+ switch (gating_op) {
1232
+ case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX:
1233
+ {
1234
+ probs = ggml_soft_max(ctx0, logits); // [n_expert, n_tokens]
1235
+ } break;
1236
+ case LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID:
1237
+ {
1238
+ probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
1239
+ } break;
1240
+ case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT:
1241
+ {
1242
+ probs = logits; // [n_expert, n_tokens]
1243
+ } break;
1244
+ default:
1245
+ GGML_ABORT("fatal error");
1246
+ }
1247
+ cb(probs, "ffn_moe_probs", il);
1248
+
1249
+ // add experts selection bias - introduced in DeepSeek V3
1250
+ // leave probs unbiased as it's later used to get expert weights
1251
+ ggml_tensor * selection_probs = probs;
1252
+ if (exp_probs_b != nullptr) {
1253
+ selection_probs = ggml_add(ctx0, probs, exp_probs_b);
1254
+ cb(selection_probs, "ffn_moe_probs_biased", il);
1255
+ }
1256
+
1257
+ // llama4 doesn't have exp_probs_b, and sigmoid is only used after top_k
1258
+ // see: https://github.com/meta-llama/llama-models/blob/699a02993512fb36936b1b0741e13c06790bcf98/models/llama4/moe.py#L183-L198
1259
+ if (arch == LLM_ARCH_LLAMA4) {
1260
+ selection_probs = logits;
1261
+ }
1262
+
1263
+ if (arch == LLM_ARCH_GROVEMOE) {
1264
+ selection_probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
1265
+ cb(selection_probs, "ffn_moe_probs_biased", il);
1266
+ }
1267
+
1268
+ // select top n_group_used expert groups
1269
+ // https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/e815299b0bcbac849fa540c768ef21845365c9eb/modeling_deepseek.py#L440-L457
1270
+ if (hparams.n_expert_groups > 1 && n_tokens > 0) {
1271
+ const int64_t n_exp_per_group = n_expert / hparams.n_expert_groups;
1272
+
1273
+ // organize experts into n_expert_groups
1274
+ ggml_tensor * selection_groups = ggml_reshape_3d(ctx0, selection_probs, n_exp_per_group, hparams.n_expert_groups, n_tokens); // [n_exp_per_group, n_expert_groups, n_tokens]
1275
+
1276
+ ggml_tensor * group_scores = ggml_argsort_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens]
1277
+ group_scores = ggml_get_rows(ctx0, ggml_reshape_4d(ctx0, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1], selection_groups->ne[2]), group_scores); // [1, 2, n_expert_groups, n_tokens]
1278
+
1279
+ // get top n_group_used expert groups
1280
+ group_scores = ggml_sum_rows(ctx0, ggml_reshape_3d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2], group_scores->ne[3])); // [1, n_expert_groups, n_tokens]
1281
+ group_scores = ggml_reshape_2d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2]); // [n_expert_groups, n_tokens]
1282
+
1283
+ ggml_tensor * expert_groups = ggml_argsort_top_k(ctx0, group_scores, hparams.n_group_used); // [n_group_used, n_tokens]
1284
+ cb(expert_groups, "ffn_moe_group_topk", il);
1285
+
1286
+ // mask out the other groups
1287
+ selection_probs = ggml_get_rows(ctx0, selection_groups, expert_groups); // [n_exp_per_group, n_group_used, n_tokens]
1288
+ selection_probs = ggml_set_rows(ctx0, ggml_fill(ctx0, selection_groups, -INFINITY), selection_probs, expert_groups); // [n_exp_per_group, n_expert_groups, n_tokens]
1289
+ selection_probs = ggml_reshape_2d(ctx0, selection_probs, n_expert, n_tokens); // [n_expert, n_tokens]
1290
+ cb(selection_probs, "ffn_moe_probs_masked", il);
1291
+ }
1292
+
1293
+ // select experts
1294
+ ggml_tensor * selected_experts = ggml_argsort_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
1295
+ cb(selected_experts->src[0], "ffn_moe_argsort", il);
1296
+ cb(selected_experts, "ffn_moe_topk", il);
1297
+
1298
+ if (arch == LLM_ARCH_GROVEMOE && n_expert != hparams.n_expert) {
1299
+ // TODO: Use scalar div instead when/if implemented
1300
+ ggml_tensor * f_sel = ggml_cast(ctx0, selected_experts, GGML_TYPE_F32);
1301
+ selected_experts = ggml_cast(ctx0, ggml_scale(ctx0, f_sel, 1.0f / float(hparams.n_group_experts)), GGML_TYPE_I32);
1302
+ probs = ggml_reshape_3d(ctx0, probs, 1, hparams.n_expert, n_tokens);
1303
+ } else {
1304
+ probs = ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens);
1305
+ }
1306
+
1307
+ ggml_tensor * weights = ggml_get_rows(ctx0, probs, selected_experts); // [1, n_expert_used, n_tokens]
1308
+ cb(weights, "ffn_moe_weights", il);
1309
+
1310
+
1311
+ if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) {
1312
+ weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
1313
+ weights = ggml_soft_max(ctx0, weights); // [n_expert_used, n_tokens]
1314
+ weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
1315
+ cb(weights, "ffn_moe_weights_softmax", il);
1316
+ }
1317
+
1318
+ if (norm_w) {
1319
+ weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
1320
+
1321
+ ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens]
1322
+ cb(weights_sum, "ffn_moe_weights_sum", il);
1323
+
1324
+ // Avoid division by zero, clamp to smallest number representable by F16
1325
+ weights_sum = ggml_clamp(ctx0, weights_sum, 6.103515625e-5, INFINITY);
1326
+ cb(weights_sum, "ffn_moe_weights_sum_clamped", il);
1327
+
1328
+ weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens]
1329
+ cb(weights, "ffn_moe_weights_norm", il);
1330
+
1331
+ weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
1332
+ }
1333
+ if (scale_w) {
1334
+ weights = ggml_scale(ctx0, weights, w_scale);
1335
+ cb(weights, "ffn_moe_weights_scaled", il);
1336
+ }
1337
+
1338
+ //call early so that topk-moe can be used
1339
+ ggml_build_forward_expand(gf, weights);
1340
+
1341
+ cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
1342
+
1343
+ if (weight_before_ffn) {
1344
+ // repeat cur to [n_embd, n_expert_used, n_tokens]
1345
+ ggml_tensor * repeated = ggml_repeat_4d(ctx0, cur, n_embd, n_expert_used, n_tokens, 1);
1346
+ cur = ggml_mul(ctx0, repeated, weights);
1347
+ cb(cur, "ffn_moe_weighted", il);
1348
+ }
1349
+
1350
+ ggml_tensor * up = nullptr;
1351
+ ggml_tensor * experts = nullptr;
1352
+
1353
+ if (gate_up_exps) {
1354
+ // merged gate_up path: one mul_mat_id, then split into gate and up views
1355
+ ggml_tensor * gate_up = build_lora_mm_id(gate_up_exps, cur, selected_experts); // [n_ff*2, n_expert_used, n_tokens]
1356
+ cb(gate_up, "ffn_moe_gate_up", il);
1357
+
1358
+ if (gate_up_exps_b) {
1359
+ gate_up = ggml_add_id(ctx0, gate_up, gate_up_exps_b, selected_experts);
1360
+ cb(gate_up, "ffn_moe_gate_up_biased", il);
1361
+ }
1362
+
1363
+ const int64_t n_ff = gate_up->ne[0] / 2;
1364
+ cur = ggml_view_3d(ctx0, gate_up, n_ff, gate_up->ne[1], gate_up->ne[2], gate_up->nb[1], gate_up->nb[2], 0);
1365
+ cb(cur, "ffn_moe_gate", il);
1366
+ up = ggml_view_3d(ctx0, gate_up, n_ff, gate_up->ne[1], gate_up->ne[2], gate_up->nb[1], gate_up->nb[2], n_ff * gate_up->nb[0]);
1367
+ cb(up, "ffn_moe_up", il);
1368
+ } else {
1369
+ // separate gate and up path
1370
+ up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
1371
+ cb(up, "ffn_moe_up", il);
1372
+
1373
+ if (up_exps_b) {
1374
+ up = ggml_add_id(ctx0, up, up_exps_b, selected_experts);
1375
+ cb(up, "ffn_moe_up_biased", il);
1376
+ }
1377
+
1378
+ if (gate_exps) {
1379
+ cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
1380
+ cb(cur, "ffn_moe_gate", il);
1381
+ } else {
1382
+ cur = up;
1383
+ }
1384
+
1385
+ if (gate_exps_b) {
1386
+ cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts);
1387
+ cb(cur, "ffn_moe_gate_biased", il);
1388
+ }
1389
+ }
1390
+
1391
+ const bool has_gate = gate_exps || gate_up_exps;
1392
+
1393
+ switch (type_op) {
1394
+ case LLM_FFN_SILU:
1395
+ if (gate_exps) {
1396
+ // Step35: per-layer clamp for routed experts
1397
+ if (arch == LLM_ARCH_STEP35 && il >= 0) {
1398
+ const float limit = hparams.swiglu_clamp_exp[il];
1399
+ constexpr float eps = 1e-6f;
1400
+ if (limit > eps) {
1401
+ ggml_tensor * gate_act = ggml_silu(ctx0, cur);
1402
+ cb(gate_act, "ffn_moe_silu", il);
1403
+ gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit);
1404
+ cb(gate_act, "ffn_moe_silu_clamped", il);
1405
+
1406
+ up = ggml_clamp(ctx0, up, -limit, limit);
1407
+ cb(up, "ffn_moe_up_clamped", il);
1408
+
1409
+ cur = ggml_mul(ctx0, gate_act, up);
1410
+ cb(cur, "ffn_moe_swiglu_limited", il);
1411
+ break;
1412
+ }
1413
+ }
1414
+ }
1415
+
1416
+ if (has_gate) {
1417
+ cur = ggml_swiglu_split(ctx0, cur, up);
1418
+ cb(cur, "ffn_moe_swiglu", il);
1419
+ } else {
1420
+ cur = ggml_silu(ctx0, cur);
1421
+ cb(cur, "ffn_moe_silu", il);
1422
+ } break;
1423
+ case LLM_FFN_GELU:
1424
+ if (has_gate) {
1425
+ cur = ggml_geglu_split(ctx0, cur, up);
1426
+ cb(cur, "ffn_moe_geglu", il);
1427
+ } else {
1428
+ cur = ggml_gelu(ctx0, cur);
1429
+ cb(cur, "ffn_moe_gelu", il);
1430
+ } break;
1431
+ case LLM_FFN_SWIGLU_OAI_MOE:
1432
+ {
1433
+ // TODO: move to hparams?
1434
+ constexpr float alpha = 1.702f;
1435
+ constexpr float limit = 7.0f;
1436
+ cur = ggml_swiglu_oai(ctx0, cur, up, alpha, limit);
1437
+ cb(cur, "ffn_moe_swiglu_oai", il);
1438
+ } break;
1439
+ case LLM_FFN_RELU:
1440
+ if (has_gate) {
1441
+ cur = ggml_reglu_split(ctx0, cur, up);
1442
+ cb(cur, "ffn_moe_reglu", il);
1443
+ } else {
1444
+ cur = ggml_relu(ctx0, cur);
1445
+ cb(cur, "ffn_moe_relu", il);
1446
+ } break;
1447
+ case LLM_FFN_RELU_SQR:
1448
+ if (has_gate) {
1449
+ // TODO: add support for gated squared relu
1450
+ GGML_ABORT("fatal error: gated squared relu not implemented");
1451
+ } else {
1452
+ cur = ggml_relu(ctx0, cur);
1453
+ cur = ggml_sqr(ctx0, cur);
1454
+ cb(cur, "ffn_moe_relu_sqr", il);
1455
+ } break;
1456
+ default:
1457
+ GGML_ABORT("fatal error");
1458
+ }
1459
+
1460
+ experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
1461
+ cb(experts, "ffn_moe_down", il);
1462
+
1463
+ if (down_exps_b) {
1464
+ experts = ggml_add_id(ctx0, experts, down_exps_b, selected_experts);
1465
+ cb(experts, "ffn_moe_down_biased", il);
1466
+ }
1467
+
1468
+ if (!weight_before_ffn) {
1469
+ experts = ggml_mul(ctx0, experts, weights);
1470
+ cb(cur, "ffn_moe_weighted", il);
1471
+ }
1472
+
1473
+ ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr };
1474
+
1475
+ assert(n_expert_used > 0);
1476
+
1477
+ // order the views before the adds
1478
+ for (uint32_t i = 0; i < hparams.n_expert_used; ++i) {
1479
+ cur_experts[i] = ggml_view_2d(ctx0, experts, n_embd, n_tokens, experts->nb[2], i*experts->nb[1]);
1480
+
1481
+ ggml_build_forward_expand(gf, cur_experts[i]);
1482
+ }
1483
+
1484
+ // aggregate experts
1485
+ // note: here we explicitly use hparams.n_expert_used instead of n_expert_used
1486
+ // to avoid potentially a large number of add nodes during warmup
1487
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14753
1488
+ ggml_tensor * moe_out = cur_experts[0];
1489
+
1490
+ for (uint32_t i = 1; i < hparams.n_expert_used; ++i) {
1491
+ moe_out = ggml_add(ctx0, moe_out, cur_experts[i]);
1492
+ }
1493
+
1494
+ if (hparams.n_expert_used == 1) {
1495
+ // avoid returning a non-contiguous tensor
1496
+ moe_out = ggml_cont(ctx0, moe_out);
1497
+ }
1498
+
1499
+ cb(moe_out, "ffn_moe_out", il);
1500
+
1501
+ return moe_out;
1502
+ }
1503
+
1504
+ // input embeddings with optional lora
1505
+ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
1506
+ const int64_t n_embd_inp = hparams.n_embd_inp();
1507
+ const int64_t n_embd = hparams.n_embd;
1508
+
1509
+ assert(n_embd_inp >= n_embd);
1510
+
1511
+ auto inp = std::make_unique<llm_graph_input_embd>(n_embd_inp);
1512
+
1513
+ inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
1514
+ cb(inp->tokens, "inp_tokens", -1);
1515
+ ggml_set_input(inp->tokens);
1516
+ res->t_inp_tokens = inp->tokens;
1517
+
1518
+ inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_inp, ubatch.n_tokens);
1519
+ cb(inp->embd, "inp_embd", -1);
1520
+ ggml_set_input(inp->embd);
1521
+
1522
+ // select one of the 2 inputs, based on the batch contents
1523
+ // ref: https://github.com/ggml-org/llama.cpp/pull/18550
1524
+ std::array<ggml_tensor *, 2> inps;
1525
+
1526
+ // token embeddings path (ubatch.token != nullptr)
1527
+ {
1528
+ auto & cur = inps[0];
1529
+
1530
+ cur = ggml_get_rows(ctx0, tok_embd, inp->tokens);
1531
+
1532
+ // apply lora for embedding tokens if needed
1533
+ for (const auto & lora : *loras) {
1534
+ llama_adapter_lora_weight * lw = lora.first->get_weight(tok_embd);
1535
+ if (lw == nullptr) {
1536
+ continue;
1537
+ }
1538
+
1539
+ const float adapter_scale = lora.second;
1540
+ const float scale = lw->get_scale(lora.first->alpha, adapter_scale);
1541
+
1542
+ ggml_tensor * inpL_delta = ggml_scale(ctx0, ggml_mul_mat(
1543
+ ctx0, lw->b, // non-transposed lora_b
1544
+ ggml_get_rows(ctx0, lw->a, inp->tokens)
1545
+ ), scale);
1546
+
1547
+ cur = ggml_add(ctx0, cur, inpL_delta);
1548
+ }
1549
+
1550
+ if (n_embd_inp != n_embd) {
1551
+ cur = ggml_pad(ctx0, cur, hparams.n_embd_inp() - n_embd, 0, 0, 0);
1552
+ }
1553
+ }
1554
+
1555
+ // vector embeddings path (ubatch.embd != nullptr)
1556
+ {
1557
+ auto & cur = inps[1];
1558
+
1559
+ cur = inp->embd;
1560
+ }
1561
+
1562
+ assert(ggml_are_same_shape (inps[0], inps[1]));
1563
+ assert(ggml_are_same_stride(inps[0], inps[1]));
1564
+
1565
+ ggml_tensor * cur = ggml_build_forward_select(gf, inps.data(), inps.size(), ubatch.token ? 0 : 1);
1566
+
1567
+ if (n_embd_inp != n_embd) {
1568
+ cur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0);
1569
+ }
1570
+
1571
+ res->t_inp_embd = cur;
1572
+
1573
+ // For Granite architecture
1574
+ if (hparams.f_embedding_scale != 0.0f) {
1575
+ cur = ggml_scale(ctx0, cur, hparams.f_embedding_scale);
1576
+ }
1577
+
1578
+ cb(cur, "embd", -1);
1579
+
1580
+ res->add_input(std::move(inp));
1581
+
1582
+ // make sure the produced embeddings are immediately materialized in the ggml graph
1583
+ // ref: https://github.com/ggml-org/llama.cpp/pull/18599
1584
+ ggml_build_forward_expand(gf, cur);
1585
+
1586
+ return cur;
1587
+ }
1588
+
1589
+ ggml_tensor * llm_graph_context::build_inp_pos() const {
1590
+ auto inp = std::make_unique<llm_graph_input_pos>(hparams.n_pos_per_embd());
1591
+
1592
+ auto & cur = inp->pos;
1593
+
1594
+ cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, (int64_t)n_tokens*hparams.n_pos_per_embd());
1595
+ ggml_set_input(cur);
1596
+
1597
+ res->add_input(std::move(inp));
1598
+
1599
+ return cur;
1600
+ }
1601
+
1602
+ ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
1603
+ auto inp = std::make_unique<llm_graph_input_attn_temp>(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale, hparams.f_attn_temp_offset);
1604
+
1605
+ auto & cur = inp->attn_scale;
1606
+
1607
+ // this need to be 1x1xN for broadcasting
1608
+ cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens);
1609
+ ggml_set_input(cur);
1610
+
1611
+ res->add_input(std::move(inp));
1612
+
1613
+ return cur;
1614
+ }
1615
+
1616
+ ggml_tensor * llm_graph_context::build_inp_out_ids() const {
1617
+ // note: when all tokens are output, we could skip this optimization to spare the ggml_get_rows() calls,
1618
+ // but this would make the graph topology depend on the number of output tokens, which can interere with
1619
+ // features that require constant topology such as pipline parallelism
1620
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14275#issuecomment-2987424471
1621
+ //if (n_outputs < n_tokens) {
1622
+ // return nullptr;
1623
+ //}
1624
+
1625
+ auto inp = std::make_unique<llm_graph_input_out_ids>(hparams, cparams, n_outputs);
1626
+
1627
+ auto & cur = inp->out_ids;
1628
+
1629
+ cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_outputs);
1630
+ ggml_set_input(cur);
1631
+
1632
+ res->add_input(std::move(inp));
1633
+
1634
+ return cur;
1635
+ }
1636
+
1637
+ ggml_tensor * llm_graph_context::build_inp_mean() const {
1638
+ auto inp = std::make_unique<llm_graph_input_mean>(cparams);
1639
+
1640
+ auto & cur = inp->mean;
1641
+
1642
+ cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, ubatch.n_seqs_unq);
1643
+ ggml_set_input(cur);
1644
+
1645
+ res->add_input(std::move(inp));
1646
+
1647
+ return cur;
1648
+ }
1649
+
1650
+ ggml_tensor * llm_graph_context::build_inp_cls() const {
1651
+ auto inp = std::make_unique<llm_graph_input_cls>(cparams, arch);
1652
+
1653
+ auto & cur = inp->cls;
1654
+
1655
+ cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_seqs_unq);
1656
+ ggml_set_input(cur);
1657
+
1658
+ res->add_input(std::move(inp));
1659
+
1660
+ return cur;
1661
+ }
1662
+
1663
+ ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
1664
+ auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
1665
+
1666
+ auto & cur = inp->cross_embd;
1667
+
1668
+ // if we have the output embeddings from the encoder, use them directly
1669
+ // TODO: needs more work to be correct, for now just use the tensor shape
1670
+ //if (cross->t_embd) {
1671
+ // cur = ggml_view_tensor(ctx0, cross->t_embd);
1672
+
1673
+ // return cur;
1674
+ //}
1675
+
1676
+ const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd_inp();
1677
+ const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
1678
+
1679
+ cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc);
1680
+ ggml_set_input(cur);
1681
+
1682
+ res->add_input(std::move(inp));
1683
+
1684
+ return cur;
1685
+ }
1686
+
1687
+ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
1688
+ auto inp = std::make_unique<llm_graph_input_pos_bucket>(hparams);
1689
+
1690
+ auto & cur = inp->pos_bucket;
1691
+
1692
+ cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_tokens);
1693
+ ggml_set_input(cur);
1694
+
1695
+ res->add_input(std::move(inp));
1696
+
1697
+ return cur;
1698
+ }
1699
+
1700
+ ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
1701
+ const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
1702
+
1703
+ auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);
1704
+
1705
+ const auto n_kv = mctx_cur->get_n_kv();
1706
+
1707
+ auto & cur = inp->pos_bucket;
1708
+
1709
+ cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens);
1710
+ ggml_set_input(cur);
1711
+
1712
+ res->add_input(std::move(inp));
1713
+
1714
+ return cur;
1715
+ }
1716
+
1717
+ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const {
1718
+ ggml_tensor * pos_bucket_1d = ggml_reshape_1d(ctx0, pos_bucket, pos_bucket->ne[0] * pos_bucket->ne[1]);
1719
+ cb(pos_bucket_1d, "pos_bucket_1d", -1);
1720
+
1721
+ ggml_tensor * pos_bias = ggml_get_rows(ctx0, attn_rel_b, pos_bucket_1d);
1722
+
1723
+ pos_bias = ggml_reshape_3d(ctx0, pos_bias, pos_bias->ne[0], pos_bucket->ne[0], pos_bucket->ne[1]);
1724
+ pos_bias = ggml_permute (ctx0, pos_bias, 2, 0, 1, 3);
1725
+ pos_bias = ggml_cont (ctx0, pos_bias);
1726
+
1727
+ cb(pos_bias, "pos_bias", -1);
1728
+
1729
+ return pos_bias;
1730
+ }
1731
+
1732
+ ggml_tensor * llm_graph_context::build_attn_mha(
1733
+ ggml_tensor * q,
1734
+ ggml_tensor * k,
1735
+ ggml_tensor * v,
1736
+ ggml_tensor * kq_b,
1737
+ ggml_tensor * kq_mask,
1738
+ ggml_tensor * sinks,
1739
+ ggml_tensor * v_mla,
1740
+ float kq_scale,
1741
+ int il) const {
1742
+ const bool v_trans = v->nb[1] > v->nb[2];
1743
+
1744
+ // split the batch into streams if needed
1745
+ const auto n_stream = k->ne[3];
1746
+
1747
+ q = ggml_view_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream, q->nb[1], q->nb[2], q->nb[3]/n_stream, 0);
1748
+
1749
+ q = ggml_permute(ctx0, q, 0, 2, 1, 3);
1750
+ k = ggml_permute(ctx0, k, 0, 2, 1, 3);
1751
+ v = ggml_permute(ctx0, v, 0, 2, 1, 3);
1752
+
1753
+ ggml_tensor * cur;
1754
+
1755
+ const bool use_flash_attn = cparams.flash_attn && kq_b == nullptr;
1756
+ if (use_flash_attn) {
1757
+ GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");
1758
+
1759
+ if (v_trans) {
1760
+ v = ggml_transpose(ctx0, v);
1761
+ }
1762
+
1763
+ // this can happen when KV cache is not used (e.g. an embedding model with non-causal attn)
1764
+ if (k->type == GGML_TYPE_F32) {
1765
+ k = ggml_cast(ctx0, k, GGML_TYPE_F16);
1766
+ }
1767
+
1768
+ if (v->type == GGML_TYPE_F32) {
1769
+ v = ggml_cast(ctx0, v, GGML_TYPE_F16);
1770
+ }
1771
+
1772
+ cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
1773
+ hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
1774
+ cb(cur, LLAMA_TENSOR_NAME_FATTN, il);
1775
+
1776
+ ggml_flash_attn_ext_add_sinks(cur, sinks);
1777
+ ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32);
1778
+
1779
+ if (v_mla) {
1780
+ #if 0
1781
+ // v_mla can be applied as a matrix-vector multiplication with broadcasting across dimension 3 == n_tokens.
1782
+ // However, the code is optimized for dimensions 0 and 1 being large, so this is ineffient.
1783
+ cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
1784
+ cur = ggml_mul_mat(ctx0, v_mla, cur);
1785
+ #else
1786
+ // It's preferable to do the calculation as a matrix-matrix multiplication with n_tokens in dimension 1.
1787
+ // The permutations are noops and only change how the tensor data is interpreted.
1788
+ cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
1789
+ cur = ggml_mul_mat(ctx0, v_mla, cur);
1790
+ cb(cur, "fattn_mla", il);
1791
+ cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
1792
+ cur = ggml_cont(ctx0, cur); // Needed because ggml_reshape_2d expects contiguous inputs.
1793
+ #endif
1794
+ }
1795
+
1796
+ cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
1797
+ } else {
1798
+ ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
1799
+ cb(kq, "kq", il);
1800
+
1801
+ // note: this op tends to require high floating point range
1802
+ // while for some models F16 is enough, for others it is not, so we default to F32 here
1803
+ ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
1804
+
1805
+ if (arch == LLM_ARCH_GROK) {
1806
+ // need to do the following:
1807
+ // multiply by attn_output_multiplier
1808
+ // and then :
1809
+ // kq = 30 * tanh(kq / 30)
1810
+ // before the softmax below
1811
+
1812
+ kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, hparams.f_attn_out_scale / hparams.f_attn_logit_softcapping));
1813
+ cb(kq, "kq_tanh", il);
1814
+ kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
1815
+ cb(kq, "kq_scaled", il);
1816
+ }
1817
+
1818
+ if (hparams.attn_soft_cap) {
1819
+ kq = ggml_scale(ctx0, kq, 1.0f / hparams.f_attn_logit_softcapping);
1820
+ cb(kq, "kq_scaled_1", il);
1821
+ kq = ggml_tanh (ctx0, kq);
1822
+ cb(kq, "kq_tanh", il);
1823
+ kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
1824
+ cb(kq, "kq_scaled_2", il);
1825
+ }
1826
+
1827
+ if (kq_b) {
1828
+ kq = ggml_add(ctx0, kq, kq_b);
1829
+ cb(kq, "kq_plus_kq_b", il);
1830
+ }
1831
+
1832
+ kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
1833
+ ggml_soft_max_add_sinks(kq, sinks);
1834
+ cb(kq, "kq_soft_max", il);
1835
+
1836
+ if (!v_trans) {
1837
+ // note: avoid this branch
1838
+ v = ggml_cont(ctx0, ggml_transpose(ctx0, v));
1839
+ cb(v, "v_cont", il);
1840
+ }
1841
+
1842
+ ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
1843
+ cb(kqv, "kqv", il);
1844
+
1845
+ // for MLA with the absorption optimization, we need to "decompress" from MQA back to MHA
1846
+ if (v_mla) {
1847
+ kqv = ggml_mul_mat(ctx0, v_mla, kqv);
1848
+ cb(kqv, "kqv_mla", il);
1849
+ }
1850
+
1851
+ cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
1852
+
1853
+ // recombine streams
1854
+ cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
1855
+
1856
+ if (!cparams.offload_kqv) {
1857
+ // all nodes between the KV store and the attention output are run on the CPU
1858
+ ggml_backend_sched_set_tensor_backend(sched, cur, backend_cpu);
1859
+ }
1860
+ }
1861
+
1862
+ ggml_build_forward_expand(gf, cur);
1863
+
1864
+ return cur;
1865
+ }
1866
+
1867
+ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() const {
1868
+ auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
1869
+
1870
+ // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
1871
+ inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1);
1872
+ ggml_set_input(inp->self_kq_mask);
1873
+
1874
+ inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1875
+
1876
+ if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
1877
+ inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1);
1878
+ ggml_set_input(inp->self_kq_mask_swa);
1879
+
1880
+ inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
1881
+ } else {
1882
+ inp->self_kq_mask_swa = nullptr;
1883
+ inp->self_kq_mask_swa_cnv = nullptr;
1884
+ }
1885
+
1886
+ return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
1887
+ }
1888
+
1889
+ ggml_tensor * llm_graph_context::build_attn(
1890
+ llm_graph_input_attn_no_cache * inp,
1891
+ ggml_tensor * wo,
1892
+ ggml_tensor * wo_b,
1893
+ ggml_tensor * q_cur,
1894
+ ggml_tensor * k_cur,
1895
+ ggml_tensor * v_cur,
1896
+ ggml_tensor * kq_b,
1897
+ ggml_tensor * sinks,
1898
+ ggml_tensor * v_mla,
1899
+ float kq_scale,
1900
+ int il) const {
1901
+ GGML_UNUSED(n_tokens);
1902
+
1903
+ // these nodes are added to the graph together so that they are not reordered
1904
+ // by doing so, the number of splits in the graph is reduced
1905
+ ggml_build_forward_expand(gf, q_cur);
1906
+ ggml_build_forward_expand(gf, k_cur);
1907
+ ggml_build_forward_expand(gf, v_cur);
1908
+
1909
+ const bool is_swa = hparams.is_swa(il);
1910
+
1911
+ const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
1912
+
1913
+ // [TAG_NO_CACHE_PAD]
1914
+ // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
1915
+ // but it might not be worth it: https://github.com/ggml-org/llama.cpp/pull/15636
1916
+ //assert(!ubatch.equal_seqs() || (k_cur->ne[3] == 1 && k_cur->ne[3] == ubatch.n_seqs_unq));
1917
+
1918
+ ggml_tensor * q = q_cur;
1919
+ ggml_tensor * k = k_cur;
1920
+ ggml_tensor * v = v_cur;
1921
+
1922
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
1923
+ cb(cur, "kqv_out", il);
1924
+
1925
+ if (wo) {
1926
+ cur = build_lora_mm(wo, cur);
1927
+ }
1928
+
1929
+ if (wo_b) {
1930
+ //cb(cur, "kqv_wo", il);
1931
+ }
1932
+
1933
+ if (wo_b) {
1934
+ cur = ggml_add(ctx0, cur, wo_b);
1935
+ }
1936
+
1937
+ return cur;
1938
+ }
1939
+
1940
+ static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
1941
+ ggml_context * ctx0,
1942
+ const llama_ubatch & ubatch,
1943
+ const llama_hparams & hparams,
1944
+ const llama_cparams & cparams,
1945
+ const llama_kv_cache_context * mctx_cur) {
1946
+
1947
+ auto inp = std::make_unique<llm_graph_input_attn_kv>(hparams, cparams, mctx_cur);
1948
+
1949
+ {
1950
+ GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
1951
+
1952
+ inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
1953
+ inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
1954
+
1955
+ inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams);
1956
+
1957
+ ggml_set_input(inp->self_kq_mask);
1958
+
1959
+ inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1960
+ }
1961
+
1962
+ return inp;
1963
+ }
1964
+
1965
+ llm_graph_input_attn_kv * llm_graph_context::build_attn_inp_kv() const {
1966
+ const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
1967
+
1968
+ auto inp = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
1969
+
1970
+ return (llm_graph_input_attn_kv *) res->add_input(std::move(inp));
1971
+ }
1972
+
1973
+ ggml_tensor * llm_graph_context::build_attn(
1974
+ llm_graph_input_attn_kv * inp,
1975
+ ggml_tensor * wo,
1976
+ ggml_tensor * wo_b,
1977
+ ggml_tensor * q_cur,
1978
+ ggml_tensor * k_cur,
1979
+ ggml_tensor * v_cur,
1980
+ ggml_tensor * kq_b,
1981
+ ggml_tensor * sinks,
1982
+ ggml_tensor * v_mla, // TODO: remove
1983
+ float kq_scale,
1984
+ int il) const {
1985
+ GGML_ASSERT(v_mla == nullptr);
1986
+
1987
+ // these nodes are added to the graph together so that they are not reordered
1988
+ // by doing so, the number of splits in the graph is reduced
1989
+ // expand k later to enable rope fusion which directly writes into k-v cache
1990
+ ggml_build_forward_expand(gf, q_cur);
1991
+ ggml_build_forward_expand(gf, v_cur);
1992
+ ggml_build_forward_expand(gf, k_cur);
1993
+
1994
+ const auto * mctx_cur = inp->mctx;
1995
+
1996
+ // store to KV cache
1997
+ {
1998
+ const auto & k_idxs = inp->get_k_idxs();
1999
+ const auto & v_idxs = inp->get_v_idxs();
2000
+
2001
+ ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
2002
+ ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
2003
+ }
2004
+
2005
+ const auto & kq_mask = inp->get_kq_mask();
2006
+
2007
+ ggml_tensor * q = q_cur;
2008
+ ggml_tensor * k = mctx_cur->get_k(ctx0, il);
2009
+ ggml_tensor * v = mctx_cur->get_v(ctx0, il);
2010
+
2011
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
2012
+ cb(cur, "kqv_out", il);
2013
+
2014
+ if (wo) {
2015
+ cur = build_lora_mm(wo, cur);
2016
+ if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) {
2017
+ // GLM4, GLM4_MOE, and JAIS2 seem to have numerical issues with half-precision accumulators
2018
+ ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
2019
+ }
2020
+ }
2021
+
2022
+ if (wo_b) {
2023
+ cur = ggml_add(ctx0, cur, wo_b);
2024
+ }
2025
+
2026
+ return cur;
2027
+ }
2028
+
2029
+ static std::unique_ptr<llm_graph_input_attn_k> build_attn_inp_k_impl(
2030
+ ggml_context * ctx0,
2031
+ const llama_ubatch & ubatch,
2032
+ const llama_hparams & hparams,
2033
+ const llama_cparams & cparams,
2034
+ const llama_kv_cache_context * mctx_cur) {
2035
+
2036
+ auto inp = std::make_unique<llm_graph_input_attn_k>(hparams, cparams, mctx_cur);
2037
+
2038
+ {
2039
+ GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
2040
+
2041
+ inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
2042
+
2043
+ inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams);
2044
+ ggml_set_input(inp->self_kq_mask);
2045
+
2046
+ inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
2047
+ }
2048
+
2049
+ return inp;
2050
+ }
2051
+
2052
+ llm_graph_input_attn_k * llm_graph_context::build_attn_inp_k() const {
2053
+ const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
2054
+
2055
+ auto inp = build_attn_inp_k_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
2056
+
2057
+ return (llm_graph_input_attn_k *) res->add_input(std::move(inp));
2058
+ }
2059
+
2060
+ ggml_tensor * llm_graph_context::build_attn(
2061
+ llm_graph_input_attn_k * inp,
2062
+ ggml_tensor * wo,
2063
+ ggml_tensor * wo_b,
2064
+ ggml_tensor * q_cur,
2065
+ ggml_tensor * k_cur,
2066
+ ggml_tensor * v_cur,
2067
+ ggml_tensor * kq_b,
2068
+ ggml_tensor * sinks,
2069
+ ggml_tensor * v_mla,
2070
+ float kq_scale,
2071
+ int il) const {
2072
+ // these nodes are added to the graph together so that they are not reordered
2073
+ // by doing so, the number of splits in the graph is reduced
2074
+ // expand k later to enable rope fusion which directly writes into k-v cache
2075
+ ggml_build_forward_expand(gf, q_cur);
2076
+ ggml_build_forward_expand(gf, v_cur);
2077
+ ggml_build_forward_expand(gf, k_cur);
2078
+
2079
+ const auto * mctx_cur = inp->mctx;
2080
+
2081
+ // store to KV cache
2082
+ {
2083
+ const auto & k_idxs = inp->get_k_idxs();
2084
+
2085
+ ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
2086
+ }
2087
+
2088
+ const auto & kq_mask = inp->get_kq_mask();
2089
+
2090
+ ggml_tensor * q = q_cur;
2091
+ ggml_tensor * k = mctx_cur->get_k(ctx0, il);
2092
+ ggml_tensor * v = ggml_view_4d(ctx0, k, v_cur->ne[0], k->ne[1], k->ne[2], k->ne[3], k->nb[1], k->nb[2], k->nb[3], 0);
2093
+
2094
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
2095
+ cb(cur, "kqv_out", il);
2096
+
2097
+ if (wo) {
2098
+ cur = build_lora_mm(wo, cur);
2099
+ if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
2100
+ // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
2101
+ ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
2102
+ }
2103
+ }
2104
+
2105
+ if (wo_b) {
2106
+ cur = ggml_add(ctx0, cur, wo_b);
2107
+ }
2108
+
2109
+ return cur;
2110
+ }
2111
+
2112
+ ggml_tensor * llm_graph_context::build_attn(
2113
+ llm_graph_input_attn_kv_iswa * inp,
2114
+ ggml_tensor * wo,
2115
+ ggml_tensor * wo_b,
2116
+ ggml_tensor * q_cur,
2117
+ ggml_tensor * k_cur,
2118
+ ggml_tensor * v_cur,
2119
+ ggml_tensor * kq_b,
2120
+ ggml_tensor * sinks,
2121
+ ggml_tensor * v_mla,
2122
+ float kq_scale,
2123
+ int il) const {
2124
+ // these nodes are added to the graph together so that they are not reordered
2125
+ // by doing so, the number of splits in the graph is reduced
2126
+ ggml_build_forward_expand(gf, q_cur);
2127
+
2128
+ if (k_cur) {
2129
+ ggml_build_forward_expand(gf, k_cur);
2130
+ }
2131
+
2132
+ if (v_cur) {
2133
+ ggml_build_forward_expand(gf, v_cur);
2134
+ }
2135
+
2136
+ const auto * mctx_iswa = inp->mctx;
2137
+
2138
+ const bool is_swa = hparams.is_swa(il);
2139
+
2140
+ const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();
2141
+
2142
+ // optionally store to KV cache
2143
+ if (k_cur) {
2144
+ const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs();
2145
+
2146
+ ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
2147
+ }
2148
+
2149
+ if (v_cur) {
2150
+ const auto & v_idxs = is_swa ? inp->get_v_idxs_swa() : inp->get_v_idxs();
2151
+
2152
+ ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
2153
+ }
2154
+
2155
+ const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
2156
+
2157
+ ggml_tensor * q = q_cur;
2158
+ ggml_tensor * k = mctx_cur->get_k(ctx0, il);
2159
+ ggml_tensor * v = mctx_cur->get_v(ctx0, il);
2160
+
2161
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
2162
+ cb(cur, "kqv_out", il);
2163
+
2164
+ if (wo) {
2165
+ cur = build_lora_mm(wo, cur);
2166
+ }
2167
+
2168
+ if (wo_b) {
2169
+ //cb(cur, "kqv_wo", il);
2170
+ }
2171
+
2172
+ if (wo_b) {
2173
+ cur = ggml_add(ctx0, cur, wo_b);
2174
+ }
2175
+
2176
+ return cur;
2177
+ }
2178
+
2179
+ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
2180
+ auto inp = std::make_unique<llm_graph_input_attn_cross>(cross);
2181
+
2182
+ const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
2183
+
2184
+ inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, n_tokens, 1, 1);
2185
+ ggml_set_input(inp->cross_kq_mask);
2186
+
2187
+ inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
2188
+
2189
+ return (llm_graph_input_attn_cross *) res->add_input(std::move(inp));
2190
+ }
2191
+
2192
+ ggml_tensor * llm_graph_context::build_attn(
2193
+ llm_graph_input_attn_cross * inp,
2194
+ ggml_tensor * wo,
2195
+ ggml_tensor * wo_b,
2196
+ ggml_tensor * q_cur,
2197
+ ggml_tensor * k_cur,
2198
+ ggml_tensor * v_cur,
2199
+ ggml_tensor * kq_b,
2200
+ ggml_tensor * sinks,
2201
+ ggml_tensor * v_mla,
2202
+ float kq_scale,
2203
+ int il) const {
2204
+ // these nodes are added to the graph together so that they are not reordered
2205
+ // by doing so, the number of splits in the graph is reduced
2206
+ ggml_build_forward_expand(gf, q_cur);
2207
+ ggml_build_forward_expand(gf, k_cur);
2208
+ ggml_build_forward_expand(gf, v_cur);
2209
+
2210
+ const auto & kq_mask = inp->get_kq_mask_cross();
2211
+
2212
+ ggml_tensor * q = q_cur;
2213
+ ggml_tensor * k = k_cur;
2214
+ ggml_tensor * v = v_cur;
2215
+
2216
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
2217
+ cb(cur, "kqv_out", il);
2218
+
2219
+ if (wo) {
2220
+ cur = build_lora_mm(wo, cur);
2221
+ }
2222
+
2223
+ if (wo_b) {
2224
+ //cb(cur, "kqv_wo", il);
2225
+ }
2226
+
2227
+ if (wo_b) {
2228
+ cur = ggml_add(ctx0, cur, wo_b);
2229
+ }
2230
+
2231
+ return cur;
2232
+ }
2233
+
2234
+ // TODO: maybe separate the inner implementation into a separate function
2235
+ // like with the non-sliding window equivalent
2236
+ // once sliding-window hybrid caches are a thing.
2237
+ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const {
2238
+ const auto * mctx_cur = static_cast<const llama_kv_cache_iswa_context *>(mctx);
2239
+
2240
+ auto inp = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, mctx_cur);
2241
+
2242
+ {
2243
+ inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
2244
+ inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
2245
+
2246
+ inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams);
2247
+ ggml_set_input(inp->self_kq_mask);
2248
+ ggml_set_name(inp->self_kq_mask, "self_kq_mask");
2249
+
2250
+ inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
2251
+ ggml_set_name(inp->self_kq_mask_cnv, "self_kq_mask_cnv");
2252
+ }
2253
+
2254
+ {
2255
+ GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache for non-SWA");
2256
+
2257
+ inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
2258
+ inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
2259
+
2260
+ inp->self_kq_mask_swa = build_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams);
2261
+ ggml_set_input(inp->self_kq_mask_swa);
2262
+ ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa");
2263
+
2264
+ inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
2265
+ ggml_set_name(inp->self_kq_mask_swa_cnv, "self_kq_mask_swa_cnv");
2266
+ }
2267
+
2268
+ return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
2269
+ }
2270
+
2271
+ ggml_tensor * llm_graph_context::build_rs(
2272
+ ggml_tensor * s,
2273
+ ggml_tensor * state_copy_main,
2274
+ ggml_tensor * state_copy_extra,
2275
+ int32_t state_size,
2276
+ int32_t n_seqs,
2277
+ uint32_t n_rs,
2278
+ uint32_t rs_head,
2279
+ uint32_t rs_size,
2280
+ int32_t rs_zero,
2281
+ const llm_graph_get_rows_fn & get_state_rows) const {
2282
+
2283
+ ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, rs_size);
2284
+
2285
+ // Clear a single state which will then be copied to the other cleared states.
2286
+ // Note that this is a no-op when the view is zero-sized.
2287
+ ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
2288
+ ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
2289
+
2290
+ // copy states
2291
+ // NOTE: assuming the copy destinations are ALL contained between rs_head and rs_head + n_rs
2292
+ // {state_size, rs_size} -> {state_size, n_seqs}
2293
+ ggml_tensor * output_states = get_state_rows(ctx0, states, state_copy_main);
2294
+ ggml_build_forward_expand(gf, output_states);
2295
+
2296
+ // copy extra states which won't be changed further (between n_seqs and n_rs)
2297
+ ggml_tensor * states_extra = ggml_get_rows(ctx0, states, state_copy_extra);
2298
+ ggml_build_forward_expand(gf,
2299
+ ggml_cpy(ctx0,
2300
+ states_extra,
2301
+ ggml_view_1d(ctx0, s, state_size*(n_rs - n_seqs), (rs_head + n_seqs)*state_size*ggml_element_size(s))));
2302
+
2303
+ return output_states;
2304
+ }
2305
+
2306
+ static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
2307
+ ggml_context * ctx0,
2308
+ const llama_ubatch & ubatch,
2309
+ const llama_memory_recurrent_context * mctx_cur) {
2310
+
2311
+ auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
2312
+
2313
+ const int64_t n_rs = mctx_cur->get_n_rs();
2314
+ const int64_t n_seqs = ubatch.n_seqs;
2315
+
2316
+ inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
2317
+ ggml_set_input(inp->s_copy);
2318
+
2319
+ inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0);
2320
+ inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]);
2321
+
2322
+ inp->head = mctx_cur->get_head();
2323
+ inp->rs_z = mctx_cur->get_rs_z();
2324
+
2325
+ return inp;
2326
+ }
2327
+
2328
+ llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
2329
+ const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
2330
+
2331
+ auto inp = build_rs_inp_impl(ctx0, ubatch, mctx_cur);
2332
+
2333
+ return (llm_graph_input_rs *) res->add_input(std::move(inp));
2334
+ }
2335
+
2336
+ ggml_tensor * llm_graph_context::build_rs(
2337
+ llm_graph_input_rs * inp,
2338
+ ggml_tensor * s,
2339
+ int32_t state_size,
2340
+ int32_t n_seqs,
2341
+ const llm_graph_get_rows_fn & get_state_rows) const {
2342
+ const auto * kv_state = inp->mctx;
2343
+
2344
+ return build_rs(s, inp->s_copy_main, inp->s_copy_extra, state_size, n_seqs,
2345
+ kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(),
2346
+ get_state_rows);
2347
+ }
2348
+
2349
+ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
2350
+ llm_graph_input_rs * inp,
2351
+ const llama_ubatch & ubatch,
2352
+ int il) const {
2353
+ const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
2354
+
2355
+ const auto token_shift_count = hparams.token_shift_count;
2356
+
2357
+ const int64_t n_seqs = ubatch.n_seqs;
2358
+
2359
+ ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
2360
+
2361
+ ggml_tensor * token_shift = build_rs(
2362
+ inp, token_shift_all,
2363
+ hparams.n_embd_r(), n_seqs);
2364
+
2365
+ token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
2366
+
2367
+ return token_shift;
2368
+ }
2369
+
2370
+ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
2371
+ ggml_tensor * token_shift,
2372
+ const llama_ubatch & ubatch,
2373
+ int il) const {
2374
+ const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
2375
+
2376
+ const auto token_shift_count = hparams.token_shift_count;
2377
+ const auto n_embd = hparams.n_embd;
2378
+
2379
+ const int64_t n_seqs = ubatch.n_seqs;
2380
+
2381
+ const auto kv_head = mctx_cur->get_head();
2382
+
2383
+ return ggml_cpy(
2384
+ ctx0,
2385
+ ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
2386
+ ggml_view_1d(ctx0, mctx_cur->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(mctx_cur->get_r_l(il)))
2387
+ );
2388
+ }
2389
+
2390
+ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
2391
+ const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
2392
+
2393
+ auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr());
2394
+ auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
2395
+
2396
+ auto inp = std::make_unique<llm_graph_input_mem_hybrid>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
2397
+
2398
+ return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
2399
+ }
2400
+
2401
+ llm_graph_input_mem_hybrid_k * llm_graph_context::build_inp_mem_hybrid_k() const {
2402
+ const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
2403
+
2404
+ auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr());
2405
+ auto inp_attn = build_attn_inp_k_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
2406
+
2407
+ auto inp = std::make_unique<llm_graph_input_mem_hybrid_k>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
2408
+
2409
+ return (llm_graph_input_mem_hybrid_k *) res->add_input(std::move(inp));
2410
+ }
2411
+
2412
+ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() const {
2413
+ const auto * mctx_cur = static_cast<const llama_memory_hybrid_iswa_context *>(mctx);
2414
+
2415
+ auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
2416
+
2417
+ // build iswa attention input
2418
+ const auto * attn_ctx = mctx_cur->get_attn();
2419
+
2420
+ auto inp_attn = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, attn_ctx);
2421
+
2422
+ {
2423
+ inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch);
2424
+ inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch);
2425
+
2426
+ inp_attn->self_kq_mask = build_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams);
2427
+ ggml_set_input(inp_attn->self_kq_mask);
2428
+
2429
+ inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask;
2430
+ }
2431
+
2432
+ {
2433
+ inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch);
2434
+ inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch);
2435
+
2436
+ inp_attn->self_kq_mask_swa = build_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams);
2437
+ ggml_set_input(inp_attn->self_kq_mask_swa);
2438
+
2439
+ inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa;
2440
+ }
2441
+
2442
+ auto inp = std::make_unique<llm_graph_input_mem_hybrid_iswa>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
2443
+
2444
+ return (llm_graph_input_mem_hybrid_iswa *) res->add_input(std::move(inp));
2445
+ }
2446
+
2447
+ void llm_graph_context::build_dense_out(
2448
+ ggml_tensor * dense_2,
2449
+ ggml_tensor * dense_2_b,
2450
+ ggml_tensor * dense_3) const {
2451
+ if (!cparams.embeddings || !(dense_2 || dense_2_b || dense_3)) {
2452
+ return;
2453
+ }
2454
+ ggml_tensor * cur = res->t_embd_pooled != nullptr ? res->t_embd_pooled : res->t_embd;
2455
+ GGML_ASSERT(cur != nullptr && "missing t_embd_pooled/t_embd");
2456
+
2457
+ if (dense_2) {
2458
+ cur = ggml_mul_mat(ctx0, dense_2, cur);
2459
+ }
2460
+ if (dense_2_b) {
2461
+ cur = ggml_add(ctx0, cur, dense_2_b);
2462
+ }
2463
+ if (dense_3) {
2464
+ cur = ggml_mul_mat(ctx0, dense_3, cur);
2465
+ }
2466
+ cb(cur, "result_embd_pooled", -1);
2467
+ res->t_embd_pooled = cur;
2468
+ ggml_build_forward_expand(gf, cur);
2469
+ }
2470
+
2471
+
2472
+ void llm_graph_context::build_pooling(
2473
+ ggml_tensor * cls,
2474
+ ggml_tensor * cls_b,
2475
+ ggml_tensor * cls_out,
2476
+ ggml_tensor * cls_out_b,
2477
+ ggml_tensor * cls_norm) const {
2478
+ if (!cparams.embeddings) {
2479
+ return;
2480
+ }
2481
+
2482
+ ggml_tensor * inp = res->t_embd;
2483
+
2484
+ //// find result_norm tensor for input
2485
+ //for (int i = ggml_graph_n_nodes(gf) - 1; i >= 0; --i) {
2486
+ // inp = ggml_graph_node(gf, i);
2487
+ // if (strcmp(inp->name, "result_norm") == 0 || strcmp(inp->name, "result_embd") == 0) {
2488
+ // break;
2489
+ // }
2490
+
2491
+ // inp = nullptr;
2492
+ //}
2493
+
2494
+ GGML_ASSERT(inp != nullptr && "missing result_norm/result_embd tensor");
2495
+
2496
+ ggml_tensor * cur;
2497
+
2498
+ switch (pooling_type) {
2499
+ case LLAMA_POOLING_TYPE_NONE:
2500
+ {
2501
+ cur = inp;
2502
+ } break;
2503
+ case LLAMA_POOLING_TYPE_MEAN:
2504
+ {
2505
+ ggml_tensor * inp_mean = build_inp_mean();
2506
+ cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean);
2507
+ } break;
2508
+ case LLAMA_POOLING_TYPE_CLS:
2509
+ case LLAMA_POOLING_TYPE_LAST:
2510
+ {
2511
+ ggml_tensor * inp_cls = build_inp_cls();
2512
+ cur = ggml_get_rows(ctx0, inp, inp_cls);
2513
+ } break;
2514
+ case LLAMA_POOLING_TYPE_RANK:
2515
+ {
2516
+ if (arch == LLM_ARCH_MODERN_BERT) {
2517
+ // modern bert gte reranker builds mean first then applies prediction head and classifier
2518
+ // https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modular_modernbert.py#L1404-1411
2519
+ ggml_tensor * inp_mean = build_inp_mean();
2520
+ cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean);
2521
+ } else {
2522
+ ggml_tensor * inp_cls = build_inp_cls();
2523
+ cur = ggml_get_rows(ctx0, inp, inp_cls);
2524
+ }
2525
+
2526
+ // classification head
2527
+ // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
2528
+ if (cls) {
2529
+ cur = ggml_mul_mat(ctx0, cls, cur);
2530
+ if (cls_b) {
2531
+ cur = ggml_add(ctx0, cur, cls_b);
2532
+ }
2533
+ if (arch == LLM_ARCH_MODERN_BERT) {
2534
+ cur = ggml_gelu(ctx0, cur);
2535
+ } else {
2536
+ cur = ggml_tanh(ctx0, cur);
2537
+ }
2538
+ if (cls_norm) {
2539
+ // head norm
2540
+ cur = build_norm(cur, cls_norm, NULL, LLM_NORM, -1);
2541
+ }
2542
+ }
2543
+
2544
+ // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
2545
+ // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
2546
+ // Single layer classification head (direct projection)
2547
+ // https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
2548
+ if (cls_out) {
2549
+ cur = ggml_mul_mat(ctx0, cls_out, cur);
2550
+ if (cls_out_b) {
2551
+ cur = ggml_add(ctx0, cur, cls_out_b);
2552
+ }
2553
+ }
2554
+
2555
+ // softmax for qwen3 reranker
2556
+ if (arch == LLM_ARCH_QWEN3) {
2557
+ cur = ggml_soft_max(ctx0, cur);
2558
+ }
2559
+ } break;
2560
+ default:
2561
+ {
2562
+ GGML_ABORT("unknown pooling type");
2563
+ }
2564
+ }
2565
+
2566
+ cb(cur, "result_embd_pooled", -1);
2567
+ res->t_embd_pooled = cur;
2568
+
2569
+ ggml_build_forward_expand(gf, cur);
2570
+ }
2571
+
2572
+ void llm_graph_context::build_sampling() const {
2573
+ if (samplers.empty() || !res->t_logits) {
2574
+ return;
2575
+ }
2576
+
2577
+ std::array<ggml_tensor *, 2> outs;
2578
+ outs[0] = res->t_logits;
2579
+
2580
+ auto inp_sampling = std::make_unique<llm_graph_input_sampling>(samplers);
2581
+ res->add_input(std::move(inp_sampling));
2582
+
2583
+ std::map<llama_seq_id, int32_t> seq_to_logit_row;
2584
+ int32_t logit_row_idx = 0;
2585
+
2586
+ for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
2587
+ if (ubatch.output[i]) {
2588
+ llama_seq_id seq_id = ubatch.seq_id[i][0];
2589
+ seq_to_logit_row[seq_id] = logit_row_idx;
2590
+ logit_row_idx++;
2591
+ }
2592
+ }
2593
+
2594
+ // res->t_logits will contain logits for all tokens that want the logits calculated (logits=1 or output=1)
2595
+ GGML_ASSERT(res->t_logits != nullptr && "missing t_logits tensor");
2596
+
2597
+ // add a dummy row of logits
2598
+ // this trick makes the graph static, regardless of which samplers are activated
2599
+ // this is important in order to minimize graph reallocations
2600
+ ggml_tensor * logits_t = ggml_pad(ctx0, res->t_logits, 0, 1, 0, 0);
2601
+
2602
+ for (const auto & [seq_id, sampler] : samplers) {
2603
+ const auto it = seq_to_logit_row.find(seq_id);
2604
+
2605
+ // inactive samplers always work on the first row
2606
+ const auto row_idx = it != seq_to_logit_row.end() ? it->second : 0;
2607
+ const int i_out = it != seq_to_logit_row.end() ? 1 : 0;
2608
+
2609
+ ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]);
2610
+ ggml_format_name(logits_seq, "logits_seq_%d", seq_id);
2611
+
2612
+ struct llama_sampler_data data = {
2613
+ /*.logits =*/ logits_seq,
2614
+ /*.probs =*/ nullptr,
2615
+ /*.sampled =*/ nullptr,
2616
+ /*.candidates =*/ nullptr,
2617
+ };
2618
+
2619
+ assert(sampler->iface->backend_apply);
2620
+ sampler->iface->backend_apply(sampler, ctx0, gf, &data);
2621
+
2622
+ if (data.sampled != nullptr) {
2623
+ res->t_sampled[seq_id] = data.sampled;
2624
+ outs[1] = data.sampled;
2625
+ ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
2626
+ }
2627
+
2628
+ if (data.probs != nullptr) {
2629
+ res->t_sampled_probs[seq_id] = data.probs;
2630
+ outs[1] = data.probs;
2631
+ ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
2632
+ }
2633
+
2634
+ if (data.logits != nullptr) {
2635
+ res->t_sampled_logits[seq_id] = data.logits;
2636
+ outs[1] = data.logits;
2637
+ ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
2638
+ }
2639
+
2640
+ if (data.candidates != nullptr) {
2641
+ res->t_candidates[seq_id] = data.candidates;
2642
+ outs[1] = data.candidates;
2643
+ ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
2644
+ }
2645
+ }
2646
+
2647
+ // TODO: Call llama_sampler_accept_ggml after all samplers have been applied.
2648
+ /*
2649
+ for (const auto & [seq_id, sampler] : samplers) {
2650
+ if (auto it = res->t_sampled.find(seq_id); it != res->t_sampled.end()) {
2651
+ ggml_tensor * selected_token = it->second;
2652
+ if (selected_token != nullptr) {
2653
+ llama_sampler_accept_ggml(sampler, ctx0, gf, selected_token);
2654
+ }
2655
+ }
2656
+ }
2657
+ */
2658
+ }
2659
+
2660
+ int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
2661
+ // TODO move to hparams if a T5 variant appears that uses a different value
2662
+ const int64_t max_distance = 128;
2663
+
2664
+ if (bidirectional) {
2665
+ n_buckets >>= 1;
2666
+ }
2667
+
2668
+ const int64_t max_exact = n_buckets >> 1;
2669
+
2670
+ int32_t relative_position = x - y;
2671
+ int32_t relative_bucket = 0;
2672
+
2673
+ if (bidirectional) {
2674
+ relative_bucket += (relative_position > 0) * n_buckets;
2675
+ relative_position = std::abs(relative_position);
2676
+ } else {
2677
+ relative_position = -std::min<int32_t>(relative_position, 0);
2678
+ }
2679
+
2680
+ int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
2681
+ relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
2682
+ relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
2683
+
2684
+ return relative_bucket;
2685
+ }