whispercpp 1.3.3 → 1.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (963) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +60 -43
  3. data/ext/extconf.rb +2 -2
  4. data/ext/ruby_whisper.c +14 -2
  5. data/ext/ruby_whisper.h +39 -0
  6. data/ext/ruby_whisper_context.c +22 -22
  7. data/ext/ruby_whisper_model.c +12 -12
  8. data/ext/ruby_whisper_params.c +79 -25
  9. data/ext/ruby_whisper_segment.c +84 -19
  10. data/ext/ruby_whisper_token.c +351 -0
  11. data/ext/ruby_whisper_transcribe.cpp +1 -1
  12. data/ext/ruby_whisper_vad_context.c +75 -0
  13. data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
  14. data/ext/ruby_whisper_vad_segment.c +139 -0
  15. data/ext/ruby_whisper_vad_segments.c +106 -0
  16. data/ext/sources/CMakeLists.txt +4 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  19. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  20. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  21. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  22. data/ext/sources/examples/CMakeLists.txt +1 -0
  23. data/ext/sources/examples/addon.node/addon.cpp +19 -19
  24. data/ext/sources/examples/addon.node/index.js +7 -5
  25. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  26. data/ext/sources/examples/bench/bench.cpp +26 -16
  27. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  28. data/ext/sources/examples/cli/cli.cpp +122 -111
  29. data/ext/sources/examples/command/command.cpp +26 -24
  30. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  31. data/ext/sources/examples/common-ggml.cpp +2 -0
  32. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  33. data/ext/sources/examples/lsp/lsp.cpp +19 -17
  34. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  35. data/ext/sources/examples/server/server.cpp +34 -24
  36. data/ext/sources/examples/server.py +6 -1
  37. data/ext/sources/examples/stream/stream.cpp +4 -2
  38. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  39. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  40. data/ext/sources/examples/talk-llama/CMakeLists.txt +7 -3
  41. data/ext/sources/examples/talk-llama/llama-adapter.cpp +113 -7
  42. data/ext/sources/examples/talk-llama/llama-adapter.h +13 -1
  43. data/ext/sources/examples/talk-llama/llama-arch.cpp +2136 -1491
  44. data/ext/sources/examples/talk-llama/llama-arch.h +125 -3
  45. data/ext/sources/examples/talk-llama/llama-batch.cpp +174 -100
  46. data/ext/sources/examples/talk-llama/llama-batch.h +46 -20
  47. data/ext/sources/examples/talk-llama/llama-chat.cpp +199 -8
  48. data/ext/sources/examples/talk-llama/llama-chat.h +11 -0
  49. data/ext/sources/examples/talk-llama/llama-context.cpp +1213 -413
  50. data/ext/sources/examples/talk-llama/llama-context.h +99 -36
  51. data/ext/sources/examples/talk-llama/llama-cparams.h +5 -4
  52. data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
  53. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  54. data/ext/sources/examples/talk-llama/llama-graph.cpp +883 -294
  55. data/ext/sources/examples/talk-llama/llama-graph.h +361 -161
  56. data/ext/sources/examples/talk-llama/llama-hparams.cpp +144 -6
  57. data/ext/sources/examples/talk-llama/llama-hparams.h +100 -23
  58. data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
  59. data/ext/sources/examples/talk-llama/llama-impl.h +3 -1
  60. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +328 -0
  61. data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
  62. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2100 -0
  63. data/ext/sources/examples/talk-llama/llama-kv-cache.h +373 -27
  64. data/ext/sources/examples/talk-llama/llama-kv-cells.h +124 -30
  65. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +63 -41
  66. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
  67. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +77 -35
  68. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +15 -16
  69. data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
  70. data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
  71. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  72. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +93 -9
  73. data/ext/sources/examples/talk-llama/llama-model-loader.h +9 -2
  74. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
  75. data/ext/sources/examples/talk-llama/llama-model.cpp +3369 -10145
  76. data/ext/sources/examples/talk-llama/llama-model.h +104 -12
  77. data/ext/sources/examples/talk-llama/llama-quant.cpp +53 -30
  78. data/ext/sources/examples/talk-llama/llama-sampling.cpp +1520 -324
  79. data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
  80. data/ext/sources/examples/talk-llama/llama-vocab.cpp +562 -39
  81. data/ext/sources/examples/talk-llama/llama-vocab.h +50 -0
  82. data/ext/sources/examples/talk-llama/llama.cpp +794 -12
  83. data/ext/sources/examples/talk-llama/llama.h +246 -190
  84. data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
  85. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  86. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  87. data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
  88. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  89. data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
  90. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
  91. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
  92. data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
  93. data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
  94. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  95. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  96. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  97. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  98. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  99. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  100. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  101. data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
  102. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  103. data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
  104. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
  105. data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
  106. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  107. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
  108. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  109. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  110. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  111. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
  112. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  113. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  114. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  115. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  116. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  117. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  118. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  119. data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
  120. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  121. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  122. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
  123. data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
  124. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
  125. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  126. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
  127. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  128. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
  129. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  130. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  131. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  132. data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
  133. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  134. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  135. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  136. data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
  137. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  138. data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
  139. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
  140. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
  141. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
  142. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  143. data/ext/sources/examples/talk-llama/models/models.h +569 -0
  144. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
  145. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  146. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
  147. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  148. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  149. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  150. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  151. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  153. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  154. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  155. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  156. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  157. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  158. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  159. data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
  160. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  161. data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
  162. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  163. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  164. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  165. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  166. data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
  167. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
  168. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
  169. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
  170. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
  171. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  172. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  173. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
  174. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  175. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  176. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
  177. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  178. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  179. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  180. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  181. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  182. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  183. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  184. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  185. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  186. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  187. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  188. data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
  189. data/ext/sources/examples/talk-llama/unicode.cpp +309 -16
  190. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  191. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  192. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
  193. data/ext/sources/examples/whisper.wasm/index-tmpl.html +18 -17
  194. data/ext/sources/ggml/CMakeLists.txt +135 -79
  195. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  196. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  197. data/ext/sources/ggml/include/ggml-backend.h +21 -2
  198. data/ext/sources/ggml/include/ggml-cpu.h +2 -1
  199. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  200. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  201. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  202. data/ext/sources/ggml/include/ggml-rpc.h +8 -11
  203. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  204. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  205. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  206. data/ext/sources/ggml/include/ggml.h +406 -23
  207. data/ext/sources/ggml/src/CMakeLists.txt +99 -13
  208. data/ext/sources/ggml/src/ggml-alloc.c +368 -161
  209. data/ext/sources/ggml/src/ggml-backend-impl.h +5 -5
  210. data/ext/sources/ggml/src/ggml-backend-reg.cpp +55 -14
  211. data/ext/sources/ggml/src/ggml-backend.cpp +290 -57
  212. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
  213. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -13
  214. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
  215. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +59 -45
  216. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
  217. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2586 -1917
  218. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +348 -309
  219. data/ext/sources/ggml/src/ggml-cann/common.h +350 -133
  220. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +894 -625
  221. data/ext/sources/ggml/src/ggml-common.h +17 -0
  222. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +167 -75
  223. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -2
  224. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  225. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +560 -622
  226. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1002 -270
  227. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +107 -587
  228. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
  229. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  230. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +373 -486
  231. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
  232. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  233. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
  234. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
  235. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
  236. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4682 -1660
  237. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +82 -4
  238. data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
  239. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +18 -9
  240. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +263 -111
  241. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +39 -28
  242. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +683 -82
  243. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +38 -43
  244. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +435 -119
  245. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
  246. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1234 -1182
  247. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  248. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +2167 -1480
  249. data/ext/sources/ggml/src/ggml-cpu/ops.h +10 -12
  250. data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
  251. data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
  252. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1132 -81
  253. data/ext/sources/ggml/src/ggml-cpu/repack.h +36 -0
  254. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +120 -93
  255. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1025 -0
  256. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  257. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  258. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  259. data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
  260. data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
  261. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
  262. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  263. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +294 -27
  264. data/ext/sources/ggml/src/ggml-cpu/vec.h +606 -48
  265. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +92 -17
  266. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  267. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  268. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  269. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
  270. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  271. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  272. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  273. data/ext/sources/ggml/src/ggml-cuda/common.cuh +588 -128
  274. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  275. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  276. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
  278. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +25 -0
  279. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  280. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +335 -485
  281. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  282. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  283. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  284. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  285. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  286. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  287. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  288. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +519 -378
  289. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +750 -637
  290. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +49 -0
  291. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1244 -0
  292. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +586 -0
  293. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +98 -61
  294. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  295. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +230 -197
  296. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  297. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  298. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  299. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  300. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1557 -294
  301. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  302. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  303. data/ext/sources/ggml/src/ggml-cuda/mean.cu +57 -2
  304. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +915 -69
  305. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +171 -0
  306. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +835 -0
  307. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  308. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  309. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +109 -67
  310. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1601 -733
  311. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +802 -0
  312. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +12 -0
  313. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +286 -149
  314. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
  315. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  316. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  317. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  318. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  319. data/ext/sources/ggml/src/ggml-cuda/pad.cu +86 -32
  320. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  321. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  322. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +163 -10
  323. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  324. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  325. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  326. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  327. data/ext/sources/ggml/src/ggml-cuda/rope.cu +207 -98
  328. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  329. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  330. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +330 -0
  331. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  332. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  333. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  334. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  335. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  336. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +325 -61
  337. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  338. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  339. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -12
  340. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +291 -104
  341. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  342. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
  343. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  344. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  345. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  346. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  347. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  348. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  349. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  350. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  351. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  352. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  353. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  354. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  355. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  356. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  357. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  358. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  359. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  360. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  361. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  362. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  363. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  364. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  365. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  366. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  367. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  368. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  369. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  370. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  371. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  372. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  373. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  374. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  375. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  376. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  377. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  378. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  379. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  380. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  381. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  382. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  383. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  384. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  385. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  386. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  387. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  388. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +40 -19
  389. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  390. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  391. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  392. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  393. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  394. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  395. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  396. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  397. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  398. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  399. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  400. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  401. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  402. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  403. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  404. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  405. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  406. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
  407. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  408. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +351 -0
  409. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +21 -0
  410. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  411. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  412. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  413. data/ext/sources/ggml/src/ggml-cuda/unary.cu +189 -5
  414. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +44 -0
  415. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +248 -6
  416. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  417. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +8 -0
  418. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +70 -37
  419. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +10 -3
  420. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
  421. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
  422. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
  423. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
  424. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
  425. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  426. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
  427. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  428. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  429. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
  430. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
  431. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
  432. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
  433. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  434. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
  435. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
  436. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
  437. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
  438. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
  439. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
  440. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
  441. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
  442. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
  443. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  444. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
  445. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
  446. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
  447. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  448. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
  449. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
  450. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  451. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +16 -13
  452. data/ext/sources/ggml/src/ggml-impl.h +186 -15
  453. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
  454. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  455. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  456. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  457. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +609 -0
  458. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1743 -0
  459. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +273 -0
  460. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1686 -0
  461. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +356 -61
  462. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +4161 -0
  463. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +94 -0
  464. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +724 -0
  465. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +4495 -1876
  466. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +21 -9
  467. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +29 -0
  468. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4005 -427
  469. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  470. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  471. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  472. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  473. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +147 -0
  474. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
  475. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
  476. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  477. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  478. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +371 -0
  479. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  480. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  481. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  482. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  483. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  484. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
  485. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
  486. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  487. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  488. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
  489. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  490. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  491. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  492. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +146 -0
  493. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +147 -0
  494. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  495. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  496. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  497. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  498. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  499. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  500. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  501. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  502. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  503. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  504. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  505. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +94 -0
  506. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  507. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  508. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +208 -0
  509. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  510. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  511. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  512. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  513. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
  514. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  515. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  516. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  517. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
  518. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +33 -0
  519. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
  520. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  521. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  522. data/ext/sources/ggml/src/ggml-quants.c +111 -16
  523. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  524. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +497 -195
  525. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
  526. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
  527. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  528. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +8 -0
  529. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +6 -5
  530. data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
  531. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +50 -30
  532. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
  533. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +200 -99
  534. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  535. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  536. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -309
  537. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  538. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  539. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
  540. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +77 -34
  541. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +397 -314
  542. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
  543. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
  544. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
  545. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +643 -413
  546. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  547. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
  548. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
  549. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +223 -132
  550. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +230 -55
  551. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  552. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  553. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  554. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  555. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  556. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
  557. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  558. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
  559. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  560. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  561. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  562. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  563. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +65 -59
  564. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  565. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  566. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  567. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  568. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +330 -165
  569. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  570. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  571. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  572. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
  573. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +60 -6
  574. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
  575. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
  576. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +7398 -2635
  577. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  578. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
  579. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +43 -3
  580. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  581. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  582. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  583. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +15 -6
  584. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +56 -39
  585. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  586. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  587. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  588. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  589. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  590. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  591. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +347 -0
  592. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  593. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  594. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +5 -5
  595. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +67 -13
  596. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  597. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  598. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  599. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +158 -16
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +38 -3
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +2 -2
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +3 -2
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +7 -6
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +5 -3
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +4 -4
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +2 -2
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +4 -4
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +4 -4
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +2 -2
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +103 -36
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +220 -0
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +139 -45
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +113 -38
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +75 -14
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +19 -17
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +21 -12
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +28 -18
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +4 -0
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +33 -17
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +125 -0
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +227 -0
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +20 -14
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +143 -0
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -556
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +230 -51
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +566 -0
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +90 -223
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +195 -0
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +41 -5
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  692. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +59 -9
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  695. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  696. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  697. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +104 -14
  698. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  699. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  700. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  701. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
  702. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  703. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -52
  704. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -35
  705. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -35
  706. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
  707. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
  708. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  709. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +5 -0
  710. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +3 -3
  711. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  712. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  713. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  714. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  715. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +30 -8
  716. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +6 -2
  717. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  718. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  719. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  720. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  721. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  722. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  723. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  724. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  725. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
  726. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  727. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  728. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  729. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +16 -6
  730. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  731. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  732. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  733. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  734. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +5 -4
  735. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  736. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  737. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  738. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  739. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  740. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +435 -24
  741. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +148 -6
  742. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.glsl +25 -0
  743. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +619 -177
  744. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  745. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +80 -0
  746. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
  747. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3087 -0
  748. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  749. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
  750. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  751. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  752. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
  753. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +147 -0
  754. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
  755. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  756. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  757. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  758. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  759. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
  760. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
  761. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
  762. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
  763. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +123 -0
  764. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  765. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
  766. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
  767. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  768. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  769. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
  770. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  771. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  772. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  773. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  774. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  775. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  776. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  777. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  778. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
  779. data/ext/sources/ggml/src/ggml.c +901 -129
  780. data/ext/sources/ggml/src/gguf.cpp +8 -1
  781. data/ext/sources/include/whisper.h +1 -0
  782. data/ext/sources/src/CMakeLists.txt +3 -1
  783. data/ext/sources/src/whisper.cpp +124 -81
  784. data/ext/sources/tests/CMakeLists.txt +8 -1
  785. data/ext/sources/tests/test-vad-full.cpp +7 -5
  786. data/ext/sources/tests/test-vad.cpp +3 -3
  787. data/extsources.rb +1 -0
  788. data/lib/whisper/model/uri.rb +17 -18
  789. data/sig/whisper.rbs +126 -2
  790. data/test/test_params.rb +24 -8
  791. data/test/test_segment.rb +0 -1
  792. data/test/test_token.rb +70 -0
  793. data/test/test_vad.rb +1 -1
  794. data/test/test_vad_context.rb +50 -0
  795. data/test/test_vad_segment.rb +19 -0
  796. data/test/test_vad_segments.rb +16 -0
  797. data/test/test_whisper.rb +8 -1
  798. data/whispercpp.gemspec +1 -1
  799. metadata +439 -179
  800. data/ext/sources/build-xcframework.sh +0 -547
  801. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
  802. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
  803. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
  804. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  805. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  806. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  807. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  808. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  809. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  810. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  811. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  812. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  813. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  814. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  815. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  816. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  817. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  818. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  819. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  820. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  821. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  822. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  823. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  824. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  825. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  826. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  827. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -506
  828. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +0 -11
  829. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  830. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  831. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  832. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  833. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  834. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  835. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  836. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  837. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  838. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  839. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  840. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  841. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  842. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  843. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  844. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  845. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  846. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  847. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  848. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  849. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  850. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  851. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  852. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  853. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  854. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  855. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  856. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  857. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  858. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  859. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  860. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  861. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  862. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  863. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  864. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  865. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  866. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  867. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  868. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  869. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  870. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  871. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  872. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  873. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  874. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  875. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  876. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  877. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  878. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  879. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  880. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  881. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  882. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  883. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  884. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  885. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  886. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  887. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  888. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  889. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  890. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  891. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  892. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  893. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  894. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  895. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  896. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  897. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  898. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  899. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  900. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  901. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  902. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  903. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  904. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  905. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  906. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  907. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  908. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  909. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  910. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  911. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  912. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  913. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  914. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  915. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  916. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  917. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  918. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  919. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  920. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  921. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  922. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  923. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  924. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  925. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  926. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  927. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  928. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  929. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  930. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  931. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  932. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  933. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  934. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  935. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  936. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  937. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  938. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  939. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  940. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  941. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  942. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  943. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  944. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  945. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  946. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  947. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  948. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  949. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  950. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  951. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  952. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  953. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  954. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
  955. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +0 -162
  956. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +0 -118
  957. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -99
  958. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -58
  959. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  960. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  961. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  962. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  963. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
@@ -7,7 +7,10 @@
7
7
  #include "unary-ops.h"
8
8
  #include "vec.h"
9
9
 
10
- #include <float.h>
10
+ #include <cfloat>
11
+ #include <algorithm>
12
+ #include <cmath>
13
+ #include <functional>
11
14
 
12
15
  // ggml_compute_forward_dup
13
16
 
@@ -40,13 +43,15 @@ static void ggml_compute_forward_dup_same_cont(
40
43
  }
41
44
  }
42
45
 
