whispercpp 1.3.4 → 1.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (630) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +60 -43
  3. data/ext/extconf.rb +2 -2
  4. data/ext/ruby_whisper.c +14 -2
  5. data/ext/ruby_whisper.h +39 -0
  6. data/ext/ruby_whisper_context.c +22 -22
  7. data/ext/ruby_whisper_model.c +12 -12
  8. data/ext/ruby_whisper_params.c +47 -23
  9. data/ext/ruby_whisper_segment.c +84 -19
  10. data/ext/ruby_whisper_token.c +351 -0
  11. data/ext/ruby_whisper_transcribe.cpp +1 -1
  12. data/ext/ruby_whisper_vad_context.c +75 -0
  13. data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
  14. data/ext/ruby_whisper_vad_segment.c +139 -0
  15. data/ext/ruby_whisper_vad_segments.c +106 -0
  16. data/ext/sources/CMakeLists.txt +4 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  19. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  20. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  21. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  22. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  23. data/ext/sources/examples/cli/cli.cpp +121 -112
  24. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  25. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  26. data/ext/sources/examples/server/server.cpp +10 -11
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +5 -1
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +12 -3
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +7 -1
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +2046 -1974
  31. data/ext/sources/examples/talk-llama/llama-arch.h +67 -2
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +75 -33
  33. data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
  34. data/ext/sources/examples/talk-llama/llama-chat.cpp +79 -3
  35. data/ext/sources/examples/talk-llama/llama-chat.h +4 -0
  36. data/ext/sources/examples/talk-llama/llama-context.cpp +775 -78
  37. data/ext/sources/examples/talk-llama/llama-context.h +57 -9
  38. data/ext/sources/examples/talk-llama/llama-cparams.h +1 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
  40. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  41. data/ext/sources/examples/talk-llama/llama-graph.cpp +381 -64
  42. data/ext/sources/examples/talk-llama/llama-graph.h +103 -13
  43. data/ext/sources/examples/talk-llama/llama-hparams.cpp +26 -2
  44. data/ext/sources/examples/talk-llama/llama-hparams.h +41 -10
  45. data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
  46. data/ext/sources/examples/talk-llama/llama-impl.h +1 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +5 -3
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +145 -65
  49. data/ext/sources/examples/talk-llama/llama-kv-cache.h +22 -7
  50. data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
  51. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
  52. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +32 -19
  53. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
  54. data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
  55. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  56. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +91 -9
  57. data/ext/sources/examples/talk-llama/llama-model-loader.h +6 -0
  58. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
  59. data/ext/sources/examples/talk-llama/llama-model.cpp +1529 -13134
  60. data/ext/sources/examples/talk-llama/llama-model.h +44 -3
  61. data/ext/sources/examples/talk-llama/llama-quant.cpp +8 -23
  62. data/ext/sources/examples/talk-llama/llama-sampling.cpp +1294 -198
  63. data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
  64. data/ext/sources/examples/talk-llama/llama-vocab.cpp +133 -37
  65. data/ext/sources/examples/talk-llama/llama-vocab.h +45 -40
  66. data/ext/sources/examples/talk-llama/llama.cpp +729 -2
  67. data/ext/sources/examples/talk-llama/llama.h +152 -14
  68. data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
  69. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  70. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  71. data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
  72. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  73. data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
  74. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
  75. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
  76. data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
  77. data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
  78. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  79. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  80. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  81. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  82. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  83. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  84. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  85. data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
  86. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  87. data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
  88. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
  89. data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
  90. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  91. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
  92. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
  108. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
  109. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  110. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
  111. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  112. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
  113. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  114. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  116. data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
  117. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  118. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  119. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  120. data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
  121. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  122. data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
  123. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
  124. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
  125. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
  126. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  127. data/ext/sources/examples/talk-llama/models/models.h +569 -0
  128. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
  129. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  130. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
  131. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  132. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  133. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  134. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  135. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  136. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  137. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  138. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  139. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  140. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  141. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  142. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  143. data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
  144. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  145. data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
  146. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  147. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  148. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  149. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  150. data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
  151. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
  153. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
  154. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
  155. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  156. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  157. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
  158. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  159. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  160. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
  161. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  162. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  163. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  164. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  165. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  166. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  167. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  168. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  169. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  170. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  171. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  172. data/ext/sources/examples/talk-llama/unicode.cpp +102 -16
  173. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  174. data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
  175. data/ext/sources/ggml/CMakeLists.txt +82 -54
  176. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  177. data/ext/sources/ggml/include/ggml-backend.h +4 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +1 -0
  179. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  180. data/ext/sources/ggml/include/ggml-rpc.h +8 -11
  181. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  182. data/ext/sources/ggml/include/ggml.h +190 -12
  183. data/ext/sources/ggml/src/CMakeLists.txt +82 -11
  184. data/ext/sources/ggml/src/ggml-alloc.c +124 -41
  185. data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
  186. data/ext/sources/ggml/src/ggml-backend-reg.cpp +27 -3
  187. data/ext/sources/ggml/src/ggml-backend.cpp +71 -21
  188. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
  189. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -9
  190. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
  191. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
  192. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2179 -1696
  193. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -317
  194. data/ext/sources/ggml/src/ggml-cann/common.h +283 -208
  195. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +626 -776
  196. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +156 -86
  197. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -0
  198. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  199. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
  200. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1004 -0
  201. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
  202. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  203. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
  204. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  205. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6 -6
  206. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +50 -2
  207. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -3
  208. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +195 -71
  209. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  210. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +573 -106
  211. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
  212. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +298 -112
  213. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
  214. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +819 -125
  215. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  216. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +708 -431
  217. data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -4
  218. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +671 -31
  219. data/ext/sources/ggml/src/ggml-cpu/repack.h +14 -0
  220. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +41 -43
  221. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
  222. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
  223. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  224. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +124 -1
  225. data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
  226. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
  227. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  228. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
  229. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  230. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +1 -1
  231. data/ext/sources/ggml/src/ggml-cuda/common.cuh +353 -80
  232. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  233. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +339 -246
  235. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  236. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  237. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  238. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  239. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +31 -21
  241. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +663 -596
  242. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
  243. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1241 -0
  244. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +30 -37
  245. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +14 -13
  246. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  247. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +83 -37
  248. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  249. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  250. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1155 -164
  251. data/ext/sources/ggml/src/ggml-cuda/mean.cu +5 -4
  252. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +741 -48
  253. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +60 -12
  254. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +381 -42
  255. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  256. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  257. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
  258. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +498 -171
  259. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +375 -79
  260. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +3 -2
  261. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +241 -95
  262. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/pad.cu +64 -33
  264. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +151 -0
  265. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  266. data/ext/sources/ggml/src/ggml-cuda/rope.cu +192 -77
  267. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  268. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
  269. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  270. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  271. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +203 -6
  272. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  273. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  274. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -20
  275. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
  276. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  278. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  279. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  280. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  281. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  282. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  283. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  284. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  285. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +19 -1
  286. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
  287. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  288. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +168 -76
  289. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +11 -4
  290. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  291. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  292. data/ext/sources/ggml/src/ggml-cuda/unary.cu +105 -11
  293. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +36 -0
  294. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
  295. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  296. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +12 -1
  297. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
  317. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
  321. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  322. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
  323. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
  324. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
  325. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  326. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
  327. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
  328. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  329. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +8 -13
  330. data/ext/sources/ggml/src/ggml-impl.h +67 -6
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +2 -2
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +29 -20
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +652 -285
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +103 -56
  335. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +496 -118
  336. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +231 -9
  337. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1227 -224
  338. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +12 -0
  339. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +14 -8
  340. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1972 -704
  341. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
  342. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +11 -0
  343. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1430 -120
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +63 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +13 -0
  365. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
  366. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
  367. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
  368. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  369. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +6 -0
  370. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +0 -9
  371. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
  372. data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
  373. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
  374. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +34 -0
  375. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  376. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  377. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
  378. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  379. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +76 -3
  380. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +333 -300
  381. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +10 -2
  382. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +335 -110
  383. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
  384. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +156 -0
  385. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  386. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  387. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  388. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  389. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  390. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
  391. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  392. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  393. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  394. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  395. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +30 -17
  396. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  397. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  398. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
  399. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  400. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  401. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  402. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +58 -0
  403. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
  404. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5013 -2859
  405. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  406. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
  407. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
  408. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  409. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
  410. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  411. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
  412. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
  413. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  414. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  415. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  416. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  417. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  418. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  419. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
  420. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  421. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  422. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
  423. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
  424. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  425. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  426. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  427. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  428. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  429. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  430. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  431. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  432. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
  433. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
  434. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  449. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
  450. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  451. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  452. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
  453. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  454. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  455. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  456. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  457. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  458. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
  459. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  460. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +39 -17
  461. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +19 -1
  462. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +45 -7
  463. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +50 -12
  464. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +1 -1
  465. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  466. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  467. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
  468. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
  469. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  470. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
  471. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  472. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
  473. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  474. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  475. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
  476. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
  477. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
  478. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  479. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
  480. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
  481. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
  482. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
  483. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
  484. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  485. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  486. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  487. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  488. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +70 -25
  489. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  490. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  491. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  492. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  493. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  494. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  495. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  496. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  497. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
  498. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  499. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  500. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  501. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  502. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  503. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  504. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
  505. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  506. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +78 -103
  507. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +34 -23
  508. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
  509. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
  510. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +88 -228
  511. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  512. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  513. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
  514. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  515. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  516. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  517. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
  518. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
  519. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  520. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
  521. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  522. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  523. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  524. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  525. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +50 -4
  526. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  527. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
  528. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
  529. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
  530. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  531. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -50
  532. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -33
  533. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -33
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
  560. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  561. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
  562. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  563. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  564. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  565. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  566. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  567. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
  568. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
  569. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -151
  570. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  571. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
  572. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
  573. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1964 -435
  574. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +33 -10
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +1 -1
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +6 -6
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
  584. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
  585. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
  586. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
  587. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  588. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
  589. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
  590. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  591. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
  592. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  593. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
  594. data/ext/sources/ggml/src/ggml.c +425 -33
  595. data/ext/sources/include/whisper.h +1 -0
  596. data/ext/sources/src/CMakeLists.txt +3 -1
  597. data/ext/sources/src/whisper.cpp +101 -35
  598. data/ext/sources/tests/CMakeLists.txt +2 -2
  599. data/ext/sources/tests/test-vad-full.cpp +4 -2
  600. data/ext/sources/tests/test-vad.cpp +1 -1
  601. data/extsources.rb +1 -0
  602. data/lib/whisper/model/uri.rb +17 -18
  603. data/sig/whisper.rbs +119 -2
  604. data/test/test_params.rb +16 -8
  605. data/test/test_segment.rb +0 -1
  606. data/test/test_token.rb +70 -0
  607. data/test/test_vad.rb +1 -1
  608. data/test/test_vad_context.rb +50 -0
  609. data/test/test_vad_segment.rb +19 -0
  610. data/test/test_vad_segments.rb +16 -0
  611. data/test/test_whisper.rb +7 -0
  612. data/whispercpp.gemspec +1 -1
  613. metadata +287 -34
  614. data/ext/sources/build-xcframework.sh +0 -571
  615. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
  618. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
  619. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
  620. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
  621. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
  622. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
  623. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
  624. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  625. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  626. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  627. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  628. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
  629. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
  630. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
