whispercpp 1.3.6 → 1.3.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (828) hide show
  1. checksums.yaml +4 -4
  2. data/.document +3 -0
  3. data/.rdoc_options +2 -0
  4. data/README.md +38 -5
  5. data/Rakefile +18 -3
  6. data/ext/dependencies.rb +10 -4
  7. data/ext/dependencies_for_windows.rb +17 -0
  8. data/ext/extconf.rb +20 -8
  9. data/ext/options.rb +54 -14
  10. data/ext/options_for_windows.rb +51 -0
  11. data/ext/ruby_whisper.c +36 -42
  12. data/ext/ruby_whisper.h +135 -0
  13. data/ext/ruby_whisper_context.c +107 -28
  14. data/ext/ruby_whisper_log_queue.c +180 -0
  15. data/ext/ruby_whisper_log_settable.h +47 -0
  16. data/ext/ruby_whisper_parakeet.c +49 -0
  17. data/ext/ruby_whisper_parakeet_context.c +304 -0
  18. data/ext/ruby_whisper_parakeet_context_params.c +117 -0
  19. data/ext/ruby_whisper_parakeet_model.c +84 -0
  20. data/ext/ruby_whisper_parakeet_params.c +548 -0
  21. data/ext/ruby_whisper_parakeet_segment.c +157 -0
  22. data/ext/ruby_whisper_parakeet_token.c +188 -0
  23. data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
  24. data/ext/ruby_whisper_params.c +256 -65
  25. data/ext/ruby_whisper_segment.c +6 -6
  26. data/ext/ruby_whisper_transcribe.cpp +42 -15
  27. data/ext/sources/CMakeLists.txt +41 -3
  28. data/ext/sources/CMakePresets.json +95 -0
  29. data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
  30. data/ext/sources/cmake/parakeet.pc.in +10 -0
  31. data/ext/sources/cmake/whisper.pc.in +1 -1
  32. data/ext/sources/examples/CMakeLists.txt +4 -2
  33. data/ext/sources/examples/bench/bench.cpp +1 -1
  34. data/ext/sources/examples/cli/cli.cpp +43 -9
  35. data/ext/sources/examples/common-ggml.cpp +2 -0
  36. data/ext/sources/examples/common-whisper.cpp +139 -67
  37. data/ext/sources/examples/common-whisper.h +11 -0
  38. data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
  39. data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
  40. data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
  41. data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
  42. data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
  43. data/ext/sources/examples/server/server.cpp +199 -163
  44. data/ext/sources/ggml/CMakeLists.txt +21 -13
  45. data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
  46. data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
  47. data/ext/sources/ggml/include/ggml-alloc.h +1 -0
  48. data/ext/sources/ggml/include/ggml-backend.h +72 -10
  49. data/ext/sources/ggml/include/ggml-cuda.h +3 -0
  50. data/ext/sources/ggml/include/ggml-rpc.h +3 -3
  51. data/ext/sources/ggml/include/ggml.h +101 -9
  52. data/ext/sources/ggml/include/gguf.h +10 -2
  53. data/ext/sources/ggml/src/CMakeLists.txt +22 -5
  54. data/ext/sources/ggml/src/ggml-alloc.c +5 -1
  55. data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
  56. data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
  57. data/ext/sources/ggml/src/ggml-backend-reg.cpp +12 -0
  58. data/ext/sources/ggml/src/ggml-backend.cpp +110 -9
  59. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +4 -0
  60. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +672 -257
  61. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +71 -0
  62. data/ext/sources/ggml/src/ggml-cann/common.h +20 -10
  63. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +211 -30
  64. data/ext/sources/ggml/src/ggml-common.h +11 -0
  65. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +58 -29
  66. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +2 -0
  67. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +16 -16
  68. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +116 -7
  69. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +65 -0
  70. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
  71. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
  72. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +4279 -1292
  73. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +5 -35
  74. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +0 -1
  75. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
  76. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +177 -27
  77. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +1 -1
  78. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +5 -0
  79. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
  80. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +10 -0
  81. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +95 -5
  82. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -0
  83. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +146 -134
  84. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +88 -70
  85. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +372 -73
  86. data/ext/sources/ggml/src/ggml-cpu/ops.h +3 -0
  87. data/ext/sources/ggml/src/ggml-cpu/quants.c +55 -0
  88. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  89. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3 -0
  90. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +90 -0
  91. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +3 -16
  92. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
  93. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
  94. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
  95. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
  96. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
  97. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
  98. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
  99. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
  100. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
  101. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
  102. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
  103. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
  104. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
  105. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
  106. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
  107. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +37 -53
  108. data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
  109. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +17 -7
  110. data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
  111. data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
  112. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +62 -26
  113. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +44 -18
  114. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
  115. data/ext/sources/ggml/src/ggml-cuda/common.cuh +242 -28
  116. data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
  117. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
  118. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
  119. data/ext/sources/ggml/src/ggml-cuda/convert.cu +53 -0
  120. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  121. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +14 -6
  122. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
  123. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +278 -44
  124. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +331 -130
  125. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
  126. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +126 -27
  127. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +40 -15
  128. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +18 -9
  129. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +152 -49
  130. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  131. data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
  132. data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
  133. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +84 -35
  134. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
  135. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1069 -609
  136. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
  137. data/ext/sources/ggml/src/ggml-cuda/mean.cu +4 -2
  138. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +242 -195
  139. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +3 -3
  140. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
  141. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +502 -423
  142. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +19 -12
  143. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +485 -57
  144. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -1
  145. data/ext/sources/ggml/src/ggml-cuda/norm.cu +36 -10
  146. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
  147. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +133 -26
  148. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
  149. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +5 -1
  150. data/ext/sources/ggml/src/ggml-cuda/rope.cu +11 -4
  151. data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
  152. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
  153. data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
  154. data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
  155. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
  156. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +45 -13
  157. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  158. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
  159. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
  160. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
  161. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +1 -0
  162. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
  163. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  164. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
  165. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +1 -0
  166. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  167. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
  168. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
  169. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  170. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
  171. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  172. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
  173. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
  174. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
  175. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
  176. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
  177. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
  178. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
  179. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
  180. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
  181. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
  182. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
  183. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
  184. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
  185. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
  186. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
  187. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
  188. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
  189. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
  190. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
  191. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -4
  192. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +26 -23
  193. data/ext/sources/ggml/src/ggml-cuda/unary.cu +31 -2
  194. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
  195. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +80 -0
  196. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
  197. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -4
  198. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
  199. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +2 -1
  200. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1428 -743
  201. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +45 -7
  202. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +53 -84
  203. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +25 -12
  204. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +165 -184
  205. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
  206. data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
  207. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +170 -127
  208. data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
  209. data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
  210. data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
  211. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +125 -97
  212. data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
  213. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +148 -42
  214. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +2 -2
  215. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +252 -62
  216. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +9 -0
  217. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +87 -1
  218. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
  219. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
  220. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
  221. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
  222. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
  223. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
  224. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
  225. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
  226. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +96 -13
  227. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +182 -57
  228. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
  229. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +71 -3
  230. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +27 -10
  231. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +63 -23
  232. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +9 -8
  233. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
  234. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
  235. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
  236. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
  237. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +1 -0
  238. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
  239. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +5 -8
  240. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +529 -815
  241. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2522 -234
  242. data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
  243. data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
  244. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +291 -95
  245. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +59 -37
  246. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +121 -133
  247. data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
  248. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +244 -151
  249. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +6 -6
  250. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +719 -45
  251. data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
  252. data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
  253. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +3 -1
  254. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +22 -9
  255. data/ext/sources/ggml/src/ggml-impl.h +6 -1
  256. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +138 -13
  257. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +32 -1
  258. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +164 -28
  259. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +80 -0
  260. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +190 -19
  261. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +2 -0
  262. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +39 -26
  263. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +823 -322
  264. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
  265. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +54 -5
  266. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +12248 -5907
  267. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +67 -0
  268. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +59 -0
  269. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1819 -112
  270. data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
  271. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
  272. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
  273. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
  274. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
  275. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
  276. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
  277. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
  278. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
  279. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
  280. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
  281. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
  282. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
  283. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
  284. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
  285. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
  286. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mm_q8_0_f32_8x4.cl → gemm_noshuffle_q8_0_f32.cl} +1 -1
  287. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
  288. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
  289. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
  290. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
  291. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
  292. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
  293. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
  294. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
  295. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
  296. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
  297. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
  298. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
  299. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
  300. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
  301. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
  302. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
  303. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
  304. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
  305. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
  306. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
  307. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
  308. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
  309. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
  310. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
  311. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
  312. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
  313. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
  314. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
  315. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
  316. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
  317. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
  318. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +48 -64
  322. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +15 -5
  323. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +18 -11
  324. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +35 -13
  325. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +264 -192
  326. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +33 -7
  327. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
  328. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +1 -0
  329. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +1 -0
  330. data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
  331. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +27 -3
  332. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +67 -36
  333. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +1 -0
  334. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +101 -44
  335. data/ext/sources/ggml/src/ggml-openvino/utils.h +23 -3
  336. data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
  337. data/ext/sources/ggml/src/ggml-quants.c +289 -114
  338. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  339. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
  340. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
  341. data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
  342. data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
  343. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +50 -4
  344. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +1 -1
  345. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +3 -1
  346. data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
  347. data/ext/sources/ggml/src/ggml-sycl/common.hpp +41 -1
  348. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +115 -13
  349. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +9 -0
  350. data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
  351. data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
  352. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
  353. data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
  354. data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
  355. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
  356. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1 -90
  357. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +0 -2
  358. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
  359. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
  360. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +7 -5
  361. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +4 -0
  362. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +76 -168
  363. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +7 -0
  364. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +3 -1
  365. data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
  366. data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
  367. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +69 -31
  368. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +1 -0
  369. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  370. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
  371. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +823 -190
  372. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
  373. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
  374. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
  375. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
  376. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
  377. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +71 -0
  378. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
  379. data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
  380. data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
  381. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
  382. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
  383. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
  384. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
  385. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
  386. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
  387. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +1 -0
  388. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +1 -0
  389. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +1 -0
  390. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +1 -0
  391. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +1 -0
  392. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +1 -0
  393. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +1 -0
  394. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +1 -0
  395. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +1 -0
  396. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +1 -0
  397. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +1 -0
  398. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +1 -0
  399. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +1 -0
  400. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +1 -0
  401. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +1 -0
  402. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +1 -0
  403. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +1 -0
  404. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +1 -0
  405. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +1 -0
  406. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +1 -0
  407. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +1 -0
  408. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +1 -0
  409. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +1 -0
  410. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +1 -0
  411. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +1 -0
  412. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +1 -0
  413. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +1 -0
  414. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +1 -0
  415. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +1 -0
  416. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +1 -0
  417. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +1 -0
  418. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +1 -0
  419. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +1 -0
  420. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +1 -0
  421. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +1 -0
  422. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +1 -0
  423. data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
  424. data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
  425. data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
  426. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +215 -53
  427. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +4 -0
  428. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +2 -0
  429. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +2 -0
  430. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +1 -0
  431. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +1 -0
  432. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +0 -2
  433. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +11 -0
  434. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +2060 -535
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
  449. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +197 -48
  450. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +60 -59
  451. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +115 -113
  452. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +122 -31
  453. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
  454. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
  455. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
  456. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +125 -64
  457. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
  458. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
  459. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
  460. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
  461. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
  462. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
  463. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
  464. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
  465. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
  466. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
  467. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
  468. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
  469. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
  470. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -17
  471. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +43 -10
  472. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
  473. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
  474. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
  475. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
  476. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +5 -2
  477. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
  478. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +3 -2
  479. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
  480. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +11 -1
  481. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
  482. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
  483. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +171 -147
  484. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
  485. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +2202 -283
  486. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2610 -1403
  487. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
  488. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
  489. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +8 -7
  490. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +76 -95
  491. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +19 -1
  492. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
  493. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -50
  494. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +107 -184
  495. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
  496. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
  497. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
  498. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
  499. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
  500. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
  501. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +183 -78
  502. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
  503. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
  504. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +655 -495
  505. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
  506. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
  507. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
  508. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +8 -6
  509. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +5 -1
  510. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +80 -409
  511. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
  512. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
  513. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
  514. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
  515. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
  516. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
  517. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
  518. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +6 -4
  519. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
  520. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +2 -3
  521. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
  522. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
  523. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
  524. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
  525. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
  526. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +68 -48
  527. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
  528. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +18 -14
  529. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +1 -1
  530. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +244 -10
  531. data/ext/sources/ggml/src/ggml.c +110 -28
  532. data/ext/sources/ggml/src/gguf.cpp +173 -28
  533. data/ext/sources/include/parakeet.h +342 -0
  534. data/ext/sources/include/whisper.h +10 -0
  535. data/ext/sources/media/matmul.png +0 -0
  536. data/ext/sources/src/CMakeLists.txt +23 -0
  537. data/ext/sources/src/parakeet-arch.h +188 -0
  538. data/ext/sources/src/parakeet.cpp +3838 -0
  539. data/ext/sources/src/whisper.cpp +56 -12
  540. data/extsources.rb +26 -10
  541. data/lib/whisper/log_settable.rb +36 -0
  542. data/lib/whisper/model/uri.rb +13 -1
  543. data/lib/whisper/output.rb +74 -0
  544. data/sig/whisper.rbs +411 -62
  545. data/test/helper.rb +2 -0
  546. data/test/jfk_reader/jfk_reader.c +50 -7
  547. data/test/test_callback.rb +1 -0
  548. data/test/test_package.rb +6 -5
  549. data/test/test_parakeet.rb +28 -0
  550. data/test/test_parakeet_callback.rb +107 -0
  551. data/test/test_parakeet_context.rb +116 -0
  552. data/test/test_parakeet_context_params.rb +24 -0
  553. data/test/test_parakeet_model.rb +21 -0
  554. data/test/test_parakeet_params.rb +78 -0
  555. data/test/test_parakeet_segment.rb +42 -0
  556. data/test/test_parakeet_token.rb +73 -0
  557. data/test/test_params.rb +2 -0
  558. data/test/test_vad_segment.rb +1 -1
  559. data/test/test_whisper.rb +24 -6
  560. data/whispercpp.gemspec +2 -2
  561. metadata +215 -281
  562. data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
  563. data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
  564. data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
  565. data/ext/sources/bindings/javascript/package.json +0 -26
  566. data/ext/sources/bindings/javascript/whisper.js +0 -19
  567. data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
  568. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
  569. data/ext/sources/examples/addon.node/addon.cpp +0 -557
  570. data/ext/sources/examples/addon.node/index.js +0 -59
  571. data/ext/sources/examples/addon.node/package.json +0 -16
  572. data/ext/sources/examples/addon.node/vad-example.js +0 -132
  573. data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
  574. data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
  575. data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
  576. data/ext/sources/examples/coi-serviceworker.js +0 -146
  577. data/ext/sources/examples/command/CMakeLists.txt +0 -10
  578. data/ext/sources/examples/command/command.cpp +0 -802
  579. data/ext/sources/examples/command/commands.txt +0 -9
  580. data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
  581. data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
  582. data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
  583. data/ext/sources/examples/generate-karaoke.sh +0 -57
  584. data/ext/sources/examples/helpers.js +0 -191
  585. data/ext/sources/examples/livestream.sh +0 -112
  586. data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
  587. data/ext/sources/examples/lsp/lsp.cpp +0 -471
  588. data/ext/sources/examples/lsp/whisper.vim +0 -362
  589. data/ext/sources/examples/python/test_whisper_processor.py +0 -7
  590. data/ext/sources/examples/python/whisper_processor.py +0 -54
  591. data/ext/sources/examples/server/bench.js +0 -29
  592. data/ext/sources/examples/server.py +0 -120
  593. data/ext/sources/examples/stream/CMakeLists.txt +0 -10
  594. data/ext/sources/examples/stream/stream.cpp +0 -437
  595. data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
  596. data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
  597. data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
  598. data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
  599. data/ext/sources/examples/sycl/build.sh +0 -22
  600. data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
  601. data/ext/sources/examples/sycl/run-whisper.sh +0 -17
  602. data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -48
  603. data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
  604. data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -488
  605. data/ext/sources/examples/talk-llama/llama-adapter.h +0 -89
  606. data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2877
  607. data/ext/sources/examples/talk-llama/llama-arch.h +0 -628
  608. data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -919
  609. data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
  610. data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -896
  611. data/ext/sources/examples/talk-llama/llama-chat.h +0 -71
  612. data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3633
  613. data/ext/sources/examples/talk-llama/llama-context.h +0 -359
  614. data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
  615. data/ext/sources/examples/talk-llama/llama-cparams.h +0 -47
  616. data/ext/sources/examples/talk-llama/llama-ext.h +0 -12
  617. data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
  618. data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
  619. data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2735
  620. data/ext/sources/examples/talk-llama/llama-graph.h +0 -1031
  621. data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -258
  622. data/ext/sources/examples/talk-llama/llama-hparams.h +0 -353
  623. data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
  624. data/ext/sources/examples/talk-llama/llama-impl.h +0 -75
  625. data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
  626. data/ext/sources/examples/talk-llama/llama-io.h +0 -35
  627. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -330
  628. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
  629. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2285
  630. data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -389
  631. data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
  632. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +0 -275
  633. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +0 -140
  634. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
  635. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
  636. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1165
  637. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
  638. data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
  639. data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
  640. data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -752
  641. data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
  642. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1655
  643. data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -206
  644. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -299
  645. data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -40
  646. data/ext/sources/examples/talk-llama/llama-model.cpp +0 -9056
  647. data/ext/sources/examples/talk-llama/llama-model.h +0 -597
  648. data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1304
  649. data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
  650. data/ext/sources/examples/talk-llama/llama-sampler.cpp +0 -3885
  651. data/ext/sources/examples/talk-llama/llama-sampler.h +0 -42
  652. data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3970
  653. data/ext/sources/examples/talk-llama/llama-vocab.h +0 -187
  654. data/ext/sources/examples/talk-llama/llama.cpp +0 -1194
  655. data/ext/sources/examples/talk-llama/llama.h +0 -1573
  656. data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -190
  657. data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
  658. data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
  659. data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -137
  660. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
  661. data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -123
  662. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -143
  663. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -133
  664. data/ext/sources/examples/talk-llama/models/bert.cpp +0 -184
  665. data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -145
  666. data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
  667. data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
  668. data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
  669. data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
  670. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
  671. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
  672. data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
  673. data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -122
  674. data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
  675. data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -142
  676. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -262
  677. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +0 -445
  678. data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -132
  679. data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
  680. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -148
  681. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
  682. data/ext/sources/examples/talk-llama/models/eurobert.cpp +0 -97
  683. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +0 -145
  684. data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
  685. data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
  686. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -111
  687. data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
  688. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
  689. data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
  690. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
  691. data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
  692. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
  693. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
  694. data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -157
  695. data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
  696. data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
  697. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -195
  698. data/ext/sources/examples/talk-llama/models/granite.cpp +0 -210
  699. data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
  700. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -139
  701. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
  702. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -153
  703. data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
  704. data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
  705. data/ext/sources/examples/talk-llama/models/jais2.cpp +0 -123
  706. data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
  707. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +0 -381
  708. data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -196
  709. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
  710. data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
  711. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
  712. data/ext/sources/examples/talk-llama/models/llama.cpp +0 -175
  713. data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
  714. data/ext/sources/examples/talk-llama/models/mamba-base.cpp +0 -289
  715. data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -54
  716. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -129
  717. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -200
  718. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -123
  719. data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
  720. data/ext/sources/examples/talk-llama/models/models.h +0 -704
  721. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -109
  722. data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
  723. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -162
  724. data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
  725. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
  726. data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
  727. data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
  728. data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
  729. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
  730. data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
  731. data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
  732. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +0 -122
  733. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
  734. data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
  735. data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
  736. data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
  737. data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -320
  738. data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
  739. data/ext/sources/examples/talk-llama/models/plm.cpp +0 -169
  740. data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
  741. data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
  742. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
  743. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
  744. data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -120
  745. data/ext/sources/examples/talk-llama/models/qwen35.cpp +0 -381
  746. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +0 -422
  747. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -131
  748. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -525
  749. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -140
  750. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -132
  751. data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
  752. data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
  753. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -164
  754. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
  755. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
  756. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -137
  757. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
  758. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
  759. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
  760. data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
  761. data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
  762. data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
  763. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
  764. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +0 -165
  765. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
  766. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
  767. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
  768. data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
  769. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
  770. data/ext/sources/examples/talk-llama/speak +0 -40
  771. data/ext/sources/examples/talk-llama/speak.bat +0 -1
  772. data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
  773. data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
  774. data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
  775. data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
  776. data/ext/sources/examples/talk-llama/unicode.cpp +0 -1103
  777. data/ext/sources/examples/talk-llama/unicode.h +0 -111
  778. data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
  779. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
  780. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
  781. data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
  782. data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
  783. data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
  784. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
  785. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
  786. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
  787. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
  788. data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
  789. data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
  790. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
  791. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -155
  792. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
  793. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
  794. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +0 -123
  795. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +0 -17
  796. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +0 -333
  797. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
  798. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -182
  799. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
  800. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +0 -718
  801. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
  802. data/ext/sources/tests/CMakeLists.txt +0 -112
  803. data/ext/sources/tests/earnings21/eval.mk +0 -58
  804. data/ext/sources/tests/earnings21/eval.py +0 -68
  805. data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
  806. data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
  807. data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
  808. data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
  809. data/ext/sources/tests/earnings21/requirements.txt +0 -6
  810. data/ext/sources/tests/en-0-ref.txt +0 -1
  811. data/ext/sources/tests/en-1-ref.txt +0 -1
  812. data/ext/sources/tests/en-2-ref.txt +0 -1
  813. data/ext/sources/tests/es-0-ref.txt +0 -1
  814. data/ext/sources/tests/librispeech/eval.mk +0 -39
  815. data/ext/sources/tests/librispeech/eval.py +0 -47
  816. data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
  817. data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
  818. data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
  819. data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
  820. data/ext/sources/tests/librispeech/requirements.txt +0 -6
  821. data/ext/sources/tests/run-tests.sh +0 -130
  822. data/ext/sources/tests/test-c.c +0 -3
  823. data/ext/sources/tests/test-vad-full.cpp +0 -56
  824. data/ext/sources/tests/test-vad.cpp +0 -83
  825. data/ext/sources/tests/test-whisper.js +0 -58
  826. data/lib/whisper/context.rb +0 -15
  827. data/lib/whisper/segment.rb +0 -58
  828. /data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general_q8_0_f32.cl → gemv_noshuffle_q8_0_f32.cl} +0 -0
