whispercpp 1.3.5 → 1.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (610) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +99 -2
  4. data/ext/extconf.rb +1 -0
  5. data/ext/ruby_whisper.c +20 -4
  6. data/ext/ruby_whisper.h +30 -2
  7. data/ext/ruby_whisper_context.c +216 -124
  8. data/ext/ruby_whisper_context_params.c +163 -0
  9. data/ext/ruby_whisper_model.c +0 -1
  10. data/ext/ruby_whisper_params.c +0 -1
  11. data/ext/ruby_whisper_segment.c +0 -1
  12. data/ext/ruby_whisper_token.c +29 -9
  13. data/ext/ruby_whisper_transcribe.cpp +4 -1
  14. data/ext/ruby_whisper_vad_context.c +48 -1
  15. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  16. data/ext/ruby_whisper_vad_params.c +0 -1
  17. data/ext/ruby_whisper_vad_segment.c +0 -1
  18. data/ext/ruby_whisper_vad_segments.c +0 -1
  19. data/ext/sources/CMakeLists.txt +1 -1
  20. data/ext/sources/bindings/javascript/package.json +1 -1
  21. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  22. data/ext/sources/examples/bench/bench.cpp +23 -18
  23. data/ext/sources/examples/cli/cli.cpp +8 -0
  24. data/ext/sources/examples/common-ggml.cpp +2 -0
  25. data/ext/sources/examples/miniaudio.h +4507 -2131
  26. data/ext/sources/examples/server/server.cpp +18 -4
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -2
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +7 -13
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +4 -3
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +335 -17
  31. data/ext/sources/examples/talk-llama/llama-arch.h +42 -0
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +3 -1
  33. data/ext/sources/examples/talk-llama/llama-chat.cpp +21 -1
  34. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  35. data/ext/sources/examples/talk-llama/llama-context.cpp +508 -520
  36. data/ext/sources/examples/talk-llama/llama-context.h +27 -28
  37. data/ext/sources/examples/talk-llama/llama-cparams.h +5 -0
  38. data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +8 -8
  40. data/ext/sources/examples/talk-llama/llama-graph.cpp +583 -130
  41. data/ext/sources/examples/talk-llama/llama-graph.h +131 -10
  42. data/ext/sources/examples/talk-llama/llama-hparams.cpp +57 -40
  43. data/ext/sources/examples/talk-llama/llama-hparams.h +79 -10
  44. data/ext/sources/examples/talk-llama/llama-impl.cpp +4 -4
  45. data/ext/sources/examples/talk-llama/llama-impl.h +13 -1
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +3 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +274 -89
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.h +2 -3
  49. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
  50. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
  51. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +11 -13
  52. data/ext/sources/examples/talk-llama/llama-mmap.cpp +28 -11
  53. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +527 -119
  54. data/ext/sources/examples/talk-llama/llama-model-loader.h +35 -5
  55. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +60 -46
  56. data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
  57. data/ext/sources/examples/talk-llama/llama-model.cpp +1365 -647
  58. data/ext/sources/examples/talk-llama/llama-model.h +72 -19
  59. data/ext/sources/examples/talk-llama/llama-quant.cpp +578 -346
  60. data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +190 -76
  61. data/ext/sources/examples/talk-llama/{llama-sampling.h → llama-sampler.h} +0 -2
  62. data/ext/sources/examples/talk-llama/llama-vocab.cpp +118 -48
  63. data/ext/sources/examples/talk-llama/llama-vocab.h +5 -0
  64. data/ext/sources/examples/talk-llama/llama.cpp +76 -22
  65. data/ext/sources/examples/talk-llama/llama.h +63 -30
  66. data/ext/sources/examples/talk-llama/models/afmoe.cpp +2 -3
  67. data/ext/sources/examples/talk-llama/models/apertus.cpp +3 -3
  68. data/ext/sources/examples/talk-llama/models/arcee.cpp +3 -3
  69. data/ext/sources/examples/talk-llama/models/arctic.cpp +4 -5
  70. data/ext/sources/examples/talk-llama/models/baichuan.cpp +4 -3
  71. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +1 -2
  72. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +3 -5
  73. data/ext/sources/examples/talk-llama/models/bert.cpp +13 -7
  74. data/ext/sources/examples/talk-llama/models/bitnet.cpp +9 -24
  75. data/ext/sources/examples/talk-llama/models/bloom.cpp +2 -2
  76. data/ext/sources/examples/talk-llama/models/chameleon.cpp +3 -3
  77. data/ext/sources/examples/talk-llama/models/chatglm.cpp +2 -2
  78. data/ext/sources/examples/talk-llama/models/codeshell.cpp +3 -3
  79. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +3 -3
  80. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +2 -2
  81. data/ext/sources/examples/talk-llama/models/command-r.cpp +2 -2
  82. data/ext/sources/examples/talk-llama/models/dbrx.cpp +4 -5
  83. data/ext/sources/examples/talk-llama/models/deci.cpp +3 -3
  84. data/ext/sources/examples/talk-llama/models/deepseek.cpp +4 -6
  85. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +24 -21
  86. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
  87. data/ext/sources/examples/talk-llama/models/dots1.cpp +4 -6
  88. data/ext/sources/examples/talk-llama/models/dream.cpp +3 -3
  89. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +4 -6
  90. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +3 -3
  91. data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
  92. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +3 -3
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +3 -3
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +2 -4
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +3 -3
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +1 -1
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +1 -1
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +1 -1
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +1 -1
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +7 -7
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +3 -3
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +14 -7
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +2 -2
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +2 -2
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +4 -5
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +4 -5
  108. data/ext/sources/examples/talk-llama/models/grok.cpp +4 -4
  109. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +5 -7
  110. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +3 -3
  111. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +4 -5
  112. data/ext/sources/examples/talk-llama/models/internlm2.cpp +3 -3
  113. data/ext/sources/examples/talk-llama/models/jais.cpp +2 -2
  114. data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +3 -3
  116. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
  117. data/ext/sources/examples/talk-llama/models/lfm2.cpp +145 -124
  118. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +4 -4
  119. data/ext/sources/examples/talk-llama/models/llada.cpp +3 -3
  120. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +4 -4
  121. data/ext/sources/examples/talk-llama/models/llama.cpp +18 -11
  122. data/ext/sources/examples/talk-llama/models/maincoder.cpp +3 -3
  123. data/ext/sources/examples/talk-llama/models/{graph-context-mamba.cpp → mamba-base.cpp} +9 -3
  124. data/ext/sources/examples/talk-llama/models/mamba.cpp +1 -2
  125. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +11 -5
  126. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +14 -13
  127. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +4 -5
  128. data/ext/sources/examples/talk-llama/models/mistral3.cpp +4 -4
  129. data/ext/sources/examples/talk-llama/models/models.h +181 -46
  130. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +2 -9
  131. data/ext/sources/examples/talk-llama/models/mpt.cpp +2 -2
  132. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +26 -14
  133. data/ext/sources/examples/talk-llama/models/nemotron.cpp +3 -3
  134. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +2 -2
  135. data/ext/sources/examples/talk-llama/models/olmo.cpp +3 -3
  136. data/ext/sources/examples/talk-llama/models/olmo2.cpp +3 -3
  137. data/ext/sources/examples/talk-llama/models/olmoe.cpp +4 -4
  138. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +1 -1
  139. data/ext/sources/examples/talk-llama/models/openelm.cpp +3 -3
  140. data/ext/sources/examples/talk-llama/models/orion.cpp +3 -3
  141. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
  142. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +3 -3
  143. data/ext/sources/examples/talk-llama/models/phi2.cpp +2 -2
  144. data/ext/sources/examples/talk-llama/models/phi3.cpp +3 -3
  145. data/ext/sources/examples/talk-llama/models/plamo.cpp +3 -3
  146. data/ext/sources/examples/talk-llama/models/plamo2.cpp +9 -5
  147. data/ext/sources/examples/talk-llama/models/plamo3.cpp +2 -2
  148. data/ext/sources/examples/talk-llama/models/plm.cpp +15 -14
  149. data/ext/sources/examples/talk-llama/models/qwen.cpp +2 -2
  150. data/ext/sources/examples/talk-llama/models/qwen2.cpp +3 -3
  151. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +4 -4
  152. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +3 -3
  153. data/ext/sources/examples/talk-llama/models/qwen3.cpp +12 -9
  154. data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
  155. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
  156. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +15 -8
  157. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +84 -432
  158. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +9 -18
  159. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +8 -17
  160. data/ext/sources/examples/talk-llama/models/refact.cpp +2 -2
  161. data/ext/sources/examples/talk-llama/models/rnd1.cpp +4 -4
  162. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +2 -0
  163. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +2 -0
  164. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +3 -3
  165. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +4 -4
  166. data/ext/sources/examples/talk-llama/models/smollm3.cpp +3 -3
  167. data/ext/sources/examples/talk-llama/models/stablelm.cpp +2 -2
  168. data/ext/sources/examples/talk-llama/models/starcoder.cpp +2 -2
  169. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +3 -3
  170. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
  171. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +2 -2
  172. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +2 -2
  173. data/ext/sources/examples/talk-llama/models/xverse.cpp +3 -3
  174. data/ext/sources/examples/talk-llama/unicode.cpp +21 -65
  175. data/ext/sources/ggml/CMakeLists.txt +9 -3
  176. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  177. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  179. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  180. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  181. data/ext/sources/ggml/include/ggml-rpc.h +6 -1
  182. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  183. data/ext/sources/ggml/include/ggml.h +56 -9
  184. data/ext/sources/ggml/src/CMakeLists.txt +3 -0
  185. data/ext/sources/ggml/src/ggml-alloc.c +4 -9
  186. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  187. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  188. data/ext/sources/ggml/src/ggml-backend-reg.cpp +28 -86
  189. data/ext/sources/ggml/src/ggml-backend.cpp +5 -2
  190. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  191. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +6 -2
  192. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  193. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  194. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +348 -189
  195. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +40 -85
  196. data/ext/sources/ggml/src/ggml-cann/common.h +3 -4
  197. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +44 -62
  198. data/ext/sources/ggml/src/ggml-common.h +11 -0
  199. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +16 -11
  200. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -19
  201. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  202. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
  203. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +85 -1
  204. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2744 -548
  205. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1653 -0
  206. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
  207. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
  208. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
  209. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +118 -18
  210. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +107 -26
  211. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  212. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  213. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -0
  214. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +59 -12
  215. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +15 -0
  216. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  217. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +965 -252
  218. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +584 -197
  219. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +903 -188
  220. data/ext/sources/ggml/src/ggml-cpu/ops.h +1 -0
  221. data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
  222. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  223. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2890 -679
  224. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  225. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  226. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +111 -3
  227. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  228. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +17 -0
  229. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
  230. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +19 -10
  231. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +32 -30
  232. data/ext/sources/ggml/src/ggml-cuda/common.cuh +134 -18
  233. data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +6 -3
  235. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +78 -64
  236. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +384 -143
  237. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +36 -22
  238. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +3 -3
  239. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +26 -5
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  241. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +127 -12
  242. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
  243. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  244. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +595 -200
  245. data/ext/sources/ggml/src/ggml-cuda/mean.cu +9 -8
  246. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +173 -6
  247. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  248. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +158 -85
  249. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +34 -22
  250. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +127 -67
  251. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  252. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +157 -65
  253. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -0
  254. data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
  255. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  256. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +1 -1
  257. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
  258. data/ext/sources/ggml/src/ggml-cuda/rope.cu +233 -133
  259. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  260. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  261. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +56 -32
  262. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
  264. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  265. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
  266. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  267. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  268. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  269. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +3 -3
  270. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +0 -1
  271. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +199 -135
  272. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  273. data/ext/sources/ggml/src/ggml-cuda/unary.cu +55 -0
  274. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
  275. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
  276. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +10 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +82 -45
  278. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +334 -160
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +7 -5
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +328 -197
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +765 -234
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +412 -265
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +23 -23
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +1 -1
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.h → hex-dma.h} +28 -3
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +1 -1
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +27 -37
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +6 -35
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  297. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +20 -1347
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +211 -13
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +1119 -952
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +254 -244
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +36 -36
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +155 -138
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +209 -114
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  317. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  321. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +6 -0
  322. data/ext/sources/ggml/src/ggml-impl.h +62 -0
  323. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  324. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  325. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  326. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  327. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +274 -73
  328. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +22 -4
  329. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +102 -36
  330. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +174 -23
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +580 -280
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +5 -4
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +320 -107
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1068 -825
  335. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +19 -1
  336. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +3108 -636
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +204 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  367. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  368. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  369. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
  370. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  371. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
  372. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  373. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
  374. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
  375. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  376. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  377. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  378. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  379. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  380. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  381. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  382. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  383. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  384. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  385. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  386. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  387. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  388. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  389. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  390. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  391. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  392. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
  393. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  394. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  395. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  396. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  397. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  398. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  399. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
  400. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
  401. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
  402. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
  403. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  404. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  405. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  406. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  407. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
  412. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
  413. data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
  414. data/ext/sources/ggml/src/ggml-quants.c +96 -5
  415. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  416. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +15 -88
  417. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  418. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +1 -0
  419. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  420. data/ext/sources/ggml/src/ggml-sycl/common.hpp +315 -10
  421. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +69 -1
  422. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
  423. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  424. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  425. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +78 -68
  426. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  427. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
  428. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
  429. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
  430. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
  431. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
  432. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  433. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
  434. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
  435. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +316 -51
  436. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  437. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  438. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  439. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
  440. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  441. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  442. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  443. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  444. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  445. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  446. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  447. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  448. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  449. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  450. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  451. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  452. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
  453. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
  454. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
  455. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
  456. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
  457. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
  458. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
  459. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
  460. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
  461. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
  462. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
  463. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
  464. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
  465. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
  466. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
  467. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
  468. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
  469. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
  470. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
  471. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
  472. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
  473. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
  474. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
  475. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
  476. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
  477. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
  478. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
  479. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
  480. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
  481. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
  482. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
  483. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
  484. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
  485. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
  486. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
  487. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
  488. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +13 -0
  489. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  490. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  491. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  492. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  493. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  494. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  495. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  496. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  497. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  498. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  499. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  500. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  501. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  502. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  503. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  504. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  505. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  506. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  507. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  508. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  509. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  510. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  511. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  512. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  513. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  514. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  515. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  516. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  517. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  518. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  519. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  520. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  521. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  522. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  523. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  524. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  525. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  526. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  527. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  528. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  529. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  530. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  531. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  532. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +1 -1
  533. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1250 -465
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +374 -170
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +66 -22
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +389 -201
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +106 -58
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -3
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +8 -4
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +3 -3
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +36 -63
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +10 -5
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +16 -10
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +55 -35
  560. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1314 -109
  561. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1660 -1371
  562. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  563. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  564. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  565. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
  566. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
  567. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
  568. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +6 -0
  569. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  570. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +40 -5
  571. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +105 -60
  572. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
  573. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +68 -257
  574. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +692 -23
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_reg_tile.tmpl.wgsl → mul_mat_reg_tile.wgsl} +28 -128
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +31 -137
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +9 -36
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  584. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
  585. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  586. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +9 -6
  587. data/ext/sources/ggml/src/ggml.c +167 -33
  588. data/ext/sources/ggml/src/gguf.cpp +229 -44
  589. data/ext/sources/src/whisper.cpp +6 -28
  590. data/sig/whisper.rbs +43 -2
  591. data/test/test_context_params.rb +82 -0
  592. data/test/test_token.rb +11 -0
  593. data/test/test_vad_context.rb +58 -8
  594. data/test/test_whisper.rb +20 -0
  595. data/whispercpp.gemspec +1 -1
  596. metadata +240 -28
  597. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  598. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  599. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  600. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  601. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  602. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  603. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  604. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  605. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  606. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  607. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  608. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  609. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  610. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