@@ -44,6 +44,7 @@ void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t *
44
44
  void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc);
45
45
 
46
46
  void ggml_vec_silu_f32(const int n, float * y, const float * x);
47
+ ggml_float ggml_vec_cvar_f32(const int n, float * y, const float * x, const float mean); //it will also center y ( y = y - mean )
47
48
  ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max);
48
49
  ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max);
49
50
 
@@ -143,14 +144,14 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
143
144
  for (int i = 0; i < np; i += ggml_f16_step) {
144
145
  ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0); // 8 elements
145
146
 
146
- ax1 = GGML_F16x_VEC_LOAD(x[0] + i + 0*ggml_f16_epr, 0); // 8 elemnst
147
+ ax1 = GGML_F16x_VEC_LOAD(x[0] + i + 0*ggml_f16_epr, 0); // 8 elements
147
148
  sum_00 = GGML_F16x_VEC_FMA(sum_00, ax1, ay1); // sum_00 = sum_00+ax1*ay1
148
149
  ax1 = GGML_F16x_VEC_LOAD(x[1] + i + 0*ggml_f16_epr, 0); // 8 elements
149
150
  sum_10 = GGML_F16x_VEC_FMA(sum_10, ax1, ay1);
150
151
 
151
152
  ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1); // next 8 elements
152
153
 