43
- static void ggml_compute_forward_dup_f16(
46
+ template<typename src_t, typename dst_t>
47
+ static void ggml_compute_forward_dup_flt(
44
48
  const ggml_compute_params * params,
45
49
  ggml_tensor * dst) {
46
50
 
47
51
  const ggml_tensor * src0 = dst->src[0];
48
52
 
49
53
  GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
54
+ GGML_ASSERT(!ggml_is_quantized(src0->type) && !ggml_is_quantized(dst->type));
50
55
 
51
56
  GGML_TENSOR_UNARY_OP_LOCALS
52
57
 
@@ -61,6 +66,7 @@ static void ggml_compute_forward_dup_f16(
61
66
  const int ir0 = dr * ith;
62
67
  const int ir1 = MIN(ir0 + dr, nr);
63
68
 
69
+ // case: type & row size equal
64
70
  if (src0->type == dst->type &&
65
71
  ne00 == ne0 &&
66
72
  nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
@@ -79,275 +85,11 @@ static void ggml_compute_forward_dup_f16(
79
85
  return;
80
86
  }
81
87
 
82
- // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
83
-
84
- if (ggml_is_contiguous(dst)) {
85
- if (nb00 == sizeof(ggml_fp16_t)) {
86
- if (dst->type == GGML_TYPE_F16) {
87
- size_t id = 0;
88
- const size_t rs = ne00 * nb00;
89
- char * dst_ptr = (char *) dst->data;
90
-
91
- for (int i03 = 0; i03 < ne03; i03++) {
92
- for (int i02 = 0; i02 < ne02; i02++) {
93
- id += rs * ir0;
94
- for (int i01 = ir0; i01 < ir1; i01++) {
95
- const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
96
- memcpy(dst_ptr + id, src0_ptr, rs);
97
- id += rs;
98
- }
99
- id += rs * (ne01 - ir1);
100
- }
101
- }
102
- } else if (dst->type == GGML_TYPE_F32) {
103
- size_t id = 0;
104
- float * dst_ptr = (float *) dst->data;
105
-
106
- for (int i03 = 0; i03 < ne03; i03++) {
107
- for (int i02 = 0; i02 < ne02; i02++) {
108
- id += ne00 * ir0;
109
- for (int i01 = ir0; i01 < ir1; i01++) {
110
- const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
111
- for (int i00 = 0; i00 < ne00; i00++) {
112
- dst_ptr[id] = GGML_CPU_FP16_TO_FP32(src0_ptr[i00]);
113
- id++;
114
- }
115
- }
116
- id += ne00 * (ne01 - ir1);
117
- }
118
- }
119
- } else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
120
- ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
121
- float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
122
-
123
- size_t id = 0;
124
- size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
125
- char * dst_ptr = (char *) dst->data;
126
-
127
- for (int i03 = 0; i03 < ne03; i03++) {
128
- for (int i02 = 0; i02 < ne02; i02++) {
129
- id += rs * ir0;
130
- for (int i01 = ir0; i01 < ir1; i01++) {
131
- const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
132
-
133
- for (int i00 = 0; i00 < ne00; i00++) {
134
- src0_f32[i00] = GGML_CPU_FP16_TO_FP32(src0_ptr[i00]);
135
- }
136
-
137
- quantize_row_q(src0_f32, dst_ptr + id, ne00);
138
- id += rs;
139
- }
140
- id += rs * (ne01 - ir1);
141
- }
142
- }
143
- } else {
144
- GGML_ABORT("fatal error"); // TODO: implement
145
- }
146
- } else {
147
- //printf("%s: this is not optimal - fix me\n", __func__);
148
-
149
- if (dst->type == GGML_TYPE_F32) {
150
- size_t id = 0;
151
- float * dst_ptr = (float *) dst->data;
152
-
153
- for (int i03 = 0; i03 < ne03; i03++) {
154
- for (int i02 = 0; i02 < ne02; i02++) {
155
- id += ne00 * ir0;
156
- for (int i01 = ir0; i01 < ir1; i01++) {
157
- for (int i00 = 0; i00 < ne00; i00++) {
158
- const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
159
-
160
- dst_ptr[id] = GGML_CPU_FP16_TO_FP32(*src0_ptr);
161
- id++;
162
- }
163
- }
164
- id += ne00 * (ne01 - ir1);
165
- }
166
- }
167
- } else if (dst->type == GGML_TYPE_F16) {
168
- size_t id = 0;
169
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
170
-
171
- for (int i03 = 0; i03 < ne03; i03++) {
172
- for (int i02 = 0; i02 < ne02; i02++) {
173
- id += ne00 * ir0;
174
- for (int i01 = ir0; i01 < ir1; i01++) {
175
- for (int i00 = 0; i00 < ne00; i00++) {
176
- const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
177
-
178
- dst_ptr[id] = *src0_ptr;
179
- id++;
180
- }
181
- }
182
- id += ne00 * (ne01 - ir1);
183
- }
184
- }
185
- } else {
186
- GGML_ABORT("fatal error"); // TODO: implement
187
- }
188
- }
189
- return;
190
- }
191
-
192
- // dst counters
193
- int64_t i10 = 0;
194
- int64_t i11 = 0;
195
- int64_t i12 = 0;
196
- int64_t i13 = 0;
197
-
198
- if (dst->type == GGML_TYPE_F16) {
199
- for (int64_t i03 = 0; i03 < ne03; i03++) {
200
- for (int64_t i02 = 0; i02 < ne02; i02++) {
201
- i10 += ne00 * ir0;
202
- while (i10 >= ne0) {
203
- i10 -= ne0;
204
- if (++i11 == ne1) {
205
- i11 = 0;
206
- if (++i12 == ne2) {
207
- i12 = 0;
208
- if (++i13 == ne3) {
209
- i13 = 0;
210
- }
211
- }
212
- }
213
- }
214
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
215
- for (int64_t i00 = 0; i00 < ne00; i00++) {
216
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
217
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
218
-
219
- memcpy(dst_ptr, src0_ptr, sizeof(ggml_fp16_t));
220
-
221
- if (++i10 == ne00) {
222
- i10 = 0;
223
- if (++i11 == ne01) {
224
- i11 = 0;
225
- if (++i12 == ne02) {
226
- i12 = 0;
227
- if (++i13 == ne03) {
228
- i13 = 0;
229
- }
230
- }
231
- }
232
- }
233
- }
234
- }
235
- i10 += ne00 * (ne01 - ir1);
236
- while (i10 >= ne0) {
237
- i10 -= ne0;
238
- if (++i11 == ne1) {
239
- i11 = 0;
240
- if (++i12 == ne2) {
241
- i12 = 0;
242
- if (++i13 == ne3) {
243
- i13 = 0;
244
- }
245
- }
246
- }
247
- }
248
- }
249
- }
250
- } else if (dst->type == GGML_TYPE_F32) {
251
- for (int64_t i03 = 0; i03 < ne03; i03++) {
252
- for (int64_t i02 = 0; i02 < ne02; i02++) {
253
- i10 += ne00 * ir0;
254
- while (i10 >= ne0) {
255
- i10 -= ne0;
256
- if (++i11 == ne1) {
257
- i11 = 0;
258
- if (++i12 == ne2) {
259
- i12 = 0;
260
- if (++i13 == ne3) {
261
- i13 = 0;
262
- }
263
- }
264
- }
265
- }
266
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
267
- for (int64_t i00 = 0; i00 < ne00; i00++) {
268
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
269
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
270
-
271
- *(float *) dst_ptr = GGML_CPU_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
272
-
273
- if (++i10 == ne0) {
274
- i10 = 0;
275
- if (++i11 == ne1) {
276
- i11 = 0;
277
- if (++i12 == ne2) {
278
- i12 = 0;
279
- if (++i13 == ne3) {
280
- i13 = 0;
281
- }
282
- }
283
- }
284
- }
285
- }
286
- }
287
- i10 += ne00 * (ne01 - ir1);
288
- while (i10 >= ne0) {
289
- i10 -= ne0;
290
- if (++i11 == ne1) {
291
- i11 = 0;
292
- if (++i12 == ne2) {
293
- i12 = 0;
294
- if (++i13 == ne3) {
295
- i13 = 0;
296
- }
297
- }
298
- }
299
- }
300
- }
301
- }
302
- } else {
303
- GGML_ABORT("fatal error"); // TODO: implement
304
- }
305
- }
306
-
307
- static void ggml_compute_forward_dup_bf16(
308
- const ggml_compute_params * params,
309
- ggml_tensor * dst) {
310
-
311
- const ggml_tensor * src0 = dst->src[0];
312
-
313
- GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
314
-
315
- GGML_TENSOR_UNARY_OP_LOCALS
316
-
317
- const int ith = params->ith; // thread index
318
- const int nth = params->nth; // number of threads
319
-
320
- // parallelize by rows
321
- const int nr = ne01;
322
- // number of rows per thread
323
- const int dr = (nr + nth - 1) / nth;
324
- // row range for this thread
325
- const int ir0 = dr * ith;
326
- const int ir1 = MIN(ir0 + dr, nr);
327
-
328
- if (src0->type == dst->type &&
329
- ne00 == ne0 &&
330
- nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
331
- // copy by rows
332
- const size_t rs = ne00*nb00;
333
- for (int64_t i03 = 0; i03 < ne03; i03++) {
334
- for (int64_t i02 = 0; i02 < ne02; i02++) {
335
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
336
- memcpy(
337
- ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
338
- ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
339
- rs);
340
- }
341
- }
342
- }
343
- return;
344
- }
345
-
346
- // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
347
-
88
+ // case: dst tensor is contiguous
348
89
  if (ggml_is_contiguous(dst)) {
349
- if (nb00 == sizeof(ggml_bf16_t)) {
350
- if (dst->type == GGML_TYPE_BF16) {
90
+ if (nb00 == sizeof(src_t)) {
91
+ if constexpr (std::is_same_v<dst_t, src_t>) {
92
+ // same type
351
93
  size_t id = 0;
352
94
  const size_t rs = ne00 * nb00;
353
95
  char * dst_ptr = (char *) dst->data;
@@ -363,434 +105,58 @@ static void ggml_compute_forward_dup_bf16(
363
105
  id += rs * (ne01 - ir1);
364
106
  }
365
107
  }
366
- } else if (dst->type == GGML_TYPE_F16) {
367
- size_t id = 0;
368
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
369
-
370
- for (int i03 = 0; i03 < ne03; i03++) {
371
- for (int i02 = 0; i02 < ne02; i02++) {
372
- id += ne00 * ir0;
373
- for (int i01 = ir0; i01 < ir1; i01++) {
374
- const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
375
- for (int i00 = 0; i00 < ne00; i00++) {
376
- dst_ptr[id] = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(src0_ptr[i00]));
377
- id++;
378
- }
379
- }
380
- id += ne00 * (ne01 - ir1);
381
- }
382
- }
383
- } else if (dst->type == GGML_TYPE_F32) {
384
- size_t id = 0;
385
- float * dst_ptr = (float *) dst->data;
386
-
387
- for (int i03 = 0; i03 < ne03; i03++) {
388
- for (int i02 = 0; i02 < ne02; i02++) {
389
- id += ne00 * ir0;
390
- for (int i01 = ir0; i01 < ir1; i01++) {
391
- const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
392
- for (int i00 = 0; i00 < ne00; i00++) {
393
- dst_ptr[id] = GGML_BF16_TO_FP32(src0_ptr[i00]);
394
- id++;
395
- }
396
- }
397
- id += ne00 * (ne01 - ir1);
398
- }
399
- }
400
- } else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
401
- ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
402
- float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
403
-
404
- size_t id = 0;
405
- size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
406
- char * dst_ptr = (char *) dst->data;
407
-
408
- for (int i03 = 0; i03 < ne03; i03++) {
409
- for (int i02 = 0; i02 < ne02; i02++) {
410
- id += rs * ir0;
411
- for (int i01 = ir0; i01 < ir1; i01++) {
412
- const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
413
-
414
- for (int i00 = 0; i00 < ne00; i00++) {
415
- src0_f32[i00] = GGML_BF16_TO_FP32(src0_ptr[i00]);
416
- }
417
-
418
- quantize_row_q(src0_f32, dst_ptr + id, ne00);
419
- id += rs;
420
- }
421
- id += rs * (ne01 - ir1);
422
- }
423
- }
424
- } else {
425
- GGML_ABORT("fatal error"); // TODO: implement
426
- }
427
- } else {
428
- //printf("%s: this is not optimal - fix me\n", __func__);
429
-
430
- if (dst->type == GGML_TYPE_F32) {
431
- size_t id = 0;
432
- float * dst_ptr = (float *) dst->data;
433
-
434
- for (int i03 = 0; i03 < ne03; i03++) {
435
- for (int i02 = 0; i02 < ne02; i02++) {
436
- id += ne00 * ir0;
437
- for (int i01 = ir0; i01 < ir1; i01++) {
438
- for (int i00 = 0; i00 < ne00; i00++) {
439
- const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
440
-
441
- dst_ptr[id] = GGML_BF16_TO_FP32(*src0_ptr);
442
- id++;
443
- }
444
- }
445
- id += ne00 * (ne01 - ir1);
446
- }
447
- }
448
- } else if (dst->type == GGML_TYPE_BF16) {
449
- size_t id = 0;
450
- ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
451
-
452
- for (int i03 = 0; i03 < ne03; i03++) {
453
- for (int i02 = 0; i02 < ne02; i02++) {
454
- id += ne00 * ir0;
455
- for (int i01 = ir0; i01 < ir1; i01++) {
456
- for (int i00 = 0; i00 < ne00; i00++) {
457
- const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
458
-
459
- dst_ptr[id] = *src0_ptr;
460
- id++;
461
- }
462
- }
463
- id += ne00 * (ne01 - ir1);
464
- }
465
- }
466
- } else if (dst->type == GGML_TYPE_F16) {
467
- size_t id = 0;
468
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
469
-
470
- for (int i03 = 0; i03 < ne03; i03++) {
471
- for (int i02 = 0; i02 < ne02; i02++) {
472
- id += ne00 * ir0;
473
- for (int i01 = ir0; i01 < ir1; i01++) {
474
- for (int i00 = 0; i00 < ne00; i00++) {
475
- const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
476
-
477
- dst_ptr[id] = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(*src0_ptr));
478
- id++;
479
- }
480
- }
481
- id += ne00 * (ne01 - ir1);
482
- }
483
- }
484
- } else {
485
- GGML_ABORT("fatal error"); // TODO: implement
486
- }
487
- }
488
- return;
489
- }
490
-
491
- // dst counters
492
- int64_t i10 = 0;
493
- int64_t i11 = 0;
494
- int64_t i12 = 0;
495
- int64_t i13 = 0;
496
-
497
- if (dst->type == GGML_TYPE_BF16) {
498
- for (int64_t i03 = 0; i03 < ne03; i03++) {
499
- for (int64_t i02 = 0; i02 < ne02; i02++) {
500
- i10 += ne00 * ir0;
501
- while (i10 >= ne0) {
502
- i10 -= ne0;
503
- if (++i11 == ne1) {
504
- i11 = 0;
505
- if (++i12 == ne2) {
506
- i12 = 0;
507
- if (++i13 == ne3) {
508
- i13 = 0;
509
- }
510
- }
511
- }
512
- }
513
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
514
- for (int64_t i00 = 0; i00 < ne00; i00++) {
515
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
516
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
517
-
518
- memcpy(dst_ptr, src0_ptr, sizeof(ggml_bf16_t));
519
-
520
- if (++i10 == ne00) {
521
- i10 = 0;
522
- if (++i11 == ne01) {
523
- i11 = 0;
524
- if (++i12 == ne02) {
525
- i12 = 0;
526
- if (++i13 == ne03) {
527
- i13 = 0;
528
- }
529
- }
530
- }
531
- }
532
- }
533
- }
534
- i10 += ne00 * (ne01 - ir1);
535
- while (i10 >= ne0) {
536
- i10 -= ne0;
537
- if (++i11 == ne1) {
538
- i11 = 0;
539
- if (++i12 == ne2) {
540
- i12 = 0;
541
- if (++i13 == ne3) {
542
- i13 = 0;
543
- }
544
- }
545
- }
546
- }
547
- }
548
- }
549
- } else if (dst->type == GGML_TYPE_F16) {
550
- for (int64_t i03 = 0; i03 < ne03; i03++) {
551
- for (int64_t i02 = 0; i02 < ne02; i02++) {
552
- i10 += ne00 * ir0;
553
- while (i10 >= ne0) {
554
- i10 -= ne0;
555
- if (++i11 == ne1) {
556
- i11 = 0;
557
- if (++i12 == ne2) {
558
- i12 = 0;
559
- if (++i13 == ne3) {
560
- i13 = 0;
561
- }
562
- }
563
- }
564
- }
565
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
566
- for (int64_t i00 = 0; i00 < ne00; i00++) {
567
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
568
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
569
-
570
- *(ggml_fp16_t *) dst_ptr = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr));
571
-
572
- if (++i10 == ne0) {
573
- i10 = 0;
574
- if (++i11 == ne1) {
575
- i11 = 0;
576
- if (++i12 == ne2) {
577
- i12 = 0;
578
- if (++i13 == ne3) {
579
- i13 = 0;
580
- }
581
- }
582
- }
583
- }
584
- }
585
- }
586
- i10 += ne00 * (ne01 - ir1);
587
- while (i10 >= ne0) {
588
- i10 -= ne0;
589
- if (++i11 == ne1) {
590
- i11 = 0;
591
- if (++i12 == ne2) {
592
- i12 = 0;
593
- if (++i13 == ne3) {
594
- i13 = 0;
595
- }
596
- }
597
- }
598
- }
599
- }
600
- }
601
- } else if (dst->type == GGML_TYPE_F32) {
602
- for (int64_t i03 = 0; i03 < ne03; i03++) {
603
- for (int64_t i02 = 0; i02 < ne02; i02++) {
604
- i10 += ne00 * ir0;
605
- while (i10 >= ne0) {
606
- i10 -= ne0;
607
- if (++i11 == ne1) {
608
- i11 = 0;
609
- if (++i12 == ne2) {
610
- i12 = 0;
611
- if (++i13 == ne3) {
612
- i13 = 0;
613
- }
614
- }
615
- }
616
- }
617
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
618
- for (int64_t i00 = 0; i00 < ne00; i00++) {
619
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
620
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
621
-
622
- *(float *) dst_ptr = GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr);
623
-
624
- if (++i10 == ne0) {
625
- i10 = 0;
626
- if (++i11 == ne1) {
627
- i11 = 0;
628
- if (++i12 == ne2) {
629
- i12 = 0;
630
- if (++i13 == ne3) {
631
- i13 = 0;
632
- }
633
- }
634
- }
635
- }
636
- }
637
- }
638
- i10 += ne00 * (ne01 - ir1);
639
- while (i10 >= ne0) {
640
- i10 -= ne0;
641
- if (++i11 == ne1) {
642
- i11 = 0;
643
- if (++i12 == ne2) {
644
- i12 = 0;
645
- if (++i13 == ne3) {
646
- i13 = 0;
647
- }
648
- }
649
- }
650
- }
651
- }
652
- }
653
- } else {
654
- GGML_ABORT("fatal error"); // TODO: implement
655
- }
656
- }
657
-
658
- static void ggml_compute_forward_dup_f32(
659
- const ggml_compute_params * params,
660
- ggml_tensor * dst) {
661
-
662
- const ggml_tensor * src0 = dst->src[0];
663
-
664
- GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
665
-
666
- GGML_TENSOR_UNARY_OP_LOCALS
667
-
668
- const int ith = params->ith; // thread index
669
- const int nth = params->nth; // number of threads
670
-
671
- // parallelize by rows
672
- const int nr = ne01;
673
- // number of rows per thread
674
- const int dr = (nr + nth - 1) / nth;
675
- // row range for this thread
676
- const int ir0 = dr * ith;
677
- const int ir1 = MIN(ir0 + dr, nr);
678
-
679
- if (src0->type == dst->type &&
680
- ne00 == ne0 &&
681
- nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
682
- // copy by rows
683
- const size_t rs = ne00*nb00;
684
- for (int64_t i03 = 0; i03 < ne03; i03++) {
685
- for (int64_t i02 = 0; i02 < ne02; i02++) {
686
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
687
- memcpy(
688
- ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
689
- ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
690
- rs);
691
- }
692
- }
693
- }
694
- return;
695
- }
696
-
697
- if (ggml_is_contiguous(dst)) {
698
- // TODO: simplify
699
- if (nb00 == sizeof(float)) {
700
- if (ggml_get_type_traits_cpu(dst->type)->from_float) {
701
- ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;
702
-
703
- size_t id = 0;
704
- size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
705
- char * dst_ptr = (char *) dst->data;
706
-
707
- for (int i03 = 0; i03 < ne03; i03++) {
708
- for (int i02 = 0; i02 < ne02; i02++) {
709
- id += rs * ir0;
710
- for (int i01 = ir0; i01 < ir1; i01++) {
711
- const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
712
- from_float(src0_ptr, dst_ptr + id, ne00);
713
- id += rs;
714
- }
715
- id += rs * (ne01 - ir1);
716
- }
717
- }
718
- } else {
719
- GGML_ABORT("fatal error"); // TODO: implement
720
- }
721
- } else {
722
- //printf("%s: this is not optimal - fix me\n", __func__);
723
-
724
- if (dst->type == GGML_TYPE_F32) {
725
- size_t id = 0;
726
- float * dst_ptr = (float *) dst->data;
727
-
728
- for (int i03 = 0; i03 < ne03; i03++) {
729
- for (int i02 = 0; i02 < ne02; i02++) {
730
- id += ne00 * ir0;
731
- for (int i01 = ir0; i01 < ir1; i01++) {
732
- for (int i00 = 0; i00 < ne00; i00++) {
733
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
734
-
735
- dst_ptr[id] = *src0_ptr;
736
- id++;
737
- }
738
- }
739
- id += ne00 * (ne01 - ir1);
740
- }
741
- }
742
- } else if (dst->type == GGML_TYPE_F16) {
743
- size_t id = 0;
744
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
745
-
746
- for (int i03 = 0; i03 < ne03; i03++) {
747
- for (int i02 = 0; i02 < ne02; i02++) {
748
- id += ne00 * ir0;
749
- for (int i01 = ir0; i01 < ir1; i01++) {
750
- for (int i00 = 0; i00 < ne00; i00++) {
751
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
752
-
753
- dst_ptr[id] = GGML_CPU_FP32_TO_FP16(*src0_ptr);
754
- id++;
755
- }
756
- }
757
- id += ne00 * (ne01 - ir1);
758
- }
759
- }
760
- } else if (dst->type == GGML_TYPE_BF16) {
108
+ } else {
109
+ // casting between non-quantized types
761
110
  size_t id = 0;
762
- ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
111
+ dst_t * dst_ptr = (dst_t *) dst->data;
763
112
 
764
113
  for (int i03 = 0; i03 < ne03; i03++) {
765
114
  for (int i02 = 0; i02 < ne02; i02++) {
766
115
  id += ne00 * ir0;
767
116
  for (int i01 = ir0; i01 < ir1; i01++) {
117
+ const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
768
118
  for (int i00 = 0; i00 < ne00; i00++) {
769
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
770
-
771
- dst_ptr[id] = GGML_FP32_TO_BF16(*src0_ptr);
119
+ float tmp = type_conversion_table<src_t>::to_f32(src0_ptr[i00]);
120
+ dst_ptr[id] = type_conversion_table<dst_t>::from_f32(tmp);
772
121
  id++;
773
122
  }
774
123
  }
775
124
  id += ne00 * (ne01 - ir1);
776
125
  }
777
126
  }
778
- } else {
779
- GGML_ABORT("fatal error"); // TODO: implement
780
127
  }
781
- }
128
+ } else {
129
+ //printf("%s: this is not optimal - fix me\n", __func__);
130
+
131
+ size_t id = 0;
132
+ dst_t * dst_ptr = (dst_t *) dst->data;
133
+
134
+ for (int i03 = 0; i03 < ne03; i03++) {
135
+ for (int i02 = 0; i02 < ne02; i02++) {
136
+ id += ne00 * ir0;
137
+ for (int i01 = ir0; i01 < ir1; i01++) {
138
+ for (int i00 = 0; i00 < ne00; i00++) {
139
+ const src_t * src0_ptr = (src_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
782
140
 
141
+ float tmp = type_conversion_table<src_t>::to_f32(*src0_ptr);
142
+ dst_ptr[id] = type_conversion_table<dst_t>::from_f32(tmp);
143
+ id++;
144
+ }
145
+ }
146
+ id += ne00 * (ne01 - ir1);
147
+ }
148
+ }
149
+ }
783
150
  return;
784
151
  }
785
152
 
786
153
  // dst counters
787
-
788
154
  int64_t i10 = 0;
789
155
  int64_t i11 = 0;
790
156
  int64_t i12 = 0;
791
157
  int64_t i13 = 0;
792
158
 
