whispercpp 1.3.4 → 1.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (630) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +60 -43
  3. data/ext/extconf.rb +2 -2
  4. data/ext/ruby_whisper.c +14 -2
  5. data/ext/ruby_whisper.h +39 -0
  6. data/ext/ruby_whisper_context.c +22 -22
  7. data/ext/ruby_whisper_model.c +12 -12
  8. data/ext/ruby_whisper_params.c +47 -23
  9. data/ext/ruby_whisper_segment.c +84 -19
  10. data/ext/ruby_whisper_token.c +351 -0
  11. data/ext/ruby_whisper_transcribe.cpp +1 -1
  12. data/ext/ruby_whisper_vad_context.c +75 -0
  13. data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
  14. data/ext/ruby_whisper_vad_segment.c +139 -0
  15. data/ext/ruby_whisper_vad_segments.c +106 -0
  16. data/ext/sources/CMakeLists.txt +4 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  19. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  20. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  21. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  22. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  23. data/ext/sources/examples/cli/cli.cpp +121 -112
  24. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  25. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  26. data/ext/sources/examples/server/server.cpp +10 -11
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +5 -1
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +12 -3
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +7 -1
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +2046 -1974
  31. data/ext/sources/examples/talk-llama/llama-arch.h +67 -2
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +75 -33
  33. data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
  34. data/ext/sources/examples/talk-llama/llama-chat.cpp +79 -3
  35. data/ext/sources/examples/talk-llama/llama-chat.h +4 -0
  36. data/ext/sources/examples/talk-llama/llama-context.cpp +775 -78
  37. data/ext/sources/examples/talk-llama/llama-context.h +57 -9
  38. data/ext/sources/examples/talk-llama/llama-cparams.h +1 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
  40. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  41. data/ext/sources/examples/talk-llama/llama-graph.cpp +381 -64
  42. data/ext/sources/examples/talk-llama/llama-graph.h +103 -13
  43. data/ext/sources/examples/talk-llama/llama-hparams.cpp +26 -2
  44. data/ext/sources/examples/talk-llama/llama-hparams.h +41 -10
  45. data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
  46. data/ext/sources/examples/talk-llama/llama-impl.h +1 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +5 -3
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +145 -65
  49. data/ext/sources/examples/talk-llama/llama-kv-cache.h +22 -7
  50. data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
  51. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
  52. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +32 -19
  53. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
  54. data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
  55. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  56. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +91 -9
  57. data/ext/sources/examples/talk-llama/llama-model-loader.h +6 -0
  58. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
  59. data/ext/sources/examples/talk-llama/llama-model.cpp +1529 -13134
  60. data/ext/sources/examples/talk-llama/llama-model.h +44 -3
  61. data/ext/sources/examples/talk-llama/llama-quant.cpp +8 -23
  62. data/ext/sources/examples/talk-llama/llama-sampling.cpp +1294 -198
  63. data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
  64. data/ext/sources/examples/talk-llama/llama-vocab.cpp +133 -37
  65. data/ext/sources/examples/talk-llama/llama-vocab.h +45 -40
  66. data/ext/sources/examples/talk-llama/llama.cpp +729 -2
  67. data/ext/sources/examples/talk-llama/llama.h +152 -14
  68. data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
  69. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  70. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  71. data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
  72. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  73. data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
  74. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
  75. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
  76. data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
  77. data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
  78. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  79. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  80. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  81. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  82. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  83. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  84. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  85. data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
  86. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  87. data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
  88. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
  89. data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
  90. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  91. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
  92. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
  108. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
  109. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  110. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
  111. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  112. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
  113. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  114. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  116. data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
  117. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  118. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  119. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  120. data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
  121. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  122. data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
  123. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
  124. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
  125. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
  126. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  127. data/ext/sources/examples/talk-llama/models/models.h +569 -0
  128. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
  129. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  130. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
  131. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  132. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  133. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  134. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  135. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  136. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  137. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  138. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  139. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  140. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  141. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  142. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  143. data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
  144. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  145. data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
  146. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  147. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  148. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  149. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  150. data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
  151. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
  153. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
  154. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
  155. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  156. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  157. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
  158. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  159. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  160. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
  161. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  162. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  163. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  164. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  165. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  166. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  167. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  168. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  169. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  170. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  171. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  172. data/ext/sources/examples/talk-llama/unicode.cpp +102 -16
  173. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  174. data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
  175. data/ext/sources/ggml/CMakeLists.txt +82 -54
  176. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  177. data/ext/sources/ggml/include/ggml-backend.h +4 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +1 -0
  179. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  180. data/ext/sources/ggml/include/ggml-rpc.h +8 -11
  181. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  182. data/ext/sources/ggml/include/ggml.h +190 -12
  183. data/ext/sources/ggml/src/CMakeLists.txt +82 -11
  184. data/ext/sources/ggml/src/ggml-alloc.c +124 -41
  185. data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
  186. data/ext/sources/ggml/src/ggml-backend-reg.cpp +27 -3
  187. data/ext/sources/ggml/src/ggml-backend.cpp +71 -21
  188. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
  189. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -9
  190. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
  191. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
  192. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2179 -1696
  193. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -317
  194. data/ext/sources/ggml/src/ggml-cann/common.h +283 -208
  195. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +626 -776
  196. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +156 -86
  197. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -0
  198. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  199. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
  200. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1004 -0
  201. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
  202. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  203. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
  204. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  205. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6 -6
  206. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +50 -2
  207. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -3
  208. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +195 -71
  209. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  210. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +573 -106
  211. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
  212. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +298 -112
  213. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
  214. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +819 -125
  215. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  216. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +708 -431
  217. data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -4
  218. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +671 -31
  219. data/ext/sources/ggml/src/ggml-cpu/repack.h +14 -0
  220. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +41 -43
  221. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
  222. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
  223. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  224. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +124 -1
  225. data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
  226. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
  227. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  228. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
  229. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  230. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +1 -1
  231. data/ext/sources/ggml/src/ggml-cuda/common.cuh +353 -80
  232. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  233. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +339 -246
  235. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  236. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  237. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  238. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  239. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +31 -21
  241. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +663 -596
  242. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
  243. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1241 -0
  244. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +30 -37
  245. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +14 -13
  246. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  247. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +83 -37
  248. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  249. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  250. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1155 -164
  251. data/ext/sources/ggml/src/ggml-cuda/mean.cu +5 -4
  252. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +741 -48
  253. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +60 -12
  254. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +381 -42
  255. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  256. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  257. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
  258. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +498 -171
  259. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +375 -79
  260. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +3 -2
  261. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +241 -95
  262. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/pad.cu +64 -33
  264. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +151 -0
  265. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  266. data/ext/sources/ggml/src/ggml-cuda/rope.cu +192 -77
  267. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  268. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
  269. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  270. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  271. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +203 -6
  272. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  273. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  274. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -20
  275. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
  276. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  278. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  279. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  280. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  281. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  282. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  283. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  284. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  285. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +19 -1
  286. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
  287. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  288. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +168 -76
  289. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +11 -4
  290. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  291. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  292. data/ext/sources/ggml/src/ggml-cuda/unary.cu +105 -11
  293. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +36 -0
  294. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
  295. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  296. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +12 -1
  297. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
  317. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
  321. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  322. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
  323. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
  324. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
  325. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  326. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
  327. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
  328. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  329. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +8 -13
  330. data/ext/sources/ggml/src/ggml-impl.h +67 -6
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +2 -2
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +29 -20
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +652 -285
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +103 -56
  335. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +496 -118
  336. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +231 -9
  337. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1227 -224
  338. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +12 -0
  339. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +14 -8
  340. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1972 -704
  341. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
  342. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +11 -0
  343. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1430 -120
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +63 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +13 -0
  365. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
  366. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
  367. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
  368. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  369. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +6 -0
  370. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +0 -9
  371. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
  372. data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
  373. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
  374. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +34 -0
  375. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  376. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  377. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
  378. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  379. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +76 -3
  380. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +333 -300
  381. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +10 -2
  382. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +335 -110
  383. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
  384. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +156 -0
  385. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  386. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  387. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  388. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  389. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  390. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
  391. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  392. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  393. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  394. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  395. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +30 -17
  396. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  397. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  398. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
  399. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  400. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  401. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  402. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +58 -0
  403. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
  404. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5013 -2859
  405. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  406. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
  407. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
  408. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  409. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
  410. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  411. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
  412. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
  413. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  414. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  415. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  416. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  417. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  418. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  419. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
  420. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  421. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  422. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
  423. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
  424. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  425. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  426. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  427. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  428. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  429. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  430. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  431. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  432. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
  433. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
  434. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  449. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
  450. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  451. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  452. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
  453. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  454. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  455. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  456. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  457. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  458. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
  459. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  460. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +39 -17
  461. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +19 -1
  462. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +45 -7
  463. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +50 -12
  464. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +1 -1
  465. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  466. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  467. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
  468. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
  469. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  470. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
  471. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  472. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
  473. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  474. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  475. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
  476. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
  477. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
  478. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  479. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
  480. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
  481. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
  482. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
  483. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
  484. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  485. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  486. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  487. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  488. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +70 -25
  489. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  490. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  491. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  492. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  493. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  494. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  495. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  496. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  497. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
  498. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  499. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  500. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  501. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  502. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  503. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  504. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
  505. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  506. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +78 -103
  507. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +34 -23
  508. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
  509. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
  510. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +88 -228
  511. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  512. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  513. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
  514. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  515. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  516. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  517. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
  518. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
  519. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  520. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
  521. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  522. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  523. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  524. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  525. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +50 -4
  526. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  527. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
  528. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
  529. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
  530. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  531. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -50
  532. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -33
  533. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -33
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
  560. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  561. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
  562. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  563. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  564. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  565. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  566. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  567. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
  568. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
  569. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -151
  570. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  571. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
  572. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
  573. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1964 -435
  574. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +33 -10
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +1 -1
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +6 -6
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
  584. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
  585. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
  586. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
  587. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  588. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
  589. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
  590. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  591. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
  592. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  593. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
  594. data/ext/sources/ggml/src/ggml.c +425 -33
  595. data/ext/sources/include/whisper.h +1 -0
  596. data/ext/sources/src/CMakeLists.txt +3 -1
  597. data/ext/sources/src/whisper.cpp +101 -35
  598. data/ext/sources/tests/CMakeLists.txt +2 -2
  599. data/ext/sources/tests/test-vad-full.cpp +4 -2
  600. data/ext/sources/tests/test-vad.cpp +1 -1
  601. data/extsources.rb +1 -0
  602. data/lib/whisper/model/uri.rb +17 -18
  603. data/sig/whisper.rbs +119 -2
  604. data/test/test_params.rb +16 -8
  605. data/test/test_segment.rb +0 -1
  606. data/test/test_token.rb +70 -0
  607. data/test/test_vad.rb +1 -1
  608. data/test/test_vad_context.rb +50 -0
  609. data/test/test_vad_segment.rb +19 -0
  610. data/test/test_vad_segments.rb +16 -0
  611. data/test/test_whisper.rb +7 -0
  612. data/whispercpp.gemspec +1 -1
  613. metadata +287 -34
  614. data/ext/sources/build-xcframework.sh +0 -571
  615. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
  618. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
  619. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
  620. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
  621. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
  622. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
  623. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
  624. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  625. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  626. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  627. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  628. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
  629. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
  630. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