153
- ax2 = GGML_F16x_VEC_LOAD(x[0] + i + 1*ggml_f16_epr, 1); // next 8 ekements
154
+ ax2 = GGML_F16x_VEC_LOAD(x[0] + i + 1*ggml_f16_epr, 1); // next 8 elements
154
155
  sum_01 = GGML_F16x_VEC_FMA(sum_01, ax2, ay2);
155
156
  ax2 = GGML_F16x_VEC_LOAD(x[1] + i + 1*ggml_f16_epr, 1);
156
157
  sum_11 = GGML_F16x_VEC_FMA(sum_11, ax2, ay2);
@@ -159,7 +160,7 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
159
160
 
160
161
  ax3 = GGML_F16x_VEC_LOAD(x[0] + i + 2*ggml_f16_epr, 2);
161
162
  sum_02 = GGML_F16x_VEC_FMA(sum_02, ax3, ay3);
162
- ax1 = GGML_F16x_VEC_LOAD(x[1] + i + 2*ggml_f16_epr, 2);
163
+ ax3 = GGML_F16x_VEC_LOAD(x[1] + i + 2*ggml_f16_epr, 2);
163
164
  sum_12 = GGML_F16x_VEC_FMA(sum_12, ax3, ay3);
164
165
 
165
166
  ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
@@ -223,13 +224,71 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
223
224
  }
224
225
  GGML_F16x_VEC_REDUCE(sumf[0], sum_00, sum_01, sum_02, sum_03);
225
226
  GGML_F16x_VEC_REDUCE(sumf[1], sum_10, sum_11, sum_12, sum_13);
226
- #elif defined(__riscv_v_intrinsic)
227
- // todo: RVV impl
228
- for (int i = 0; i < n; ++i) {
229
- for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
230
- sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i]));
231
- }
232
- }
227
+
228
+ #elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh)
229
+ size_t vl = __riscv_vsetvlmax_e32m4();
230
+
231
+ // initialize accumulators to all zeroes
232
+ vfloat32m4_t vsum0_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
233
+ vfloat32m4_t vsum0_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
234
+ vfloat32m4_t vsum1_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
235
+ vfloat32m4_t vsum1_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
236
+
237
+ // calculate step size
238
+ const size_t epr = __riscv_vsetvlmax_e16m2();
239
+ const size_t step = epr * 2;
240
+ const int np = (n & ~(step - 1));
241
+
242
+ // unroll by 2 along the row dimension
243
+ for (int i = 0; i < np; i += step) {
244
+ vfloat16m2_t ay0 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), epr);
245
+ vfloat16m2_t ax0_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), epr);
246
+ vfloat16m2_t ax1_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), epr);
247
+ vsum0_0 = __riscv_vfwmacc_vv_f32m4(vsum0_0, ax0_0, ay0, epr);
248
+ vsum1_0 = __riscv_vfwmacc_vv_f32m4(vsum1_0, ax1_0, ay0, epr);
249
+
250
+ vfloat16m2_t ay1 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i + epr), epr);
251
+ vfloat16m2_t ax0_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i + epr), epr);
252
+ vfloat16m2_t ax1_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i + epr), epr);
253
+ vsum0_1 = __riscv_vfwmacc_vv_f32m4(vsum0_1, ax0_1, ay1, epr);
254
+ vsum1_1 = __riscv_vfwmacc_vv_f32m4(vsum1_1, ax1_1, ay1, epr);
255
+ }
256
+
257
+ vfloat32m4_t vsum0 = __riscv_vfadd_vv_f32m4(vsum0_0, vsum0_1, vl);
258
+ vfloat32m4_t vsum1 = __riscv_vfadd_vv_f32m4(vsum1_0, vsum1_1, vl);
259
+
260
+ // leftovers
261
+ for (int i = np; i < n; i += vl) {
262
+ vl = __riscv_vsetvl_e16m2(n - i);
263
+ vfloat16m2_t ay = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), vl);
264
+ vfloat16m2_t ax0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), vl);
265
+ vfloat16m2_t ax1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), vl);
266
+
267
+ vsum0 = __riscv_vfwmacc_vv_f32m4(vsum0, ax0, ay, vl);
268
+ vsum1 = __riscv_vfwmacc_vv_f32m4(vsum1, ax1, ay, vl);
269
+ }
270
+
271
+ // reduce
272
+ vl = __riscv_vsetvlmax_e32m2();
273
+ vfloat32m2_t acc0_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum0, 0),
274
+ __riscv_vget_v_f32m4_f32m2(vsum0, 1), vl);
275
+ vl = __riscv_vsetvlmax_e32m1();
276
+ vfloat32m1_t acc0_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc0_0, 0),
277
+ __riscv_vget_v_f32m2_f32m1(acc0_0, 1), vl);
278
+ vfloat32m1_t redsum0 = __riscv_vfredusum_vs_f32m1_f32m1(
279
+ acc0_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
280
+
281
+ vl = __riscv_vsetvlmax_e32m2();
282
+ vfloat32m2_t acc1_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum1, 0),
283
+ __riscv_vget_v_f32m4_f32m2(vsum1, 1), vl);
284
+ vl = __riscv_vsetvlmax_e32m1();
285
+ vfloat32m1_t acc1_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc1_0, 0),
286
+ __riscv_vget_v_f32m2_f32m1(acc1_0, 1), vl);
287
+ vfloat32m1_t redsum1 = __riscv_vfredusum_vs_f32m1_f32m1(
288
+ acc1_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
289
+ sumf[0] = __riscv_vfmv_f_s_f32m1_f32(redsum0);
290
+ sumf[1] = __riscv_vfmv_f_s_f32m1_f32(redsum1);
291
+
233
292
  #else
234
293
  const int np = (n & ~(GGML_F16_STEP - 1));
235
294
 
@@ -396,119 +455,142 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
396
455
  }
