whispercpp 1.3.4 → 1.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (891) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +158 -44
  4. data/ext/extconf.rb +3 -2
  5. data/ext/ruby_whisper.c +34 -6
  6. data/ext/ruby_whisper.h +67 -0
  7. data/ext/ruby_whisper_context.c +236 -144
  8. data/ext/ruby_whisper_context_params.c +163 -0
  9. data/ext/ruby_whisper_model.c +12 -13
  10. data/ext/ruby_whisper_params.c +47 -24
  11. data/ext/ruby_whisper_segment.c +84 -20
  12. data/ext/ruby_whisper_token.c +371 -0
  13. data/ext/ruby_whisper_transcribe.cpp +5 -2
  14. data/ext/ruby_whisper_vad_context.c +122 -0
  15. data/ext/ruby_whisper_vad_context_detect.cpp +51 -0
  16. data/ext/ruby_whisper_vad_params.c +0 -1
  17. data/ext/ruby_whisper_vad_segment.c +138 -0
  18. data/ext/ruby_whisper_vad_segments.c +105 -0
  19. data/ext/sources/CMakeLists.txt +4 -1
  20. data/ext/sources/bindings/javascript/package.json +1 -1
  21. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  22. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  23. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  24. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  25. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  26. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  27. data/ext/sources/examples/bench/bench.cpp +23 -18
  28. data/ext/sources/examples/cli/cli.cpp +129 -112
  29. data/ext/sources/examples/common-ggml.cpp +2 -0
  30. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  31. data/ext/sources/examples/miniaudio.h +4507 -2131
  32. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  33. data/ext/sources/examples/server/server.cpp +28 -15
  34. data/ext/sources/examples/talk-llama/CMakeLists.txt +8 -3
  35. data/ext/sources/examples/talk-llama/llama-adapter.cpp +5 -2
  36. data/ext/sources/examples/talk-llama/llama-adapter.h +7 -0
  37. data/ext/sources/examples/talk-llama/llama-arch.cpp +2378 -1988
  38. data/ext/sources/examples/talk-llama/llama-arch.h +109 -2
  39. data/ext/sources/examples/talk-llama/llama-batch.cpp +78 -34
  40. data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
  41. data/ext/sources/examples/talk-llama/llama-chat.cpp +100 -4
  42. data/ext/sources/examples/talk-llama/llama-chat.h +5 -0
  43. data/ext/sources/examples/talk-llama/llama-context.cpp +1088 -403
  44. data/ext/sources/examples/talk-llama/llama-context.h +70 -23
  45. data/ext/sources/examples/talk-llama/llama-cparams.h +6 -0
  46. data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
  47. data/ext/sources/examples/talk-llama/llama-grammar.cpp +295 -60
  48. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  49. data/ext/sources/examples/talk-llama/llama-graph.cpp +925 -155
  50. data/ext/sources/examples/talk-llama/llama-graph.h +234 -23
  51. data/ext/sources/examples/talk-llama/llama-hparams.cpp +79 -38
  52. data/ext/sources/examples/talk-llama/llama-hparams.h +118 -18
  53. data/ext/sources/examples/talk-llama/llama-impl.cpp +11 -7
  54. data/ext/sources/examples/talk-llama/llama-impl.h +14 -2
  55. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +8 -4
  56. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +405 -140
  57. data/ext/sources/examples/talk-llama/llama-kv-cache.h +24 -10
  58. data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
  59. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
  60. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
  61. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
  62. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +42 -31
  63. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
  64. data/ext/sources/examples/talk-llama/llama-mmap.cpp +197 -45
  65. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  66. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +606 -116
  67. data/ext/sources/examples/talk-llama/llama-model-loader.h +41 -5
  68. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +61 -44
  69. data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
  70. data/ext/sources/examples/talk-llama/llama-model.cpp +2756 -13643
  71. data/ext/sources/examples/talk-llama/llama-model.h +112 -18
  72. data/ext/sources/examples/talk-llama/llama-quant.cpp +582 -365
  73. data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +1409 -199
  74. data/ext/sources/examples/talk-llama/llama-sampler.h +42 -0
  75. data/ext/sources/examples/talk-llama/llama-vocab.cpp +248 -82
  76. data/ext/sources/examples/talk-llama/llama-vocab.h +50 -40
  77. data/ext/sources/examples/talk-llama/llama.cpp +802 -21
  78. data/ext/sources/examples/talk-llama/llama.h +210 -39
  79. data/ext/sources/examples/talk-llama/models/afmoe.cpp +190 -0
  80. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  81. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  82. data/ext/sources/examples/talk-llama/models/arctic.cpp +137 -0
  83. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  84. data/ext/sources/examples/talk-llama/models/baichuan.cpp +123 -0
  85. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +143 -0
  86. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +133 -0
  87. data/ext/sources/examples/talk-llama/models/bert.cpp +184 -0
  88. data/ext/sources/examples/talk-llama/models/bitnet.cpp +145 -0
  89. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  90. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  91. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  92. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  93. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  94. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  95. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  96. data/ext/sources/examples/talk-llama/models/dbrx.cpp +122 -0
  97. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  98. data/ext/sources/examples/talk-llama/models/deepseek.cpp +142 -0
  99. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +262 -0
  100. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
  101. data/ext/sources/examples/talk-llama/models/dots1.cpp +132 -0
  102. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  103. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +148 -0
  104. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  105. data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
  106. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
  107. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  108. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  109. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +111 -0
  110. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  111. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  112. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  113. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  114. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  115. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  116. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  117. data/ext/sources/examples/talk-llama/models/glm4.cpp +157 -0
  118. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  119. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  120. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +195 -0
  121. data/ext/sources/examples/talk-llama/models/granite.cpp +210 -0
  122. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  123. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +139 -0
  124. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  125. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +153 -0
  126. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  127. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  128. data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
  129. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  130. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
  131. data/ext/sources/examples/talk-llama/models/lfm2.cpp +196 -0
  132. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  133. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  134. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  135. data/ext/sources/examples/talk-llama/models/llama.cpp +175 -0
  136. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  137. data/ext/sources/examples/talk-llama/models/mamba-base.cpp +289 -0
  138. data/ext/sources/examples/talk-llama/models/mamba.cpp +54 -0
  139. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +129 -0
  140. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +200 -0
  141. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +123 -0
  142. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  143. data/ext/sources/examples/talk-llama/models/models.h +704 -0
  144. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +109 -0
  145. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  146. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +162 -0
  147. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  148. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  149. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  150. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  151. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  153. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  154. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  155. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
  156. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  157. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  158. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  159. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  160. data/ext/sources/examples/talk-llama/models/plamo2.cpp +320 -0
  161. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  162. data/ext/sources/examples/talk-llama/models/plm.cpp +169 -0
  163. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  164. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  165. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  166. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  167. data/ext/sources/examples/talk-llama/models/qwen3.cpp +120 -0
  168. data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
  169. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
  170. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +131 -0
  171. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +525 -0
  172. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +140 -0
  173. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +132 -0
  174. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  175. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  176. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +164 -0
  177. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  178. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  179. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +137 -0
  180. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  181. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  182. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  183. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  184. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  185. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  186. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  187. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
  188. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  189. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  190. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  191. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  192. data/ext/sources/examples/talk-llama/unicode.cpp +121 -79
  193. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  194. data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
  195. data/ext/sources/ggml/CMakeLists.txt +90 -56
  196. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  197. data/ext/sources/ggml/include/ggml-backend.h +5 -2
  198. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  199. data/ext/sources/ggml/include/ggml-cpu.h +6 -0
  200. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  201. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  202. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  203. data/ext/sources/ggml/include/ggml-rpc.h +14 -12
  204. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  205. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  206. data/ext/sources/ggml/include/ggml.h +246 -21
  207. data/ext/sources/ggml/src/CMakeLists.txt +85 -11
  208. data/ext/sources/ggml/src/ggml-alloc.c +128 -50
  209. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  210. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  211. data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
  212. data/ext/sources/ggml/src/ggml-backend-reg.cpp +54 -88
  213. data/ext/sources/ggml/src/ggml-backend.cpp +76 -23
  214. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +18 -4
  215. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +11 -11
  216. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +58 -46
  217. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +139 -48
  218. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2427 -1785
  219. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -362
  220. data/ext/sources/ggml/src/ggml-cann/common.h +285 -211
  221. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +663 -831
  222. data/ext/sources/ggml/src/ggml-common.h +11 -0
  223. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +170 -95
  224. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -18
  225. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  226. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
  227. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  228. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +513 -27
  229. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +4192 -992
  230. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
  231. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  232. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1761 -49
  233. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
  234. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  235. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
  236. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
  237. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +124 -24
  238. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +157 -28
  239. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  240. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  241. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -3
  242. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +251 -80
  243. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +19 -0
  244. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +587 -119
  245. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
  246. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1093 -194
  247. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1284 -203
  248. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  249. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1519 -527
  250. data/ext/sources/ggml/src/ggml-cpu/ops.h +6 -4
  251. data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
  252. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  253. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3632 -781
  254. data/ext/sources/ggml/src/ggml-cpu/repack.h +129 -4
  255. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  256. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +152 -46
  257. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
  258. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +152 -1
  259. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  260. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +140 -0
  261. data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
  262. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
  263. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  264. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +132 -6
  265. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  266. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +33 -31
  267. data/ext/sources/ggml/src/ggml-cuda/common.cuh +474 -85
  268. data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
  269. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  270. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
  271. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +342 -246
  272. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  273. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  274. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  275. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  276. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +98 -74
  278. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +973 -665
  279. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
  280. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1255 -0
  281. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +33 -40
  282. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +40 -18
  283. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  284. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +206 -45
  285. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  286. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  287. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
  288. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  289. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1688 -302
  290. data/ext/sources/ggml/src/ggml-cuda/mean.cu +12 -10
  291. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +908 -48
  292. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +88 -20
  293. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +502 -90
  294. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  295. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  296. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
  297. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +532 -193
  298. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +460 -104
  299. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +5 -2
  300. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +360 -122
  301. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +2 -1
  302. data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
  303. data/ext/sources/ggml/src/ggml-cuda/pad.cu +73 -39
  304. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +152 -1
  305. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  306. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
  307. data/ext/sources/ggml/src/ggml-cuda/rope.cu +364 -149
  308. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  309. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
  310. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  311. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  312. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +163 -41
  313. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  314. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  315. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +68 -50
  316. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  317. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
  318. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
  319. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  320. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
  321. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  322. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  323. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  324. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  325. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  326. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  327. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  328. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  329. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  330. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  331. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  332. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  333. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +22 -4
  334. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +95 -0
  335. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  336. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +275 -119
  337. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -7
  338. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  339. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  340. data/ext/sources/ggml/src/ggml-cuda/unary.cu +160 -11
  341. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +38 -0
  342. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
  343. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
  344. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  345. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -1
  346. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
  347. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +117 -0
  348. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3325 -0
  349. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +46 -0
  350. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +813 -0
  351. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  352. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +891 -0
  353. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  354. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
  355. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +713 -0
  356. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  357. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +63 -0
  358. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +182 -0
  359. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  360. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  361. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  362. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  363. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +155 -0
  364. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +63 -0
  365. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  366. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  367. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
  368. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  369. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
  370. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  371. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  372. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  373. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  374. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  375. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  376. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  377. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  378. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  379. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +26 -0
  380. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1199 -0
  381. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2670 -0
  382. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +497 -0
  383. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  384. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +419 -0
  385. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
  386. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  387. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +382 -0
  388. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +293 -0
  389. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  390. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  391. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  392. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  393. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  394. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  395. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -13
  396. data/ext/sources/ggml/src/ggml-impl.h +129 -6
  397. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  398. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +15 -4
  399. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  400. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +173 -34
  401. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +912 -344
  402. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +124 -59
  403. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +588 -144
  404. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +396 -23
  405. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1724 -421
  406. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +16 -3
  407. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +333 -114
  408. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3050 -1539
  409. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
  410. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +30 -1
  411. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4279 -497
  412. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
  413. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
  414. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  415. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +267 -0
  416. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  417. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  418. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +113 -0
  419. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  420. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
  421. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  422. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  423. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  424. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  425. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  426. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  427. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  428. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +140 -0
  429. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  430. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
  431. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
  432. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  433. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  434. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  435. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  436. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  437. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  438. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  439. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  440. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  441. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  442. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  443. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  444. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  445. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
  446. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  447. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  448. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
  449. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +116 -0
  450. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  451. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  452. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  453. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  454. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  455. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  456. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +39 -0
  457. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  458. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  459. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  460. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
  461. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  462. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
  463. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  464. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
  465. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
  466. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  467. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  468. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  469. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  470. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  471. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  472. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  473. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  474. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  475. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  476. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  477. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  478. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  479. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  480. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  481. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  482. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  483. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
  484. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  485. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  486. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  487. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  488. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  489. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  490. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
  491. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
  492. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
  493. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
  494. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  495. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  496. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  497. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  498. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  499. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
  500. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  501. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
  502. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
  503. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
  504. data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
  505. data/ext/sources/ggml/src/ggml-quants.c +96 -5
  506. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  507. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
  508. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +59 -87
  509. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +81 -0
  510. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  511. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +7 -0
  512. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -29
  513. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
  514. data/ext/sources/ggml/src/ggml-sycl/common.hpp +427 -20
  515. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
  516. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +103 -1
  517. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
  518. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  519. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  520. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
  521. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  522. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +867 -50
  523. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +401 -358
  524. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
  525. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
  526. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
  527. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
  528. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
  529. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
  530. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  531. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
  532. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
  533. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +645 -155
  534. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
  535. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +221 -66
  536. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  537. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  538. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  539. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  540. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  541. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  542. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +5 -0
  543. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
  544. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  545. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  546. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  547. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  548. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +457 -281
  549. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  550. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  551. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  552. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
  553. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  554. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  555. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  556. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  557. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  558. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  559. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  560. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  561. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  562. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  563. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  564. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  565. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
  566. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
  567. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
  568. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
  569. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
  570. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
  571. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
  572. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
  573. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
  574. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
  575. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
  576. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
  577. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
  578. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
  579. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
  580. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
  581. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
  582. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
  583. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
  584. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
  585. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
  586. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
  587. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
  588. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
  589. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
  590. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
  591. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
  592. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
  593. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
  594. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
  595. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
  596. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
  597. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
  598. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
  599. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
  600. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
  601. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +71 -0
  602. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  603. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  604. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  605. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  606. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  607. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  608. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  609. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  610. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  611. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  612. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  613. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  614. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  615. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  616. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  617. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  618. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  619. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  620. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  621. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  622. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  623. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  624. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  625. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  626. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  627. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  628. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  629. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  630. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  631. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  632. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  633. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  634. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  635. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  636. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  637. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  638. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  639. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  640. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  641. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  642. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  643. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  644. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  645. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +39 -19
  646. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5994 -3055
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +18 -10
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
  692. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
  695. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  696. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  697. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  698. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  699. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  700. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  701. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
  702. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  703. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +386 -160
  704. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +82 -20
  705. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +400 -174
  706. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +123 -37
  707. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  708. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +10 -9
  709. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  710. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
  711. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  712. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
  713. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
  714. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  715. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
  716. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  717. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
  718. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  719. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  720. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
  721. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
  722. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
  723. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  724. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
  725. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
  726. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
  727. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
  728. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +13 -10
  729. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  730. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  731. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  732. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  733. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +77 -29
  734. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  735. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  736. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  737. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  738. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  739. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  740. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  741. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  742. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
  743. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  744. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  745. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  746. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  747. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  748. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  749. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
  750. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  751. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +88 -105
  752. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +41 -26
  753. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
  754. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +74 -0
  755. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +92 -230
  756. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  757. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  758. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
  759. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  760. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  761. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  762. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
  763. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
  764. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  765. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
  766. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  767. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  768. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  769. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  770. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -4
  771. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  772. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
  773. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
  774. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +207 -0
  775. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  776. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +8 -49
  777. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +8 -32
  778. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +8 -32
  779. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +33 -0
  780. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +8 -38
  781. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  782. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
  783. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  784. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  785. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  786. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  787. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  788. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
  789. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
  790. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  791. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  792. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  793. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  794. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  795. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  796. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
  797. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  798. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +50 -0
  799. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  800. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  801. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  802. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
  803. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  804. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  805. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
  806. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  807. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
  808. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  809. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  810. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  811. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  812. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  813. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
  814. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
  815. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +384 -180
  816. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  817. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
  818. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1374 -0
  819. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2544 -726
  820. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  821. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  822. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  823. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  824. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
  825. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
  826. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
  827. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +107 -0
  828. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  829. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +73 -15
  830. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +636 -0
  831. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
  832. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  833. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +72 -261
  834. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +766 -0
  835. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +147 -0
  836. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +196 -0
  837. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
  838. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  839. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  840. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
  841. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  842. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +63 -0
  843. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
  844. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  845. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  846. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  847. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
  848. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +91 -0
  849. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +469 -0
  850. data/ext/sources/ggml/src/ggml.c +590 -64
  851. data/ext/sources/ggml/src/gguf.cpp +229 -44
  852. data/ext/sources/include/whisper.h +1 -0
  853. data/ext/sources/src/CMakeLists.txt +3 -1
  854. data/ext/sources/src/whisper.cpp +106 -62
  855. data/ext/sources/tests/CMakeLists.txt +2 -2
  856. data/ext/sources/tests/test-vad-full.cpp +4 -2
  857. data/ext/sources/tests/test-vad.cpp +1 -1
  858. data/extsources.rb +1 -0
  859. data/lib/whisper/model/uri.rb +17 -18
  860. data/sig/whisper.rbs +162 -4
  861. data/test/test_context_params.rb +82 -0
  862. data/test/test_params.rb +16 -8
  863. data/test/test_segment.rb +0 -1
  864. data/test/test_token.rb +81 -0
  865. data/test/test_vad.rb +1 -1
  866. data/test/test_vad_context.rb +100 -0
  867. data/test/test_vad_segment.rb +19 -0
  868. data/test/test_vad_segments.rb +16 -0
  869. data/test/test_whisper.rb +27 -0
  870. data/whispercpp.gemspec +1 -1
  871. metadata +502 -37
  872. data/ext/sources/build-xcframework.sh +0 -571
  873. data/ext/sources/examples/talk-llama/llama-sampling.h +0 -32
  874. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  875. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  876. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
  877. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
  878. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
  879. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
  880. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  881. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
  882. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
  883. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
  884. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
  885. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  886. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  887. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  888. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  889. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
  890. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
  891. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