793
- if (dst->type == GGML_TYPE_F32) {
159
+ if constexpr (std::is_same_v<dst_t, src_t>) {
794
160
  for (int64_t i03 = 0; i03 < ne03; i03++) {
795
161
  for (int64_t i02 = 0; i02 < ne02; i02++) {
796
162
  i10 += ne00 * ir0;
@@ -811,15 +177,15 @@ static void ggml_compute_forward_dup_f32(
811
177
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
812
178
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
813
179
 
814
- memcpy(dst_ptr, src0_ptr, sizeof(float));
180
+ memcpy(dst_ptr, src0_ptr, sizeof(dst_t));
815
181
 
816
- if (++i10 == ne0) {
182
+ if (++i10 == ne00) {
817
183
  i10 = 0;
818
- if (++i11 == ne1) {
184
+ if (++i11 == ne01) {
819
185
  i11 = 0;
820
- if (++i12 == ne2) {
186
+ if (++i12 == ne02) {
821
187
  i12 = 0;
822
- if (++i13 == ne3) {
188
+ if (++i13 == ne03) {
823
189
  i13 = 0;
824
190
  }
825
191
  }
@@ -842,7 +208,8 @@ static void ggml_compute_forward_dup_f32(
842
208
  }
843
209
  }
844
210
  }
845
- } else if (dst->type == GGML_TYPE_F16) {
211
+
212
+ } else {
846
213
  for (int64_t i03 = 0; i03 < ne03; i03++) {
847
214
  for (int64_t i02 = 0; i02 < ne02; i02++) {
848
215
  i10 += ne00 * ir0;
@@ -863,7 +230,8 @@ static void ggml_compute_forward_dup_f32(
863
230
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
864
231
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
865
232
 
866
- *(ggml_fp16_t *) dst_ptr = GGML_CPU_FP32_TO_FP16(*(const float *) src0_ptr);
233
+ float tmp = type_conversion_table<src_t>::to_f32(*(const src_t *) src0_ptr);
234
+ *(dst_t *) dst_ptr = type_conversion_table<dst_t>::from_f32(tmp);
867
235
 
868
236
  if (++i10 == ne0) {
869
237
  i10 = 0;
@@ -894,60 +262,63 @@ static void ggml_compute_forward_dup_f32(
894
262
  }
895
263
  }
896
264
  }
897
- } else if (dst->type == GGML_TYPE_BF16) {
898
- for (int64_t i03 = 0; i03 < ne03; i03++) {
899
- for (int64_t i02 = 0; i02 < ne02; i02++) {
900
- i10 += ne00 * ir0;
901
- while (i10 >= ne0) {
902
- i10 -= ne0;
903
- if (++i11 == ne1) {
904
- i11 = 0;
905
- if (++i12 == ne2) {
906
- i12 = 0;
907
- if (++i13 == ne3) {
908
- i13 = 0;
909
- }
910
- }
911
- }
912
- }
913
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
914
- for (int64_t i00 = 0; i00 < ne00; i00++) {
915
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
916
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
265
+ }
266
+ }
917
267
 
918
- *(ggml_bf16_t *) dst_ptr = GGML_FP32_TO_BF16(*(const float *) src0_ptr);
919
268
 
920
- if (++i10 == ne0) {
921
- i10 = 0;
922
- if (++i11 == ne1) {
923
- i11 = 0;
924
- if (++i12 == ne2) {
925
- i12 = 0;
926
- if (++i13 == ne3) {
927
- i13 = 0;
928
- }
929
- }
930
- }
931
- }
932
- }
933
- }
934
- i10 += ne00 * (ne01 - ir1);
935
- while (i10 >= ne0) {
936
- i10 -= ne0;
937
- if (++i11 == ne1) {
938
- i11 = 0;
939
- if (++i12 == ne2) {
940
- i12 = 0;
941
- if (++i13 == ne3) {
942
- i13 = 0;
943
- }
944
- }
269
+ template<typename src_t>
270
+ static void ggml_compute_forward_dup_to_q(
271
+ const ggml_compute_params * params,
272
+ ggml_tensor * dst) {
273
+
274
+ const ggml_tensor * src0 = dst->src[0];
275
+
276
+ GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
277
+ GGML_ASSERT(!ggml_is_quantized(src0->type));
278
+
279
+ GGML_TENSOR_UNARY_OP_LOCALS
280
+
281
+ const int ith = params->ith; // thread index
282
+ const int nth = params->nth; // number of threads
283
+
284
+ // parallelize by rows
285
+ const int nr = ne01;
286
+ // number of rows per thread
287
+ const int dr = (nr + nth - 1) / nth;
288
+ // row range for this thread
289
+ const int ir0 = dr * ith;
290
+ const int ir1 = MIN(ir0 + dr, nr);
291
+
292
+ if (ggml_is_contiguous(dst) &&
293
+ nb00 == sizeof(src_t) &&
294
+ ggml_get_type_traits_cpu(dst->type)->from_float) {
295
+ // casting non-quantized types --> intermediate f32 --> quantized
296
+ ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
297
+ float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
298
+
299
+ size_t id = 0;
300
+ size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
301
+ char * dst_ptr = (char *) dst->data;
302
+
303
+ for (int i03 = 0; i03 < ne03; i03++) {
304
+ for (int i02 = 0; i02 < ne02; i02++) {
305
+ id += rs * ir0;
306
+ for (int i01 = ir0; i01 < ir1; i01++) {
307
+ const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
308
+
309
+ for (int i00 = 0; i00 < ne00; i00++) {
310
+ src0_f32[i00] = type_conversion_table<src_t>::to_f32(src0_ptr[i00]);
945
311
  }
312
+
313
+ quantize_row_q(src0_f32, dst_ptr + id, ne00);
314
+ id += rs;
946
315
  }
316
+ id += rs * (ne01 - ir1);
947
317
  }
948
318
  }
949
319
  } else {
950
- GGML_ABORT("fatal error"); // TODO: implement
320
+ // printf("%s %s\n", ggml_type_name(src0->type), ggml_type_name(dst->type));
321
+ GGML_ABORT("not implemented");
951
322
  }
952
323
  }
953
324
 
@@ -1101,7 +472,7 @@ static void ggml_compute_forward_dup_bytes(
1101
472
  }
1102
473
  }
1103
474
 
1104
- static void ggml_compute_forward_dup_q(
475
+ static void ggml_compute_forward_dup_from_q(
1105
476
  const ggml_compute_params * params,
1106
477
  ggml_tensor * dst) {
1107
478
 
@@ -1166,20 +537,35 @@ void ggml_compute_forward_dup(
1166
537
  switch (src0->type) {
1167
538
  case GGML_TYPE_F16:
1168
539
  {
1169
- ggml_compute_forward_dup_f16(params, dst);
540
+ /**/ if (dst->type == GGML_TYPE_F16) ggml_compute_forward_dup_flt<ggml_fp16_t, ggml_fp16_t>(params, dst);
541
+ else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<ggml_fp16_t, ggml_bf16_t>(params, dst);
542
+ else if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<ggml_fp16_t, float >(params, dst);
543
+ else ggml_compute_forward_dup_to_q<ggml_fp16_t>(params, dst);
1170
544
  } break;
1171
545
  case GGML_TYPE_BF16:
1172
546
  {
1173
- ggml_compute_forward_dup_bf16(params, dst);
547
+ /**/ if (dst->type == GGML_TYPE_F16) ggml_compute_forward_dup_flt<ggml_bf16_t, ggml_fp16_t>(params, dst);
548
+ else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<ggml_bf16_t, ggml_bf16_t>(params, dst);
549
+ else if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<ggml_bf16_t, float >(params, dst);
550
+ else ggml_compute_forward_dup_to_q<ggml_bf16_t>(params, dst);
1174
551
  } break;
1175
552
  case GGML_TYPE_F32:
1176
553
  {
1177
- ggml_compute_forward_dup_f32(params, dst);
554
+ /**/ if (dst->type == GGML_TYPE_F16) ggml_compute_forward_dup_flt<float, ggml_fp16_t>(params, dst);
555
+ else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<float, ggml_bf16_t>(params, dst);
556
+ else if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<float, float >(params, dst);
557
+ else if (dst->type == GGML_TYPE_I32) ggml_compute_forward_dup_flt<float, int32_t >(params, dst);
558
+ else ggml_compute_forward_dup_to_q<float>(params, dst);
559
+ } break;
560
+ case GGML_TYPE_I32:
561
+ {
562
+ if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<int32_t, float>(params, dst);
563
+ else GGML_ABORT("not implemented");
1178
564
  } break;
1179
565
  default:
1180
566
  {
1181
567
  if (ggml_is_quantized(src0->type) && dst->type == GGML_TYPE_F32) {
1182
- ggml_compute_forward_dup_q(params, dst);
568
+ ggml_compute_forward_dup_from_q(params, dst);
1183
569
  break;
1184
570
  }
1185
571
  GGML_ABORT("fatal error");
@@ -1283,6 +669,7 @@ void ggml_compute_forward_add(
1283
669
  case GGML_TYPE_Q5_0:
1284
670
  case GGML_TYPE_Q5_1:
1285
671
  case GGML_TYPE_Q8_0:
672
+ case GGML_TYPE_MXFP4:
1286
673
  case GGML_TYPE_Q2_K:
1287
674
  case GGML_TYPE_Q3_K:
1288
675
  case GGML_TYPE_Q4_K:
@@ -1309,6 +696,77 @@ void ggml_compute_forward_add(
1309
696
  }
1310
697
  }
1311
698
 
699
+ // ggml_compute_forward_add_id
700
+
701
+ static void ggml_compute_forward_add_id_f32(
702
+ const ggml_compute_params * params,
703
+ ggml_tensor * dst) {
704
+
705
+ const ggml_tensor * src0 = dst->src[0];
706
+ const ggml_tensor * src1 = dst->src[1];
707
+ const ggml_tensor * src2 = dst->src[2];
708
+
709
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
710
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
711
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
712
+ GGML_ASSERT(src2->type == GGML_TYPE_I32);
713
+
714
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
715
+ GGML_ASSERT(src1->nb[0] == sizeof(float));
716
+
717
+ const int ith = params->ith;
718
+ const int nth = params->nth;
719
+
720
+ const int nr = ggml_nrows(src0);
721
+
722
+ GGML_TENSOR_TERNARY_OP_LOCALS
723
+
724
+ GGML_ASSERT( nb0 == sizeof(float));
725
+ GGML_ASSERT(nb10 == sizeof(float));
726
+
727
+ // rows per thread
728
+ const int dr = (nr + nth - 1)/nth;
729
+
730
+ // row range for this thread
731
+ const int ir0 = dr*ith;
732
+ const int ir1 = MIN(ir0 + dr, nr);
733
+
734
+ for (int ir = ir0; ir < ir1; ++ir) {
735
+ // src0 indices
736
+ const int i3 = ir/(ne2*ne1);
737
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
738
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
739
+
740
+ // src1 indices
741
+ const int i11 = *(int32_t *) ((char *) src2->data + i1*nb20 + i2*nb21);
742
+
743
+ GGML_ASSERT(i11 >= 0 && i11 < ne11);
744
+
745
+ ggml_vec_add_f32(ne0,
746
+ (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
747
+ (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
748
+ (float *) ((char *) src1->data + i11*nb11));
749
+ }
750
+ }
751
+
752
+ void ggml_compute_forward_add_id(
753
+ const ggml_compute_params * params,
754
+ ggml_tensor * dst) {
755
+
756
+ const ggml_tensor * src0 = dst->src[0];
757
+
758
+ switch (src0->type) {
759
+ case GGML_TYPE_F32:
760
+ {
761
+ ggml_compute_forward_add_id_f32(params, dst);
762
+ } break;
763
+ default:
764
+ {
765
+ GGML_ABORT("unsupported type for ggml_compute_forward_add_id: %s", ggml_type_name(src0->type));
766
+ }
767
+ }
768
+ }
769
+
1312
770
  // ggml_compute_forward_add1
1313
771
 
1314
772
  static void ggml_compute_forward_add1_f32(
@@ -1660,6 +1118,7 @@ void ggml_compute_forward_add1(
1660
1118
  case GGML_TYPE_Q5_1:
1661
1119
  case GGML_TYPE_Q8_0:
1662
1120
  case GGML_TYPE_Q8_1:
1121
+ case GGML_TYPE_MXFP4:
1663
1122
  case GGML_TYPE_Q2_K:
1664
1123
  case GGML_TYPE_Q3_K:
1665
1124
  case GGML_TYPE_Q4_K:
@@ -1787,6 +1246,7 @@ void ggml_compute_forward_acc(
1787
1246
  case GGML_TYPE_Q5_1:
1788
1247
  case GGML_TYPE_Q8_0:
1789
1248
  case GGML_TYPE_Q8_1:
1249
+ case GGML_TYPE_MXFP4:
1790
1250
  case GGML_TYPE_Q2_K:
1791
1251
  case GGML_TYPE_Q3_K:
1792
1252
  case GGML_TYPE_Q4_K:
@@ -1936,6 +1396,56 @@ void ggml_compute_forward_sum(
1936
1396
  }
1937
1397
  }
1938
1398
 
1399
+ // ggml_compute_forward_cumsum
1400
+
1401
+ static void ggml_compute_forward_cumsum_f32(
1402
+ const ggml_compute_params * params,
1403
+ ggml_tensor * dst) {
1404
+
1405
+ const ggml_tensor * src0 = dst->src[0];
1406
+
1407
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
1408
+ GGML_ASSERT(dst->nb[0] == sizeof(float));
1409
+
1410
+ GGML_TENSOR_UNARY_OP_LOCALS
1411
+
1412
+ GGML_ASSERT(ne0 == ne00);
1413
+ GGML_ASSERT(ne1 == ne01);
1414
+ GGML_ASSERT(ne2 == ne02);
1415
+ GGML_ASSERT(ne3 == ne03);
1416
+
1417
+ const auto [ir0, ir1] = get_thread_range(params, src0);
1418
+
1419
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
1420
+ const int64_t i03 = ir/(ne02*ne01);
1421
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
1422
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
1423
+
1424
+ float * src_row = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
1425
+ float * dst_row = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
1426
+
1427
+ ggml_vec_cumsum_f32(ne00, dst_row, src_row);
1428
+ }
1429
+ }
1430
+
1431
+ void ggml_compute_forward_cumsum(
1432
+ const ggml_compute_params * params,
1433
+ ggml_tensor * dst) {
1434
+
1435
+ const ggml_tensor * src0 = dst->src[0];
1436
+
1437
+ switch (src0->type) {
1438
+ case GGML_TYPE_F32:
1439
+ {
1440
+ ggml_compute_forward_cumsum_f32(params, dst);
1441
+ } break;
1442
+ default:
1443
+ {
1444
+ GGML_ABORT("fatal error");
1445
+ }
1446
+ }
1447
+ }
1448
+
1939
1449
  // ggml_compute_forward_sum_rows
1940
1450
 
1941
1451
  static void ggml_compute_forward_sum_rows_f32(
@@ -2656,24 +2166,101 @@ static void ggml_compute_forward_gelu_f16(
2656
2166
  assert(!isnan(v));
2657
2167
  assert(!isinf(v));
2658
2168
  }
2659
- #endif
2169
+ #endif
2170
+ }
2171
+ }
2172
+
2173
+ static void ggml_compute_forward_gelu(
2174
+ const ggml_compute_params * params,
2175
+ ggml_tensor * dst) {
2176
+
2177
+ const ggml_tensor * src0 = dst->src[0];
2178
+
2179
+ switch (src0->type) {
2180
+ case GGML_TYPE_F32:
2181
+ {
2182
+ ggml_compute_forward_gelu_f32(params, dst);
2183
+ } break;
2184
+ case GGML_TYPE_F16:
2185
+ {
2186
+ ggml_compute_forward_gelu_f16(params, dst);
2187
+ } break;
2188
+ default:
2189
+ {
2190
+ GGML_ABORT("fatal error");
2191
+ }
2192
+ }
2193
+ }
2194
+
2195
+ // ggml_compute_fill
2196
+
2197
+ static void ggml_compute_forward_fill_f32(const ggml_compute_params * params, ggml_tensor * dst) {
2198
+ const float c = ggml_get_op_params_f32(dst, 0);
2199
+
2200
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
2201
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
2202
+
2203
+ const auto [ir0, ir1] = get_thread_range(params, dst);
2204
+
2205
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
2206
+ const int64_t i03 = ir/(ne2*ne1);
2207
+ const int64_t i02 = (ir - i03*ne2*ne1)/ne1;
2208
+ const int64_t i01 = (ir - i03*ne2*ne1 - i02*ne1);
2209
+
2210
+ float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
2211
+
2212
+ ggml_vec_set_f32(ne0, dst_ptr, c);
2213
+ }
2214
+ }
2215
+
2216
+ void ggml_compute_forward_fill(const ggml_compute_params * params, ggml_tensor * dst) {
2217
+ ggml_compute_forward_fill_f32(params, dst);
2218
+ }
2219
+
2220
+ // ggml_compute_tri
2221
+
2222
+ static void ggml_compute_forward_tri_f32(const ggml_compute_params * params, ggml_tensor * dst) {
2223
+ const ggml_tensor * src0 = dst->src[0];
2224
+
2225
+ const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0);
2226
+
2227
+ GGML_ASSERT(ggml_is_contiguous(src0));
2228
+
2229
+ GGML_TENSOR_UNARY_OP_LOCALS
2230
+
2231
+ const auto [ir0, ir1] = get_thread_range(params, src0);
2232
+
2233
+ bool (*bipred)(int, int);
2234
+
2235
+ switch (ttype) {
2236
+ case GGML_TRI_TYPE_LOWER: bipred = [](int i, int r) { return i < r; }; break;
2237
+ case GGML_TRI_TYPE_LOWER_DIAG: bipred = [](int i, int r) { return i <= r; }; break;
2238
+ case GGML_TRI_TYPE_UPPER: bipred = [](int i, int r) { return i > r; }; break;
2239
+ case GGML_TRI_TYPE_UPPER_DIAG: bipred = [](int i, int r) { return i >= r; }; break;
2240
+ default: GGML_ABORT("invalid tri type");
2241
+ }
2242
+
2243
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
2244
+ const int64_t i03 = ir/(ne02*ne01);
2245
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
2246
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
2247
+
2248
+ const float * src_ptr = (const float *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
2249
+ float * dst_ptr = ( float *) (( char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
2250
+
2251
+ for (int i0 = 0; i0 < ne0; ++i0) {
2252
+ dst_ptr[i0] = bipred(i0, i01) ? src_ptr[i0] : 0.0f;
2253
+ }
2660
2254
  }
2661
2255
  }
2662
2256
 
2663
- static void ggml_compute_forward_gelu(
2664
- const ggml_compute_params * params,
2665
- ggml_tensor * dst) {
2666
-
2257
+ void ggml_compute_forward_tri(const ggml_compute_params * params, ggml_tensor * dst) {
2667
2258
  const ggml_tensor * src0 = dst->src[0];
2668
2259
 
2669
2260
  switch (src0->type) {
2670
2261
  case GGML_TYPE_F32:
2671
2262
  {
2672
- ggml_compute_forward_gelu_f32(params, dst);
2673
- } break;
2674
- case GGML_TYPE_F16:
2675
- {
2676
- ggml_compute_forward_gelu_f16(params, dst);
2263
+ ggml_compute_forward_tri_f32(params, dst);
2677
2264
  } break;
2678
2265
  default:
2679
2266
  {
@@ -3032,27 +2619,281 @@ static void ggml_compute_forward_leaky_relu_f16(
3032
2619
  return;
3033
2620
  }
3034
2621
 
3035
- assert(ggml_is_contiguous_1(src0));
3036
- assert(ggml_is_contiguous_1(dst));
3037
- assert(ggml_are_same_shape(src0, dst));
2622
+ assert(ggml_is_contiguous_1(src0));
2623
+ assert(ggml_is_contiguous_1(dst));
2624
+ assert(ggml_are_same_shape(src0, dst));
2625
+
2626
+ const int n = ggml_nrows(src0);
2627
+ const int nc = src0->ne[0];
2628
+
2629
+ float negative_slope;
2630
+ memcpy(&negative_slope, dst->op_params, sizeof(float));
2631
+
2632
+ assert(dst->nb[0] == sizeof(ggml_fp16_t));
2633
+ assert(src0->nb[0] == sizeof(ggml_fp16_t));
2634
+
2635
+ for (int i = 0; i < n; i++) {
2636
+ ggml_vec_leaky_relu_f16(nc,
2637
+ (ggml_fp16_t *) ((char *) dst->data + i*( dst->nb[1])),
2638
+ (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
2639
+ }
2640
+ }
2641
+
2642
+ void ggml_compute_forward_leaky_relu(
2643
+ const ggml_compute_params * params,
2644
+ ggml_tensor * dst) {
2645
+
2646
+ const ggml_tensor * src0 = dst->src[0];
2647
+
2648
+ switch (src0->type) {
2649
+ case GGML_TYPE_F32:
2650
+ {
2651
+ ggml_compute_forward_leaky_relu_f32(params, dst);
2652
+ } break;
2653
+ case GGML_TYPE_F16:
2654
+ {
2655
+ ggml_compute_forward_leaky_relu_f16(params, dst);
2656
+ } break;
2657
+ default:
2658
+ {
2659
+ GGML_ABORT("fatal error");
2660
+ }
2661
+ }
2662
+ }
2663
+
2664
+ // ggml_compute_forward_silu_back
2665
+
2666
+ static void ggml_compute_forward_silu_back_f32(
2667
+ const ggml_compute_params * params,
2668
+ ggml_tensor * dst) {
2669
+
2670
+ const ggml_tensor * grad = dst->src[0];
2671
+ const ggml_tensor * src1 = dst->src[1];
2672
+
2673
+ assert(ggml_is_contiguous_1(grad));
2674
+ assert(ggml_is_contiguous_1(src1));
2675
+ assert(ggml_is_contiguous_1(dst));
2676
+ assert(ggml_are_same_shape(src1, dst));
2677
+ assert(ggml_are_same_shape(src1, grad));
2678
+
2679
+ const int ith = params->ith;
2680
+ const int nth = params->nth;
2681
+
2682
+ const int nc = src1->ne[0];
2683
+ const int nr = ggml_nrows(src1);
2684
+
2685
+ // rows per thread
2686
+ const int dr = (nr + nth - 1)/nth;
2687
+
2688
+ // row range for this thread
2689
+ const int ir0 = dr*ith;
2690
+ const int ir1 = MIN(ir0 + dr, nr);
2691
+
2692
+ for (int i1 = ir0; i1 < ir1; i1++) {
2693
+ ggml_vec_silu_backward_f32(nc,
2694
+ (float *) ((char *) dst->data + i1*( dst->nb[1])),
2695
+ (float *) ((char *) src1->data + i1*(src1->nb[1])),
2696
+ (float *) ((char *) grad->data + i1*(grad->nb[1])));
2697
+
2698
+ #ifndef NDEBUG
2699
+ for (int k = 0; k < nc; k++) {
2700
+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2701
+ GGML_UNUSED(x);
2702
+ assert(!isnan(x));
2703
+ assert(!isinf(x));
2704
+ }
2705
+ #endif
2706
+ }
2707
+ }
2708
+
2709
+ static void ggml_compute_forward_silu_back_f16(
2710
+ const ggml_compute_params * params,
2711
+ ggml_tensor * dst) {
2712
+
2713
+ const ggml_tensor * grad = dst->src[0];
2714
+ const ggml_tensor * src1 = dst->src[1];
2715
+
2716
+ assert(ggml_is_contiguous_1(grad));
2717
+ assert(ggml_is_contiguous_1(src1));
2718
+ assert(ggml_is_contiguous_1(dst));
2719
+ assert(ggml_are_same_shape(src1, dst));
2720
+ assert(ggml_are_same_shape(src1, grad));
2721
+
2722
+ const int ith = params->ith;
2723
+ const int nth = params->nth;
2724
+
2725
+ const int nc = src1->ne[0];
2726
+ const int nr = ggml_nrows(src1);
2727
+
2728
+ // rows per thread
2729
+ const int dr = (nr + nth - 1)/nth;
2730
+
2731
+ // row range for this thread
2732
+ const int ir0 = dr*ith;
2733
+ const int ir1 = MIN(ir0 + dr, nr);
2734
+
2735
+ for (int i1 = ir0; i1 < ir1; i1++) {
2736
+ ggml_vec_silu_backward_f16(nc,
2737
+ (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
2738
+ (ggml_fp16_t *) ((char *) src1->data + i1*(src1->nb[1])),
2739
+ (ggml_fp16_t *) ((char *) grad->data + i1*(grad->nb[1])));
2740
+
2741
+ #ifndef NDEBUG
2742
+ for (int k = 0; k < nc; k++) {
2743
+ const float x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2744
+ const float v = GGML_CPU_FP16_TO_FP32(x);
2745
+ GGML_UNUSED(v);
2746
+ assert(!isnan(v));
2747
+ assert(!isinf(v));
2748
+ }
2749
+ #endif
2750
+ }
2751
+ }
2752
+
2753
+ void ggml_compute_forward_silu_back(
2754
+ const ggml_compute_params * params,
2755
+ ggml_tensor * dst) {
2756
+
2757
+ const ggml_tensor * src0 = dst->src[0];
2758
+
2759
+ switch (src0->type) {
2760
+ case GGML_TYPE_F32:
2761
+ {
2762
+ ggml_compute_forward_silu_back_f32(params, dst);
2763
+ } break;
2764
+ case GGML_TYPE_F16:
2765
+ {
2766
+ ggml_compute_forward_silu_back_f16(params, dst);
2767
+ } break;
2768
+ default:
2769
+ {
2770
+ GGML_ABORT("fatal error");
2771
+ }
2772
+ }
2773
+ }
2774
+
2775
+ // ggml_compute_forward_reglu
2776
+
2777
+ static void ggml_compute_forward_reglu_f32(
2778
+ const ggml_compute_params * params,
2779
+ ggml_tensor * dst) {
2780
+
2781
+ const ggml_tensor * src0 = dst->src[0];
2782
+ const ggml_tensor * src1 = dst->src[1];
2783
+ char * src0_d = (char *) src0->data;
2784
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
2785
+ const size_t src0_o = src0->nb[1];
2786
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
2787
+
2788
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
2789
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
2790
+
2791
+ if (src1) {
2792
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
2793
+ GGML_ASSERT(src0->type == src1->type);
2794
+ }
2795
+
2796
+ const int ith = params->ith;
2797
+ const int nth = params->nth;
2798
+
2799
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
2800
+ const int nr = ggml_nrows(src0);
2801
+
2802
+ GGML_ASSERT(dst->ne[0] == nc);
2803
+ GGML_ASSERT(ggml_nrows(dst) == nr);
2804
+
2805
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
2806
+
2807
+ // rows per thread
2808
+ const int dr = (nr + nth - 1)/nth;
2809
+
2810
+ // row range for this thread
2811
+ const int ir0 = dr*ith;
2812
+ const int ir1 = MIN(ir0 + dr, nr);
2813
+
2814
+ for (int i1 = ir0; i1 < ir1; i1++) {
2815
+ float * src0_p = (float *) (src0_d + i1*src0_o);
2816
+ float * src1_p = (float *) (src1_d + i1*src1_o);
2817
+
2818
+ if (!src1) {
2819
+ src0_p += swapped ? nc : 0;
2820
+ src1_p += swapped ? 0 : nc;
2821
+ }
2822
+
2823
+ ggml_vec_reglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
2824
+
2825
+ #ifndef NDEBUG
2826
+ for (int k = 0; k < nc; k++) {
2827
+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2828
+ GGML_UNUSED(x);
2829
+ assert(!isnan(x));
2830
+ assert(!isinf(x));
2831
+ }
2832
+ #endif
2833
+ }
2834
+ }
2835
+
2836
+ static void ggml_compute_forward_reglu_f16(
2837
+ const ggml_compute_params * params,
2838
+ ggml_tensor * dst) {
2839
+
2840
+ const ggml_tensor * src0 = dst->src[0];
2841
+ const ggml_tensor * src1 = dst->src[1];
2842
+ char * src0_d = (char *) src0->data;
2843
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
2844
+ const size_t src0_o = src0->nb[1];
2845
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
2846
+
2847
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
2848
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
2849
+
2850
+ if (src1) {
2851
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
2852
+ GGML_ASSERT(src0->type == src1->type);
2853
+ }
2854
+
2855
+ const int ith = params->ith;
2856
+ const int nth = params->nth;
2857
+
2858
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
2859
+ const int nr = ggml_nrows(src0);
2860
+
2861
+ GGML_ASSERT(dst->ne[0] == nc);
2862
+ GGML_ASSERT(ggml_nrows(dst) == nr);
2863
+
2864
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
2865
+
2866
+ // rows per thread
2867
+ const int dr = (nr + nth - 1)/nth;
2868
+
2869
+ // row range for this thread
2870
+ const int ir0 = dr*ith;
2871
+ const int ir1 = MIN(ir0 + dr, nr);
3038
2872
 
3039
- const int n = ggml_nrows(src0);
3040
- const int nc = src0->ne[0];
2873
+ for (int i1 = ir0; i1 < ir1; i1++) {
2874
+ ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
2875
+ ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3041
2876
 
3042
- float negative_slope;
3043
- memcpy(&negative_slope, dst->op_params, sizeof(float));
2877
+ if (!src1) {
2878
+ src0_p += swapped ? nc : 0;
2879
+ src1_p += swapped ? 0 : nc;
2880
+ }
3044
2881
 
3045
- assert(dst->nb[0] == sizeof(ggml_fp16_t));
3046
- assert(src0->nb[0] == sizeof(ggml_fp16_t));
2882
+ ggml_vec_reglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3047
2883
 
3048
- for (int i = 0; i < n; i++) {
3049
- ggml_vec_leaky_relu_f16(nc,
3050
- (ggml_fp16_t *) ((char *) dst->data + i*( dst->nb[1])),
3051
- (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
2884
+ #ifndef NDEBUG
2885
+ for (int k = 0; k < nc; k++) {
2886
+ const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2887
+ const float v = GGML_FP16_TO_FP32(x);
2888
+ GGML_UNUSED(v);
2889
+ assert(!isnan(v));
2890
+ assert(!isinf(v));
2891
+ }
2892
+ #endif
3052
2893
  }
3053
2894
  }
3054
2895
 
3055
- void ggml_compute_forward_leaky_relu(
2896
+ static void ggml_compute_forward_reglu(
3056
2897
  const ggml_compute_params * params,
3057
2898
  ggml_tensor * dst) {
3058
2899
 
@@ -3061,11 +2902,11 @@ void ggml_compute_forward_leaky_relu(
3061
2902
  switch (src0->type) {
3062
2903
  case GGML_TYPE_F32:
3063
2904
  {
3064
- ggml_compute_forward_leaky_relu_f32(params, dst);
2905
+ ggml_compute_forward_reglu_f32(params, dst);
3065
2906
  } break;
3066
2907
  case GGML_TYPE_F16:
3067
2908
  {
3068
- ggml_compute_forward_leaky_relu_f16(params, dst);
2909
+ ggml_compute_forward_reglu_f16(params, dst);
3069
2910
  } break;
3070
2911
  default:
3071
2912
  {
@@ -3074,26 +2915,37 @@ void ggml_compute_forward_leaky_relu(
3074
2915
  }
3075
2916
  }
3076
2917
 
3077
- // ggml_compute_forward_silu_back
2918
+ // ggml_compute_forward_geglu
3078
2919
 
3079
- static void ggml_compute_forward_silu_back_f32(
2920
+ static void ggml_compute_forward_geglu_f32(
3080
2921
  const ggml_compute_params * params,
3081
2922
  ggml_tensor * dst) {
3082
2923
 
3083
- const ggml_tensor * grad = dst->src[0];
2924
+ const ggml_tensor * src0 = dst->src[0];
3084
2925
  const ggml_tensor * src1 = dst->src[1];
2926
+ char * src0_d = (char *) src0->data;
2927
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
2928
+ const size_t src0_o = src0->nb[1];
2929
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3085
2930
 
3086
- assert(ggml_is_contiguous_1(grad));
3087
- assert(ggml_is_contiguous_1(src1));
3088
- assert(ggml_is_contiguous_1(dst));
3089
- assert(ggml_are_same_shape(src1, dst));
3090
- assert(ggml_are_same_shape(src1, grad));
2931
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
2932
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
2933
+
2934
+ if (src1) {
2935
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
2936
+ GGML_ASSERT(src0->type == src1->type);
2937
+ }
3091
2938
 
3092
2939
  const int ith = params->ith;
3093
2940
  const int nth = params->nth;
3094
2941
 
3095
- const int nc = src1->ne[0];
3096
- const int nr = ggml_nrows(src1);
2942
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
2943
+ const int nr = ggml_nrows(src0);
2944
+
2945
+ GGML_ASSERT(dst->ne[0] == nc);
2946
+ GGML_ASSERT(ggml_nrows(dst) == nr);
2947
+
2948
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3097
2949
 
3098
2950
  // rows per thread
3099
2951
  const int dr = (nr + nth - 1)/nth;
@@ -3103,10 +2955,15 @@ static void ggml_compute_forward_silu_back_f32(
3103
2955
  const int ir1 = MIN(ir0 + dr, nr);
3104
2956
 
3105
2957
  for (int i1 = ir0; i1 < ir1; i1++) {
3106
- ggml_vec_silu_backward_f32(nc,
3107
- (float *) ((char *) dst->data + i1*( dst->nb[1])),
3108
- (float *) ((char *) src1->data + i1*(src1->nb[1])),
3109
- (float *) ((char *) grad->data + i1*(grad->nb[1])));
2958
+ float * src0_p = (float *) (src0_d + i1*src0_o);
2959
+ float * src1_p = (float *) (src1_d + i1*src1_o);
2960
+
2961
+ if (!src1) {
2962
+ src0_p += swapped ? nc : 0;
2963
+ src1_p += swapped ? 0 : nc;
2964
+ }
2965
+
2966
+ ggml_vec_geglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3110
2967
 
3111
2968
  #ifndef NDEBUG
3112
2969
  for (int k = 0; k < nc; k++) {
@@ -3119,24 +2976,35 @@ static void ggml_compute_forward_silu_back_f32(
3119
2976
  }
3120
2977
  }
3121
2978
 
3122
- static void ggml_compute_forward_silu_back_f16(
2979
+ static void ggml_compute_forward_geglu_f16(
3123
2980
  const ggml_compute_params * params,
3124
2981
  ggml_tensor * dst) {
3125
2982
 
3126
- const ggml_tensor * grad = dst->src[0];
2983
+ const ggml_tensor * src0 = dst->src[0];
3127
2984
  const ggml_tensor * src1 = dst->src[1];
2985
+ char * src0_d = (char *) src0->data;
2986
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
2987
+ const size_t src0_o = src0->nb[1];
2988
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3128
2989
 
3129
- assert(ggml_is_contiguous_1(grad));
3130
- assert(ggml_is_contiguous_1(src1));
3131
- assert(ggml_is_contiguous_1(dst));
3132
- assert(ggml_are_same_shape(src1, dst));
3133
- assert(ggml_are_same_shape(src1, grad));
2990
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
2991
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
2992
+
2993
+ if (src1) {
2994
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
2995
+ GGML_ASSERT(src0->type == src1->type);
2996
+ }
3134
2997
 
3135
2998
  const int ith = params->ith;
3136
2999
  const int nth = params->nth;
3137
3000
 
3138
- const int nc = src1->ne[0];
3139
- const int nr = ggml_nrows(src1);
3001
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3002
+ const int nr = ggml_nrows(src0);
3003
+
3004
+ GGML_ASSERT(dst->ne[0] == nc);
3005
+ GGML_ASSERT(ggml_nrows(dst) == nr);
3006
+
3007
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3140
3008
 
3141
3009
  // rows per thread
3142
3010
  const int dr = (nr + nth - 1)/nth;
@@ -3146,24 +3014,29 @@ static void ggml_compute_forward_silu_back_f16(
3146
3014
  const int ir1 = MIN(ir0 + dr, nr);
3147
3015
 
3148
3016
  for (int i1 = ir0; i1 < ir1; i1++) {
3149
- ggml_vec_silu_backward_f16(nc,
3150
- (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
3151
- (ggml_fp16_t *) ((char *) src1->data + i1*(src1->nb[1])),
3152
- (ggml_fp16_t *) ((char *) grad->data + i1*(grad->nb[1])));
3017
+ ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3018
+ ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3153
3019
 
3154
- #ifndef NDEBUG
3020
+ if (!src1) {
3021
+ src0_p += swapped ? nc : 0;
3022
+ src1_p += swapped ? 0 : nc;
3023
+ }
3024
+
3025
+ ggml_vec_geglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3026
+
3027
+ #ifndef NDEBUG
3155
3028
  for (int k = 0; k < nc; k++) {
3156
- const float x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3157
- const float v = GGML_CPU_FP16_TO_FP32(x);
3029
+ const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3030
+ const float v = GGML_FP16_TO_FP32(x);
3158
3031
  GGML_UNUSED(v);
3159
3032
  assert(!isnan(v));
3160
3033
  assert(!isinf(v));
3161
3034
  }
3162
- #endif
3035
+ #endif
3163
3036
  }
3164
3037
  }
3165
3038
 
3166
- void ggml_compute_forward_silu_back(
3039
+ static void ggml_compute_forward_geglu(
3167
3040
  const ggml_compute_params * params,
3168
3041
  ggml_tensor * dst) {
3169
3042
 
@@ -3172,11 +3045,11 @@ void ggml_compute_forward_silu_back(
3172
3045
  switch (src0->type) {
3173
3046
  case GGML_TYPE_F32:
3174
3047
  {
3175
- ggml_compute_forward_silu_back_f32(params, dst);
3048
+ ggml_compute_forward_geglu_f32(params, dst);
3176
3049
  } break;
3177
3050
  case GGML_TYPE_F16:
3178
3051
  {
3179
- ggml_compute_forward_silu_back_f16(params, dst);
3052
+ ggml_compute_forward_geglu_f16(params, dst);
3180
3053
  } break;
3181
3054
  default:
3182
3055
  {
@@ -3185,9 +3058,9 @@ void ggml_compute_forward_silu_back(
3185
3058
  }
3186
3059
  }
3187
3060
 
3188
- // ggml_compute_forward_reglu
3061
+ // ggml_compute_forward_swiglu
3189
3062
 
3190
- static void ggml_compute_forward_reglu_f32(
3063
+ static void ggml_compute_forward_swiglu_f32(
3191
3064
  const ggml_compute_params * params,
3192
3065
  ggml_tensor * dst) {
3193
3066
 
@@ -3233,7 +3106,7 @@ static void ggml_compute_forward_reglu_f32(
3233
3106
  src1_p += swapped ? 0 : nc;
3234
3107
  }
3235
3108
 
3236
- ggml_vec_reglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3109
+ ggml_vec_swiglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3237
3110
 
3238
3111
  #ifndef NDEBUG
3239
3112
  for (int k = 0; k < nc; k++) {
@@ -3246,7 +3119,7 @@ static void ggml_compute_forward_reglu_f32(
3246
3119
  }
3247
3120
  }
3248
3121
 
3249
- static void ggml_compute_forward_reglu_f16(
3122
+ static void ggml_compute_forward_swiglu_f16(
3250
3123
  const ggml_compute_params * params,
3251
3124
  ggml_tensor * dst) {
3252
3125
 
@@ -3292,7 +3165,7 @@ static void ggml_compute_forward_reglu_f16(
3292
3165
  src1_p += swapped ? 0 : nc;
3293
3166
  }
3294
3167
 
3295
- ggml_vec_reglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3168
+ ggml_vec_swiglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3296
3169
 
3297
3170
  #ifndef NDEBUG
3298
3171
  for (int k = 0; k < nc; k++) {
@@ -3306,7 +3179,7 @@ static void ggml_compute_forward_reglu_f16(
3306
3179
  }
3307
3180
  }
3308
3181
 
3309
- static void ggml_compute_forward_reglu(
3182
+ static void ggml_compute_forward_swiglu(
3310
3183
  const ggml_compute_params * params,
3311
3184
  ggml_tensor * dst) {
3312
3185
 
@@ -3315,11 +3188,11 @@ static void ggml_compute_forward_reglu(
3315
3188
  switch (src0->type) {
3316
3189
  case GGML_TYPE_F32:
3317
3190
  {
3318
- ggml_compute_forward_reglu_f32(params, dst);
3191
+ ggml_compute_forward_swiglu_f32(params, dst);
3319
3192
  } break;
3320
3193
  case GGML_TYPE_F16:
3321
3194
  {
3322
- ggml_compute_forward_reglu_f16(params, dst);
3195
+ ggml_compute_forward_swiglu_f16(params, dst);
3323
3196
  } break;
3324
3197
  default:
3325
3198
  {
@@ -3328,9 +3201,9 @@ static void ggml_compute_forward_reglu(
3328
3201
  }
3329
3202
  }
3330
3203
 
3331
- // ggml_compute_forward_geglu
3204
+ // ggml_compute_forward_swiglu_oai
3332
3205
 
3333
- static void ggml_compute_forward_geglu_f32(
3206
+ static void ggml_compute_forward_swiglu_oai_f32(
3334
3207
  const ggml_compute_params * params,
3335
3208
  ggml_tensor * dst) {
3336
3209
 
@@ -3359,6 +3232,8 @@ static void ggml_compute_forward_geglu_f32(
3359
3232
  GGML_ASSERT(ggml_nrows(dst) == nr);
3360
3233
 
3361
3234
  const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3235
+ const float alpha = ggml_get_op_params_f32(dst, 2);
3236
+ const float limit = ggml_get_op_params_f32(dst, 3);
3362
3237
 
3363
3238
  // rows per thread
3364
3239
  const int dr = (nr + nth - 1)/nth;
@@ -3370,13 +3245,98 @@ static void ggml_compute_forward_geglu_f32(
3370
3245
  for (int i1 = ir0; i1 < ir1; i1++) {
3371
3246
  float * src0_p = (float *) (src0_d + i1*src0_o);
3372
3247
  float * src1_p = (float *) (src1_d + i1*src1_o);
3248
+ float * dst_p = (float *) ((char *) dst->data + i1*(dst->nb[1]));
3373
3249
 
3374
3250
  if (!src1) {
3375
3251
  src0_p += swapped ? nc : 0;
3376
3252
  src1_p += swapped ? 0 : nc;
3377
3253
  }
3378
3254
 
3379
- ggml_vec_geglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3255
+ for (int k = 0; k < nc; k++) {
3256
+ const float x = std::min(src0_p[k], limit);
3257
+ const float y = std::clamp(src1_p[k], -limit, limit);
3258
+ const float out_glu = x / (1.f + expf(alpha * (-x)));
3259
+ dst_p[k] = out_glu * (y + 1.f);
3260
+ }
3261
+
3262
+ #ifndef NDEBUG
3263
+ for (int k = 0; k < nc; k++) {
3264
+ const float x = dst_p[k];
3265
+ GGML_UNUSED(x);
3266
+ assert(!isnan(x));
3267
+ assert(!isinf(x));
3268
+ }
3269
+ #endif
3270
+ }
3271
+ }
3272
+
3273
+ static void ggml_compute_forward_swiglu_oai(
3274
+ const ggml_compute_params * params,
3275
+ ggml_tensor * dst) {
3276
+
3277
+ const ggml_tensor * src0 = dst->src[0];
3278
+
3279
+ switch (src0->type) {
3280
+ case GGML_TYPE_F32:
3281
+ {
3282
+ ggml_compute_forward_swiglu_oai_f32(params, dst);
3283
+ } break;
3284
+ default:
3285
+ {
3286
+ GGML_ABORT("fatal error");
3287
+ }
3288
+ }
3289
+ }
3290
+
3291
+ // ggml_compute_forward_geglu_erf
3292
+
3293
+ static void ggml_compute_forward_geglu_erf_f32(
3294
+ const ggml_compute_params * params,
3295
+ ggml_tensor * dst) {
3296
+
3297
+ const ggml_tensor * src0 = dst->src[0];
3298
+ const ggml_tensor * src1 = dst->src[1];
3299
+ char * src0_d = (char *) src0->data;
3300
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3301
+ const size_t src0_o = src0->nb[1];
3302
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3303
+
3304
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3305
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
3306
+
3307
+ if (src1) {
3308
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
3309
+ GGML_ASSERT(src0->type == src1->type);
3310
+ }
3311
+
3312
+ const int ith = params->ith;
3313
+ const int nth = params->nth;
3314
+
3315
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3316
+ const int nr = ggml_nrows(src0);
3317
+
3318
+ GGML_ASSERT(dst->ne[0] == nc);
3319
+ GGML_ASSERT(ggml_nrows(dst) == nr);
3320
+
3321
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3322
+
3323
+ // rows per thread
3324
+ const int dr = (nr + nth - 1)/nth;
3325
+
3326
+ // row range for this thread
3327
+ const int ir0 = dr*ith;
3328
+ const int ir1 = MIN(ir0 + dr, nr);
3329
+
3330
+ for (int i1 = ir0; i1 < ir1; i1++) {
3331
+ float * src0_p = (float *) (src0_d + i1*src0_o);
3332
+ float * src1_p = (float *) (src1_d + i1*src1_o);
3333
+
3334
+ if (!src1) {
3335
+ src0_p += swapped ? nc : 0;
3336
+ src1_p += swapped ? 0 : nc;
3337
+ }
3338
+
3339
+ ggml_vec_geglu_erf_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3380
3340
 
3381
3341
  #ifndef NDEBUG
3382
3342
  for (int k = 0; k < nc; k++) {
@@ -3389,7 +3349,7 @@ static void ggml_compute_forward_geglu_f32(
3389
3349
  }
3390
3350
  }
3391
3351
 
3392
- static void ggml_compute_forward_geglu_f16(
3352
+ static void ggml_compute_forward_geglu_erf_f16(
3393
3353
  const ggml_compute_params * params,
3394
3354
  ggml_tensor * dst) {
3395
3355
 
@@ -3435,7 +3395,7 @@ static void ggml_compute_forward_geglu_f16(
3435
3395
  src1_p += swapped ? 0 : nc;
3436
3396
  }
3437
3397
 
3438
- ggml_vec_geglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3398
+ ggml_vec_geglu_erf_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3439
3399
 
3440
3400
  #ifndef NDEBUG
3441
3401
  for (int k = 0; k < nc; k++) {
@@ -3449,7 +3409,7 @@ static void ggml_compute_forward_geglu_f16(
3449
3409
  }
3450
3410
  }
3451
3411
 
3452
- static void ggml_compute_forward_geglu(
3412
+ static void ggml_compute_forward_geglu_erf(
3453
3413
  const ggml_compute_params * params,
3454
3414
  ggml_tensor * dst) {
3455
3415
 
@@ -3458,11 +3418,11 @@ static void ggml_compute_forward_geglu(
3458
3418
  switch (src0->type) {
3459
3419
  case GGML_TYPE_F32:
3460
3420
  {
3461
- ggml_compute_forward_geglu_f32(params, dst);
3421
+ ggml_compute_forward_geglu_erf_f32(params, dst);
3462
3422
  } break;
3463
3423
  case GGML_TYPE_F16:
3464
3424
  {
3465
- ggml_compute_forward_geglu_f16(params, dst);
3425
+ ggml_compute_forward_geglu_erf_f16(params, dst);
3466
3426
  } break;
3467
3427
  default:
3468
3428
  {
@@ -3471,9 +3431,9 @@ static void ggml_compute_forward_geglu(
3471
3431
  }
3472
3432
  }
3473
3433
 
3474
- // ggml_compute_forward_swiglu
3434
+ // ggml_compute_forward_geglu_quick
3475
3435
 
3476
- static void ggml_compute_forward_swiglu_f32(
3436
+ static void ggml_compute_forward_geglu_quick_f32(
3477
3437
  const ggml_compute_params * params,
3478
3438
  ggml_tensor * dst) {
3479
3439
 
@@ -3519,7 +3479,7 @@ static void ggml_compute_forward_swiglu_f32(
3519
3479
  src1_p += swapped ? 0 : nc;
3520
3480
  }
3521
3481
 
3522
- ggml_vec_swiglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3482
+ ggml_vec_geglu_quick_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3523
3483
 
3524
3484
  #ifndef NDEBUG
3525
3485
  for (int k = 0; k < nc; k++) {
@@ -3532,7 +3492,7 @@ static void ggml_compute_forward_swiglu_f32(
3532
3492
  }
3533
3493
  }
3534
3494
 
3535
- static void ggml_compute_forward_swiglu_f16(
3495
+ static void ggml_compute_forward_geglu_quick_f16(
3536
3496
  const ggml_compute_params * params,
3537
3497
  ggml_tensor * dst) {
3538
3498
 
@@ -3578,7 +3538,7 @@ static void ggml_compute_forward_swiglu_f16(
3578
3538
  src1_p += swapped ? 0 : nc;
3579
3539
  }
3580
3540
 
3581
- ggml_vec_swiglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3541
+ ggml_vec_geglu_quick_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3582
3542
 
3583
3543
  #ifndef NDEBUG
3584
3544
  for (int k = 0; k < nc; k++) {
@@ -3592,7 +3552,7 @@ static void ggml_compute_forward_swiglu_f16(
3592
3552
  }
3593
3553
  }
3594
3554
 
3595
- static void ggml_compute_forward_swiglu(
3555
+ static void ggml_compute_forward_geglu_quick(
3596
3556
  const ggml_compute_params * params,
3597
3557
  ggml_tensor * dst) {
3598
3558
 
@@ -3601,11 +3561,11 @@ static void ggml_compute_forward_swiglu(
3601
3561
  switch (src0->type) {
3602
3562
  case GGML_TYPE_F32:
3603
3563
  {
3604
- ggml_compute_forward_swiglu_f32(params, dst);
3564
+ ggml_compute_forward_geglu_quick_f32(params, dst);
3605
3565
  } break;
3606
3566
  case GGML_TYPE_F16:
3607
3567
  {
3608
- ggml_compute_forward_swiglu_f16(params, dst);
3568
+ ggml_compute_forward_geglu_quick_f16(params, dst);
3609
3569
  } break;
3610
3570
  default:
3611
3571
  {
@@ -3636,31 +3596,27 @@ static void ggml_compute_forward_norm_f32(
3636
3596
 
3637
3597
  GGML_ASSERT(eps >= 0.0f);
3638
3598
 
3639
- // TODO: optimize
3640
3599
  for (int64_t i03 = 0; i03 < ne03; i03++) {
3641
3600
  for (int64_t i02 = 0; i02 < ne02; i02++) {
3642
3601
  for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
3643
3602
  const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
3644
3603
 
3645
- ggml_float sum = 0.0;
3646
- for (int64_t i00 = 0; i00 < ne00; i00++) {
3647
- sum += (ggml_float)x[i00];
3648
- }
3649
-
3604
+ float sum = 0.0;
3605
+ ggml_vec_sum_f32(ne00, &sum, x);
3650
3606
  float mean = sum/ne00;
3651
3607
 
3652
3608
  float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
3609
+ float variance = 0;
3653
3610
 
3654
- ggml_float sum2 = 0.0;
3655
- for (int64_t i00 = 0; i00 < ne00; i00++) {
3656
- float v = x[i00] - mean;
3657
- y[i00] = v;
3658
- sum2 += (ggml_float)(v*v);
3659
- }
3611
+ #ifdef GGML_USE_ACCELERATE
3612
+ mean = -mean;
3613
+ vDSP_vsadd(x, 1, &mean, y, 1, ne00);
3614
+ vDSP_measqv(y, 1, &variance, ne00);
3615
+ #else
3616
+ variance = ggml_vec_cvar_f32(ne00, y, x, mean);
3617
+ #endif //GGML_USE_ACCELERATE
3660
3618
 
3661
- float variance = sum2/ne00;
3662
3619
  const float scale = 1.0f/sqrtf(variance + eps);
3663
-
3664
3620
  ggml_vec_scale_f32(ne00, y, scale);
3665
3621
  }
3666
3622
  }
@@ -3729,6 +3685,9 @@ static void ggml_compute_forward_rms_norm_f32(
3729
3685
 
3730
3686
  const float scale = 1.0f/sqrtf(mean + eps);
3731
3687
 
3688
+ // if you hit this, likely you got an inf somewhere earlier
3689
+ assert(scale > 0.0f);
3690
+
3732
3691
  ggml_vec_scale_f32(ne00, y, scale);
3733
3692
  }
3734
3693
  }
@@ -4310,6 +4269,7 @@ void ggml_compute_forward_out_prod(
4310
4269
  case GGML_TYPE_Q5_0:
4311
4270
  case GGML_TYPE_Q5_1:
4312
4271
  case GGML_TYPE_Q8_0:
4272
+ case GGML_TYPE_MXFP4:
4313
4273
  case GGML_TYPE_Q2_K:
4314
4274
  case GGML_TYPE_Q3_K:
4315
4275
  case GGML_TYPE_Q4_K:
@@ -4357,9 +4317,11 @@ static void ggml_compute_forward_scale_f32(
4357
4317
  GGML_ASSERT(ggml_is_contiguous(dst));
4358
4318
  GGML_ASSERT(ggml_are_same_shape(src0, dst));
4359
4319
 
4360
- // scale factor
4361
- float v;
4362
- memcpy(&v, dst->op_params, sizeof(float));
4320
+ float s; // scale factor
4321
+ float b; // bias
4322
+
4323
+ memcpy(&s, (float *) dst->op_params + 0, sizeof(float));
4324
+ memcpy(&b, (float *) dst->op_params + 1, sizeof(float));
4363
4325
 
4364
4326
  const int ith = params->ith;
4365
4327
  const int nth = params->nth;
@@ -4378,12 +4340,22 @@ static void ggml_compute_forward_scale_f32(
4378
4340
 
4379
4341
  const size_t nb1 = dst->nb[1];
4380
4342
 
4381
- for (int i1 = ir0; i1 < ir1; i1++) {
4382
- if (dst->data != src0->data) {
4383
- // src0 is same shape as dst => same indices
4384
- memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
4343
+ if (b == 0.0f) {
4344
+ for (int i1 = ir0; i1 < ir1; i1++) {
4345
+ if (dst->data != src0->data) {
4346
+ // src0 is same shape as dst => same indices
4347
+ // TODO: add x parameter to ggml_vec_scale_f32 and remove this memcpy
4348
+ memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
4349
+ }
4350
+ ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), s);
4351
+ }
4352
+ } else {
4353
+ for (int i1 = ir0; i1 < ir1; i1++) {
4354
+ ggml_vec_mad1_f32(nc,
4355
+ (float *) ((char *) dst->data + i1*nb1),
4356
+ (float *) ((char *) src0->data + i1*nb1),
4357
+ s, b);
4385
4358
  }
4386
- ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v);
4387
4359
  }
4388
4360
  }
4389
4361
 
@@ -4572,6 +4544,7 @@ void ggml_compute_forward_set(
4572
4544
  case GGML_TYPE_Q5_1:
4573
4545
  case GGML_TYPE_Q8_0:
4574
4546
  case GGML_TYPE_Q8_1:
4547
+ case GGML_TYPE_MXFP4:
4575
4548
  case GGML_TYPE_Q2_K:
4576
4549
  case GGML_TYPE_Q3_K:
4577
4550
  case GGML_TYPE_Q4_K:
@@ -4611,46 +4584,6 @@ void ggml_compute_forward_cont(
4611
4584
  ggml_compute_forward_dup(params, dst);
4612
4585
  }
4613
4586
 
4614
- // ggml_compute_forward_reshape
4615
-
4616
- void ggml_compute_forward_reshape(
4617
- const ggml_compute_params * params,
4618
- ggml_tensor * dst) {
4619
- // NOP
4620
- GGML_UNUSED(params);
4621
- GGML_UNUSED(dst);
4622
- }
4623
-
4624
- // ggml_compute_forward_view
4625
-
4626
- void ggml_compute_forward_view(
4627
- const ggml_compute_params * params,
4628
- ggml_tensor * dst) {
4629
- // NOP
4630
- GGML_UNUSED(params);
4631
- GGML_UNUSED(dst);
4632
- }
4633
-
4634
- // ggml_compute_forward_permute
4635
-
4636
- void ggml_compute_forward_permute(
4637
- const ggml_compute_params * params,
4638
- ggml_tensor * dst) {
4639
- // NOP
4640
- GGML_UNUSED(params);
4641
- GGML_UNUSED(dst);
4642
- }
4643
-
4644
- // ggml_compute_forward_transpose
4645
-
4646
- void ggml_compute_forward_transpose(
4647
- const ggml_compute_params * params,
4648
- ggml_tensor * dst) {
4649
- // NOP
4650
- GGML_UNUSED(params);
4651
- GGML_UNUSED(dst);
4652
- }
4653
-
4654
4587
  // ggml_compute_forward_get_rows
4655
4588
 
4656
4589
  static void ggml_compute_forward_get_rows_q(
@@ -4833,6 +4766,7 @@ void ggml_compute_forward_get_rows(
4833
4766
  case GGML_TYPE_Q5_1:
4834
4767
  case GGML_TYPE_Q8_0:
4835
4768
  case GGML_TYPE_Q8_1:
4769
+ case GGML_TYPE_MXFP4:
4836
4770
  case GGML_TYPE_Q2_K:
4837
4771
  case GGML_TYPE_Q3_K:
4838
4772
  case GGML_TYPE_Q4_K:
@@ -4890,6 +4824,7 @@ void ggml_compute_forward_get_rows(
4890
4824
  //}
4891
4825
  }
4892
4826
 
4827
+ template<typename idx_t>
4893
4828
  static void ggml_compute_forward_set_rows_f32(
4894
4829
  const ggml_compute_params * params,
4895
4830
  ggml_tensor * dst) {
@@ -4928,7 +4863,7 @@ static void ggml_compute_forward_set_rows_f32(
4928
4863
  const int64_t i11 = i02%ne11;
4929
4864
  const int64_t i10 = i;
4930
4865
 
4931
- const int64_t i1 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
4866
+ const int64_t i1 = *(idx_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
4932
4867
 
4933
4868
  GGML_ASSERT(i1 >= 0 && i1 < ne1);
4934
4869
 
@@ -4945,11 +4880,18 @@ void ggml_compute_forward_set_rows(
4945
4880
  ggml_tensor * dst) {
4946
4881
 
4947
4882
  const ggml_tensor * src0 = dst->src[0];
4883
+ const ggml_tensor * src1 = dst->src[1];
4948
4884
 
4949
4885
  switch (src0->type) {
4950
4886
  case GGML_TYPE_F32:
4951
4887
  {
4952
- ggml_compute_forward_set_rows_f32(params, dst);
4888
+ if (src1->type == GGML_TYPE_I64) {
4889
+ ggml_compute_forward_set_rows_f32<int64_t>(params, dst);
4890
+ } else if (src1->type == GGML_TYPE_I32) {
4891
+ ggml_compute_forward_set_rows_f32<int32_t>(params, dst);
4892
+ } else {
4893
+ GGML_ABORT("src1->type = %d (%s) not supported", src1->type, ggml_type_name(src1->type));
4894
+ }
4953
4895
  } break;
4954
4896
  default:
4955
4897
  {
@@ -5222,6 +5164,7 @@ static void ggml_compute_forward_soft_max_f32(
5222
5164
 
5223
5165
  const ggml_tensor * src0 = dst->src[0];
5224
5166
  const ggml_tensor * src1 = dst->src[1];
5167
+ const ggml_tensor * src2 = dst->src[2];
5225
5168
 
5226
5169
  assert(ggml_is_contiguous(dst));
5227
5170
  assert(ggml_are_same_shape(src0, dst));
@@ -5232,14 +5175,17 @@ static void ggml_compute_forward_soft_max_f32(
5232
5175
  memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
5233
5176
  memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
5234
5177
 
5235
- // TODO: handle transposed/permuted matrices
5236
-
5237
5178
  const int ith = params->ith;
5238
5179
  const int nth = params->nth;
5239
5180
 
5240
5181
  GGML_TENSOR_UNARY_OP_LOCALS
5241
5182
 
5242
- //const int64_t ne11 = src1 ? src1->ne[1] : 1;
5183
+ const int64_t nb11 = src1 ? src1->nb[1] : 1;
5184
+ const int64_t nb12 = src1 ? src1->nb[2] : 1;
5185
+ const int64_t nb13 = src1 ? src1->nb[3] : 1;
5186
+
5187
+ const int64_t ne12 = src1 ? src1->ne[2] : 1;
5188
+ const int64_t ne13 = src1 ? src1->ne[3] : 1;
5243
5189
 
5244
5190
  // TODO: is this supposed to be ceil instead of floor?
5245
5191
  // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
@@ -5249,68 +5195,78 @@ static void ggml_compute_forward_soft_max_f32(
5249
5195
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
5250
5196
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
5251
5197
 
5252
- const int nc = src0->ne[0];
5253
- const int nr = ggml_nrows(src0);
5254
-
5255
- // rows per thread
5256
- const int dr = (nr + nth - 1)/nth;
5257
-
5258
- // row range for this thread
5259
- const int ir0 = dr*ith;
5260
- const int ir1 = MIN(ir0 + dr, nr);
5261
-
5262
- float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
5198
+ float * wp = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
5263
5199
 
5264
5200
  const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
5265
5201
 
5266
- for (int i1 = ir0; i1 < ir1; i1++) {
5267
- // ALiBi
5268
- const uint32_t h = (i1/ne01)%ne02; // head
5269
- const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
5270
-
5271
- float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
5272
- float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
5202
+ // sinks
5203
+ const float * sk = src2 ? (float *)((char *) src2->data) : nullptr;
5273
5204
 
5274
- // broadcast the mask across rows
5275
- ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
5276
- float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
5277
-
5278
- ggml_vec_cpy_f32 (nc, wp, sp);
5279
- ggml_vec_scale_f32(nc, wp, scale);
5280
- if (mp_f32) {
5281
- if (use_f16) {
5282
- for (int i = 0; i < nc; ++i) {
5283
- wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]);
5284
- }
5285
- } else {
5286
- for (int i = 0; i < nc; ++i) {
5287
- wp[i] += slope*mp_f32[i];
5205
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
5206
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
5207
+ for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
5208
+ const int64_t i11 = i01;
5209
+ const int64_t i12 = i02%ne12;
5210
+ const int64_t i13 = i03%ne13;
5211
+
5212
+ // ALiBi
5213
+ const uint32_t h = i02; // head
5214
+ const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
5215
+
5216
+ float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5217
+ float * dp = (float *)((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
5218
+
5219
+ // broadcast the mask across rows
5220
+ ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
5221
+ float * mp_f32 = src1 ? (float *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
5222
+
5223
+ ggml_vec_cpy_f32 (ne00, wp, sp);
5224
+ ggml_vec_scale_f32(ne00, wp, scale);
5225
+ if (mp_f32) {
5226
+ if (use_f16) {
5227
+ for (int i = 0; i < ne00; ++i) {
5228
+ wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]);
5229
+ }
5230
+ } else {
5231
+ for (int i = 0; i < ne00; ++i) {
5232
+ wp[i] += slope*mp_f32[i];
5233
+ }
5234
+ }
5288
5235
  }
5289
- }
5290
- }
5291
5236
 
5292
5237
  #ifndef NDEBUG
5293
- for (int i = 0; i < nc; ++i) {
5294
- //printf("p[%d] = %f\n", i, p[i]);
5295
- assert(!isnan(wp[i]));
5296
- }
5238
+ for (int i = 0; i < ne00; ++i) {
5239
+ //printf("p[%d] = %f\n", i, p[i]);
5240
+ assert(!isnan(wp[i]));
5241
+ }
5297
5242
  #endif
5298
5243
 
5299
- float max = -INFINITY;
5300
- ggml_vec_max_f32(nc, &max, wp);
5244
+ float max = -INFINITY;
5245
+ ggml_vec_max_f32(ne00, &max, wp);
5301
5246
 
5302
- ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max);
5303
- assert(sum > 0.0);
5247
+ // if we have sinks, make a correction as if they were included in the softmax
5248
+ if (sk) {
5249
+ max = MAX(max, sk[i02]);
5250
+ }
5304
5251
 
5305
- sum = 1.0/sum;
5306
- ggml_vec_scale_f32(nc, dp, sum);
5252
+ ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
5253
+ assert(sum > 0.0);
5254
+
5255
+ if (sk) {
5256
+ sum += (ggml_float) expf(sk[i02] - max);
5257
+ }
5258
+
5259
+ sum = 1.0/sum;
5260
+ ggml_vec_scale_f32(ne00, dp, sum);
5307
5261
 
5308
5262
  #ifndef NDEBUG
5309
- for (int i = 0; i < nc; ++i) {
5310
- assert(!isnan(dp[i]));
5311
- assert(!isinf(dp[i]));
5312
- }
5263
+ for (int i = 0; i < ne00; ++i) {
5264
+ assert(!isnan(dp[i]));
5265
+ assert(!isinf(dp[i]));
5266
+ }
5313
5267
  #endif
5268
+ }
5269
+ }
5314
5270
  }
5315
5271
  }
5316
5272
 
@@ -5534,6 +5490,7 @@ void ggml_compute_forward_clamp(
5534
5490
  case GGML_TYPE_Q5_1:
5535
5491
  case GGML_TYPE_Q8_0:
5536
5492
  case GGML_TYPE_Q8_1:
5493
+ case GGML_TYPE_MXFP4:
5537
5494
  case GGML_TYPE_Q2_K:
5538
5495
  case GGML_TYPE_Q3_K:
5539
5496
  case GGML_TYPE_Q4_K:
@@ -5580,276 +5537,123 @@ static void rope_yarn(
5580
5537
  float theta = theta_interp;
5581
5538
  if (ext_factor != 0.0f) {
5582
5539
  float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
5583
- theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
5584
-
5585
- // Get n-d magnitude scaling corrected for interpolation
5586
- mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
5587
- }
5588
- *cos_theta = cosf(theta) * mscale;
5589
- *sin_theta = sinf(theta) * mscale;
5590
- }
5591
-
5592
- static void ggml_rope_cache_init(
5593
- float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
5594
- float * cache, float sin_sign, float theta_scale) {
5595
- // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
5596
- float theta = theta_base;
5597
- for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
5598
- const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
5599
- rope_yarn(
5600
- theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
5601
- );
5602
- cache[i0 + 1] *= sin_sign;
5603
-
5604
- theta *= theta_scale;
5605
- }
5606
- }
5607
-
5608
- static void ggml_mrope_cache_init(
5609
- float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects,
5610
- float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
5611
- float * cache, float sin_sign, float theta_scale) {
5612
- // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
5613
- float theta_t = theta_base_t;
5614
- float theta_h = theta_base_h;
5615
- float theta_w = theta_base_w;
5616
- float theta_e = theta_base_e; // extra position id for vision encoder
5617
- int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
5618
- int sec_w = sections[1] + sections[0];
5619
- int sec_e = sections[2] + sec_w;
5620
- GGML_ASSERT(sect_dims <= ne0);
5621
-
5622
- for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
5623
- const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
5624
-
5625
- int sector = (i0 / 2) % sect_dims;
5626
- if (indep_sects) {
5627
- // compute theta independently for each dim sections
5628
- // (i.e. reset corresponding theta when `i0` go from one section to another)
5629
- if (sector == 0) {
5630
- theta_t = theta_base_t;
5631
- }
5632
- else if (sector == sections[0]) {
5633
- theta_h = theta_base_h;;
5634
- }
5635
- else if (sector == sec_w) {
5636
- theta_w = theta_base_w;
5637
- }
5638
- else if (sector == sec_e) {
5639
- theta_e = theta_base_e;
5640
- }
5641
- }
5642
-
5643
- float theta = theta_t;
5644
- if (sector >= sections[0] && sector < sec_w) {
5645
- theta = theta_h;
5646
- }
5647
- else if (sector >= sec_w && sector < sec_w + sections[2]) {
5648
- theta = theta_w;
5649
- }
5650
- else if (sector >= sec_w + sections[2]) {
5651
- theta = theta_e;
5652
- }
5653
-
5654
- rope_yarn(
5655
- theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
5656
- );
5657
- cache[i0 + 1] *= sin_sign;
5658
-
5659
- theta_t *= theta_scale;
5660
- theta_w *= theta_scale;
5661
- theta_h *= theta_scale;
5662
- theta_e *= theta_scale;
5663
- }
5664
- }
5665
-
5666
- static void ggml_compute_forward_rope_f32(
5667
- const ggml_compute_params * params,
5668
- ggml_tensor * dst,
5669
- const bool forward) {
5670
-
5671
- const ggml_tensor * src0 = dst->src[0];
5672
- const ggml_tensor * src1 = dst->src[1];
5673
- const ggml_tensor * src2 = dst->src[2];
5674
-
5675
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
5676
- int sections[4];
5677
-
5678
- //const int n_past = ((int32_t *) dst->op_params)[0];
5679
- const int n_dims = ((int32_t *) dst->op_params)[1];
5680
- const int mode = ((int32_t *) dst->op_params)[2];
5681
- //const int n_ctx = ((int32_t *) dst->op_params)[3];
5682
- const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
5683
-
5684
- memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
5685
- memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
5686
- memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
5687
- memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
5688
- memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
5689
- memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
5690
- memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
5691
-
5692
- GGML_TENSOR_UNARY_OP_LOCALS
5693
-
5694
- //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
5695
- //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
5696
-
5697
- GGML_ASSERT(nb00 == sizeof(float));
5698
-
5699
- const int ith = params->ith;
5700
- const int nth = params->nth;
5701
-
5702
- const int nr = ggml_nrows(dst);
5703
-
5704
- GGML_ASSERT(n_dims <= ne0);
5705
- GGML_ASSERT(n_dims % 2 == 0);
5706
-
5707
- // rows per thread
5708
- const int dr = (nr + nth - 1)/nth;
5709
-
5710
- // row range for this thread
5711
- const int ir0 = dr*ith;
5712
- const int ir1 = MIN(ir0 + dr, nr);
5713
-
5714
- // row index used to determine which thread to use
5715
- int ir = 0;
5716
-
5717
- const float theta_scale = powf(freq_base, -2.0f/n_dims);
5718
-
5719
- float corr_dims[2];
5720
- ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
5721
-
5722
- const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
5723
- const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding
5724
- const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
5725
-
5726
- if (is_mrope) {
5727
- GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
5728
- }
5729
-
5730
- if (is_vision) {
5731
- GGML_ASSERT(n_dims == ne0/2);
5732
- }
5733
-
5734
- const float * freq_factors = NULL;
5735
- if (src2 != NULL) {
5736
- GGML_ASSERT(src2->type == GGML_TYPE_F32);
5737
- GGML_ASSERT(src2->ne[0] >= n_dims / 2);
5738
- freq_factors = (const float *) src2->data;
5739
- }
5740
-
5741
- // backward process uses inverse rotation by cos and sin.
5742
- // cos and sin build a rotation matrix, where the inverse is the transpose.
5743
- // this essentially just switches the sign of sin.
5744
- const float sin_sign = forward ? 1.0f : -1.0f;
5745
-
5746
- const int32_t * pos = (const int32_t *) src1->data;
5747
-
5748
- for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
5749
- for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
5750
-
5751
- float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
5752
- if (!is_mrope) {
5753
- const int64_t p = pos[i2];
5754
- ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5755
- }
5756
- else {
5757
- const int64_t p_t = pos[i2];
5758
- const int64_t p_h = pos[i2 + ne2];
5759
- const int64_t p_w = pos[i2 + ne2 * 2];
5760
- const int64_t p_e = pos[i2 + ne2 * 3];
5761
- ggml_mrope_cache_init(
5762
- p_t, p_h, p_w, p_e, sections, is_vision,
5763
- freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5764
- }
5765
-
5766
- for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
5767
- if (ir++ < ir0) continue;
5768
- if (ir > ir1) break;
5769
-
5770
- if (is_neox || is_mrope) {
5771
- if (is_vision){
5772
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5773
- const int64_t ic = i0/2;
5774
-
5775
- const float cos_theta = cache[i0 + 0];
5776
- const float sin_theta = cache[i0 + 1];
5777
-
5778
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5779
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5540
+ theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
5780
5541
 
5781
- const float x0 = src[0];
5782
- const float x1 = src[n_dims];
5542
+ // Get n-d magnitude scaling corrected for interpolation
5543
+ mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
5544
+ }
5545
+ *cos_theta = cosf(theta) * mscale;
5546
+ *sin_theta = sinf(theta) * mscale;
5547
+ }
5783
5548
 
5784
- dst_data[0] = x0*cos_theta - x1*sin_theta;
5785
- dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
5786
- }
5787
- } else {
5788
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5789
- const int64_t ic = i0/2;
5549
+ static void ggml_rope_cache_init(
5550
+ float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
5551
+ float * cache, float sin_sign, float theta_scale) {
5552
+ // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
5553
+ float theta = theta_base;
5554
+ for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
5555
+ const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
5556
+ rope_yarn(
5557
+ theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
5558
+ );
5559
+ cache[i0 + 1] *= sin_sign;
5790
5560
 
5791
- const float cos_theta = cache[i0 + 0];
5792
- const float sin_theta = cache[i0 + 1];
5561
+ theta *= theta_scale;
5562
+ }
5563
+ }
5793
5564
 
5794
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5795
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5565
+ static void ggml_mrope_cache_init(
5566
+ float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool is_imrope, bool indep_sects,
5567
+ float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
5568
+ float * cache, float sin_sign, float theta_scale) {
5569
+ // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
5570
+ float theta_t = theta_base_t;
5571
+ float theta_h = theta_base_h;
5572
+ float theta_w = theta_base_w;
5573
+ float theta_e = theta_base_e; // extra position id for vision encoder
5574
+ int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
5575
+ int sec_w = sections[1] + sections[0];
5576
+ int sec_e = sections[2] + sec_w;
5577
+ GGML_ASSERT(sect_dims <= ne0);
5796
5578
 
5797
- const float x0 = src[0];
5798
- const float x1 = src[n_dims/2];
5579
+ for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
5580
+ const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
5799
5581
 
5800
- dst_data[0] = x0*cos_theta - x1*sin_theta;
5801
- dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
5802
- }
5803
- }
5804
- } else {
5805
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5806
- const float cos_theta = cache[i0 + 0];
5807
- const float sin_theta = cache[i0 + 1];
5582
+ int sector = (i0 / 2) % sect_dims;
5583
+ if (indep_sects) {
5584
+ // compute theta independently for each dim sections
5585
+ // (i.e. reset corresponding theta when `i0` go from one section to another)
5586
+ if (sector == 0) {
5587
+ theta_t = theta_base_t;
5588
+ }
5589
+ else if (sector == sections[0]) {
5590
+ theta_h = theta_base_h;;
5591
+ }
5592
+ else if (sector == sec_w) {
5593
+ theta_w = theta_base_w;
5594
+ }
5595
+ else if (sector == sec_e) {
5596
+ theta_e = theta_base_e;
5597
+ }
5598
+ }
5808
5599
 
5809
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5810
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5600
+ float theta = theta_t;
5601
+ if (is_imrope) { // qwen3vl apply interleaved mrope
5602
+ if (sector % 3 == 1 && sector < 3 * sections[1]) {
5603
+ theta = theta_h;
5604
+ } else if (sector % 3 == 2 && sector < 3 * sections[2]) {
5605
+ theta = theta_w;
5606
+ } else if (sector % 3 == 0 && sector < 3 * sections[0]) {
5607
+ theta = theta_t;
5608
+ } else {
5609
+ theta = theta_e;
5610
+ }
5611
+ } else {
5612
+ if (sector >= sections[0] && sector < sec_w) {
5613
+ theta = theta_h;
5614
+ }
5615
+ else if (sector >= sec_w && sector < sec_w + sections[2]) {
5616
+ theta = theta_w;
5617
+ }
5618
+ else if (sector >= sec_w + sections[2]) {
5619
+ theta = theta_e;
5620
+ }
5621
+ }
5811
5622
 
5812
- const float x0 = src[0];
5813
- const float x1 = src[1];
5623
+ rope_yarn(
5624
+ theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
5625
+ );
5626
+ cache[i0 + 1] *= sin_sign;
5814
5627
 
5815
- dst_data[0] = x0*cos_theta - x1*sin_theta;
5816
- dst_data[1] = x0*sin_theta + x1*cos_theta;
5817
- }
5818
- }
5628
+ theta_t *= theta_scale;
5629
+ theta_w *= theta_scale;
5630
+ theta_h *= theta_scale;
5631
+ theta_e *= theta_scale;
5632
+ }
5633
+ }
5819
5634
 
5820
- if (is_vision) {
5821
- for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
5822
- const int64_t ic = i0/2;
5823
5635
 
5824
- const float cos_theta = cache[i0 + 0];
5825
- const float sin_theta = cache[i0 + 1];
5636
+ template<typename T>
5637
+ static void rotate_pairs(const int64_t n, const int64_t n_offset, const float * cache, const T * src_data, T * dst_data, const int scale = 2) {
5638
+ for (int64_t i0 = 0; i0 < n; i0 += 2) {
5639
+ const int64_t ic = i0/scale; // hack for GGML_ROPE_TYPE_NORMAL, where we need ic = i0; for all other cases, ic = i0/2
5826
5640
 
5827
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5828
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5641
+ const float cos_theta = cache[i0 + 0];
5642
+ const float sin_theta = cache[i0 + 1];
5829
5643
 
5830
- const float x0 = src[0];
5831
- const float x1 = src[n_dims];
5644
+ const T * const src = src_data + ic;
5645
+ T * dst = dst_data + ic;
5832
5646
 
5833
- dst_data[0] = x0*cos_theta - x1*sin_theta;
5834
- dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
5835
- }
5836
- } else {
5837
- // fill the remain channels with data from src tensor
5838
- for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
5839
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5840
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5647
+ const float x0 = type_conversion_table<T>::to_f32(src[0]);
5648
+ const float x1 = type_conversion_table<T>::to_f32(src[n_offset]);
5841
5649
 
5842
- dst_data[0] = src[0];
5843
- dst_data[1] = src[1];
5844
- }
5845
- }
5846
- }
5847
- }
5848
- }
5650
+ dst[0] = type_conversion_table<T>::from_f32(x0*cos_theta - x1*sin_theta);
5651
+ dst[n_offset] = type_conversion_table<T>::from_f32(x0*sin_theta + x1*cos_theta);
5652
+ }
5849
5653
  }
5850
5654
 
5851
- // TODO: deduplicate f16/f32 code
5852
- static void ggml_compute_forward_rope_f16(
5655
+ template<typename T> //float or ggml_fp16_t
5656
+ static void ggml_compute_forward_rope_flt(
5853
5657
  const ggml_compute_params * params,
5854
5658
  ggml_tensor * dst,
5855
5659
  const bool forward) {
@@ -5858,6 +5662,9 @@ static void ggml_compute_forward_rope_f16(
5858
5662
  const ggml_tensor * src1 = dst->src[1];
5859
5663
  const ggml_tensor * src2 = dst->src[2];
5860
5664
 
5665
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
5666
+ GGML_ASSERT(src1->type == GGML_TYPE_I32);
5667
+
5861
5668
  float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
5862
5669
  int sections[4];
5863
5670
 
@@ -5866,6 +5673,7 @@ static void ggml_compute_forward_rope_f16(
5866
5673
  const int mode = ((int32_t *) dst->op_params)[2];
5867
5674
  //const int n_ctx = ((int32_t *) dst->op_params)[3];
5868
5675
  const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
5676
+
5869
5677
  memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
5870
5678
  memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
5871
5679
  memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
@@ -5874,13 +5682,13 @@ static void ggml_compute_forward_rope_f16(
5874
5682
  memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
5875
5683
  memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
5876
5684
 
5877
-
5878
5685
  GGML_TENSOR_UNARY_OP_LOCALS
5879
5686
 
5880
5687
  //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
5881
5688
  //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
5882
5689
 
5883
- GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
5690
+ GGML_ASSERT(nb0 == nb00);
5691
+ GGML_ASSERT(nb0 == sizeof(T));
5884
5692
 
5885
5693
  const int ith = params->ith;
5886
5694
  const int nth = params->nth;
@@ -5905,11 +5713,11 @@ static void ggml_compute_forward_rope_f16(
5905
5713
  float corr_dims[2];
5906
5714
  ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
5907
5715
 
5908
- const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
5909
- const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
5716
+ const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
5717
+ const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, note: also true for vision (24 & 8 == true) and for imrope
5910
5718
  const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
5911
5719
 
5912
- if (is_mrope) {
5720
+ if (mrope_used) {
5913
5721
  GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
5914
5722
  }
5915
5723
 
@@ -5931,11 +5739,11 @@ static void ggml_compute_forward_rope_f16(
5931
5739
 
5932
5740
  const int32_t * pos = (const int32_t *) src1->data;
5933
5741
 
5934
- for (int64_t i3 = 0; i3 < ne3; i3++) {
5935
- for (int64_t i2 = 0; i2 < ne2; i2++) {
5742
+ for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
5743
+ for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
5936
5744
 
5937
5745
  float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
5938
- if (!is_mrope) {
5746
+ if (!mrope_used) {
5939
5747
  const int64_t p = pos[i2];
5940
5748
  ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5941
5749
  }
@@ -5945,90 +5753,44 @@ static void ggml_compute_forward_rope_f16(
5945
5753
  const int64_t p_w = pos[i2 + ne2 * 2];
5946
5754
  const int64_t p_e = pos[i2 + ne2 * 3];
5947
5755
  ggml_mrope_cache_init(
5948
- p_t, p_h, p_w, p_e, sections, is_vision,
5756
+ p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
5949
5757
  freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5950
5758
  }
5951
5759
 
5952
- for (int64_t i1 = 0; i1 < ne1; i1++) {
5760
+ for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
5953
5761
  if (ir++ < ir0) continue;
5954
5762
  if (ir > ir1) break;
5955
5763
 
5956
- if (is_neox || is_mrope) {
5957
- if (is_vision) {
5958
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5959
- const int64_t ic = i0/2;
5960
-
5961
- const float cos_theta = cache[i0 + 0];
5962
- const float sin_theta = cache[i0 + 1];
5963
-
5964
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5965
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5966
-
5967
- const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
5968
- const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims]);
5969
-
5970
- dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5971
- dst_data[n_dims] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5972
- }
5973
- } else {
5974
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5975
- const int64_t ic = i0/2;
5976
-
5977
- const float cos_theta = cache[i0 + 0];
5978
- const float sin_theta = cache[i0 + 1];
5979
-
5980
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5981
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5982
-
5983
- const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
5984
- const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims/2]);
5985
-
5986
- dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5987
- dst_data[n_dims/2] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5988
- }
5989
- }
5990
- } else {
5991
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5992
- const float cos_theta = cache[i0 + 0];
5993
- const float sin_theta = cache[i0 + 1];
5994
-
5995
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5996
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5997
-
5998
- const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
5999
- const float x1 = GGML_CPU_FP16_TO_FP32(src[1]);
6000
-
6001
- dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
6002
- dst_data[1] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
6003
- }
6004
- }
6005
-
6006
- if (is_vision) {
6007
- for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
6008
- const int64_t ic = i0/2;
6009
-
6010
- const float cos_theta = cache[i0 + 0];
6011
- const float sin_theta = cache[i0 + 1];
6012
-
6013
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
6014
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
6015
-
6016
- const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
6017
- const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims]);
6018
-
6019
- dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
6020
- dst_data[n_dims] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
6021
- }
6022
- } else {
5764
+ T * src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
5765
+ T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
5766
+
5767
+ switch (mode) {
5768
+ case GGML_ROPE_TYPE_NORMAL:
5769
+ rotate_pairs<T>(n_dims, 1, cache, src, dst_data, 1);
5770
+ break;
5771
+ case GGML_ROPE_TYPE_NEOX:
5772
+ case GGML_ROPE_TYPE_MROPE:
5773
+ case GGML_ROPE_TYPE_IMROPE:
5774
+ rotate_pairs<T>(n_dims, n_dims/2, cache, src, dst_data);
5775
+ break;
5776
+ case GGML_ROPE_TYPE_VISION:
5777
+ rotate_pairs<T>(ne0, n_dims, cache, src, dst_data);
5778
+ break;
5779
+ default:
5780
+ GGML_ABORT("rope type not supported");
5781
+ }
5782
+
5783
+ if (!is_vision) {
5784
+ // fill the remain channels with data from src tensor
6023
5785
  for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
6024
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
6025
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5786
+ const T * const src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5787
+ T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
6026
5788
 
6027
5789
  dst_data[0] = src[0];
6028
5790
  dst_data[1] = src[1];
6029
5791
  }
6030
5792
  }
6031
- }
5793
+ } //attn-heads
6032
5794
  }
6033
5795
  }
6034
5796
  }
@@ -6042,11 +5804,11 @@ void ggml_compute_forward_rope(
6042
5804
  switch (src0->type) {
6043
5805
  case GGML_TYPE_F16:
6044
5806
  {
6045
- ggml_compute_forward_rope_f16(params, dst, true);
5807
+ ggml_compute_forward_rope_flt<ggml_fp16_t>(params, dst, true);
6046
5808
  } break;
6047
5809
  case GGML_TYPE_F32:
6048
5810
  {
6049
- ggml_compute_forward_rope_f32(params, dst, true);
5811
+ ggml_compute_forward_rope_flt<float>(params, dst, true);
6050
5812
  } break;
6051
5813
  default:
6052
5814
  {
@@ -6066,11 +5828,11 @@ void ggml_compute_forward_rope_back(
6066
5828
  switch (src0->type) {
6067
5829
  case GGML_TYPE_F16:
6068
5830
  {
6069
- ggml_compute_forward_rope_f16(params, dst, false);
5831
+ ggml_compute_forward_rope_flt<ggml_fp16_t>(params, dst, false);
6070
5832
  } break;
6071
5833
  case GGML_TYPE_F32:
6072
5834
  {
6073
- ggml_compute_forward_rope_f32(params, dst, false);
5835
+ ggml_compute_forward_rope_flt<float>(params, dst, false);
6074
5836
  } break;
6075
5837
  default:
6076
5838
  {
@@ -6477,68 +6239,251 @@ void ggml_compute_forward_im2col_back_f32(
6477
6239
  const int ith = params->ith;
6478
6240
  const int nth = params->nth;
6479
6241
 
6480
- const int64_t N = is_2D ? ne3 : ne2;
6481
- const int64_t IC = is_2D ? ne2 : ne1;
6482
- const int64_t IH = is_2D ? ne1 : 1;
6483
- const int64_t IW = ne0;
6242
+ const int64_t N = is_2D ? ne3 : ne2;
6243
+ const int64_t IC = is_2D ? ne2 : ne1;
6244
+ const int64_t IH = is_2D ? ne1 : 1;
6245
+ const int64_t IW = ne0;
6246
+
6247
+ const int64_t KH = is_2D ? ne11 : 1;
6248
+ const int64_t KW = ne10;
6249
+
6250
+ const int64_t OH = is_2D ? ne02 : 1;
6251
+ const int64_t OW = ne01;
6252
+
6253
+ int ofs0 = is_2D ? nb3 : nb2;
6254
+ int ofs1 = is_2D ? nb2 : nb1;
6255
+
6256
+ GGML_ASSERT(nb0 == sizeof(float));
6257
+
6258
+ // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
6259
+ {
6260
+ float * const wdata = (float *) dst->data;
6261
+
6262
+ for (int64_t in = 0; in < N; in++) {
6263
+ for (int64_t iic = ith; iic < IC; iic += nth) {
6264
+ for (int64_t iih = 0; iih < IH; iih++) {
6265
+ for (int64_t iiw = 0; iiw < IW; iiw++) {
6266
+
6267
+ // micro kernel
6268
+ float grad = 0.0f;
6269
+ for (int64_t ikh = 0; ikh < KH; ikh++) {
6270
+ for (int64_t ikw = 0; ikw < KW; ikw++) {
6271
+ // For s0 > 1 some values were skipped over in the forward pass.
6272
+ // These values have tmpw % s0 != 0 and need to be skipped in the backwards pass as well.
6273
+ const int64_t tmpw = (iiw + p0 - ikw*d0);
6274
+ if (tmpw % s0 != 0) {
6275
+ continue;
6276
+ }
6277
+ const int64_t iow = tmpw / s0;
6278
+
6279
+ // Equivalent logic as above except for s1.
6280
+ int64_t ioh;
6281
+ if (is_2D) {
6282
+ const int64_t tmph = iih + p1 - ikh*d1;
6283
+
6284
+ if (tmph % s1 != 0) {
6285
+ continue;
6286
+ }
6287
+
6288
+ ioh = tmph / s1;
6289
+ } else {
6290
+ ioh = 0;
6291
+ }
6292
+
6293
+ if (iow < 0 || iow >= OW || ioh < 0 || ioh >= OH) {
6294
+ continue;
6295
+ }
6296
+
6297
+ const float * const grad_in = (const float *) src0->data
6298
+ + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
6299
+ grad += grad_in[iic*(KH*KW) + ikh*KW + ikw];
6300
+ }
6301
+ }
6302
+ float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
6303
+ dst_data[iih*IW + iiw] = grad;
6304
+ }
6305
+ }
6306
+ }
6307
+ }
6308
+ }
6309
+ }
6310
+
6311
+
6312
+ // ggml_compute_forward_im2col_3d_f16
6313
+ // src0: kernel [OC*IC, KD, KH, KW]
6314
+ // src1: image [N*IC, ID, IH, IW]
6315
+ // dst: result [N*OD, OH, OW, IC * KD * KH * KW]
6316
+ static void ggml_compute_forward_im2col_3d_f16(
6317
+ const ggml_compute_params * params,
6318
+ ggml_tensor * dst) {
6319
+
6320
+ const ggml_tensor * src0 = dst->src[0];
6321
+ const ggml_tensor * src1 = dst->src[1];
6322
+
6323
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
6324
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
6325
+ GGML_ASSERT( dst->type == GGML_TYPE_F16);
6326
+
6327
+ GGML_TENSOR_BINARY_OP_LOCALS;
6328
+
6329
+ const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
6330
+ const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
6331
+ const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
6332
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
6333
+ const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
6334
+ const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
6335
+ const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
6336
+ const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
6337
+ const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
6338
+ const int32_t IC = ((const int32_t *)(dst->op_params))[9];
6339
+
6340
+
6341
+ const int ith = params->ith;
6342
+ const int nth = params->nth;
6343
+
6344
+ const int64_t N = ne13 / IC;
6345
+ const int64_t ID = ne12;
6346
+ const int64_t IH = ne11;
6347
+ const int64_t IW = ne10;
6348
+
6349
+ const int64_t OC = ne03 / IC;
6350
+ GGML_UNUSED(OC);
6351
+ const int64_t KD = ne02;
6352
+ const int64_t KH = ne01;
6353
+ const int64_t KW = ne00;
6354
+
6355
+ const int64_t OD = ne3 / N;
6356
+ const int64_t OH = ne2;
6357
+ const int64_t OW = ne1;
6358
+ const int64_t OH_OW = OH*OW;
6359
+ const int64_t KD_KH_KW = KD*KH*KW;
6360
+ const int64_t KH_KW = KH*KW;
6361
+ const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
6362
+
6363
+ GGML_ASSERT(nb10 == sizeof(float));
6364
+
6365
+ // im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
6366
+ {
6367
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
6368
+
6369
+ for (int64_t in = 0; in < N; in++) {
6370
+ for (int64_t iod = 0; iod < OD; iod++) {
6371
+ for (int64_t ioh = 0; ioh < OH; ioh++) {
6372
+ for (int64_t iow = 0; iow < OW; iow++) {
6373
+ for (int64_t iic = ith; iic < IC; iic += nth) {
6374
+
6375
+ // micro kernel
6376
+ ggml_fp16_t * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
6377
+ const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]
6378
+
6379
+ for (int64_t ikd = 0; ikd < KD; ikd++) {
6380
+ for (int64_t ikh = 0; ikh < KH; ikh++) {
6381
+ for (int64_t ikw = 0; ikw < KW; ikw++) {
6382
+ const int64_t iiw = iow*s0 + ikw*d0 - p0;
6383
+ const int64_t iih = ioh*s1 + ikh*d1 - p1;
6384
+ const int64_t iid = iod*s2 + ikd*d2 - p2;
6385
+
6386
+ if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
6387
+ dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
6388
+ } else {
6389
+ const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
6390
+ dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(*s);
6391
+ }
6392
+ }
6393
+ }
6394
+ }
6395
+ }
6396
+ }
6397
+ }
6398
+ }
6399
+ }
6400
+ }
6401
+ }
6402
+
6403
+ // ggml_compute_forward_im2col_3d_f32
6404
+ // src0: kernel [OC*IC, KD, KH, KW]
6405
+ // src1: image [N*IC, ID, IH, IW]
6406
+ // dst: result [N*OD, OH, OW, IC * KD * KH * KW]
6407
+ static void ggml_compute_forward_im2col_3d_f32(
6408
+ const ggml_compute_params * params,
6409
+ ggml_tensor * dst) {
6410
+
6411
+ const ggml_tensor * src0 = dst->src[0];
6412
+ const ggml_tensor * src1 = dst->src[1];
6413
+
6414
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
6415
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
6416
+
6417
+ GGML_TENSOR_BINARY_OP_LOCALS;
6418
+
6419
+ const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
6420
+ const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
6421
+ const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
6422
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
6423
+ const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
6424
+ const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
6425
+ const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
6426
+ const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
6427
+ const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
6428
+ const int32_t IC = ((const int32_t *)(dst->op_params))[9];
6429
+
6430
+
6431
+ const int ith = params->ith;
6432
+ const int nth = params->nth;
6433
+
6434
+ const int64_t N = ne13 / IC;
6435
+ const int64_t ID = ne12;
6436
+ const int64_t IH = ne11;
6437
+ const int64_t IW = ne10;
6484
6438
 
6485
- const int64_t KH = is_2D ? ne11 : 1;
6486
- const int64_t KW = ne10;
6439
+ const int64_t OC = ne03 / IC;
6440
+ GGML_UNUSED(OC);
6441
+ const int64_t KD = ne02;
6442
+ const int64_t KH = ne01;
6443
+ const int64_t KW = ne00;
6487
6444
 
6488
- const int64_t OH = is_2D ? ne02 : 1;
6489
- const int64_t OW = ne01;
6445
+ const int64_t OD = ne3 / N;
6446
+ const int64_t OH = ne2;
6447
+ const int64_t OW = ne1;
6490
6448
 
6491
- int ofs0 = is_2D ? nb3 : nb2;
6492
- int ofs1 = is_2D ? nb2 : nb1;
6449
+ const int64_t OH_OW = OH*OW;
6450
+ const int64_t KD_KH_KW = KD*KH*KW;
6451
+ const int64_t KH_KW = KH*KW;
6452
+ const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
6493
6453
 
6494
- GGML_ASSERT(nb0 == sizeof(float));
6454
+ GGML_ASSERT(nb10 == sizeof(float));
6495
6455
 
6496
- // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
6456
+ // im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
6497
6457
  {
6498
6458
  float * const wdata = (float *) dst->data;
6499
6459
 
6500
6460
  for (int64_t in = 0; in < N; in++) {
6501
- for (int64_t iic = ith; iic < IC; iic += nth) {
6502
- for (int64_t iih = 0; iih < IH; iih++) {
6503
- for (int64_t iiw = 0; iiw < IW; iiw++) {
6504
-
6505
- // micro kernel
6506
- float grad = 0.0f;
6507
- for (int64_t ikh = 0; ikh < KH; ikh++) {
6508
- for (int64_t ikw = 0; ikw < KW; ikw++) {
6509
- // For s0 > 1 some values were skipped over in the forward pass.
6510
- // These values have tmpw % s0 != 0 and need to be skipped in the backwards pass as well.
6511
- const int64_t tmpw = (iiw + p0 - ikw*d0);
6512
- if (tmpw % s0 != 0) {
6513
- continue;
6514
- }
6515
- const int64_t iow = tmpw / s0;
6516
-
6517
- // Equivalent logic as above except for s1.
6518
- int64_t ioh;
6519
- if (is_2D) {
6520
- const int64_t tmph = iih + p1 - ikh*d1;
6521
-
6522
- if (tmph % s1 != 0) {
6523
- continue;
6461
+ for (int64_t iod = 0; iod < OD; iod++) {
6462
+ for (int64_t ioh = 0; ioh < OH; ioh++) {
6463
+ for (int64_t iow = 0; iow < OW; iow++) {
6464
+ for (int64_t iic = ith; iic < IC; iic += nth) {
6465
+
6466
+ // micro kernel
6467
+ float * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
6468
+ const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]
6469
+
6470
+ for (int64_t ikd = 0; ikd < KD; ikd++) {
6471
+ for (int64_t ikh = 0; ikh < KH; ikh++) {
6472
+ for (int64_t ikw = 0; ikw < KW; ikw++) {
6473
+ const int64_t iiw = iow*s0 + ikw*d0 - p0;
6474
+ const int64_t iih = ioh*s1 + ikh*d1 - p1;
6475
+ const int64_t iid = iod*s2 + ikd*d2 - p2;
6476
+
6477
+ if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
6478
+ dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
6479
+ } else {
6480
+ const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
6481
+ dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = *s;
6482
+ }
6524
6483
  }
6525
-
6526
- ioh = tmph / s1;
6527
- } else {
6528
- ioh = 0;
6529
- }
6530
-
6531
- if (iow < 0 || iow >= OW || ioh < 0 || ioh >= OH) {
6532
- continue;
6533
6484
  }
6534
-
6535
- const float * const grad_in = (const float *) src0->data
6536
- + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
6537
- grad += grad_in[iic*(KH*KW) + ikh*KW + ikw];
6538
6485
  }
6539
6486
  }
6540
- float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
6541
- dst_data[iih*IW + iiw] = grad;
6542
6487
  }
6543
6488
  }
6544
6489
  }
@@ -6546,6 +6491,26 @@ void ggml_compute_forward_im2col_back_f32(
6546
6491
  }
6547
6492
  }
6548
6493
 
6494
+
6495
+ void ggml_compute_forward_im2col_3d(
6496
+ const ggml_compute_params * params,
6497
+ ggml_tensor * dst) {
6498
+ switch (dst->type) {
6499
+ case GGML_TYPE_F16:
6500
+ {
6501
+ ggml_compute_forward_im2col_3d_f16(params, dst);
6502
+ } break;
6503
+ case GGML_TYPE_F32:
6504
+ {
6505
+ ggml_compute_forward_im2col_3d_f32(params, dst);
6506
+ } break;
6507
+ default:
6508
+ {
6509
+ GGML_ABORT("fatal error");
6510
+ }
6511
+ }
6512
+ }
6513
+
6549
6514
  static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
6550
6515
  void * a, void * b, float * c) {
6551
6516
  const ggml_type_traits * traits = ggml_get_type_traits(type);
@@ -6589,8 +6554,13 @@ static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params
6589
6554
  ggml_compute_forward_mul_mat(params, &dst);
6590
6555
  }
6591
6556
 
6557
+ static inline int64_t ggml_wrap_around(int64_t coord, int64_t size) {
6558
+ return (coord + size) % size; // adding size avoids negative number weirdness
6559
+ }
6560
+
6592
6561
  // ggml_compute_forward_conv_2d
6593
6562
 
6563
+
6594
6564
  static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params,
6595
6565
  const ggml_tensor * kernel, // [KW, KH, IC, OC]
6596
6566
  const ggml_tensor * src, // [W, H, C, N]
@@ -6726,6 +6696,148 @@ void ggml_compute_forward_conv_2d(
6726
6696
  ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type);
6727
6697
  }
6728
6698
 
6699
+ // ggml_compute_forward_conv_3d
6700
+
6701
+ static void ggml_compute_forward_conv_3d_impl(const ggml_compute_params * params,
6702
+ const ggml_tensor * kernel,
6703
+ const ggml_tensor * src,
6704
+ ggml_tensor * dst,
6705
+ ggml_type kernel_type) {
6706
+
6707
+ GGML_ASSERT(ggml_is_contiguous(kernel));
6708
+ GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);
6709
+ GGML_ASSERT(kernel->type == kernel_type);
6710
+
6711
+ const ggml_type_traits * traits = ggml_get_type_traits(kernel_type);
6712
+
6713
+ const int32_t s0 = dst->op_params[0];
6714
+ const int32_t s1 = dst->op_params[1];
6715
+ const int32_t s2 = dst->op_params[2];
6716
+ const int32_t p0 = dst->op_params[3];
6717
+ const int32_t p1 = dst->op_params[4];
6718
+ const int32_t p2 = dst->op_params[5];
6719
+ const int32_t d0 = dst->op_params[6];
6720
+ const int32_t d1 = dst->op_params[7];
6721
+ const int32_t d2 = dst->op_params[8];
6722
+ const int32_t c = dst->op_params[9];
6723
+ const int32_t n = dst->op_params[10];
6724
+ const int32_t oc = dst->op_params[11];
6725
+
6726
+ const int64_t src_w = src->ne[0];
6727
+ const int64_t src_h = src->ne[1];
6728
+ const int64_t src_d = src->ne[2];
6729
+ const int64_t knl_w = kernel->ne[0];
6730
+ const int64_t knl_h = kernel->ne[1];
6731
+ const int64_t knl_d = kernel->ne[2];
6732
+ const int64_t dst_w = dst->ne[0];
6733
+ const int64_t dst_h = dst->ne[1];
6734
+ const int64_t dst_d = dst->ne[2];
6735
+
6736
+ const float * src_data = (float *) src->data;
6737
+ void * knl_data = kernel->data;
6738
+ float * dst_data = (float *) dst->data;
6739
+
6740
+ const int64_t knl_n_per_channel = knl_w * knl_h * knl_d;
6741
+ const int64_t knl_n_total = knl_n_per_channel * c;
6742
+ const int64_t patch_total = n * dst_w * dst_h * dst_d;
6743
+
6744
+ const int64_t space_per_patch = knl_n_total * traits->type_size + oc * sizeof(float);
6745
+ const int64_t batch_size = params->wsize / space_per_patch;
6746
+ const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
6747
+ const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch;
6748
+
6749
+ GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
6750
+
6751
+ void * tmp = params->wdata;
6752
+
6753
+ for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
6754
+ const int64_t patch_start_batch = batch_i * patches_per_batch;
6755
+ const int64_t patch_end_batch = std::min(patch_start_batch + patches_per_batch, patch_total);
6756
+ const int64_t patch_n_in_batch = patch_end_batch - patch_start_batch;
6757
+
6758
+ const int64_t patch_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;
6759
+ const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread;
6760
+ const int64_t patch_end = std::min(patch_start + patch_per_thread, patch_end_batch);
6761
+
6762
+ for (int64_t p = patch_start; p < patch_end; ++p) {
6763
+ const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
6764
+ const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
6765
+ const int64_t batch_idx = p / (dst_w * dst_h * dst_d);
6766
+ const int64_t dst_z = p_in_batch / (dst_w * dst_h);
6767
+ const int64_t dst_y = p_in_depth / dst_w;
6768
+ const int64_t dst_x = p_in_depth % dst_w;
6769
+
6770
+ char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n_total * traits->type_size;
6771
+
6772
+ for (int64_t ic = 0; ic < c; ++ic) {
6773
+ for (int64_t kz = 0; kz < knl_d; ++kz) {
6774
+ for (int64_t ky = 0; ky < knl_h; ++ky) {
6775
+ for (int64_t kx = 0; kx < knl_w; ++kx) {
6776
+ const int64_t sz = dst_z * s2 + kz * d2 - p2;
6777
+ const int64_t sy = dst_y * s1 + ky * d1 - p1;
6778
+ const int64_t sx = dst_x * s0 + kx * d0 - p0;
6779
+
6780
+ int64_t dst_idx = ic * knl_n_per_channel + kz * (knl_h * knl_w) + ky * knl_w + kx;
6781
+
6782
+ float src_val;
6783
+ if (sz < 0 || sz >= src_d || sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
6784
+ src_val = 0.0f;
6785
+ } else {
6786
+ const int64_t cn_idx = batch_idx * c + ic;
6787
+ const float * src_ptr = (const float *)((const char *)src_data + sx*src->nb[0] + sy*src->nb[1] + sz*src->nb[2] + cn_idx*src->nb[3]);
6788
+ src_val = *src_ptr;
6789
+ }
6790
+
6791
+ char * element_ptr = dst_row + dst_idx * traits->type_size;
6792
+ if (kernel_type == GGML_TYPE_F32) {
6793
+ *(float *)element_ptr = src_val;
6794
+ } else if (kernel_type == GGML_TYPE_F16) {
6795
+ *(ggml_fp16_t *)element_ptr = GGML_CPU_FP32_TO_FP16(src_val);
6796
+ }
6797
+ }
6798
+ }
6799
+ }
6800
+ }
6801
+ }
6802
+
6803
+ ggml_barrier(params->threadpool);
6804
+
6805
+ float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n_total * traits->type_size);
6806
+ ggml_call_mul_mat(kernel_type, params, patch_n_in_batch, oc, knl_n_total, tmp, knl_data, gemm_output);
6807
+
6808
+ ggml_barrier(params->threadpool);
6809
+
6810
+ const int64_t permute_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;
6811
+ const int64_t permute_start = params->ith * permute_per_thread;
6812
+ const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n_in_batch);
6813
+
6814
+ for (int64_t i = permute_start; i < permute_end; ++i) {
6815
+ const int64_t p = patch_start_batch + i;
6816
+ const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
6817
+ const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
6818
+ const int64_t batch_idx = p / (dst_w * dst_h * dst_d);
6819
+ const int64_t dst_z = p_in_batch / (dst_w * dst_h);
6820
+ const int64_t dst_y = p_in_depth / dst_w;
6821
+ const int64_t dst_x = p_in_depth % dst_w;
6822
+
6823
+ for (int64_t ioc = 0; ioc < oc; ++ioc) {
6824
+ const float value = gemm_output[i * oc + ioc];
6825
+ const int64_t ocn_idx = batch_idx * oc + ioc;
6826
+ float * dst_ptr = (float *)((char *)dst_data + dst_x*dst->nb[0] + dst_y*dst->nb[1] + dst_z*dst->nb[2] + ocn_idx*dst->nb[3]);
6827
+ *dst_ptr = value;
6828
+ }
6829
+ }
6830
+ }
6831
+ }
6832
+
6833
+ void ggml_compute_forward_conv_3d(
6834
+ const ggml_compute_params * params,
6835
+ ggml_tensor * dst) {
6836
+ const ggml_tensor * src0 = dst->src[0];
6837
+ const ggml_tensor * src1 = dst->src[1];
6838
+ ggml_compute_forward_conv_3d_impl(params, src0, src1, dst, src0->type);
6839
+ }
6840
+
6729
6841
  // ggml_compute_forward_conv_transpose_2d
6730
6842
 
6731
6843
  void ggml_compute_forward_conv_transpose_2d(
@@ -6857,7 +6969,11 @@ static void ggml_compute_forward_conv_2d_dw_cwhn(
6857
6969
  const int64_t row_end = MIN(row_start + rows_per_thread, rows_total);
6858
6970
 
6859
6971
  #ifdef GGML_SIMD
6860
- const int64_t pkg_size = GGML_F32_EPR;
6972
+ #if defined(__ARM_FEATURE_SVE)
6973
+ const int64_t pkg_size = svcntw();
6974
+ #else
6975
+ const int64_t pkg_size = GGML_F32_EPR;
6976
+ #endif
6861
6977
  const int64_t pkg_count = c / pkg_size;
6862
6978
  const int64_t c_pkg_end = pkg_count * pkg_size;
6863
6979
  #else
@@ -7280,10 +7396,17 @@ static void ggml_compute_forward_upscale_f32(
7280
7396
  float sf1 = (float)ne1/src0->ne[1];
7281
7397
  float sf2 = (float)ne2/src0->ne[2];
7282
7398
  float sf3 = (float)ne3/src0->ne[3];
7399
+ float pixel_offset = 0.5f;
7283
7400
 
7284
7401
  const int32_t mode_flags = ggml_get_op_params_i32(dst, 0);
7285
7402
  const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF);
7286
7403
 
7404
+ if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
7405
+ pixel_offset = 0.0f;
7406
+ sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
7407
+ sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
7408
+ }
7409
+
7287
7410
  if (mode == GGML_SCALE_MODE_NEAREST) {
7288
7411
  for (int64_t i3 = 0; i3 < ne3; i3++) {
7289
7412
  const int64_t i03 = i3 / sf3;
@@ -7302,14 +7425,66 @@ static void ggml_compute_forward_upscale_f32(
7302
7425
  }
7303
7426
  }
7304
7427
  }
7305
- } else if (mode == GGML_SCALE_MODE_BILINEAR) {
7306
- float pixel_offset = 0.5f;
7307
- if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
7308
- pixel_offset = 0.0f;
7309
- sf0 = (float)(ne0 - 1) / (src0->ne[0] - 1);
7310
- sf1 = (float)(ne1 - 1) / (src0->ne[1] - 1);
7311
- }
7428
+ } else if (mode == GGML_SCALE_MODE_BILINEAR && (mode_flags & GGML_SCALE_FLAG_ANTIALIAS)) {
7429
+ // Similar to F.interpolate(..., mode="bilinear", align_corners=False, antialias=True)
7430
+ // https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp
7431
+ auto triangle_filter = [](float x) -> float {
7432
+ return std::max(1.0f - fabsf(x), 0.0f);
7433
+ };
7434
+
7435
+ // support and invscale, minimum 1 pixel for bilinear
7436
+ const float support1 = std::max(1.0f, 1.0f / sf1);
7437
+ const float invscale1 = 1.0f / support1;
7438
+ const float support0 = std::max(1.0f, 1.0f / sf0);
7439
+ const float invscale0 = 1.0f / support0;
7440
+
7441
+ for (int64_t i3 = 0; i3 < ne3; i3++) {
7442
+ const int64_t i03 = i3 / sf3;
7443
+ for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
7444
+ const int64_t i02 = i2 / sf2;
7445
+ for (int64_t i1 = 0; i1 < ne1; i1++) {
7446
+ const float y = ((float) i1 + pixel_offset) / sf1;
7447
+ for (int64_t i0 = 0; i0 < ne0; i0++) {
7448
+ const float x = ((float) i0 + pixel_offset) / sf0;
7449
+
7450
+ // the range of source pixels that contribute
7451
+ const int64_t x_min = std::max<int64_t>(x - support0 + pixel_offset, 0);
7452
+ const int64_t x_max = std::min<int64_t>(x + support0 + pixel_offset, ne00);
7453
+ const int64_t y_min = std::max<int64_t>(y - support1 + pixel_offset, 0);
7454
+ const int64_t y_max = std::min<int64_t>(y + support1 + pixel_offset, ne01);
7455
+
7456
+ // bilinear filter with antialiasing
7457
+ float val = 0.0f;
7458
+ float total_weight = 0.0f;
7459
+
7460
+ for (int64_t sy = y_min; sy < y_max; sy++) {
7461
+ const float weight_y = triangle_filter((sy - y + pixel_offset) * invscale1);
7462
+
7463
+ for (int64_t sx = x_min; sx < x_max; sx++) {
7464
+ const float weight_x = triangle_filter((sx - x + pixel_offset) * invscale0);
7465
+ const float weight = weight_x * weight_y;
7466
+
7467
+ if (weight <= 0.0f) {
7468
+ continue;
7469
+ }
7470
+
7471
+ const float pixel = *(const float *)((const char *)src0->data + sx*nb00 + sy*nb01 + i02*nb02 + i03*nb03);
7472
+ val += pixel * weight;
7473
+ total_weight += weight;
7474
+ }
7475
+ }
7312
7476
 
7477
+ if (total_weight > 0.0f) {
7478
+ val /= total_weight;
7479
+ }
7480
+
7481
+ float * dst_ptr = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
7482
+ *dst_ptr = val;
7483
+ }
7484
+ }
7485
+ }
7486
+ }
7487
+ } else if (mode == GGML_SCALE_MODE_BILINEAR) {
7313
7488
  for (int64_t i3 = 0; i3 < ne3; i3++) {
7314
7489
  const int64_t i03 = i3 / sf3;
7315
7490
  for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
@@ -7344,6 +7519,51 @@ static void ggml_compute_forward_upscale_f32(
7344
7519
 
7345
7520
  const float val = a*(1 - dx)*(1 - dy) + b*dx*(1 - dy) + c*(1 - dx)*dy + d*dx*dy;
7346
7521
 
7522
+ float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
7523
+ *y_dst = val;
7524
+ }
7525
+ }
7526
+ }
7527
+ }
7528
+ } else if (mode == GGML_SCALE_MODE_BICUBIC) {
7529
+ // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
7530
+ const float a = -0.75f; // use alpha = -0.75 (same as PyTorch)
7531
+ auto weight1 = [a](float x) { return ((a + 2) * x - (a + 3)) * x * x + 1; };
7532
+ auto weight2 = [a](float x) { return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; };
7533
+ auto bicubic = [=](float p0, float p1, float p2, float p3, float x) {
7534
+ const float w0 = weight2(x + 1);
7535
+ const float w1 = weight1(x + 0);
7536
+ const float w2 = weight1(1 - x);
7537
+ const float w3 = weight2(2 - x);
7538
+ return p0*w0 + p1*w1 + p2*w2 + p3*w3;
7539
+ };
7540
+
7541
+ for (int64_t i3 = 0; i3 < ne3; i3++) {
7542
+ const int64_t i03 = i3 / sf3;
7543
+ for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
7544
+ const int64_t i02 = i2 / sf2;
7545
+ for (int64_t i1 = 0; i1 < ne1; i1++) {
7546
+ const float y = ((float)i1 + pixel_offset) / sf1 - pixel_offset;
7547
+ const int64_t y0 = (int64_t)floorf(y);
7548
+ const float dy = y - (float)y0;
7549
+
7550
+ for (int64_t i0 = 0; i0 < ne0; i0++) {
7551
+ const float x = ((float)i0 + pixel_offset) / sf0 - pixel_offset;
7552
+ const int64_t x0 = (int64_t)floorf(x);
7553
+ const float dx = x - (float)x0;
7554
+
7555
+ auto p = [=](int64_t x_off, int64_t y_off) -> float {
7556
+ int64_t i00 = std::max(int64_t(0), std::min(x0 + x_off, ne00 - 1));
7557
+ int64_t i01 = std::max(int64_t(0), std::min(y0 + y_off, ne01 - 1));
7558
+ return *(const float *)((const char *)src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7559
+ };
7560
+
7561
+ const float val = bicubic(
7562
+ bicubic(p(-1,-1), p(0,-1), p(1,-1), p(2,-1), dx),
7563
+ bicubic(p(-1, 0), p(0, 0), p(1, 0), p(2, 0), dx),
7564
+ bicubic(p(-1, 1), p(0, 1), p(1, 1), p(2, 1), dx),
7565
+ bicubic(p(-1, 2), p(0, 2), p(1, 2), p(2, 2), dx), dy);
7566
+
7347
7567
  float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
7348
7568
  *y_dst = val;
7349
7569
  }
@@ -7376,6 +7596,7 @@ void ggml_compute_forward_upscale(
7376
7596
 
7377
7597
  // ggml_compute_forward_pad
7378
7598
 
7599
+ template<bool circular_t>
7379
7600
  static void ggml_compute_forward_pad_f32(
7380
7601
  const ggml_compute_params * params,
7381
7602
  ggml_tensor * dst) {
@@ -7391,6 +7612,14 @@ static void ggml_compute_forward_pad_f32(
7391
7612
  GGML_TENSOR_UNARY_OP_LOCALS
7392
7613
 
7393
7614
  float * dst_ptr = (float *) dst->data;
7615
+ const int32_t lp0 = ggml_get_op_params_i32(dst, 0);
7616
+ const int32_t rp0 = ggml_get_op_params_i32(dst, 1);
7617
+ const int32_t lp1 = ggml_get_op_params_i32(dst, 2);
7618
+ const int32_t rp1 = ggml_get_op_params_i32(dst, 3);
7619
+ const int32_t lp2 = ggml_get_op_params_i32(dst, 4);
7620
+ const int32_t rp2 = ggml_get_op_params_i32(dst, 5);
7621
+ const int32_t lp3 = ggml_get_op_params_i32(dst, 6);
7622
+ const int32_t rp3 = ggml_get_op_params_i32(dst, 7);
7394
7623
 
7395
7624
  // TODO: optimize
7396
7625
 
@@ -7398,14 +7627,34 @@ static void ggml_compute_forward_pad_f32(
7398
7627
  for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
7399
7628
  for (int64_t i0 = 0; i0 < ne0; ++i0) {
7400
7629
  for (int64_t i3 = 0; i3 < ne3; ++i3) {
7401
- const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
7402
-
7403
- const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
7404
-
7405
- if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
7630
+ // circular means wrap around on a torus, so x and y loop around
7631
+ if constexpr (circular_t) {
7632
+ const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
7633
+ const int64_t src_i0 = ggml_wrap_around(i0 - lp0, ne00);
7634
+ const int64_t src_i1 = ggml_wrap_around(i1 - lp1, ne01);
7635
+ const int64_t src_i2 = ggml_wrap_around(i2 - lp2, ne02);
7636
+ const int64_t src_i3 = ggml_wrap_around(i3 - lp3, ne03);
7637
+
7638
+ const int64_t src_idx =
7639
+ src_i3*nb03 +
7640
+ src_i2*nb02 +
7641
+ src_i1*nb01 +
7642
+ src_i0*nb00;
7643
+
7644
+ const float * src_ptr = (const float *)((char *) src0->data + src_idx);
7406
7645
  dst_ptr[dst_idx] = *src_ptr;
7407
7646
  } else {
7408
- dst_ptr[dst_idx] = 0;
7647
+ const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
7648
+ if ((i0 >= lp0 && i0 < ne0 - rp0) \
7649
+ && (i1 >= lp1 && i1 < ne1 - rp1) \
7650
+ && (i2 >= lp2 && i2 < ne2 - rp2) \
7651
+ && (i3 >= lp3 && i3 < ne3 - rp3)) {
7652
+ const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
7653
+ const float * src_ptr = (const float *)((char *) src0->data + src_idx);
7654
+ dst_ptr[dst_idx] = *src_ptr;
7655
+ } else {
7656
+ dst_ptr[dst_idx] = 0;
7657
+ }
7409
7658
  }
7410
7659
  }
7411
7660
  }
@@ -7413,16 +7662,20 @@ static void ggml_compute_forward_pad_f32(
7413
7662
  }
7414
7663
  }
7415
7664
 
7665
+
7416
7666
  void ggml_compute_forward_pad(
7417
7667
  const ggml_compute_params * params,
7418
7668
  ggml_tensor * dst) {
7419
-
7420
7669
  const ggml_tensor * src0 = dst->src[0];
7421
-
7670
+ const bool circular = (bool) ggml_get_op_params_i32(dst, 8);
7422
7671
  switch (src0->type) {
7423
7672
  case GGML_TYPE_F32:
7424
7673
  {
7425
- ggml_compute_forward_pad_f32(params, dst);
7674
+ if (circular) {
7675
+ ggml_compute_forward_pad_f32<true>(params, dst);
7676
+ } else {
7677
+ ggml_compute_forward_pad_f32<false>(params, dst);
7678
+ }
7426
7679
  } break;
7427
7680
  default:
7428
7681
  {
@@ -7601,7 +7854,7 @@ static void ggml_compute_forward_timestep_embedding_f32(
7601
7854
  embed_data[j + half] = sinf(arg);
7602
7855
  }
7603
7856
  if (dim % 2 != 0 && ith == 0) {
7604
- embed_data[dim] = 0.f;
7857
+ embed_data[2 * half] = 0.f;
7605
7858
  }
7606
7859
  }
7607
7860
  }
@@ -7615,7 +7868,80 @@ void ggml_compute_forward_timestep_embedding(
7615
7868
  switch (src0->type) {
7616
7869
  case GGML_TYPE_F32:
7617
7870
  {
7618
- ggml_compute_forward_timestep_embedding_f32(params, dst);
7871
+ ggml_compute_forward_timestep_embedding_f32(params, dst);
7872
+ } break;
7873
+ default:
7874
+ {
7875
+ GGML_ABORT("fatal error");
7876
+ }
7877
+ }
7878
+ }
7879
+
7880
+ // ggml_compute_forward_argsort
7881
+
7882
+ template<enum ggml_sort_order order>
7883
+ struct cmp_argsort {
7884
+ const float * data;
7885
+ bool operator()(int32_t a, int32_t b) const {
7886
+ if constexpr (order == GGML_SORT_ORDER_ASC) {
7887
+ return data[a] < data[b];
7888
+ } else {
7889
+ return data[a] > data[b];
7890
+ }
7891
+ }
7892
+ };
7893
+
7894
+ static void ggml_compute_forward_argsort_f32(
7895
+ const ggml_compute_params * params,
7896
+ ggml_tensor * dst) {
7897
+
7898
+ const ggml_tensor * src0 = dst->src[0];
7899
+
7900
+ GGML_TENSOR_UNARY_OP_LOCALS
7901
+
7902
+ GGML_ASSERT(nb0 == sizeof(float));
7903
+
7904
+ const int ith = params->ith;
7905
+ const int nth = params->nth;
7906
+
7907
+ const int64_t nr = ggml_nrows(src0);
7908
+
7909
+ ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(dst, 0);
7910
+
7911
+ for (int64_t i = ith; i < nr; i += nth) {
7912
+ const float * src_data = (float *)((char *) src0->data + i*nb01);
7913
+
7914
+ int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
7915
+
7916
+ for (int64_t j = 0; j < ne0; j++) {
7917
+ dst_data[j] = j;
7918
+ }
7919
+
7920
+ switch (order) {
7921
+ case GGML_SORT_ORDER_ASC:
7922
+ std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_ASC>{src_data});
7923
+ break;
7924
+
7925
+ case GGML_SORT_ORDER_DESC:
7926
+ std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_DESC>{src_data});
7927
+ break;
7928
+
7929
+ default:
7930
+ GGML_ABORT("invalid sort order");
7931
+ }
7932
+ }
7933
+ }
7934
+
7935
+ void ggml_compute_forward_argsort(
7936
+ const ggml_compute_params * params,
7937
+ ggml_tensor * dst) {
7938
+
7939
+ const ggml_tensor * src0 = dst->src[0];
7940
+
7941
+ switch (src0->type) {
7942
+ case GGML_TYPE_F32:
7943
+ {
7944
+ ggml_compute_forward_argsort_f32(params, dst);
7619
7945
  } break;
7620
7946
  default:
7621
7947
  {
@@ -7624,9 +7950,16 @@ void ggml_compute_forward_timestep_embedding(
7624
7950
  }
7625
7951
  }
7626
7952
 
7627
- // ggml_compute_forward_argsort
7953
+ // ggml_compute_forward_top_k
7628
7954
 
7629
- static void ggml_compute_forward_argsort_f32(
7955
+ struct cmp_top_k {
7956
+ const float * data;
7957
+ bool operator()(int32_t a, int32_t b) const {
7958
+ return data[a] > data[b];
7959
+ }
7960
+ };
7961
+
7962
+ static void ggml_compute_forward_top_k_f32(
7630
7963
  const ggml_compute_params * params,
7631
7964
  ggml_tensor * dst) {
7632
7965
 
@@ -7641,31 +7974,31 @@ static void ggml_compute_forward_argsort_f32(
7641
7974
 
7642
7975
  const int64_t nr = ggml_nrows(src0);
7643
7976
 
7644
- ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(dst, 0);
7977
+ const int top_k = ne0;
7978
+
7979
+ int32_t * tmp = (int32_t *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
7645
7980
 
7646
7981
  for (int64_t i = ith; i < nr; i += nth) {
7647
- int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
7648
7982
  const float * src_data = (float *)((char *) src0->data + i*nb01);
7649
7983
 
7650
- for (int64_t j = 0; j < ne0; j++) {
7651
- dst_data[j] = j;
7984
+ for (int64_t j = 0; j < ne00; j++) {
7985
+ tmp[j] = j;
7652
7986
  }
7653
7987
 
7654
- // C doesn't have a functional sort, so we do a bubble sort instead
7655
- for (int64_t j = 0; j < ne0; j++) {
7656
- for (int64_t k = j + 1; k < ne0; k++) {
7657
- if ((order == GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
7658
- (order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
7659
- int32_t tmp = dst_data[j];
7660
- dst_data[j] = dst_data[k];
7661
- dst_data[k] = tmp;
7662
- }
7663
- }
7988
+ std::partial_sort(tmp, tmp + top_k, tmp + ne00, cmp_top_k{src_data});
7989
+
7990
+ int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
7991
+
7992
+ std::copy(tmp, tmp + top_k, dst_data);
7993
+
7994
+ // emphasize that the order is not important
7995
+ if (top_k > 1) {
7996
+ std::swap(dst_data[0], dst_data[1]);
7664
7997
  }
7665
7998
  }
7666
7999
  }
7667
8000
 
7668
- void ggml_compute_forward_argsort(
8001
+ void ggml_compute_forward_top_k(
7669
8002
  const ggml_compute_params * params,
7670
8003
  ggml_tensor * dst) {
7671
8004
 
@@ -7674,7 +8007,7 @@ void ggml_compute_forward_argsort(
7674
8007
  switch (src0->type) {
7675
8008
  case GGML_TYPE_F32:
7676
8009
  {
7677
- ggml_compute_forward_argsort_f32(params, dst);
8010
+ ggml_compute_forward_top_k_f32(params, dst);
7678
8011
  } break;
7679
8012
  default:
7680
8013
  {
@@ -7685,13 +8018,15 @@ void ggml_compute_forward_argsort(
7685
8018
 
7686
8019
  // ggml_compute_forward_flash_attn_ext
7687
8020
 
7688
- static void ggml_compute_forward_flash_attn_ext_f16(
8021
+ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
7689
8022
  const ggml_compute_params * params,
7690
- const ggml_tensor * q,
7691
- const ggml_tensor * k,
7692
- const ggml_tensor * v,
7693
- const ggml_tensor * mask,
7694
- ggml_tensor * dst) {
8023
+ ggml_tensor * dst,
8024
+ int ir0, int ir1) {
8025
+ const ggml_tensor * q = dst->src[0];
8026
+ const ggml_tensor * k = dst->src[1];
8027
+ const ggml_tensor * v = dst->src[2];
8028
+ const ggml_tensor * mask = dst->src[3];
8029
+ const ggml_tensor * sinks = dst->src[4];
7695
8030
 
7696
8031
  GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
7697
8032
  GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
@@ -7702,9 +8037,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
7702
8037
  GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
7703
8038
  GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
7704
8039
 
7705
- const int ith = params->ith;
7706
- const int nth = params->nth;
7707
-
7708
8040
  const int64_t DK = nek0;
7709
8041
  const int64_t DV = nev0;
7710
8042
  const int64_t N = neq1;
@@ -7738,16 +8070,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
7738
8070
 
7739
8071
  // parallelize by q rows using ggml_vec_dot_f32
7740
8072
 
7741
- // total rows in q
7742
- const int nr = neq1*neq2*neq3;
7743
-
7744
- // rows per thread
7745
- const int dr = (nr + nth - 1)/nth;
7746
-
7747
- // row range for this thread
7748
- const int ir0 = dr*ith;
7749
- const int ir1 = MIN(ir0 + dr, nr);
7750
-
7751
8073
  float scale = 1.0f;
7752
8074
  float max_bias = 0.0f;
7753
8075
  float logit_softcap = 0.0f;
@@ -7766,7 +8088,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
7766
8088
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
7767
8089
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
7768
8090
 
7769
- ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
8091
+ ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
7770
8092
  ggml_from_float_t const q_to_vec_dot = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float;
7771
8093
  ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot;
7772
8094
  ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float;
@@ -7774,6 +8096,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
7774
8096
  GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type");
7775
8097
  GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type");
7776
8098
 
8099
+ int ith = params->ith;
8100
+
7777
8101
  // loop over n_batch and n_head
7778
8102
  for (int ir = ir0; ir < ir1; ++ir) {
7779
8103
  // q indices
@@ -7798,7 +8122,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
7798
8122
  memset(VKQ32, 0, DV*sizeof(float));
7799
8123
  }
7800
8124
 
7801
- const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
8125
+ const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]) : NULL;
7802
8126
 
7803
8127
  // k indices
7804
8128
  const int ik3 = iq3 / rk3;
@@ -7887,8 +8211,25 @@ static void ggml_compute_forward_flash_attn_ext_f16(
7887
8211
  }
7888
8212
  }
7889
8213
 
8214
+ // sinks
8215
+ if (sinks) {
8216
+ const float s = ((float *)((char *) sinks->data))[h];
8217
+
8218
+ float ms = 1.0f;
8219
+ float vs = 1.0f;
8220
+
8221
+ if (s > M) {
8222
+ ms = expf(M - s);
8223
+ ggml_vec_scale_f32(DV, VKQ32, ms);
8224
+ } else {
8225
+ vs = expf(s - M);
8226
+ }
8227
+
8228
+ S = S*ms + vs;
8229
+ }
8230
+
7890
8231
  // V /= S
7891
- const float S_inv = 1.0f/S;
8232
+ const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
7892
8233
  ggml_vec_scale_f32(DV, VKQ32, S_inv);
7893
8234
 
7894
8235
  // dst indices
@@ -7904,19 +8245,100 @@ static void ggml_compute_forward_flash_attn_ext_f16(
7904
8245
  }
7905
8246
  }
7906
8247
 
8248
+ static void ggml_compute_forward_flash_attn_ext_f16(
8249
+ const ggml_compute_params * params,
8250
+ ggml_tensor * dst) {
8251
+
8252
+ const ggml_tensor * q = dst->src[0];
8253
+ const ggml_tensor * k = dst->src[1];
8254
+ const ggml_tensor * v = dst->src[2];
8255
+
8256
+ GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
8257
+ GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
8258
+ GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
8259
+ GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
8260
+ GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
8261
+ GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
8262
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
8263
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
8264
+
8265
+ const int64_t DK = nek0;
8266
+ const int64_t DV = nev0;
8267
+ const int64_t N = neq1;
8268
+
8269
+ GGML_ASSERT(ne0 == DV);
8270
+ GGML_ASSERT(ne2 == N);
8271
+
8272
+ // input tensor rows must be contiguous
8273
+ GGML_ASSERT(nbq0 == ggml_type_size(q->type));
8274
+ GGML_ASSERT(nbk0 == ggml_type_size(k->type));
8275
+ GGML_ASSERT(nbv0 == ggml_type_size(v->type));
8276
+
8277
+ GGML_ASSERT(neq0 == DK);
8278
+ GGML_ASSERT(nek0 == DK);
8279
+ GGML_ASSERT(nev0 == DV);
8280
+
8281
+ GGML_ASSERT(neq1 == N);
8282
+
8283
+ // dst cannot be transposed or permuted
8284
+ GGML_ASSERT(nb0 == sizeof(float));
8285
+ GGML_ASSERT(nb0 <= nb1);
8286
+ GGML_ASSERT(nb1 <= nb2);
8287
+ GGML_ASSERT(nb2 <= nb3);
8288
+
8289
+ // parallelize by q rows using ggml_vec_dot_f32
8290
+
8291
+ // total rows in q
8292
+ const int64_t nr = neq1*neq2*neq3;
8293
+
8294
+ // rows per thread
8295
+ const int ith = params->ith;
8296
+ const int nth = params->nth;
8297
+
8298
+ // disable for NUMA
8299
+ const bool disable_chunking = ggml_is_numa();
8300
+
8301
+ // 4x chunks per thread
8302
+ int nth_scaled = nth * 4;
8303
+ int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
8304
+ int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
8305
+
8306
+ if (nth == 1 || nchunk < nth || disable_chunking) {
8307
+ nchunk = nth;
8308
+ }
8309
+
8310
+ if (ith == 0) {
8311
+ // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
8312
+ ggml_threadpool_chunk_set(params->threadpool, nth);
8313
+ }
8314
+
8315
+ ggml_barrier(params->threadpool);
8316
+
8317
+ // The number of elements in each chunk
8318
+ const int64_t dr = (nr + nchunk - 1) / nchunk;
8319
+
8320
+ // The first chunk comes from our thread_id, the rest will get auto-assigned.
8321
+ int current_chunk = ith;
8322
+
8323
+ while (current_chunk < nchunk) {
8324
+ const int64_t ir0 = dr * current_chunk;
8325
+ const int64_t ir1 = MIN(ir0 + dr, nr);
8326
+
8327
+ ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
8328
+
8329
+ current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
8330
+ }
8331
+ }
8332
+
7907
8333
  void ggml_compute_forward_flash_attn_ext(
7908
8334
  const ggml_compute_params * params,
7909
- const ggml_tensor * q,
7910
- const ggml_tensor * k,
7911
- const ggml_tensor * v,
7912
- const ggml_tensor * mask,
7913
8335
  ggml_tensor * dst) {
7914
8336
  switch (dst->op_params[3]) {
7915
8337
  case GGML_PREC_DEFAULT:
7916
8338
  case GGML_PREC_F32:
7917
8339
  {
7918
8340
  // uses F32 accumulators
7919
- ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst);
8341
+ ggml_compute_forward_flash_attn_ext_f16(params, dst);
7920
8342
  } break;
7921
8343
  default:
7922
8344
  {
@@ -8336,120 +8758,214 @@ void ggml_compute_forward_ssm_conv(
8336
8758
  static void ggml_compute_forward_ssm_scan_f32(
8337
8759
  const ggml_compute_params * params,
8338
8760
  ggml_tensor * dst) {
8339
- const ggml_tensor * src0 = dst->src[0]; // s
8340
- const ggml_tensor * src1 = dst->src[1]; // x
8341
- const ggml_tensor * src2 = dst->src[2]; // dt
8342
- const ggml_tensor * src3 = dst->src[3]; // A
8343
- const ggml_tensor * src4 = dst->src[4]; // B
8344
- const ggml_tensor * src5 = dst->src[5]; // C
8761
+ const ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs+}
8762
+ const ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs}
8763
+ const ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}
8764
+ const ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head}
8765
+ const ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs}
8766
+ const ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs}
8767
+ const ggml_tensor * src6 = dst->src[6]; // ids {n_seqs}
8345
8768
 
8346
8769
  const int ith = params->ith;
8347
8770
  const int nth = params->nth;
8348
8771
 
8349
- const int64_t nc = src0->ne[0]; // d_state
8350
- const int64_t nr = src0->ne[1]; // d_inner
8351
- const int64_t n_t = src1->ne[1]; // number of tokens per sequence
8352
- const int64_t n_s = src0->ne[2]; // number of sequences in the batch
8772
+ const int64_t nc = src0->ne[0]; // d_state
8773
+ const int64_t nr = src0->ne[1]; // dim
8774
+ const int64_t nh = src1->ne[1]; // n_head
8775
+ const int64_t ng = src4->ne[1];
8776
+ const int64_t nt = src1->ne[2]; // number of tokens per sequence
8777
+ const int64_t ns = src1->ne[3]; // number of sequences in the batch
8778
+
8779
+ // can't use ggml_nbytes because src1 is not necessarily contiguous
8780
+ const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1);
8353
8781
 
8354
- GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
8782
+ GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst));
8355
8783
  GGML_ASSERT(src0->nb[0] == sizeof(float));
8356
8784
  GGML_ASSERT(src1->nb[0] == sizeof(float));
8357
8785
  GGML_ASSERT(src2->nb[0] == sizeof(float));
8358
8786
  GGML_ASSERT(src3->nb[0] == sizeof(float));
8359
8787
  GGML_ASSERT(src4->nb[0] == sizeof(float));
8360
8788
  GGML_ASSERT(src5->nb[0] == sizeof(float));
8361
- // required for the dot product between s and C
8362
- GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
8363
- // required for per-sequence offsets for states
8364
- GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
8365
- // required to get correct offset for state destination (i.e. src1->nb[3])
8366
- GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
8789
+ GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
8790
+ GGML_ASSERT(nh % ng == 0);
8367
8791
 
8368
- // rows per thread
8369
- const int dr = (nr + nth - 1)/nth;
8792
+ // heads per thread
8793
+ const int dh = (nh + nth - 1)/nth;
8370
8794
 
8371
- // row range for this thread
8372
- const int ir0 = dr*ith;
8373
- const int ir1 = MIN(ir0 + dr, nr);
8374
- const int ir = ir1 - ir0;
8795
+ // head range for this thread
8796
+ const int ih0 = dh*ith;
8797
+ const int ih1 = MIN(ih0 + dh, nh);
8798
+
8799
+ const int32_t * ids = (const int32_t *) src6->data;
8800
+
8801
+ for (int i3 = 0; i3 < ns; ++i3) {
8802
+ const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns}
8803
+ float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
8804
+
8805
+ for (int i2 = 0; i2 < nt; ++i2) {
8806
+ const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}
8807
+ const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}
8808
+ const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh}
8809
+ const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
8810
+ const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
8811
+ float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
8812
+
8813
+ if (src3->ne[0] == 1) {
8814
+ // Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop
8815
+
8816
+ // n_head
8817
+ for (int h = ih0; h < ih1; ++h) {
8818
+ // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8819
+ const float dt_soft_plus = ggml_compute_softplus_f32(dt[h]);
8820
+ const float dA = expf(dt_soft_plus * A[h]);
8821
+ const int g = h / (nh / ng); // repeat_interleave
8822
+
8823
+ // dim
8824
+ for (int i1 = 0; i1 < nr; ++i1) {
8825
+ const int ii = i1 + h*nr;
8826
+ const float x_dt = x[ii] * dt_soft_plus;
8827
+ float sumf = 0.0f;
8828
+ #if defined(GGML_SIMD)
8829
+ #if defined(__ARM_FEATURE_SVE)
8830
+ const int ggml_f32_epr = svcntw();
8831
+ const int ggml_f32_step = 1 * ggml_f32_epr;
8832
+
8833
+ const int np = (nc & ~(ggml_f32_step - 1));
8834
+
8835
+ GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
8836
+
8837
+ GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
8838
+ GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
8839
+
8840
+ for (int i = 0; i < np; i += ggml_f32_step) {
8841
+ // TODO: maybe unroll more?
8842
+ for (int j = 0; j < 1; j++) {
8843
+ GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc);
8844
+ GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + g*nc);
8845
+ GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + g*nc);
8846
+
8847
+ t0 = GGML_F32_VEC_MUL(t0, adA);
8848
+ t1 = GGML_F32_VEC_MUL(t1, axdt);
8849
+
8850
+ t0 = GGML_F32_VEC_ADD(t0, t1);
8851
+
8852
+ sum = GGML_F32_VEC_FMA(sum, t0, t2);
8375
8853
 
8376
- #ifdef __ARM_FEATURE_SVE
8377
- for (int i3 = 0; i3 < n_s; ++i3) {
8378
- for (int i2 = 0; i2 < n_t; ++i2) {
8379
- const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
8380
- const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
8381
- const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
8382
- const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
8383
- const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
8384
- const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
8385
- float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
8386
- float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
8387
-
8388
- // use the output as the source for the next token-wise iterations
8389
- if (i2 > 0) { s0 = s; }
8390
-
8391
- // d_inner
8392
- for (int i1 = 0; i1 < ir; ++i1) {
8393
- float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
8394
- float x_dt = x[i1] * dt_soft_plus;
8395
- svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
8396
- svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
8397
- svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
8398
-
8399
- for (int64_t k = 0; k < nc; k += svcntw()) {
8400
- svfloat32_t vA = GGML_F32_VEC_LOAD(&A[i1*nc + k]);
8401
- svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k]);
8402
- svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k]);
8403
- svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[i1*nc + k]);
8404
-
8405
- svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
8406
- t1 = exp_ps_sve(svptrue_b32(), t1);
8407
- svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
8408
-
8409
- vs0 = GGML_F32_VEC_FMA(vs0, t1, t2);
8410
- r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
8411
-
8412
- GGML_F32_VEC_STORE(&s[i1*nc + k], vs0);
8854
+ GGML_F32_VEC_STORE(s + i + j*ggml_f32_epr + ii*nc, t0);
8855
+ }
8856
+ }
8857
+
8858
+ sumf = GGML_F32xt_REDUCE_ONE(sum);
8859
+ #elif defined(__riscv_v_intrinsic)
8860
+ // todo: RVV implementation
8861
+ const int np = 0;
8862
+ #else
8863
+ const int np = (nc & ~(GGML_F32_STEP - 1));
8864
+
8865
+ GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8866
+
8867
+ GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
8868
+ GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
8869
+
8870
+ GGML_F32_VEC ax[GGML_F32_ARR];
8871
+ GGML_F32_VEC ay[GGML_F32_ARR];
8872
+ GGML_F32_VEC az[GGML_F32_ARR];
8873
+
8874
+ for (int i = 0; i < np; i += GGML_F32_STEP) {
8875
+ for (int j = 0; j < GGML_F32_ARR; j++) {
8876
+ ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
8877
+ ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + g*nc);
8878
+ az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + g*nc);
8879
+
8880
+ ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
8881
+ ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
8882
+
8883
+ ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]);
8884
+
8885
+ sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]);
8886
+
8887
+ GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]);
8888
+ }
8889
+ }
8890
+
8891
+ // reduce sum0..sum3 to sum0
8892
+ GGML_F32_VEC_REDUCE(sumf, sum);
8893
+ #endif
8894
+ #else
8895
+ const int np = 0;
8896
+ #endif
8897
+ // d_state
8898
+ for (int i0 = np; i0 < nc; ++i0) {
8899
+ const int i = i0 + ii*nc;
8900
+ const int ig = i0 + g*nc;
8901
+ // state = prev_state * dA + dB * x
8902
+ const float state = (s0[i] * dA) + (B[ig] * x_dt);
8903
+ // y = rowwise_dotprod(state, C)
8904
+ sumf += state * C[ig];
8905
+ s[i] = state;
8906
+ }
8907
+ y[ii] = sumf;
8413
8908
  }
8414
- y[i1] = GGML_F32xt_REDUCE_ONE(r1_vector);
8415
8909
  }
8416
- }
8417
- }
8418
- #else
8419
- for (int i3 = 0; i3 < n_s; ++i3) {
8420
- for (int i2 = 0; i2 < n_t; ++i2) {
8421
- const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
8422
- const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
8423
- const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
8424
- const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
8425
- const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
8426
- const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
8427
- float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
8428
- float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
8429
-
8430
- // use the output as the source for the next token-wise iterations
8431
- if (i2 > 0) { s0 = s; }
8432
-
8433
- // d_inner
8434
- for (int i1 = 0; i1 < ir; ++i1) {
8435
- // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
8436
- float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
8437
- float x_dt = x[i1] * dt_soft_plus;
8438
- float sumf = 0.0f;
8439
- // d_state
8440
- for (int i0 = 0; i0 < nc; ++i0) {
8441
- int i = i0 + i1*nc;
8442
- // state = prev_state * dA + dB * x
8443
- float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
8444
- // y = rowwise_dotprod(state, C)
8445
- sumf += state * C[i0];
8446
- s[i] = state;
8910
+ } else {
8911
+ // Mamba-1 has an element-wise decay factor for the states
8912
+
8913
+ // n_head
8914
+ for (int h = ih0; h < ih1; ++h) {
8915
+ // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8916
+ const float dt_soft_plus = ggml_compute_softplus_f32(dt[h]);
8917
+ const int g = h / (nh / ng); // repeat_interleave
8918
+
8919
+ // dim
8920
+ for (int i1 = 0; i1 < nr; ++i1) {
8921
+ const int ii = i1 + h*nr;
8922
+ const float x_dt = x[ii] * dt_soft_plus;
8923
+ #if defined(__ARM_FEATURE_SVE)
8924
+ svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
8925
+ svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
8926
+ svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
8927
+
8928
+ // d_state
8929
+ // TODO: what happens when (d_state % svcntw()) != 0?
8930
+ for (int64_t k = 0; k < nc; k += svcntw()) {
8931
+ svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]);
8932
+ svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + g*nc]);
8933
+ svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + g*nc]);
8934
+ svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
8935
+
8936
+ svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
8937
+ t1 = exp_ps_sve(svptrue_b32(), t1);
8938
+ svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
8939
+
8940
+ vs0 = GGML_F32_VEC_FMA(t2, vs0, t1);
8941
+ r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
8942
+
8943
+ GGML_F32_VEC_STORE(&s[ii*nc + k], vs0);
8944
+ }
8945
+ y[ii] = GGML_F32xt_REDUCE_ONE(r1_vector);
8946
+ #else
8947
+ float sumf = 0.0f;
8948
+ // NOTE: can't really use GGML_SIMD here because d_state is usually 16
8949
+ // and also because expf is used within the loop.
8950
+ // d_state
8951
+ for (int i0 = 0; i0 < nc; ++i0) {
8952
+ const int i = i0 + ii*nc;
8953
+ const int ig = i0 + g*nc;
8954
+ // state = prev_state * dA + dB * x
8955
+ const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
8956
+ // y = rowwise_dotprod(state, C)
8957
+ sumf += state * C[ig];
8958
+ s[i] = state;
8959
+ }
8960
+ y[ii] = sumf;
8961
+ #endif
8447
8962
  }
8448
- y[i1] = sumf;
8449
8963
  }
8450
8964
  }
8965
+ // use the output as the source when it's not the first token-wise iteration
8966
+ s0 = s;
8451
8967
  }
8452
- #endif
8968
+ }
8453
8969
  }
8454
8970
 
8455
8971
  void ggml_compute_forward_ssm_scan(
@@ -8660,6 +9176,34 @@ void ggml_compute_forward_unary(
8660
9176
  {
8661
9177
  ggml_compute_forward_exp(params, dst);
8662
9178
  } break;
9179
+ case GGML_UNARY_OP_FLOOR:
9180
+ {
9181
+ ggml_compute_forward_floor(params, dst);
9182
+ } break;
9183
+ case GGML_UNARY_OP_CEIL:
9184
+ {
9185
+ ggml_compute_forward_ceil(params, dst);
9186
+ } break;
9187
+ case GGML_UNARY_OP_ROUND:
9188
+ {
9189
+ ggml_compute_forward_round(params, dst);
9190
+ } break;
9191
+ case GGML_UNARY_OP_TRUNC:
9192
+ {
9193
+ ggml_compute_forward_trunc(params, dst);
9194
+ } break;
9195
+ case GGML_UNARY_OP_XIELU:
9196
+ {
9197
+ ggml_compute_forward_xielu(params, dst);
9198
+ } break;
9199
+ case GGML_UNARY_OP_EXPM1:
9200
+ {
9201
+ ggml_compute_forward_expm1(params, dst);
9202
+ } break;
9203
+ case GGML_UNARY_OP_SOFTPLUS:
9204
+ {
9205
+ ggml_compute_forward_softplus(params, dst);
9206
+ } break;
8663
9207
  default:
8664
9208
  {
8665
9209
  GGML_ABORT("fatal error");
@@ -8688,6 +9232,18 @@ void ggml_compute_forward_glu(
8688
9232
  {
8689
9233
  ggml_compute_forward_swiglu(params, dst);
8690
9234
  } break;
9235
+ case GGML_GLU_OP_SWIGLU_OAI:
9236
+ {
9237
+ ggml_compute_forward_swiglu_oai(params, dst);
9238
+ } break;
9239
+ case GGML_GLU_OP_GEGLU_ERF:
9240
+ {
9241
+ ggml_compute_forward_geglu_erf(params, dst);
9242
+ } break;
9243
+ case GGML_GLU_OP_GEGLU_QUICK:
9244
+ {
9245
+ ggml_compute_forward_geglu_quick(params, dst);
9246
+ } break;
8691
9247
  default:
8692
9248
  {
8693
9249
  GGML_ABORT("fatal error");
@@ -9244,6 +9800,76 @@ void ggml_compute_forward_gla(
9244
9800
  }
9245
9801
  }
9246
9802
 
9803
+ static void ggml_compute_forward_solve_tri_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst) {
9804
+ const struct ggml_tensor * src0 = dst->src[0]; // A (lower triangular)
9805
+ const struct ggml_tensor * src1 = dst->src[1]; // B (RHS)
9806
+
9807
+ GGML_TENSOR_BINARY_OP_LOCALS;
9808
+
9809
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
9810
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
9811
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
9812
+
9813
+ GGML_ASSERT(ne00 == ne01); // A must be square
9814
+ GGML_ASSERT(ne0 == ne10); // solution cols == B cols
9815
+ GGML_ASSERT(ne1 == ne11); // solution rows == B rows
9816
+
9817
+ GGML_ASSERT(ne02 == ne12 && ne12 == ne2);
9818
+ GGML_ASSERT(ne03 == ne13 && ne13 == ne3);
9819
+
9820
+ const int ith = params->ith;
9821
+ const int nth = params->nth;
9822
+
9823
+ const int64_t k = ne10; // number of RHS columns
9824
+ const int64_t n = ne11; // A is n×n
9825
+ const int64_t nr = ne02 * ne03 * k; // we're parallelizing on columns here, so seq x token x column will be the unit
9826
+
9827
+ // chunks per thread
9828
+ const int64_t dr = (nr + nth - 1)/nth;
9829
+
9830
+ // chunk range for this thread
9831
+ const int64_t ir0 = dr*ith;
9832
+ const int64_t ir1 = MIN(ir0 + dr, nr);
9833
+
9834
+ const float * A = (const float *) src0->data; // [n, n, B1, B2]
9835
+ const float * B = (const float *) src1->data; // [n, k, B1, B2]
9836
+ float * X = ( float *) dst->data; // [n, k, B1, B2]
9837
+
9838
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
9839
+ const int64_t i03 = ir/(ne02*k);
9840
+ const int64_t i02 = (ir - i03*ne02*k)/k;
9841
+ const int64_t i01 = (ir - i03*ne02*k - i02*k);
9842
+
9843
+ const float * A_batch = A + i02 * nb02 / sizeof(float) + i03 * nb03 / sizeof(float);
9844
+ const float * B_batch = B + i02 * nb12 / sizeof(float) + i03 * nb13 / sizeof(float);
9845
+
9846
+ float * X_batch = X + i02 * nb2 / sizeof(float) + i03 * nb3 / sizeof(float);
9847
+
9848
+ for (int64_t i00 = 0; i00 < n; ++i00) {
9849
+ float sum = 0.0f;
9850
+ for (int64_t t = 0; t < i00; ++t) {
9851
+ sum += A_batch[i00 * n + t] * X_batch[t * k + i01];
9852
+ }
9853
+
9854
+ const float diag = A_batch[i00 * n + i00];
9855
+ assert(diag != 0.0f && "Zero diagonal in triangular matrix");
9856
+
9857
+ X_batch[i00 * k + i01] = (B_batch[i00 * k + i01] - sum) / diag;
9858
+ }
9859
+ }
9860
+ }
9861
+
9862
+ void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst) {
9863
+ const ggml_tensor * src0 = dst->src[0];
9864
+ const ggml_tensor * src1 = dst->src[1];
9865
+
9866
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
9867
+ ggml_compute_forward_solve_tri_f32(params, dst);
9868
+ } else {
9869
+ GGML_ABORT("fatal error");
9870
+ }
9871
+ }
9872
+
9247
9873
  // ggml_compute_forward_rwkv_wkv7
9248
9874
 
9249
9875
  static void ggml_compute_forward_rwkv_wkv7_f32(
@@ -9283,8 +9909,8 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
9283
9909
  int64_t h_stride_2d = head_size * head_size;
9284
9910
 
9285
9911
  #if defined(GGML_SIMD)
9286
- #if defined(__ARM_FEATURE_SVE)
9287
- // scalar Route to scalar implementation //TODO: Write SVE code
9912
+ #if defined(__ARM_FEATURE_SVE) || defined(__riscv_v_intrinsic)
9913
+ // scalar Route to scalar implementation //TODO: Write SVE code and RVV code
9288
9914
  for (int64_t t = 0; t < T; t++) {
9289
9915
  int64_t t_offset = t * t_stride;
9290
9916
  int64_t state_offset = head_size * C * (t / (T / n_seqs));
@@ -9732,6 +10358,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
9732
10358
  const int ir1 = MIN(ir0 + dr, nr);
9733
10359
 
9734
10360
  const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
10361
+
9735
10362
  const float alpha = adamw_params_ptr[0];
9736
10363
  const float beta1 = adamw_params_ptr[1];
9737
10364
  const float beta2 = adamw_params_ptr[2];
@@ -9739,7 +10366,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
9739
10366
  const float wd = adamw_params_ptr[4];
9740
10367
  const float beta1h = adamw_params_ptr[5];
9741
10368
  const float beta2h = adamw_params_ptr[6];
9742
-
10369
+ const float keep = 1.f - alpha * wd;
9743
10370
  for (int ir = ir0; ir < ir1; ++ir) {
9744
10371
  const int64_t i03 = ir/(ne02*ne01);
9745
10372
  const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
@@ -9762,7 +10389,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
9762
10389
  // The weight decay is applied independently of the Adam momenta m and v.
9763
10390
  // This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
9764
10391
  // See: https://arxiv.org/pdf/1711.05101v3.pdf
9765
- w[i00] = w[i00]*(1.0f - alpha*wd) - alpha*mh/vh;
10392
+ w[i00] = w[i00] * keep - alpha * mh / vh;
9766
10393
  }
9767
10394
  }
9768
10395
  }
@@ -9784,3 +10411,63 @@ void ggml_compute_forward_opt_step_adamw(
9784
10411
  }
9785
10412
  }
9786
10413
  }
10414
+
10415
+ static void ggml_compute_forward_opt_step_sgd_f32(const ggml_compute_params * params, ggml_tensor * dst) {
10416
+ const ggml_tensor * src0 = dst->src[0];
10417
+ const ggml_tensor * src0_grad = dst->src[1];
10418
+ const ggml_tensor * sgd_params = dst->src[2];
10419
+
10420
+ GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
10421
+ GGML_ASSERT(ggml_nelements(sgd_params) == 2);
10422
+
10423
+ const int ith = params->ith;
10424
+ const int nth = params->nth;
10425
+
10426
+ const int nr = ggml_nrows(src0);
10427
+
10428
+ GGML_TENSOR_UNARY_OP_LOCALS
10429
+ GGML_ASSERT(nb00 == sizeof(float));
10430
+
10431
+ // rows per thread
10432
+ const int dr = (nr + nth - 1) / nth;
10433
+
10434
+ // row range for this thread
10435
+ const int ir0 = dr * ith;
10436
+ const int ir1 = MIN(ir0 + dr, nr);
10437
+
10438
+ // using adamw param subset we care about - alpha, wd - could have a separate struct
10439
+ const float * sgd_params_ptr = ggml_get_data_f32(sgd_params);
10440
+ const float alpha = sgd_params_ptr[0];
10441
+ const float keep = 1.f - alpha * sgd_params_ptr[1];
10442
+
10443
+ for (int ir = ir0; ir < ir1; ++ir) {
10444
+ const int64_t i03 = ir / (ne02 * ne01);
10445
+ const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01;
10446
+ const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01);
10447
+
10448
+ const size_t offset = i03 * nb03 + i02 * nb02 + i01 * nb01;
10449
+
10450
+ float * w = (float *) ((char *) src0->data + offset); // weight
10451
+ const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad
10452
+
10453
+ for (int i00 = 0; i00 < ne00; ++i00) {
10454
+ w[i00] = w[i00] * keep - alpha * g[i00];
10455
+ }
10456
+ }
10457
+ }
10458
+
10459
+ void ggml_compute_forward_opt_step_sgd(const ggml_compute_params * params, ggml_tensor * dst) {
10460
+ const ggml_tensor * src0 = dst->src[0];
10461
+
10462
+ switch (src0->type) {
10463
+ case GGML_TYPE_F32:
10464
+ {
10465
+ ggml_compute_forward_opt_step_sgd_f32(params, dst);
10466
+ }
10467
+ break;
10468
+ default:
10469
+ {
10470
+ GGML_ABORT("fatal error - sgd is F32 only");
10471
+ }
10472
+ }
10473
+ }