397
456
 
398
457
  inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, const ggml_fp16_t * GGML_RESTRICT x, const float v) {
399
- #if defined(GGML_SIMD)
400
- #if defined(__ARM_FEATURE_SVE)
401
- const int sve_register_length = svcntb() * 8;
402
- const int ggml_f16_epr = sve_register_length / 16;
403
- const int ggml_f16_step = 8 * ggml_f16_epr;
458
+ #if defined(GGML_SIMD) && defined(__ARM_FEATURE_SVE)
459
+ const int sve_register_length = svcntb() * 8;
460
+ const int ggml_f16_epr = sve_register_length / 16;
461
+ const int ggml_f16_step = 8 * ggml_f16_epr;
404
462
 
405
- GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v);
463
+ GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v);
406
464
 
407
- const int np= (n & ~(ggml_f16_step - 1));
465
+ int np = (n & ~(ggml_f16_step - 1));
408
466
 
409
- svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
410
- svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
411
- for (int i = 0; i < np; i += ggml_f16_step) {
412
- ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0);
413
- ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0);
414
- ay1 = GGML_F16x_VEC_FMA(ay1, ax1, vx);
467
+ svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
468
+ svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
469
+ for (int i = 0; i < np; i += ggml_f16_step) {
470
+ ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0);
471
+ ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0);
472
+ ay1 = GGML_F16x_VEC_FMA(ay1, ax1, vx);
415
473
 
416
- GGML_F16x_VEC_STORE(y + i + 0 * ggml_f16_epr, ay1, 0);
474
+ GGML_F16x_VEC_STORE(y + i + 0 * ggml_f16_epr, ay1, 0);
417
475
 
418
- ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1);
419
- ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1);
420
- ay2 = GGML_F16x_VEC_FMA(ay2, ax2, vx);
476
+ ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1);
477
+ ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1);
478
+ ay2 = GGML_F16x_VEC_FMA(ay2, ax2, vx);
421
479
 
422
- GGML_F16x_VEC_STORE(y + i + 1 * ggml_f16_epr, ay2, 1);
480
+ GGML_F16x_VEC_STORE(y + i + 1 * ggml_f16_epr, ay2, 1);
423
481
 
424
- ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2);
425
- ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2);
426
- ay3 = GGML_F16x_VEC_FMA(ay3, ax3, vx);
482
+ ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2);
483
+ ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2);
484
+ ay3 = GGML_F16x_VEC_FMA(ay3, ax3, vx);
427
485
 
428
- GGML_F16x_VEC_STORE(y + i + 2 * ggml_f16_epr, ay3, 2);
486
+ GGML_F16x_VEC_STORE(y + i + 2 * ggml_f16_epr, ay3, 2);
429
487
 
430
- ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3);
431
- ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
432
- ay4 = GGML_F16x_VEC_FMA(ay4, ax4, vx);
488
+ ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3);
489
+ ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
490
+ ay4 = GGML_F16x_VEC_FMA(ay4, ax4, vx);
433
491
 
434
- GGML_F16x_VEC_STORE(y + i + 3 * ggml_f16_epr, ay4, 3);
492
+ GGML_F16x_VEC_STORE(y + i + 3 * ggml_f16_epr, ay4, 3);
435
493
 
436
- ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4);
437
- ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4);
438
- ay5 = GGML_F16x_VEC_FMA(ay5, ax5, vx);
494
+ ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4);
495
+ ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4);
496
+ ay5 = GGML_F16x_VEC_FMA(ay5, ax5, vx);
439
497
 
440
- GGML_F16x_VEC_STORE(y + i + 4 * ggml_f16_epr, ay5, 4);
498
+ GGML_F16x_VEC_STORE(y + i + 4 * ggml_f16_epr, ay5, 4);
441
499
 