@@ -3,12 +3,14 @@
3
3
  #include "ggml-cpu.h"
4
4
  #include "ggml-impl.h"
5
5
  #include "binary-ops.h"
6
+ #include "simd-gemm.h"
6
7
  #include "ggml.h"
7
8
  #include "unary-ops.h"
8
9
  #include "vec.h"
9
10
 
10
- #include <float.h>
11
11
  #include <algorithm>
12
+ #include <cfloat>
13
+ #include <cmath>
12
14
 
13
15
  // ggml_compute_forward_dup
14
16
 
@@ -373,7 +375,7 @@ static void ggml_compute_forward_dup_bytes(
373
375
  const size_t rs = ne00 * type_size;
374
376
 
375
377
  if (nb00 == type_size) {
376
- // src0 is contigous on first dimension, copy by rows
378
+ // src0 is contiguous on first dimension, copy by rows
377
379
  for (int64_t i03 = 0; i03 < ne03; i03++) {
378
380
  for (int64_t i02 = 0; i02 < ne02; i02++) {
379
381
  id += rs * ir0;
@@ -668,6 +670,7 @@ void ggml_compute_forward_add(
668
670
  case GGML_TYPE_Q5_1:
669
671
  case GGML_TYPE_Q8_0:
670
672
  case GGML_TYPE_MXFP4:
673
+ case GGML_TYPE_NVFP4:
671
674
  case GGML_TYPE_Q2_K:
672
675
  case GGML_TYPE_Q3_K:
673
676
  case GGML_TYPE_Q4_K:
@@ -1117,6 +1120,7 @@ void ggml_compute_forward_add1(
1117
1120
  case GGML_TYPE_Q8_0:
1118
1121
  case GGML_TYPE_Q8_1:
1119
1122
  case GGML_TYPE_MXFP4:
1123
+ case GGML_TYPE_NVFP4:
1120
1124
  case GGML_TYPE_Q2_K:
1121
1125
  case GGML_TYPE_Q3_K:
1122
1126
  case GGML_TYPE_Q4_K:
@@ -1245,6 +1249,7 @@ void ggml_compute_forward_acc(
1245
1249
  case GGML_TYPE_Q8_0:
1246
1250
  case GGML_TYPE_Q8_1:
1247
1251
  case GGML_TYPE_MXFP4:
1252
+ case GGML_TYPE_NVFP4:
1248
1253
  case GGML_TYPE_Q2_K:
1249
1254
  case GGML_TYPE_Q3_K:
1250
1255
  case GGML_TYPE_Q4_K:
@@ -1394,6 +1399,56 @@ void ggml_compute_forward_sum(
1394
1399
  }
1395
1400
  }
1396
1401
 
1402
+ // ggml_compute_forward_cumsum
1403
+
1404
+ static void ggml_compute_forward_cumsum_f32(
1405
+ const ggml_compute_params * params,
1406
+ ggml_tensor * dst) {
1407
+
1408
+ const ggml_tensor * src0 = dst->src[0];
1409
+
1410
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
1411
+ GGML_ASSERT(dst->nb[0] == sizeof(float));
1412
+
1413
+ GGML_TENSOR_UNARY_OP_LOCALS
1414
+
1415
+ GGML_ASSERT(ne0 == ne00);
1416
+ GGML_ASSERT(ne1 == ne01);
1417
+ GGML_ASSERT(ne2 == ne02);
1418
+ GGML_ASSERT(ne3 == ne03);
1419
+
1420
+ const auto [ir0, ir1] = get_thread_range(params, src0);
1421
+
1422
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
1423
+ const int64_t i03 = ir/(ne02*ne01);
1424
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
1425
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
1426
+
1427
+ float * src_row = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
1428
+ float * dst_row = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
1429
+
1430
+ ggml_vec_cumsum_f32(ne00, dst_row, src_row);
1431
+ }
1432
+ }
1433
+
1434
+ void ggml_compute_forward_cumsum(
1435
+ const ggml_compute_params * params,
1436
+ ggml_tensor * dst) {
1437
+
1438
+ const ggml_tensor * src0 = dst->src[0];
1439
+
1440
+ switch (src0->type) {
1441
+ case GGML_TYPE_F32:
1442
+ {
1443
+ ggml_compute_forward_cumsum_f32(params, dst);
1444
+ } break;
1445
+ default:
1446
+ {
1447
+ GGML_ABORT("fatal error");
1448
+ }
1449
+ }
1450
+ }
1451
+
1397
1452
  // ggml_compute_forward_sum_rows
1398
1453
 
1399
1454
  static void ggml_compute_forward_sum_rows_f32(
@@ -1743,7 +1798,7 @@ void ggml_compute_forward_repeat(
1743
1798
  {
1744
1799
  ggml_compute_forward_repeat_f32(params, dst);
1745
1800
  } break;
1746
- // TODO: templateify the implemenation and support for I64
1801
+ // TODO: templateify the implementation and support for I64
1747
1802
  // ref https://github.com/ggml-org/llama.cpp/pull/14274#discussion_r2169492225
1748
1803
  //case GGML_TYPE_I64:
1749
1804
  // {
@@ -2045,10 +2100,14 @@ static void ggml_compute_forward_gelu_f32(
2045
2100
 
2046
2101
  const ggml_tensor * src0 = dst->src[0];
2047
2102
 
2048
- assert(ggml_is_contiguous_1(src0));
2049
- assert(ggml_is_contiguous_1(dst));
2103
+ assert(ggml_is_contiguous_rows(src0));
2050
2104
  assert(ggml_are_same_shape(src0, dst));
2051
2105
 
2106
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
2107
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
2108
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
2109
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
2110
+
2052
2111
  const int ith = params->ith;
2053
2112
  const int nth = params->nth;
2054
2113
 
@@ -2062,19 +2121,23 @@ static void ggml_compute_forward_gelu_f32(
2062
2121
  const int ir0 = dr*ith;
2063
2122
  const int ir1 = MIN(ir0 + dr, nr);
2064
2123
 
2065
- for (int i1 = ir0; i1 < ir1; i1++) {
2124
+ for (int ir = ir0; ir < ir1; ++ir) {
2125
+ const int i3 = ir/(ne02*ne01);
2126
+ const int i2 = (ir - i3*ne02*ne01)/ne01;
2127
+ const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
2128
+
2066
2129
  ggml_vec_gelu_f32(nc,
2067
- (float *) ((char *) dst->data + i1*( dst->nb[1])),
2068
- (float *) ((char *) src0->data + i1*(src0->nb[1])));
2130
+ (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
2131
+ (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
2069
2132
 
2070
2133
  #ifndef NDEBUG
2071
2134
  for (int k = 0; k < nc; k++) {
2072
- const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2135
+ const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k];
2073
2136
  GGML_UNUSED(x);
2074
2137
  assert(!isnan(x));
2075
2138
  assert(!isinf(x));
2076
2139
  }
2077
- #endif
2140
+ #endif // NDEBUG
2078
2141
  }
2079
2142
  }
2080
2143
 
@@ -2084,10 +2147,14 @@ static void ggml_compute_forward_gelu_f16(
2084
2147
 
2085
2148
  const ggml_tensor * src0 = dst->src[0];
2086
2149
 
2087
- assert(ggml_is_contiguous_1(src0));
2088
- assert(ggml_is_contiguous_1(dst));
2150
+ assert(ggml_is_contiguous_rows(src0));
2089
2151
  assert(ggml_are_same_shape(src0, dst));
2090
2152
 
2153
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
2154
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
2155
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
2156
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
2157
+
2091
2158
  const int ith = params->ith;
2092
2159
  const int nth = params->nth;
2093
2160
 
@@ -2101,20 +2168,24 @@ static void ggml_compute_forward_gelu_f16(
2101
2168
  const int ir0 = dr*ith;
2102
2169
  const int ir1 = MIN(ir0 + dr, nr);
2103
2170
 
2104
- for (int i1 = ir0; i1 < ir1; i1++) {
2171
+ for (int ir = ir0; ir < ir1; ++ir) {
2172
+ const int i3 = ir/(ne02*ne01);
2173
+ const int i2 = (ir - i3*ne02*ne01)/ne01;
2174
+ const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
2175
+
2105
2176
  ggml_vec_gelu_f16(nc,
2106
- (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
2107
- (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
2177
+ (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
2178
+ (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
2108
2179
 
2109
2180
  #ifndef NDEBUG
2110
2181
  for (int k = 0; k < nc; k++) {
2111
- const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2182
+ const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k];
2112
2183
  const float v = GGML_CPU_FP16_TO_FP32(x);
2113
2184
  GGML_UNUSED(v);
2114
2185
  assert(!isnan(v));
2115
2186
  assert(!isinf(v));
2116
2187
  }
2117
- #endif
2188
+ #endif // NDEBUG
2118
2189
  }
2119
2190
  }
2120
2191
 
@@ -2140,6 +2211,83 @@ static void ggml_compute_forward_gelu(
2140
2211
  }
2141
2212
  }
2142
2213
 
2214
+ // ggml_compute_fill
2215
+
2216
+ static void ggml_compute_forward_fill_f32(const ggml_compute_params * params, ggml_tensor * dst) {
2217
+ const float c = ggml_get_op_params_f32(dst, 0);
2218
+
2219
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
2220
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
2221
+
2222
+ const auto [ir0, ir1] = get_thread_range(params, dst);
2223
+
2224
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
2225
+ const int64_t i03 = ir/(ne2*ne1);
2226
+ const int64_t i02 = (ir - i03*ne2*ne1)/ne1;
2227
+ const int64_t i01 = (ir - i03*ne2*ne1 - i02*ne1);
2228
+
2229
+ float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
2230
+
2231
+ ggml_vec_set_f32(ne0, dst_ptr, c);
2232
+ }
2233
+ }
2234
+
2235
+ void ggml_compute_forward_fill(const ggml_compute_params * params, ggml_tensor * dst) {
2236
+ ggml_compute_forward_fill_f32(params, dst);
2237
+ }
2238
+
2239
+ // ggml_compute_tri
2240
+
2241
+ static void ggml_compute_forward_tri_f32(const ggml_compute_params * params, ggml_tensor * dst) {
2242
+ const ggml_tensor * src0 = dst->src[0];
2243
+
2244
+ const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0);
2245
+
2246
+ GGML_ASSERT(ggml_is_contiguous(src0));
2247
+
2248
+ GGML_TENSOR_UNARY_OP_LOCALS
2249
+
2250
+ const auto [ir0, ir1] = get_thread_range(params, src0);
2251
+
2252
+ bool (*bipred)(int, int);
2253
+
2254
+ switch (ttype) {
2255
+ case GGML_TRI_TYPE_LOWER: bipred = [](int i, int r) { return i < r; }; break;
2256
+ case GGML_TRI_TYPE_LOWER_DIAG: bipred = [](int i, int r) { return i <= r; }; break;
2257
+ case GGML_TRI_TYPE_UPPER: bipred = [](int i, int r) { return i > r; }; break;
2258
+ case GGML_TRI_TYPE_UPPER_DIAG: bipred = [](int i, int r) { return i >= r; }; break;
2259
+ default: GGML_ABORT("invalid tri type");
2260
+ }
2261
+
2262
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
2263
+ const int64_t i03 = ir/(ne02*ne01);
2264
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
2265
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
2266
+
2267
+ const float * src_ptr = (const float *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
2268
+ float * dst_ptr = ( float *) (( char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
2269
+
2270
+ for (int i0 = 0; i0 < ne0; ++i0) {
2271
+ dst_ptr[i0] = bipred(i0, i01) ? src_ptr[i0] : 0.0f;
2272
+ }
2273
+ }
2274
+ }
2275
+
2276
+ void ggml_compute_forward_tri(const ggml_compute_params * params, ggml_tensor * dst) {
2277
+ const ggml_tensor * src0 = dst->src[0];
2278
+
2279
+ switch (src0->type) {
2280
+ case GGML_TYPE_F32:
2281
+ {
2282
+ ggml_compute_forward_tri_f32(params, dst);
2283
+ } break;
2284
+ default:
2285
+ {
2286
+ GGML_ABORT("fatal error");
2287
+ }
2288
+ }
2289
+ }
2290
+
2143
2291
  // ggml_compute_forward_gelu_erf
2144
2292
 
2145
2293
  static void ggml_compute_forward_gelu_erf_f32(
@@ -2148,10 +2296,14 @@ static void ggml_compute_forward_gelu_erf_f32(
2148
2296
 
2149
2297
  const ggml_tensor * src0 = dst->src[0];
2150
2298
 
2151
- assert(ggml_is_contiguous_1(src0));
2152
- assert(ggml_is_contiguous_1(dst));
2299
+ assert(ggml_is_contiguous_rows(src0));
2153
2300
  assert(ggml_are_same_shape(src0, dst));
2154
2301
 
2302
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
2303
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
2304
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
2305
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
2306
+
2155
2307
  const int ith = params->ith;
2156
2308
  const int nth = params->nth;
2157
2309
 
@@ -2165,19 +2317,23 @@ static void ggml_compute_forward_gelu_erf_f32(
2165
2317
  const int ir0 = dr*ith;
2166
2318
  const int ir1 = MIN(ir0 + dr, nr);
2167
2319
 
2168
- for (int i1 = ir0; i1 < ir1; i1++) {
2320
+ for (int ir = ir0; ir < ir1; ++ir) {
2321
+ const int i3 = ir/(ne02*ne01);
2322
+ const int i2 = (ir - i3*ne02*ne01)/ne01;
2323
+ const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
2324
+
2169
2325
  ggml_vec_gelu_erf_f32(nc,
2170
- (float *) ((char *) dst->data + i1*( dst->nb[1])),
2171
- (float *) ((char *) src0->data + i1*(src0->nb[1])));
2326
+ (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
2327
+ (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
2172
2328
 
2173
2329
  #ifndef NDEBUG
2174
2330
  for (int k = 0; k < nc; k++) {
2175
- const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2331
+ const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k];
2176
2332
  GGML_UNUSED(x);
2177
2333
  assert(!isnan(x));
2178
2334
  assert(!isinf(x));
2179
2335
  }
2180
- #endif
2336
+ #endif // NDEBUG
2181
2337
  }
2182
2338
  }
2183
2339
 
@@ -2187,10 +2343,14 @@ static void ggml_compute_forward_gelu_erf_f16(
2187
2343
 
2188
2344
  const ggml_tensor * src0 = dst->src[0];
2189
2345
 
2190
- assert(ggml_is_contiguous_1(src0));
2191
- assert(ggml_is_contiguous_1(dst));
2346
+ assert(ggml_is_contiguous_rows(src0));
2192
2347
  assert(ggml_are_same_shape(src0, dst));
2193
2348
 
2349
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
2350
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
2351
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
2352
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
2353
+
2194
2354
  const int ith = params->ith;
2195
2355
  const int nth = params->nth;
2196
2356
 
@@ -2204,20 +2364,24 @@ static void ggml_compute_forward_gelu_erf_f16(
2204
2364
  const int ir0 = dr*ith;
2205
2365
  const int ir1 = MIN(ir0 + dr, nr);
2206
2366
 
2207
- for (int i1 = ir0; i1 < ir1; i1++) {
2367
+ for (int ir = ir0; ir < ir1; ++ir) {
2368
+ const int i3 = ir/(ne02*ne01);
2369
+ const int i2 = (ir - i3*ne02*ne01)/ne01;
2370
+ const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
2371
+
2208
2372
  ggml_vec_gelu_erf_f16(nc,
2209
- (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
2210
- (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
2373
+ (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
2374
+ (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
2211
2375
 
2212
2376
  #ifndef NDEBUG
2213
2377
  for (int k = 0; k < nc; k++) {
2214
- const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2378
+ const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k];
2215
2379
  const float v = GGML_CPU_FP16_TO_FP32(x);
2216
2380
  GGML_UNUSED(v);
2217
2381
  assert(!isnan(v));
2218
2382
  assert(!isinf(v));
2219
2383
  }
2220
- #endif
2384
+ #endif // NDEBUG
2221
2385
  }
2222
2386
  }
2223
2387
 
@@ -2251,10 +2415,14 @@ static void ggml_compute_forward_gelu_quick_f32(
2251
2415
 
2252
2416
  const ggml_tensor * src0 = dst->src[0];
2253
2417
 
2254
- assert(ggml_is_contiguous_1(src0));
2255
- assert(ggml_is_contiguous_1(dst));
2418
+ assert(ggml_is_contiguous_rows(src0));
2256
2419
  assert(ggml_are_same_shape(src0, dst));
2257
2420
 
2421
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
2422
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
2423
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
2424
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
2425
+
2258
2426
  const int ith = params->ith;
2259
2427
  const int nth = params->nth;
2260
2428
 
@@ -2268,19 +2436,23 @@ static void ggml_compute_forward_gelu_quick_f32(
2268
2436
  const int ir0 = dr*ith;
2269
2437
  const int ir1 = MIN(ir0 + dr, nr);
2270
2438
 
2271
- for (int i1 = ir0; i1 < ir1; i1++) {
2439
+ for (int ir = ir0; ir < ir1; ++ir) {
2440
+ const int i3 = ir/(ne02*ne01);
2441
+ const int i2 = (ir - i3*ne02*ne01)/ne01;
2442
+ const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
2443
+
2272
2444
  ggml_vec_gelu_quick_f32(nc,
2273
- (float *) ((char *) dst->data + i1*( dst->nb[1])),
2274
- (float *) ((char *) src0->data + i1*(src0->nb[1])));
2445
+ (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
2446
+ (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
2275
2447
 
2276
2448
  #ifndef NDEBUG
2277
2449
  for (int k = 0; k < nc; k++) {
2278
- const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2450
+ const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k];
2279
2451
  GGML_UNUSED(x);
2280
2452
  assert(!isnan(x));
2281
2453
  assert(!isinf(x));
2282
2454
  }
2283
- #endif
2455
+ #endif // NDEBUG
2284
2456
  }
2285
2457
  }
2286
2458
 
@@ -2290,10 +2462,14 @@ static void ggml_compute_forward_gelu_quick_f16(
2290
2462
 
2291
2463
  const ggml_tensor * src0 = dst->src[0];
2292
2464
 
2293
- assert(ggml_is_contiguous_1(src0));
2294
- assert(ggml_is_contiguous_1(dst));
2465
+ assert(ggml_is_contiguous_rows(src0));
2295
2466
  assert(ggml_are_same_shape(src0, dst));
2296
2467
 
2468
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
2469
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
2470
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
2471
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
2472
+
2297
2473
  const int ith = params->ith;
2298
2474
  const int nth = params->nth;
2299
2475
 
@@ -2307,20 +2483,24 @@ static void ggml_compute_forward_gelu_quick_f16(
2307
2483
  const int ir0 = dr*ith;
2308
2484
  const int ir1 = MIN(ir0 + dr, nr);
2309
2485
 
2310
- for (int i1 = ir0; i1 < ir1; i1++) {
2486
+ for (int ir = ir0; ir < ir1; ++ir) {
2487
+ const int i3 = ir/(ne02*ne01);
2488
+ const int i2 = (ir - i3*ne02*ne01)/ne01;
2489
+ const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
2490
+
2311
2491
  ggml_vec_gelu_quick_f16(nc,
2312
- (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
2313
- (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
2492
+ (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
2493
+ (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
2314
2494
 
2315
2495
  #ifndef NDEBUG
2316
2496
  for (int k = 0; k < nc; k++) {
2317
- const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2497
+ const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k];
2318
2498
  const float v = GGML_CPU_FP16_TO_FP32(x);
2319
2499
  GGML_UNUSED(v);
2320
2500
  assert(!isnan(v));
2321
2501
  assert(!isinf(v));
2322
2502
  }
2323
- #endif
2503
+ #endif // NDEBUG
2324
2504
  }
2325
2505
  }
2326
2506
 
@@ -2354,10 +2534,14 @@ static void ggml_compute_forward_silu_f32(
2354
2534
 
2355
2535
  const ggml_tensor * src0 = dst->src[0];
2356
2536
 
2357
- assert(ggml_is_contiguous_1(src0));
2358
- assert(ggml_is_contiguous_1(dst));
2537
+ assert(ggml_is_contiguous_rows(src0));
2359
2538
  assert(ggml_are_same_shape(src0, dst));
2360
2539
 
2540
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
2541
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
2542
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
2543
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
2544
+
2361
2545
  const int ith = params->ith;
2362
2546
  const int nth = params->nth;
2363
2547
 
@@ -2371,19 +2555,23 @@ static void ggml_compute_forward_silu_f32(
2371
2555
  const int ir0 = dr*ith;
2372
2556
  const int ir1 = MIN(ir0 + dr, nr);
2373
2557
 
2374
- for (int i1 = ir0; i1 < ir1; i1++) {
2558
+ for (int ir = ir0; ir < ir1; ++ir) {
2559
+ const int i3 = ir/(ne02*ne01);
2560
+ const int i2 = (ir - i3*ne02*ne01)/ne01;
2561
+ const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
2562
+
2375
2563
  ggml_vec_silu_f32(nc,
2376
- (float *) ((char *) dst->data + i1*( dst->nb[1])),
2377
- (float *) ((char *) src0->data + i1*(src0->nb[1])));
2564
+ (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
2565
+ (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
2378
2566
 
2379
2567
  #ifndef NDEBUG
2380
2568
  for (int k = 0; k < nc; k++) {
2381
- const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k];
2569
+ const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k];
2382
2570
  GGML_UNUSED(x);
2383
2571
  assert(!isnan(x));
2384
2572
  assert(!isinf(x));
2385
2573
  }
2386
- #endif
2574
+ #endif // NDEBUG
2387
2575
  }
2388
2576
  }
2389
2577
 
@@ -2393,10 +2581,14 @@ static void ggml_compute_forward_silu_f16(
2393
2581
 
2394
2582
  const ggml_tensor * src0 = dst->src[0];
2395
2583
 
2396
- assert(ggml_is_contiguous_1(src0));
2397
- assert(ggml_is_contiguous_1(dst));
2584
+ assert(ggml_is_contiguous_rows(src0));
2398
2585
  assert(ggml_are_same_shape(src0, dst));
2399
2586
 
2587
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
2588
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
2589
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
2590
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
2591
+
2400
2592
  const int ith = params->ith;
2401
2593
  const int nth = params->nth;
2402
2594
 
@@ -2410,20 +2602,24 @@ static void ggml_compute_forward_silu_f16(
2410
2602
  const int ir0 = dr*ith;
2411
2603
  const int ir1 = MIN(ir0 + dr, nr);
2412
2604
 
2413
- for (int i1 = ir0; i1 < ir1; i1++) {
2605
+ for (int ir = ir0; ir < ir1; ++ir) {
2606
+ const int i3 = ir/(ne02*ne01);
2607
+ const int i2 = (ir - i3*ne02*ne01)/ne01;
2608
+ const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
2609
+
2414
2610
  ggml_vec_silu_f16(nc,
2415
- (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
2416
- (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
2611
+ (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
2612
+ (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
2417
2613
 
2418
2614
  #ifndef NDEBUG
2419
2615
  for (int k = 0; k < nc; k++) {
2420
- const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])))[k];
2616
+ const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k];
2421
2617
  const float v = GGML_CPU_FP16_TO_FP32(x);
2422
2618
  GGML_UNUSED(v);
2423
2619
  assert(!isnan(v));
2424
2620
  assert(!isinf(v));
2425
2621
  }
2426
- #endif
2622
+ #endif // NDEBUG
2427
2623
  }
2428
2624
  }
2429
2625
 
@@ -2573,7 +2769,7 @@ static void ggml_compute_forward_silu_back_f32(
2573
2769
  assert(!isnan(x));
2574
2770
  assert(!isinf(x));
2575
2771
  }
2576
- #endif
2772
+ #endif // NDEBUG
2577
2773
  }
2578
2774
  }
2579
2775
 
@@ -2609,7 +2805,7 @@ static void ggml_compute_forward_silu_back_f16(
2609
2805
  (ggml_fp16_t *) ((char *) src1->data + i1*(src1->nb[1])),
2610
2806
  (ggml_fp16_t *) ((char *) grad->data + i1*(grad->nb[1])));
2611
2807
 
2612
- #ifndef NDEBUG
2808
+ #ifndef NDEBUG
2613
2809
  for (int k = 0; k < nc; k++) {
2614
2810
  const float x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2615
2811
  const float v = GGML_CPU_FP16_TO_FP32(x);
@@ -2617,7 +2813,7 @@ static void ggml_compute_forward_silu_back_f16(
2617
2813
  assert(!isnan(v));
2618
2814
  assert(!isinf(v));
2619
2815
  }
2620
- #endif
2816
+ #endif // NDEBUG
2621
2817
  }
2622
2818
  }
2623
2819
 
@@ -2700,7 +2896,7 @@ static void ggml_compute_forward_reglu_f32(
2700
2896
  assert(!isnan(x));
2701
2897
  assert(!isinf(x));
2702
2898
  }
2703
- #endif
2899
+ #endif // NDEBUG
2704
2900
  }
2705
2901
  }
2706
2902
 
@@ -2760,7 +2956,7 @@ static void ggml_compute_forward_reglu_f16(
2760
2956
  assert(!isnan(v));
2761
2957
  assert(!isinf(v));
2762
2958
  }
2763
- #endif
2959
+ #endif // NDEBUG
2764
2960
  }
2765
2961
  }
2766
2962
 
@@ -2843,7 +3039,7 @@ static void ggml_compute_forward_geglu_f32(
2843
3039
  assert(!isnan(x));
2844
3040
  assert(!isinf(x));
2845
3041
  }
2846
- #endif
3042
+ #endif // NDEBUG
2847
3043
  }
2848
3044
  }
2849
3045
 
@@ -2903,7 +3099,7 @@ static void ggml_compute_forward_geglu_f16(
2903
3099
  assert(!isnan(v));
2904
3100
  assert(!isinf(v));
2905
3101
  }
2906
- #endif
3102
+ #endif // NDEBUG
2907
3103
  }
2908
3104
  }
2909
3105
 
@@ -2986,7 +3182,7 @@ static void ggml_compute_forward_swiglu_f32(
2986
3182
  assert(!isnan(x));
2987
3183
  assert(!isinf(x));
2988
3184
  }
2989
- #endif
3185
+ #endif // NDEBUG
2990
3186
  }
2991
3187
  }
2992
3188
 
@@ -3046,7 +3242,7 @@ static void ggml_compute_forward_swiglu_f16(
3046
3242
  assert(!isnan(v));
3047
3243
  assert(!isinf(v));
3048
3244
  }
3049
- #endif
3245
+ #endif // NDEBUG
3050
3246
  }
3051
3247
  }
3052
3248
 
@@ -3137,7 +3333,7 @@ static void ggml_compute_forward_swiglu_oai_f32(
3137
3333
  assert(!isnan(x));
3138
3334
  assert(!isinf(x));
3139
3335
  }
3140
- #endif
3336
+ #endif // NDEBUG
3141
3337
  }
3142
3338
  }
3143
3339
 
@@ -3216,7 +3412,7 @@ static void ggml_compute_forward_geglu_erf_f32(
3216
3412
  assert(!isnan(x));
3217
3413
  assert(!isinf(x));
3218
3414
  }
3219
- #endif
3415
+ #endif // NDEBUG
3220
3416
  }
3221
3417
  }
3222
3418
 
@@ -3276,7 +3472,7 @@ static void ggml_compute_forward_geglu_erf_f16(
3276
3472
  assert(!isnan(v));
3277
3473
  assert(!isinf(v));
3278
3474
  }
3279
- #endif
3475
+ #endif // NDEBUG
3280
3476
  }
3281
3477
  }
3282
3478
 
@@ -3359,7 +3555,7 @@ static void ggml_compute_forward_geglu_quick_f32(
3359
3555
  assert(!isnan(x));
3360
3556
  assert(!isinf(x));
3361
3557
  }
3362
- #endif
3558
+ #endif // NDEBUG
3363
3559
  }
3364
3560
  }
3365
3561
 
@@ -3419,7 +3615,7 @@ static void ggml_compute_forward_geglu_quick_f16(
3419
3615
  assert(!isnan(v));
3420
3616
  assert(!isinf(v));
3421
3617
  }
3422
- #endif
3618
+ #endif // NDEBUG
3423
3619
  }
3424
3620
  }
3425
3621
 
@@ -3467,31 +3663,27 @@ static void ggml_compute_forward_norm_f32(
3467
3663
 
3468
3664
  GGML_ASSERT(eps >= 0.0f);
3469
3665
 
3470
- // TODO: optimize
3471
3666
  for (int64_t i03 = 0; i03 < ne03; i03++) {
3472
3667
  for (int64_t i02 = 0; i02 < ne02; i02++) {
3473
3668
  for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
3474
3669
  const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
3475
3670
 
3476
- ggml_float sum = 0.0;
3477
- for (int64_t i00 = 0; i00 < ne00; i00++) {
3478
- sum += (ggml_float)x[i00];
3479
- }
3480
-
3671
+ float sum = 0.0;
3672
+ ggml_vec_sum_f32(ne00, &sum, x);
3481
3673
  float mean = sum/ne00;
3482
3674
 
3483
3675
  float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
3676
+ float variance = 0;
3484
3677
 
3485
- ggml_float sum2 = 0.0;
3486
- for (int64_t i00 = 0; i00 < ne00; i00++) {
3487
- float v = x[i00] - mean;
3488
- y[i00] = v;
3489
- sum2 += (ggml_float)(v*v);
3490
- }
3678
+ #ifdef GGML_USE_ACCELERATE
3679
+ mean = -mean;
3680
+ vDSP_vsadd(x, 1, &mean, y, 1, ne00);
3681
+ vDSP_measqv(y, 1, &variance, ne00);
3682
+ #else
3683
+ variance = ggml_vec_cvar_f32(ne00, y, x, mean);
3684
+ #endif //GGML_USE_ACCELERATE
3491
3685
 
3492
- float variance = sum2/ne00;
3493
3686
  const float scale = 1.0f/sqrtf(variance + eps);
3494
-
3495
3687
  ggml_vec_scale_f32(ne00, y, scale);
3496
3688
  }
3497
3689
  }
@@ -4145,6 +4337,7 @@ void ggml_compute_forward_out_prod(
4145
4337
  case GGML_TYPE_Q5_1:
4146
4338
  case GGML_TYPE_Q8_0:
4147
4339
  case GGML_TYPE_MXFP4:
4340
+ case GGML_TYPE_NVFP4:
4148
4341
  case GGML_TYPE_Q2_K:
4149
4342
  case GGML_TYPE_Q3_K:
4150
4343
  case GGML_TYPE_Q4_K:
@@ -4420,6 +4613,7 @@ void ggml_compute_forward_set(
4420
4613
  case GGML_TYPE_Q8_0:
4421
4614
  case GGML_TYPE_Q8_1:
4422
4615
  case GGML_TYPE_MXFP4:
4616
+ case GGML_TYPE_NVFP4:
4423
4617
  case GGML_TYPE_Q2_K:
4424
4618
  case GGML_TYPE_Q3_K:
4425
4619
  case GGML_TYPE_Q4_K:
@@ -4459,46 +4653,6 @@ void ggml_compute_forward_cont(
4459
4653
  ggml_compute_forward_dup(params, dst);
4460
4654
  }
4461
4655
 
4462
- // ggml_compute_forward_reshape
4463
-
4464
- void ggml_compute_forward_reshape(
4465
- const ggml_compute_params * params,
4466
- ggml_tensor * dst) {
4467
- // NOP
4468
- GGML_UNUSED(params);
4469
- GGML_UNUSED(dst);
4470
- }
4471
-
4472
- // ggml_compute_forward_view
4473
-
4474
- void ggml_compute_forward_view(
4475
- const ggml_compute_params * params,
4476
- ggml_tensor * dst) {
4477
- // NOP
4478
- GGML_UNUSED(params);
4479
- GGML_UNUSED(dst);
4480
- }
4481
-
4482
- // ggml_compute_forward_permute
4483
-
4484
- void ggml_compute_forward_permute(
4485
- const ggml_compute_params * params,
4486
- ggml_tensor * dst) {
4487
- // NOP
4488
- GGML_UNUSED(params);
4489
- GGML_UNUSED(dst);
4490
- }
4491
-
4492
- // ggml_compute_forward_transpose
4493
-
4494
- void ggml_compute_forward_transpose(
4495
- const ggml_compute_params * params,
4496
- ggml_tensor * dst) {
4497
- // NOP
4498
- GGML_UNUSED(params);
4499
- GGML_UNUSED(dst);
4500
- }
4501
-
4502
4656
  // ggml_compute_forward_get_rows
4503
4657
 
4504
4658
  static void ggml_compute_forward_get_rows_q(
@@ -4682,6 +4836,7 @@ void ggml_compute_forward_get_rows(
4682
4836
  case GGML_TYPE_Q8_0:
4683
4837
  case GGML_TYPE_Q8_1:
4684
4838
  case GGML_TYPE_MXFP4:
4839
+ case GGML_TYPE_NVFP4:
4685
4840
  case GGML_TYPE_Q2_K:
4686
4841
  case GGML_TYPE_Q3_K:
4687
4842
  case GGML_TYPE_Q4_K:
@@ -5154,7 +5309,7 @@ static void ggml_compute_forward_soft_max_f32(
5154
5309
  //printf("p[%d] = %f\n", i, p[i]);
5155
5310
  assert(!isnan(wp[i]));
5156
5311
  }
5157
- #endif
5312
+ #endif // NDEBUG
5158
5313
 
5159
5314
  float max = -INFINITY;
5160
5315
  ggml_vec_max_f32(ne00, &max, wp);
@@ -5179,7 +5334,7 @@ static void ggml_compute_forward_soft_max_f32(
5179
5334
  assert(!isnan(dp[i]));
5180
5335
  assert(!isinf(dp[i]));
5181
5336
  }
5182
- #endif
5337
+ #endif // NDEBUG
5183
5338
  }
5184
5339
  }
5185
5340
  }
@@ -5253,7 +5408,7 @@ static void ggml_compute_forward_soft_max_ext_back_f32(
5253
5408
  assert(!isnan(dy[i]));
5254
5409
  assert(!isnan(y[i]));
5255
5410
  }
5256
- #endif
5411
+ #endif // NDEBUG
5257
5412
  // Jii = yi - yi*yi
5258
5413
  // Jij = -yi*yj
5259
5414
  // J = diag(y)-y.T*y
@@ -5286,7 +5441,7 @@ static void ggml_compute_forward_soft_max_ext_back_f32(
5286
5441
  assert(!isnan(dx[i]));
5287
5442
  assert(!isinf(dx[i]));
5288
5443
  }
5289
- #endif
5444
+ #endif // NDEBUG
5290
5445
  }
5291
5446
  }
5292
5447
 
@@ -5406,6 +5561,7 @@ void ggml_compute_forward_clamp(
5406
5561
  case GGML_TYPE_Q8_0:
5407
5562
  case GGML_TYPE_Q8_1:
5408
5563
  case GGML_TYPE_MXFP4:
5564
+ case GGML_TYPE_NVFP4:
5409
5565
  case GGML_TYPE_Q2_K:
5410
5566
  case GGML_TYPE_Q3_K:
5411
5567
  case GGML_TYPE_Q4_K:
@@ -5478,7 +5634,7 @@ static void ggml_rope_cache_init(
5478
5634
  }
5479
5635
 
5480
5636
  static void ggml_mrope_cache_init(
5481
- float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects,
5637
+ float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool is_imrope, bool indep_sects,
5482
5638
  float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
5483
5639
  float * cache, float sin_sign, float theta_scale) {
5484
5640
  // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
@@ -5513,14 +5669,26 @@ static void ggml_mrope_cache_init(
5513
5669
  }
5514
5670
 
5515
5671
  float theta = theta_t;
5516
- if (sector >= sections[0] && sector < sec_w) {
5517
- theta = theta_h;
5518
- }
5519
- else if (sector >= sec_w && sector < sec_w + sections[2]) {
5520
- theta = theta_w;
5521
- }
5522
- else if (sector >= sec_w + sections[2]) {
5523
- theta = theta_e;
5672
+ if (is_imrope) { // qwen3vl apply interleaved mrope
5673
+ if (sector % 3 == 1 && sector < 3 * sections[1]) {
5674
+ theta = theta_h;
5675
+ } else if (sector % 3 == 2 && sector < 3 * sections[2]) {
5676
+ theta = theta_w;
5677
+ } else if (sector % 3 == 0 && sector < 3 * sections[0]) {
5678
+ theta = theta_t;
5679
+ } else {
5680
+ theta = theta_e;
5681
+ }
5682
+ } else {
5683
+ if (sector >= sections[0] && sector < sec_w) {
5684
+ theta = theta_h;
5685
+ }
5686
+ else if (sector >= sec_w && sector < sec_w + sections[2]) {
5687
+ theta = theta_w;
5688
+ }
5689
+ else if (sector >= sec_w + sections[2]) {
5690
+ theta = theta_e;
5691
+ }
5524
5692
  }
5525
5693
 
5526
5694
  rope_yarn(
@@ -5535,7 +5703,28 @@ static void ggml_mrope_cache_init(
5535
5703
  }
5536
5704
  }
5537
5705
 
5538
- static void ggml_compute_forward_rope_f32(
5706
+
5707
+ template<typename T>
5708
+ static void rotate_pairs(const int64_t n, const int64_t n_offset, const float * cache, const T * src_data, T * dst_data, const int scale = 2) {
5709
+ for (int64_t i0 = 0; i0 < n; i0 += 2) {
5710
+ const int64_t ic = i0/scale; // hack for GGML_ROPE_TYPE_NORMAL, where we need ic = i0; for all other cases, ic = i0/2
5711
+
5712
+ const float cos_theta = cache[i0 + 0];
5713
+ const float sin_theta = cache[i0 + 1];
5714
+
5715
+ const T * const src = src_data + ic;
5716
+ T * dst = dst_data + ic;
5717
+
5718
+ const float x0 = type_conversion_table<T>::to_f32(src[0]);
5719
+ const float x1 = type_conversion_table<T>::to_f32(src[n_offset]);
5720
+
5721
+ dst[0] = type_conversion_table<T>::from_f32(x0*cos_theta - x1*sin_theta);
5722
+ dst[n_offset] = type_conversion_table<T>::from_f32(x0*sin_theta + x1*cos_theta);
5723
+ }
5724
+ }
5725
+
5726
+ template<typename T> //float or ggml_fp16_t
5727
+ static void ggml_compute_forward_rope_flt(
5539
5728
  const ggml_compute_params * params,
5540
5729
  ggml_tensor * dst,
5541
5730
  const bool forward) {
@@ -5544,6 +5733,9 @@ static void ggml_compute_forward_rope_f32(
5544
5733
  const ggml_tensor * src1 = dst->src[1];
5545
5734
  const ggml_tensor * src2 = dst->src[2];
5546
5735
 
5736
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
5737
+ GGML_ASSERT(src1->type == GGML_TYPE_I32);
5738
+
5547
5739
  float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
5548
5740
  int sections[4];
5549
5741
 
@@ -5566,7 +5758,8 @@ static void ggml_compute_forward_rope_f32(
5566
5758
  //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
5567
5759
  //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
5568
5760
 
5569
- GGML_ASSERT(nb00 == sizeof(float));
5761
+ GGML_ASSERT(nb0 == nb00);
5762
+ GGML_ASSERT(nb0 == sizeof(T));
5570
5763
 
5571
5764
  const int ith = params->ith;
5572
5765
  const int nth = params->nth;
@@ -5591,11 +5784,11 @@ static void ggml_compute_forward_rope_f32(
5591
5784
  float corr_dims[2];
5592
5785
  ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
5593
5786
 
5594
- const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
5595
- const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding
5787
+ const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
5788
+ const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, note: also true for vision (24 & 8 == true) and for imrope
5596
5789
  const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
5597
5790
 
5598
- if (is_mrope) {
5791
+ if (mrope_used) {
5599
5792
  GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
5600
5793
  }
5601
5794
 
@@ -5617,290 +5810,63 @@ static void ggml_compute_forward_rope_f32(
5617
5810
 
5618
5811
  const int32_t * pos = (const int32_t *) src1->data;
5619
5812
 
5813
+ int64_t last_i2 = -1;
5814
+
5620
5815
  for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
5621
5816
  for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
5817
+ for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
5818
+ if (ir++ < ir0) continue; // skip rows mapped to other threads
5819
+ if (ir > ir1) break;
5622
5820
 
5623
- float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
5624
- if (!is_mrope) {
5625
- const int64_t p = pos[i2];
5626
- ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5627
- }
5628
- else {
5629
- const int64_t p_t = pos[i2];
5630
- const int64_t p_h = pos[i2 + ne2];
5631
- const int64_t p_w = pos[i2 + ne2 * 2];
5632
- const int64_t p_e = pos[i2 + ne2 * 3];
5633
- ggml_mrope_cache_init(
5634
- p_t, p_h, p_w, p_e, sections, is_vision,
5635
- freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5636
- }
5637
-
5638
- for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
5639
- if (ir++ < ir0) continue;
5640
- if (ir > ir1) break;
5641
-
5642
- if (is_neox || is_mrope) {
5643
- if (is_vision){
5644
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5645
- const int64_t ic = i0/2;
5646
-
5647
- const float cos_theta = cache[i0 + 0];
5648
- const float sin_theta = cache[i0 + 1];
5649
-
5650
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5651
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5652
-
5653
- const float x0 = src[0];
5654
- const float x1 = src[n_dims];
5655
-
5656
- dst_data[0] = x0*cos_theta - x1*sin_theta;
5657
- dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
5658
- }
5659
- } else {
5660
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5661
- const int64_t ic = i0/2;
5662
-
5663
- const float cos_theta = cache[i0 + 0];
5664
- const float sin_theta = cache[i0 + 1];
5665
-
5666
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5667
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5668
-
5669
- const float x0 = src[0];
5670
- const float x1 = src[n_dims/2];
5671
-
5672
- dst_data[0] = x0*cos_theta - x1*sin_theta;
5673
- dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
5674
- }
5675
- }
5676
- } else {
5677
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5678
- const float cos_theta = cache[i0 + 0];
5679
- const float sin_theta = cache[i0 + 1];
5680
-
5681
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5682
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5683
-
5684
- const float x0 = src[0];
5685
- const float x1 = src[1];
5686
-
5687
- dst_data[0] = x0*cos_theta - x1*sin_theta;
5688
- dst_data[1] = x0*sin_theta + x1*cos_theta;
5821
+ float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
5822
+ if (last_i2 != i2) {
5823
+ if (!mrope_used) {
5824
+ const int64_t p = pos[i2];
5825
+ ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5689
5826
  }
5690
- }
5691
-
5692
- if (is_vision) {
5693
- for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
5694
- const int64_t ic = i0/2;
5695
-
5696
- const float cos_theta = cache[i0 + 0];
5697
- const float sin_theta = cache[i0 + 1];
5698
-
5699
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5700
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5701
-
5702
- const float x0 = src[0];
5703
- const float x1 = src[n_dims];
5704
-
5705
- dst_data[0] = x0*cos_theta - x1*sin_theta;
5706
- dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
5827
+ else {
5828
+ const int64_t p_t = pos[i2];
5829
+ const int64_t p_h = pos[i2 + ne2];
5830
+ const int64_t p_w = pos[i2 + ne2 * 2];
5831
+ const int64_t p_e = pos[i2 + ne2 * 3];
5832
+ ggml_mrope_cache_init(
5833
+ p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
5834
+ freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5707
5835
  }
5708
- } else {
5709
- // fill the remain channels with data from src tensor
5710
- for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
5711
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5712
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5713
5836
 
5714
- dst_data[0] = src[0];
5715
- dst_data[1] = src[1];
5716
- }
5837
+ last_i2 = i2;
5717
5838
  }
5718
- }
5719
- }
5720
- }
5721
- }
5722
-
5723
- // TODO: deduplicate f16/f32 code
5724
- static void ggml_compute_forward_rope_f16(
5725
- const ggml_compute_params * params,
5726
- ggml_tensor * dst,
5727
- const bool forward) {
5728
-
5729
- const ggml_tensor * src0 = dst->src[0];
5730
- const ggml_tensor * src1 = dst->src[1];
5731
- const ggml_tensor * src2 = dst->src[2];
5732
-
5733
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
5734
- int sections[4];
5735
-
5736
- //const int n_past = ((int32_t *) dst->op_params)[0];
5737
- const int n_dims = ((int32_t *) dst->op_params)[1];
5738
- const int mode = ((int32_t *) dst->op_params)[2];
5739
- //const int n_ctx = ((int32_t *) dst->op_params)[3];
5740
- const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
5741
- memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
5742
- memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
5743
- memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
5744
- memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
5745
- memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
5746
- memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
5747
- memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
5748
-
5749
-
5750
- GGML_TENSOR_UNARY_OP_LOCALS
5751
-
5752
- //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
5753
- //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
5754
-
5755
- GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
5756
-
5757
- const int ith = params->ith;
5758
- const int nth = params->nth;
5759
-
5760
- const int nr = ggml_nrows(dst);
5761
-
5762
- GGML_ASSERT(n_dims <= ne0);
5763
- GGML_ASSERT(n_dims % 2 == 0);
5764
-
5765
- // rows per thread
5766
- const int dr = (nr + nth - 1)/nth;
5767
-
5768
- // row range for this thread
5769
- const int ir0 = dr*ith;
5770
- const int ir1 = MIN(ir0 + dr, nr);
5771
-
5772
- // row index used to determine which thread to use
5773
- int ir = 0;
5774
-
5775
- const float theta_scale = powf(freq_base, -2.0f/n_dims);
5776
-
5777
- float corr_dims[2];
5778
- ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
5779
-
5780
- const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
5781
- const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
5782
- const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
5783
-
5784
- if (is_mrope) {
5785
- GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
5786
- }
5787
-
5788
- if (is_vision) {
5789
- GGML_ASSERT(n_dims == ne0/2);
5790
- }
5791
-
5792
- const float * freq_factors = NULL;
5793
- if (src2 != NULL) {
5794
- GGML_ASSERT(src2->type == GGML_TYPE_F32);
5795
- GGML_ASSERT(src2->ne[0] >= n_dims / 2);
5796
- freq_factors = (const float *) src2->data;
5797
- }
5798
-
5799
- // backward process uses inverse rotation by cos and sin.
5800
- // cos and sin build a rotation matrix, where the inverse is the transpose.
5801
- // this essentially just switches the sign of sin.
5802
- const float sin_sign = forward ? 1.0f : -1.0f;
5803
-
5804
- const int32_t * pos = (const int32_t *) src1->data;
5805
-
5806
- for (int64_t i3 = 0; i3 < ne3; i3++) {
5807
- for (int64_t i2 = 0; i2 < ne2; i2++) {
5808
-
5809
- float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
5810
- if (!is_mrope) {
5811
- const int64_t p = pos[i2];
5812
- ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5813
- }
5814
- else {
5815
- const int64_t p_t = pos[i2];
5816
- const int64_t p_h = pos[i2 + ne2];
5817
- const int64_t p_w = pos[i2 + ne2 * 2];
5818
- const int64_t p_e = pos[i2 + ne2 * 3];
5819
- ggml_mrope_cache_init(
5820
- p_t, p_h, p_w, p_e, sections, is_vision,
5821
- freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5822
- }
5823
-
5824
- for (int64_t i1 = 0; i1 < ne1; i1++) {
5825
- if (ir++ < ir0) continue;
5826
- if (ir > ir1) break;
5827
-
5828
- if (is_neox || is_mrope) {
5829
- if (is_vision) {
5830
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5831
- const int64_t ic = i0/2;
5832
-
5833
- const float cos_theta = cache[i0 + 0];
5834
- const float sin_theta = cache[i0 + 1];
5835
-
5836
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5837
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5838
-
5839
- const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
5840
- const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims]);
5841
5839
 
5842
- dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5843
- dst_data[n_dims] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5844
- }
5845
- } else {
5846
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5847
- const int64_t ic = i0/2;
5848
-
5849
- const float cos_theta = cache[i0 + 0];
5850
- const float sin_theta = cache[i0 + 1];
5851
-
5852
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5853
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5854
-
5855
- const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
5856
- const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims/2]);
5857
-
5858
- dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5859
- dst_data[n_dims/2] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5860
- }
5861
- }
5862
- } else {
5863
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5864
- const float cos_theta = cache[i0 + 0];
5865
- const float sin_theta = cache[i0 + 1];
5866
-
5867
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5868
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5869
-
5870
- const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
5871
- const float x1 = GGML_CPU_FP16_TO_FP32(src[1]);
5872
-
5873
- dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5874
- dst_data[1] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5875
- }
5840
+ T * src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
5841
+ T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
5842
+
5843
+ switch (mode) {
5844
+ case GGML_ROPE_TYPE_NORMAL:
5845
+ rotate_pairs<T>(n_dims, 1, cache, src, dst_data, 1);
5846
+ break;
5847
+ case GGML_ROPE_TYPE_NEOX:
5848
+ case GGML_ROPE_TYPE_MROPE:
5849
+ case GGML_ROPE_TYPE_IMROPE:
5850
+ rotate_pairs<T>(n_dims, n_dims/2, cache, src, dst_data);
5851
+ break;
5852
+ case GGML_ROPE_TYPE_VISION:
5853
+ rotate_pairs<T>(ne0, n_dims, cache, src, dst_data);
5854
+ break;
5855
+ default:
5856
+ GGML_ABORT("rope type not supported");
5876
5857
  }
5877
5858
 
5878
- if (is_vision) {
5879
- for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
5880
- const int64_t ic = i0/2;
5881
-
5882
- const float cos_theta = cache[i0 + 0];
5883
- const float sin_theta = cache[i0 + 1];
5884
-
5885
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5886
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5887
-
5888
- const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
5889
- const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims]);
5890
-
5891
- dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5892
- dst_data[n_dims] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5893
- }
5894
- } else {
5859
+ if (!is_vision) {
5860
+ // fill the remain channels with data from src tensor
5895
5861
  for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
5896
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5897
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5862
+ const T * const src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5863
+ T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5898
5864
 
5899
5865
  dst_data[0] = src[0];
5900
5866
  dst_data[1] = src[1];
5901
5867
  }
5902
5868
  }
5903
- }
5869
+ } //attn-heads
5904
5870
  }
5905
5871
  }
5906
5872
  }
@@ -5914,11 +5880,11 @@ void ggml_compute_forward_rope(
5914
5880
  switch (src0->type) {
5915
5881
  case GGML_TYPE_F16:
5916
5882
  {
5917
- ggml_compute_forward_rope_f16(params, dst, true);
5883
+ ggml_compute_forward_rope_flt<ggml_fp16_t>(params, dst, true);
5918
5884
  } break;
5919
5885
  case GGML_TYPE_F32:
5920
5886
  {
5921
- ggml_compute_forward_rope_f32(params, dst, true);
5887
+ ggml_compute_forward_rope_flt<float>(params, dst, true);
5922
5888
  } break;
5923
5889
  default:
5924
5890
  {
@@ -5938,11 +5904,11 @@ void ggml_compute_forward_rope_back(
5938
5904
  switch (src0->type) {
5939
5905
  case GGML_TYPE_F16:
5940
5906
  {
5941
- ggml_compute_forward_rope_f16(params, dst, false);
5907
+ ggml_compute_forward_rope_flt<ggml_fp16_t>(params, dst, false);
5942
5908
  } break;
5943
5909
  case GGML_TYPE_F32:
5944
5910
  {
5945
- ggml_compute_forward_rope_f32(params, dst, false);
5911
+ ggml_compute_forward_rope_flt<float>(params, dst, false);
5946
5912
  } break;
5947
5913
  default:
5948
5914
  {
@@ -6239,7 +6205,7 @@ static void ggml_compute_forward_im2col_f16(
6239
6205
  const ggml_tensor * src1 = dst->src[1];
6240
6206
 
6241
6207
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
6242
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
6208
+ GGML_ASSERT(src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
6243
6209
  GGML_ASSERT( dst->type == GGML_TYPE_F16);
6244
6210
 
6245
6211
  GGML_TENSOR_BINARY_OP_LOCALS;
@@ -6270,7 +6236,7 @@ static void ggml_compute_forward_im2col_f16(
6270
6236
  int ofs1 = is_2D ? nb12 : nb11;
6271
6237
 
6272
6238
  GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
6273
- GGML_ASSERT(nb10 == sizeof(float));
6239
+ GGML_ASSERT(nb10 == ggml_type_size(src1->type));
6274
6240
 
6275
6241
  // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
6276
6242
  {
@@ -6283,7 +6249,12 @@ static void ggml_compute_forward_im2col_f16(
6283
6249
 
6284
6250
  // micro kernel
6285
6251
  ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
6286
- const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
6252
+ const float * const src_data_f32 = src1->type == GGML_TYPE_F32
6253
+ ? (const float *)((const char *) src1->data + in*ofs0 + iic*ofs1)
6254
+ : nullptr; // [IH, IW]
6255
+ const ggml_fp16_t * const src_data_f16 = src1->type == GGML_TYPE_F16
6256
+ ? (const ggml_fp16_t *)((const char *) src1->data + in*ofs0 + iic*ofs1)
6257
+ : nullptr; // [IH, IW]
6287
6258
 
6288
6259
  for (int64_t ikh = 0; ikh < KH; ikh++) { // 1
6289
6260
  for (int64_t ikw = 0; ikw < KW; ikw++) {
@@ -6293,7 +6264,11 @@ static void ggml_compute_forward_im2col_f16(
6293
6264
  if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
6294
6265
  dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
6295
6266
  } else {
6296
- dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(src_data[iih*IW + iiw]);
6267
+ if (src_data_f32 != nullptr) {
6268
+ dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(src_data_f32[iih*IW + iiw]);
6269
+ } else {
6270
+ dst_data[iic*(KH*KW) + ikh*KW + ikw] = src_data_f16[iih*IW + iiw];
6271
+ }
6297
6272
  }
6298
6273
  }
6299
6274
  }
@@ -6493,7 +6468,7 @@ static void ggml_compute_forward_im2col_3d_f16(
6493
6468
  const int64_t iih = ioh*s1 + ikh*d1 - p1;
6494
6469
  const int64_t iid = iod*s2 + ikd*d2 - p2;
6495
6470
 
6496
- if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
6471
+ if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
6497
6472
  dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
6498
6473
  } else {
6499
6474
  const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
@@ -6664,8 +6639,13 @@ static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params
6664
6639
  ggml_compute_forward_mul_mat(params, &dst);
6665
6640
  }
6666
6641
 
6642
+ static inline int64_t ggml_wrap_around(int64_t coord, int64_t size) {
6643
+ return (coord + size) % size; // adding size avoids negative number weirdness
6644
+ }
6645
+
6667
6646
  // ggml_compute_forward_conv_2d
6668
6647
 
6648
+
6669
6649
  static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params,
6670
6650
  const ggml_tensor * kernel, // [KW, KH, IC, OC]
6671
6651
  const ggml_tensor * src, // [W, H, C, N]
@@ -7074,7 +7054,11 @@ static void ggml_compute_forward_conv_2d_dw_cwhn(
7074
7054
  const int64_t row_end = MIN(row_start + rows_per_thread, rows_total);
7075
7055
 
7076
7056
  #ifdef GGML_SIMD
7077
- const int64_t pkg_size = GGML_F32_EPR;
7057
+ #if defined(__ARM_FEATURE_SVE)
7058
+ const int64_t pkg_size = svcntw();
7059
+ #else
7060
+ const int64_t pkg_size = GGML_F32_EPR;
7061
+ #endif
7078
7062
  const int64_t pkg_count = c / pkg_size;
7079
7063
  const int64_t c_pkg_end = pkg_count * pkg_size;
7080
7064
  #else
@@ -7211,12 +7195,13 @@ void ggml_compute_forward_conv_2d_dw(
7211
7195
  }
7212
7196
  }
7213
7197
 
7214
- // ggml_compute_forward_pool_1d_sk_p0
7215
-
7216
- static void ggml_compute_forward_pool_1d_sk_p0(
7198
+ // ggml_compute_forward_pool_1d_ksp
7199
+ static void ggml_compute_forward_pool_1d_ksp(
7217
7200
  const ggml_compute_params * params,
7218
7201
  const ggml_op_pool op,
7219
7202
  const int k,
7203
+ const int s,
7204
+ const int p,
7220
7205
  ggml_tensor * dst) {
7221
7206
 
7222
7207
  const ggml_tensor * src = dst->src[0];
@@ -7227,39 +7212,56 @@ static void ggml_compute_forward_pool_1d_sk_p0(
7227
7212
  return;
7228
7213
  }
7229
7214
 
7230
- const char * cdata = (const char *)src->data;
7231
- const char * const data_end = cdata + ggml_nbytes(src);
7232
- float * drow = (float *)dst->data;
7215
+ const int64_t IW = src->ne[0];
7216
+ const int64_t OW = dst->ne[0];
7233
7217
 
7234
- const int64_t rs = dst->ne[0];
7218
+ const int64_t nr = ggml_nrows(src);
7235
7219
 
7236
- while (cdata < data_end) {
7237
- const void * srow = (const void *)cdata;
7238
- int j = 0;
7239
- for (int64_t i = 0; i < rs; ++i) {
7220
+ for (int64_t ir = 0; ir < nr; ++ir) {
7221
+ const char * srow_bytes = (const char *) src->data + ir * src->nb[1];
7222
+ float * drow = (float *) (( char *) dst->data + ir * dst->nb[1]);
7223
+
7224
+ for (int64_t ow = 0; ow < OW; ++ow) {
7225
+ float res = 0;
7240
7226
  switch (op) {
7241
- case GGML_OP_POOL_AVG: drow[i] = 0; break;
7242
- case GGML_OP_POOL_MAX: drow[i] = -FLT_MAX; break;
7227
+ case GGML_OP_POOL_AVG: res = 0.0f; break;
7228
+ case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
7243
7229
  case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7244
7230
  }
7231
+
7232
+ int count = 0;
7233
+ const int base = (int) ow * s - p;
7234
+
7245
7235
  for (int ki = 0; ki < k; ++ki) {
7246
- const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
7236
+ const int j = base + ki;
7237
+ if (j < 0 || j >= (int) IW) {
7238
+ continue;
7239
+ }
7240
+
7241
+ float v;
7242
+ if (src->type == GGML_TYPE_F32) {
7243
+ v = ((const float *) srow_bytes)[j];
7244
+ } else {
7245
+ v = GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) srow_bytes)[j]);
7246
+ }
7247
+
7247
7248
  switch (op) {
7248
- case GGML_OP_POOL_AVG: drow[i] += srow_j; break;
7249
- case GGML_OP_POOL_MAX: if (srow_j > drow[i]) drow[i] = srow_j; break;
7250
- case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7249
+ case GGML_OP_POOL_AVG: res += v; break;
7250
+ case GGML_OP_POOL_MAX: res = std::max(v, res); break;
7251
+ case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7251
7252
  }
7252
- ++j;
7253
+
7254
+ ++count;
7253
7255
  }
7256
+
7254
7257
  switch (op) {
7255
- case GGML_OP_POOL_AVG: drow[i] /= k; break;
7256
- case GGML_OP_POOL_MAX: break;
7258
+ case GGML_OP_POOL_AVG: res = (count > 0) ? (res / count) : 0.0f; break;
7259
+ case GGML_OP_POOL_MAX: break;
7257
7260
  case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7258
7261
  }
7259
- }
7260
7262
 
7261
- cdata += src->nb[1];
7262
- drow += rs;
7263
+ drow[ow] = res;
7264
+ }
7263
7265
  }
7264
7266
  }
7265
7267
 
@@ -7274,10 +7276,8 @@ void ggml_compute_forward_pool_1d(
7274
7276
  const int k0 = opts[1];
7275
7277
  const int s0 = opts[2];
7276
7278
  const int p0 = opts[3];
7277
- GGML_ASSERT(p0 == 0); // padding not supported
7278
- GGML_ASSERT(k0 == s0); // only s = k supported
7279
7279
 
7280
- ggml_compute_forward_pool_1d_sk_p0(params, op, k0, dst);
7280
+ ggml_compute_forward_pool_1d_ksp(params, op, k0, s0, p0, dst);
7281
7281
  }
7282
7282
 
7283
7283
  // ggml_compute_forward_pool_2d
@@ -7295,6 +7295,7 @@ void ggml_compute_forward_pool_2d(
7295
7295
  }
7296
7296
 
7297
7297
  const int32_t * opts = (const int32_t *)dst->op_params;
7298
+
7298
7299
  ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
7299
7300
  const int k0 = opts[1];
7300
7301
  const int k1 = opts[2];
@@ -7318,11 +7319,13 @@ void ggml_compute_forward_pool_2d(
7318
7319
  while (cdata < data_end) {
7319
7320
  for (int oy = 0; oy < py; ++oy) {
7320
7321
  float * const drow = dplane + oy * px;
7322
+ float * const out = drow;
7323
+
7321
7324
  for (int ox = 0; ox < px; ++ox) {
7322
- float * const out = drow + ox;
7325
+ float res = 0;
7323
7326
  switch (op) {
7324
- case GGML_OP_POOL_AVG: *out = 0; break;
7325
- case GGML_OP_POOL_MAX: *out = -FLT_MAX; break;
7327
+ case GGML_OP_POOL_AVG: res = 0; break;
7328
+ case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
7326
7329
  case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7327
7330
  }
7328
7331
 
@@ -7330,24 +7333,32 @@ void ggml_compute_forward_pool_2d(
7330
7333
  const int iy = offset1 + oy * s1;
7331
7334
 
7332
7335
  for (int ky = 0; ky < k1; ++ky) {
7333
- if (iy + ky < 0 || iy + ky >= src->ne[1]) continue;
7336
+ if (iy + ky < 0 || iy + ky >= src->ne[1]) {
7337
+ continue;
7338
+ }
7339
+
7334
7340
  const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky));
7335
7341
  for (int kx = 0; kx < k0; ++kx) {
7336
7342
  int j = ix + kx;
7337
- if (j < 0 || j >= src->ne[0]) continue;
7343
+ if (j < 0 || j >= src->ne[0]) {
7344
+ continue;
7345
+ }
7346
+
7338
7347
  const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
7339
7348
  switch (op) {
7340
- case GGML_OP_POOL_AVG: *out += srow_j; break;
7341
- case GGML_OP_POOL_MAX: if (srow_j > *out) *out = srow_j; break;
7349
+ case GGML_OP_POOL_AVG: res += srow_j; break;
7350
+ case GGML_OP_POOL_MAX: res = std::max(srow_j, res); break;
7342
7351
  case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7343
7352
  }
7344
7353
  }
7345
7354
  }
7346
7355
  switch (op) {
7347
- case GGML_OP_POOL_AVG: *out /= ka; break;
7348
- case GGML_OP_POOL_MAX: break;
7356
+ case GGML_OP_POOL_AVG: res /= ka; break;
7357
+ case GGML_OP_POOL_MAX: break;
7349
7358
  case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7350
7359
  }
7360
+
7361
+ out[ox] = res;
7351
7362
  }
7352
7363
  }
7353
7364
 
@@ -7497,10 +7508,17 @@ static void ggml_compute_forward_upscale_f32(
7497
7508
  float sf1 = (float)ne1/src0->ne[1];
7498
7509
  float sf2 = (float)ne2/src0->ne[2];
7499
7510
  float sf3 = (float)ne3/src0->ne[3];
7511
+ float pixel_offset = 0.5f;
7500
7512
 
7501
7513
  const int32_t mode_flags = ggml_get_op_params_i32(dst, 0);
7502
7514
  const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF);
7503
7515
 
7516
+ if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
7517
+ pixel_offset = 0.0f;
7518
+ sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
7519
+ sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
7520
+ }
7521
+
7504
7522
  if (mode == GGML_SCALE_MODE_NEAREST) {
7505
7523
  for (int64_t i3 = 0; i3 < ne3; i3++) {
7506
7524
  const int64_t i03 = i3 / sf3;
@@ -7519,14 +7537,66 @@ static void ggml_compute_forward_upscale_f32(
7519
7537
  }
7520
7538
  }
7521
7539
  }
7522
- } else if (mode == GGML_SCALE_MODE_BILINEAR) {
7523
- float pixel_offset = 0.5f;
7524
- if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
7525
- pixel_offset = 0.0f;
7526
- sf0 = (float)(ne0 - 1) / (src0->ne[0] - 1);
7527
- sf1 = (float)(ne1 - 1) / (src0->ne[1] - 1);
7528
- }
7540
+ } else if (mode == GGML_SCALE_MODE_BILINEAR && (mode_flags & GGML_SCALE_FLAG_ANTIALIAS)) {
7541
+ // Similar to F.interpolate(..., mode="bilinear", align_corners=False, antialias=True)
7542
+ // https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp
7543
+ auto triangle_filter = [](float x) -> float {
7544
+ return std::max(1.0f - fabsf(x), 0.0f);
7545
+ };
7546
+
7547
+ // support and invscale, minimum 1 pixel for bilinear
7548
+ const float support1 = std::max(1.0f, 1.0f / sf1);
7549
+ const float invscale1 = 1.0f / support1;
7550
+ const float support0 = std::max(1.0f, 1.0f / sf0);
7551
+ const float invscale0 = 1.0f / support0;
7552
+
7553
+ for (int64_t i3 = 0; i3 < ne3; i3++) {
7554
+ const int64_t i03 = i3 / sf3;
7555
+ for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
7556
+ const int64_t i02 = i2 / sf2;
7557
+ for (int64_t i1 = 0; i1 < ne1; i1++) {
7558
+ const float y = ((float) i1 + pixel_offset) / sf1;
7559
+ for (int64_t i0 = 0; i0 < ne0; i0++) {
7560
+ const float x = ((float) i0 + pixel_offset) / sf0;
7561
+
7562
+ // the range of source pixels that contribute
7563
+ const int64_t x_min = std::max<int64_t>(x - support0 + pixel_offset, 0);
7564
+ const int64_t x_max = std::min<int64_t>(x + support0 + pixel_offset, ne00);
7565
+ const int64_t y_min = std::max<int64_t>(y - support1 + pixel_offset, 0);
7566
+ const int64_t y_max = std::min<int64_t>(y + support1 + pixel_offset, ne01);
7567
+
7568
+ // bilinear filter with antialiasing
7569
+ float val = 0.0f;
7570
+ float total_weight = 0.0f;
7571
+
7572
+ for (int64_t sy = y_min; sy < y_max; sy++) {
7573
+ const float weight_y = triangle_filter((sy - y + pixel_offset) * invscale1);
7574
+
7575
+ for (int64_t sx = x_min; sx < x_max; sx++) {
7576
+ const float weight_x = triangle_filter((sx - x + pixel_offset) * invscale0);
7577
+ const float weight = weight_x * weight_y;
7578
+
7579
+ if (weight <= 0.0f) {
7580
+ continue;
7581
+ }
7582
+
7583
+ const float pixel = *(const float *)((const char *)src0->data + sx*nb00 + sy*nb01 + i02*nb02 + i03*nb03);
7584
+ val += pixel * weight;
7585
+ total_weight += weight;
7586
+ }
7587
+ }
7529
7588
 
7589
+ if (total_weight > 0.0f) {
7590
+ val /= total_weight;
7591
+ }
7592
+
7593
+ float * dst_ptr = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
7594
+ *dst_ptr = val;
7595
+ }
7596
+ }
7597
+ }
7598
+ }
7599
+ } else if (mode == GGML_SCALE_MODE_BILINEAR) {
7530
7600
  for (int64_t i3 = 0; i3 < ne3; i3++) {
7531
7601
  const int64_t i03 = i3 / sf3;
7532
7602
  for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
@@ -7561,6 +7631,51 @@ static void ggml_compute_forward_upscale_f32(
7561
7631
 
7562
7632
  const float val = a*(1 - dx)*(1 - dy) + b*dx*(1 - dy) + c*(1 - dx)*dy + d*dx*dy;
7563
7633
 
7634
+ float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
7635
+ *y_dst = val;
7636
+ }
7637
+ }
7638
+ }
7639
+ }
7640
+ } else if (mode == GGML_SCALE_MODE_BICUBIC) {
7641
+ // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
7642
+ const float a = -0.75f; // use alpha = -0.75 (same as PyTorch)
7643
+ auto weight1 = [a](float x) { return ((a + 2) * x - (a + 3)) * x * x + 1; };
7644
+ auto weight2 = [a](float x) { return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; };
7645
+ auto bicubic = [=](float p0, float p1, float p2, float p3, float x) {
7646
+ const float w0 = weight2(x + 1);
7647
+ const float w1 = weight1(x + 0);
7648
+ const float w2 = weight1(1 - x);
7649
+ const float w3 = weight2(2 - x);
7650
+ return p0*w0 + p1*w1 + p2*w2 + p3*w3;
7651
+ };
7652
+
7653
+ for (int64_t i3 = 0; i3 < ne3; i3++) {
7654
+ const int64_t i03 = i3 / sf3;
7655
+ for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
7656
+ const int64_t i02 = i2 / sf2;
7657
+ for (int64_t i1 = 0; i1 < ne1; i1++) {
7658
+ const float y = ((float)i1 + pixel_offset) / sf1 - pixel_offset;
7659
+ const int64_t y0 = (int64_t)floorf(y);
7660
+ const float dy = y - (float)y0;
7661
+
7662
+ for (int64_t i0 = 0; i0 < ne0; i0++) {
7663
+ const float x = ((float)i0 + pixel_offset) / sf0 - pixel_offset;
7664
+ const int64_t x0 = (int64_t)floorf(x);
7665
+ const float dx = x - (float)x0;
7666
+
7667
+ auto p = [=](int64_t x_off, int64_t y_off) -> float {
7668
+ int64_t i00 = std::max(int64_t(0), std::min(x0 + x_off, ne00 - 1));
7669
+ int64_t i01 = std::max(int64_t(0), std::min(y0 + y_off, ne01 - 1));
7670
+ return *(const float *)((const char *)src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7671
+ };
7672
+
7673
+ const float val = bicubic(
7674
+ bicubic(p(-1,-1), p(0,-1), p(1,-1), p(2,-1), dx),
7675
+ bicubic(p(-1, 0), p(0, 0), p(1, 0), p(2, 0), dx),
7676
+ bicubic(p(-1, 1), p(0, 1), p(1, 1), p(2, 1), dx),
7677
+ bicubic(p(-1, 2), p(0, 2), p(1, 2), p(2, 2), dx), dy);
7678
+
7564
7679
  float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
7565
7680
  *y_dst = val;
7566
7681
  }
@@ -7593,14 +7708,14 @@ void ggml_compute_forward_upscale(
7593
7708
 
7594
7709
  // ggml_compute_forward_pad
7595
7710
 
7711
+ template<bool circular_t>
7596
7712
  static void ggml_compute_forward_pad_f32(
7597
7713
  const ggml_compute_params * params,
7598
7714
  ggml_tensor * dst) {
7599
7715
 
7600
7716
  const ggml_tensor * src0 = dst->src[0];
7601
7717
 
7602
- GGML_ASSERT(src0->nb[0] == sizeof(float));
7603
- GGML_ASSERT( dst->nb[0] == sizeof(float));
7718
+ assert(dst->nb[0] == sizeof(float));
7604
7719
 
7605
7720
  const int ith = params->ith;
7606
7721
  const int nth = params->nth;
@@ -7617,23 +7732,40 @@ static void ggml_compute_forward_pad_f32(
7617
7732
  const int32_t lp3 = ggml_get_op_params_i32(dst, 6);
7618
7733
  const int32_t rp3 = ggml_get_op_params_i32(dst, 7);
7619
7734
 
7620
-
7621
7735
  // TODO: optimize
7622
7736
 
7623
7737
  for (int64_t i2 = 0; i2 < ne2; ++i2) {
7624
7738
  for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
7625
7739
  for (int64_t i0 = 0; i0 < ne0; ++i0) {
7626
7740
  for (int64_t i3 = 0; i3 < ne3; ++i3) {
7627
- const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
7628
- if ((i0 >= lp0 && i0 < ne0 - rp0) \
7629
- && (i1 >= lp1 && i1 < ne1 - rp1) \
7630
- && (i2 >= lp2 && i2 < ne2 - rp2) \
7631
- && (i3 >= lp3 && i3 < ne3 - rp3)) {
7632
- const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
7741
+ // circular means wrap around on a torus, so x and y loop around
7742
+ if constexpr (circular_t) {
7743
+ const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
7744
+ const int64_t src_i0 = ggml_wrap_around(i0 - lp0, ne00);
7745
+ const int64_t src_i1 = ggml_wrap_around(i1 - lp1, ne01);
7746
+ const int64_t src_i2 = ggml_wrap_around(i2 - lp2, ne02);
7747
+ const int64_t src_i3 = ggml_wrap_around(i3 - lp3, ne03);
7748
+
7749
+ const int64_t src_idx =
7750
+ src_i3*nb03 +
7751
+ src_i2*nb02 +
7752
+ src_i1*nb01 +
7753
+ src_i0*nb00;
7754
+
7633
7755
  const float * src_ptr = (const float *)((char *) src0->data + src_idx);
7634
7756
  dst_ptr[dst_idx] = *src_ptr;
7635
7757
  } else {
7636
- dst_ptr[dst_idx] = 0;
7758
+ const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
7759
+ if ((i0 >= lp0 && i0 < ne0 - rp0) \
7760
+ && (i1 >= lp1 && i1 < ne1 - rp1) \
7761
+ && (i2 >= lp2 && i2 < ne2 - rp2) \
7762
+ && (i3 >= lp3 && i3 < ne3 - rp3)) {
7763
+ const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
7764
+ const float * src_ptr = (const float *)((char *) src0->data + src_idx);
7765
+ dst_ptr[dst_idx] = *src_ptr;
7766
+ } else {
7767
+ dst_ptr[dst_idx] = 0;
7768
+ }
7637
7769
  }
7638
7770
  }
7639
7771
  }
@@ -7641,16 +7773,20 @@ static void ggml_compute_forward_pad_f32(
7641
7773
  }
7642
7774
  }
7643
7775
 
7776
+
7644
7777
  void ggml_compute_forward_pad(
7645
7778
  const ggml_compute_params * params,
7646
7779
  ggml_tensor * dst) {
7647
-
7648
7780
  const ggml_tensor * src0 = dst->src[0];
7649
-
7781
+ const bool circular = (bool) ggml_get_op_params_i32(dst, 8);
7650
7782
  switch (src0->type) {
7651
7783
  case GGML_TYPE_F32:
7652
7784
  {
7653
- ggml_compute_forward_pad_f32(params, dst);
7785
+ if (circular) {
7786
+ ggml_compute_forward_pad_f32<true>(params, dst);
7787
+ } else {
7788
+ ggml_compute_forward_pad_f32<false>(params, dst);
7789
+ }
7654
7790
  } break;
7655
7791
  default:
7656
7792
  {
@@ -7854,6 +7990,18 @@ void ggml_compute_forward_timestep_embedding(
7854
7990
 
7855
7991
  // ggml_compute_forward_argsort
7856
7992
 
7993
+ template<enum ggml_sort_order order>
7994
+ struct cmp_argsort {
7995
+ const float * data;
7996
+ bool operator()(int32_t a, int32_t b) const {
7997
+ if constexpr (order == GGML_SORT_ORDER_ASC) {
7998
+ return data[a] < data[b];
7999
+ } else {
8000
+ return data[a] > data[b];
8001
+ }
8002
+ }
8003
+ };
8004
+
7857
8005
  static void ggml_compute_forward_argsort_f32(
7858
8006
  const ggml_compute_params * params,
7859
8007
  ggml_tensor * dst) {
@@ -7872,23 +8020,25 @@ static void ggml_compute_forward_argsort_f32(
7872
8020
  ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(dst, 0);
7873
8021
 
7874
8022
  for (int64_t i = ith; i < nr; i += nth) {
7875
- int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
7876
8023
  const float * src_data = (float *)((char *) src0->data + i*nb01);
7877
8024
 
8025
+ int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
8026
+
7878
8027
  for (int64_t j = 0; j < ne0; j++) {
7879
8028
  dst_data[j] = j;
7880
8029
  }
7881
8030
 
7882
- // C doesn't have a functional sort, so we do a bubble sort instead
7883
- for (int64_t j = 0; j < ne0; j++) {
7884
- for (int64_t k = j + 1; k < ne0; k++) {
7885
- if ((order == GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
7886
- (order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
7887
- int32_t tmp = dst_data[j];
7888
- dst_data[j] = dst_data[k];
7889
- dst_data[k] = tmp;
7890
- }
7891
- }
8031
+ switch (order) {
8032
+ case GGML_SORT_ORDER_ASC:
8033
+ std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_ASC>{src_data});
8034
+ break;
8035
+
8036
+ case GGML_SORT_ORDER_DESC:
8037
+ std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_DESC>{src_data});
8038
+ break;
8039
+
8040
+ default:
8041
+ GGML_ABORT("invalid sort order");
7892
8042
  }
7893
8043
  }
7894
8044
  }
@@ -7911,12 +8061,80 @@ void ggml_compute_forward_argsort(
7911
8061
  }
7912
8062
  }
7913
8063
 
7914
- // ggml_compute_forward_flash_attn_ext
8064
+ // ggml_compute_forward_top_k
7915
8065
 
7916
- static void ggml_compute_forward_flash_attn_ext_f16(
8066
+ struct cmp_top_k {
8067
+ const float * data;
8068
+ bool operator()(int32_t a, int32_t b) const {
8069
+ return data[a] > data[b];
8070
+ }
8071
+ };
8072
+
8073
+ static void ggml_compute_forward_top_k_f32(
8074
+ const ggml_compute_params * params,
8075
+ ggml_tensor * dst) {
8076
+
8077
+ const ggml_tensor * src0 = dst->src[0];
8078
+
8079
+ GGML_TENSOR_UNARY_OP_LOCALS
8080
+
8081
+ GGML_ASSERT(nb0 == sizeof(float));
8082
+
8083
+ const int ith = params->ith;
8084
+ const int nth = params->nth;
8085
+
8086
+ const int64_t nr = ggml_nrows(src0);
8087
+
8088
+ const int top_k = ne0;
8089
+
8090
+ int32_t * tmp = (int32_t *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
8091
+
8092
+ for (int64_t i = ith; i < nr; i += nth) {
8093
+ const float * src_data = (float *)((char *) src0->data + i*nb01);
8094
+
8095
+ for (int64_t j = 0; j < ne00; j++) {
8096
+ tmp[j] = j;
8097
+ }
8098
+
8099
+ std::partial_sort(tmp, tmp + top_k, tmp + ne00, cmp_top_k{src_data});
8100
+
8101
+ int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
8102
+
8103
+ std::copy(tmp, tmp + top_k, dst_data);
8104
+
8105
+ // emphasize that the order is not important
8106
+ if (top_k > 1) {
8107
+ std::swap(dst_data[0], dst_data[1]);
8108
+ }
8109
+ }
8110
+ }
8111
+
8112
+ void ggml_compute_forward_top_k(
8113
+ const ggml_compute_params * params,
8114
+ ggml_tensor * dst) {
8115
+
8116
+ const ggml_tensor * src0 = dst->src[0];
8117
+
8118
+ switch (src0->type) {
8119
+ case GGML_TYPE_F32:
8120
+ {
8121
+ ggml_compute_forward_top_k_f32(params, dst);
8122
+ } break;
8123
+ default:
8124
+ {
8125
+ GGML_ABORT("fatal error");
8126
+ }
8127
+ }
8128
+ }
8129
+
8130
+ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
7917
8131
  const ggml_compute_params * params,
7918
- ggml_tensor * dst) {
8132
+ ggml_tensor * dst,
8133
+ int ir0, int ir1,
8134
+ int64_t ic_start, int64_t ic_end,
8135
+ float * partials, int64_t partial_stride) {
7919
8136
 
8137
+ const bool write_partials = (partials != nullptr);
7920
8138
  const ggml_tensor * q = dst->src[0];
7921
8139
  const ggml_tensor * k = dst->src[1];
7922
8140
  const ggml_tensor * v = dst->src[2];
@@ -7932,9 +8150,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
7932
8150
  GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
7933
8151
  GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
7934
8152
 
7935
- const int ith = params->ith;
7936
- const int nth = params->nth;
7937
-
7938
8153
  const int64_t DK = nek0;
7939
8154
  const int64_t DV = nev0;
7940
8155
  const int64_t N = neq1;
@@ -7968,16 +8183,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
7968
8183
 
7969
8184
  // parallelize by q rows using ggml_vec_dot_f32
7970
8185
 
7971
- // total rows in q
7972
- const int nr = neq1*neq2*neq3;
7973
-
7974
- // rows per thread
7975
- const int dr = (nr + nth - 1)/nth;
7976
-
7977
- // row range for this thread
7978
- const int ir0 = dr*ith;
7979
- const int ir1 = MIN(ir0 + dr, nr);
7980
-
7981
8186
  float scale = 1.0f;
7982
8187
  float max_bias = 0.0f;
7983
8188
  float logit_softcap = 0.0f;
@@ -8004,7 +8209,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
8004
8209
  GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type");
8005
8210
  GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type");
8006
8211
 
8007
- // loop over n_batch and n_head
8212
+ int ith = params->ith;
8213
+
8008
8214
  for (int ir = ir0; ir < ir1; ++ir) {
8009
8215
  // q indices
8010
8216
  const int iq3 = ir/(neq2*neq1);
@@ -8044,7 +8250,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
8044
8250
  // online softmax / attention
8045
8251
  // loop over n_kv and n_head_kv
8046
8252
  // ref: https://arxiv.org/pdf/2112.05682.pdf
8047
- for (int64_t ic = 0; ic < nek1; ++ic) {
8253
+
8254
+ for (int64_t ic = ic_start; ic < ic_end; ++ic) {
8048
8255
  const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f;
8049
8256
  if (mv == -INFINITY) {
8050
8257
  continue;
@@ -8117,8 +8324,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
8117
8324
  }
8118
8325
  }
8119
8326
 
8120
- // sinks
8121
- if (sinks) {
8327
+ // sinks - apply only on the first kv-chunk
8328
+ if (sinks && ic_start == 0) {
8122
8329
  const float s = ((float *)((char *) sinks->data))[h];
8123
8330
 
8124
8331
  float ms = 1.0f;
@@ -8126,6 +8333,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
8126
8333
 
8127
8334
  if (s > M) {
8128
8335
  ms = expf(M - s);
8336
+ M = s;
8129
8337
  ggml_vec_scale_f32(DV, VKQ32, ms);
8130
8338
  } else {
8131
8339
  vs = expf(s - M);
@@ -8134,20 +8342,517 @@ static void ggml_compute_forward_flash_attn_ext_f16(
8134
8342
  S = S*ms + vs;
8135
8343
  }
8136
8344
 
8137
- // V /= S
8138
- const float S_inv = 1.0f/S;
8139
- ggml_vec_scale_f32(DV, VKQ32, S_inv);
8345
+ if (write_partials) {
8346
+ // Write M, S, VKQ to partials for later reduction
8347
+ // partials layout: [M, S, VKQ[DV]] per query head
8348
+ float * partial = partials + ir * partial_stride;
8349
+ partial[0] = M;
8350
+ partial[1] = S;
8351
+ memcpy(partial + 2, VKQ32, DV * sizeof(float));
8352
+ } else {
8353
+ // V /= S
8354
+ const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
8355
+ ggml_vec_scale_f32(DV, VKQ32, S_inv);
8140
8356
 
8141
- // dst indices
8142
- const int i1 = iq1;
8143
- const int i2 = iq2;
8144
- const int i3 = iq3;
8357
+ // dst indices
8358
+ const int i1 = iq1;
8359
+ const int i2 = iq2;
8360
+ const int i3 = iq3;
8361
+
8362
+ // permute(0, 2, 1, 3)
8363
+ memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
8364
+ }
8365
+ }
8366
+ }
8367
+
8368
+ static void ggml_compute_forward_flash_attn_ext_tiled(
8369
+ const ggml_compute_params * params,
8370
+ ggml_tensor * dst,
8371
+ int ir0, int ir1) {
8372
+ const ggml_tensor * q = dst->src[0];
8373
+ const ggml_tensor * k = dst->src[1];
8374
+ const ggml_tensor * v = dst->src[2];
8375
+ const ggml_tensor * mask = dst->src[3];
8376
+ const ggml_tensor * sinks = dst->src[4];
8377
+
8378
+ GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
8379
+ GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
8380
+ GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
8381
+ GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
8382
+ GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
8383
+ GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
8384
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
8385
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
8386
+
8387
+ const int64_t DK = nek0;
8388
+ const int64_t DV = nev0;
8389
+ const int64_t N = neq1;
8390
+
8391
+ GGML_ASSERT(ne0 == DV);
8392
+ GGML_ASSERT(ne2 == N);
8393
+
8394
+ // input tensor rows must be contiguous
8395
+ GGML_ASSERT(nbq0 == ggml_type_size(q->type));
8396
+ GGML_ASSERT(nbk0 == ggml_type_size(k->type));
8397
+ GGML_ASSERT(nbv0 == ggml_type_size(v->type));
8398
+
8399
+ GGML_ASSERT(neq0 == DK);
8400
+ GGML_ASSERT(nek0 == DK);
8401
+ GGML_ASSERT(nev0 == DV);
8402
+
8403
+ GGML_ASSERT(neq1 == N);
8404
+
8405
+ // dst cannot be transposed or permuted
8406
+ GGML_ASSERT(nb0 == sizeof(float));
8407
+ GGML_ASSERT(nb0 <= nb1);
8408
+ GGML_ASSERT(nb1 <= nb2);
8409
+ GGML_ASSERT(nb2 <= nb3);
8410
+
8411
+ GGML_ASSERT(k->type == v->type);
8412
+ const ggml_type kv_type = k->type;
8145
8413
 
8146
- // original
8147
- //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
8148
8414
 
8149
- // permute(0, 2, 1, 3)
8150
- memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
8415
+ // broadcast factors
8416
+ const int64_t rk2 = neq2/nek2;
8417
+ const int64_t rk3 = neq3/nek3;
8418
+
8419
+ const int64_t rv2 = neq2/nev2;
8420
+ const int64_t rv3 = neq3/nev3;
8421
+
8422
+ float scale = 1.0f;
8423
+ float max_bias = 0.0f;
8424
+ float logit_softcap = 0.0f;
8425
+
8426
+ memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
8427
+ memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
8428
+ memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
8429
+
8430
+ if (logit_softcap != 0) {
8431
+ scale /= logit_softcap;
8432
+ }
8433
+
8434
+ const uint32_t n_head = neq2;
8435
+ const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
8436
+
8437
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
8438
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
8439
+
8440
+ int ith = params->ith;
8441
+
8442
+ static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q;
8443
+ static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV;
8444
+
8445
+ int ir = ir0;
8446
+ while (ir < ir1) {
8447
+ // q indices for the start of this tile
8448
+ const int iq3 = ir/(neq2*neq1);
8449
+ const int iq2 = (ir - iq3*neq2*neq1)/neq1;
8450
+ const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
8451
+
8452
+ // Number of valid rows in this tile:
8453
+ // - limited by tile size (Q_TILE_SZ)
8454
+ // - limited by chunk boundary (ir1 - ir)
8455
+ // - limited by head boundary (neq1 - iq1) to avoid crossing into next head
8456
+ const int tile_rows = MIN(Q_TILE_SZ, MIN((int)(ir1 - ir), (int)(neq1 - iq1)));
8457
+ GGML_ASSERT(tile_rows > 0);
8458
+
8459
+ const uint32_t h = iq2; // head index
8460
+ const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
8461
+
8462
+ float S[Q_TILE_SZ];
8463
+ float M[Q_TILE_SZ];
8464
+
8465
+ for (int i = 0 ; i < Q_TILE_SZ; ++i) {
8466
+ S[i] = 0.;
8467
+ M[i] = -INFINITY;
8468
+ }
8469
+
8470
+ // Per-thread scratch layout:
8471
+ // Q_q: Q_TILE_SZ * DK (converted Q tile — F32 for GEMM, KV type for scalar)
8472
+ // KQ: Q_TILE_SZ * KV_TILE_SZ (attention scores in float)
8473
+ // mask: Q_TILE_SZ * KV_TILE_SZ (mask in float)
8474
+ // VKQ32: Q_TILE_SZ * DV (FP32 output accumulator)
8475
+ // V32: KV_TILE_SZ * DV (F32 buffer for V tile)
8476
+ // K_f32: KV_TILE_SZ * DK (F32 buffer for K tile — GEMM path)
8477
+ float * base = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + KV_TILE_SZ*DK + CACHE_LINE_SIZE_F32);
8478
+
8479
+ void * Q_q = base;
8480
+ float * KQ = (float *)((char *)base + Q_TILE_SZ * DK * sizeof(float));
8481
+ float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ;
8482
+ float * VKQ32 = mask32 + Q_TILE_SZ * KV_TILE_SZ;
8483
+ float * V32 = VKQ32 + Q_TILE_SZ * DV;
8484
+ float * K_f32 = V32 + KV_TILE_SZ * DV;
8485
+
8486
+ memset(VKQ32, 0, Q_TILE_SZ * DV * sizeof(float));
8487
+ memset(mask32, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
8488
+
8489
+ // k indices
8490
+ const int ik3 = iq3 / rk3;
8491
+ const int ik2 = iq2 / rk2;
8492
+
8493
+ // v indices
8494
+ const int iv3 = iq3 / rv3;
8495
+ const int iv2 = iq2 / rv2;
8496
+
8497
+ {
8498
+ float * Q_f32 = (float *)Q_q;
8499
+ for (int tq = 0; tq < tile_rows; tq++) {
8500
+ const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
8501
+ memcpy(Q_f32 + tq * DK, pq, DK * sizeof(float));
8502
+ }
8503
+ for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
8504
+ memset(Q_f32 + tq * DK, 0, DK * sizeof(float));
8505
+ }
8506
+ }
8507
+
8508
+ memset(K_f32, 0, DK * KV_TILE_SZ * sizeof(float));
8509
+ memset(V32, 0, KV_TILE_SZ * DV * sizeof(float));
8510
+
8511
+ for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) {
8512
+ const int kv_tile = (int)std::min((int64_t)KV_TILE_SZ, nek1 - ic);
8513
+
8514
+ // skip the tile entirely if all the masks are -inf
8515
+ if (mask) {
8516
+ bool can_skip = true;
8517
+ for (int tq = 0; tq < tile_rows; tq++) {
8518
+ const ggml_fp16_t * mp_row = (const ggml_fp16_t *)((const char *) mask->data + (iq1 + tq)*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]);
8519
+ for (int tk = 0; tk < kv_tile; tk++) {
8520
+ mask32[tq * KV_TILE_SZ + tk] = slope * GGML_CPU_FP16_TO_FP32(mp_row[ic + tk]);
8521
+ if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY) {
8522
+ can_skip = false;
8523
+ }
8524
+ }
8525
+ // Pad remaining mask entries with -inf
8526
+ for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) {
8527
+ mask32[tq * KV_TILE_SZ + tk] = -INFINITY;
8528
+ }
8529
+ }
8530
+
8531
+ if (can_skip) {
8532
+ continue;
8533
+ }
8534
+ }
8535
+
8536
+ // Pack K tile transposed: K_f32[dk][kv] so KV_TILE is contiguous (SIMD dim)
8537
+ // Zero-pad the last tile so the GEMM always operates on KV_TILE_SZ columns
8538
+ for (int tk = 0; tk < kv_tile; tk++) {
8539
+ const char * k_data = (const char *)k->data + (ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3;
8540
+ if (kv_type == GGML_TYPE_F16) {
8541
+ const ggml_fp16_t * k_f16 = (const ggml_fp16_t *)k_data;
8542
+ for (int64_t dk = 0; dk < DK; dk++) {
8543
+ K_f32[dk * KV_TILE_SZ + tk] = GGML_CPU_FP16_TO_FP32(k_f16[dk]);
8544
+ }
8545
+ } else {
8546
+ const float * k_f32_src = (const float *)k_data;
8547
+ for (int64_t dk = 0; dk < DK; dk++) {
8548
+ K_f32[dk * KV_TILE_SZ + tk] = k_f32_src[dk];
8549
+ }
8550
+ }
8551
+ }
8552
+ memset(KQ, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
8553
+ simd_gemm(KQ, (const float *)Q_q, K_f32, Q_TILE_SZ, DK, KV_TILE_SZ);
8554
+ ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, scale);
8555
+
8556
+ // Set padded KQ entries to -inf so softmax gives them zero weight
8557
+ if (kv_tile < KV_TILE_SZ) {
8558
+ for (int tq = 0; tq < Q_TILE_SZ; tq++) {
8559
+ for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) {
8560
+ KQ[tq * KV_TILE_SZ + tk] = -INFINITY;
8561
+ }
8562
+ }
8563
+ }
8564
+
8565
+ if (logit_softcap != 0.0f) {
8566
+ ggml_vec_tanh_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, KQ);
8567
+ ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, logit_softcap);
8568
+ }
8569
+
8570
+ if (mask) {
8571
+ ggml_vec_add_f32(tile_rows * KV_TILE_SZ, KQ, KQ, mask32);
8572
+ }
8573
+
8574
+ bool skip[Q_TILE_SZ] = {};
8575
+
8576
+ for (int tq = 0; tq < Q_TILE_SZ; tq++) {
8577
+ float * kq_row = KQ + tq * KV_TILE_SZ;
8578
+
8579
+ float tile_max;
8580
+ ggml_vec_max_f32(KV_TILE_SZ, &tile_max, kq_row);
8581
+
8582
+ if (tile_max == -INFINITY) {
8583
+ skip[tq] = true;
8584
+ continue;
8585
+ }
8586
+
8587
+ const float Mold = M[tq];
8588
+ const float Mnew = fmaxf(Mold, tile_max);
8589
+
8590
+ if (Mnew > Mold) {
8591
+ const float ms = expf(Mold - Mnew);
8592
+ ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms);
8593
+ S[tq] *= ms;
8594
+ }
8595
+ M[tq] = Mnew;
8596
+
8597
+
8598
+ S[tq] += ggml_vec_soft_max_f32(KV_TILE_SZ, kq_row, kq_row, Mnew);
8599
+ }
8600
+
8601
+ // V accumulation: VKQ32 += softmax(KQ) * V
8602
+ // Pack V tile to contiguous F32, zero-padded
8603
+ for (int tk = 0; tk < kv_tile; tk++) {
8604
+ const char * v_data = (const char *)v->data + (ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3;
8605
+ if (kv_type == GGML_TYPE_F16) {
8606
+ ggml_fp16_to_fp32_row((const ggml_fp16_t *)v_data, V32 + tk * DV, DV);
8607
+ } else {
8608
+ memcpy(V32 + tk * DV, v_data, DV * sizeof(float));
8609
+ }
8610
+ }
8611
+ for (int tq = 0; tq < Q_TILE_SZ; tq++) {
8612
+ if (skip[tq]) {
8613
+ memset(KQ + tq * KV_TILE_SZ, 0, KV_TILE_SZ * sizeof(float));
8614
+ }
8615
+ }
8616
+ simd_gemm(VKQ32, KQ, V32, Q_TILE_SZ, KV_TILE_SZ, DV);
8617
+ }
8618
+
8619
+ // sinks (apply only to valid rows in the tile)
8620
+ if (sinks) {
8621
+ const float s = ((float *)((char *) sinks->data))[h];
8622
+
8623
+ for (int tq = 0; tq < tile_rows; tq++) {
8624
+ float ms = 1.0f;
8625
+ float vs = 1.0f;
8626
+
8627
+ if (s > M[tq]) {
8628
+ ms = expf(M[tq] - s);
8629
+ ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms);
8630
+ } else {
8631
+ vs = expf(s - M[tq]);
8632
+ }
8633
+
8634
+ S[tq] = S[tq] * ms + vs;
8635
+ }
8636
+ }
8637
+
8638
+ for (int tq = 0; tq < tile_rows; tq++) {
8639
+ // V /= S
8640
+ const float S_inv = S[tq] == 0.0f ? 0.0f : 1.0f / S[tq];
8641
+ ggml_vec_scale_f32(DV, VKQ32 + tq * DV, S_inv);
8642
+
8643
+ // dst indices
8644
+ const int i1 = iq1 + tq;
8645
+ const int i2 = iq2;
8646
+ const int i3 = iq3;
8647
+
8648
+ // permute(0, 2, 1, 3)
8649
+ memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32 + tq * DV, nb1);
8650
+ }
8651
+
8652
+ ir += tile_rows;
8653
+ }
8654
+ }
8655
+
8656
+ // Reduction function: combines partial results across KV chunks
8657
+ // Partials layout in wdata: [n_q_heads][n_chunks][2 + DV]
8658
+ static void ggml_flash_attn_ext_reduce_partials(
8659
+ const ggml_compute_params * params,
8660
+ ggml_tensor * dst,
8661
+ const int64_t n_chunks,
8662
+ const int64_t chunk_size) {
8663
+
8664
+ const ggml_tensor * q = dst->src[0];
8665
+ const ggml_tensor * k = dst->src[1];
8666
+ const ggml_tensor * v = dst->src[2];
8667
+
8668
+ const int64_t DK = k->ne[0];
8669
+ const int64_t DV = v->ne[0];
8670
+ const int64_t nek1 = k->ne[1];
8671
+ const int64_t n_q_heads = q->ne[2];
8672
+
8673
+ const int ith = params->ith;
8674
+ const int nth = params->nth;
8675
+
8676
+ const int64_t wdata_per_thread = DK + 2*DV + CACHE_LINE_SIZE_F32;
8677
+ float * thread_wdata = (float *) params->wdata + ith * wdata_per_thread;
8678
+
8679
+ const int64_t partials_offset = nth * (DK + 2*DV + CACHE_LINE_SIZE_F32);
8680
+ const int64_t partial_size = 2 + DV;
8681
+ const float * partials_base = (const float *) params->wdata + partials_offset;
8682
+
8683
+ // Output layout
8684
+ const int64_t ne1 = dst->ne[1];
8685
+ const int64_t ne2 = dst->ne[2];
8686
+ const size_t nb1 = dst->nb[1];
8687
+
8688
+ // Each thread reduces a subset of query heads
8689
+ for (int64_t q_head = ith; q_head < n_q_heads; q_head += nth) {
8690
+ float M_final = -INFINITY;
8691
+ float S_final = 0.0f;
8692
+ float * VKQ_final = thread_wdata;
8693
+ memset(VKQ_final, 0, DV * sizeof(float));
8694
+
8695
+ // Combine partials from all chunks
8696
+ for (int64_t chunk_idx = 0; chunk_idx < n_chunks; ++chunk_idx) {
8697
+ const int64_t ic_start = chunk_idx * chunk_size;
8698
+ if (ic_start >= nek1) continue;
8699
+
8700
+ const float * partial = partials_base + (q_head * n_chunks + chunk_idx) * partial_size;
8701
+ const float M_chunk = partial[0];
8702
+ const float S_chunk = partial[1];
8703
+ const float * VKQ_chunk = partial + 2;
8704
+
8705
+ if (S_chunk == 0.0f) continue;
8706
+
8707
+ const float M_new = fmaxf(M_final, M_chunk);
8708
+ const float scale_old = expf(M_final - M_new);
8709
+ const float scale_new = expf(M_chunk - M_new);
8710
+
8711
+ for (int64_t d = 0; d < DV; ++d) {
8712
+ VKQ_final[d] = VKQ_final[d] * scale_old + VKQ_chunk[d] * scale_new;
8713
+ }
8714
+ S_final = S_final * scale_old + S_chunk * scale_new;
8715
+ M_final = M_new;
8716
+ }
8717
+
8718
+ // Normalize and write to output
8719
+ if (S_final != 0.0f) {
8720
+ const float S_inv = 1.0f / S_final;
8721
+ ggml_vec_scale_f32(DV, VKQ_final, S_inv);
8722
+ }
8723
+ // iq1=0, iq3=0 for decode
8724
+ memcpy((char *) dst->data + (0*ne2*ne1 + q_head + 0*ne1)*nb1, VKQ_final, nb1);
8725
+ }
8726
+ }
8727
+
8728
+ static void ggml_compute_forward_flash_attn_ext_f16(
8729
+ const ggml_compute_params * params,
8730
+ ggml_tensor * dst) {
8731
+
8732
+ const ggml_tensor * q = dst->src[0];
8733
+ const ggml_tensor * k = dst->src[1];
8734
+ const ggml_tensor * v = dst->src[2];
8735
+
8736
+ GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
8737
+ GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
8738
+ GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
8739
+ GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
8740
+ GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
8741
+ GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
8742
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
8743
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
8744
+
8745
+ const int64_t DK = nek0;
8746
+ const int64_t DV = nev0;
8747
+ const int64_t N = neq1;
8748
+
8749
+
8750
+ GGML_ASSERT(ne0 == DV);
8751
+ GGML_ASSERT(ne2 == N);
8752
+
8753
+ // input tensor rows must be contiguous
8754
+ GGML_ASSERT(nbq0 == ggml_type_size(q->type));
8755
+ GGML_ASSERT(nbk0 == ggml_type_size(k->type));
8756
+ GGML_ASSERT(nbv0 == ggml_type_size(v->type));
8757
+
8758
+ GGML_ASSERT(neq0 == DK);
8759
+ GGML_ASSERT(nek0 == DK);
8760
+ GGML_ASSERT(nev0 == DV);
8761
+
8762
+ GGML_ASSERT(neq1 == N);
8763
+
8764
+ // dst cannot be transposed or permuted
8765
+ GGML_ASSERT(nb0 == sizeof(float));
8766
+ GGML_ASSERT(nb0 <= nb1);
8767
+ GGML_ASSERT(nb1 <= nb2);
8768
+ GGML_ASSERT(nb2 <= nb3);
8769
+
8770
+ const int ith = params->ith;
8771
+ const int nth = params->nth;
8772
+
8773
+ // When use_ref is set, force the vec-only reference implementation (no tiling, no KV-chunking)
8774
+ const bool use_ref = params->use_ref;
8775
+
8776
+ const bool kv_is_f32_or_f16 = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16);
8777
+ const bool use_split_kv_path = !use_ref && (neq1 == 1 && neq3 == 1) && kv_is_f32_or_f16 && (k->type == v->type) && q->type == GGML_TYPE_F32 && nek1 >= 512;
8778
+
8779
+ if (use_split_kv_path) {
8780
+ const int64_t chunk_size = (nek1 + nth - 1) / nth;
8781
+
8782
+ // Partials buffer layout: [q_head][kv_chunk][M, S, VKQ]
8783
+ const int64_t partial_size = 2 + DV;
8784
+ float * partials_base = (float *) params->wdata + nth * (DK + 2*DV + CACHE_LINE_SIZE_F32);
8785
+
8786
+ const int64_t ic_start = ith * chunk_size;
8787
+ const int64_t ic_end = std::min(ic_start + chunk_size, nek1);
8788
+
8789
+ const int64_t partial_stride = nth * partial_size;
8790
+ float * chunk_partials = partials_base + ith * partial_size;
8791
+
8792
+ if (ic_start < nek1) {
8793
+ for (int64_t q_head = 0; q_head < neq2; q_head++) {
8794
+ ggml_compute_forward_flash_attn_ext_f16_one_chunk(
8795
+ params, dst, q_head, q_head + 1, ic_start, ic_end,
8796
+ chunk_partials, partial_stride);
8797
+ }
8798
+ } else {
8799
+ for (int64_t q_head = 0; q_head < neq2; q_head++) {
8800
+ float * q_partials = chunk_partials + q_head * partial_stride;
8801
+ q_partials[0] = -INFINITY; // M
8802
+ q_partials[1] = 0.0f; // S
8803
+ }
8804
+ }
8805
+
8806
+ ggml_barrier(params->threadpool);
8807
+ ggml_flash_attn_ext_reduce_partials(params, dst, nth, chunk_size);
8808
+ } else {
8809
+
8810
+ // total rows in q
8811
+ const int64_t nr = neq1*neq2*neq3;
8812
+
8813
+ // disable for NUMA
8814
+ const bool disable_chunking = ggml_is_numa();
8815
+
8816
+ // 4x chunks per thread
8817
+ int nth_scaled = nth * 4;
8818
+ int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
8819
+ int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
8820
+
8821
+ if (nth == 1 || nchunk < nth || disable_chunking) {
8822
+ nchunk = nth;
8823
+ }
8824
+
8825
+ if (ith == 0) {
8826
+ ggml_threadpool_chunk_set(params->threadpool, nth);
8827
+ }
8828
+
8829
+ ggml_barrier(params->threadpool);
8830
+
8831
+ const int64_t dr = (nr + nchunk - 1) / nchunk;
8832
+
8833
+ static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
8834
+ bool use_tiled = !use_ref &&
8835
+ (q->type == GGML_TYPE_F32 &&
8836
+ kv_is_f32_or_f16 &&
8837
+ k->type == v->type &&
8838
+ neq1 >= Q_TILE_SZ);
8839
+ #ifdef GGML_SIMD
8840
+ use_tiled &= (DV % GGML_F32_EPR == 0);
8841
+ #endif
8842
+ int current_chunk = ith;
8843
+
8844
+ while (current_chunk < nchunk) {
8845
+ const int64_t ir0 = dr * current_chunk;
8846
+ const int64_t ir1 = MIN(ir0 + dr, nr);
8847
+
8848
+ if (use_tiled) {
8849
+ ggml_compute_forward_flash_attn_ext_tiled(params, dst, ir0, ir1);
8850
+ } else {
8851
+ ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1, 0, nek1, nullptr, 0);
8852
+ }
8853
+
8854
+ current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
8855
+ }
8151
8856
  }
8152
8857
  }
8153
8858
 
@@ -8637,7 +9342,7 @@ static void ggml_compute_forward_ssm_scan_f32(
8637
9342
  // n_head
8638
9343
  for (int h = ih0; h < ih1; ++h) {
8639
9344
  // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8640
- const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
9345
+ const float dt_soft_plus = ggml_compute_softplus_f32(dt[h]);
8641
9346
  const float dA = expf(dt_soft_plus * A[h]);
8642
9347
  const int g = h / (nh / ng); // repeat_interleave
8643
9348
 
@@ -8734,7 +9439,7 @@ static void ggml_compute_forward_ssm_scan_f32(
8734
9439
  // n_head
8735
9440
  for (int h = ih0; h < ih1; ++h) {
8736
9441
  // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8737
- const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
9442
+ const float dt_soft_plus = ggml_compute_softplus_f32(dt[h]);
8738
9443
  const int g = h / (nh / ng); // repeat_interleave
8739
9444
 
8740
9445
  // dim
@@ -8928,7 +9633,7 @@ void ggml_compute_forward_win_unpart(
8928
9633
  }
8929
9634
  }
8930
9635
 
8931
- //gmml_compute_forward_unary
9636
+ //ggml_compute_forward_unary
8932
9637
 
8933
9638
  void ggml_compute_forward_unary(
8934
9639
  const ggml_compute_params * params,
@@ -8997,6 +9702,34 @@ void ggml_compute_forward_unary(
8997
9702
  {
8998
9703
  ggml_compute_forward_exp(params, dst);
8999
9704
  } break;
9705
+ case GGML_UNARY_OP_FLOOR:
9706
+ {
9707
+ ggml_compute_forward_floor(params, dst);
9708
+ } break;
9709
+ case GGML_UNARY_OP_CEIL:
9710
+ {
9711
+ ggml_compute_forward_ceil(params, dst);
9712
+ } break;
9713
+ case GGML_UNARY_OP_ROUND:
9714
+ {
9715
+ ggml_compute_forward_round(params, dst);
9716
+ } break;
9717
+ case GGML_UNARY_OP_TRUNC:
9718
+ {
9719
+ ggml_compute_forward_trunc(params, dst);
9720
+ } break;
9721
+ case GGML_UNARY_OP_XIELU:
9722
+ {
9723
+ ggml_compute_forward_xielu(params, dst);
9724
+ } break;
9725
+ case GGML_UNARY_OP_EXPM1:
9726
+ {
9727
+ ggml_compute_forward_expm1(params, dst);
9728
+ } break;
9729
+ case GGML_UNARY_OP_SOFTPLUS:
9730
+ {
9731
+ ggml_compute_forward_softplus(params, dst);
9732
+ } break;
9000
9733
  default:
9001
9734
  {
9002
9735
  GGML_ABORT("fatal error");
@@ -9593,6 +10326,265 @@ void ggml_compute_forward_gla(
9593
10326
  }
9594
10327
  }
9595
10328
 
10329
+ static void ggml_compute_forward_solve_tri_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst) {
10330
+ const struct ggml_tensor * src0 = dst->src[0]; // A (lower triangular)
10331
+ const struct ggml_tensor * src1 = dst->src[1]; // B (RHS)
10332
+
10333
+ GGML_TENSOR_BINARY_OP_LOCALS;
10334
+
10335
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
10336
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
10337
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
10338
+
10339
+ GGML_ASSERT(ne00 == ne01); // A must be square
10340
+ GGML_ASSERT(ne0 == ne10); // solution cols == B cols
10341
+ GGML_ASSERT(ne1 == ne11); // solution rows == B rows
10342
+
10343
+ GGML_ASSERT(ne02 == ne12 && ne12 == ne2);
10344
+ GGML_ASSERT(ne03 == ne13 && ne13 == ne3);
10345
+
10346
+ const int ith = params->ith;
10347
+ const int nth = params->nth;
10348
+
10349
+ const int64_t k = ne10; // number of RHS columns
10350
+ const int64_t n = ne11; // A is n×n
10351
+ const int64_t nr = ne02 * ne03 * k; // we're parallelizing on columns here, so seq x token x column will be the unit
10352
+
10353
+ // chunks per thread
10354
+ const int64_t dr = (nr + nth - 1)/nth;
10355
+
10356
+ // chunk range for this thread
10357
+ const int64_t ir0 = dr*ith;
10358
+ const int64_t ir1 = MIN(ir0 + dr, nr);
10359
+
10360
+ const float * A = (const float *) src0->data; // [n, n, B1, B2]
10361
+ const float * B = (const float *) src1->data; // [n, k, B1, B2]
10362
+ float * X = ( float *) dst->data; // [n, k, B1, B2]
10363
+
10364
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
10365
+ const int64_t i03 = ir/(ne02*k);
10366
+ const int64_t i02 = (ir - i03*ne02*k)/k;
10367
+ const int64_t i01 = (ir - i03*ne02*k - i02*k);
10368
+
10369
+ const float * A_batch = A + i02 * nb02 / sizeof(float) + i03 * nb03 / sizeof(float);
10370
+ const float * B_batch = B + i02 * nb12 / sizeof(float) + i03 * nb13 / sizeof(float);
10371
+
10372
+ float * X_batch = X + i02 * nb2 / sizeof(float) + i03 * nb3 / sizeof(float);
10373
+
10374
+ for (int64_t i00 = 0; i00 < n; ++i00) {
10375
+ float sum = 0.0f;
10376
+ for (int64_t t = 0; t < i00; ++t) {
10377
+ sum += A_batch[i00 * n + t] * X_batch[t * k + i01];
10378
+ }
10379
+
10380
+ const float diag = A_batch[i00 * n + i00];
10381
+ assert(diag != 0.0f && "Zero diagonal in triangular matrix");
10382
+
10383
+ X_batch[i00 * k + i01] = (B_batch[i00 * k + i01] - sum) / diag;
10384
+ }
10385
+ }
10386
+ }
10387
+
10388
+ void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst) {
10389
+ const ggml_tensor * src0 = dst->src[0];
10390
+ const ggml_tensor * src1 = dst->src[1];
10391
+
10392
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
10393
+ ggml_compute_forward_solve_tri_f32(params, dst);
10394
+ } else {
10395
+ GGML_ABORT("fatal error");
10396
+ }
10397
+ }
10398
+
10399
+ // ggml_compute_forward_gated_delta_net
10400
+ static void ggml_compute_forward_gated_delta_net_one_chunk(
10401
+ const ggml_compute_params * params,
10402
+ ggml_tensor * dst,
10403
+ int64_t ir0,
10404
+ int64_t ir1) {
10405
+
10406
+ ggml_tensor * src_q = dst->src[0];
10407
+ ggml_tensor * src_k = dst->src[1];
10408
+ ggml_tensor * src_v = dst->src[2];
10409
+ ggml_tensor * src_g = dst->src[3];
10410
+ ggml_tensor * src_beta = dst->src[4];
10411
+ ggml_tensor * src_state = dst->src[5];
10412
+
10413
+ const int64_t S_v = src_v->ne[0];
10414
+ const int64_t H = src_v->ne[1];
10415
+ const int64_t n_tokens = src_v->ne[2];
10416
+ const int64_t n_seqs = src_v->ne[3];
10417
+
10418
+ GGML_ASSERT(ggml_is_contiguous_rows(src_q));
10419
+ GGML_ASSERT(ggml_is_contiguous_rows(src_k));
10420
+ GGML_ASSERT(ggml_is_contiguous_rows(src_v));
10421
+ GGML_ASSERT(ggml_is_contiguous(src_g));
10422
+ GGML_ASSERT(ggml_is_contiguous(src_beta));
10423
+ GGML_ASSERT(ggml_is_contiguous(src_state));
10424
+
10425
+ GGML_ASSERT(src_g->ne[0] == 1 || src_g->ne[0] == S_v);
10426
+ GGML_ASSERT(src_beta->ne[0] == 1);
10427
+
10428
+ GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne);
10429
+ GGML_TENSOR_LOCALS(size_t, nbq, src_q, nb);
10430
+ GGML_TENSOR_LOCALS(int64_t, nek, src_k, ne);
10431
+ GGML_TENSOR_LOCALS(size_t, nbk, src_k, nb);
10432
+ GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne);
10433
+ GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb);
10434
+ GGML_TENSOR_LOCALS(int64_t, neg, src_g, ne);
10435
+ GGML_TENSOR_LOCALS(size_t, nbg, src_g, nb);
10436
+ GGML_TENSOR_LOCALS(size_t, nbb, src_beta, nb);
10437
+
10438
+ const bool kda = (neg0 == S_v);
10439
+
10440
+ // scratch layout per thread: [delta(S_v)]
10441
+ const int64_t scratch_per_thread = S_v;
10442
+ const int ith = params->ith;
10443
+
10444
+ float * delta = (float *)params->wdata + ith * scratch_per_thread + CACHE_LINE_SIZE_F32;
10445
+
10446
+ // output layout: [attn_scores | new_states]
10447
+ // attn_scores: S_v * H * n_tokens * n_seqs floats
10448
+ // new_states: S_v * S_v * H * n_seqs floats
10449
+ const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
10450
+ float * attn_out_base = (float *)dst->data;
10451
+ float * state_out_base = (float *)dst->data + attn_score_elems;
10452
+
10453
+ const float * state_in_base = (const float *)src_state->data;
10454
+
10455
+ //const int64_t rq1 = nev1 / neq1;
10456
+ //const int64_t rk1 = nev1 / nek1;
10457
+ const int64_t rq3 = nev3 / neq3;
10458
+ const int64_t rk3 = nev3 / nek3;
10459
+
10460
+ const float scale = 1.0f / sqrtf((float) S_v);
10461
+
10462
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
10463
+ const int64_t iv1 = ir % H; // head_index
10464
+ const int64_t iv3 = ir / H; // sequence
10465
+
10466
+ const int64_t iq1 = iv1 % neq1;
10467
+ const int64_t ik1 = iv1 % nek1;
10468
+
10469
+ const int64_t iq3 = iv3 / rq3;
10470
+ const int64_t ik3 = iv3 / rk3;
10471
+
10472
+ float * s_out = state_out_base + (iv3 * H + iv1) * S_v * S_v;
10473
+
10474
+ // copy input state into output buffer and operate in-place
10475
+ const float * s_in = state_in_base + (iv3 * H + iv1) * S_v * S_v;
10476
+ memcpy(s_out, s_in, S_v * S_v * sizeof(float));
10477
+
10478
+ // attn output pointer for first token of this (head, seq)
10479
+ float * attn_data = attn_out_base + (iv3 * n_tokens * H + iv1) * S_v;
10480
+
10481
+ for (int64_t t = 0; t < n_tokens; t++) {
10482
+ const float * q_d = (const float *)((const char *)src_q->data + iq3 * nbq3 + t * nbq2 + iq1 * nbq1);
10483
+ const float * k_d = (const float *)((const char *)src_k->data + ik3 * nbk3 + t * nbk2 + ik1 * nbk1);
10484
+ const float * v_d = (const float *)((const char *)src_v->data + iv3 * nbv3 + t * nbv2 + iv1 * nbv1);
10485
+
10486
+ const float beta_val = *(const float *)((const char *)src_beta->data + iv3 * nbb3 + t * nbb2 + iv1 * nbb1);
10487
+ const float * g_d = (const float *)((const char *)src_g->data + iv3 * nbg3 + t * nbg2 + iv1 * nbg1);
10488
+
10489
+ // state is stored transposed: s_out[j*S_v + i] = S[i][j]
10490
+ // so row j of s_out = column j of S (contiguous access)
10491
+
10492
+ if (kda) {
10493
+ // precompute exp(g) into delta scratch (reused below)
10494
+ for (int64_t i = 0; i < S_v; ++i) {
10495
+ delta[i] = expf(g_d[i]);
10496
+ }
10497
+ // S[i][:] *= exp(g[i]) => for each row j of M: M[j][i] *= exp(g[i])
10498
+ for (int64_t j = 0; j < S_v; ++j) {
10499
+ ggml_vec_mul_f32(S_v, &s_out[j * S_v], &s_out[j * S_v], delta);
10500
+ }
10501
+ } else {
10502
+ ggml_vec_scale_f32(S_v * S_v, s_out, expf(g_d[0]));
10503
+ }
10504
+
10505
+ // delta[j] = sum_i S[i][j] * k[i] = dot(row j of M, k)
10506
+ for (int64_t j = 0; j < S_v; ++j) {
10507
+ float sum = 0.0f;
10508
+ ggml_vec_dot_f32(S_v, &sum, 0, &s_out[j * S_v], 0, k_d, 0, 1);
10509
+ delta[j] = (v_d[j] - sum) * beta_val;
10510
+ }
10511
+
10512
+ // outer product: S[i][j] += k[i] * delta[j] => M[j][i] += delta[j] * k[i]
10513
+ for (int64_t j = 0; j < S_v; ++j) {
10514
+ ggml_vec_mad_f32(S_v, &s_out[j * S_v], k_d, delta[j]);
10515
+ }
10516
+
10517
+ // attn_out[j] = sum_i S[i][j] * q[i] = dot(row j of M, q)
10518
+ for (int64_t j = 0; j < S_v; ++j) {
10519
+ float sum = 0.0f;
10520
+ ggml_vec_dot_f32(S_v, &sum, 0, &s_out[j * S_v], 0, q_d, 0, 1);
10521
+ attn_data[j] = sum * scale;
10522
+ }
10523
+
10524
+ attn_data += S_v * H; // advance to next token
10525
+ }
10526
+ }
10527
+ }
10528
+
10529
+
10530
+ static void ggml_compute_forward_gated_delta_net_f32(
10531
+ const ggml_compute_params * params,
10532
+ ggml_tensor * dst) {
10533
+
10534
+ ggml_tensor * V = dst->src[2];
10535
+ int64_t nr = V->ne[1] * V->ne[3];
10536
+
10537
+ // disable for NUMA
10538
+ const bool disable_chunking = ggml_is_numa();
10539
+
10540
+ int nth = params->nth;
10541
+ int ith = params->ith;
10542
+
10543
+ // 4x chunks per thread
10544
+ int nth_scaled = nth * 4;
10545
+ int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
10546
+ int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
10547
+
10548
+ if (nth == 1 || nchunk < nth || disable_chunking) {
10549
+ nchunk = nth;
10550
+ }
10551
+
10552
+ if (ith == 0) {
10553
+ ggml_threadpool_chunk_set(params->threadpool, nth);
10554
+ }
10555
+
10556
+ ggml_barrier(params->threadpool);
10557
+
10558
+ const int64_t dr = (nr + nchunk - 1) / nchunk;
10559
+
10560
+ int current_chunk = ith;
10561
+
10562
+ while (current_chunk < nchunk) {
10563
+ const int64_t ir0 = dr * current_chunk;
10564
+ const int64_t ir1 = MIN(ir0 + dr, nr);
10565
+
10566
+ ggml_compute_forward_gated_delta_net_one_chunk(params, dst, ir0, ir1);
10567
+ current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
10568
+ }
10569
+ }
10570
+
10571
+ void ggml_compute_forward_gated_delta_net(
10572
+ const ggml_compute_params * params,
10573
+ ggml_tensor * dst) {
10574
+ const ggml_tensor * src0 = dst->src[0];
10575
+
10576
+ switch (src0->type) {
10577
+ case GGML_TYPE_F32:
10578
+ {
10579
+ ggml_compute_forward_gated_delta_net_f32(params, dst);
10580
+ } break;
10581
+ default:
10582
+ {
10583
+ GGML_ABORT("fatal error");
10584
+ }
10585
+ }
10586
+ }
10587
+
9596
10588
  // ggml_compute_forward_rwkv_wkv7
9597
10589
 
9598
10590
  static void ggml_compute_forward_rwkv_wkv7_f32(
@@ -9918,7 +10910,7 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
9918
10910
  assert(!isnan(s0[i]));
9919
10911
  assert(!isnan(s1[i]));
9920
10912
  }
9921
- #endif
10913
+ #endif // NDEBUG
9922
10914
 
9923
10915
  float max = -INFINITY;
9924
10916
  ggml_vec_max_f32(nc, &max, s0);
@@ -9937,7 +10929,7 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
9937
10929
  assert(!isnan(st[i]));
9938
10930
  assert(!isinf(st[i]));
9939
10931
  }
9940
- #endif
10932
+ #endif // NDEBUG
9941
10933
  }
9942
10934
  sums[ith] = sum_thread;
9943
10935
  ggml_barrier(params->threadpool);
@@ -10010,7 +11002,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
10010
11002
  assert(!isnan(s0[i]));
10011
11003
  assert(!isnan(s1[i]));
10012
11004
  }
10013
- #endif
11005
+ #endif // NDEBUG
10014
11006
 
10015
11007
  // soft_max
10016
11008
  float max = -INFINITY;
@@ -10028,7 +11020,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
10028
11020
  assert(!isnan(ds0[i]));
10029
11021
  assert(!isinf(ds0[i]));
10030
11022
  }
10031
- #endif
11023
+ #endif // NDEBUG
10032
11024
  }
10033
11025
  }
10034
11026