@@ -7,8 +7,10 @@
7
7
  #include "unary-ops.h"
8
8
  #include "vec.h"
9
9
 
10
- #include <float.h>
10
+ #include <cfloat>
11
11
  #include <algorithm>
12
+ #include <cmath>
13
+ #include <functional>
12
14
 
13
15
  // ggml_compute_forward_dup
14
16
 
@@ -1394,6 +1396,56 @@ void ggml_compute_forward_sum(
1394
1396
  }
1395
1397
  }
1396
1398
 
1399
+ // ggml_compute_forward_cumsum
1400
+
1401
+ static void ggml_compute_forward_cumsum_f32(
1402
+ const ggml_compute_params * params,
1403
+ ggml_tensor * dst) {
1404
+
1405
+ const ggml_tensor * src0 = dst->src[0];
1406
+
1407
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
1408
+ GGML_ASSERT(dst->nb[0] == sizeof(float));
1409
+
1410
+ GGML_TENSOR_UNARY_OP_LOCALS
1411
+
1412
+ GGML_ASSERT(ne0 == ne00);
1413
+ GGML_ASSERT(ne1 == ne01);
1414
+ GGML_ASSERT(ne2 == ne02);
1415
+ GGML_ASSERT(ne3 == ne03);
1416
+
1417
+ const auto [ir0, ir1] = get_thread_range(params, src0);
1418
+
1419
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
1420
+ const int64_t i03 = ir/(ne02*ne01);
1421
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
1422
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
1423
+
1424
+ float * src_row = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
1425
+ float * dst_row = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
1426
+
1427
+ ggml_vec_cumsum_f32(ne00, dst_row, src_row);
1428
+ }
1429
+ }
1430
+
1431
+ void ggml_compute_forward_cumsum(
1432
+ const ggml_compute_params * params,
1433
+ ggml_tensor * dst) {
1434
+
1435
+ const ggml_tensor * src0 = dst->src[0];
1436
+
1437
+ switch (src0->type) {
1438
+ case GGML_TYPE_F32:
1439
+ {
1440
+ ggml_compute_forward_cumsum_f32(params, dst);
1441
+ } break;
1442
+ default:
1443
+ {
1444
+ GGML_ABORT("fatal error");
1445
+ }
1446
+ }
1447
+ }
1448
+
1397
1449
  // ggml_compute_forward_sum_rows
1398
1450
 
1399
1451
  static void ggml_compute_forward_sum_rows_f32(
@@ -2140,6 +2192,83 @@ static void ggml_compute_forward_gelu(
2140
2192
  }
2141
2193
  }
2142
2194
 