442
- ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5);
443
- ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5);
444
- ay6 = GGML_F16x_VEC_FMA(ay6, ax6, vx);
500
+ ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5);
501
+ ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5);
502
+ ay6 = GGML_F16x_VEC_FMA(ay6, ax6, vx);
445
503
 
446
- GGML_F16x_VEC_STORE(y + i + 5 * ggml_f16_epr, ay6, 5);
504
+ GGML_F16x_VEC_STORE(y + i + 5 * ggml_f16_epr, ay6, 5);
447
505
 
448
- ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6);
449
- ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6);
450
- ay7 = GGML_F16x_VEC_FMA(ay7, ax7, vx);
506
+ ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6);
507
+ ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6);
508
+ ay7 = GGML_F16x_VEC_FMA(ay7, ax7, vx);
451
509
 
452
- GGML_F16x_VEC_STORE(y + i + 6 * ggml_f16_epr, ay7, 6);
510
+ GGML_F16x_VEC_STORE(y + i + 6 * ggml_f16_epr, ay7, 6);
453
511
 
454
- ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7);
455
- ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7);
456
- ay8 = GGML_F16x_VEC_FMA(ay8, ax8, vx);
512
+ ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7);
513
+ ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7);
514
+ ay8 = GGML_F16x_VEC_FMA(ay8, ax8, vx);
457
515
 
458
- GGML_F16x_VEC_STORE(y + i + 7 * ggml_f16_epr, ay8, 7);
459
- }
460
- const int np2 = (n & ~(ggml_f16_epr - 1));
461
- for (int k = np; k < np2; k += ggml_f16_epr) {
462
- svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0);
463
- svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);
464
- ry = GGML_F16x_VEC_FMA(ry, rx, vx);
516
+ GGML_F16x_VEC_STORE(y + i + 7 * ggml_f16_epr, ay8, 7);
517
+ }
518
+ const int np2 = (n & ~(ggml_f16_epr - 1));
519
+ for (int k = np; k < np2; k += ggml_f16_epr) {
520
+ svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0);
521
+ svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);
522
+ ry = GGML_F16x_VEC_FMA(ry, rx, vx);
465
523
 
466
- GGML_F16x_VEC_STORE(y + k, ry, 0);
467
- }
524
+ GGML_F16x_VEC_STORE(y + k, ry, 0);
525
+ }
468
526
 
469
- if (np2 < n) {
470
- svbool_t pg = svwhilelt_b16(np2, n);
471
- svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2));
472
- svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
473
- hy = svmad_f16_x(pg, hx, vx, hy);
474
- svst1_f16(pg, (__fp16 *)(y + np2), hy);
475
- }
527
+ if (np2 < n) {
528
+ svbool_t pg = svwhilelt_b16(np2, n);
529
+ svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2));
530
+ svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
531
+ hy = svmad_f16_x(pg, hx, vx, hy);
532
+ svst1_f16(pg, (__fp16 *)(y + np2), hy);
533
+ }
534
+ np = n;
535
+ #elif defined(__riscv_zvfh) // implies __riscv_v_intrinsic
536
+ const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v);
537
+ const _Float16 scale = *(const _Float16*)(&s);
538
+
539
+ // calculate step size
540
+ const int epr = __riscv_vsetvlmax_e16m4();
541
+ const int step = epr * 2;
542
+ int np = (n & ~(step - 1));
543
+
544
+ // unroll by 2
545
+ for (int i = 0; i < np; i += step) {
546
+ vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, epr);
547
+ vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr);
548
+ ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, epr);
549
+ __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr);
550
+ __asm__ __volatile__ ("" ::: "memory");
551
+
552
+ vfloat16m4_t ax1 = __riscv_vle16_v_f16m4((const _Float16*)x + i + epr, epr);
553
+ vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr);
554
+ ay1 = __riscv_vfmacc_vf_f16m4(ay1, scale, ax1, epr);
555
+ __riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr);
556
+ __asm__ __volatile__ ("" ::: "memory");
557
+ }
476
558
 
477
- #elif defined(__riscv_v_intrinsic)
478
- // todo: RVV impl
479
- // scalar
480
- for (int i = 0; i < n; ++i) {
481
- y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
482
- }
483
- #else
484
- const int np = (n & ~(GGML_F16_STEP - 1));
559
+ // leftovers
560
+ int vl;
561
+ for (int i = np; i < n; i += vl) {
562
+ vl = __riscv_vsetvl_e16m4(n - i);
563
+ vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, vl);
564
+ vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl);
565
+ ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, vl);
566
+ __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl);
567
+ }
568
+ np = n;
569
+ #elif defined(GGML_SIMD)
570
+ const int np = (n & ~(GGML_F16_STEP - 1));
485
571
 
486
- GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
572
+ GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
487
573
 
488
- GGML_F16_VEC ax[GGML_F16_ARR];
489
- GGML_F16_VEC ay[GGML_F16_ARR];
574
+ GGML_F16_VEC ax[GGML_F16_ARR];
575
+ GGML_F16_VEC ay[GGML_F16_ARR];
490
576
 