@@ -3,14 +3,14 @@
3
3
  #include "ggml-cpu.h"
4
4
  #include "ggml-impl.h"
5
5
  #include "binary-ops.h"
6
+ #include "simd-gemm.h"
6
7
  #include "ggml.h"
7
8
  #include "unary-ops.h"
8
9
  #include "vec.h"
9
10
 
10
- #include <cfloat>
11
11
  #include <algorithm>
12
+ #include <cfloat>
12
13
  #include <cmath>
13
- #include <functional>
14
14
 
15
15
  // ggml_compute_forward_dup
16
16
 
@@ -375,7 +375,7 @@ static void ggml_compute_forward_dup_bytes(
375
375
  const size_t rs = ne00 * type_size;
376
376
 
377
377
  if (nb00 == type_size) {
378
- // src0 is contigous on first dimension, copy by rows
378
+ // src0 is contiguous on first dimension, copy by rows
379
379
  for (int64_t i03 = 0; i03 < ne03; i03++) {
380
380
  for (int64_t i02 = 0; i02 < ne02; i02++) {
381
381
  id += rs * ir0;
@@ -670,6 +670,7 @@ void ggml_compute_forward_add(
670
670
  case GGML_TYPE_Q5_1:
671
671
  case GGML_TYPE_Q8_0:
672
672
  case GGML_TYPE_MXFP4:
673
+ case GGML_TYPE_NVFP4:
673
674
  case GGML_TYPE_Q2_K:
674
675
  case GGML_TYPE_Q3_K:
675
676
  case GGML_TYPE_Q4_K:
@@ -1119,6 +1120,7 @@ void ggml_compute_forward_add1(
1119
1120
  case GGML_TYPE_Q8_0:
1120
1121
  case GGML_TYPE_Q8_1:
1121
1122
  case GGML_TYPE_MXFP4:
1123
+ case GGML_TYPE_NVFP4:
1122
1124
  case GGML_TYPE_Q2_K:
1123
1125
  case GGML_TYPE_Q3_K:
1124
1126
  case GGML_TYPE_Q4_K:
@@ -1247,6 +1249,7 @@ void ggml_compute_forward_acc(
1247
1249
  case GGML_TYPE_Q8_0:
1248
1250
  case GGML_TYPE_Q8_1:
1249
1251
  case GGML_TYPE_MXFP4:
1252
+ case GGML_TYPE_NVFP4:
1250
1253
  case GGML_TYPE_Q2_K:
1251
1254
  case GGML_TYPE_Q3_K:
1252
1255
  case GGML_TYPE_Q4_K:
@@ -1795,7 +1798,7 @@ void ggml_compute_forward_repeat(
1795
1798
  {
1796
1799
  ggml_compute_forward_repeat_f32(params, dst);
1797
1800
  } break;
1798
- // TODO: templateify the implemenation and support for I64
1801
+ // TODO: templateify the implementation and support for I64
1799
1802
  // ref https://github.com/ggml-org/llama.cpp/pull/14274#discussion_r2169492225
1800
1803
  //case GGML_TYPE_I64:
1801
1804
  // {
@@ -2097,10 +2100,14 @@ static void ggml_compute_forward_gelu_f32(
2097
2100
 
2098
2101
  const ggml_tensor * src0 = dst->src[0];
2099
2102
 
2100
- assert(ggml_is_contiguous_1(src0));
2101
- assert(ggml_is_contiguous_1(dst));
2103
+ assert(ggml_is_contiguous_rows(src0));
2102
2104
  assert(ggml_are_same_shape(src0, dst));
2103
2105
 
2106
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
2107
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
2108
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
2109
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
2110
+
2104
2111
  const int ith = params->ith;
2105
2112
  const int nth = params->nth;
2106
2113
 
@@ -2114,19 +2121,23 @@ static void ggml_compute_forward_gelu_f32(
2114
2121
  const int ir0 = dr*ith;
2115
2122
  const int ir1 = MIN(ir0 + dr, nr);
2116
2123
 
2117
- for (int i1 = ir0; i1 < ir1; i1++) {
2124
+ for (int ir = ir0; ir < ir1; ++ir) {
2125
+ const int i3 = ir/(ne02*ne01);
2126
+ const int i2 = (ir - i3*ne02*ne01)/ne01;
2127
+ const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
2128
+
2118
2129
  ggml_vec_gelu_f32(nc,
2119
- (float *) ((char *) dst->data + i1*( dst->nb[1])),
2120
- (float *) ((char *) src0->data + i1*(src0->nb[1])));
2130
+ (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
2131
+ (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
2121
2132
 
2122
2133
  #ifndef NDEBUG
2123
2134
  for (int k = 0; k < nc; k++) {
2124
- const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2135
+ const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k];
2125
2136
  GGML_UNUSED(x);
2126
2137
  assert(!isnan(x));
2127
2138
  assert(!isinf(x));
2128
2139
  }
2129
- #endif
2140
+ #endif // NDEBUG
2130
2141
  }
2131
2142
  }
2132
2143
 
@@ -2136,10 +2147,14 @@ static void ggml_compute_forward_gelu_f16(
2136
2147
 
2137
2148
  const ggml_tensor * src0 = dst->src[0];
2138
2149
 
2139
- assert(ggml_is_contiguous_1(src0));
2140
- assert(ggml_is_contiguous_1(dst));
2150
+ assert(ggml_is_contiguous_rows(src0));
2141
2151
  assert(ggml_are_same_shape(src0, dst));
2142
2152
 
2153
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
2154
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
2155
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
2156
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
2157
+
2143
2158
  const int ith = params->ith;
2144
2159
  const int nth = params->nth;
2145
2160
 
@@ -2153,20 +2168,24 @@ static void ggml_compute_forward_gelu_f16(
2153
2168
  const int ir0 = dr*ith;
2154
2169
  const int ir1 = MIN(ir0 + dr, nr);
2155
2170
 
2156
- for (int i1 = ir0; i1 < ir1; i1++) {
2171
+ for (int ir = ir0; ir < ir1; ++ir) {
2172
+ const int i3 = ir/(ne02*ne01);
2173
+ const int i2 = (ir - i3*ne02*ne01)/ne01;
2174
+ const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
2175
+
2157
2176
  ggml_vec_gelu_f16(nc,
2158
- (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
2159
- (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
2177
+ (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
2178
+ (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
2160
2179
 
2161
2180
  #ifndef NDEBUG
2162
2181
  for (int k = 0; k < nc; k++) {
2163
- const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2182
+ const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k];
2164
2183
  const float v = GGML_CPU_FP16_TO_FP32(x);
2165
2184
  GGML_UNUSED(v);
2166
2185
  assert(!isnan(v));
2167
2186
  assert(!isinf(v));
2168
2187
  }
2169
- #endif
2188
+ #endif // NDEBUG
2170
2189
  }
2171
2190
  }
2172
2191
 
@@ -2277,10 +2296,14 @@ static void ggml_compute_forward_gelu_erf_f32(
2277
2296
 
2278
2297
  const ggml_tensor * src0 = dst->src[0];
2279
2298
 
2280
- assert(ggml_is_contiguous_1(src0));
2281
- assert(ggml_is_contiguous_1(dst));
2299
+ assert(ggml_is_contiguous_rows(src0));
2282
2300
  assert(ggml_are_same_shape(src0, dst));
2283
2301
 
2302
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
2303
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
2304
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
2305
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
2306
+
2284
2307
  const int ith = params->ith;
2285
2308
  const int nth = params->nth;
2286
2309
 
@@ -2294,19 +2317,23 @@ static void ggml_compute_forward_gelu_erf_f32(
2294
2317
  const int ir0 = dr*ith;
2295
2318
  const int ir1 = MIN(ir0 + dr, nr);
2296
2319
 
2297
- for (int i1 = ir0; i1 < ir1; i1++) {
2320
+ for (int ir = ir0; ir < ir1; ++ir) {
2321
+ const int i3 = ir/(ne02*ne01);
2322
+ const int i2 = (ir - i3*ne02*ne01)/ne01;
2323
+ const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
2324
+
2298
2325
  ggml_vec_gelu_erf_f32(nc,
2299
- (float *) ((char *) dst->data + i1*( dst->nb[1])),
2300
- (float *) ((char *) src0->data + i1*(src0->nb[1])));
2326
+ (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
2327
+ (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
2301
2328
 
2302
2329
  #ifndef NDEBUG
2303
2330
  for (int k = 0; k < nc; k++) {
2304
- const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2331
+ const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k];
2305
2332
  GGML_UNUSED(x);
2306
2333
  assert(!isnan(x));
2307
2334
  assert(!isinf(x));
2308
2335
  }
2309
- #endif
2336
+ #endif // NDEBUG
2310
2337
  }
2311
2338
  }
2312
2339
 
@@ -2316,10 +2343,14 @@ static void ggml_compute_forward_gelu_erf_f16(
2316
2343
 
2317
2344
  const ggml_tensor * src0 = dst->src[0];
2318
2345
 
2319
- assert(ggml_is_contiguous_1(src0));
2320
- assert(ggml_is_contiguous_1(dst));
2346
+ assert(ggml_is_contiguous_rows(src0));
2321
2347
  assert(ggml_are_same_shape(src0, dst));
2322
2348
 
2349
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
2350
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
2351
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
2352
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
2353
+
2323
2354
  const int ith = params->ith;
2324
2355
  const int nth = params->nth;
2325
2356
 
@@ -2333,20 +2364,24 @@ static void ggml_compute_forward_gelu_erf_f16(
2333
2364
  const int ir0 = dr*ith;
2334
2365
  const int ir1 = MIN(ir0 + dr, nr);
2335
2366
 
2336
- for (int i1 = ir0; i1 < ir1; i1++) {
2367
+ for (int ir = ir0; ir < ir1; ++ir) {
2368
+ const int i3 = ir/(ne02*ne01);
2369
+ const int i2 = (ir - i3*ne02*ne01)/ne01;
2370
+ const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
2371
+
2337
2372
  ggml_vec_gelu_erf_f16(nc,
2338
- (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
2339
- (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
2373
+ (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
2374
+ (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
2340
2375
 
2341
2376
  #ifndef NDEBUG
2342
2377
  for (int k = 0; k < nc; k++) {
2343
- const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2378
+ const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k];
2344
2379
  const float v = GGML_CPU_FP16_TO_FP32(x);
2345
2380
  GGML_UNUSED(v);
2346
2381
  assert(!isnan(v));
2347
2382
  assert(!isinf(v));
2348
2383
  }
2349
- #endif
2384
+ #endif // NDEBUG
2350
2385
  }
2351
2386
  }
2352
2387
 
@@ -2380,10 +2415,14 @@ static void ggml_compute_forward_gelu_quick_f32(
2380
2415
 
2381
2416
  const ggml_tensor * src0 = dst->src[0];
2382
2417
 
2383
- assert(ggml_is_contiguous_1(src0));
2384
- assert(ggml_is_contiguous_1(dst));
2418
+ assert(ggml_is_contiguous_rows(src0));
2385
2419
  assert(ggml_are_same_shape(src0, dst));
2386
2420
 
2421
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
2422
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
2423
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
2424
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
2425
+
2387
2426
  const int ith = params->ith;
2388
2427
  const int nth = params->nth;
2389
2428
 
@@ -2397,19 +2436,23 @@ static void ggml_compute_forward_gelu_quick_f32(
2397
2436
  const int ir0 = dr*ith;
2398
2437
  const int ir1 = MIN(ir0 + dr, nr);
2399
2438
 
2400
- for (int i1 = ir0; i1 < ir1; i1++) {
2439
+ for (int ir = ir0; ir < ir1; ++ir) {
2440
+ const int i3 = ir/(ne02*ne01);
2441
+ const int i2 = (ir - i3*ne02*ne01)/ne01;
2442
+ const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
2443
+
2401
2444
  ggml_vec_gelu_quick_f32(nc,
2402
- (float *) ((char *) dst->data + i1*( dst->nb[1])),
2403
- (float *) ((char *) src0->data + i1*(src0->nb[1])));
2445
+ (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
2446
+ (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
2404
2447
 
2405
2448
  #ifndef NDEBUG
2406
2449
  for (int k = 0; k < nc; k++) {
2407
- const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2450
+ const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k];
2408
2451
  GGML_UNUSED(x);
2409
2452
  assert(!isnan(x));
2410
2453
  assert(!isinf(x));
2411
2454
  }
2412
- #endif
2455
+ #endif // NDEBUG
2413
2456
  }
2414
2457
  }
2415
2458
 
@@ -2419,10 +2462,14 @@ static void ggml_compute_forward_gelu_quick_f16(
2419
2462
 
2420
2463
  const ggml_tensor * src0 = dst->src[0];
2421
2464
 
2422
- assert(ggml_is_contiguous_1(src0));
2423
- assert(ggml_is_contiguous_1(dst));
2465
+ assert(ggml_is_contiguous_rows(src0));
2424
2466
  assert(ggml_are_same_shape(src0, dst));
2425
2467
 
2468
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
2469
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
2470
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
2471
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
2472
+
2426
2473
  const int ith = params->ith;
2427
2474
  const int nth = params->nth;
2428
2475
 
@@ -2436,20 +2483,24 @@ static void ggml_compute_forward_gelu_quick_f16(
2436
2483
  const int ir0 = dr*ith;
2437
2484
  const int ir1 = MIN(ir0 + dr, nr);
2438
2485
 
2439
- for (int i1 = ir0; i1 < ir1; i1++) {
2486
+ for (int ir = ir0; ir < ir1; ++ir) {
2487
+ const int i3 = ir/(ne02*ne01);
2488
+ const int i2 = (ir - i3*ne02*ne01)/ne01;
2489
+ const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
2490
+
2440
2491
  ggml_vec_gelu_quick_f16(nc,
2441
- (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
2442
- (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
2492
+ (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
2493
+ (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
2443
2494
 
2444
2495
  #ifndef NDEBUG
2445
2496
  for (int k = 0; k < nc; k++) {
2446
- const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2497
+ const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k];
2447
2498
  const float v = GGML_CPU_FP16_TO_FP32(x);
2448
2499
  GGML_UNUSED(v);
2449
2500
  assert(!isnan(v));
2450
2501
  assert(!isinf(v));
2451
2502
  }
2452
- #endif
2503
+ #endif // NDEBUG
2453
2504
  }
2454
2505
  }
2455
2506
 
@@ -2483,10 +2534,14 @@ static void ggml_compute_forward_silu_f32(
2483
2534
 
2484
2535
  const ggml_tensor * src0 = dst->src[0];
2485
2536
 
2486
- assert(ggml_is_contiguous_1(src0));
2487
- assert(ggml_is_contiguous_1(dst));
2537
+ assert(ggml_is_contiguous_rows(src0));
2488
2538
  assert(ggml_are_same_shape(src0, dst));
2489
2539
 
2540
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
2541
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
2542
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
2543
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
2544
+
2490
2545
  const int ith = params->ith;
2491
2546
  const int nth = params->nth;
2492
2547
 
@@ -2500,19 +2555,23 @@ static void ggml_compute_forward_silu_f32(
2500
2555
  const int ir0 = dr*ith;
2501
2556
  const int ir1 = MIN(ir0 + dr, nr);
2502
2557
 
2503
- for (int i1 = ir0; i1 < ir1; i1++) {
2558
+ for (int ir = ir0; ir < ir1; ++ir) {
2559
+ const int i3 = ir/(ne02*ne01);
2560
+ const int i2 = (ir - i3*ne02*ne01)/ne01;
2561
+ const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
2562
+
2504
2563
  ggml_vec_silu_f32(nc,
2505
- (float *) ((char *) dst->data + i1*( dst->nb[1])),
2506
- (float *) ((char *) src0->data + i1*(src0->nb[1])));
2564
+ (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
2565
+ (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
2507
2566
 
2508
2567
  #ifndef NDEBUG
2509
2568
  for (int k = 0; k < nc; k++) {
2510
- const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k];
2569
+ const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k];
2511
2570
  GGML_UNUSED(x);
2512
2571
  assert(!isnan(x));
2513
2572
  assert(!isinf(x));
2514
2573
  }
2515
- #endif
2574
+ #endif // NDEBUG
2516
2575
  }
2517
2576
  }
2518
2577
 
@@ -2522,10 +2581,14 @@ static void ggml_compute_forward_silu_f16(
2522
2581
 
2523
2582
  const ggml_tensor * src0 = dst->src[0];
2524
2583
 
2525
- assert(ggml_is_contiguous_1(src0));
2526
- assert(ggml_is_contiguous_1(dst));
2584
+ assert(ggml_is_contiguous_rows(src0));
2527
2585
  assert(ggml_are_same_shape(src0, dst));
2528
2586
 
2587
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
2588
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
2589
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
2590
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
2591
+
2529
2592
  const int ith = params->ith;
2530
2593
  const int nth = params->nth;
2531
2594
 
@@ -2539,20 +2602,24 @@ static void ggml_compute_forward_silu_f16(
2539
2602
  const int ir0 = dr*ith;
2540
2603
  const int ir1 = MIN(ir0 + dr, nr);
2541
2604
 
2542
- for (int i1 = ir0; i1 < ir1; i1++) {
2605
+ for (int ir = ir0; ir < ir1; ++ir) {
2606
+ const int i3 = ir/(ne02*ne01);
2607
+ const int i2 = (ir - i3*ne02*ne01)/ne01;
2608
+ const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
2609
+
2543
2610
  ggml_vec_silu_f16(nc,
2544
- (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
2545
- (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
2611
+ (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
2612
+ (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
2546
2613
 
2547
2614
  #ifndef NDEBUG
2548
2615
  for (int k = 0; k < nc; k++) {
2549
- const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])))[k];
2616
+ const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k];
2550
2617
  const float v = GGML_CPU_FP16_TO_FP32(x);
2551
2618
  GGML_UNUSED(v);
2552
2619
  assert(!isnan(v));
2553
2620
  assert(!isinf(v));
2554
2621
  }
2555
- #endif
2622
+ #endif // NDEBUG
2556
2623
  }
2557
2624
  }
2558
2625
 
@@ -2702,7 +2769,7 @@ static void ggml_compute_forward_silu_back_f32(
2702
2769
  assert(!isnan(x));
2703
2770
  assert(!isinf(x));
2704
2771
  }
2705
- #endif
2772
+ #endif // NDEBUG
2706
2773
  }
2707
2774
  }
2708
2775
 
@@ -2738,7 +2805,7 @@ static void ggml_compute_forward_silu_back_f16(
2738
2805
  (ggml_fp16_t *) ((char *) src1->data + i1*(src1->nb[1])),
2739
2806
  (ggml_fp16_t *) ((char *) grad->data + i1*(grad->nb[1])));
2740
2807
 
2741
- #ifndef NDEBUG
2808
+ #ifndef NDEBUG
2742
2809
  for (int k = 0; k < nc; k++) {
2743
2810
  const float x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2744
2811
  const float v = GGML_CPU_FP16_TO_FP32(x);
@@ -2746,7 +2813,7 @@ static void ggml_compute_forward_silu_back_f16(
2746
2813
  assert(!isnan(v));
2747
2814
  assert(!isinf(v));
2748
2815
  }
2749
- #endif
2816
+ #endif // NDEBUG
2750
2817
  }
2751
2818
  }
2752
2819
 
@@ -2829,7 +2896,7 @@ static void ggml_compute_forward_reglu_f32(
2829
2896
  assert(!isnan(x));
2830
2897
  assert(!isinf(x));
2831
2898
  }
2832
- #endif
2899
+ #endif // NDEBUG
2833
2900
  }
2834
2901
  }
2835
2902
 
@@ -2889,7 +2956,7 @@ static void ggml_compute_forward_reglu_f16(
2889
2956
  assert(!isnan(v));
2890
2957
  assert(!isinf(v));
2891
2958
  }
2892
- #endif
2959
+ #endif // NDEBUG
2893
2960
  }
2894
2961
  }
2895
2962
 
@@ -2972,7 +3039,7 @@ static void ggml_compute_forward_geglu_f32(
2972
3039
  assert(!isnan(x));
2973
3040
  assert(!isinf(x));
2974
3041
  }
2975
- #endif
3042
+ #endif // NDEBUG
2976
3043
  }
2977
3044
  }
2978
3045
 
@@ -3032,7 +3099,7 @@ static void ggml_compute_forward_geglu_f16(
3032
3099
  assert(!isnan(v));
3033
3100
  assert(!isinf(v));
3034
3101
  }
3035
- #endif
3102
+ #endif // NDEBUG
3036
3103
  }
3037
3104
  }
3038
3105
 
@@ -3115,7 +3182,7 @@ static void ggml_compute_forward_swiglu_f32(
3115
3182
  assert(!isnan(x));
3116
3183
  assert(!isinf(x));
3117
3184
  }
3118
- #endif
3185
+ #endif // NDEBUG
3119
3186
  }
3120
3187
  }
3121
3188
 
@@ -3175,7 +3242,7 @@ static void ggml_compute_forward_swiglu_f16(
3175
3242
  assert(!isnan(v));
3176
3243
  assert(!isinf(v));
3177
3244
  }
3178
- #endif
3245
+ #endif // NDEBUG
3179
3246
  }
3180
3247
  }
3181
3248
 
@@ -3266,7 +3333,7 @@ static void ggml_compute_forward_swiglu_oai_f32(
3266
3333
  assert(!isnan(x));
3267
3334
  assert(!isinf(x));
3268
3335
  }
3269
- #endif
3336
+ #endif // NDEBUG
3270
3337
  }
3271
3338
  }
3272
3339
 
@@ -3345,7 +3412,7 @@ static void ggml_compute_forward_geglu_erf_f32(
3345
3412
  assert(!isnan(x));
3346
3413
  assert(!isinf(x));
3347
3414
  }
3348
- #endif
3415
+ #endif // NDEBUG
3349
3416
  }
3350
3417
  }
3351
3418
 
@@ -3405,7 +3472,7 @@ static void ggml_compute_forward_geglu_erf_f16(
3405
3472
  assert(!isnan(v));
3406
3473
  assert(!isinf(v));
3407
3474
  }
3408
- #endif
3475
+ #endif // NDEBUG
3409
3476
  }
3410
3477
  }
3411
3478
 
@@ -3488,7 +3555,7 @@ static void ggml_compute_forward_geglu_quick_f32(
3488
3555
  assert(!isnan(x));
3489
3556
  assert(!isinf(x));
3490
3557
  }
3491
- #endif
3558
+ #endif // NDEBUG
3492
3559
  }
3493
3560
  }
3494
3561
 
@@ -3548,7 +3615,7 @@ static void ggml_compute_forward_geglu_quick_f16(
3548
3615
  assert(!isnan(v));
3549
3616
  assert(!isinf(v));
3550
3617
  }
3551
- #endif
3618
+ #endif // NDEBUG
3552
3619
  }
3553
3620
  }
3554
3621
 
@@ -4270,6 +4337,7 @@ void ggml_compute_forward_out_prod(
4270
4337
  case GGML_TYPE_Q5_1:
4271
4338
  case GGML_TYPE_Q8_0:
4272
4339
  case GGML_TYPE_MXFP4:
4340
+ case GGML_TYPE_NVFP4:
4273
4341
  case GGML_TYPE_Q2_K:
4274
4342
  case GGML_TYPE_Q3_K:
4275
4343
  case GGML_TYPE_Q4_K:
@@ -4545,6 +4613,7 @@ void ggml_compute_forward_set(
4545
4613
  case GGML_TYPE_Q8_0:
4546
4614
  case GGML_TYPE_Q8_1:
4547
4615
  case GGML_TYPE_MXFP4:
4616
+ case GGML_TYPE_NVFP4:
4548
4617
  case GGML_TYPE_Q2_K:
4549
4618
  case GGML_TYPE_Q3_K:
4550
4619
  case GGML_TYPE_Q4_K:
@@ -4767,6 +4836,7 @@ void ggml_compute_forward_get_rows(
4767
4836
  case GGML_TYPE_Q8_0:
4768
4837
  case GGML_TYPE_Q8_1:
4769
4838
  case GGML_TYPE_MXFP4:
4839
+ case GGML_TYPE_NVFP4:
4770
4840
  case GGML_TYPE_Q2_K:
4771
4841
  case GGML_TYPE_Q3_K:
4772
4842
  case GGML_TYPE_Q4_K:
@@ -5239,7 +5309,7 @@ static void ggml_compute_forward_soft_max_f32(
5239
5309
  //printf("p[%d] = %f\n", i, p[i]);
5240
5310
  assert(!isnan(wp[i]));
5241
5311
  }
5242
- #endif
5312
+ #endif // NDEBUG
5243
5313
 
5244
5314
  float max = -INFINITY;
5245
5315
  ggml_vec_max_f32(ne00, &max, wp);
@@ -5264,7 +5334,7 @@ static void ggml_compute_forward_soft_max_f32(
5264
5334
  assert(!isnan(dp[i]));
5265
5335
  assert(!isinf(dp[i]));
5266
5336
  }
5267
- #endif
5337
+ #endif // NDEBUG
5268
5338
  }
5269
5339
  }
5270
5340
  }
@@ -5338,7 +5408,7 @@ static void ggml_compute_forward_soft_max_ext_back_f32(
5338
5408
  assert(!isnan(dy[i]));
5339
5409
  assert(!isnan(y[i]));
5340
5410
  }
5341
- #endif
5411
+ #endif // NDEBUG
5342
5412
  // Jii = yi - yi*yi
5343
5413
  // Jij = -yi*yj
5344
5414
  // J = diag(y)-y.T*y
@@ -5371,7 +5441,7 @@ static void ggml_compute_forward_soft_max_ext_back_f32(
5371
5441
  assert(!isnan(dx[i]));
5372
5442
  assert(!isinf(dx[i]));
5373
5443
  }
5374
- #endif
5444
+ #endif // NDEBUG
5375
5445
  }
5376
5446
  }
5377
5447
 
@@ -5491,6 +5561,7 @@ void ggml_compute_forward_clamp(
5491
5561
  case GGML_TYPE_Q8_0:
5492
5562
  case GGML_TYPE_Q8_1:
5493
5563
  case GGML_TYPE_MXFP4:
5564
+ case GGML_TYPE_NVFP4:
5494
5565
  case GGML_TYPE_Q2_K:
5495
5566
  case GGML_TYPE_Q3_K:
5496
5567
  case GGML_TYPE_Q4_K:
@@ -5739,28 +5810,33 @@ static void ggml_compute_forward_rope_flt(
5739
5810
 
5740
5811
  const int32_t * pos = (const int32_t *) src1->data;
5741
5812
 
5813
+ int64_t last_i2 = -1;
5814
+
5742
5815
  for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
5743
5816
  for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
5744
-
5745
- float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
5746
- if (!mrope_used) {
5747
- const int64_t p = pos[i2];
5748
- ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5749
- }
5750
- else {
5751
- const int64_t p_t = pos[i2];
5752
- const int64_t p_h = pos[i2 + ne2];
5753
- const int64_t p_w = pos[i2 + ne2 * 2];
5754
- const int64_t p_e = pos[i2 + ne2 * 3];
5755
- ggml_mrope_cache_init(
5756
- p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
5757
- freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5758
- }
5759
-
5760
5817
  for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
5761
- if (ir++ < ir0) continue;
5818
+ if (ir++ < ir0) continue; // skip rows mapped to other threads
5762
5819
  if (ir > ir1) break;
5763
5820
 
5821
+ float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
5822
+ if (last_i2 != i2) {
5823
+ if (!mrope_used) {
5824
+ const int64_t p = pos[i2];
5825
+ ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5826
+ }
5827
+ else {
5828
+ const int64_t p_t = pos[i2];
5829
+ const int64_t p_h = pos[i2 + ne2];
5830
+ const int64_t p_w = pos[i2 + ne2 * 2];
5831
+ const int64_t p_e = pos[i2 + ne2 * 3];
5832
+ ggml_mrope_cache_init(
5833
+ p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
5834
+ freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5835
+ }
5836
+
5837
+ last_i2 = i2;
5838
+ }
5839
+
5764
5840
  T * src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
5765
5841
  T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
5766
5842
 
@@ -6129,7 +6205,7 @@ static void ggml_compute_forward_im2col_f16(
6129
6205
  const ggml_tensor * src1 = dst->src[1];
6130
6206
 
6131
6207
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
6132
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
6208
+ GGML_ASSERT(src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
6133
6209
  GGML_ASSERT( dst->type == GGML_TYPE_F16);
6134
6210
 
6135
6211
  GGML_TENSOR_BINARY_OP_LOCALS;
@@ -6160,7 +6236,7 @@ static void ggml_compute_forward_im2col_f16(
6160
6236
  int ofs1 = is_2D ? nb12 : nb11;
6161
6237
 
6162
6238
  GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
6163
- GGML_ASSERT(nb10 == sizeof(float));
6239
+ GGML_ASSERT(nb10 == ggml_type_size(src1->type));
6164
6240
 
6165
6241
  // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
6166
6242
  {
@@ -6173,7 +6249,12 @@ static void ggml_compute_forward_im2col_f16(
6173
6249
 
6174
6250
  // micro kernel
6175
6251
  ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
6176
- const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
6252
+ const float * const src_data_f32 = src1->type == GGML_TYPE_F32
6253
+ ? (const float *)((const char *) src1->data + in*ofs0 + iic*ofs1)
6254
+ : nullptr; // [IH, IW]
6255
+ const ggml_fp16_t * const src_data_f16 = src1->type == GGML_TYPE_F16
6256
+ ? (const ggml_fp16_t *)((const char *) src1->data + in*ofs0 + iic*ofs1)
6257
+ : nullptr; // [IH, IW]
6177
6258
 
6178
6259
  for (int64_t ikh = 0; ikh < KH; ikh++) { // 1
6179
6260
  for (int64_t ikw = 0; ikw < KW; ikw++) {
@@ -6183,7 +6264,11 @@ static void ggml_compute_forward_im2col_f16(
6183
6264
  if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
6184
6265
  dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
6185
6266
  } else {
6186
- dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(src_data[iih*IW + iiw]);
6267
+ if (src_data_f32 != nullptr) {
6268
+ dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(src_data_f32[iih*IW + iiw]);
6269
+ } else {
6270
+ dst_data[iic*(KH*KW) + ikh*KW + ikw] = src_data_f16[iih*IW + iiw];
6271
+ }
6187
6272
  }
6188
6273
  }
6189
6274
  }
@@ -7110,12 +7195,13 @@ void ggml_compute_forward_conv_2d_dw(
7110
7195
  }
7111
7196
  }
7112
7197
 
7113
- // ggml_compute_forward_pool_1d_sk_p0
7114
-
7115
- static void ggml_compute_forward_pool_1d_sk_p0(
7198
+ // ggml_compute_forward_pool_1d_ksp
7199
+ static void ggml_compute_forward_pool_1d_ksp(
7116
7200
  const ggml_compute_params * params,
7117
7201
  const ggml_op_pool op,
7118
7202
  const int k,
7203
+ const int s,
7204
+ const int p,
7119
7205
  ggml_tensor * dst) {
7120
7206
 
7121
7207
  const ggml_tensor * src = dst->src[0];
@@ -7126,39 +7212,56 @@ static void ggml_compute_forward_pool_1d_sk_p0(
7126
7212
  return;
7127
7213
  }
7128
7214
 
7129
- const char * cdata = (const char *)src->data;
7130
- const char * const data_end = cdata + ggml_nbytes(src);
7131
- float * drow = (float *)dst->data;
7215
+ const int64_t IW = src->ne[0];
7216
+ const int64_t OW = dst->ne[0];
7132
7217
 
7133
- const int64_t rs = dst->ne[0];
7218
+ const int64_t nr = ggml_nrows(src);
7134
7219
 
7135
- while (cdata < data_end) {
7136
- const void * srow = (const void *)cdata;
7137
- int j = 0;
7138
- for (int64_t i = 0; i < rs; ++i) {
7220
+ for (int64_t ir = 0; ir < nr; ++ir) {
7221
+ const char * srow_bytes = (const char *) src->data + ir * src->nb[1];
7222
+ float * drow = (float *) (( char *) dst->data + ir * dst->nb[1]);
7223
+
7224
+ for (int64_t ow = 0; ow < OW; ++ow) {
7225
+ float res = 0;
7139
7226
  switch (op) {
7140
- case GGML_OP_POOL_AVG: drow[i] = 0; break;
7141
- case GGML_OP_POOL_MAX: drow[i] = -FLT_MAX; break;
7227
+ case GGML_OP_POOL_AVG: res = 0.0f; break;
7228
+ case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
7142
7229
  case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7143
7230
  }
7231
+
7232
+ int count = 0;
7233
+ const int base = (int) ow * s - p;
7234
+
7144
7235
  for (int ki = 0; ki < k; ++ki) {
7145
- const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
7236
+ const int j = base + ki;
7237
+ if (j < 0 || j >= (int) IW) {
7238
+ continue;
7239
+ }
7240
+
7241
+ float v;
7242
+ if (src->type == GGML_TYPE_F32) {
7243
+ v = ((const float *) srow_bytes)[j];
7244
+ } else {
7245
+ v = GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) srow_bytes)[j]);
7246
+ }
7247
+
7146
7248
  switch (op) {
7147
- case GGML_OP_POOL_AVG: drow[i] += srow_j; break;
7148
- case GGML_OP_POOL_MAX: if (srow_j > drow[i]) drow[i] = srow_j; break;
7149
- case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7249
+ case GGML_OP_POOL_AVG: res += v; break;
7250
+ case GGML_OP_POOL_MAX: res = std::max(v, res); break;
7251
+ case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7150
7252
  }
7151
- ++j;
7253
+
7254
+ ++count;
7152
7255
  }
7256
+
7153
7257
  switch (op) {
7154
- case GGML_OP_POOL_AVG: drow[i] /= k; break;
7155
- case GGML_OP_POOL_MAX: break;
7258
+ case GGML_OP_POOL_AVG: res = (count > 0) ? (res / count) : 0.0f; break;
7259
+ case GGML_OP_POOL_MAX: break;
7156
7260
  case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7157
7261
  }
7158
- }
7159
7262
 
7160
- cdata += src->nb[1];
7161
- drow += rs;
7263
+ drow[ow] = res;
7264
+ }
7162
7265
  }
7163
7266
  }
7164
7267
 
@@ -7173,10 +7276,8 @@ void ggml_compute_forward_pool_1d(
7173
7276
  const int k0 = opts[1];
7174
7277
  const int s0 = opts[2];
7175
7278
  const int p0 = opts[3];
7176
- GGML_ASSERT(p0 == 0); // padding not supported
7177
- GGML_ASSERT(k0 == s0); // only s = k supported
7178
7279
 
7179
- ggml_compute_forward_pool_1d_sk_p0(params, op, k0, dst);
7280
+ ggml_compute_forward_pool_1d_ksp(params, op, k0, s0, p0, dst);
7180
7281
  }
7181
7282
 
7182
7283
  // ggml_compute_forward_pool_2d
@@ -7194,6 +7295,7 @@ void ggml_compute_forward_pool_2d(
7194
7295
  }
7195
7296
 
7196
7297
  const int32_t * opts = (const int32_t *)dst->op_params;
7298
+
7197
7299
  ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
7198
7300
  const int k0 = opts[1];
7199
7301
  const int k1 = opts[2];
@@ -7217,11 +7319,13 @@ void ggml_compute_forward_pool_2d(
7217
7319
  while (cdata < data_end) {
7218
7320
  for (int oy = 0; oy < py; ++oy) {
7219
7321
  float * const drow = dplane + oy * px;
7322
+ float * const out = drow;
7323
+
7220
7324
  for (int ox = 0; ox < px; ++ox) {
7221
- float * const out = drow + ox;
7325
+ float res = 0;
7222
7326
  switch (op) {
7223
- case GGML_OP_POOL_AVG: *out = 0; break;
7224
- case GGML_OP_POOL_MAX: *out = -FLT_MAX; break;
7327
+ case GGML_OP_POOL_AVG: res = 0; break;
7328
+ case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
7225
7329
  case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7226
7330
  }
7227
7331
 
@@ -7229,24 +7333,32 @@ void ggml_compute_forward_pool_2d(
7229
7333
  const int iy = offset1 + oy * s1;
7230
7334
 
7231
7335
  for (int ky = 0; ky < k1; ++ky) {
7232
- if (iy + ky < 0 || iy + ky >= src->ne[1]) continue;
7336
+ if (iy + ky < 0 || iy + ky >= src->ne[1]) {
7337
+ continue;
7338
+ }
7339
+
7233
7340
  const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky));
7234
7341
  for (int kx = 0; kx < k0; ++kx) {
7235
7342
  int j = ix + kx;
7236
- if (j < 0 || j >= src->ne[0]) continue;
7343
+ if (j < 0 || j >= src->ne[0]) {
7344
+ continue;
7345
+ }
7346
+
7237
7347
  const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
7238
7348
  switch (op) {
7239
- case GGML_OP_POOL_AVG: *out += srow_j; break;
7240
- case GGML_OP_POOL_MAX: if (srow_j > *out) *out = srow_j; break;
7349
+ case GGML_OP_POOL_AVG: res += srow_j; break;
7350
+ case GGML_OP_POOL_MAX: res = std::max(srow_j, res); break;
7241
7351
  case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7242
7352
  }
7243
7353
  }
7244
7354
  }
7245
7355
  switch (op) {
7246
- case GGML_OP_POOL_AVG: *out /= ka; break;
7247
- case GGML_OP_POOL_MAX: break;
7356
+ case GGML_OP_POOL_AVG: res /= ka; break;
7357
+ case GGML_OP_POOL_MAX: break;
7248
7358
  case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7249
7359
  }
7360
+
7361
+ out[ox] = res;
7250
7362
  }
7251
7363
  }
7252
7364
 
@@ -7603,8 +7715,7 @@ static void ggml_compute_forward_pad_f32(
7603
7715
 
7604
7716
  const ggml_tensor * src0 = dst->src[0];
7605
7717
 
7606
- GGML_ASSERT(src0->nb[0] == sizeof(float));
7607
- GGML_ASSERT( dst->nb[0] == sizeof(float));
7718
+ assert(dst->nb[0] == sizeof(float));
7608
7719
 
7609
7720
  const int ith = params->ith;
7610
7721
  const int nth = params->nth;
@@ -8016,12 +8127,14 @@ void ggml_compute_forward_top_k(
8016
8127
  }
8017
8128
  }
8018
8129
 
8019
- // ggml_compute_forward_flash_attn_ext
8020
-
8021
8130
  static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
8022
8131
  const ggml_compute_params * params,
8023
8132
  ggml_tensor * dst,
8024
- int ir0, int ir1) {
8133
+ int ir0, int ir1,
8134
+ int64_t ic_start, int64_t ic_end,
8135
+ float * partials, int64_t partial_stride) {
8136
+
8137
+ const bool write_partials = (partials != nullptr);
8025
8138
  const ggml_tensor * q = dst->src[0];
8026
8139
  const ggml_tensor * k = dst->src[1];
8027
8140
  const ggml_tensor * v = dst->src[2];
@@ -8098,7 +8211,6 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
8098
8211
 
8099
8212
  int ith = params->ith;
8100
8213
 
8101
- // loop over n_batch and n_head
8102
8214
  for (int ir = ir0; ir < ir1; ++ir) {
8103
8215
  // q indices
8104
8216
  const int iq3 = ir/(neq2*neq1);
@@ -8138,7 +8250,8 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
8138
8250
  // online softmax / attention
8139
8251
  // loop over n_kv and n_head_kv
8140
8252
  // ref: https://arxiv.org/pdf/2112.05682.pdf
8141
- for (int64_t ic = 0; ic < nek1; ++ic) {
8253
+
8254
+ for (int64_t ic = ic_start; ic < ic_end; ++ic) {
8142
8255
  const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f;
8143
8256
  if (mv == -INFINITY) {
8144
8257
  continue;
@@ -8211,8 +8324,8 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
8211
8324
  }
8212
8325
  }
8213
8326
 
8214
- // sinks
8215
- if (sinks) {
8327
+ // sinks - apply only on the first kv-chunk
8328
+ if (sinks && ic_start == 0) {
8216
8329
  const float s = ((float *)((char *) sinks->data))[h];
8217
8330
 
8218
8331
  float ms = 1.0f;
@@ -8220,6 +8333,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
8220
8333
 
8221
8334
  if (s > M) {
8222
8335
  ms = expf(M - s);
8336
+ M = s;
8223
8337
  ggml_vec_scale_f32(DV, VKQ32, ms);
8224
8338
  } else {
8225
8339
  vs = expf(s - M);
@@ -8228,20 +8342,386 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
8228
8342
  S = S*ms + vs;
8229
8343
  }
8230
8344
 
8231
- // V /= S
8232
- const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
8233
- ggml_vec_scale_f32(DV, VKQ32, S_inv);
8345
+ if (write_partials) {
8346
+ // Write M, S, VKQ to partials for later reduction
8347
+ // partials layout: [M, S, VKQ[DV]] per query head
8348
+ float * partial = partials + ir * partial_stride;
8349
+ partial[0] = M;
8350
+ partial[1] = S;
8351
+ memcpy(partial + 2, VKQ32, DV * sizeof(float));
8352
+ } else {
8353
+ // V /= S
8354
+ const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
8355
+ ggml_vec_scale_f32(DV, VKQ32, S_inv);
8356
+
8357
+ // dst indices
8358
+ const int i1 = iq1;
8359
+ const int i2 = iq2;
8360
+ const int i3 = iq3;
8234
8361
 
8235
- // dst indices
8236
- const int i1 = iq1;
8237
- const int i2 = iq2;
8238
- const int i3 = iq3;
8362
+ // permute(0, 2, 1, 3)
8363
+ memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
8364
+ }
8365
+ }
8366
+ }
8367
+
8368
+ static void ggml_compute_forward_flash_attn_ext_tiled(
8369
+ const ggml_compute_params * params,
8370
+ ggml_tensor * dst,
8371
+ int ir0, int ir1) {
8372
+ const ggml_tensor * q = dst->src[0];
8373
+ const ggml_tensor * k = dst->src[1];
8374
+ const ggml_tensor * v = dst->src[2];
8375
+ const ggml_tensor * mask = dst->src[3];
8376
+ const ggml_tensor * sinks = dst->src[4];
8377
+
8378
+ GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
8379
+ GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
8380
+ GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
8381
+ GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
8382
+ GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
8383
+ GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
8384
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
8385
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
8386
+
8387
+ const int64_t DK = nek0;
8388
+ const int64_t DV = nev0;
8389
+ const int64_t N = neq1;
8390
+
8391
+ GGML_ASSERT(ne0 == DV);
8392
+ GGML_ASSERT(ne2 == N);
8393
+
8394
+ // input tensor rows must be contiguous
8395
+ GGML_ASSERT(nbq0 == ggml_type_size(q->type));
8396
+ GGML_ASSERT(nbk0 == ggml_type_size(k->type));
8397
+ GGML_ASSERT(nbv0 == ggml_type_size(v->type));
8398
+
8399
+ GGML_ASSERT(neq0 == DK);
8400
+ GGML_ASSERT(nek0 == DK);
8401
+ GGML_ASSERT(nev0 == DV);
8402
+
8403
+ GGML_ASSERT(neq1 == N);
8404
+
8405
+ // dst cannot be transposed or permuted
8406
+ GGML_ASSERT(nb0 == sizeof(float));
8407
+ GGML_ASSERT(nb0 <= nb1);
8408
+ GGML_ASSERT(nb1 <= nb2);
8409
+ GGML_ASSERT(nb2 <= nb3);
8410
+
8411
+ GGML_ASSERT(k->type == v->type);
8412
+ const ggml_type kv_type = k->type;
8413
+
8414
+
8415
+ // broadcast factors
8416
+ const int64_t rk2 = neq2/nek2;
8417
+ const int64_t rk3 = neq3/nek3;
8418
+
8419
+ const int64_t rv2 = neq2/nev2;
8420
+ const int64_t rv3 = neq3/nev3;
8421
+
8422
+ float scale = 1.0f;
8423
+ float max_bias = 0.0f;
8424
+ float logit_softcap = 0.0f;
8425
+
8426
+ memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
8427
+ memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
8428
+ memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
8429
+
8430
+ if (logit_softcap != 0) {
8431
+ scale /= logit_softcap;
8432
+ }
8433
+
8434
+ const uint32_t n_head = neq2;
8435
+ const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
8436
+
8437
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
8438
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
8439
+
8440
+ int ith = params->ith;
8441
+
8442
+ static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q;
8443
+ static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV;
8444
+
8445
+ int ir = ir0;
8446
+ while (ir < ir1) {
8447
+ // q indices for the start of this tile
8448
+ const int iq3 = ir/(neq2*neq1);
8449
+ const int iq2 = (ir - iq3*neq2*neq1)/neq1;
8450
+ const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
8451
+
8452
+ // Number of valid rows in this tile:
8453
+ // - limited by tile size (Q_TILE_SZ)
8454
+ // - limited by chunk boundary (ir1 - ir)
8455
+ // - limited by head boundary (neq1 - iq1) to avoid crossing into next head
8456
+ const int tile_rows = MIN(Q_TILE_SZ, MIN((int)(ir1 - ir), (int)(neq1 - iq1)));
8457
+ GGML_ASSERT(tile_rows > 0);
8458
+
8459
+ const uint32_t h = iq2; // head index
8460
+ const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
8461
+
8462
+ float S[Q_TILE_SZ];
8463
+ float M[Q_TILE_SZ];
8464
+
8465
+ for (int i = 0 ; i < Q_TILE_SZ; ++i) {
8466
+ S[i] = 0.;
8467
+ M[i] = -INFINITY;
8468
+ }
8239
8469
 
8240
- // original
8241
- //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
8470
+ // Per-thread scratch layout:
8471
+ // Q_q: Q_TILE_SZ * DK (converted Q tile F32 for GEMM, KV type for scalar)
8472
+ // KQ: Q_TILE_SZ * KV_TILE_SZ (attention scores in float)
8473
+ // mask: Q_TILE_SZ * KV_TILE_SZ (mask in float)
8474
+ // VKQ32: Q_TILE_SZ * DV (FP32 output accumulator)
8475
+ // V32: KV_TILE_SZ * DV (F32 buffer for V tile)
8476
+ // K_f32: KV_TILE_SZ * DK (F32 buffer for K tile — GEMM path)
8477
+ float * base = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + KV_TILE_SZ*DK + CACHE_LINE_SIZE_F32);
8242
8478
 
8243
- // permute(0, 2, 1, 3)
8244
- memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
8479
+ void * Q_q = base;
8480
+ float * KQ = (float *)((char *)base + Q_TILE_SZ * DK * sizeof(float));
8481
+ float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ;
8482
+ float * VKQ32 = mask32 + Q_TILE_SZ * KV_TILE_SZ;
8483
+ float * V32 = VKQ32 + Q_TILE_SZ * DV;
8484
+ float * K_f32 = V32 + KV_TILE_SZ * DV;
8485
+
8486
+ memset(VKQ32, 0, Q_TILE_SZ * DV * sizeof(float));
8487
+ memset(mask32, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
8488
+
8489
+ // k indices
8490
+ const int ik3 = iq3 / rk3;
8491
+ const int ik2 = iq2 / rk2;
8492
+
8493
+ // v indices
8494
+ const int iv3 = iq3 / rv3;
8495
+ const int iv2 = iq2 / rv2;
8496
+
8497
+ {
8498
+ float * Q_f32 = (float *)Q_q;
8499
+ for (int tq = 0; tq < tile_rows; tq++) {
8500
+ const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
8501
+ memcpy(Q_f32 + tq * DK, pq, DK * sizeof(float));
8502
+ }
8503
+ for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
8504
+ memset(Q_f32 + tq * DK, 0, DK * sizeof(float));
8505
+ }
8506
+ }
8507
+
8508
+ memset(K_f32, 0, DK * KV_TILE_SZ * sizeof(float));
8509
+ memset(V32, 0, KV_TILE_SZ * DV * sizeof(float));
8510
+
8511
+ for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) {
8512
+ const int kv_tile = (int)std::min((int64_t)KV_TILE_SZ, nek1 - ic);
8513
+
8514
+ // skip the tile entirely if all the masks are -inf
8515
+ if (mask) {
8516
+ bool can_skip = true;
8517
+ for (int tq = 0; tq < tile_rows; tq++) {
8518
+ const ggml_fp16_t * mp_row = (const ggml_fp16_t *)((const char *) mask->data + (iq1 + tq)*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]);
8519
+ for (int tk = 0; tk < kv_tile; tk++) {
8520
+ mask32[tq * KV_TILE_SZ + tk] = slope * GGML_CPU_FP16_TO_FP32(mp_row[ic + tk]);
8521
+ if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY) {
8522
+ can_skip = false;
8523
+ }
8524
+ }
8525
+ // Pad remaining mask entries with -inf
8526
+ for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) {
8527
+ mask32[tq * KV_TILE_SZ + tk] = -INFINITY;
8528
+ }
8529
+ }
8530
+
8531
+ if (can_skip) {
8532
+ continue;
8533
+ }
8534
+ }
8535
+
8536
+ // Pack K tile transposed: K_f32[dk][kv] so KV_TILE is contiguous (SIMD dim)
8537
+ // Zero-pad the last tile so the GEMM always operates on KV_TILE_SZ columns
8538
+ for (int tk = 0; tk < kv_tile; tk++) {
8539
+ const char * k_data = (const char *)k->data + (ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3;
8540
+ if (kv_type == GGML_TYPE_F16) {
8541
+ const ggml_fp16_t * k_f16 = (const ggml_fp16_t *)k_data;
8542
+ for (int64_t dk = 0; dk < DK; dk++) {
8543
+ K_f32[dk * KV_TILE_SZ + tk] = GGML_CPU_FP16_TO_FP32(k_f16[dk]);
8544
+ }
8545
+ } else {
8546
+ const float * k_f32_src = (const float *)k_data;
8547
+ for (int64_t dk = 0; dk < DK; dk++) {
8548
+ K_f32[dk * KV_TILE_SZ + tk] = k_f32_src[dk];
8549
+ }
8550
+ }
8551
+ }
8552
+ memset(KQ, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
8553
+ simd_gemm(KQ, (const float *)Q_q, K_f32, Q_TILE_SZ, DK, KV_TILE_SZ);
8554
+ ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, scale);
8555
+
8556
+ // Set padded KQ entries to -inf so softmax gives them zero weight
8557
+ if (kv_tile < KV_TILE_SZ) {
8558
+ for (int tq = 0; tq < Q_TILE_SZ; tq++) {
8559
+ for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) {
8560
+ KQ[tq * KV_TILE_SZ + tk] = -INFINITY;
8561
+ }
8562
+ }
8563
+ }
8564
+
8565
+ if (logit_softcap != 0.0f) {
8566
+ ggml_vec_tanh_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, KQ);
8567
+ ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, logit_softcap);
8568
+ }
8569
+
8570
+ if (mask) {
8571
+ ggml_vec_add_f32(tile_rows * KV_TILE_SZ, KQ, KQ, mask32);
8572
+ }
8573
+
8574
+ bool skip[Q_TILE_SZ] = {};
8575
+
8576
+ for (int tq = 0; tq < Q_TILE_SZ; tq++) {
8577
+ float * kq_row = KQ + tq * KV_TILE_SZ;
8578
+
8579
+ float tile_max;
8580
+ ggml_vec_max_f32(KV_TILE_SZ, &tile_max, kq_row);
8581
+
8582
+ if (tile_max == -INFINITY) {
8583
+ skip[tq] = true;
8584
+ continue;
8585
+ }
8586
+
8587
+ const float Mold = M[tq];
8588
+ const float Mnew = fmaxf(Mold, tile_max);
8589
+
8590
+ if (Mnew > Mold) {
8591
+ const float ms = expf(Mold - Mnew);
8592
+ ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms);
8593
+ S[tq] *= ms;
8594
+ }
8595
+ M[tq] = Mnew;
8596
+
8597
+
8598
+ S[tq] += ggml_vec_soft_max_f32(KV_TILE_SZ, kq_row, kq_row, Mnew);
8599
+ }
8600
+
8601
+ // V accumulation: VKQ32 += softmax(KQ) * V
8602
+ // Pack V tile to contiguous F32, zero-padded
8603
+ for (int tk = 0; tk < kv_tile; tk++) {
8604
+ const char * v_data = (const char *)v->data + (ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3;
8605
+ if (kv_type == GGML_TYPE_F16) {
8606
+ ggml_fp16_to_fp32_row((const ggml_fp16_t *)v_data, V32 + tk * DV, DV);
8607
+ } else {
8608
+ memcpy(V32 + tk * DV, v_data, DV * sizeof(float));
8609
+ }
8610
+ }
8611
+ for (int tq = 0; tq < Q_TILE_SZ; tq++) {
8612
+ if (skip[tq]) {
8613
+ memset(KQ + tq * KV_TILE_SZ, 0, KV_TILE_SZ * sizeof(float));
8614
+ }
8615
+ }
8616
+ simd_gemm(VKQ32, KQ, V32, Q_TILE_SZ, KV_TILE_SZ, DV);
8617
+ }
8618
+
8619
+ // sinks (apply only to valid rows in the tile)
8620
+ if (sinks) {
8621
+ const float s = ((float *)((char *) sinks->data))[h];
8622
+
8623
+ for (int tq = 0; tq < tile_rows; tq++) {
8624
+ float ms = 1.0f;
8625
+ float vs = 1.0f;
8626
+
8627
+ if (s > M[tq]) {
8628
+ ms = expf(M[tq] - s);
8629
+ ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms);
8630
+ } else {
8631
+ vs = expf(s - M[tq]);
8632
+ }
8633
+
8634
+ S[tq] = S[tq] * ms + vs;
8635
+ }
8636
+ }
8637
+
8638
+ for (int tq = 0; tq < tile_rows; tq++) {
8639
+ // V /= S
8640
+ const float S_inv = S[tq] == 0.0f ? 0.0f : 1.0f / S[tq];
8641
+ ggml_vec_scale_f32(DV, VKQ32 + tq * DV, S_inv);
8642
+
8643
+ // dst indices
8644
+ const int i1 = iq1 + tq;
8645
+ const int i2 = iq2;
8646
+ const int i3 = iq3;
8647
+
8648
+ // permute(0, 2, 1, 3)
8649
+ memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32 + tq * DV, nb1);
8650
+ }
8651
+
8652
+ ir += tile_rows;
8653
+ }
8654
+ }
8655
+
8656
+ // Reduction function: combines partial results across KV chunks
8657
+ // Partials layout in wdata: [n_q_heads][n_chunks][2 + DV]
8658
+ static void ggml_flash_attn_ext_reduce_partials(
8659
+ const ggml_compute_params * params,
8660
+ ggml_tensor * dst,
8661
+ const int64_t n_chunks,
8662
+ const int64_t chunk_size) {
8663
+
8664
+ const ggml_tensor * q = dst->src[0];
8665
+ const ggml_tensor * k = dst->src[1];
8666
+ const ggml_tensor * v = dst->src[2];
8667
+
8668
+ const int64_t DK = k->ne[0];
8669
+ const int64_t DV = v->ne[0];
8670
+ const int64_t nek1 = k->ne[1];
8671
+ const int64_t n_q_heads = q->ne[2];
8672
+
8673
+ const int ith = params->ith;
8674
+ const int nth = params->nth;
8675
+
8676
+ const int64_t wdata_per_thread = DK + 2*DV + CACHE_LINE_SIZE_F32;
8677
+ float * thread_wdata = (float *) params->wdata + ith * wdata_per_thread;
8678
+
8679
+ const int64_t partials_offset = nth * (DK + 2*DV + CACHE_LINE_SIZE_F32);
8680
+ const int64_t partial_size = 2 + DV;
8681
+ const float * partials_base = (const float *) params->wdata + partials_offset;
8682
+
8683
+ // Output layout
8684
+ const int64_t ne1 = dst->ne[1];
8685
+ const int64_t ne2 = dst->ne[2];
8686
+ const size_t nb1 = dst->nb[1];
8687
+
8688
+ // Each thread reduces a subset of query heads
8689
+ for (int64_t q_head = ith; q_head < n_q_heads; q_head += nth) {
8690
+ float M_final = -INFINITY;
8691
+ float S_final = 0.0f;
8692
+ float * VKQ_final = thread_wdata;
8693
+ memset(VKQ_final, 0, DV * sizeof(float));
8694
+
8695
+ // Combine partials from all chunks
8696
+ for (int64_t chunk_idx = 0; chunk_idx < n_chunks; ++chunk_idx) {
8697
+ const int64_t ic_start = chunk_idx * chunk_size;
8698
+ if (ic_start >= nek1) continue;
8699
+
8700
+ const float * partial = partials_base + (q_head * n_chunks + chunk_idx) * partial_size;
8701
+ const float M_chunk = partial[0];
8702
+ const float S_chunk = partial[1];
8703
+ const float * VKQ_chunk = partial + 2;
8704
+
8705
+ if (S_chunk == 0.0f) continue;
8706
+
8707
+ const float M_new = fmaxf(M_final, M_chunk);
8708
+ const float scale_old = expf(M_final - M_new);
8709
+ const float scale_new = expf(M_chunk - M_new);
8710
+
8711
+ for (int64_t d = 0; d < DV; ++d) {
8712
+ VKQ_final[d] = VKQ_final[d] * scale_old + VKQ_chunk[d] * scale_new;
8713
+ }
8714
+ S_final = S_final * scale_old + S_chunk * scale_new;
8715
+ M_final = M_new;
8716
+ }
8717
+
8718
+ // Normalize and write to output
8719
+ if (S_final != 0.0f) {
8720
+ const float S_inv = 1.0f / S_final;
8721
+ ggml_vec_scale_f32(DV, VKQ_final, S_inv);
8722
+ }
8723
+ // iq1=0, iq3=0 for decode
8724
+ memcpy((char *) dst->data + (0*ne2*ne1 + q_head + 0*ne1)*nb1, VKQ_final, nb1);
8245
8725
  }
8246
8726
  }
8247
8727
 
@@ -8266,6 +8746,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
8266
8746
  const int64_t DV = nev0;
8267
8747
  const int64_t N = neq1;
8268
8748
 
8749
+
8269
8750
  GGML_ASSERT(ne0 == DV);
8270
8751
  GGML_ASSERT(ne2 == N);
8271
8752
 
@@ -8286,47 +8767,92 @@ static void ggml_compute_forward_flash_attn_ext_f16(
8286
8767
  GGML_ASSERT(nb1 <= nb2);
8287
8768
  GGML_ASSERT(nb2 <= nb3);
8288
8769
 
8289
- // parallelize by q rows using ggml_vec_dot_f32
8290
-
8291
- // total rows in q
8292
- const int64_t nr = neq1*neq2*neq3;
8293
-
8294
- // rows per thread
8295
8770
  const int ith = params->ith;
8296
8771
  const int nth = params->nth;
8297
8772
 
8298
- // disable for NUMA
8299
- const bool disable_chunking = ggml_is_numa();
8773
+ // When use_ref is set, force the vec-only reference implementation (no tiling, no KV-chunking)
8774
+ const bool use_ref = params->use_ref;
8300
8775
 
8301
- // 4x chunks per thread
8302
- int nth_scaled = nth * 4;
8303
- int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
8304
- int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
8776
+ const bool kv_is_f32_or_f16 = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16);
8777
+ const bool use_split_kv_path = !use_ref && (neq1 == 1 && neq3 == 1) && kv_is_f32_or_f16 && (k->type == v->type) && q->type == GGML_TYPE_F32 && nek1 >= 512;
8305
8778
 
8306
- if (nth == 1 || nchunk < nth || disable_chunking) {
8307
- nchunk = nth;
8308
- }
8779
+ if (use_split_kv_path) {
8780
+ const int64_t chunk_size = (nek1 + nth - 1) / nth;
8309
8781
 
8310
- if (ith == 0) {
8311
- // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
8312
- ggml_threadpool_chunk_set(params->threadpool, nth);
8313
- }
8782
+ // Partials buffer layout: [q_head][kv_chunk][M, S, VKQ]
8783
+ const int64_t partial_size = 2 + DV;
8784
+ float * partials_base = (float *) params->wdata + nth * (DK + 2*DV + CACHE_LINE_SIZE_F32);
8314
8785
 
8315
- ggml_barrier(params->threadpool);
8786
+ const int64_t ic_start = ith * chunk_size;
8787
+ const int64_t ic_end = std::min(ic_start + chunk_size, nek1);
8316
8788
 
8317
- // The number of elements in each chunk
8318
- const int64_t dr = (nr + nchunk - 1) / nchunk;
8789
+ const int64_t partial_stride = nth * partial_size;
8790
+ float * chunk_partials = partials_base + ith * partial_size;
8319
8791
 
8320
- // The first chunk comes from our thread_id, the rest will get auto-assigned.
8321
- int current_chunk = ith;
8792
+ if (ic_start < nek1) {
8793
+ for (int64_t q_head = 0; q_head < neq2; q_head++) {
8794
+ ggml_compute_forward_flash_attn_ext_f16_one_chunk(
8795
+ params, dst, q_head, q_head + 1, ic_start, ic_end,
8796
+ chunk_partials, partial_stride);
8797
+ }
8798
+ } else {
8799
+ for (int64_t q_head = 0; q_head < neq2; q_head++) {
8800
+ float * q_partials = chunk_partials + q_head * partial_stride;
8801
+ q_partials[0] = -INFINITY; // M
8802
+ q_partials[1] = 0.0f; // S
8803
+ }
8804
+ }
8322
8805
 
8323
- while (current_chunk < nchunk) {
8324
- const int64_t ir0 = dr * current_chunk;
8325
- const int64_t ir1 = MIN(ir0 + dr, nr);
8806
+ ggml_barrier(params->threadpool);
8807
+ ggml_flash_attn_ext_reduce_partials(params, dst, nth, chunk_size);
8808
+ } else {
8326
8809
 
8327
- ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
8810
+ // total rows in q
8811
+ const int64_t nr = neq1*neq2*neq3;
8328
8812
 
8329
- current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
8813
+ // disable for NUMA
8814
+ const bool disable_chunking = ggml_is_numa();
8815
+
8816
+ // 4x chunks per thread
8817
+ int nth_scaled = nth * 4;
8818
+ int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
8819
+ int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
8820
+
8821
+ if (nth == 1 || nchunk < nth || disable_chunking) {
8822
+ nchunk = nth;
8823
+ }
8824
+
8825
+ if (ith == 0) {
8826
+ ggml_threadpool_chunk_set(params->threadpool, nth);
8827
+ }
8828
+
8829
+ ggml_barrier(params->threadpool);
8830
+
8831
+ const int64_t dr = (nr + nchunk - 1) / nchunk;
8832
+
8833
+ static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
8834
+ bool use_tiled = !use_ref &&
8835
+ (q->type == GGML_TYPE_F32 &&
8836
+ kv_is_f32_or_f16 &&
8837
+ k->type == v->type &&
8838
+ neq1 >= Q_TILE_SZ);
8839
+ #ifdef GGML_SIMD
8840
+ use_tiled &= (DV % GGML_F32_EPR == 0);
8841
+ #endif
8842
+ int current_chunk = ith;
8843
+
8844
+ while (current_chunk < nchunk) {
8845
+ const int64_t ir0 = dr * current_chunk;
8846
+ const int64_t ir1 = MIN(ir0 + dr, nr);
8847
+
8848
+ if (use_tiled) {
8849
+ ggml_compute_forward_flash_attn_ext_tiled(params, dst, ir0, ir1);
8850
+ } else {
8851
+ ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1, 0, nek1, nullptr, 0);
8852
+ }
8853
+
8854
+ current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
8855
+ }
8330
8856
  }
8331
8857
  }
8332
8858
 
@@ -9107,7 +9633,7 @@ void ggml_compute_forward_win_unpart(
9107
9633
  }
9108
9634
  }
9109
9635
 
9110
- //gmml_compute_forward_unary
9636
+ //ggml_compute_forward_unary
9111
9637
 
9112
9638
  void ggml_compute_forward_unary(
9113
9639
  const ggml_compute_params * params,
@@ -9870,6 +10396,195 @@ void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, s
9870
10396
  }
9871
10397
  }
9872
10398
 
10399
+ // ggml_compute_forward_gated_delta_net
10400
+ static void ggml_compute_forward_gated_delta_net_one_chunk(
10401
+ const ggml_compute_params * params,
10402
+ ggml_tensor * dst,
10403
+ int64_t ir0,
10404
+ int64_t ir1) {
10405
+
10406
+ ggml_tensor * src_q = dst->src[0];
10407
+ ggml_tensor * src_k = dst->src[1];
10408
+ ggml_tensor * src_v = dst->src[2];
10409
+ ggml_tensor * src_g = dst->src[3];
10410
+ ggml_tensor * src_beta = dst->src[4];
10411
+ ggml_tensor * src_state = dst->src[5];
10412
+
10413
+ const int64_t S_v = src_v->ne[0];
10414
+ const int64_t H = src_v->ne[1];
10415
+ const int64_t n_tokens = src_v->ne[2];
10416
+ const int64_t n_seqs = src_v->ne[3];
10417
+
10418
+ GGML_ASSERT(ggml_is_contiguous_rows(src_q));
10419
+ GGML_ASSERT(ggml_is_contiguous_rows(src_k));
10420
+ GGML_ASSERT(ggml_is_contiguous_rows(src_v));
10421
+ GGML_ASSERT(ggml_is_contiguous(src_g));
10422
+ GGML_ASSERT(ggml_is_contiguous(src_beta));
10423
+ GGML_ASSERT(ggml_is_contiguous(src_state));
10424
+
10425
+ GGML_ASSERT(src_g->ne[0] == 1 || src_g->ne[0] == S_v);
10426
+ GGML_ASSERT(src_beta->ne[0] == 1);
10427
+
10428
+ GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne);
10429
+ GGML_TENSOR_LOCALS(size_t, nbq, src_q, nb);
10430
+ GGML_TENSOR_LOCALS(int64_t, nek, src_k, ne);
10431
+ GGML_TENSOR_LOCALS(size_t, nbk, src_k, nb);
10432
+ GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne);
10433
+ GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb);
10434
+ GGML_TENSOR_LOCALS(int64_t, neg, src_g, ne);
10435
+ GGML_TENSOR_LOCALS(size_t, nbg, src_g, nb);
10436
+ GGML_TENSOR_LOCALS(size_t, nbb, src_beta, nb);
10437
+
10438
+ const bool kda = (neg0 == S_v);
10439
+
10440
+ // scratch layout per thread: [delta(S_v)]
10441
+ const int64_t scratch_per_thread = S_v;
10442
+ const int ith = params->ith;
10443
+
10444
+ float * delta = (float *)params->wdata + ith * scratch_per_thread + CACHE_LINE_SIZE_F32;
10445
+
10446
+ // output layout: [attn_scores | new_states]
10447
+ // attn_scores: S_v * H * n_tokens * n_seqs floats
10448
+ // new_states: S_v * S_v * H * n_seqs floats
10449
+ const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
10450
+ float * attn_out_base = (float *)dst->data;
10451
+ float * state_out_base = (float *)dst->data + attn_score_elems;
10452
+
10453
+ const float * state_in_base = (const float *)src_state->data;
10454
+
10455
+ //const int64_t rq1 = nev1 / neq1;
10456
+ //const int64_t rk1 = nev1 / nek1;
10457
+ const int64_t rq3 = nev3 / neq3;
10458
+ const int64_t rk3 = nev3 / nek3;
10459
+
10460
+ const float scale = 1.0f / sqrtf((float) S_v);
10461
+
10462
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
10463
+ const int64_t iv1 = ir % H; // head_index
10464
+ const int64_t iv3 = ir / H; // sequence
10465
+
10466
+ const int64_t iq1 = iv1 % neq1;
10467
+ const int64_t ik1 = iv1 % nek1;
10468
+
10469
+ const int64_t iq3 = iv3 / rq3;
10470
+ const int64_t ik3 = iv3 / rk3;
10471
+
10472
+ float * s_out = state_out_base + (iv3 * H + iv1) * S_v * S_v;
10473
+
10474
+ // copy input state into output buffer and operate in-place
10475
+ const float * s_in = state_in_base + (iv3 * H + iv1) * S_v * S_v;
10476
+ memcpy(s_out, s_in, S_v * S_v * sizeof(float));
10477
+
10478
+ // attn output pointer for first token of this (head, seq)
10479
+ float * attn_data = attn_out_base + (iv3 * n_tokens * H + iv1) * S_v;
10480
+
10481
+ for (int64_t t = 0; t < n_tokens; t++) {
10482
+ const float * q_d = (const float *)((const char *)src_q->data + iq3 * nbq3 + t * nbq2 + iq1 * nbq1);
10483
+ const float * k_d = (const float *)((const char *)src_k->data + ik3 * nbk3 + t * nbk2 + ik1 * nbk1);
10484
+ const float * v_d = (const float *)((const char *)src_v->data + iv3 * nbv3 + t * nbv2 + iv1 * nbv1);
10485
+
10486
+ const float beta_val = *(const float *)((const char *)src_beta->data + iv3 * nbb3 + t * nbb2 + iv1 * nbb1);
10487
+ const float * g_d = (const float *)((const char *)src_g->data + iv3 * nbg3 + t * nbg2 + iv1 * nbg1);
10488
+
10489
+ // state is stored transposed: s_out[j*S_v + i] = S[i][j]
10490
+ // so row j of s_out = column j of S (contiguous access)
10491
+
10492
+ if (kda) {
10493
+ // precompute exp(g) into delta scratch (reused below)
10494
+ for (int64_t i = 0; i < S_v; ++i) {
10495
+ delta[i] = expf(g_d[i]);
10496
+ }
10497
+ // S[i][:] *= exp(g[i]) => for each row j of M: M[j][i] *= exp(g[i])
10498
+ for (int64_t j = 0; j < S_v; ++j) {
10499
+ ggml_vec_mul_f32(S_v, &s_out[j * S_v], &s_out[j * S_v], delta);
10500
+ }
10501
+ } else {
10502
+ ggml_vec_scale_f32(S_v * S_v, s_out, expf(g_d[0]));
10503
+ }
10504
+
10505
+ // delta[j] = sum_i S[i][j] * k[i] = dot(row j of M, k)
10506
+ for (int64_t j = 0; j < S_v; ++j) {
10507
+ float sum = 0.0f;
10508
+ ggml_vec_dot_f32(S_v, &sum, 0, &s_out[j * S_v], 0, k_d, 0, 1);
10509
+ delta[j] = (v_d[j] - sum) * beta_val;
10510
+ }
10511
+
10512
+ // outer product: S[i][j] += k[i] * delta[j] => M[j][i] += delta[j] * k[i]
10513
+ for (int64_t j = 0; j < S_v; ++j) {
10514
+ ggml_vec_mad_f32(S_v, &s_out[j * S_v], k_d, delta[j]);
10515
+ }
10516
+
10517
+ // attn_out[j] = sum_i S[i][j] * q[i] = dot(row j of M, q)
10518
+ for (int64_t j = 0; j < S_v; ++j) {
10519
+ float sum = 0.0f;
10520
+ ggml_vec_dot_f32(S_v, &sum, 0, &s_out[j * S_v], 0, q_d, 0, 1);
10521
+ attn_data[j] = sum * scale;
10522
+ }
10523
+
10524
+ attn_data += S_v * H; // advance to next token
10525
+ }
10526
+ }
10527
+ }
10528
+
10529
+
10530
+ static void ggml_compute_forward_gated_delta_net_f32(
10531
+ const ggml_compute_params * params,
10532
+ ggml_tensor * dst) {
10533
+
10534
+ ggml_tensor * V = dst->src[2];
10535
+ int64_t nr = V->ne[1] * V->ne[3];
10536
+
10537
+ // disable for NUMA
10538
+ const bool disable_chunking = ggml_is_numa();
10539
+
10540
+ int nth = params->nth;
10541
+ int ith = params->ith;
10542
+
10543
+ // 4x chunks per thread
10544
+ int nth_scaled = nth * 4;
10545
+ int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
10546
+ int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
10547
+
10548
+ if (nth == 1 || nchunk < nth || disable_chunking) {
10549
+ nchunk = nth;
10550
+ }
10551
+
10552
+ if (ith == 0) {
10553
+ ggml_threadpool_chunk_set(params->threadpool, nth);
10554
+ }
10555
+
10556
+ ggml_barrier(params->threadpool);
10557
+
10558
+ const int64_t dr = (nr + nchunk - 1) / nchunk;
10559
+
10560
+ int current_chunk = ith;
10561
+
10562
+ while (current_chunk < nchunk) {
10563
+ const int64_t ir0 = dr * current_chunk;
10564
+ const int64_t ir1 = MIN(ir0 + dr, nr);
10565
+
10566
+ ggml_compute_forward_gated_delta_net_one_chunk(params, dst, ir0, ir1);
10567
+ current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
10568
+ }
10569
+ }
10570
+
10571
+ void ggml_compute_forward_gated_delta_net(
10572
+ const ggml_compute_params * params,
10573
+ ggml_tensor * dst) {
10574
+ const ggml_tensor * src0 = dst->src[0];
10575
+
10576
+ switch (src0->type) {
10577
+ case GGML_TYPE_F32:
10578
+ {
10579
+ ggml_compute_forward_gated_delta_net_f32(params, dst);
10580
+ } break;
10581
+ default:
10582
+ {
10583
+ GGML_ABORT("fatal error");
10584
+ }
10585
+ }
10586
+ }
10587
+
9873
10588
  // ggml_compute_forward_rwkv_wkv7
9874
10589
 
9875
10590
  static void ggml_compute_forward_rwkv_wkv7_f32(
@@ -10195,7 +10910,7 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
10195
10910
  assert(!isnan(s0[i]));
10196
10911
  assert(!isnan(s1[i]));
10197
10912
  }
10198
- #endif
10913
+ #endif // NDEBUG
10199
10914
 
10200
10915
  float max = -INFINITY;
10201
10916
  ggml_vec_max_f32(nc, &max, s0);
@@ -10214,7 +10929,7 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
10214
10929
  assert(!isnan(st[i]));
10215
10930
  assert(!isinf(st[i]));
10216
10931
  }
10217
- #endif
10932
+ #endif // NDEBUG
10218
10933
  }
10219
10934
  sums[ith] = sum_thread;
10220
10935
  ggml_barrier(params->threadpool);
@@ -10287,7 +11002,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
10287
11002
  assert(!isnan(s0[i]));
10288
11003
  assert(!isnan(s1[i]));
10289
11004
  }
10290
- #endif
11005
+ #endif // NDEBUG
10291
11006
 
10292
11007
  // soft_max
10293
11008
  float max = -INFINITY;
@@ -10305,7 +11020,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
10305
11020
  assert(!isnan(ds0[i]));
10306
11021
  assert(!isinf(ds0[i]));
10307
11022
  }
10308
- #endif
11023
+ #endif // NDEBUG
10309
11024
  }
10310
11025
  }
10311
11026