2195
+ // ggml_compute_fill
2196
+
2197
+ static void ggml_compute_forward_fill_f32(const ggml_compute_params * params, ggml_tensor * dst) {
2198
+ const float c = ggml_get_op_params_f32(dst, 0);
2199
+
2200
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
2201
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
2202
+
2203
+ const auto [ir0, ir1] = get_thread_range(params, dst);
2204
+
2205
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
2206
+ const int64_t i03 = ir/(ne2*ne1);
2207
+ const int64_t i02 = (ir - i03*ne2*ne1)/ne1;
2208
+ const int64_t i01 = (ir - i03*ne2*ne1 - i02*ne1);
2209
+
2210
+ float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
2211
+
2212
+ ggml_vec_set_f32(ne0, dst_ptr, c);
2213
+ }
2214
+ }
2215
+
2216
+ void ggml_compute_forward_fill(const ggml_compute_params * params, ggml_tensor * dst) {
2217
+ ggml_compute_forward_fill_f32(params, dst);
2218
+ }
2219
+
2220
+ // ggml_compute_tri
2221
+
2222
+ static void ggml_compute_forward_tri_f32(const ggml_compute_params * params, ggml_tensor * dst) {
2223
+ const ggml_tensor * src0 = dst->src[0];
2224
+
2225
+ const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0);
2226
+
2227
+ GGML_ASSERT(ggml_is_contiguous(src0));
2228
+
2229
+ GGML_TENSOR_UNARY_OP_LOCALS
2230
+
2231
+ const auto [ir0, ir1] = get_thread_range(params, src0);
2232
+
2233
+ bool (*bipred)(int, int);
2234
+
2235
+ switch (ttype) {
2236
+ case GGML_TRI_TYPE_LOWER: bipred = [](int i, int r) { return i < r; }; break;
2237
+ case GGML_TRI_TYPE_LOWER_DIAG: bipred = [](int i, int r) { return i <= r; }; break;
2238
+ case GGML_TRI_TYPE_UPPER: bipred = [](int i, int r) { return i > r; }; break;
2239
+ case GGML_TRI_TYPE_UPPER_DIAG: bipred = [](int i, int r) { return i >= r; }; break;
2240
+ default: GGML_ABORT("invalid tri type");
2241
+ }
2242
+
2243
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
2244
+ const int64_t i03 = ir/(ne02*ne01);
2245
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
2246
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
2247
+
2248
+ const float * src_ptr = (const float *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
2249
+ float * dst_ptr = ( float *) (( char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
2250
+
2251
+ for (int i0 = 0; i0 < ne0; ++i0) {
2252
+ dst_ptr[i0] = bipred(i0, i01) ? src_ptr[i0] : 0.0f;
2253
+ }
2254
+ }
2255
+ }
2256
+
2257
+ void ggml_compute_forward_tri(const ggml_compute_params * params, ggml_tensor * dst) {
2258
+ const ggml_tensor * src0 = dst->src[0];
2259
+
2260
+ switch (src0->type) {
2261
+ case GGML_TYPE_F32:
2262
+ {
2263
+ ggml_compute_forward_tri_f32(params, dst);
2264
+ } break;
2265
+ default:
2266
+ {
2267
+ GGML_ABORT("fatal error");
2268
+ }
2269
+ }
2270
+ }
2271
+
2143
2272
  // ggml_compute_forward_gelu_erf
2144
2273
 
2145
2274
  static void ggml_compute_forward_gelu_erf_f32(
@@ -3467,31 +3596,27 @@ static void ggml_compute_forward_norm_f32(
3467
3596
 
3468
3597
  GGML_ASSERT(eps >= 0.0f);
3469
3598
 
3470
- // TODO: optimize
3471
3599
  for (int64_t i03 = 0; i03 < ne03; i03++) {
3472
3600
  for (int64_t i02 = 0; i02 < ne02; i02++) {
3473
3601
  for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
3474
3602
  const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
3475
3603
 
3476
- ggml_float sum = 0.0;
3477
- for (int64_t i00 = 0; i00 < ne00; i00++) {
3478
- sum += (ggml_float)x[i00];
3479
- }
3480
-
3604
+ float sum = 0.0;
3605
+ ggml_vec_sum_f32(ne00, &sum, x);
3481
3606
  float mean = sum/ne00;
3482
3607
 
3483
3608
  float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
3609
+ float variance = 0;
3484
3610
 
3485
- ggml_float sum2 = 0.0;
3486
- for (int64_t i00 = 0; i00 < ne00; i00++) {
3487
- float v = x[i00] - mean;
3488
- y[i00] = v;
3489
- sum2 += (ggml_float)(v*v);
3490
- }
3611
+ #ifdef GGML_USE_ACCELERATE
3612
+ mean = -mean;
3613
+ vDSP_vsadd(x, 1, &mean, y, 1, ne00);
3614
+ vDSP_measqv(y, 1, &variance, ne00);
3615
+ #else
3616
+ variance = ggml_vec_cvar_f32(ne00, y, x, mean);
3617
+ #endif //GGML_USE_ACCELERATE
3491
3618
 
3492
- float variance = sum2/ne00;
3493
3619
  const float scale = 1.0f/sqrtf(variance + eps);
3494
-
3495
3620
  ggml_vec_scale_f32(ne00, y, scale);
3496
3621
  }
3497
3622
  }
@@ -4459,46 +4584,6 @@ void ggml_compute_forward_cont(
4459
4584
  ggml_compute_forward_dup(params, dst);
4460
4585
  }
4461
4586
 
4462
- // ggml_compute_forward_reshape
4463
-
4464
- void ggml_compute_forward_reshape(
4465
- const ggml_compute_params * params,
4466
- ggml_tensor * dst) {
4467
- // NOP
4468
- GGML_UNUSED(params);
4469
- GGML_UNUSED(dst);
4470
- }
4471
-
4472
- // ggml_compute_forward_view
4473
-
4474
- void ggml_compute_forward_view(
4475
- const ggml_compute_params * params,
4476
- ggml_tensor * dst) {
4477
- // NOP
4478
- GGML_UNUSED(params);
4479
- GGML_UNUSED(dst);
4480
- }
4481
-
4482
- // ggml_compute_forward_permute
4483
-
4484
- void ggml_compute_forward_permute(
4485
- const ggml_compute_params * params,
4486
- ggml_tensor * dst) {
4487
- // NOP
4488
- GGML_UNUSED(params);
4489
- GGML_UNUSED(dst);
4490
- }
4491
-
4492
- // ggml_compute_forward_transpose
4493
-
4494
- void ggml_compute_forward_transpose(
4495
- const ggml_compute_params * params,
4496
- ggml_tensor * dst) {
4497
- // NOP
4498
- GGML_UNUSED(params);
4499
- GGML_UNUSED(dst);
4500
- }
4501
-
4502
4587
  // ggml_compute_forward_get_rows
4503
4588
 
4504
4589
  static void ggml_compute_forward_get_rows_q(
@@ -5478,7 +5563,7 @@ static void ggml_rope_cache_init(
5478
5563
  }
5479
5564
 
5480
5565
  static void ggml_mrope_cache_init(
5481
- float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects,
5566
+ float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool is_imrope, bool indep_sects,
5482
5567
  float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
5483
5568
  float * cache, float sin_sign, float theta_scale) {
5484
5569
  // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
@@ -5513,14 +5598,26 @@ static void ggml_mrope_cache_init(
5513
5598
  }
5514
5599
 
5515
5600
  float theta = theta_t;
5516
- if (sector >= sections[0] && sector < sec_w) {
5517
- theta = theta_h;
5518
- }
5519
- else if (sector >= sec_w && sector < sec_w + sections[2]) {
5520
- theta = theta_w;
5521
- }
5522
- else if (sector >= sec_w + sections[2]) {
5523
- theta = theta_e;
5601
+ if (is_imrope) { // qwen3vl apply interleaved mrope
5602
+ if (sector % 3 == 1 && sector < 3 * sections[1]) {
5603
+ theta = theta_h;
5604
+ } else if (sector % 3 == 2 && sector < 3 * sections[2]) {
5605
+ theta = theta_w;
5606
+ } else if (sector % 3 == 0 && sector < 3 * sections[0]) {
5607
+ theta = theta_t;
5608
+ } else {
5609
+ theta = theta_e;
5610
+ }
5611
+ } else {
5612
+ if (sector >= sections[0] && sector < sec_w) {
5613
+ theta = theta_h;
5614
+ }
5615
+ else if (sector >= sec_w && sector < sec_w + sections[2]) {
5616
+ theta = theta_w;
5617
+ }
5618
+ else if (sector >= sec_w + sections[2]) {
5619
+ theta = theta_e;
5620
+ }
5524
5621
  }
5525
5622
 
5526
5623
  rope_yarn(
@@ -5535,7 +5632,28 @@ static void ggml_mrope_cache_init(
5535
5632
  }
5536
5633
  }
5537
5634
 
5538
- static void ggml_compute_forward_rope_f32(
5635
+
5636
+ template<typename T>
5637
+ static void rotate_pairs(const int64_t n, const int64_t n_offset, const float * cache, const T * src_data, T * dst_data, const int scale = 2) {
5638
+ for (int64_t i0 = 0; i0 < n; i0 += 2) {
5639
+ const int64_t ic = i0/scale; // hack for GGML_ROPE_TYPE_NORMAL, where we need ic = i0; for all other cases, ic = i0/2
5640
+
5641
+ const float cos_theta = cache[i0 + 0];
5642
+ const float sin_theta = cache[i0 + 1];
5643
+
5644
+ const T * const src = src_data + ic;
5645
+ T * dst = dst_data + ic;
5646
+
5647
+ const float x0 = type_conversion_table<T>::to_f32(src[0]);
5648
+ const float x1 = type_conversion_table<T>::to_f32(src[n_offset]);
5649
+
5650
+ dst[0] = type_conversion_table<T>::from_f32(x0*cos_theta - x1*sin_theta);
5651
+ dst[n_offset] = type_conversion_table<T>::from_f32(x0*sin_theta + x1*cos_theta);
5652
+ }
5653
+ }
5654
+
5655
+ template<typename T> //float or ggml_fp16_t
5656
+ static void ggml_compute_forward_rope_flt(
5539
5657
  const ggml_compute_params * params,
5540
5658
  ggml_tensor * dst,
5541
5659
  const bool forward) {
@@ -5544,6 +5662,9 @@ static void ggml_compute_forward_rope_f32(
5544
5662
  const ggml_tensor * src1 = dst->src[1];
5545
5663
  const ggml_tensor * src2 = dst->src[2];
5546
5664
 
5665
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
5666
+ GGML_ASSERT(src1->type == GGML_TYPE_I32);
5667
+
5547
5668
  float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
5548
5669
  int sections[4];
5549
5670
 
@@ -5566,7 +5687,8 @@ static void ggml_compute_forward_rope_f32(
5566
5687
  //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
5567
5688
  //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
5568
5689
 
5569
- GGML_ASSERT(nb00 == sizeof(float));
5690
+ GGML_ASSERT(nb0 == nb00);
5691
+ GGML_ASSERT(nb0 == sizeof(T));
5570
5692
 
5571
5693
  const int ith = params->ith;
5572
5694
  const int nth = params->nth;
@@ -5591,11 +5713,11 @@ static void ggml_compute_forward_rope_f32(
5591
5713
  float corr_dims[2];
5592
5714
  ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
5593
5715
 
5594
- const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
5595
- const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding
5716
+ const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
5717
+ const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, note: also true for vision (24 & 8 == true) and for imrope
5596
5718
  const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
5597
5719
 
5598
- if (is_mrope) {
5720
+ if (mrope_used) {
5599
5721
  GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
5600
5722
  }
5601
5723
 
@@ -5621,7 +5743,7 @@ static void ggml_compute_forward_rope_f32(
5621
5743
  for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
5622
5744
 
5623
5745
  float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
5624
- if (!is_mrope) {
5746
+ if (!mrope_used) {
5625
5747
  const int64_t p = pos[i2];
5626
5748
  ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5627
5749
  }
@@ -5631,7 +5753,7 @@ static void ggml_compute_forward_rope_f32(
5631
5753
  const int64_t p_w = pos[i2 + ne2 * 2];
5632
5754
  const int64_t p_e = pos[i2 + ne2 * 3];
5633
5755
  ggml_mrope_cache_init(
5634
- p_t, p_h, p_w, p_e, sections, is_vision,
5756
+ p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
5635
5757
  freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5636
5758
  }
5637
5759
 
@@ -5639,347 +5761,115 @@ static void ggml_compute_forward_rope_f32(
5639
5761
  if (ir++ < ir0) continue;
5640
5762
  if (ir > ir1) break;
5641
5763
 
5642
- if (is_neox || is_mrope) {
5643
- if (is_vision){
5644
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5645
- const int64_t ic = i0/2;
5646
-
5647
- const float cos_theta = cache[i0 + 0];
5648
- const float sin_theta = cache[i0 + 1];
5649
-
5650
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5651
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5652
-
5653
- const float x0 = src[0];
5654
- const float x1 = src[n_dims];
5655
-
5656
- dst_data[0] = x0*cos_theta - x1*sin_theta;
5657
- dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
5658
- }
5659
- } else {
5660
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5661
- const int64_t ic = i0/2;
5662
-
5663
- const float cos_theta = cache[i0 + 0];
5664
- const float sin_theta = cache[i0 + 1];
5665
-
5666
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5667
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5668
-
5669
- const float x0 = src[0];
5670
- const float x1 = src[n_dims/2];
5671
-
5672
- dst_data[0] = x0*cos_theta - x1*sin_theta;
5673
- dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
5674
- }
5675
- }
5676
- } else {
5677
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5678
- const float cos_theta = cache[i0 + 0];
5679
- const float sin_theta = cache[i0 + 1];
5680
-
5681
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5682
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5683
-
5684
- const float x0 = src[0];
5685
- const float x1 = src[1];
5686
-
5687
- dst_data[0] = x0*cos_theta - x1*sin_theta;
5688
- dst_data[1] = x0*sin_theta + x1*cos_theta;
5689
- }
5764
+ T * src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
5765
+ T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
5766
+
5767
+ switch (mode) {
5768
+ case GGML_ROPE_TYPE_NORMAL:
5769
+ rotate_pairs<T>(n_dims, 1, cache, src, dst_data, 1);
5770
+ break;
5771
+ case GGML_ROPE_TYPE_NEOX:
5772
+ case GGML_ROPE_TYPE_MROPE:
5773
+ case GGML_ROPE_TYPE_IMROPE:
5774
+ rotate_pairs<T>(n_dims, n_dims/2, cache, src, dst_data);
5775
+ break;
5776
+ case GGML_ROPE_TYPE_VISION:
5777
+ rotate_pairs<T>(ne0, n_dims, cache, src, dst_data);
5778
+ break;
5779
+ default:
5780
+ GGML_ABORT("rope type not supported");
5690
5781
  }
5691
5782
 
5692
- if (is_vision) {
5693
- for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
5694
- const int64_t ic = i0/2;
5695
-
5696
- const float cos_theta = cache[i0 + 0];
5697
- const float sin_theta = cache[i0 + 1];
5698
-
5699
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5700
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5701
-
5702
- const float x0 = src[0];
5703
- const float x1 = src[n_dims];
5704
-
5705
- dst_data[0] = x0*cos_theta - x1*sin_theta;
5706
- dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
5707
- }
5708
- } else {
5783
+ if (!is_vision) {
5709
5784
  // fill the remain channels with data from src tensor
5710
5785
  for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
5711
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5712
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5786
+ const T * const src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5787
+ T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5713
5788
 
5714
5789
  dst_data[0] = src[0];
5715
5790
  dst_data[1] = src[1];
5716
5791
  }
5717
5792
  }
5718
- }
5793
+ } //attn-heads
5719
5794
  }
5720
5795
  }
5721
5796
  }
5722
5797
 
5723
- // TODO: deduplicate f16/f32 code
5724
- static void ggml_compute_forward_rope_f16(
5798
+ void ggml_compute_forward_rope(
5725
5799
  const ggml_compute_params * params,
5726
- ggml_tensor * dst,
5727
- const bool forward) {
5800
+ ggml_tensor * dst) {
5728
5801
 
5729
5802
  const ggml_tensor * src0 = dst->src[0];
5730
- const ggml_tensor * src1 = dst->src[1];
5731
- const ggml_tensor * src2 = dst->src[2];
5732
5803
 
5733
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
5734
- int sections[4];
5804
+ switch (src0->type) {
5805
+ case GGML_TYPE_F16:
5806
+ {
5807
+ ggml_compute_forward_rope_flt<ggml_fp16_t>(params, dst, true);
5808
+ } break;
5809
+ case GGML_TYPE_F32:
5810
+ {
5811
+ ggml_compute_forward_rope_flt<float>(params, dst, true);
5812
+ } break;
5813
+ default:
5814
+ {
5815
+ GGML_ABORT("fatal error");
5816
+ }
5817
+ }
5818
+ }
5735
5819
 
5736
- //const int n_past = ((int32_t *) dst->op_params)[0];
5737
- const int n_dims = ((int32_t *) dst->op_params)[1];
5738
- const int mode = ((int32_t *) dst->op_params)[2];
5739
- //const int n_ctx = ((int32_t *) dst->op_params)[3];
5740
- const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
5741
- memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
5742
- memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
5743
- memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
5744
- memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
5745
- memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
5746
- memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
5747
- memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
5820
+ // ggml_compute_forward_rope_back
5748
5821
 
5822
+ void ggml_compute_forward_rope_back(
5823
+ const ggml_compute_params * params,
5824
+ ggml_tensor * dst) {
5749
5825
 
5750
- GGML_TENSOR_UNARY_OP_LOCALS
5826
+ const ggml_tensor * src0 = dst->src[0];
5751
5827
 
5752
- //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
5753
- //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
5828
+ switch (src0->type) {
5829
+ case GGML_TYPE_F16:
5830
+ {
5831
+ ggml_compute_forward_rope_flt<ggml_fp16_t>(params, dst, false);
5832
+ } break;
5833
+ case GGML_TYPE_F32:
5834
+ {
5835
+ ggml_compute_forward_rope_flt<float>(params, dst, false);
5836
+ } break;
5837
+ default:
5838
+ {
5839
+ GGML_ABORT("fatal error");
5840
+ }
5841
+ }
5842
+ }
5754
5843
 
5755
- GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
5844
+ // ggml_compute_forward_conv_transpose_1d
5756
5845
 
5757
- const int ith = params->ith;
5758
- const int nth = params->nth;
5846
+ static void ggml_compute_forward_conv_transpose_1d_f16_f32(
5847
+ const ggml_compute_params * params,
5848
+ ggml_tensor * dst) {
5759
5849
 
5760
- const int nr = ggml_nrows(dst);
5850
+ const ggml_tensor * src0 = dst->src[0];
5851
+ const ggml_tensor * src1 = dst->src[1];
5761
5852
 
5762
- GGML_ASSERT(n_dims <= ne0);
5763
- GGML_ASSERT(n_dims % 2 == 0);
5853
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
5854
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
5855
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
5764
5856
 
5765
- // rows per thread
5766
- const int dr = (nr + nth - 1)/nth;
5857
+ GGML_TENSOR_BINARY_OP_LOCALS
5767
5858
 
5768
- // row range for this thread
5769
- const int ir0 = dr*ith;
5770
- const int ir1 = MIN(ir0 + dr, nr);
5859
+ const int ith = params->ith;
5860
+ const int nth = params->nth;
5771
5861
 
5772
- // row index used to determine which thread to use
5773
- int ir = 0;
5862
+ const int nk = ne00*ne01*ne02;
5774
5863
 
5775
- const float theta_scale = powf(freq_base, -2.0f/n_dims);
5864
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
5865
+ GGML_ASSERT(nb10 == sizeof(float));
5776
5866
 
5777
- float corr_dims[2];
5778
- ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
5867
+ if (ith == 0) {
5868
+ memset(params->wdata, 0, params->wsize);
5779
5869
 
5780
- const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
5781
- const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
5782
- const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
5783
-
5784
- if (is_mrope) {
5785
- GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
5786
- }
5787
-
5788
- if (is_vision) {
5789
- GGML_ASSERT(n_dims == ne0/2);
5790
- }
5791
-
5792
- const float * freq_factors = NULL;
5793
- if (src2 != NULL) {
5794
- GGML_ASSERT(src2->type == GGML_TYPE_F32);
5795
- GGML_ASSERT(src2->ne[0] >= n_dims / 2);
5796
- freq_factors = (const float *) src2->data;
5797
- }
5798
-
5799
- // backward process uses inverse rotation by cos and sin.
5800
- // cos and sin build a rotation matrix, where the inverse is the transpose.
5801
- // this essentially just switches the sign of sin.
5802
- const float sin_sign = forward ? 1.0f : -1.0f;
5803
-
5804
- const int32_t * pos = (const int32_t *) src1->data;
5805
-
5806
- for (int64_t i3 = 0; i3 < ne3; i3++) {
5807
- for (int64_t i2 = 0; i2 < ne2; i2++) {
5808
-
5809
- float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
5810
- if (!is_mrope) {
5811
- const int64_t p = pos[i2];
5812
- ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5813
- }
5814
- else {
5815
- const int64_t p_t = pos[i2];
5816
- const int64_t p_h = pos[i2 + ne2];
5817
- const int64_t p_w = pos[i2 + ne2 * 2];
5818
- const int64_t p_e = pos[i2 + ne2 * 3];
5819
- ggml_mrope_cache_init(
5820
- p_t, p_h, p_w, p_e, sections, is_vision,
5821
- freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5822
- }
5823
-
5824
- for (int64_t i1 = 0; i1 < ne1; i1++) {
5825
- if (ir++ < ir0) continue;
5826
- if (ir > ir1) break;
5827
-
5828
- if (is_neox || is_mrope) {
5829
- if (is_vision) {
5830
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5831
- const int64_t ic = i0/2;
5832
-
5833
- const float cos_theta = cache[i0 + 0];
5834
- const float sin_theta = cache[i0 + 1];
5835
-
5836
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5837
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5838
-
5839
- const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
5840
- const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims]);
5841
-
5842
- dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5843
- dst_data[n_dims] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5844
- }
5845
- } else {
5846
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5847
- const int64_t ic = i0/2;
5848
-
5849
- const float cos_theta = cache[i0 + 0];
5850
- const float sin_theta = cache[i0 + 1];
5851
-
5852
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5853
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5854
-
5855
- const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
5856
- const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims/2]);
5857
-
5858
- dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5859
- dst_data[n_dims/2] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5860
- }
5861
- }
5862
- } else {
5863
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5864
- const float cos_theta = cache[i0 + 0];
5865
- const float sin_theta = cache[i0 + 1];
5866
-
5867
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5868
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5869
-
5870
- const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
5871
- const float x1 = GGML_CPU_FP16_TO_FP32(src[1]);
5872
-
5873
- dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5874
- dst_data[1] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5875
- }
5876
- }
5877
-
5878
- if (is_vision) {
5879
- for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
5880
- const int64_t ic = i0/2;
5881
-
5882
- const float cos_theta = cache[i0 + 0];
5883
- const float sin_theta = cache[i0 + 1];
5884
-
5885
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5886
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5887
-
5888
- const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
5889
- const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims]);
5890
-
5891
- dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5892
- dst_data[n_dims] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5893
- }
5894
- } else {
5895
- for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
5896
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5897
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5898
-
5899
- dst_data[0] = src[0];
5900
- dst_data[1] = src[1];
5901
- }
5902
- }
5903
- }
5904
- }
5905
- }
5906
- }
5907
-
5908
- void ggml_compute_forward_rope(
5909
- const ggml_compute_params * params,
5910
- ggml_tensor * dst) {
5911
-
5912
- const ggml_tensor * src0 = dst->src[0];
5913
-
5914
- switch (src0->type) {
5915
- case GGML_TYPE_F16:
5916
- {
5917
- ggml_compute_forward_rope_f16(params, dst, true);
5918
- } break;
5919
- case GGML_TYPE_F32:
5920
- {
5921
- ggml_compute_forward_rope_f32(params, dst, true);
5922
- } break;
5923
- default:
5924
- {
5925
- GGML_ABORT("fatal error");
5926
- }
5927
- }
5928
- }
5929
-
5930
- // ggml_compute_forward_rope_back
5931
-
5932
- void ggml_compute_forward_rope_back(
5933
- const ggml_compute_params * params,
5934
- ggml_tensor * dst) {
5935
-
5936
- const ggml_tensor * src0 = dst->src[0];
5937
-
5938
- switch (src0->type) {
5939
- case GGML_TYPE_F16:
5940
- {
5941
- ggml_compute_forward_rope_f16(params, dst, false);
5942
- } break;
5943
- case GGML_TYPE_F32:
5944
- {
5945
- ggml_compute_forward_rope_f32(params, dst, false);
5946
- } break;
5947
- default:
5948
- {
5949
- GGML_ABORT("fatal error");
5950
- }
5951
- }
5952
- }
5953
-
5954
- // ggml_compute_forward_conv_transpose_1d
5955
-
5956
- static void ggml_compute_forward_conv_transpose_1d_f16_f32(
5957
- const ggml_compute_params * params,
5958
- ggml_tensor * dst) {
5959
-
5960
- const ggml_tensor * src0 = dst->src[0];
5961
- const ggml_tensor * src1 = dst->src[1];
5962
-
5963
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
5964
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
5965
- GGML_ASSERT( dst->type == GGML_TYPE_F32);
5966
-
5967
- GGML_TENSOR_BINARY_OP_LOCALS
5968
-
5969
- const int ith = params->ith;
5970
- const int nth = params->nth;
5971
-
5972
- const int nk = ne00*ne01*ne02;
5973
-
5974
- GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
5975
- GGML_ASSERT(nb10 == sizeof(float));
5976
-
5977
- if (ith == 0) {
5978
- memset(params->wdata, 0, params->wsize);
5979
-
5980
- // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
5981
- {
5982
- ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
5870
+ // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
5871
+ {
5872
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
5983
5873
 
5984
5874
  for (int64_t i02 = 0; i02 < ne02; i02++) {
5985
5875
  for (int64_t i01 = 0; i01 < ne01; i01++) {
@@ -6493,7 +6383,7 @@ static void ggml_compute_forward_im2col_3d_f16(
6493
6383
  const int64_t iih = ioh*s1 + ikh*d1 - p1;
6494
6384
  const int64_t iid = iod*s2 + ikd*d2 - p2;
6495
6385
 
6496
- if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
6386
+ if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
6497
6387
  dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
6498
6388
  } else {
6499
6389
  const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
@@ -6664,8 +6554,13 @@ static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params
6664
6554
  ggml_compute_forward_mul_mat(params, &dst);
6665
6555
  }
6666
6556
 
6557
+ static inline int64_t ggml_wrap_around(int64_t coord, int64_t size) {
6558
+ return (coord + size) % size; // adding size avoids negative number weirdness
6559
+ }
6560
+
6667
6561
  // ggml_compute_forward_conv_2d
6668
6562
 
6563
+
6669
6564
  static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params,
6670
6565
  const ggml_tensor * kernel, // [KW, KH, IC, OC]
6671
6566
  const ggml_tensor * src, // [W, H, C, N]
@@ -7074,7 +6969,11 @@ static void ggml_compute_forward_conv_2d_dw_cwhn(
7074
6969
  const int64_t row_end = MIN(row_start + rows_per_thread, rows_total);
7075
6970
 
7076
6971
  #ifdef GGML_SIMD
7077
- const int64_t pkg_size = GGML_F32_EPR;
6972
+ #if defined(__ARM_FEATURE_SVE)
6973
+ const int64_t pkg_size = svcntw();
6974
+ #else
6975
+ const int64_t pkg_size = GGML_F32_EPR;
6976
+ #endif
7078
6977
  const int64_t pkg_count = c / pkg_size;
7079
6978
  const int64_t c_pkg_end = pkg_count * pkg_size;
7080
6979
  #else
@@ -7497,10 +7396,17 @@ static void ggml_compute_forward_upscale_f32(
7497
7396
  float sf1 = (float)ne1/src0->ne[1];
7498
7397
  float sf2 = (float)ne2/src0->ne[2];
7499
7398
  float sf3 = (float)ne3/src0->ne[3];
7399
+ float pixel_offset = 0.5f;
7500
7400
 
7501
7401
  const int32_t mode_flags = ggml_get_op_params_i32(dst, 0);
7502
7402
  const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF);
7503
7403
 
7404
+ if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
7405
+ pixel_offset = 0.0f;
7406
+ sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
7407
+ sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
7408
+ }
7409
+
7504
7410
  if (mode == GGML_SCALE_MODE_NEAREST) {
7505
7411
  for (int64_t i3 = 0; i3 < ne3; i3++) {
7506
7412
  const int64_t i03 = i3 / sf3;
@@ -7519,14 +7425,66 @@ static void ggml_compute_forward_upscale_f32(
7519
7425
  }
7520
7426
  }
7521
7427
  }
7522
- } else if (mode == GGML_SCALE_MODE_BILINEAR) {
7523
- float pixel_offset = 0.5f;
7524
- if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
7525
- pixel_offset = 0.0f;
7526
- sf0 = (float)(ne0 - 1) / (src0->ne[0] - 1);
7527
- sf1 = (float)(ne1 - 1) / (src0->ne[1] - 1);
7528
- }
7428
+ } else if (mode == GGML_SCALE_MODE_BILINEAR && (mode_flags & GGML_SCALE_FLAG_ANTIALIAS)) {
7429
+ // Similar to F.interpolate(..., mode="bilinear", align_corners=False, antialias=True)
7430
+ // https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp
7431
+ auto triangle_filter = [](float x) -> float {
7432
+ return std::max(1.0f - fabsf(x), 0.0f);
7433
+ };
7434
+
7435
+ // support and invscale, minimum 1 pixel for bilinear
7436
+ const float support1 = std::max(1.0f, 1.0f / sf1);
7437
+ const float invscale1 = 1.0f / support1;
7438
+ const float support0 = std::max(1.0f, 1.0f / sf0);
7439
+ const float invscale0 = 1.0f / support0;
7529
7440
 
7441
+ for (int64_t i3 = 0; i3 < ne3; i3++) {
7442
+ const int64_t i03 = i3 / sf3;
7443
+ for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
7444
+ const int64_t i02 = i2 / sf2;
7445
+ for (int64_t i1 = 0; i1 < ne1; i1++) {
7446
+ const float y = ((float) i1 + pixel_offset) / sf1;
7447
+ for (int64_t i0 = 0; i0 < ne0; i0++) {
7448
+ const float x = ((float) i0 + pixel_offset) / sf0;
7449
+
7450
+ // the range of source pixels that contribute
7451
+ const int64_t x_min = std::max<int64_t>(x - support0 + pixel_offset, 0);
7452
+ const int64_t x_max = std::min<int64_t>(x + support0 + pixel_offset, ne00);
7453
+ const int64_t y_min = std::max<int64_t>(y - support1 + pixel_offset, 0);
7454
+ const int64_t y_max = std::min<int64_t>(y + support1 + pixel_offset, ne01);
7455
+
7456
+ // bilinear filter with antialiasing
7457
+ float val = 0.0f;
7458
+ float total_weight = 0.0f;
7459
+
7460
+ for (int64_t sy = y_min; sy < y_max; sy++) {
7461
+ const float weight_y = triangle_filter((sy - y + pixel_offset) * invscale1);
7462
+
7463
+ for (int64_t sx = x_min; sx < x_max; sx++) {
7464
+ const float weight_x = triangle_filter((sx - x + pixel_offset) * invscale0);
7465
+ const float weight = weight_x * weight_y;
7466
+
7467
+ if (weight <= 0.0f) {
7468
+ continue;
7469
+ }
7470
+
7471
+ const float pixel = *(const float *)((const char *)src0->data + sx*nb00 + sy*nb01 + i02*nb02 + i03*nb03);
7472
+ val += pixel * weight;
7473
+ total_weight += weight;
7474
+ }
7475
+ }
7476
+
7477
+ if (total_weight > 0.0f) {
7478
+ val /= total_weight;
7479
+ }
7480
+
7481
+ float * dst_ptr = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
7482
+ *dst_ptr = val;
7483
+ }
7484
+ }
7485
+ }
7486
+ }
7487
+ } else if (mode == GGML_SCALE_MODE_BILINEAR) {
7530
7488
  for (int64_t i3 = 0; i3 < ne3; i3++) {
7531
7489
  const int64_t i03 = i3 / sf3;
7532
7490
  for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
@@ -7561,6 +7519,51 @@ static void ggml_compute_forward_upscale_f32(
7561
7519
 
7562
7520
  const float val = a*(1 - dx)*(1 - dy) + b*dx*(1 - dy) + c*(1 - dx)*dy + d*dx*dy;
7563
7521
 
7522
+ float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
7523
+ *y_dst = val;
7524
+ }
7525
+ }
7526
+ }
7527
+ }
7528
+ } else if (mode == GGML_SCALE_MODE_BICUBIC) {
7529
+ // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
7530
+ const float a = -0.75f; // use alpha = -0.75 (same as PyTorch)
7531
+ auto weight1 = [a](float x) { return ((a + 2) * x - (a + 3)) * x * x + 1; };
7532
+ auto weight2 = [a](float x) { return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; };
7533
+ auto bicubic = [=](float p0, float p1, float p2, float p3, float x) {
7534
+ const float w0 = weight2(x + 1);
7535
+ const float w1 = weight1(x + 0);
7536
+ const float w2 = weight1(1 - x);
7537
+ const float w3 = weight2(2 - x);
7538
+ return p0*w0 + p1*w1 + p2*w2 + p3*w3;
7539
+ };
7540
+
7541
+ for (int64_t i3 = 0; i3 < ne3; i3++) {
7542
+ const int64_t i03 = i3 / sf3;
7543
+ for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
7544
+ const int64_t i02 = i2 / sf2;
7545
+ for (int64_t i1 = 0; i1 < ne1; i1++) {
7546
+ const float y = ((float)i1 + pixel_offset) / sf1 - pixel_offset;
7547
+ const int64_t y0 = (int64_t)floorf(y);
7548
+ const float dy = y - (float)y0;
7549
+
7550
+ for (int64_t i0 = 0; i0 < ne0; i0++) {
7551
+ const float x = ((float)i0 + pixel_offset) / sf0 - pixel_offset;
7552
+ const int64_t x0 = (int64_t)floorf(x);
7553
+ const float dx = x - (float)x0;
7554
+
7555
+ auto p = [=](int64_t x_off, int64_t y_off) -> float {
7556
+ int64_t i00 = std::max(int64_t(0), std::min(x0 + x_off, ne00 - 1));
7557
+ int64_t i01 = std::max(int64_t(0), std::min(y0 + y_off, ne01 - 1));
7558
+ return *(const float *)((const char *)src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7559
+ };
7560
+
7561
+ const float val = bicubic(
7562
+ bicubic(p(-1,-1), p(0,-1), p(1,-1), p(2,-1), dx),
7563
+ bicubic(p(-1, 0), p(0, 0), p(1, 0), p(2, 0), dx),
7564
+ bicubic(p(-1, 1), p(0, 1), p(1, 1), p(2, 1), dx),
7565
+ bicubic(p(-1, 2), p(0, 2), p(1, 2), p(2, 2), dx), dy);
7566
+
7564
7567
  float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
7565
7568
  *y_dst = val;
7566
7569
  }
@@ -7593,6 +7596,7 @@ void ggml_compute_forward_upscale(
7593
7596
 
7594
7597
  // ggml_compute_forward_pad
7595
7598
 
7599
+ template<bool circular_t>
7596
7600
  static void ggml_compute_forward_pad_f32(
7597
7601
  const ggml_compute_params * params,
7598
7602
  ggml_tensor * dst) {
@@ -7617,23 +7621,40 @@ static void ggml_compute_forward_pad_f32(
7617
7621
  const int32_t lp3 = ggml_get_op_params_i32(dst, 6);
7618
7622
  const int32_t rp3 = ggml_get_op_params_i32(dst, 7);
7619
7623
 
7620
-
7621
7624
  // TODO: optimize
7622
7625
 
7623
7626
  for (int64_t i2 = 0; i2 < ne2; ++i2) {
7624
7627
  for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
7625
7628
  for (int64_t i0 = 0; i0 < ne0; ++i0) {
7626
7629
  for (int64_t i3 = 0; i3 < ne3; ++i3) {
7627
- const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
7628
- if ((i0 >= lp0 && i0 < ne0 - rp0) \
7629
- && (i1 >= lp1 && i1 < ne1 - rp1) \
7630
- && (i2 >= lp2 && i2 < ne2 - rp2) \
7631
- && (i3 >= lp3 && i3 < ne3 - rp3)) {
7632
- const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
7630
+ // circular means wrap around on a torus, so x and y loop around
7631
+ if constexpr (circular_t) {
7632
+ const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
7633
+ const int64_t src_i0 = ggml_wrap_around(i0 - lp0, ne00);
7634
+ const int64_t src_i1 = ggml_wrap_around(i1 - lp1, ne01);
7635
+ const int64_t src_i2 = ggml_wrap_around(i2 - lp2, ne02);
7636
+ const int64_t src_i3 = ggml_wrap_around(i3 - lp3, ne03);
7637
+
7638
+ const int64_t src_idx =
7639
+ src_i3*nb03 +
7640
+ src_i2*nb02 +
7641
+ src_i1*nb01 +
7642
+ src_i0*nb00;
7643
+
7633
7644
  const float * src_ptr = (const float *)((char *) src0->data + src_idx);
7634
7645
  dst_ptr[dst_idx] = *src_ptr;
7635
7646
  } else {
7636
- dst_ptr[dst_idx] = 0;
7647
+ const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
7648
+ if ((i0 >= lp0 && i0 < ne0 - rp0) \
7649
+ && (i1 >= lp1 && i1 < ne1 - rp1) \
7650
+ && (i2 >= lp2 && i2 < ne2 - rp2) \
7651
+ && (i3 >= lp3 && i3 < ne3 - rp3)) {
7652
+ const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
7653
+ const float * src_ptr = (const float *)((char *) src0->data + src_idx);
7654
+ dst_ptr[dst_idx] = *src_ptr;
7655
+ } else {
7656
+ dst_ptr[dst_idx] = 0;
7657
+ }
7637
7658
  }
7638
7659
  }
7639
7660
  }
@@ -7641,16 +7662,20 @@ static void ggml_compute_forward_pad_f32(
7641
7662
  }
7642
7663
  }
7643
7664
 
7665
+
7644
7666
  void ggml_compute_forward_pad(
7645
7667
  const ggml_compute_params * params,
7646
7668
  ggml_tensor * dst) {
7647
-
7648
7669
  const ggml_tensor * src0 = dst->src[0];
7649
-
7670
+ const bool circular = (bool) ggml_get_op_params_i32(dst, 8);
7650
7671
  switch (src0->type) {
7651
7672
  case GGML_TYPE_F32:
7652
7673
  {
7653
- ggml_compute_forward_pad_f32(params, dst);
7674
+ if (circular) {
7675
+ ggml_compute_forward_pad_f32<true>(params, dst);
7676
+ } else {
7677
+ ggml_compute_forward_pad_f32<false>(params, dst);
7678
+ }
7654
7679
  } break;
7655
7680
  default:
7656
7681
  {
@@ -7854,6 +7879,18 @@ void ggml_compute_forward_timestep_embedding(
7854
7879
 
7855
7880
  // ggml_compute_forward_argsort
7856
7881
 
7882
+ template<enum ggml_sort_order order>
7883
+ struct cmp_argsort {
7884
+ const float * data;
7885
+ bool operator()(int32_t a, int32_t b) const {
7886
+ if constexpr (order == GGML_SORT_ORDER_ASC) {
7887
+ return data[a] < data[b];
7888
+ } else {
7889
+ return data[a] > data[b];
7890
+ }
7891
+ }
7892
+ };
7893
+
7857
7894
  static void ggml_compute_forward_argsort_f32(
7858
7895
  const ggml_compute_params * params,
7859
7896
  ggml_tensor * dst) {
@@ -7872,23 +7909,25 @@ static void ggml_compute_forward_argsort_f32(
7872
7909
  ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(dst, 0);
7873
7910
 
7874
7911
  for (int64_t i = ith; i < nr; i += nth) {
7875
- int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
7876
7912
  const float * src_data = (float *)((char *) src0->data + i*nb01);
7877
7913
 
7914
+ int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
7915
+
7878
7916
  for (int64_t j = 0; j < ne0; j++) {
7879
7917
  dst_data[j] = j;
7880
7918
  }
7881
7919
 
7882
- // C doesn't have a functional sort, so we do a bubble sort instead
7883
- for (int64_t j = 0; j < ne0; j++) {
7884
- for (int64_t k = j + 1; k < ne0; k++) {
7885
- if ((order == GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
7886
- (order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
7887
- int32_t tmp = dst_data[j];
7888
- dst_data[j] = dst_data[k];
7889
- dst_data[k] = tmp;
7890
- }
7891
- }
7920
+ switch (order) {
7921
+ case GGML_SORT_ORDER_ASC:
7922
+ std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_ASC>{src_data});
7923
+ break;
7924
+
7925
+ case GGML_SORT_ORDER_DESC:
7926
+ std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_DESC>{src_data});
7927
+ break;
7928
+
7929
+ default:
7930
+ GGML_ABORT("invalid sort order");
7892
7931
  }
7893
7932
  }
7894
7933
  }
@@ -7911,12 +7950,78 @@ void ggml_compute_forward_argsort(
7911
7950
  }
7912
7951
  }
7913
7952
 
7953
+ // ggml_compute_forward_top_k
7954
+
7955
+ struct cmp_top_k {
7956
+ const float * data;
7957
+ bool operator()(int32_t a, int32_t b) const {
7958
+ return data[a] > data[b];
7959
+ }
7960
+ };
7961
+
7962
+ static void ggml_compute_forward_top_k_f32(
7963
+ const ggml_compute_params * params,
7964
+ ggml_tensor * dst) {
7965
+
7966
+ const ggml_tensor * src0 = dst->src[0];
7967
+
7968
+ GGML_TENSOR_UNARY_OP_LOCALS
7969
+
7970
+ GGML_ASSERT(nb0 == sizeof(float));
7971
+
7972
+ const int ith = params->ith;
7973
+ const int nth = params->nth;
7974
+
7975
+ const int64_t nr = ggml_nrows(src0);
7976
+
7977
+ const int top_k = ne0;
7978
+
7979
+ int32_t * tmp = (int32_t *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
7980
+
7981
+ for (int64_t i = ith; i < nr; i += nth) {
7982
+ const float * src_data = (float *)((char *) src0->data + i*nb01);
7983
+
7984
+ for (int64_t j = 0; j < ne00; j++) {
7985
+ tmp[j] = j;
7986
+ }
7987
+
7988
+ std::partial_sort(tmp, tmp + top_k, tmp + ne00, cmp_top_k{src_data});
7989
+
7990
+ int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
7991
+
7992
+ std::copy(tmp, tmp + top_k, dst_data);
7993
+
7994
+ // emphasize that the order is not important
7995
+ if (top_k > 1) {
7996
+ std::swap(dst_data[0], dst_data[1]);
7997
+ }
7998
+ }
7999
+ }
8000
+
8001
+ void ggml_compute_forward_top_k(
8002
+ const ggml_compute_params * params,
8003
+ ggml_tensor * dst) {
8004
+
8005
+ const ggml_tensor * src0 = dst->src[0];
8006
+
8007
+ switch (src0->type) {
8008
+ case GGML_TYPE_F32:
8009
+ {
8010
+ ggml_compute_forward_top_k_f32(params, dst);
8011
+ } break;
8012
+ default:
8013
+ {
8014
+ GGML_ABORT("fatal error");
8015
+ }
8016
+ }
8017
+ }
8018
+
7914
8019
  // ggml_compute_forward_flash_attn_ext
7915
8020
 
7916
- static void ggml_compute_forward_flash_attn_ext_f16(
8021
+ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
7917
8022
  const ggml_compute_params * params,
7918
- ggml_tensor * dst) {
7919
-
8023
+ ggml_tensor * dst,
8024
+ int ir0, int ir1) {
7920
8025
  const ggml_tensor * q = dst->src[0];
7921
8026
  const ggml_tensor * k = dst->src[1];
7922
8027
  const ggml_tensor * v = dst->src[2];
@@ -7932,9 +8037,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
7932
8037
  GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
7933
8038
  GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
7934
8039
 
7935
- const int ith = params->ith;
7936
- const int nth = params->nth;
7937
-
7938
8040
  const int64_t DK = nek0;
7939
8041
  const int64_t DV = nev0;
7940
8042
  const int64_t N = neq1;
@@ -7968,16 +8070,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
7968
8070
 
7969
8071
  // parallelize by q rows using ggml_vec_dot_f32
7970
8072
 
7971
- // total rows in q
7972
- const int nr = neq1*neq2*neq3;
7973
-
7974
- // rows per thread
7975
- const int dr = (nr + nth - 1)/nth;
7976
-
7977
- // row range for this thread
7978
- const int ir0 = dr*ith;
7979
- const int ir1 = MIN(ir0 + dr, nr);
7980
-
7981
8073
  float scale = 1.0f;
7982
8074
  float max_bias = 0.0f;
7983
8075
  float logit_softcap = 0.0f;
@@ -8004,6 +8096,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
8004
8096
  GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type");
8005
8097
  GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type");
8006
8098
 
8099
+ int ith = params->ith;
8100
+
8007
8101
  // loop over n_batch and n_head
8008
8102
  for (int ir = ir0; ir < ir1; ++ir) {
8009
8103
  // q indices
@@ -8135,7 +8229,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
8135
8229
  }
8136
8230
 
8137
8231
  // V /= S
8138
- const float S_inv = 1.0f/S;
8232
+ const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
8139
8233
  ggml_vec_scale_f32(DV, VKQ32, S_inv);
8140
8234
 
8141
8235
  // dst indices
@@ -8151,6 +8245,91 @@ static void ggml_compute_forward_flash_attn_ext_f16(
8151
8245
  }
8152
8246
  }
8153
8247
 
8248
+ static void ggml_compute_forward_flash_attn_ext_f16(
8249
+ const ggml_compute_params * params,
8250
+ ggml_tensor * dst) {
8251
+
8252
+ const ggml_tensor * q = dst->src[0];
8253
+ const ggml_tensor * k = dst->src[1];
8254
+ const ggml_tensor * v = dst->src[2];
8255
+
8256
+ GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
8257
+ GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
8258
+ GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
8259
+ GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
8260
+ GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
8261
+ GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
8262
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
8263
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
8264
+
8265
+ const int64_t DK = nek0;
8266
+ const int64_t DV = nev0;
8267
+ const int64_t N = neq1;
8268
+
8269
+ GGML_ASSERT(ne0 == DV);
8270
+ GGML_ASSERT(ne2 == N);
8271
+
8272
+ // input tensor rows must be contiguous
8273
+ GGML_ASSERT(nbq0 == ggml_type_size(q->type));
8274
+ GGML_ASSERT(nbk0 == ggml_type_size(k->type));
8275
+ GGML_ASSERT(nbv0 == ggml_type_size(v->type));
8276
+
8277
+ GGML_ASSERT(neq0 == DK);
8278
+ GGML_ASSERT(nek0 == DK);
8279
+ GGML_ASSERT(nev0 == DV);
8280
+
8281
+ GGML_ASSERT(neq1 == N);
8282
+
8283
+ // dst cannot be transposed or permuted
8284
+ GGML_ASSERT(nb0 == sizeof(float));
8285
+ GGML_ASSERT(nb0 <= nb1);
8286
+ GGML_ASSERT(nb1 <= nb2);
8287
+ GGML_ASSERT(nb2 <= nb3);
8288
+
8289
+ // parallelize by q rows using ggml_vec_dot_f32
8290
+
8291
+ // total rows in q
8292
+ const int64_t nr = neq1*neq2*neq3;
8293
+
8294
+ // rows per thread
8295
+ const int ith = params->ith;
8296
+ const int nth = params->nth;
8297
+
8298
+ // disable for NUMA
8299
+ const bool disable_chunking = ggml_is_numa();
8300
+
8301
+ // 4x chunks per thread
8302
+ int nth_scaled = nth * 4;
8303
+ int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
8304
+ int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
8305
+
8306
+ if (nth == 1 || nchunk < nth || disable_chunking) {
8307
+ nchunk = nth;
8308
+ }
8309
+
8310
+ if (ith == 0) {
8311
+ // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
8312
+ ggml_threadpool_chunk_set(params->threadpool, nth);
8313
+ }
8314
+
8315
+ ggml_barrier(params->threadpool);
8316
+
8317
+ // The number of elements in each chunk
8318
+ const int64_t dr = (nr + nchunk - 1) / nchunk;
8319
+
8320
+ // The first chunk comes from our thread_id, the rest will get auto-assigned.
8321
+ int current_chunk = ith;
8322
+
8323
+ while (current_chunk < nchunk) {
8324
+ const int64_t ir0 = dr * current_chunk;
8325
+ const int64_t ir1 = MIN(ir0 + dr, nr);
8326
+
8327
+ ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
8328
+
8329
+ current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
8330
+ }
8331
+ }
8332
+
8154
8333
  void ggml_compute_forward_flash_attn_ext(
8155
8334
  const ggml_compute_params * params,
8156
8335
  ggml_tensor * dst) {
@@ -8637,7 +8816,7 @@ static void ggml_compute_forward_ssm_scan_f32(
8637
8816
  // n_head
8638
8817
  for (int h = ih0; h < ih1; ++h) {
8639
8818
  // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8640
- const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
8819
+ const float dt_soft_plus = ggml_compute_softplus_f32(dt[h]);
8641
8820
  const float dA = expf(dt_soft_plus * A[h]);
8642
8821
  const int g = h / (nh / ng); // repeat_interleave
8643
8822
 
@@ -8734,7 +8913,7 @@ static void ggml_compute_forward_ssm_scan_f32(
8734
8913
  // n_head
8735
8914
  for (int h = ih0; h < ih1; ++h) {
8736
8915
  // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8737
- const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
8916
+ const float dt_soft_plus = ggml_compute_softplus_f32(dt[h]);
8738
8917
  const int g = h / (nh / ng); // repeat_interleave
8739
8918
 
8740
8919
  // dim
@@ -8997,6 +9176,34 @@ void ggml_compute_forward_unary(
8997
9176
  {
8998
9177
  ggml_compute_forward_exp(params, dst);
8999
9178
  } break;
9179
+ case GGML_UNARY_OP_FLOOR:
9180
+ {
9181
+ ggml_compute_forward_floor(params, dst);
9182
+ } break;
9183
+ case GGML_UNARY_OP_CEIL:
9184
+ {
9185
+ ggml_compute_forward_ceil(params, dst);
9186
+ } break;
9187
+ case GGML_UNARY_OP_ROUND:
9188
+ {
9189
+ ggml_compute_forward_round(params, dst);
9190
+ } break;
9191
+ case GGML_UNARY_OP_TRUNC:
9192
+ {
9193
+ ggml_compute_forward_trunc(params, dst);
9194
+ } break;
9195
+ case GGML_UNARY_OP_XIELU:
9196
+ {
9197
+ ggml_compute_forward_xielu(params, dst);
9198
+ } break;
9199
+ case GGML_UNARY_OP_EXPM1:
9200
+ {
9201
+ ggml_compute_forward_expm1(params, dst);
9202
+ } break;
9203
+ case GGML_UNARY_OP_SOFTPLUS:
9204
+ {
9205
+ ggml_compute_forward_softplus(params, dst);
9206
+ } break;
9000
9207
  default:
9001
9208
  {
9002
9209
  GGML_ABORT("fatal error");
@@ -9593,6 +9800,76 @@ void ggml_compute_forward_gla(
9593
9800
  }
9594
9801
  }
9595
9802
 
9803
+ static void ggml_compute_forward_solve_tri_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst) {
9804
+ const struct ggml_tensor * src0 = dst->src[0]; // A (lower triangular)
9805
+ const struct ggml_tensor * src1 = dst->src[1]; // B (RHS)
9806
+
9807
+ GGML_TENSOR_BINARY_OP_LOCALS;
9808
+
9809
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
9810
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
9811
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
9812
+
9813
+ GGML_ASSERT(ne00 == ne01); // A must be square
9814
+ GGML_ASSERT(ne0 == ne10); // solution cols == B cols
9815
+ GGML_ASSERT(ne1 == ne11); // solution rows == B rows
9816
+
9817
+ GGML_ASSERT(ne02 == ne12 && ne12 == ne2);
9818
+ GGML_ASSERT(ne03 == ne13 && ne13 == ne3);
9819
+
9820
+ const int ith = params->ith;
9821
+ const int nth = params->nth;
9822
+
9823
+ const int64_t k = ne10; // number of RHS columns
9824
+ const int64_t n = ne11; // A is n×n
9825
+ const int64_t nr = ne02 * ne03 * k; // we're parallelizing on columns here, so seq x token x column will be the unit
9826
+
9827
+ // chunks per thread
9828
+ const int64_t dr = (nr + nth - 1)/nth;
9829
+
9830
+ // chunk range for this thread
9831
+ const int64_t ir0 = dr*ith;
9832
+ const int64_t ir1 = MIN(ir0 + dr, nr);
9833
+
9834
+ const float * A = (const float *) src0->data; // [n, n, B1, B2]
9835
+ const float * B = (const float *) src1->data; // [n, k, B1, B2]
9836
+ float * X = ( float *) dst->data; // [n, k, B1, B2]
9837
+
9838
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
9839
+ const int64_t i03 = ir/(ne02*k);
9840
+ const int64_t i02 = (ir - i03*ne02*k)/k;
9841
+ const int64_t i01 = (ir - i03*ne02*k - i02*k);
9842
+
9843
+ const float * A_batch = A + i02 * nb02 / sizeof(float) + i03 * nb03 / sizeof(float);
9844
+ const float * B_batch = B + i02 * nb12 / sizeof(float) + i03 * nb13 / sizeof(float);
9845
+
9846
+ float * X_batch = X + i02 * nb2 / sizeof(float) + i03 * nb3 / sizeof(float);
9847
+
9848
+ for (int64_t i00 = 0; i00 < n; ++i00) {
9849
+ float sum = 0.0f;
9850
+ for (int64_t t = 0; t < i00; ++t) {
9851
+ sum += A_batch[i00 * n + t] * X_batch[t * k + i01];
9852
+ }
9853
+
9854
+ const float diag = A_batch[i00 * n + i00];
9855
+ assert(diag != 0.0f && "Zero diagonal in triangular matrix");
9856
+
9857
+ X_batch[i00 * k + i01] = (B_batch[i00 * k + i01] - sum) / diag;
9858
+ }
9859
+ }
9860
+ }
9861
+
9862
+ void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst) {
9863
+ const ggml_tensor * src0 = dst->src[0];
9864
+ const ggml_tensor * src1 = dst->src[1];
9865
+
9866
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
9867
+ ggml_compute_forward_solve_tri_f32(params, dst);
9868
+ } else {
9869
+ GGML_ABORT("fatal error");
9870
+ }
9871
+ }
9872
+
9596
9873
  // ggml_compute_forward_rwkv_wkv7
9597
9874
 
9598
9875
  static void ggml_compute_forward_rwkv_wkv7_f32(