491
- for (int i = 0; i < np; i += GGML_F16_STEP) {
492
- for (int j = 0; j < GGML_F16_ARR; j++) {
493
- ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
494
- ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
495
- ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
577
+ for (int i = 0; i < np; i += GGML_F16_STEP) {
578
+ for (int j = 0; j < GGML_F16_ARR; j++) {
579
+ ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
580
+ ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
581
+ ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
496
582
 
497
- GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
498
- }
583
+ GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
499
584
  }
500
-
501
- // leftovers
502
- for (int i = np; i < n; ++i) {
503
- y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
504
- }
505
- #endif
585
+ }
506
586
  #else
507
- // scalar
508
- for (int i = 0; i < n; ++i) {
587
+ const int np = 0;
588
+ #endif
589
+
590
+ // leftovers
591
+ for (int i = np; i < n; ++i) {
509
592
  y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
510
593
  }
511
- #endif
512
594
  }
513
595
 
514
596
  // xs and vs are byte strides of x and v
@@ -654,11 +736,11 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
654
736
  }
655
737
  // leftovers
656
738
  // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only
657
- if (np < n) {
658
- svbool_t pg = svwhilelt_b32(np, n);
659
- ay1 = svld1_f32(pg, y + np);
739
+ for (int i = np; i < n; i += ggml_f32_epr) {
740
+ svbool_t pg = svwhilelt_b32(i, n);
741
+ ay1 = svld1_f32(pg, y + i);
660
742
  ay1 = svmul_f32_m(pg, ay1, vx);
661
- svst1_f32(pg, y + np, ay1);
743
+ svst1_f32(pg, y + i, ay1);
662
744
  }
663
745
  #elif defined(__riscv_v_intrinsic)
664
746
  for (int i = 0, avl; i < n; i += avl) {
@@ -697,60 +779,82 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
697
779
  }
698
780
 
699
781
  inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
700
- #if defined(GGML_SIMD)
701
- #if defined(__ARM_FEATURE_SVE)
702
- const int sve_register_length = svcntb() * 8;
703
- const int ggml_f16_epr = sve_register_length / 16;
704
- const int ggml_f16_step = 2 * ggml_f16_epr;
705
-
706
- GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v);
707
- const int np = (n & ~(ggml_f16_step - 1));
708
- svfloat16_t ay1, ay2;
709
-
710
- for (int i = 0; i < np; i += ggml_f16_step) {
711
- ay1 = GGML_F16x_VEC_LOAD(y + i + 0*ggml_f16_epr, 0);
712
- ay1 = GGML_F16x_VEC_MUL(ay1, vx);
713
- GGML_F16x_VEC_STORE(y + i + 0*ggml_f16_epr, ay1, 0);
782
+ #if defined(GGML_SIMD) && defined(__ARM_FEATURE_SVE)
783
+ const int sve_register_length = svcntb() * 8;
784
+ const int ggml_f16_epr = sve_register_length / 16;
785
+ const int ggml_f16_step = 2 * ggml_f16_epr;
786
+
787
+ GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v);
788
+ const int np = (n & ~(ggml_f16_step - 1));
789
+ svfloat16_t ay1, ay2;
790
+
791
+ for (int i = 0; i < np; i += ggml_f16_step) {
792
+ ay1 = GGML_F16x_VEC_LOAD(y + i + 0*ggml_f16_epr, 0);
793
+ ay1 = GGML_F16x_VEC_MUL(ay1, vx);
794
+ GGML_F16x_VEC_STORE(y + i + 0*ggml_f16_epr, ay1, 0);
795
+
796
+ ay2 = GGML_F16x_VEC_LOAD(y + i + 1*ggml_f16_epr, 1);
797
+ ay2 = GGML_F16x_VEC_MUL(ay2, vx);
798
+ GGML_F16x_VEC_STORE(y + i + 1*ggml_f16_epr, ay2, 1);
799
+ }
800
+ // leftovers
801
+ // maximum number of leftover elements will be less that ggmlF_16x_epr. Apply predicated svmad on available elements only
802
+ if (np < n) {
803
+ svbool_t pg = svwhilelt_b16(np, n);
804
+ svfloat16_t hy = svld1_f16(pg, (__fp16 *)(y + np));
805
+ svfloat16_t out = svmul_f16_m(pg, hy, vx);
806
+ svst1_f16(pg, (__fp16 *)(y + np), out);
807
+ }
808
+ #elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh)
809
+ const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v);
810
+ const _Float16 scale = *(const _Float16*)(&s);
811
+
812
+ // calculate step size
813
+ const int epr = __riscv_vsetvlmax_e16m4();
814
+ const int step = epr * 2;
815
+ const int np = (n & ~(step - 1));
816
+
817
+ // unroll by 2
818
+ for (int i = 0; i < np; i += step) {
819
+ vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr);
820
+ ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, epr);
821
+ __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr);
822
+ __asm__ __volatile__ ("" ::: "memory");
823
+
824
+ vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr);
825
+ ay1 = __riscv_vfmul_vf_f16m4(ay1, scale, epr);
826
+ __riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr);
827
+ __asm__ __volatile__ ("" ::: "memory");
828
+ }
714
829
 
715
- ay2 = GGML_F16x_VEC_LOAD(y + i + 1*ggml_f16_epr, 1);
716
- ay2 = GGML_F16x_VEC_MUL(ay2, vx);
717
- GGML_F16x_VEC_STORE(y + i + 1*ggml_f16_epr, ay2, 1);
718
- }
719
- // leftovers
720
- // maximum number of leftover elements will be less that ggmlF_16x_epr. Apply predicated svmad on available elements only
721
- if (np < n) {
722
- svbool_t pg = svwhilelt_b16(np, n);
723
- svfloat16_t hy = svld1_f16(pg, (__fp16 *)(y + np));
724
- svfloat16_t out = svmul_f16_m(pg, hy, vx);
725
- svst1_f16(pg, (__fp16 *)(y + np), out);
726
- }
727
- #elif defined(__riscv_v_intrinsic)
728
- // todo: RVV impl
729
- // scalar
730
- for (int i = 0; i < n; ++i) {
731
- y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
732
- }
733
- #else
734
- const int np = (n & ~(GGML_F16_STEP - 1));
830
+ // leftovers
831
+ int vl;
832
+ for (int i = np; i < n; i += vl) {
833
+ vl = __riscv_vsetvl_e16m4(n - i);
834
+ vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl);
835
+ ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, vl);
836
+ __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl);
837
+ }
838
+ #elif defined(GGML_SIMD)
839
+ const int np = (n & ~(GGML_F16_STEP - 1));
735
840
 