@@ -1,2735 +0,0 @@
1
- #include "llama-graph.h"
2
-
3
- #include "llama-impl.h"
4
- #include "llama-batch.h"
5
- #include "llama-cparams.h"
6
-
7
- #include "llama-kv-cache.h"
8
- #include "llama-kv-cache-iswa.h"
9
- #include "llama-memory-hybrid.h"
10
- #include "llama-memory-hybrid-iswa.h"
11
- #include "llama-memory-recurrent.h"
12
-
13
- #include <cassert>
14
- #include <cmath>
15
- #include <cstring>
16
- #include <numeric>
17
- #include <sstream>
18
- #include <unordered_set>
19
-
20
- // dedup helpers
21
-
22
- static ggml_tensor * build_kq_mask(
23
- ggml_context * ctx,
24
- const llama_kv_cache_context * mctx,
25
- const llama_ubatch & ubatch,
26
- const llama_cparams & cparams) {
27
- const auto n_kv = mctx->get_n_kv();
28
- const auto n_tokens = ubatch.n_tokens;
29
- const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
30
-
31
- return ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
32
- }
33
-
34
- static bool can_reuse_kq_mask(
35
- ggml_tensor * kq_mask,
36
- const llama_kv_cache_context * mctx,
37
- const llama_ubatch & ubatch,
38
- const llama_cparams & cparams) {
39
- const auto n_kv = mctx->get_n_kv();
40
- const auto n_tokens = ubatch.n_tokens;
41
- const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
42
-
43
- bool res = true;
44
-
45
- res &= (kq_mask->ne[0] == n_kv);
46
- res &= (kq_mask->ne[1] == n_tokens/n_stream);
47
- res &= (kq_mask->ne[2] == 1);
48
- res &= (kq_mask->ne[3] == n_stream);
49
-
50
- return res;
51
- }
52
-
53
- // impl
54
-
55
- void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
56
- if (ubatch->token) {
57
- const int64_t n_tokens = ubatch->n_tokens;
58
-
59
- ggml_backend_tensor_set(tokens, ubatch->token, 0, n_tokens*ggml_element_size(tokens));
60
- }
61
-
62
- if (ubatch->embd) {
63
- GGML_ASSERT(n_embd == embd->ne[0]);
64
-
65
- const int64_t n_tokens = ubatch->n_tokens;
66
-
67
- ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(embd));
68
- }
69
- }
70
-
71
- bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) {
72
- bool res = true;
73
-
74
- res &= (!params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
75
- res &= (!params.ubatch.embd) || (embd && embd->ne[1] == params.ubatch.n_tokens);
76
-
77
- return res;
78
- }
79
-
80
- void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
81
- if (ubatch->pos && pos) {
82
- const int64_t n_tokens = ubatch->n_tokens;
83
-
84
- if (ubatch->token && n_pos_per_embd == 4) {
85
- // in case we're using M-RoPE with text tokens, convert the 1D positions to 4D
86
- // the 3 first dims are the same, and 4th dim is all 0
87
- std::vector<llama_pos> pos_data(n_tokens*n_pos_per_embd);
88
- // copy the first dimension
89
- for (int i = 0; i < n_tokens; ++i) {
90
- pos_data[ i] = ubatch->pos[i];
91
- pos_data[ n_tokens + i] = ubatch->pos[i];
92
- pos_data[2 * n_tokens + i] = ubatch->pos[i];
93
- pos_data[3 * n_tokens + i] = 0; // 4th dim is 0
94
- }
95
- ggml_backend_tensor_set(pos, pos_data.data(), 0, pos_data.size()*ggml_element_size(pos));
96
- } else {
97
- ggml_backend_tensor_set(pos, ubatch->pos, 0, n_tokens*n_pos_per_embd*ggml_element_size(pos));
98
- }
99
- }
100
- }
101
-
102
- bool llm_graph_input_pos::can_reuse(const llm_graph_params & params) {
103
- bool res = true;
104
-
105
- res &= pos->ne[0] == params.ubatch.n_tokens*n_pos_per_embd;
106
-
107
- return res;
108
- }
109
-
110
- void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
111
- if (ubatch->pos && attn_scale) {
112
- const int64_t n_tokens = ubatch->n_tokens;
113
-
114
- GGML_ASSERT(f_attn_temp_scale != 0.0f);
115
- GGML_ASSERT(n_attn_temp_floor_scale != 0);
116
-
117
- std::vector<float> attn_scale_data(n_tokens, 0.0f);
118
- for (int i = 0; i < n_tokens; ++i) {
119
- const float pos = ubatch->pos[i];
120
- attn_scale_data[i] = std::log(
121
- std::floor((pos + f_attn_temp_offset) / n_attn_temp_floor_scale) + 1.0
122
- ) * f_attn_temp_scale + 1.0;
123
- }
124
-
125
- ggml_backend_tensor_set(attn_scale, attn_scale_data.data(), 0, n_tokens*ggml_element_size(attn_scale));
126
- }
127
- }
128
-
129
- void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
130
- if (pos_bucket) {
131
- const int64_t n_tokens = ubatch->n_tokens;
132
-
133
- GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
134
- GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
135
-
136
- int32_t * data = (int32_t *) pos_bucket->data;
137
-
138
- for (int j = 0; j < n_tokens; ++j) {
139
- for (int i = 0; i < n_tokens; ++i) {
140
- data[j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true);
141
- }
142
- }
143
- }
144
- }
145
-
146
- void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
147
- if (pos_bucket) {
148
- mctx->set_input_pos_bucket(pos_bucket, ubatch);
149
- }
150
- }
151
-
152
- void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
153
- GGML_ASSERT(out_ids);
154
-
155
- const int64_t n_tokens = ubatch->n_tokens;
156
-
157
- GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer));
158
- int32_t * data = (int32_t *) out_ids->data;
159
-
160
- if (n_outputs == n_tokens) {
161
- for (int i = 0; i < n_tokens; ++i) {
162
- data[i] = i;
163
- }
164
-
165
- return;
166
- }
167
-
168
- GGML_ASSERT(ubatch->output);
169
-
170
- int n_outputs = 0;
171
-
172
- for (int i = 0; i < n_tokens; ++i) {
173
- if (ubatch->output[i]) {
174
- data[n_outputs++] = i;
175
- }
176
- }
177
- }
178
-
179
- bool llm_graph_input_out_ids::can_reuse(const llm_graph_params & params) {
180
- bool res = true;
181
-
182
- res &= n_outputs == params.n_outputs;
183
-
184
- return res;
185
- }
186
-
187
- void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
188
- if (cparams.embeddings &&
189
- (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN ||
190
- cparams.pooling_type == LLAMA_POOLING_TYPE_RANK )) {
191
-
192
- const int64_t n_tokens = ubatch->n_tokens;
193
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
194
- const int64_t n_seqs_unq = ubatch->n_seqs_unq;
195
-
196
- GGML_ASSERT(mean);
197
- GGML_ASSERT(ggml_backend_buffer_is_host(mean->buffer));
198
-
199
- float * data = (float *) mean->data;
200
- memset(mean->data, 0, n_tokens*n_seqs_unq*ggml_element_size(mean));
201
-
202
- std::vector<uint64_t> sums(n_seqs_unq, 0);
203
- for (int i = 0; i < n_tokens; i += n_seq_tokens) {
204
- for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
205
- const llama_seq_id seq_id = ubatch->seq_id[i][s];
206
- const int32_t seq_idx = ubatch->seq_idx[seq_id];
207
-
208
- sums[seq_idx] += ubatch->n_seq_tokens;
209
- }
210
- }
211
-
212
- std::vector<float> div(n_seqs_unq, 0.0f);
213
- for (int s = 0; s < n_seqs_unq; ++s) {
214
- const uint64_t sum = sums[s];
215
- if (sum > 0) {
216
- div[s] = 1.0f/float(sum);
217
- }
218
- }
219
-
220
- for (int i = 0; i < n_tokens; i += n_seq_tokens) {
221
- for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
222
- const llama_seq_id seq_id = ubatch->seq_id[i][s];
223
- const int32_t seq_idx = ubatch->seq_idx[seq_id];
224
-
225
- for (int j = 0; j < n_seq_tokens; ++j) {
226
- data[seq_idx*n_tokens + i + j] = div[seq_idx];
227
- }
228
- }
229
- }
230
- }
231
- }
232
-
233
- void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
234
- const int64_t n_tokens = ubatch->n_tokens;
235
- const int64_t n_seqs_unq = ubatch->n_seqs_unq;
236
-
237
- if (cparams.embeddings && (
238
- cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
239
- cparams.pooling_type == LLAMA_POOLING_TYPE_RANK ||
240
- cparams.pooling_type == LLAMA_POOLING_TYPE_LAST
241
- )) {
242
- GGML_ASSERT(cls);
243
- GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
244
-
245
- uint32_t * data = (uint32_t *) cls->data;
246
- memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
247
-
248
- std::vector<int> target_pos(n_seqs_unq, -1);
249
- std::vector<int> target_row(n_seqs_unq, -1);
250
-
251
- const bool last = (
252
- cparams.pooling_type == LLAMA_POOLING_TYPE_LAST ||
253
- (cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && (arch == LLM_ARCH_QWEN3 || arch == LLM_ARCH_QWEN3VL)) // qwen3 reranking & embedding models use last token
254
- );
255
-
256
- for (int i = 0; i < n_tokens; ++i) {
257
- const llama_pos pos = ubatch->pos[i];
258
-
259
- for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
260
- const llama_seq_id seq_id = ubatch->seq_id[i][s];
261
- const int32_t seq_idx = ubatch->seq_idx[seq_id];
262
-
263
- if (
264
- (target_pos[seq_idx] == -1) ||
265
- ( last && pos >= target_pos[seq_idx]) ||
266
- (!last && pos < target_pos[seq_idx])
267
- ) {
268
- target_pos[seq_idx] = pos;
269
- target_row[seq_idx] = i;
270
- }
271
- }
272
- }
273
-
274
- for (int s = 0; s < n_seqs_unq; ++s) {
275
- if (target_row[s] >= 0) {
276
- data[s] = target_row[s];
277
- }
278
- }
279
- }
280
- }
281
-
282
- void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
283
- GGML_UNUSED(ubatch);
284
-
285
- const int64_t n_rs = mctx->get_n_rs();
286
-
287
- if (s_copy) {
288
- GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
289
- int32_t * data = (int32_t *) s_copy->data;
290
-
291
- // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
292
- for (uint32_t i = 0; i < n_rs; ++i) {
293
- data[i] = mctx->s_copy(i);
294
- }
295
- }
296
- }
297
-
298
- bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) {
299
- const auto * mctx = static_cast<const llama_memory_recurrent_context *>(params.mctx);
300
-
301
- this->mctx = mctx;
302
-
303
- bool res = true;
304
-
305
- res &= s_copy->ne[0] == mctx->get_n_rs();
306
-
307
- res &= s_copy_main->ne[0] == params.ubatch.n_seqs;
308
- res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs;
309
-
310
- res &= head == mctx->get_head();
311
- res &= rs_z == mctx->get_rs_z();
312
-
313
- return res;
314
- }
315
-
316
- void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
317
- GGML_UNUSED(ubatch);
318
-
319
- if (cross_embd && !cross->v_embd.empty()) {
320
- assert(cross_embd->type == GGML_TYPE_F32);
321
-
322
- ggml_backend_tensor_set(cross_embd, cross->v_embd.data(), 0, ggml_nbytes(cross_embd));
323
- }
324
- }
325
-
326
- static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
327
- LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
328
- const char * swa_type_str = "unknown";
329
-
330
- switch (swa_type) {
331
- case LLAMA_SWA_TYPE_NONE: swa_type_str = "LLAMA_SWA_TYPE_NONE"; break;
332
- case LLAMA_SWA_TYPE_STANDARD: swa_type_str = "LLAMA_SWA_TYPE_STANDARD"; break;
333
- case LLAMA_SWA_TYPE_CHUNKED: swa_type_str = "LLAMA_SWA_TYPE_CHUNKED"; break;
334
- case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = "LLAMA_SWA_TYPE_SYMMETRIC"; break;
335
- };
336
-
337
- LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
338
- LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
339
- LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
340
-
341
- LLAMA_LOG_DEBUG(" ");
342
- for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) {
343
- LLAMA_LOG_DEBUG("%2d", j);
344
- }
345
- LLAMA_LOG_DEBUG("\n");
346
-
347
- for (int i = 0; i < std::min((int64_t)20, n_tokens); ++i) {
348
- LLAMA_LOG_DEBUG(" %2d ", i);
349
- for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) {
350
- float val = data[i * n_kv + j];
351
- if (val == -INFINITY) {
352
- LLAMA_LOG_DEBUG(" ∞");
353
- } else {
354
- LLAMA_LOG_DEBUG(" 0");
355
- }
356
- }
357
- LLAMA_LOG_DEBUG("\n");
358
- }
359
- }
360
-
361
- void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
362
- const int64_t n_kv = ubatch->n_tokens;
363
- const int64_t n_tokens = ubatch->n_tokens;
364
-
365
- const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) {
366
- for (int i1 = 0; i1 < n_tokens; ++i1) {
367
- const llama_seq_id s1 = ubatch->seq_id[i1][0];
368
- const llama_pos p1 = ubatch->pos[i1];
369
-
370
- const uint64_t idst = i1*n_kv;
371
-
372
- for (int i0 = 0; i0 < n_tokens; ++i0) {
373
- const llama_seq_id s0 = ubatch->seq_id[i0][0];
374
- const llama_pos p0 = ubatch->pos[i0];
375
-
376
- // mask different sequences
377
- if (s0 != s1) {
378
- continue;
379
- }
380
-
381
- // mask future tokens
382
- if (cparams.causal_attn && p0 > p1) {
383
- continue;
384
- }
385
-
386
- // apply SWA if any
387
- if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
388
- continue;
389
- }
390
-
391
- data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
392
- }
393
- }
394
- };
395
-
396
- {
397
- GGML_ASSERT(self_kq_mask);
398
- GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
399
-
400
- float * data = (float *) self_kq_mask->data;
401
-
402
- std::fill(data, data + ggml_nelements(self_kq_mask), -INFINITY);
403
-
404
- fill_mask(data, 0, LLAMA_SWA_TYPE_NONE);
405
-
406
- if (debug) {
407
- print_mask(data, n_tokens, n_kv, 0, LLAMA_SWA_TYPE_NONE);
408
- }
409
- }
410
-
411
- if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
412
- GGML_ASSERT(self_kq_mask_swa);
413
- GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
414
-
415
- float * data = (float *) self_kq_mask_swa->data;
416
-
417
- std::fill(data, data + ggml_nelements(self_kq_mask_swa), -INFINITY);
418
-
419
- fill_mask(data, hparams.n_swa, hparams.swa_type);
420
-
421
- if (debug) {
422
- print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
423
- }
424
- }
425
- }
426
-
427
- void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
428
- mctx->set_input_k_idxs(self_k_idxs, ubatch);
429
- mctx->set_input_v_idxs(self_v_idxs, ubatch);
430
-
431
- mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
432
- }
433
-
434
- bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
435
- const auto * mctx = static_cast<const llama_kv_cache_context *>(params.mctx);
436
-
437
- this->mctx = mctx;
438
-
439
- bool res = true;
440
-
441
- res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
442
- //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
443
-
444
- res &= can_reuse_kq_mask(self_kq_mask, mctx, params.ubatch, params.cparams);
445
-
446
- return res;
447
- }
448
-
449
- void llm_graph_input_attn_k::set_input(const llama_ubatch * ubatch) {
450
- mctx->set_input_k_idxs(self_k_idxs, ubatch);
451
-
452
- mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
453
- }
454
-
455
- bool llm_graph_input_attn_k::can_reuse(const llm_graph_params & params) {
456
- const auto * mctx = static_cast<const llama_kv_cache_context *>(params.mctx);
457
-
458
- this->mctx = mctx;
459
-
460
- bool res = true;
461
-
462
- res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
463
-
464
- res &= can_reuse_kq_mask(self_kq_mask, mctx, params.ubatch, params.cparams);
465
-
466
- return res;
467
- }
468
-
469
- void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
470
- mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
471
- mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
472
-
473
- mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
474
-
475
- mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
476
- mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
477
-
478
- mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
479
- }
480
-
481
- bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
482
- const auto * mctx = static_cast<const llama_kv_cache_iswa_context *>(params.mctx);
483
-
484
- this->mctx = mctx;
485
-
486
- bool res = true;
487
-
488
- res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
489
- //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
490
-
491
- res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
492
- //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
493
-
494
- res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams);
495
- res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams);
496
-
497
- return res;
498
- }
499
-
500
- void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
501
- GGML_ASSERT(cross_kq_mask);
502
-
503
- const int64_t n_enc = cross_kq_mask->ne[0];
504
- const int64_t n_tokens = ubatch->n_tokens;
505
-
506
- GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
507
- GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
508
-
509
- float * data = (float *) cross_kq_mask->data;
510
-
511
- for (int i = 0; i < n_tokens; ++i) {
512
- GGML_ASSERT(!cross->seq_ids_enc.empty() && "llama_encode must be called first");
513
- for (int j = 0; j < n_enc; ++j) {
514
- float f = -INFINITY;
515
-
516
- for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
517
- const llama_seq_id seq_id = ubatch->seq_id[i][s];
518
-
519
- if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
520
- f = 0.0f;
521
- }
522
- }
523
-
524
- data[i*n_enc + j] = f;
525
- }
526
- }
527
- }
528
-
529
- void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
530
- mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
531
- mctx->get_attn()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);
532
-
533
- mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
534
-
535
- const int64_t n_rs = mctx->get_recr()->get_n_rs();
536
-
537
- if (inp_rs->s_copy) {
538
- GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
539
- int32_t * data = (int32_t *) inp_rs->s_copy->data;
540
-
541
- // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
542
- for (uint32_t i = 0; i < n_rs; ++i) {
543
- data[i] = mctx->get_recr()->s_copy(i);
544
- }
545
- }
546
- }
547
-
548
- bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
549
- const auto * mctx = static_cast<const llama_memory_hybrid_context *>(params.mctx);
550
-
551
- this->mctx = mctx;
552
-
553
- bool res = true;
554
-
555
- res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
556
- //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
557
-
558
- res &= can_reuse_kq_mask(inp_attn->self_kq_mask, mctx->get_attn(), params.ubatch, params.cparams);
559
-
560
- res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
561
-
562
- res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
563
- res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
564
-
565
- res &= inp_rs->head == mctx->get_recr()->get_head();
566
- res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
567
-
568
- return res;
569
- }
570
-
571
- // TODO: Hybrid input classes are a bit redundant.
572
- // Instead of creating a hybrid input, the graph can simply create 2 separate inputs.
573
- // Refactoring is required in the future.
574
- void llm_graph_input_mem_hybrid_k::set_input(const llama_ubatch * ubatch) {
575
- mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
576
-
577
- mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
578
-
579
- const int64_t n_rs = mctx->get_recr()->get_n_rs();
580
-
581
- if (inp_rs->s_copy) {
582
- GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
583
- int32_t * data = (int32_t *) inp_rs->s_copy->data;
584
-
585
- // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
586
- for (uint32_t i = 0; i < n_rs; ++i) {
587
- data[i] = mctx->get_recr()->s_copy(i);
588
- }
589
- }
590
- }
591
-
592
- bool llm_graph_input_mem_hybrid_k::can_reuse(const llm_graph_params & params) {
593
- const auto * mctx = static_cast<const llama_memory_hybrid_context *>(params.mctx);
594
-
595
- this->mctx = mctx;
596
-
597
- bool res = true;
598
-
599
- res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
600
-
601
- res &= can_reuse_kq_mask(inp_attn->self_kq_mask, mctx->get_attn(), params.ubatch, params.cparams);
602
-
603
- res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
604
-
605
- res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
606
- res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
607
-
608
- res &= inp_rs->head == mctx->get_recr()->get_head();
609
- res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
610
-
611
- return res;
612
- }
613
-
614
- void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
615
- const auto * attn_ctx = mctx->get_attn();
616
-
617
- // base tensors may not be allocated if there are no non-SWA attention layers
618
- if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
619
- attn_ctx->get_base()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
620
- attn_ctx->get_base()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);
621
-
622
- attn_ctx->get_base()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
623
- }
624
-
625
- // swa tensors may not be allocated if there are no SWA attention layers
626
- if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
627
- attn_ctx->get_swa()->set_input_k_idxs(inp_attn->self_k_idxs_swa, ubatch);
628
- attn_ctx->get_swa()->set_input_v_idxs(inp_attn->self_v_idxs_swa, ubatch);
629
-
630
- attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn);
631
- }
632
-
633
- const int64_t n_rs = mctx->get_recr()->get_n_rs();
634
-
635
- if (inp_rs->s_copy) {
636
- GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
637
- int32_t * data = (int32_t *) inp_rs->s_copy->data;
638
-
639
- // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
640
- for (uint32_t i = 0; i < n_rs; ++i) {
641
- data[i] = mctx->get_recr()->s_copy(i);
642
- }
643
- }
644
- }
645
-
646
- bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params) {
647
- const auto * mctx = static_cast<const llama_memory_hybrid_iswa_context *>(params.mctx);
648
-
649
- this->mctx = mctx;
650
-
651
- bool res = true;
652
-
653
- const auto * attn_ctx = mctx->get_attn();
654
-
655
- // base tensors may not be allocated if there are no non-SWA attention layers
656
- if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
657
- res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
658
- //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
659
-
660
- res &= can_reuse_kq_mask(inp_attn->self_kq_mask, attn_ctx->get_base(), params.ubatch, params.cparams);
661
- }
662
-
663
- // swa tensors may not be allocated if there are no SWA attention layers
664
- if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
665
- res &= inp_attn->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
666
- //res &= inp_attn->self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
667
-
668
- res &= can_reuse_kq_mask(inp_attn->self_kq_mask_swa, attn_ctx->get_swa(), params.ubatch, params.cparams);
669
- }
670
-
671
- res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
672
-
673
- res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
674
- res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
675
-
676
- res &= inp_rs->head == mctx->get_recr()->get_head();
677
- res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
678
-
679
- return res;
680
- }
681
-
682
- void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) {
683
- // set the inputs only for the active samplers in the current ubatch
684
- std::unordered_set<llama_seq_id> active_samplers;
685
- for (uint32_t i = 0; i < ubatch->n_tokens; i++) {
686
- if (ubatch->output[i]) {
687
- llama_seq_id seq_id = ubatch->seq_id[i][0];
688
- active_samplers.insert(seq_id);
689
- }
690
- }
691
-
692
- for (auto seq_id : active_samplers) {
693
- if (samplers.find(seq_id) == samplers.end()) {
694
- continue;
695
- }
696
-
697
- auto & sampler = samplers[seq_id];
698
-
699
- if (sampler->iface->backend_set_input) {
700
- sampler->iface->backend_set_input(sampler);
701
- }
702
- }
703
- }
704
-
705
- bool llm_graph_input_sampling::can_reuse(const llm_graph_params & params) {
706
- if (samplers.size() != params.samplers.size()) {
707
- return false;
708
- }
709
-
710
- for (const auto & [seq_id, sampler] : params.samplers) {
711
- if (samplers[seq_id] != sampler) {
712
- return false;
713
- }
714
- }
715
-
716
- return true;
717
- }
718
-
719
- //
720
- // llm_graph_result
721
- //
722
-
723
- llm_graph_result::llm_graph_result(int64_t max_nodes) : max_nodes(max_nodes) {
724
- reset();
725
-
726
- const char * LLAMA_GRAPH_RESULT_DEBUG = getenv("LLAMA_GRAPH_RESULT_DEBUG");
727
- debug = LLAMA_GRAPH_RESULT_DEBUG ? atoi(LLAMA_GRAPH_RESULT_DEBUG) : 0;
728
- }
729
-
730
- int64_t llm_graph_result::get_max_nodes() const {
731
- return max_nodes;
732
- }
733
-
734
- void llm_graph_result::reset() {
735
- t_inp_tokens = nullptr;
736
- t_inp_embd = nullptr;
737
- t_logits = nullptr;
738
- t_embd = nullptr;
739
- t_embd_pooled = nullptr;
740
- t_sampled.clear();
741
- t_sampled_probs.clear();
742
- t_sampled_logits.clear();
743
- t_candidates.clear();
744
-
745
- params = {};
746
-
747
- inputs.clear();
748
-
749
- buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
750
-
751
- ggml_init_params params = {
752
- /*.mem_size =*/ buf_compute_meta.size(),
753
- /*.mem_buffer =*/ buf_compute_meta.data(),
754
- /*.no_alloc =*/ true,
755
- };
756
-
757
- ctx_compute.reset(ggml_init(params));
758
-
759
- gf = ggml_new_graph_custom(ctx_compute.get(), max_nodes, false);
760
- }
761
-
762
- void llm_graph_result::set_inputs(const llama_ubatch * ubatch) {
763
- for (auto & input : inputs) {
764
- input->set_input(ubatch);
765
- }
766
- }
767
-
768
- void llm_graph_result::set_outputs() {
769
- if (t_logits != nullptr) {
770
- ggml_set_output(t_logits);
771
- }
772
- if (t_embd != nullptr) {
773
- ggml_set_output(t_embd);
774
- }
775
- if (t_embd_pooled != nullptr) {
776
- ggml_set_output(t_embd_pooled);
777
- }
778
- for (auto & [seq_id, t] : t_sampled) {
779
- if (t != nullptr) {
780
- ggml_set_output(t);
781
- }
782
- }
783
- for (auto & [seq_id, t] : t_sampled_probs) {
784
- if (t != nullptr) {
785
- ggml_set_output(t);
786
- }
787
- }
788
- for (auto & [seq_id, t] : t_sampled_logits) {
789
- if (t != nullptr) {
790
- ggml_set_output(t);
791
- }
792
- }
793
- for (auto & [seq_id, t] : t_candidates) {
794
- if (t != nullptr) {
795
- ggml_set_output(t);
796
- }
797
- }
798
- }
799
-
800
- bool llm_graph_result::can_reuse(const llm_graph_params & params) {
801
- if (!this->params.allow_reuse(params)) {
802
- if (debug > 1) {
803
- LLAMA_LOG_DEBUG("%s: cannot reuse graph due to incompatible graph parameters\n", __func__);
804
- }
805
-
806
- return false;
807
- }
808
-
809
- if (debug > 1) {
810
- LLAMA_LOG_DEBUG("%s: checking compatibility of %d inputs:\n", __func__, (int) inputs.size());
811
- }
812
-
813
- bool res = true;
814
-
815
- for (auto & input : inputs) {
816
- const bool cur = input->can_reuse(params);
817
-
818
- if (debug > 1) {
819
- LLAMA_LOG_DEBUG("%s: can_reuse = %d\n", "placeholder", cur);
820
- }
821
-
822
- res = res && cur;
823
- }
824
-
825
- if (debug > 0) {
826
- LLAMA_LOG_DEBUG("%s: can reuse graph = %d\n", __func__, res);
827
- }
828
-
829
- return res;
830
- }
831
-
832
- llm_graph_input_i * llm_graph_result::add_input(llm_graph_input_ptr input) {
833
- inputs.emplace_back(std::move(input));
834
- return inputs.back().get();
835
- }
836
-
837
- void llm_graph_result::set_params(const llm_graph_params & params) {
838
- this->params = params;
839
- }
840
-
841
- //
842
- // llm_graph_context
843
- //
844
-
845
- llm_graph_context::llm_graph_context(const llm_graph_params & params) :
846
- arch (params.arch),
847
- hparams (params.hparams),
848
- cparams (params.cparams),
849
- ubatch (params.ubatch),
850
- n_embd (hparams.n_embd),
851
- n_layer (hparams.n_layer),
852
- n_rot (hparams.n_rot()),
853
- n_ctx (cparams.n_ctx),
854
- n_head (hparams.n_head()),
855
- n_head_kv (hparams.n_head_kv()),
856
- n_embd_head_k (hparams.n_embd_head_k()),
857
- n_embd_k_gqa (hparams.n_embd_k_gqa()),
858
- n_embd_head_v (hparams.n_embd_head_v()),
859
- n_embd_v_gqa (hparams.n_embd_v_gqa()),
860
- n_expert (hparams.n_expert),
861
- n_expert_used (cparams.warmup ? hparams.n_expert : hparams.n_expert_used),
862
- freq_base (cparams.rope_freq_base),
863
- freq_scale (cparams.rope_freq_scale),
864
- ext_factor (cparams.yarn_ext_factor),
865
- attn_factor (cparams.yarn_attn_factor),
866
- beta_fast (cparams.yarn_beta_fast),
867
- beta_slow (cparams.yarn_beta_slow),
868
- norm_eps (hparams.f_norm_eps),
869
- norm_rms_eps (hparams.f_norm_rms_eps),
870
- n_tokens (ubatch.n_tokens),
871
- n_outputs (params.n_outputs),
872
- n_ctx_orig (cparams.n_ctx_orig_yarn),
873
- pooling_type (cparams.pooling_type),
874
- rope_type (hparams.rope_type),
875
- sched (params.sched),
876
- backend_cpu (params.backend_cpu),
877
- cvec (params.cvec),
878
- loras (params.loras),
879
- mctx (params.mctx),
880
- cross (params.cross),
881
- samplers (params.samplers),
882
- cb_func (params.cb),
883
- res (params.res),
884
- ctx0 (res->get_ctx()),
885
- gf (res->get_gf()) {
886
- res->set_params(params);
887
- }
888
-
889
- void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
890
- if (cb_func) {
891
- cb_func(ubatch, cur, name, il);
892
- }
893
- }
894
-
895
- ggml_tensor * llm_graph_context::build_cvec(
896
- ggml_tensor * cur,
897
- int il) const {
898
- return cvec->apply_to(ctx0, cur, il);
899
- }
900
-
901
- ggml_tensor * llm_graph_context::build_lora_mm(
902
- ggml_tensor * w,
903
- ggml_tensor * cur,
904
- ggml_tensor * w_s) const {
905
- ggml_tensor * res = ggml_mul_mat(ctx0, w, cur);
906
-
907
- for (const auto & lora : *loras) {
908
- llama_adapter_lora_weight * lw = lora.first->get_weight(w);
909
- if (lw == nullptr) {
910
- continue;
911
- }
912
-
913
- const float adapter_scale = lora.second;
914
- const float scale = lw->get_scale(lora.first->alpha, adapter_scale);
915
-
916
- ggml_tensor * ab_cur = ggml_mul_mat(
917
- ctx0, lw->b,
918
- ggml_mul_mat(ctx0, lw->a, cur)
919
- );
920
-
921
- ab_cur = ggml_scale(ctx0, ab_cur, scale);
922
- res = ggml_add(ctx0, res, ab_cur);
923
- }
924
-
925
- if (w_s) {
926
- res = ggml_mul(ctx0, res, w_s);
927
- }
928
-
929
- return res;
930
- }
931
-
932
- ggml_tensor * llm_graph_context::build_lora_mm_id(
933
- ggml_tensor * w, // ggml_tensor * as
934
- ggml_tensor * cur, // ggml_tensor * b
935
- ggml_tensor * ids) const {
936
- ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids);
937
- for (const auto & lora : *loras) {
938
- llama_adapter_lora_weight * lw = lora.first->get_weight(w);
939
- if (lw == nullptr) {
940
- continue;
941
- }
942
-
943
- const float alpha = lora.first->alpha;
944
- const float rank = (float) lw->b->ne[0];
945
- const float scale = alpha ? lora.second * alpha / rank : lora.second;
946
-
947
- ggml_tensor * ab_cur = ggml_mul_mat_id(
948
- ctx0, lw->b,
949
- ggml_mul_mat_id(ctx0, lw->a, cur, ids),
950
- ids
951
- );
952
-
953
- ab_cur = ggml_scale(ctx0, ab_cur, scale);
954
- res = ggml_add(ctx0, res, ab_cur);
955
- }
956
-
957
- return res;
958
- }
959
-
960
- ggml_tensor * llm_graph_context::build_norm(
961
- ggml_tensor * cur,
962
- ggml_tensor * mw,
963
- ggml_tensor * mb,
964
- llm_norm_type type,
965
- int il) const {
966
- switch (type) {
967
- case LLM_NORM: cur = ggml_norm (ctx0, cur, hparams.f_norm_eps); break;
968
- case LLM_NORM_RMS: cur = ggml_rms_norm(ctx0, cur, hparams.f_norm_rms_eps); break;
969
- case LLM_NORM_GROUP:
970
- {
971
- cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], 1, cur->ne[1]);
972
- cur = ggml_group_norm(ctx0, cur, hparams.n_norm_groups, hparams.f_norm_group_eps);
973
- cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[2]);
974
- } break;
975
- }
976
-
977
- if (mw || mb) {
978
- cb(cur, "norm", il);
979
- }
980
-
981
- if (mw) {
982
- cur = ggml_mul(ctx0, cur, mw);
983
- if (mb) {
984
- cb(cur, "norm_w", il);
985
- }
986
- }
987
-
988
- if (mb) {
989
- cur = ggml_add(ctx0, cur, mb);
990
- }
991
-
992
- return cur;
993
- }
994
-
995
- ggml_tensor * llm_graph_context::build_ffn(
996
- ggml_tensor * cur,
997
- ggml_tensor * up,
998
- ggml_tensor * up_b,
999
- ggml_tensor * up_s,
1000
- ggml_tensor * gate,
1001
- ggml_tensor * gate_b,
1002
- ggml_tensor * gate_s,
1003
- ggml_tensor * down,
1004
- ggml_tensor * down_b,
1005
- ggml_tensor * down_s,
1006
- ggml_tensor * act_scales,
1007
- llm_ffn_op_type type_op,
1008
- llm_ffn_gate_type type_gate,
1009
- int il) const {
1010
- ggml_tensor * tmp = up ? build_lora_mm(up, cur) : cur;
1011
- cb(tmp, "ffn_up", il);
1012
-
1013
- if (up_b) {
1014
- tmp = ggml_add(ctx0, tmp, up_b);
1015
- cb(tmp, "ffn_up_b", il);
1016
- }
1017
-
1018
- if (up_s) {
1019
- tmp = ggml_mul(ctx0, tmp, up_s);
1020
- cb(tmp, "ffn_up_s", il);
1021
- }
1022
-
1023
- if (gate) {
1024
- switch (type_gate) {
1025
- case LLM_FFN_SEQ:
1026
- {
1027
- cur = build_lora_mm(gate, tmp);
1028
- cb(cur, "ffn_gate", il);
1029
- } break;
1030
- case LLM_FFN_PAR:
1031
- {
1032
- cur = build_lora_mm(gate, cur);
1033
- cb(cur, "ffn_gate", il);
1034
- } break;
1035
- }
1036
-
1037
- if (gate_b) {
1038
- cur = ggml_add(ctx0, cur, gate_b);
1039
- cb(cur, "ffn_gate_b", il);
1040
- }
1041
-
1042
- if (gate_s) {
1043
- cur = ggml_mul(ctx0, cur, gate_s);
1044
- cb(cur, "ffn_gate_s", il);
1045
- }
1046
-
1047
- } else {
1048
- cur = tmp;
1049
- }
1050
-
1051
- switch (type_op) {
1052
- case LLM_FFN_SILU:
1053
- if (gate && type_gate == LLM_FFN_PAR) {
1054
- // Step35: HF clamps gate (after SiLU) and up before multiplication
1055
- if (arch == LLM_ARCH_STEP35 && il >= 0) {
1056
- const float limit = hparams.swiglu_clamp_shexp[il];
1057
- constexpr float eps = 1e-6f;
1058
- if (limit > eps) {
1059
- ggml_tensor * gate_act = ggml_silu(ctx0, cur);
1060
- cb(gate_act, "ffn_silu", il);
1061
- gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit);
1062
- cb(gate_act, "ffn_silu_clamped", il);
1063
-
1064
- tmp = ggml_clamp(ctx0, tmp, -limit, limit);
1065
- cb(tmp, "ffn_up_clamped", il);
1066
-
1067
- cur = ggml_mul(ctx0, gate_act, tmp);
1068
- cb(cur, "ffn_swiglu_limited", il);
1069
- type_gate = LLM_FFN_SEQ;
1070
- break;
1071
- }
1072
- }
1073
-
1074
- cur = ggml_swiglu_split(ctx0, cur, tmp);
1075
- cb(cur, "ffn_swiglu", il);
1076
- type_gate = LLM_FFN_SEQ;
1077
- } else {
1078
- cur = ggml_silu(ctx0, cur);
1079
- cb(cur, "ffn_silu", il);
1080
- } break;
1081
- case LLM_FFN_GELU:
1082
- if (gate && type_gate == LLM_FFN_PAR) {
1083
- cur = ggml_geglu_split(ctx0, cur, tmp);
1084
- cb(cur, "ffn_geglu", il);
1085
- type_gate = LLM_FFN_SEQ;
1086
- } else {
1087
- cur = ggml_gelu(ctx0, cur);
1088
- cb(cur, "ffn_gelu", il);
1089
- if (act_scales != NULL) {
1090
- cur = ggml_div(ctx0, cur, act_scales);
1091
- cb(cur, "ffn_act", il);
1092
- }
1093
- } break;
1094
- case LLM_FFN_RELU:
1095
- if (gate && type_gate == LLM_FFN_PAR) {
1096
- cur = ggml_reglu_split(ctx0, cur, tmp);
1097
- cb(cur, "ffn_reglu", il);
1098
- type_gate = LLM_FFN_SEQ;
1099
- } else {
1100
- cur = ggml_relu(ctx0, cur);
1101
- cb(cur, "ffn_relu", il);
1102
- } break;
1103
- case LLM_FFN_RELU_SQR:
1104
- {
1105
- cur = ggml_relu(ctx0, cur);
1106
- cb(cur, "ffn_relu", il);
1107
-
1108
- cur = ggml_sqr(ctx0, cur);
1109
- cb(cur, "ffn_sqr(relu)", il);
1110
- } break;
1111
- case LLM_FFN_SWIGLU:
1112
- {
1113
- cur = ggml_swiglu(ctx0, cur);
1114
- cb(cur, "ffn_swiglu", il);
1115
- } break;
1116
- case LLM_FFN_GEGLU:
1117
- {
1118
- cur = ggml_geglu(ctx0, cur);
1119
- cb(cur, "ffn_geglu", il);
1120
- } break;
1121
- case LLM_FFN_REGLU:
1122
- {
1123
- cur = ggml_reglu(ctx0, cur);
1124
- cb(cur, "ffn_reglu", il);
1125
- } break;
1126
- default:
1127
- GGML_ABORT("fatal error");
1128
- }
1129
-
1130
- if (gate && type_gate == LLM_FFN_PAR) {
1131
- cur = ggml_mul(ctx0, cur, tmp);
1132
- cb(cur, "ffn_gate_par", il);
1133
- }
1134
-
1135
- if (down) {
1136
- cur = build_lora_mm(down, cur);
1137
- if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) {
1138
- // GLM4, GLM4_MOE, and JAIS2 seem to have numerical issues with half-precision accumulators
1139
- ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
1140
- }
1141
- }
1142
-
1143
- if (down_b) {
1144
- cb(cur, "ffn_down", il);
1145
- }
1146
-
1147
- if (down_b) {
1148
- cur = ggml_add(ctx0, cur, down_b);
1149
- }
1150
-
1151
- if (down_s) {
1152
- cur = ggml_mul(ctx0, cur, down_s);
1153
- cb(cur, "ffn_down_s", il);
1154
- }
1155
-
1156
- return cur;
1157
- }
1158
-
1159
- ggml_tensor * llm_graph_context::build_moe_ffn(
1160
- ggml_tensor * cur,
1161
- ggml_tensor * gate_inp,
1162
- ggml_tensor * up_exps,
1163
- ggml_tensor * gate_exps,
1164
- ggml_tensor * down_exps,
1165
- ggml_tensor * exp_probs_b,
1166
- int64_t n_expert,
1167
- int64_t n_expert_used,
1168
- llm_ffn_op_type type_op,
1169
- bool norm_w,
1170
- float w_scale,
1171
- llama_expert_gating_func_type gating_op,
1172
- int il,
1173
- ggml_tensor * probs_in,
1174
- ggml_tensor * gate_up_exps,
1175
- ggml_tensor * up_exps_s,
1176
- ggml_tensor * gate_exps_s,
1177
- ggml_tensor * down_exps_s) const {
1178
- return build_moe_ffn(
1179
- cur,
1180
- gate_inp, /* gate_inp_b */ nullptr,
1181
- up_exps, /* up_exps_b */ nullptr,
1182
- gate_exps, /* gate_exps_b */ nullptr,
1183
- down_exps, /* down_exps_b */ nullptr,
1184
- exp_probs_b,
1185
- n_expert,
1186
- n_expert_used,
1187
- type_op,
1188
- norm_w,
1189
- w_scale,
1190
- gating_op,
1191
- il,
1192
- probs_in,
1193
- gate_up_exps,
1194
- /* gate_up_exps_b */ nullptr,
1195
- up_exps_s,
1196
- gate_exps_s,
1197
- down_exps_s
1198
- );
1199
- }
1200
-
1201
- ggml_tensor * llm_graph_context::build_moe_ffn(
1202
- ggml_tensor * cur,
1203
- ggml_tensor * gate_inp,
1204
- ggml_tensor * gate_inp_b,
1205
- ggml_tensor * up_exps,
1206
- ggml_tensor * up_exps_b,
1207
- ggml_tensor * gate_exps,
1208
- ggml_tensor * gate_exps_b,
1209
- ggml_tensor * down_exps,
1210
- ggml_tensor * down_exps_b,
1211
- ggml_tensor * exp_probs_b,
1212
- int64_t n_expert,
1213
- int64_t n_expert_used,
1214
- llm_ffn_op_type type_op,
1215
- bool norm_w,
1216
- float w_scale,
1217
- llama_expert_gating_func_type gating_op,
1218
- int il,
1219
- ggml_tensor * probs_in,
1220
- ggml_tensor * gate_up_exps,
1221
- ggml_tensor * gate_up_exps_b,
1222
- ggml_tensor * up_exps_s,
1223
- ggml_tensor * gate_exps_s,
1224
- ggml_tensor * down_exps_s) const {
1225
- const int64_t n_embd = cur->ne[0];
1226
- const int64_t n_tokens = cur->ne[1];
1227
- const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
1228
-
1229
- ggml_tensor * logits = nullptr;
1230
-
1231
- if (probs_in == nullptr) {
1232
- logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens]
1233
- cb(logits, "ffn_moe_logits", il);
1234
- } else {
1235
- logits = probs_in;
1236
- }
1237
-
1238
- if (gate_inp_b) {
1239
- logits = ggml_add(ctx0, logits, gate_inp_b);
1240
- cb(logits, "ffn_moe_logits_biased", il);
1241
- }
1242
-
1243
- ggml_tensor * probs = nullptr;
1244
- switch (gating_op) {
1245
- case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX:
1246
- {
1247
- probs = ggml_soft_max(ctx0, logits); // [n_expert, n_tokens]
1248
- } break;
1249
- case LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID:
1250
- {
1251
- probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
1252
- } break;
1253
- case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT:
1254
- {
1255
- probs = logits; // [n_expert, n_tokens]
1256
- } break;
1257
- default:
1258
- GGML_ABORT("fatal error");
1259
- }
1260
- cb(probs, "ffn_moe_probs", il);
1261
-
1262
- // add experts selection bias - introduced in DeepSeek V3
1263
- // leave probs unbiased as it's later used to get expert weights
1264
- ggml_tensor * selection_probs = probs;
1265
- if (exp_probs_b != nullptr) {
1266
- selection_probs = ggml_add(ctx0, probs, exp_probs_b);
1267
- cb(selection_probs, "ffn_moe_probs_biased", il);
1268
- }
1269
-
1270
- // llama4 doesn't have exp_probs_b, and sigmoid is only used after top_k
1271
- // see: https://github.com/meta-llama/llama-models/blob/699a02993512fb36936b1b0741e13c06790bcf98/models/llama4/moe.py#L183-L198
1272
- if (arch == LLM_ARCH_LLAMA4) {
1273
- selection_probs = logits;
1274
- }
1275
-
1276
- if (arch == LLM_ARCH_GROVEMOE) {
1277
- selection_probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
1278
- cb(selection_probs, "ffn_moe_probs_biased", il);
1279
- }
1280
-
1281
- // select top n_group_used expert groups
1282
- // https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/e815299b0bcbac849fa540c768ef21845365c9eb/modeling_deepseek.py#L440-L457
1283
- if (hparams.n_expert_groups > 1 && n_tokens > 0) {
1284
- const int64_t n_exp_per_group = n_expert / hparams.n_expert_groups;
1285
-
1286
- // organize experts into n_expert_groups
1287
- ggml_tensor * selection_groups = ggml_reshape_3d(ctx0, selection_probs, n_exp_per_group, hparams.n_expert_groups, n_tokens); // [n_exp_per_group, n_expert_groups, n_tokens]
1288
-
1289
- ggml_tensor * group_scores = ggml_argsort_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens]
1290
- group_scores = ggml_get_rows(ctx0, ggml_reshape_4d(ctx0, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1], selection_groups->ne[2]), group_scores); // [1, 2, n_expert_groups, n_tokens]
1291
-
1292
- // get top n_group_used expert groups
1293
- group_scores = ggml_sum_rows(ctx0, ggml_reshape_3d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2], group_scores->ne[3])); // [1, n_expert_groups, n_tokens]
1294
- group_scores = ggml_reshape_2d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2]); // [n_expert_groups, n_tokens]
1295
-
1296
- ggml_tensor * expert_groups = ggml_argsort_top_k(ctx0, group_scores, hparams.n_group_used); // [n_group_used, n_tokens]
1297
- cb(expert_groups, "ffn_moe_group_topk", il);
1298
-
1299
- // mask out the other groups
1300
- selection_probs = ggml_get_rows(ctx0, selection_groups, expert_groups); // [n_exp_per_group, n_group_used, n_tokens]
1301
- selection_probs = ggml_set_rows(ctx0, ggml_fill(ctx0, selection_groups, -INFINITY), selection_probs, expert_groups); // [n_exp_per_group, n_expert_groups, n_tokens]
1302
- selection_probs = ggml_reshape_2d(ctx0, selection_probs, n_expert, n_tokens); // [n_expert, n_tokens]
1303
- cb(selection_probs, "ffn_moe_probs_masked", il);
1304
- }
1305
-
1306
- // select experts
1307
- ggml_tensor * selected_experts = ggml_argsort_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
1308
- cb(selected_experts->src[0], "ffn_moe_argsort", il);
1309
- cb(selected_experts, "ffn_moe_topk", il);
1310
-
1311
- if (arch == LLM_ARCH_GROVEMOE && n_expert != hparams.n_expert) {
1312
- // TODO: Use scalar div instead when/if implemented
1313
- ggml_tensor * f_sel = ggml_cast(ctx0, selected_experts, GGML_TYPE_F32);
1314
- selected_experts = ggml_cast(ctx0, ggml_scale(ctx0, f_sel, 1.0f / float(hparams.n_group_experts)), GGML_TYPE_I32);
1315
- probs = ggml_reshape_3d(ctx0, probs, 1, hparams.n_expert, n_tokens);
1316
- } else {
1317
- probs = ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens);
1318
- }
1319
-
1320
- ggml_tensor * weights = ggml_get_rows(ctx0, probs, selected_experts); // [1, n_expert_used, n_tokens]
1321
- cb(weights, "ffn_moe_weights", il);
1322
-
1323
-
1324
- if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) {
1325
- weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
1326
- weights = ggml_soft_max(ctx0, weights); // [n_expert_used, n_tokens]
1327
- weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
1328
- cb(weights, "ffn_moe_weights_softmax", il);
1329
- }
1330
-
1331
- if (norm_w) {
1332
- weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
1333
-
1334
- ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens]
1335
- cb(weights_sum, "ffn_moe_weights_sum", il);
1336
-
1337
- // Avoid division by zero, clamp to smallest number representable by F16
1338
- weights_sum = ggml_clamp(ctx0, weights_sum, 6.103515625e-5, INFINITY);
1339
- cb(weights_sum, "ffn_moe_weights_sum_clamped", il);
1340
-
1341
- weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens]
1342
- cb(weights, "ffn_moe_weights_norm", il);
1343
-
1344
- weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
1345
- }
1346
- if (w_scale != 0.0f && w_scale != 1.0f) {
1347
- weights = ggml_scale(ctx0, weights, w_scale);
1348
- cb(weights, "ffn_moe_weights_scaled", il);
1349
- }
1350
-
1351
- //call early so that topk-moe can be used
1352
- ggml_build_forward_expand(gf, weights);
1353
-
1354
- cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
1355
-
1356
- if (weight_before_ffn) {
1357
- // repeat cur to [n_embd, n_expert_used, n_tokens]
1358
- ggml_tensor * repeated = ggml_repeat_4d(ctx0, cur, n_embd, n_expert_used, n_tokens, 1);
1359
- cur = ggml_mul(ctx0, repeated, weights);
1360
- cb(cur, "ffn_moe_weighted", il);
1361
- }
1362
-
1363
- ggml_tensor * up = nullptr;
1364
- ggml_tensor * experts = nullptr;
1365
-
1366
- if (gate_up_exps) {
1367
- // merged gate_up path: one mul_mat_id, then split into gate and up views
1368
- ggml_tensor * gate_up = build_lora_mm_id(gate_up_exps, cur, selected_experts); // [n_ff*2, n_expert_used, n_tokens]
1369
- cb(gate_up, "ffn_moe_gate_up", il);
1370
-
1371
- if (gate_up_exps_b) {
1372
- gate_up = ggml_add_id(ctx0, gate_up, gate_up_exps_b, selected_experts);
1373
- cb(gate_up, "ffn_moe_gate_up_biased", il);
1374
- }
1375
-
1376
- // apply per-expert scale2 to merged gate_up (use up_exps_s since gate and up are fused)
1377
- if (up_exps_s) {
1378
- ggml_tensor * s = ggml_reshape_3d(ctx0, up_exps_s, 1, n_expert, 1);
1379
- s = ggml_repeat_4d(ctx0, s, 1, n_expert, n_tokens, 1);
1380
- s = ggml_get_rows(ctx0, s, selected_experts); // [1, n_expert_used, n_tokens]
1381
- gate_up = ggml_mul(ctx0, gate_up, s);
1382
- cb(gate_up, "ffn_moe_gate_up_scaled", il);
1383
- }
1384
-
1385
- const int64_t n_ff = gate_up->ne[0] / 2;
1386
- cur = ggml_view_3d(ctx0, gate_up, n_ff, gate_up->ne[1], gate_up->ne[2], gate_up->nb[1], gate_up->nb[2], 0);
1387
- cb(cur, "ffn_moe_gate", il);
1388
- up = ggml_view_3d(ctx0, gate_up, n_ff, gate_up->ne[1], gate_up->ne[2], gate_up->nb[1], gate_up->nb[2], n_ff * gate_up->nb[0]);
1389
- cb(up, "ffn_moe_up", il);
1390
- } else {
1391
- // separate gate and up path
1392
- up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
1393
- cb(up, "ffn_moe_up", il);
1394
-
1395
- if (up_exps_b) {
1396
- up = ggml_add_id(ctx0, up, up_exps_b, selected_experts);
1397
- cb(up, "ffn_moe_up_biased", il);
1398
- }
1399
-
1400
- // apply per-expert scale2 to up
1401
- if (up_exps_s) {
1402
- ggml_tensor * s = ggml_reshape_3d(ctx0, up_exps_s, 1, n_expert, 1);
1403
- s = ggml_repeat_4d(ctx0, s, 1, n_expert, n_tokens, 1);
1404
- s = ggml_get_rows(ctx0, s, selected_experts); // [1, n_expert_used, n_tokens]
1405
- up = ggml_mul(ctx0, up, s);
1406
- cb(up, "ffn_moe_up_scaled", il);
1407
- }
1408
-
1409
- if (gate_exps) {
1410
- cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
1411
- cb(cur, "ffn_moe_gate", il);
1412
- } else {
1413
- cur = up;
1414
- }
1415
-
1416
- if (gate_exps_b) {
1417
- cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts);
1418
- cb(cur, "ffn_moe_gate_biased", il);
1419
- }
1420
-
1421
- // apply per-expert scale2 to gate
1422
- if (gate_exps_s) {
1423
- ggml_tensor * s = ggml_reshape_3d(ctx0, gate_exps_s, 1, n_expert, 1);
1424
- s = ggml_repeat_4d(ctx0, s, 1, n_expert, n_tokens, 1);
1425
- s = ggml_get_rows(ctx0, s, selected_experts); // [1, n_expert_used, n_tokens]
1426
- cur = ggml_mul(ctx0, cur, s);
1427
- cb(cur, "ffn_moe_gate_scaled", il);
1428
- }
1429
- }
1430
-
1431
- const bool has_gate = gate_exps || gate_up_exps;
1432
-
1433
- switch (type_op) {
1434
- case LLM_FFN_SILU:
1435
- if (gate_exps) {
1436
- // Step35: per-layer clamp for routed experts
1437
- if (arch == LLM_ARCH_STEP35 && il >= 0) {
1438
- const float limit = hparams.swiglu_clamp_exp[il];
1439
- constexpr float eps = 1e-6f;
1440
- if (limit > eps) {
1441
- ggml_tensor * gate_act = ggml_silu(ctx0, cur);
1442
- cb(gate_act, "ffn_moe_silu", il);
1443
- gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit);
1444
- cb(gate_act, "ffn_moe_silu_clamped", il);
1445
-
1446
- up = ggml_clamp(ctx0, up, -limit, limit);
1447
- cb(up, "ffn_moe_up_clamped", il);
1448
-
1449
- cur = ggml_mul(ctx0, gate_act, up);
1450
- cb(cur, "ffn_moe_swiglu_limited", il);
1451
- break;
1452
- }
1453
- }
1454
- }
1455
-
1456
- if (has_gate) {
1457
- cur = ggml_swiglu_split(ctx0, cur, up);
1458
- cb(cur, "ffn_moe_swiglu", il);
1459
- } else {
1460
- cur = ggml_silu(ctx0, cur);
1461
- cb(cur, "ffn_moe_silu", il);
1462
- } break;
1463
- case LLM_FFN_GELU:
1464
- if (has_gate) {
1465
- cur = ggml_geglu_split(ctx0, cur, up);
1466
- cb(cur, "ffn_moe_geglu", il);
1467
- } else {
1468
- cur = ggml_gelu(ctx0, cur);
1469
- cb(cur, "ffn_moe_gelu", il);
1470
- } break;
1471
- case LLM_FFN_SWIGLU_OAI_MOE:
1472
- {
1473
- // TODO: move to hparams?
1474
- constexpr float alpha = 1.702f;
1475
- constexpr float limit = 7.0f;
1476
- cur = ggml_swiglu_oai(ctx0, cur, up, alpha, limit);
1477
- cb(cur, "ffn_moe_swiglu_oai", il);
1478
- } break;
1479
- case LLM_FFN_RELU:
1480
- if (has_gate) {
1481
- cur = ggml_reglu_split(ctx0, cur, up);
1482
- cb(cur, "ffn_moe_reglu", il);
1483
- } else {
1484
- cur = ggml_relu(ctx0, cur);
1485
- cb(cur, "ffn_moe_relu", il);
1486
- } break;
1487
- case LLM_FFN_RELU_SQR:
1488
- if (has_gate) {
1489
- // TODO: add support for gated squared relu
1490
- GGML_ABORT("fatal error: gated squared relu not implemented");
1491
- } else {
1492
- cur = ggml_relu(ctx0, cur);
1493
- cur = ggml_sqr(ctx0, cur);
1494
- cb(cur, "ffn_moe_relu_sqr", il);
1495
- } break;
1496
- default:
1497
- GGML_ABORT("fatal error");
1498
- }
1499
-
1500
- experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
1501
- cb(experts, "ffn_moe_down", il);
1502
-
1503
- if (down_exps_b) {
1504
- experts = ggml_add_id(ctx0, experts, down_exps_b, selected_experts);
1505
- cb(experts, "ffn_moe_down_biased", il);
1506
- }
1507
-
1508
- // apply per-expert scale2 to down
1509
- if (down_exps_s) {
1510
- ggml_tensor * s = ggml_reshape_3d(ctx0, down_exps_s, 1, n_expert, 1);
1511
- s = ggml_repeat_4d(ctx0, s, 1, n_expert, n_tokens, 1);
1512
- s = ggml_get_rows(ctx0, s, selected_experts); // [1, n_expert_used, n_tokens]
1513
- experts = ggml_mul(ctx0, experts, s);
1514
- cb(experts, "ffn_moe_down_scaled", il);
1515
- }
1516
-
1517
- if (!weight_before_ffn) {
1518
- experts = ggml_mul(ctx0, experts, weights);
1519
- cb(cur, "ffn_moe_weighted", il);
1520
- }
1521
-
1522
- ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr };
1523
-
1524
- assert(n_expert_used > 0);
1525
-
1526
- // order the views before the adds
1527
- for (uint32_t i = 0; i < hparams.n_expert_used; ++i) {
1528
- cur_experts[i] = ggml_view_2d(ctx0, experts, n_embd, n_tokens, experts->nb[2], i*experts->nb[1]);
1529
-
1530
- ggml_build_forward_expand(gf, cur_experts[i]);
1531
- }
1532
-
1533
- // aggregate experts
1534
- // note: here we explicitly use hparams.n_expert_used instead of n_expert_used
1535
- // to avoid potentially a large number of add nodes during warmup
1536
- // ref: https://github.com/ggml-org/llama.cpp/pull/14753
1537
- ggml_tensor * moe_out = cur_experts[0];
1538
-
1539
- for (uint32_t i = 1; i < hparams.n_expert_used; ++i) {
1540
- moe_out = ggml_add(ctx0, moe_out, cur_experts[i]);
1541
- }
1542
-
1543
- if (hparams.n_expert_used == 1) {
1544
- // avoid returning a non-contiguous tensor
1545
- moe_out = ggml_cont(ctx0, moe_out);
1546
- }
1547
-
1548
- cb(moe_out, "ffn_moe_out", il);
1549
-
1550
- return moe_out;
1551
- }
1552
-
1553
- // input embeddings with optional lora
1554
- ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
1555
- const int64_t n_embd_inp = hparams.n_embd_inp();
1556
- const int64_t n_embd = hparams.n_embd;
1557
-
1558
- assert(n_embd_inp >= n_embd);
1559
-
1560
- auto inp = std::make_unique<llm_graph_input_embd>(n_embd_inp);
1561
-
1562
- inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
1563
- cb(inp->tokens, "inp_tokens", -1);
1564
- ggml_set_input(inp->tokens);
1565
- res->t_inp_tokens = inp->tokens;
1566
-
1567
- inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_inp, ubatch.n_tokens);
1568
- cb(inp->embd, "inp_embd", -1);
1569
- ggml_set_input(inp->embd);
1570
-
1571
- // select one of the 2 inputs, based on the batch contents
1572
- // ref: https://github.com/ggml-org/llama.cpp/pull/18550
1573
- std::array<ggml_tensor *, 2> inps;
1574
-
1575
- // token embeddings path (ubatch.token != nullptr)
1576
- {
1577
- auto & cur = inps[0];
1578
-
1579
- cur = ggml_get_rows(ctx0, tok_embd, inp->tokens);
1580
-
1581
- // apply lora for embedding tokens if needed
1582
- for (const auto & lora : *loras) {
1583
- llama_adapter_lora_weight * lw = lora.first->get_weight(tok_embd);
1584
- if (lw == nullptr) {
1585
- continue;
1586
- }
1587
-
1588
- const float adapter_scale = lora.second;
1589
- const float scale = lw->get_scale(lora.first->alpha, adapter_scale);
1590
-
1591
- ggml_tensor * inpL_delta = ggml_scale(ctx0, ggml_mul_mat(
1592
- ctx0, lw->b, // non-transposed lora_b
1593
- ggml_get_rows(ctx0, lw->a, inp->tokens)
1594
- ), scale);
1595
-
1596
- cur = ggml_add(ctx0, cur, inpL_delta);
1597
- }
1598
-
1599
- if (n_embd_inp != n_embd) {
1600
- cur = ggml_pad(ctx0, cur, hparams.n_embd_inp() - n_embd, 0, 0, 0);
1601
- }
1602
- }
1603
-
1604
- // vector embeddings path (ubatch.embd != nullptr)
1605
- {
1606
- auto & cur = inps[1];
1607
-
1608
- cur = inp->embd;
1609
- }
1610
-
1611
- assert(ggml_are_same_shape (inps[0], inps[1]));
1612
- assert(ggml_are_same_stride(inps[0], inps[1]));
1613
-
1614
- ggml_tensor * cur = ggml_build_forward_select(gf, inps.data(), inps.size(), ubatch.token ? 0 : 1);
1615
-
1616
- if (n_embd_inp != n_embd) {
1617
- cur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0);
1618
- }
1619
-
1620
- res->t_inp_embd = cur;
1621
-
1622
- // For Granite architecture
1623
- if (hparams.f_embedding_scale != 0.0f) {
1624
- cur = ggml_scale(ctx0, cur, hparams.f_embedding_scale);
1625
- }
1626
-
1627
- cb(cur, "embd", -1);
1628
-
1629
- res->add_input(std::move(inp));
1630
-
1631
- // make sure the produced embeddings are immediately materialized in the ggml graph
1632
- // ref: https://github.com/ggml-org/llama.cpp/pull/18599
1633
- ggml_build_forward_expand(gf, cur);
1634
-
1635
- return cur;
1636
- }
1637
-
1638
- ggml_tensor * llm_graph_context::build_inp_pos() const {
1639
- auto inp = std::make_unique<llm_graph_input_pos>(hparams.n_pos_per_embd());
1640
-
1641
- auto & cur = inp->pos;
1642
-
1643
- cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, (int64_t)n_tokens*hparams.n_pos_per_embd());
1644
- ggml_set_input(cur);
1645
-
1646
- res->add_input(std::move(inp));
1647
-
1648
- return cur;
1649
- }
1650
-
1651
- ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
1652
- auto inp = std::make_unique<llm_graph_input_attn_temp>(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale, hparams.f_attn_temp_offset);
1653
-
1654
- auto & cur = inp->attn_scale;
1655
-
1656
- // this need to be 1x1xN for broadcasting
1657
- cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens);
1658
- ggml_set_input(cur);
1659
- ggml_set_name(cur, "attn_scale");
1660
-
1661
- res->add_input(std::move(inp));
1662
-
1663
- return cur;
1664
- }
1665
-
1666
- ggml_tensor * llm_graph_context::build_inp_out_ids() const {
1667
- // note: when all tokens are output, we could skip this optimization to spare the ggml_get_rows() calls,
1668
- // but this would make the graph topology depend on the number of output tokens, which can interere with
1669
- // features that require constant topology such as pipeline parallelism
1670
- // ref: https://github.com/ggml-org/llama.cpp/pull/14275#issuecomment-2987424471
1671
- //if (n_outputs < n_tokens) {
1672
- // return nullptr;
1673
- //}
1674
-
1675
- auto inp = std::make_unique<llm_graph_input_out_ids>(hparams, cparams, n_outputs);
1676
-
1677
- auto & cur = inp->out_ids;
1678
-
1679
- cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_outputs);
1680
- ggml_set_input(cur);
1681
-
1682
- res->add_input(std::move(inp));
1683
-
1684
- return cur;
1685
- }
1686
-
1687
- ggml_tensor * llm_graph_context::build_inp_mean() const {
1688
- auto inp = std::make_unique<llm_graph_input_mean>(cparams);
1689
-
1690
- auto & cur = inp->mean;
1691
-
1692
- cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, ubatch.n_seqs_unq);
1693
- ggml_set_input(cur);
1694
-
1695
- res->add_input(std::move(inp));
1696
-
1697
- return cur;
1698
- }
1699
-
1700
- ggml_tensor * llm_graph_context::build_inp_cls() const {
1701
- auto inp = std::make_unique<llm_graph_input_cls>(cparams, arch);
1702
-
1703
- auto & cur = inp->cls;
1704
-
1705
- cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_seqs_unq);
1706
- ggml_set_input(cur);
1707
-
1708
- res->add_input(std::move(inp));
1709
-
1710
- return cur;
1711
- }
1712
-
1713
- ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
1714
- auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
1715
-
1716
- auto & cur = inp->cross_embd;
1717
-
1718
- // if we have the output embeddings from the encoder, use them directly
1719
- // TODO: needs more work to be correct, for now just use the tensor shape
1720
- //if (cross->t_embd) {
1721
- // cur = ggml_view_tensor(ctx0, cross->t_embd);
1722
-
1723
- // return cur;
1724
- //}
1725
-
1726
- const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd_inp();
1727
- const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
1728
-
1729
- cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc);
1730
- ggml_set_input(cur);
1731
-
1732
- res->add_input(std::move(inp));
1733
-
1734
- return cur;
1735
- }
1736
-
1737
- ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
1738
- auto inp = std::make_unique<llm_graph_input_pos_bucket>(hparams);
1739
-
1740
- auto & cur = inp->pos_bucket;
1741
-
1742
- cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_tokens);
1743
- ggml_set_input(cur);
1744
-
1745
- res->add_input(std::move(inp));
1746
-
1747
- return cur;
1748
- }
1749
-
1750
- ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
1751
- const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
1752
-
1753
- auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);
1754
-
1755
- const auto n_kv = mctx_cur->get_n_kv();
1756
-
1757
- auto & cur = inp->pos_bucket;
1758
-
1759
- cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens);
1760
- ggml_set_input(cur);
1761
-
1762
- res->add_input(std::move(inp));
1763
-
1764
- return cur;
1765
- }
1766
-
1767
- ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const {
1768
- ggml_tensor * pos_bucket_1d = ggml_reshape_1d(ctx0, pos_bucket, pos_bucket->ne[0] * pos_bucket->ne[1]);
1769
- cb(pos_bucket_1d, "pos_bucket_1d", -1);
1770
-
1771
- ggml_tensor * pos_bias = ggml_get_rows(ctx0, attn_rel_b, pos_bucket_1d);
1772
-
1773
- pos_bias = ggml_reshape_3d(ctx0, pos_bias, pos_bias->ne[0], pos_bucket->ne[0], pos_bucket->ne[1]);
1774
- pos_bias = ggml_permute (ctx0, pos_bias, 2, 0, 1, 3);
1775
- pos_bias = ggml_cont (ctx0, pos_bias);
1776
-
1777
- cb(pos_bias, "pos_bias", -1);
1778
-
1779
- return pos_bias;
1780
- }
1781
-
1782
- ggml_tensor * llm_graph_context::build_attn_mha(
1783
- ggml_tensor * q,
1784
- ggml_tensor * k,
1785
- ggml_tensor * v,
1786
- ggml_tensor * kq_b,
1787
- ggml_tensor * kq_mask,
1788
- ggml_tensor * sinks,
1789
- ggml_tensor * v_mla,
1790
- float kq_scale,
1791
- int il) const {
1792
- const bool v_trans = v->nb[1] > v->nb[2];
1793
-
1794
- // split the batch into streams if needed
1795
- const auto n_stream = k->ne[3];
1796
-
1797
- q = ggml_view_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream, q->nb[1], q->nb[2], q->nb[3]/n_stream, 0);
1798
-
1799
- q = ggml_permute(ctx0, q, 0, 2, 1, 3);
1800
- k = ggml_permute(ctx0, k, 0, 2, 1, 3);
1801
- v = ggml_permute(ctx0, v, 0, 2, 1, 3);
1802
-
1803
- ggml_tensor * cur;
1804
-
1805
- const bool use_flash_attn = cparams.flash_attn && kq_b == nullptr;
1806
- if (use_flash_attn) {
1807
- GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");
1808
-
1809
- if (v_trans) {
1810
- v = ggml_transpose(ctx0, v);
1811
- }
1812
-
1813
- // this can happen when KV cache is not used (e.g. an embedding model with non-causal attn)
1814
- if (k->type == GGML_TYPE_F32) {
1815
- k = ggml_cast(ctx0, k, GGML_TYPE_F16);
1816
- }
1817
-
1818
- if (v->type == GGML_TYPE_F32) {
1819
- v = ggml_cast(ctx0, v, GGML_TYPE_F16);
1820
- }
1821
-
1822
- cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
1823
- hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
1824
- cb(cur, LLAMA_TENSOR_NAME_FATTN, il);
1825
-
1826
- ggml_flash_attn_ext_add_sinks(cur, sinks);
1827
- ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32);
1828
-
1829
- if (v_mla) {
1830
- #if 0
1831
- // v_mla can be applied as a matrix-vector multiplication with broadcasting across dimension 3 == n_tokens.
1832
- // However, the code is optimized for dimensions 0 and 1 being large, so this is inefficient.
1833
- cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
1834
- cur = ggml_mul_mat(ctx0, v_mla, cur);
1835
- #else
1836
- // It's preferable to do the calculation as a matrix-matrix multiplication with n_tokens in dimension 1.
1837
- // The permutations are noops and only change how the tensor data is interpreted.
1838
- cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
1839
- cur = ggml_mul_mat(ctx0, v_mla, cur);
1840
- cb(cur, "fattn_mla", il);
1841
- cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
1842
- cur = ggml_cont(ctx0, cur); // Needed because ggml_reshape_2d expects contiguous inputs.
1843
- #endif
1844
- }
1845
-
1846
- cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
1847
- } else {
1848
- ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
1849
- cb(kq, "kq", il);
1850
-
1851
- // note: this op tends to require high floating point range
1852
- // while for some models F16 is enough, for others it is not, so we default to F32 here
1853
- ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
1854
-
1855
- if (arch == LLM_ARCH_GROK) {
1856
- // need to do the following:
1857
- // multiply by attn_output_multiplier
1858
- // and then :
1859
- // kq = 30 * tanh(kq / 30)
1860
- // before the softmax below
1861
-
1862
- kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, hparams.f_attn_out_scale / hparams.f_attn_logit_softcapping));
1863
- cb(kq, "kq_tanh", il);
1864
- kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
1865
- cb(kq, "kq_scaled", il);
1866
- }
1867
-
1868
- if (hparams.attn_soft_cap) {
1869
- kq = ggml_scale(ctx0, kq, 1.0f / hparams.f_attn_logit_softcapping);
1870
- cb(kq, "kq_scaled_1", il);
1871
- kq = ggml_tanh (ctx0, kq);
1872
- cb(kq, "kq_tanh", il);
1873
- kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
1874
- cb(kq, "kq_scaled_2", il);
1875
- }
1876
-
1877
- if (kq_b) {
1878
- kq = ggml_add(ctx0, kq, kq_b);
1879
- cb(kq, "kq_plus_kq_b", il);
1880
- }
1881
-
1882
- kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
1883
- ggml_soft_max_add_sinks(kq, sinks);
1884
- cb(kq, "kq_soft_max", il);
1885
-
1886
- if (!v_trans) {
1887
- // note: avoid this branch
1888
- v = ggml_cont(ctx0, ggml_transpose(ctx0, v));
1889
- cb(v, "v_cont", il);
1890
- }
1891
-
1892
- ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
1893
- cb(kqv, "kqv", il);
1894
-
1895
- // for MLA with the absorption optimization, we need to "decompress" from MQA back to MHA
1896
- if (v_mla) {
1897
- kqv = ggml_mul_mat(ctx0, v_mla, kqv);
1898
- cb(kqv, "kqv_mla", il);
1899
- }
1900
-
1901
- cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
1902
-
1903
- // recombine streams
1904
- cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
1905
-
1906
- if (!cparams.offload_kqv) {
1907
- // all nodes between the KV store and the attention output are run on the CPU
1908
- ggml_backend_sched_set_tensor_backend(sched, cur, backend_cpu);
1909
- }
1910
- }
1911
-
1912
- ggml_build_forward_expand(gf, cur);
1913
-
1914
- return cur;
1915
- }
1916
-
1917
- llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() const {
1918
- auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
1919
-
1920
- // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
1921
- inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1);
1922
- ggml_set_input(inp->self_kq_mask);
1923
-
1924
- inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1925
-
1926
- if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
1927
- inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1);
1928
- ggml_set_input(inp->self_kq_mask_swa);
1929
-
1930
- inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
1931
- } else {
1932
- inp->self_kq_mask_swa = nullptr;
1933
- inp->self_kq_mask_swa_cnv = nullptr;
1934
- }
1935
-
1936
- return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
1937
- }
1938
-
1939
- ggml_tensor * llm_graph_context::build_attn(
1940
- llm_graph_input_attn_no_cache * inp,
1941
- ggml_tensor * wo,
1942
- ggml_tensor * wo_b,
1943
- ggml_tensor * q_cur,
1944
- ggml_tensor * k_cur,
1945
- ggml_tensor * v_cur,
1946
- ggml_tensor * kq_b,
1947
- ggml_tensor * sinks,
1948
- ggml_tensor * v_mla,
1949
- float kq_scale,
1950
- int il) const {
1951
- GGML_UNUSED(n_tokens);
1952
-
1953
- // these nodes are added to the graph together so that they are not reordered
1954
- // by doing so, the number of splits in the graph is reduced
1955
- ggml_build_forward_expand(gf, q_cur);
1956
- ggml_build_forward_expand(gf, k_cur);
1957
- ggml_build_forward_expand(gf, v_cur);
1958
-
1959
- const bool is_swa = hparams.is_swa(il);
1960
-
1961
- const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
1962
-
1963
- // [TAG_NO_CACHE_PAD]
1964
- // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
1965
- // but it might not be worth it: https://github.com/ggml-org/llama.cpp/pull/15636
1966
- //assert(!ubatch.equal_seqs() || (k_cur->ne[3] == 1 && k_cur->ne[3] == ubatch.n_seqs_unq));
1967
-
1968
- ggml_tensor * q = q_cur;
1969
- ggml_tensor * k = k_cur;
1970
- ggml_tensor * v = v_cur;
1971
-
1972
- ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
1973
- cb(cur, "kqv_out", il);
1974
-
1975
- if (wo) {
1976
- cur = build_lora_mm(wo, cur);
1977
- }
1978
-
1979
- if (wo_b) {
1980
- //cb(cur, "kqv_wo", il);
1981
- }
1982
-
1983
- if (wo_b) {
1984
- cur = ggml_add(ctx0, cur, wo_b);
1985
- }
1986
-
1987
- return cur;
1988
- }
1989
-
1990
- static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
1991
- ggml_context * ctx0,
1992
- const llama_ubatch & ubatch,
1993
- const llama_hparams & hparams,
1994
- const llama_cparams & cparams,
1995
- const llama_kv_cache_context * mctx_cur) {
1996
-
1997
- auto inp = std::make_unique<llm_graph_input_attn_kv>(hparams, cparams, mctx_cur);
1998
-
1999
- {
2000
- GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
2001
-
2002
- inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
2003
- inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
2004
-
2005
- inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams);
2006
-
2007
- ggml_set_input(inp->self_kq_mask);
2008
-
2009
- inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
2010
- }
2011
-
2012
- return inp;
2013
- }
2014
-
2015
- llm_graph_input_attn_kv * llm_graph_context::build_attn_inp_kv() const {
2016
- const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
2017
-
2018
- auto inp = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
2019
-
2020
- return (llm_graph_input_attn_kv *) res->add_input(std::move(inp));
2021
- }
2022
-
2023
- ggml_tensor * llm_graph_context::build_attn(
2024
- llm_graph_input_attn_kv * inp,
2025
- ggml_tensor * wo,
2026
- ggml_tensor * wo_b,
2027
- ggml_tensor * q_cur,
2028
- ggml_tensor * k_cur,
2029
- ggml_tensor * v_cur,
2030
- ggml_tensor * kq_b,
2031
- ggml_tensor * sinks,
2032
- ggml_tensor * v_mla, // TODO: remove
2033
- float kq_scale,
2034
- int il) const {
2035
- GGML_ASSERT(v_mla == nullptr);
2036
-
2037
- // these nodes are added to the graph together so that they are not reordered
2038
- // by doing so, the number of splits in the graph is reduced
2039
- // expand k later to enable rope fusion which directly writes into k-v cache
2040
- ggml_build_forward_expand(gf, q_cur);
2041
- ggml_build_forward_expand(gf, v_cur);
2042
- ggml_build_forward_expand(gf, k_cur);
2043
-
2044
- const auto * mctx_cur = inp->mctx;
2045
-
2046
- // store to KV cache
2047
- {
2048
- const auto & k_idxs = inp->get_k_idxs();
2049
- const auto & v_idxs = inp->get_v_idxs();
2050
-
2051
- ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
2052
- ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
2053
- }
2054
-
2055
- const auto & kq_mask = inp->get_kq_mask();
2056
-
2057
- ggml_tensor * q = q_cur;
2058
- ggml_tensor * k = mctx_cur->get_k(ctx0, il);
2059
- ggml_tensor * v = mctx_cur->get_v(ctx0, il);
2060
-
2061
- ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
2062
- cb(cur, "kqv_out", il);
2063
-
2064
- if (wo) {
2065
- cur = build_lora_mm(wo, cur);
2066
- if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) {
2067
- // GLM4, GLM4_MOE, and JAIS2 seem to have numerical issues with half-precision accumulators
2068
- ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
2069
- }
2070
- }
2071
-
2072
- if (wo_b) {
2073
- cur = ggml_add(ctx0, cur, wo_b);
2074
- }
2075
-
2076
- return cur;
2077
- }
2078
-
2079
- static std::unique_ptr<llm_graph_input_attn_k> build_attn_inp_k_impl(
2080
- ggml_context * ctx0,
2081
- const llama_ubatch & ubatch,
2082
- const llama_hparams & hparams,
2083
- const llama_cparams & cparams,
2084
- const llama_kv_cache_context * mctx_cur) {
2085
-
2086
- auto inp = std::make_unique<llm_graph_input_attn_k>(hparams, cparams, mctx_cur);
2087
-
2088
- {
2089
- GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
2090
-
2091
- inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
2092
-
2093
- inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams);
2094
- ggml_set_input(inp->self_kq_mask);
2095
-
2096
- inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
2097
- }
2098
-
2099
- return inp;
2100
- }
2101
-
2102
- llm_graph_input_attn_k * llm_graph_context::build_attn_inp_k() const {
2103
- const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
2104
-
2105
- auto inp = build_attn_inp_k_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
2106
-
2107
- return (llm_graph_input_attn_k *) res->add_input(std::move(inp));
2108
- }
2109
-
2110
- ggml_tensor * llm_graph_context::build_attn(
2111
- llm_graph_input_attn_k * inp,
2112
- ggml_tensor * wo,
2113
- ggml_tensor * wo_b,
2114
- ggml_tensor * q_cur,
2115
- ggml_tensor * k_cur,
2116
- ggml_tensor * v_cur,
2117
- ggml_tensor * kq_b,
2118
- ggml_tensor * sinks,
2119
- ggml_tensor * v_mla,
2120
- float kq_scale,
2121
- int il) const {
2122
- // these nodes are added to the graph together so that they are not reordered
2123
- // by doing so, the number of splits in the graph is reduced
2124
- // expand k later to enable rope fusion which directly writes into k-v cache
2125
- ggml_build_forward_expand(gf, q_cur);
2126
- ggml_build_forward_expand(gf, v_cur);
2127
- ggml_build_forward_expand(gf, k_cur);
2128
-
2129
- const auto * mctx_cur = inp->mctx;
2130
-
2131
- // store to KV cache
2132
- {
2133
- const auto & k_idxs = inp->get_k_idxs();
2134
-
2135
- ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
2136
- }
2137
-
2138
- const auto & kq_mask = inp->get_kq_mask();
2139
-
2140
- ggml_tensor * q = q_cur;
2141
- ggml_tensor * k = mctx_cur->get_k(ctx0, il);
2142
- ggml_tensor * v = ggml_view_4d(ctx0, k, v_cur->ne[0], k->ne[1], k->ne[2], k->ne[3], k->nb[1], k->nb[2], k->nb[3], 0);
2143
-
2144
- ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
2145
- cb(cur, "kqv_out", il);
2146
-
2147
- if (wo) {
2148
- cur = build_lora_mm(wo, cur);
2149
- if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
2150
- // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
2151
- ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
2152
- }
2153
- }
2154
-
2155
- if (wo_b) {
2156
- cur = ggml_add(ctx0, cur, wo_b);
2157
- }
2158
-
2159
- return cur;
2160
- }
2161
-
2162
- ggml_tensor * llm_graph_context::build_attn(
2163
- llm_graph_input_attn_kv_iswa * inp,
2164
- ggml_tensor * wo,
2165
- ggml_tensor * wo_b,
2166
- ggml_tensor * q_cur,
2167
- ggml_tensor * k_cur,
2168
- ggml_tensor * v_cur,
2169
- ggml_tensor * kq_b,
2170
- ggml_tensor * sinks,
2171
- ggml_tensor * v_mla,
2172
- float kq_scale,
2173
- int il) const {
2174
- // these nodes are added to the graph together so that they are not reordered
2175
- // by doing so, the number of splits in the graph is reduced
2176
- ggml_build_forward_expand(gf, q_cur);
2177
-
2178
- if (k_cur) {
2179
- ggml_build_forward_expand(gf, k_cur);
2180
- }
2181
-
2182
- if (v_cur) {
2183
- ggml_build_forward_expand(gf, v_cur);
2184
- }
2185
-
2186
- const auto * mctx_iswa = inp->mctx;
2187
-
2188
- const bool is_swa = hparams.is_swa(il);
2189
-
2190
- const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();
2191
-
2192
- // optionally store to KV cache
2193
- if (k_cur) {
2194
- const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs();
2195
-
2196
- ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
2197
- }
2198
-
2199
- if (v_cur) {
2200
- const auto & v_idxs = is_swa ? inp->get_v_idxs_swa() : inp->get_v_idxs();
2201
-
2202
- ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
2203
- }
2204
-
2205
- const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
2206
-
2207
- ggml_tensor * q = q_cur;
2208
- ggml_tensor * k = mctx_cur->get_k(ctx0, il);
2209
- ggml_tensor * v = mctx_cur->get_v(ctx0, il);
2210
-
2211
- ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
2212
- cb(cur, "kqv_out", il);
2213
-
2214
- if (wo) {
2215
- cur = build_lora_mm(wo, cur);
2216
- }
2217
-
2218
- if (wo_b) {
2219
- //cb(cur, "kqv_wo", il);
2220
- }
2221
-
2222
- if (wo_b) {
2223
- cur = ggml_add(ctx0, cur, wo_b);
2224
- }
2225
-
2226
- return cur;
2227
- }
2228
-
2229
- llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
2230
- auto inp = std::make_unique<llm_graph_input_attn_cross>(cross);
2231
-
2232
- const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
2233
-
2234
- inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, n_tokens, 1, 1);
2235
- ggml_set_input(inp->cross_kq_mask);
2236
-
2237
- inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
2238
-
2239
- return (llm_graph_input_attn_cross *) res->add_input(std::move(inp));
2240
- }
2241
-
2242
- ggml_tensor * llm_graph_context::build_attn(
2243
- llm_graph_input_attn_cross * inp,
2244
- ggml_tensor * wo,
2245
- ggml_tensor * wo_b,
2246
- ggml_tensor * q_cur,
2247
- ggml_tensor * k_cur,
2248
- ggml_tensor * v_cur,
2249
- ggml_tensor * kq_b,
2250
- ggml_tensor * sinks,
2251
- ggml_tensor * v_mla,
2252
- float kq_scale,
2253
- int il) const {
2254
- // these nodes are added to the graph together so that they are not reordered
2255
- // by doing so, the number of splits in the graph is reduced
2256
- ggml_build_forward_expand(gf, q_cur);
2257
- ggml_build_forward_expand(gf, k_cur);
2258
- ggml_build_forward_expand(gf, v_cur);
2259
-
2260
- const auto & kq_mask = inp->get_kq_mask_cross();
2261
-
2262
- ggml_tensor * q = q_cur;
2263
- ggml_tensor * k = k_cur;
2264
- ggml_tensor * v = v_cur;
2265
-
2266
- ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
2267
- cb(cur, "kqv_out", il);
2268
-
2269
- if (wo) {
2270
- cur = build_lora_mm(wo, cur);
2271
- }
2272
-
2273
- if (wo_b) {
2274
- //cb(cur, "kqv_wo", il);
2275
- }
2276
-
2277
- if (wo_b) {
2278
- cur = ggml_add(ctx0, cur, wo_b);
2279
- }
2280
-
2281
- return cur;
2282
- }
2283
-
2284
- // TODO: maybe separate the inner implementation into a separate function
2285
- // like with the non-sliding window equivalent
2286
- // once sliding-window hybrid caches are a thing.
2287
- llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const {
2288
- const auto * mctx_cur = static_cast<const llama_kv_cache_iswa_context *>(mctx);
2289
-
2290
- auto inp = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, mctx_cur);
2291
-
2292
- {
2293
- inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
2294
- inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
2295
-
2296
- inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams);
2297
- ggml_set_input(inp->self_kq_mask);
2298
- ggml_set_name(inp->self_kq_mask, "self_kq_mask");
2299
-
2300
- inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
2301
- ggml_set_name(inp->self_kq_mask_cnv, "self_kq_mask_cnv");
2302
- }
2303
-
2304
- {
2305
- GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache for non-SWA");
2306
-
2307
- inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
2308
- inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
2309
-
2310
- inp->self_kq_mask_swa = build_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams);
2311
- ggml_set_input(inp->self_kq_mask_swa);
2312
- ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa");
2313
-
2314
- inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
2315
- ggml_set_name(inp->self_kq_mask_swa_cnv, "self_kq_mask_swa_cnv");
2316
- }
2317
-
2318
- return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
2319
- }
2320
-
2321
- ggml_tensor * llm_graph_context::build_rs(
2322
- ggml_tensor * s,
2323
- ggml_tensor * state_copy_main,
2324
- ggml_tensor * state_copy_extra,
2325
- int32_t state_size,
2326
- int32_t n_seqs,
2327
- uint32_t n_rs,
2328
- uint32_t rs_head,
2329
- uint32_t rs_size,
2330
- int32_t rs_zero,
2331
- const llm_graph_get_rows_fn & get_state_rows) const {
2332
-
2333
- ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, rs_size);
2334
-
2335
- // Clear a single state which will then be copied to the other cleared states.
2336
- // Note that this is a no-op when the view is zero-sized.
2337
- ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
2338
- ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
2339
-
2340
- // copy states
2341
- // NOTE: assuming the copy destinations are ALL contained between rs_head and rs_head + n_rs
2342
- // {state_size, rs_size} -> {state_size, n_seqs}
2343
- ggml_tensor * output_states = get_state_rows(ctx0, states, state_copy_main);
2344
- ggml_build_forward_expand(gf, output_states);
2345
-
2346
- // copy extra states which won't be changed further (between n_seqs and n_rs)
2347
- ggml_tensor * states_extra = ggml_get_rows(ctx0, states, state_copy_extra);
2348
- ggml_build_forward_expand(gf,
2349
- ggml_cpy(ctx0,
2350
- states_extra,
2351
- ggml_view_1d(ctx0, s, state_size*(n_rs - n_seqs), (rs_head + n_seqs)*state_size*ggml_element_size(s))));
2352
-
2353
- return output_states;
2354
- }
2355
-
2356
- static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
2357
- ggml_context * ctx0,
2358
- const llama_ubatch & ubatch,
2359
- const llama_memory_recurrent_context * mctx_cur) {
2360
-
2361
- auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
2362
-
2363
- const int64_t n_rs = mctx_cur->get_n_rs();
2364
- const int64_t n_seqs = ubatch.n_seqs;
2365
-
2366
- inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
2367
- ggml_set_input(inp->s_copy);
2368
-
2369
- inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0);
2370
- inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]);
2371
-
2372
- inp->head = mctx_cur->get_head();
2373
- inp->rs_z = mctx_cur->get_rs_z();
2374
-
2375
- return inp;
2376
- }
2377
-
2378
- llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
2379
- const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
2380
-
2381
- auto inp = build_rs_inp_impl(ctx0, ubatch, mctx_cur);
2382
-
2383
- return (llm_graph_input_rs *) res->add_input(std::move(inp));
2384
- }
2385
-
2386
- ggml_tensor * llm_graph_context::build_rs(
2387
- llm_graph_input_rs * inp,
2388
- ggml_tensor * s,
2389
- int32_t state_size,
2390
- int32_t n_seqs,
2391
- const llm_graph_get_rows_fn & get_state_rows) const {
2392
- const auto * kv_state = inp->mctx;
2393
-
2394
- return build_rs(s, inp->s_copy_main, inp->s_copy_extra, state_size, n_seqs,
2395
- kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(),
2396
- get_state_rows);
2397
- }
2398
-
2399
- ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
2400
- llm_graph_input_rs * inp,
2401
- const llama_ubatch & ubatch,
2402
- int il) const {
2403
- const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
2404
-
2405
- const auto token_shift_count = hparams.token_shift_count;
2406
-
2407
- const int64_t n_seqs = ubatch.n_seqs;
2408
-
2409
- ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
2410
-
2411
- ggml_tensor * token_shift = build_rs(
2412
- inp, token_shift_all,
2413
- hparams.n_embd_r(), n_seqs);
2414
-
2415
- token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
2416
-
2417
- return token_shift;
2418
- }
2419
-
2420
- ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
2421
- ggml_tensor * token_shift,
2422
- const llama_ubatch & ubatch,
2423
- int il) const {
2424
- const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
2425
-
2426
- const auto token_shift_count = hparams.token_shift_count;
2427
- const auto n_embd = hparams.n_embd;
2428
-
2429
- const int64_t n_seqs = ubatch.n_seqs;
2430
-
2431
- const auto kv_head = mctx_cur->get_head();
2432
-
2433
- return ggml_cpy(
2434
- ctx0,
2435
- ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
2436
- ggml_view_1d(ctx0, mctx_cur->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(mctx_cur->get_r_l(il)))
2437
- );
2438
- }
2439
-
2440
- llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
2441
- const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
2442
-
2443
- auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr());
2444
- auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
2445
-
2446
- auto inp = std::make_unique<llm_graph_input_mem_hybrid>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
2447
-
2448
- return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
2449
- }
2450
-
2451
- llm_graph_input_mem_hybrid_k * llm_graph_context::build_inp_mem_hybrid_k() const {
2452
- const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
2453
-
2454
- auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr());
2455
- auto inp_attn = build_attn_inp_k_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
2456
-
2457
- auto inp = std::make_unique<llm_graph_input_mem_hybrid_k>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
2458
-
2459
- return (llm_graph_input_mem_hybrid_k *) res->add_input(std::move(inp));
2460
- }
2461
-
2462
- llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() const {
2463
- const auto * mctx_cur = static_cast<const llama_memory_hybrid_iswa_context *>(mctx);
2464
-
2465
- auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
2466
-
2467
- // build iswa attention input
2468
- const auto * attn_ctx = mctx_cur->get_attn();
2469
-
2470
- auto inp_attn = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, attn_ctx);
2471
-
2472
- {
2473
- inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch);
2474
- inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch);
2475
-
2476
- inp_attn->self_kq_mask = build_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams);
2477
- ggml_set_input(inp_attn->self_kq_mask);
2478
-
2479
- inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask;
2480
- }
2481
-
2482
- {
2483
- inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch);
2484
- inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch);
2485
-
2486
- inp_attn->self_kq_mask_swa = build_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams);
2487
- ggml_set_input(inp_attn->self_kq_mask_swa);
2488
-
2489
- inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa;
2490
- }
2491
-
2492
- auto inp = std::make_unique<llm_graph_input_mem_hybrid_iswa>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
2493
-
2494
- return (llm_graph_input_mem_hybrid_iswa *) res->add_input(std::move(inp));
2495
- }
2496
-
2497
- void llm_graph_context::build_dense_out(
2498
- ggml_tensor * dense_2,
2499
- ggml_tensor * dense_2_b,
2500
- ggml_tensor * dense_3) const {
2501
- if (!cparams.embeddings || !(dense_2 || dense_2_b || dense_3)) {
2502
- return;
2503
- }
2504
- ggml_tensor * cur = res->t_embd_pooled != nullptr ? res->t_embd_pooled : res->t_embd;
2505
- GGML_ASSERT(cur != nullptr && "missing t_embd_pooled/t_embd");
2506
-
2507
- if (dense_2) {
2508
- cur = ggml_mul_mat(ctx0, dense_2, cur);
2509
- }
2510
- if (dense_2_b) {
2511
- cur = ggml_add(ctx0, cur, dense_2_b);
2512
- }
2513
- if (dense_3) {
2514
- cur = ggml_mul_mat(ctx0, dense_3, cur);
2515
- }
2516
- cb(cur, "result_embd_pooled", -1);
2517
- res->t_embd_pooled = cur;
2518
- ggml_build_forward_expand(gf, cur);
2519
- }
2520
-
2521
-
2522
- void llm_graph_context::build_pooling(
2523
- ggml_tensor * cls,
2524
- ggml_tensor * cls_b,
2525
- ggml_tensor * cls_out,
2526
- ggml_tensor * cls_out_b,
2527
- ggml_tensor * cls_norm) const {
2528
- if (!cparams.embeddings) {
2529
- return;
2530
- }
2531
-
2532
- ggml_tensor * inp = res->t_embd;
2533
-
2534
- //// find result_norm tensor for input
2535
- //for (int i = ggml_graph_n_nodes(gf) - 1; i >= 0; --i) {
2536
- // inp = ggml_graph_node(gf, i);
2537
- // if (strcmp(inp->name, "result_norm") == 0 || strcmp(inp->name, "result_embd") == 0) {
2538
- // break;
2539
- // }
2540
-
2541
- // inp = nullptr;
2542
- //}
2543
-
2544
- GGML_ASSERT(inp != nullptr && "missing result_norm/result_embd tensor");
2545
-
2546
- ggml_tensor * cur;
2547
-
2548
- switch (pooling_type) {
2549
- case LLAMA_POOLING_TYPE_NONE:
2550
- {
2551
- cur = inp;
2552
- } break;
2553
- case LLAMA_POOLING_TYPE_MEAN:
2554
- {
2555
- ggml_tensor * inp_mean = build_inp_mean();
2556
- cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean);
2557
- } break;
2558
- case LLAMA_POOLING_TYPE_CLS:
2559
- case LLAMA_POOLING_TYPE_LAST:
2560
- {
2561
- ggml_tensor * inp_cls = build_inp_cls();
2562
- cur = ggml_get_rows(ctx0, inp, inp_cls);
2563
- } break;
2564
- case LLAMA_POOLING_TYPE_RANK:
2565
- {
2566
- if (arch == LLM_ARCH_MODERN_BERT) {
2567
- // modern bert gte reranker builds mean first then applies prediction head and classifier
2568
- // https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modular_modernbert.py#L1404-1411
2569
- ggml_tensor * inp_mean = build_inp_mean();
2570
- cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean);
2571
- } else {
2572
- ggml_tensor * inp_cls = build_inp_cls();
2573
- cur = ggml_get_rows(ctx0, inp, inp_cls);
2574
- }
2575
-
2576
- // classification head
2577
- // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
2578
- if (cls) {
2579
- cur = ggml_mul_mat(ctx0, cls, cur);
2580
- if (cls_b) {
2581
- cur = ggml_add(ctx0, cur, cls_b);
2582
- }
2583
- if (arch == LLM_ARCH_MODERN_BERT) {
2584
- cur = ggml_gelu(ctx0, cur);
2585
- } else {
2586
- cur = ggml_tanh(ctx0, cur);
2587
- }
2588
- if (cls_norm) {
2589
- // head norm
2590
- cur = build_norm(cur, cls_norm, NULL, LLM_NORM, -1);
2591
- }
2592
- }
2593
-
2594
- // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
2595
- // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
2596
- // Single layer classification head (direct projection)
2597
- // https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
2598
- if (cls_out) {
2599
- cur = ggml_mul_mat(ctx0, cls_out, cur);
2600
- if (cls_out_b) {
2601
- cur = ggml_add(ctx0, cur, cls_out_b);
2602
- }
2603
- }
2604
-
2605
- // softmax for qwen3 reranker
2606
- if (arch == LLM_ARCH_QWEN3 || arch == LLM_ARCH_QWEN3VL) {
2607
- cur = ggml_soft_max(ctx0, cur);
2608
- }
2609
- } break;
2610
- default:
2611
- {
2612
- GGML_ABORT("unknown pooling type");
2613
- }
2614
- }
2615
-
2616
- cb(cur, "result_embd_pooled", -1);
2617
- res->t_embd_pooled = cur;
2618
-
2619
- ggml_build_forward_expand(gf, cur);
2620
- }
2621
-
2622
- void llm_graph_context::build_sampling() const {
2623
- if (samplers.empty() || !res->t_logits) {
2624
- return;
2625
- }
2626
-
2627
- std::array<ggml_tensor *, 2> outs;
2628
- outs[0] = res->t_logits;
2629
-
2630
- auto inp_sampling = std::make_unique<llm_graph_input_sampling>(samplers);
2631
- res->add_input(std::move(inp_sampling));
2632
-
2633
- std::map<llama_seq_id, int32_t> seq_to_logit_row;
2634
- int32_t logit_row_idx = 0;
2635
-
2636
- for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
2637
- if (ubatch.output[i]) {
2638
- llama_seq_id seq_id = ubatch.seq_id[i][0];
2639
- seq_to_logit_row[seq_id] = logit_row_idx;
2640
- logit_row_idx++;
2641
- }
2642
- }
2643
-
2644
- // res->t_logits will contain logits for all tokens that want the logits calculated (logits=1 or output=1)
2645
- GGML_ASSERT(res->t_logits != nullptr && "missing t_logits tensor");
2646
-
2647
- // add a dummy row of logits
2648
- // this trick makes the graph static, regardless of which samplers are activated
2649
- // this is important in order to minimize graph reallocations
2650
- ggml_tensor * logits_t = ggml_pad(ctx0, res->t_logits, 0, 1, 0, 0);
2651
-
2652
- for (const auto & [seq_id, sampler] : samplers) {
2653
- const auto it = seq_to_logit_row.find(seq_id);
2654
-
2655
- // inactive samplers always work on the first row
2656
- const auto row_idx = it != seq_to_logit_row.end() ? it->second : 0;
2657
- const int i_out = it != seq_to_logit_row.end() ? 1 : 0;
2658
-
2659
- ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]);
2660
- ggml_format_name(logits_seq, "logits_seq_%d", seq_id);
2661
-
2662
- struct llama_sampler_data data = {
2663
- /*.logits =*/ logits_seq,
2664
- /*.probs =*/ nullptr,
2665
- /*.sampled =*/ nullptr,
2666
- /*.candidates =*/ nullptr,
2667
- };
2668
-
2669
- assert(sampler->iface->backend_apply);
2670
- sampler->iface->backend_apply(sampler, ctx0, gf, &data);
2671
-
2672
- if (data.sampled != nullptr) {
2673
- res->t_sampled[seq_id] = data.sampled;
2674
- outs[1] = data.sampled;
2675
- ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
2676
- }
2677
-
2678
- if (data.probs != nullptr) {
2679
- res->t_sampled_probs[seq_id] = data.probs;
2680
- outs[1] = data.probs;
2681
- ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
2682
- }
2683
-
2684
- if (data.logits != nullptr) {
2685
- res->t_sampled_logits[seq_id] = data.logits;
2686
- outs[1] = data.logits;
2687
- ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
2688
- }
2689
-
2690
- if (data.candidates != nullptr) {
2691
- res->t_candidates[seq_id] = data.candidates;
2692
- outs[1] = data.candidates;
2693
- ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
2694
- }
2695
- }
2696
-
2697
- // TODO: Call llama_sampler_accept_ggml after all samplers have been applied.
2698
- /*
2699
- for (const auto & [seq_id, sampler] : samplers) {
2700
- if (auto it = res->t_sampled.find(seq_id); it != res->t_sampled.end()) {
2701
- ggml_tensor * selected_token = it->second;
2702
- if (selected_token != nullptr) {
2703
- llama_sampler_accept_ggml(sampler, ctx0, gf, selected_token);
2704
- }
2705
- }
2706
- }
2707
- */
2708
- }
2709
-
2710
- int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
2711
- // TODO move to hparams if a T5 variant appears that uses a different value
2712
- const int64_t max_distance = 128;
2713
-
2714
- if (bidirectional) {
2715
- n_buckets >>= 1;
2716
- }
2717
-
2718
- const int64_t max_exact = n_buckets >> 1;
2719
-
2720
- int32_t relative_position = x - y;
2721
- int32_t relative_bucket = 0;
2722
-
2723
- if (bidirectional) {
2724
- relative_bucket += (relative_position > 0) * n_buckets;
2725
- relative_position = std::abs(relative_position);
2726
- } else {
2727
- relative_position = -std::min<int32_t>(relative_position, 0);
2728
- }
2729
-
2730
- int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
2731
- relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
2732
- relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
2733
-
2734
- return relative_bucket;
2735
- }