736
- GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
841
+ GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
737
842
 
738
- GGML_F16_VEC ay[GGML_F16_ARR];
843
+ GGML_F16_VEC ay[GGML_F16_ARR];
739
844
 
740
- for (int i = 0; i < np; i += GGML_F16_STEP) {
741
- for (int j = 0; j < GGML_F16_ARR; j++) {
742
- ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
743
- ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
845
+ for (int i = 0; i < np; i += GGML_F16_STEP) {
846
+ for (int j = 0; j < GGML_F16_ARR; j++) {
847
+ ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
848
+ ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
744
849
 
745
- GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
746
- }
850
+ GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
747
851
  }
852
+ }
748
853
 
749
- // leftovers
750
- for (int i = np; i < n; ++i) {
751
- y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
752
- }
753
- #endif
854
+ // leftovers
855
+ for (int i = np; i < n; ++i) {
856
+ y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
857
+ }
754
858
  #else
755
859
  // scalar
756
860
  for (int i = 0; i < n; ++i) {
@@ -819,7 +923,8 @@ inline static void ggml_vec_tanh_f16 (const int n, ggml_fp16_t * y, const ggml_f
819
923
  inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expm1f(x[i]); }
820
924
  inline static void ggml_vec_elu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
821
925
  for (int i = 0; i < n; ++i) {
822
- y[i] = GGML_CPU_FP32_TO_FP16(expm1f(GGML_CPU_FP16_TO_FP32(x[i])));
926
+ const float v = GGML_CPU_FP16_TO_FP32(x[i]);
927
+ y[i] = GGML_CPU_FP32_TO_FP16((v > 0.f) ? v : expm1f(v));
823
928
  }
824
929
  }
825
930
  inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
@@ -1414,6 +1519,16 @@ inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
1414
1519
  #endif
1415
1520
  }
1416
1521
 
1522
+ inline static void ggml_vec_cumsum_f32(const int n, float * y, const float * x) {
1523
+ for (int i = 0; i < n; ++i) {
1524
+ if (i == 0) {
1525
+ y[i] = x[i];
1526
+ } else {
1527
+ y[i] = y[i - 1] + x[i];
1528
+ }
1529
+ }
1530
+ }
1531
+
1417
1532
  inline static void ggml_vec_sum_f32_ggf(const int n, ggml_float * s, const float * x) {
1418
1533
  ggml_float sum = 0.0;
1419
1534
  for (int i = 0; i < n; ++i) {
@@ -15,6 +15,7 @@ if (CUDAToolkit_FOUND)
15
15
  # 80 == Ampere, asynchronous data loading, faster tensor core instructions
16
16
  # 86 == RTX 3000, needs CUDA v11.1
17
17
  # 89 == RTX 4000, needs CUDA v11.8
18
+ # 120 == Blackwell, needs CUDA v12.8, FP4 tensor cores
18
19
  #
19
20
  # XX-virtual == compile CUDA code as PTX, do JIT compilation to binary code on first run
20
21
  # XX-real == compile CUDA code as device code for this specific architecture
@@ -34,16 +35,75 @@ if (CUDAToolkit_FOUND)
34
35
  if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.8")
35
36
  list(APPEND CMAKE_CUDA_ARCHITECTURES 89-real)
36
37
  endif()
38
+
39
+ if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8")
40
+ # The CUDA architecture 120f-virtual would in principle work for Blackwell support
41
+ # but the newly added "f" suffix conflicted with a preexising regex for validating CUDA architectures in CMake.
42
+ # So either a recent CMake version or one with the backported fix is needed.
43
+ # The following versions should work:
44
+ # - CMake >= v3.31.8 && CMake < v4.0.0
45
+ # - CMake >= v4.0.2
46
+ # This is NOT documented in the CMake release notes,
47
+ # check Modules/Internal/CMakeCUDAArchitecturesValidate.cmake in the CMake git repository instead.
48
+ # However, the architectures 120a-real and 121a-real should work with basically any CMake version and
49
+ # until the release of e.g. Rubin there is no benefit to shipping virtual architectures for Blackwell.
50
+ list(APPEND CMAKE_CUDA_ARCHITECTURES 120a-real)
51
+ endif()
52
+ if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.9")
53
+ list(APPEND CMAKE_CUDA_ARCHITECTURES 121a-real)
54
+ endif()
37
55
  endif()
38
56
  endif()
39
- message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
40
57
 
41
58
  enable_language(CUDA)
42
59
 
60
+ # TODO: Remove once CCCL 3.2 has been released and bundled with CUDA Toolkit
61
+ if (GGML_CUDA_CUB_3DOT2)
62
+ include(FetchContent)
63
+
64
+ FetchContent_Declare(
65
+ CCCL
66
+ GIT_REPOSITORY https://github.com/nvidia/cccl.git
67
+ GIT_TAG v3.2.0-rc2
68
+ GIT_SHALLOW TRUE
69
+ )
70
+
71
+ FetchContent_MakeAvailable(CCCL)
72
+ endif()
73
+
74
+ # Replace any plain 12X CUDA architectures with their "architecture-specific" equivalents 12Xa.
75
+ # 12X is forwards-compatible, 12Xa is not.
76
+ # Notably the Blackwell FP4 tensor core instructions are not forwards compatible and therefore need 12Xa.
77
+ # But while 12X vs. 12Xa can be checked in device code there is (to my knowledge) no easy way to do the same check in host code.
78
+ # So for now just replace all instances of 12X with 12Xa, this should be fine until Rubin is released.
79
+ foreach(ARCHS IN ITEMS CMAKE_CUDA_ARCHITECTURES CMAKE_CUDA_ARCHITECTURES_NATIVE)
80
+ set(FIXED_ARCHS "")
81
+ foreach(ARCH IN LISTS ${ARCHS})
82
+ if (ARCH MATCHES "^12[0-9](-real|-virtual)?$")
83
+ string(REGEX REPLACE "^(12[0-9])((-real|-virtual)?)$" "\\1a\\2" FIXED_ARCH ${ARCH})
84
+ message(STATUS "Replacing ${ARCH} in ${ARCHS} with ${FIXED_ARCH}")
85
+ list(APPEND FIXED_ARCHS "${FIXED_ARCH}")
86
+ else()
87
+ list(APPEND FIXED_ARCHS "${ARCH}")
88
+ endif()
89
+ endforeach()
90
+ set(${ARCHS} ${FIXED_ARCHS})
91
+ endforeach()
92
+
93
+ # If we try to compile a "native" build it will use the 12X architectures and fail.
94
+ # So we should instead use the native architectures as determined by CMake after replacing 12X with 12Xa.
95
+ # But if at the time of the build no GPUs are connected at all CMAKE_CUDA_ARCHITECTURES will contain garbage that we should not use.
96
+ if (CMAKE_CUDA_ARCHITECTURES STREQUAL "native" AND CMAKE_CUDA_ARCHITECTURES_NATIVE MATCHES "^[0-9]+(a|f)?(-real|-virtual)?(;[0-9]+(a|f)?(-real|-virtual)?|;)*$")
97
+ set(CMAKE_CUDA_ARCHITECTURES ${CMAKE_CUDA_ARCHITECTURES_NATIVE})
98
+ endif()
99
+ message(STATUS "Using CMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} CMAKE_CUDA_ARCHITECTURES_NATIVE=${CMAKE_CUDA_ARCHITECTURES_NATIVE}")
100
+
43
101
  file(GLOB GGML_HEADERS_CUDA "*.cuh")
44
102
  list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h")
45
103
 
46
104
  file(GLOB GGML_SOURCES_CUDA "*.cu")
105
+ file(GLOB SRCS "template-instances/fattn-tile*.cu")
106
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
47
107
  file(GLOB SRCS "template-instances/fattn-mma*.cu")
48
108
  list(APPEND GGML_SOURCES_CUDA ${SRCS})
49
109
  file(GLOB SRCS "template-instances/mmq*.cu")
@@ -100,6 +160,9 @@ if (CUDAToolkit_FOUND)
100
160
  # As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library
101
161
  target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas)
102
162
  else ()
163
+ if (GGML_CUDA_CUB_3DOT2)
164
+ target_link_libraries(ggml-cuda PRIVATE CCCL::CCCL)
165
+ endif()
103
166
  if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "10.1")
104
167
  target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
105
168
  else()
@@ -107,6 +170,9 @@ if (CUDAToolkit_FOUND)
107
170
  endif()
108
171
  endif()
109
172
  else()
173
+ if (GGML_CUDA_CUB_3DOT2)
174
+ target_link_libraries(ggml-cuda PRIVATE CCCL::CCCL)
175
+ endif()
110
176
  target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas)
111
177
  endif()
112
178
 
@@ -122,6 +188,7 @@ if (CUDAToolkit_FOUND)
122
188
 
123
189
  if (GGML_CUDA_DEBUG)
124
190
  list(APPEND CUDA_FLAGS -lineinfo)
191
+ add_compile_definitions(GGML_CUDA_DEBUG)
125
192
  endif()
126
193
 
127
194
  if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8")
@@ -174,6 +241,10 @@ if (CUDAToolkit_FOUND)
174
241
 
175
242
  if (NOT MSVC)
176
243
  list(APPEND CUDA_CXX_FLAGS -Wno-pedantic)
244
+ else()
245
+ # CCCL 3.2 onwards will require a cpp-standard-compliant preprocessor for MSVC
246
+ # https://github.com/NVIDIA/cccl/pull/6827
247
+ list(APPEND CUDA_CXX_FLAGS /Zc:preprocessor)
177
248
  endif()
178
249
 
179
250
  list(JOIN CUDA_CXX_FLAGS " " CUDA_CXX_FLAGS_JOINED) # pass host compiler flags as a single argument
@@ -21,7 +21,7 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest
21
21
  }
22
22
 
23
23
  #pragma unroll
24
- for (int offset = 16; offset > 0; offset >>= 1) {
24
+ for (int offset = WARP_SIZE/2; offset > 0; offset >>= 1) {
25
25
  const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
26
26
  const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
27
27
  if (val > maxval) {
@@ -50,7 +50,7 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest
50
50
  argmax = shared_argmax[lane_id];
51
51
  }
52
52
  #pragma unroll
53
- for (int offset = 16; offset > 0; offset >>= 1) {
53
+ for (int offset = WARP_SIZE/2; offset > 0; offset >>= 1) {
54
54
  const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
55
55
  const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
56
56
  if (val > maxval) {