whispercpp 1.3.5 → 1.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (610) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +99 -2
  4. data/ext/extconf.rb +1 -0
  5. data/ext/ruby_whisper.c +20 -4
  6. data/ext/ruby_whisper.h +30 -2
  7. data/ext/ruby_whisper_context.c +216 -124
  8. data/ext/ruby_whisper_context_params.c +163 -0
  9. data/ext/ruby_whisper_model.c +0 -1
  10. data/ext/ruby_whisper_params.c +0 -1
  11. data/ext/ruby_whisper_segment.c +0 -1
  12. data/ext/ruby_whisper_token.c +29 -9
  13. data/ext/ruby_whisper_transcribe.cpp +4 -1
  14. data/ext/ruby_whisper_vad_context.c +48 -1
  15. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  16. data/ext/ruby_whisper_vad_params.c +0 -1
  17. data/ext/ruby_whisper_vad_segment.c +0 -1
  18. data/ext/ruby_whisper_vad_segments.c +0 -1
  19. data/ext/sources/CMakeLists.txt +1 -1
  20. data/ext/sources/bindings/javascript/package.json +1 -1
  21. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  22. data/ext/sources/examples/bench/bench.cpp +23 -18
  23. data/ext/sources/examples/cli/cli.cpp +8 -0
  24. data/ext/sources/examples/common-ggml.cpp +2 -0
  25. data/ext/sources/examples/miniaudio.h +4507 -2131
  26. data/ext/sources/examples/server/server.cpp +18 -4
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -2
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +7 -13
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +4 -3
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +335 -17
  31. data/ext/sources/examples/talk-llama/llama-arch.h +42 -0
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +3 -1
  33. data/ext/sources/examples/talk-llama/llama-chat.cpp +21 -1
  34. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  35. data/ext/sources/examples/talk-llama/llama-context.cpp +508 -520
  36. data/ext/sources/examples/talk-llama/llama-context.h +27 -28
  37. data/ext/sources/examples/talk-llama/llama-cparams.h +5 -0
  38. data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +8 -8
  40. data/ext/sources/examples/talk-llama/llama-graph.cpp +583 -130
  41. data/ext/sources/examples/talk-llama/llama-graph.h +131 -10
  42. data/ext/sources/examples/talk-llama/llama-hparams.cpp +57 -40
  43. data/ext/sources/examples/talk-llama/llama-hparams.h +79 -10
  44. data/ext/sources/examples/talk-llama/llama-impl.cpp +4 -4
  45. data/ext/sources/examples/talk-llama/llama-impl.h +13 -1
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +3 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +274 -89
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.h +2 -3
  49. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
  50. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
  51. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +11 -13
  52. data/ext/sources/examples/talk-llama/llama-mmap.cpp +28 -11
  53. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +527 -119
  54. data/ext/sources/examples/talk-llama/llama-model-loader.h +35 -5
  55. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +60 -46
  56. data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
  57. data/ext/sources/examples/talk-llama/llama-model.cpp +1365 -647
  58. data/ext/sources/examples/talk-llama/llama-model.h +72 -19
  59. data/ext/sources/examples/talk-llama/llama-quant.cpp +578 -346
  60. data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +190 -76
  61. data/ext/sources/examples/talk-llama/{llama-sampling.h → llama-sampler.h} +0 -2
  62. data/ext/sources/examples/talk-llama/llama-vocab.cpp +118 -48
  63. data/ext/sources/examples/talk-llama/llama-vocab.h +5 -0
  64. data/ext/sources/examples/talk-llama/llama.cpp +76 -22
  65. data/ext/sources/examples/talk-llama/llama.h +63 -30
  66. data/ext/sources/examples/talk-llama/models/afmoe.cpp +2 -3
  67. data/ext/sources/examples/talk-llama/models/apertus.cpp +3 -3
  68. data/ext/sources/examples/talk-llama/models/arcee.cpp +3 -3
  69. data/ext/sources/examples/talk-llama/models/arctic.cpp +4 -5
  70. data/ext/sources/examples/talk-llama/models/baichuan.cpp +4 -3
  71. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +1 -2
  72. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +3 -5
  73. data/ext/sources/examples/talk-llama/models/bert.cpp +13 -7
  74. data/ext/sources/examples/talk-llama/models/bitnet.cpp +9 -24
  75. data/ext/sources/examples/talk-llama/models/bloom.cpp +2 -2
  76. data/ext/sources/examples/talk-llama/models/chameleon.cpp +3 -3
  77. data/ext/sources/examples/talk-llama/models/chatglm.cpp +2 -2
  78. data/ext/sources/examples/talk-llama/models/codeshell.cpp +3 -3
  79. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +3 -3
  80. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +2 -2
  81. data/ext/sources/examples/talk-llama/models/command-r.cpp +2 -2
  82. data/ext/sources/examples/talk-llama/models/dbrx.cpp +4 -5
  83. data/ext/sources/examples/talk-llama/models/deci.cpp +3 -3
  84. data/ext/sources/examples/talk-llama/models/deepseek.cpp +4 -6
  85. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +24 -21
  86. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
  87. data/ext/sources/examples/talk-llama/models/dots1.cpp +4 -6
  88. data/ext/sources/examples/talk-llama/models/dream.cpp +3 -3
  89. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +4 -6
  90. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +3 -3
  91. data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
  92. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +3 -3
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +3 -3
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +2 -4
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +3 -3
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +1 -1
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +1 -1
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +1 -1
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +1 -1
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +7 -7
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +3 -3
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +14 -7
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +2 -2
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +2 -2
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +4 -5
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +4 -5
  108. data/ext/sources/examples/talk-llama/models/grok.cpp +4 -4
  109. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +5 -7
  110. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +3 -3
  111. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +4 -5
  112. data/ext/sources/examples/talk-llama/models/internlm2.cpp +3 -3
  113. data/ext/sources/examples/talk-llama/models/jais.cpp +2 -2
  114. data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +3 -3
  116. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
  117. data/ext/sources/examples/talk-llama/models/lfm2.cpp +145 -124
  118. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +4 -4
  119. data/ext/sources/examples/talk-llama/models/llada.cpp +3 -3
  120. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +4 -4
  121. data/ext/sources/examples/talk-llama/models/llama.cpp +18 -11
  122. data/ext/sources/examples/talk-llama/models/maincoder.cpp +3 -3
  123. data/ext/sources/examples/talk-llama/models/{graph-context-mamba.cpp → mamba-base.cpp} +9 -3
  124. data/ext/sources/examples/talk-llama/models/mamba.cpp +1 -2
  125. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +11 -5
  126. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +14 -13
  127. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +4 -5
  128. data/ext/sources/examples/talk-llama/models/mistral3.cpp +4 -4
  129. data/ext/sources/examples/talk-llama/models/models.h +181 -46
  130. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +2 -9
  131. data/ext/sources/examples/talk-llama/models/mpt.cpp +2 -2
  132. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +26 -14
  133. data/ext/sources/examples/talk-llama/models/nemotron.cpp +3 -3
  134. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +2 -2
  135. data/ext/sources/examples/talk-llama/models/olmo.cpp +3 -3
  136. data/ext/sources/examples/talk-llama/models/olmo2.cpp +3 -3
  137. data/ext/sources/examples/talk-llama/models/olmoe.cpp +4 -4
  138. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +1 -1
  139. data/ext/sources/examples/talk-llama/models/openelm.cpp +3 -3
  140. data/ext/sources/examples/talk-llama/models/orion.cpp +3 -3
  141. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
  142. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +3 -3
  143. data/ext/sources/examples/talk-llama/models/phi2.cpp +2 -2
  144. data/ext/sources/examples/talk-llama/models/phi3.cpp +3 -3
  145. data/ext/sources/examples/talk-llama/models/plamo.cpp +3 -3
  146. data/ext/sources/examples/talk-llama/models/plamo2.cpp +9 -5
  147. data/ext/sources/examples/talk-llama/models/plamo3.cpp +2 -2
  148. data/ext/sources/examples/talk-llama/models/plm.cpp +15 -14
  149. data/ext/sources/examples/talk-llama/models/qwen.cpp +2 -2
  150. data/ext/sources/examples/talk-llama/models/qwen2.cpp +3 -3
  151. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +4 -4
  152. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +3 -3
  153. data/ext/sources/examples/talk-llama/models/qwen3.cpp +12 -9
  154. data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
  155. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
  156. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +15 -8
  157. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +84 -432
  158. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +9 -18
  159. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +8 -17
  160. data/ext/sources/examples/talk-llama/models/refact.cpp +2 -2
  161. data/ext/sources/examples/talk-llama/models/rnd1.cpp +4 -4
  162. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +2 -0
  163. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +2 -0
  164. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +3 -3
  165. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +4 -4
  166. data/ext/sources/examples/talk-llama/models/smollm3.cpp +3 -3
  167. data/ext/sources/examples/talk-llama/models/stablelm.cpp +2 -2
  168. data/ext/sources/examples/talk-llama/models/starcoder.cpp +2 -2
  169. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +3 -3
  170. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
  171. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +2 -2
  172. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +2 -2
  173. data/ext/sources/examples/talk-llama/models/xverse.cpp +3 -3
  174. data/ext/sources/examples/talk-llama/unicode.cpp +21 -65
  175. data/ext/sources/ggml/CMakeLists.txt +9 -3
  176. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  177. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  179. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  180. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  181. data/ext/sources/ggml/include/ggml-rpc.h +6 -1
  182. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  183. data/ext/sources/ggml/include/ggml.h +56 -9
  184. data/ext/sources/ggml/src/CMakeLists.txt +3 -0
  185. data/ext/sources/ggml/src/ggml-alloc.c +4 -9
  186. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  187. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  188. data/ext/sources/ggml/src/ggml-backend-reg.cpp +28 -86
  189. data/ext/sources/ggml/src/ggml-backend.cpp +5 -2
  190. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  191. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +6 -2
  192. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  193. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  194. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +348 -189
  195. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +40 -85
  196. data/ext/sources/ggml/src/ggml-cann/common.h +3 -4
  197. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +44 -62
  198. data/ext/sources/ggml/src/ggml-common.h +11 -0
  199. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +16 -11
  200. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -19
  201. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  202. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
  203. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +85 -1
  204. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2744 -548
  205. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1653 -0
  206. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
  207. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
  208. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
  209. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +118 -18
  210. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +107 -26
  211. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  212. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  213. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -0
  214. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +59 -12
  215. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +15 -0
  216. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  217. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +965 -252
  218. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +584 -197
  219. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +903 -188
  220. data/ext/sources/ggml/src/ggml-cpu/ops.h +1 -0
  221. data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
  222. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  223. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2890 -679
  224. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  225. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  226. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +111 -3
  227. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  228. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +17 -0
  229. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
  230. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +19 -10
  231. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +32 -30
  232. data/ext/sources/ggml/src/ggml-cuda/common.cuh +134 -18
  233. data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +6 -3
  235. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +78 -64
  236. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +384 -143
  237. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +36 -22
  238. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +3 -3
  239. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +26 -5
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  241. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +127 -12
  242. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
  243. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  244. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +595 -200
  245. data/ext/sources/ggml/src/ggml-cuda/mean.cu +9 -8
  246. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +173 -6
  247. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  248. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +158 -85
  249. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +34 -22
  250. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +127 -67
  251. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  252. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +157 -65
  253. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -0
  254. data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
  255. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  256. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +1 -1
  257. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
  258. data/ext/sources/ggml/src/ggml-cuda/rope.cu +233 -133
  259. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  260. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  261. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +56 -32
  262. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
  264. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  265. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
  266. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  267. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  268. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  269. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +3 -3
  270. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +0 -1
  271. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +199 -135
  272. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  273. data/ext/sources/ggml/src/ggml-cuda/unary.cu +55 -0
  274. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
  275. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
  276. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +10 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +82 -45
  278. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +334 -160
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +7 -5
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +328 -197
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +765 -234
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +412 -265
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +23 -23
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +1 -1
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.h → hex-dma.h} +28 -3
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +1 -1
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +27 -37
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +6 -35
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  297. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +20 -1347
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +211 -13
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +1119 -952
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +254 -244
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +36 -36
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +155 -138
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +209 -114
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  317. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  321. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +6 -0
  322. data/ext/sources/ggml/src/ggml-impl.h +62 -0
  323. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  324. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  325. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  326. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  327. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +274 -73
  328. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +22 -4
  329. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +102 -36
  330. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +174 -23
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +580 -280
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +5 -4
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +320 -107
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1068 -825
  335. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +19 -1
  336. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +3108 -636
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +204 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  367. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  368. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  369. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
  370. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  371. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
  372. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  373. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
  374. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
  375. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  376. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  377. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  378. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  379. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  380. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  381. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  382. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  383. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  384. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  385. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  386. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  387. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  388. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  389. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  390. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  391. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  392. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
  393. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  394. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  395. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  396. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  397. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  398. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  399. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
  400. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
  401. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
  402. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
  403. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  404. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  405. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  406. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  407. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
  412. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
  413. data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
  414. data/ext/sources/ggml/src/ggml-quants.c +96 -5
  415. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  416. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +15 -88
  417. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  418. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +1 -0
  419. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  420. data/ext/sources/ggml/src/ggml-sycl/common.hpp +315 -10
  421. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +69 -1
  422. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
  423. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  424. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  425. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +78 -68
  426. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  427. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
  428. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
  429. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
  430. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
  431. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
  432. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  433. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
  434. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
  435. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +316 -51
  436. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  437. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  438. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  439. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
  440. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  441. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  442. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  443. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  444. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  445. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  446. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  447. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  448. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  449. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  450. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  451. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  452. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
  453. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
  454. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
  455. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
  456. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
  457. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
  458. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
  459. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
  460. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
  461. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
  462. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
  463. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
  464. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
  465. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
  466. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
  467. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
  468. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
  469. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
  470. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
  471. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
  472. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
  473. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
  474. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
  475. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
  476. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
  477. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
  478. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
  479. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
  480. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
  481. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
  482. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
  483. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
  484. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
  485. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
  486. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
  487. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
  488. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +13 -0
  489. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  490. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  491. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  492. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  493. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  494. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  495. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  496. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  497. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  498. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  499. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  500. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  501. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  502. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  503. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  504. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  505. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  506. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  507. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  508. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  509. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  510. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  511. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  512. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  513. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  514. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  515. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  516. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  517. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  518. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  519. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  520. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  521. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  522. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  523. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  524. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  525. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  526. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  527. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  528. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  529. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  530. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  531. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  532. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +1 -1
  533. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1250 -465
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +374 -170
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +66 -22
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +389 -201
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +106 -58
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -3
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +8 -4
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +3 -3
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +36 -63
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +10 -5
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +16 -10
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +55 -35
  560. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1314 -109
  561. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1660 -1371
  562. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  563. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  564. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  565. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
  566. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
  567. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
  568. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +6 -0
  569. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  570. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +40 -5
  571. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +105 -60
  572. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
  573. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +68 -257
  574. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +692 -23
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_reg_tile.tmpl.wgsl → mul_mat_reg_tile.wgsl} +28 -128
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +31 -137
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +9 -36
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  584. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
  585. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  586. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +9 -6
  587. data/ext/sources/ggml/src/ggml.c +167 -33
  588. data/ext/sources/ggml/src/gguf.cpp +229 -44
  589. data/ext/sources/src/whisper.cpp +6 -28
  590. data/sig/whisper.rbs +43 -2
  591. data/test/test_context_params.rb +82 -0
  592. data/test/test_token.rb +11 -0
  593. data/test/test_vad_context.rb +58 -8
  594. data/test/test_whisper.rb +20 -0
  595. data/whispercpp.gemspec +1 -1
  596. metadata +240 -28
  597. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  598. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  599. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  600. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  601. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  602. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  603. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  604. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  605. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  606. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  607. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  608. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  609. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  610. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
@@ -29,7 +29,10 @@ layout (push_constant) uniform parameter
29
29
  #ifdef MUL_MAT_ID
30
30
  uint nei0;
31
31
  uint ne11;
32
+ uint expert_i1;
33
+ uint nbi1;
32
34
  #else
35
+ uint base_work_group_y;
33
36
  uint ne02;
34
37
  uint ne12;
35
38
  uint broadcast2;
@@ -43,9 +46,9 @@ uint expert_id;
43
46
 
44
47
  void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
45
48
  #ifdef MUL_MAT_ID
46
- const uint expert_idx = gl_GlobalInvocationID.y;
49
+ const uint expert_i0 = gl_WorkGroupID.y;
47
50
  #else
48
- const uint batch_idx = gl_GlobalInvocationID.y;
51
+ const uint batch_idx = gl_WorkGroupID.y + p.base_work_group_y;
49
52
  #endif
50
53
 
51
54
  #ifndef MUL_MAT_ID
@@ -60,7 +63,7 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
60
63
  batch_idx_a = i03 * p.ne02 + i02;
61
64
  }
62
65
  #else
63
- expert_id = data_ids[expert_idx];
66
+ expert_id = data_ids[expert_i0 + p.expert_i1 * p.nbi1];
64
67
  #endif
65
68
 
66
69
  a_offset =
@@ -71,13 +74,13 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
71
74
  #endif
72
75
  b_offset =
73
76
  #ifdef MUL_MAT_ID
74
- (expert_idx % p.ne11) * p.stride_b;
77
+ (expert_i0 % p.ne11) * p.stride_b + p.expert_i1 * p.batch_stride_b;
75
78
  #else
76
79
  batch_idx * p.batch_stride_b;
77
80
  #endif
78
81
  d_offset =
79
82
  #ifdef MUL_MAT_ID
80
- expert_idx * p.stride_d;
83
+ expert_i0 * p.stride_d + p.expert_i1 * p.batch_stride_d;
81
84
  #else
82
85
  batch_idx * p.batch_stride_d;
83
86
  #endif
@@ -103,12 +106,12 @@ void reduce_result(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t
103
106
  temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
104
107
  }
105
108
  if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
106
- const uint expert_idx = gl_GlobalInvocationID.y;
107
- temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_idx]);
109
+ const uint expert_i0 = gl_GlobalInvocationID.y;
110
+ temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_i0]);
108
111
  }
109
112
  if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
110
- const uint expert_idx = gl_GlobalInvocationID.y;
111
- temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_idx]);
113
+ const uint expert_i0 = gl_GlobalInvocationID.y;
114
+ temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_i0]);
112
115
  }
113
116
  #else
114
117
  if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
@@ -158,12 +161,12 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs
158
161
  temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
159
162
  }
160
163
  if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
161
- const uint expert_idx = gl_GlobalInvocationID.y;
162
- temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_idx]);
164
+ const uint expert_i0 = gl_GlobalInvocationID.y;
165
+ temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_i0]);
163
166
  }
164
167
  if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
165
- const uint expert_idx = gl_GlobalInvocationID.y;
166
- temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_idx]);
168
+ const uint expert_i0 = gl_GlobalInvocationID.y;
169
+ temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_i0]);
167
170
  }
168
171
  #else
169
172
  if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
@@ -203,12 +206,12 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs
203
206
  tmpsh[j][n][0] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
204
207
  }
205
208
  if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
206
- const uint expert_idx = gl_GlobalInvocationID.y;
207
- tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse0[expert_idx]);
209
+ const uint expert_i0 = gl_GlobalInvocationID.y;
210
+ tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse0[expert_i0]);
208
211
  }
209
212
  if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
210
- const uint expert_idx = gl_GlobalInvocationID.y;
211
- tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse1[expert_idx]);
213
+ const uint expert_i0 = gl_GlobalInvocationID.y;
214
+ tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse1[expert_i0]);
212
215
  }
213
216
  #else
214
217
  if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
@@ -90,6 +90,8 @@ layout (push_constant) uniform parameter
90
90
  uint nbi1;
91
91
  uint ne11;
92
92
  #else
93
+ uint base_work_group_z;
94
+ uint num_batches;
93
95
  uint k_split;
94
96
  uint ne02;
95
97
  uint ne12;
@@ -139,7 +141,7 @@ void main() {
139
141
  const uint ic = gl_WorkGroupID.y;
140
142
 
141
143
  #ifdef MUL_MAT_ID
142
- const uint expert_idx = gl_GlobalInvocationID.z;
144
+ const uint expert_idx = gl_WorkGroupID.z;
143
145
  if (ic * BN >= data_expert_count[expert_idx]) {
144
146
  return;
145
147
  }
@@ -149,7 +151,7 @@ void main() {
149
151
  #endif
150
152
 
151
153
  #ifndef MUL_MAT_ID
152
- const uint batch_idx = gl_GlobalInvocationID.z;
154
+ const uint batch_idx = gl_WorkGroupID.z + p.base_work_group_z;
153
155
 
154
156
  const uint i13 = batch_idx / p.ne12;
155
157
  const uint i12 = batch_idx % p.ne12;
@@ -366,7 +368,7 @@ void main() {
366
368
  const uint dc = ic * BN + warp_c * WN;
367
369
 
368
370
  #ifndef MUL_MAT_ID
369
- const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
371
+ const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches;
370
372
  #endif
371
373
 
372
374
  #ifdef COOPMAT
@@ -375,6 +377,7 @@ void main() {
375
377
  [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
376
378
  coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
377
379
 
380
+ barrier();
378
381
  [[unroll]] for (uint col = 0; col < TN; col += storestride) {
379
382
  const uint row_i = dc + cm_col * TN + col + store_c;
380
383
  if (row_i >= _ne1) break;
@@ -385,6 +388,7 @@ void main() {
385
388
  data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
386
389
  }
387
390
  }
391
+ barrier();
388
392
  }
389
393
  }
390
394
  #else
@@ -402,18 +406,22 @@ void main() {
402
406
  // Full coopMat is within bounds, but stride_d is not aligned
403
407
  coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
404
408
 
409
+ controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease);
405
410
  [[unroll]] for (uint col = 0; col < TN; col += storestride) {
406
411
  data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
407
412
  }
413
+ controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease);
408
414
  } else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) {
409
415
  // Partial coopMat is within bounds
410
416
  coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
411
417
 
418
+ controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease);
412
419
  [[unroll]] for (uint col = 0; col < TN; col += storestride) {
413
420
  if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) {
414
421
  data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
415
422
  }
416
423
  }
424
+ controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease);
417
425
  }
418
426
  }
419
427
  }
@@ -53,6 +53,8 @@ layout (push_constant) uniform parameter
53
53
  uint nbi1;
54
54
  uint ne11;
55
55
  #else
56
+ uint base_work_group_z;
57
+ uint num_batches;
56
58
  uint k_split;
57
59
  uint ne02;
58
60
  uint ne12;
@@ -165,7 +167,9 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
165
167
  uint id = ids[iter++];
166
168
  uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
167
169
 
168
- ballots_sh[gl_SubgroupID] = ballot;
170
+ if (gl_SubgroupInvocationID == 0) {
171
+ ballots_sh[gl_SubgroupID] = ballot;
172
+ }
169
173
  barrier();
170
174
 
171
175
  uint subgroup_base = 0;
@@ -197,7 +201,7 @@ void main() {
197
201
  const uint ic = gl_WorkGroupID.y;
198
202
 
199
203
  #ifdef MUL_MAT_ID
200
- const uint expert_idx = gl_GlobalInvocationID.z;
204
+ const uint expert_idx = gl_WorkGroupID.z;
201
205
  if (ic * BN >= data_expert_count[expert_idx]) {
202
206
  return;
203
207
  }
@@ -215,7 +219,7 @@ void main() {
215
219
  #endif
216
220
 
217
221
  #ifndef MUL_MAT_ID
218
- const uint batch_idx = gl_GlobalInvocationID.z;
222
+ const uint batch_idx = gl_WorkGroupID.z + p.base_work_group_z;
219
223
 
220
224
  const uint i13 = batch_idx / p.ne12;
221
225
  const uint i12 = batch_idx % p.ne12;
@@ -255,7 +259,7 @@ void main() {
255
259
  #else
256
260
  uint pos_a = batch_idx_a * (p.batch_stride_a / QUANT_K);
257
261
  uint pos_b = batch_idx * p.batch_stride_b;
258
- uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
262
+ uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches;
259
263
  #endif
260
264
 
261
265
  uint stride_a = p.stride_a / QUANT_K;
@@ -43,7 +43,9 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
43
43
  uint id = ids[iter++];
44
44
  uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
45
45
 
46
- ballots_sh[gl_SubgroupID] = ballot;
46
+ if (gl_SubgroupInvocationID == 0) {
47
+ ballots_sh[gl_SubgroupID] = ballot;
48
+ }
47
49
  barrier();
48
50
 
49
51
  uint subgroup_base = 0;
@@ -57,6 +57,8 @@ layout (push_constant) uniform parameter
57
57
  uint nbi1;
58
58
  uint ne11;
59
59
  #else
60
+ uint base_work_group_z;
61
+ uint num_batches;
60
62
  uint k_split;
61
63
  uint ne02;
62
64
  uint ne12;
@@ -108,7 +110,7 @@ void main() {
108
110
  const uint ic = gl_WorkGroupID.y;
109
111
 
110
112
  #ifdef MUL_MAT_ID
111
- const uint expert_idx = gl_GlobalInvocationID.z;
113
+ const uint expert_idx = gl_WorkGroupID.z;
112
114
  if (ic * BN >= data_expert_count[expert_idx]) {
113
115
  return;
114
116
  }
@@ -118,7 +120,7 @@ void main() {
118
120
  #endif
119
121
 
120
122
  #ifndef MUL_MAT_ID
121
- const uint batch_idx = gl_GlobalInvocationID.z;
123
+ const uint batch_idx = gl_WorkGroupID.z + p.base_work_group_z;
122
124
 
123
125
  const uint i13 = batch_idx / p.ne12;
124
126
  const uint i12 = batch_idx % p.ne12;
@@ -276,7 +278,7 @@ void main() {
276
278
  const uint dc = ic * BN + warp_c * WN;
277
279
 
278
280
  #ifndef MUL_MAT_ID
279
- const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
281
+ const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches;
280
282
  #endif
281
283
 
282
284
  [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
@@ -264,7 +264,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
264
264
  const i8vec2 scales = i8vec2(unpack8(uint32_t(((data_a_packed16[ib_k].scales[(is % 8 ) / 2] >> (4 * (is / 8))) & 0x0F0F) |
265
265
  (((data_a_packed16[ib_k].scales[(8 + (is % 4)) / 2] >> (2 * (is / 4))) & 0x0303) << 4))).xy); // vec4 used due to #12147
266
266
 
267
- buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales - 32);
267
+ buf_a[buf_ib].d_scales = FLOAT_TYPE_VEC2(float(data_a_packed16[ib_k].d) * vec2(scales - 32));
268
268
  }
269
269
  }
270
270
 
@@ -334,7 +334,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
334
334
  (data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2));
335
335
  }
336
336
 
337
- buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm) * FLOAT_TYPE_VEC2(scale_dm);
337
+ buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(vec2(data_a_packed32[ib_k].dm) * vec2(scale_dm));
338
338
  }
339
339
  }
340
340
 
@@ -385,7 +385,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
385
385
  const uint is = iqs_k / 4;
386
386
  const i8vec2 scales = unpack8(int32_t(data_a_packed16[ib_k].scales[is / 2])).xy;
387
387
 
388
- buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales);
388
+ buf_a[buf_ib].d_scales = FLOAT_TYPE_VEC2(float(data_a_packed16[ib_k].d) * vec2(scales));
389
389
  }
390
390
  }
391
391
 
@@ -112,12 +112,11 @@ void rms_norm(uint num_iters) {
112
112
  #if RMS_NORM_ROPE_FUSION
113
113
  barrier();
114
114
  rope_params rp = p.rope;
115
- uint rope_row = (samp*nchannels + channel)*nrows + row;
116
115
  for (uint t = 2*tid; t < ncols; t += 2*BLOCK_SIZE) {
117
116
  if (rp.rope_mode == GGML_ROPE_TYPE_NEOX) {
118
- rope_neox(t, rope_row, rp);
117
+ rope_neox(t, row, channel, samp, rp);
119
118
  } else if (rp.rope_mode == GGML_ROPE_TYPE_NORMAL) {
120
- rope_norm(t, rope_row, rp);
119
+ rope_norm(t, row, channel, samp, rp);
121
120
  }
122
121
  }
123
122
  #endif
@@ -4,12 +4,12 @@ float rope_yarn_ramp(const float low, const float high, const uint i0) {
4
4
  return 1.0f - min(1.0f, max(0.0f, y));
5
5
  }
6
6
 
7
- uint rope_a_coord(const uint i0, const uint i01, const uint i02, rope_params p) {
7
+ uint rope_a_coord(const uint i0, const uint i01, const uint i02, const uint i03, rope_params p) {
8
8
  #if RMS_NORM_ROPE_FUSION
9
9
  // Per-row offset in shared memory
10
10
  const uint ix = i0;
11
11
  #else
12
- const uint ix = i02*p.nb02 + i01*p.nb01 + i0;
12
+ const uint ix = i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i0;
13
13
  #endif
14
14
  return ix;
15
15
  }
@@ -34,26 +34,19 @@ void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out
34
34
  sin_theta = sin(theta) * mscale;
35
35
  }
36
36
 
37
- void rope_norm(const uint i0, const uint i1, rope_params p) {
38
- uint ne0 = p.ncols;
39
- uint ne1 = p.p_delta_rows;
40
-
41
- if (i0 >= ne0) {
37
+ void rope_norm(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) {
38
+ if (i0 >= p.ne00) {
42
39
  return;
43
40
  }
44
41
 
45
- // i1 is actually i2*nb2+i1, but the rows are contiguous
46
- const uint i01 = i1 % ne1;
47
- const uint i02 = i1 / ne1;
48
-
49
- uint idst = i1*ne0 + i0;
50
- const uint ix = rope_a_coord(i0, i01, i02, p);
42
+ uint idst = i0 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
43
+ const uint ix = rope_a_coord(i0, i1, i2, i3, p);
51
44
 
52
45
  // Fusion optimization: ROPE + VIEW + SET_ROWS.
53
46
  // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i.
54
47
  if (p.set_rows_stride != 0) {
55
- idst = i01*ne0 + i0;
56
- idst += rope_data_i[i02].x * p.set_rows_stride;
48
+ idst = i1*p.nb11 + i0;
49
+ idst += rope_data_i[i2].x * p.set_rows_stride;
57
50
  }
58
51
 
59
52
  if (i0 >= p.n_dims) {
@@ -63,7 +56,7 @@ void rope_norm(const uint i0, const uint i1, rope_params p) {
63
56
  return;
64
57
  }
65
58
 
66
- const float theta_base = rope_data_pos[i02] * pow(p.theta_scale, i0/2.0f);
59
+ const float theta_base = rope_data_pos[i2] * pow(p.theta_scale, i0/2.0f);
67
60
 
68
61
  const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;
69
62
 
@@ -77,25 +70,19 @@ void rope_norm(const uint i0, const uint i1, rope_params p) {
77
70
  rope_data_d[idst + 1] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);
78
71
  }
79
72
 
80
- void rope_neox(const uint i0, const uint i1, rope_params p) {
81
- uint ne0 = p.ncols;
82
- uint ne1 = p.p_delta_rows;
83
-
84
- if (i0 >= ne0) {
73
+ void rope_neox(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) {
74
+ if (i0 >= p.ne00) {
85
75
  return;
86
76
  }
87
77
 
88
- const uint i01 = i1 % ne1;
89
- const uint i02 = i1 / ne1;
90
-
91
- uint idst = i1*ne0 + i0/2;
92
- const uint ix = rope_a_coord(i0/2, i01, i02, p);
78
+ uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
79
+ const uint ix = rope_a_coord(i0/2, i1, i2, i3, p);
93
80
 
94
81
  // Fusion optimization: ROPE + VIEW + SET_ROWS.
95
82
  // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i.
96
83
  if (p.set_rows_stride != 0) {
97
- idst = i01*ne0 + i0/2;
98
- idst += rope_data_i[i02].x * p.set_rows_stride;
84
+ idst = i1*p.nb11 + i0/2;
85
+ idst += rope_data_i[i2].x * p.set_rows_stride;
99
86
  }
100
87
 
101
88
  if (i0 >= p.n_dims) {
@@ -105,7 +92,7 @@ void rope_neox(const uint i0, const uint i1, rope_params p) {
105
92
  return;
106
93
  }
107
94
 
108
- const float theta_base = rope_data_pos[i02] * pow(p.theta_scale, i0/2.0f);
95
+ const float theta_base = rope_data_pos[i2] * pow(p.theta_scale, i0/2.0f);
109
96
 
110
97
  const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;
111
98
 
@@ -120,26 +107,19 @@ void rope_neox(const uint i0, const uint i1, rope_params p) {
120
107
  }
121
108
 
122
109
 
123
- void rope_multi(const uint i0, const uint i1, rope_params p) {
124
- uint ne0 = p.ncols;
125
- uint ne1 = p.p_delta_rows;
126
- uint ne2 = p.ne02;
127
-
128
- if (i0 >= ne0) {
110
+ void rope_multi(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) {
111
+ if (i0 >= p.ne00) {
129
112
  return;
130
113
  }
131
114
 
132
- const uint i01 = i1 % ne1;
133
- const uint i02 = i1 / ne1;
134
-
135
- uint idst = i1*ne0 + i0/2;
136
- const uint ix = rope_a_coord(i0/2, i01, i02, p);
115
+ uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
116
+ const uint ix = rope_a_coord(i0/2, i1, i2, i3, p);
137
117
 
138
118
  // Fusion optimization: ROPE + VIEW + SET_ROWS.
139
119
  // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i.
140
120
  if (p.set_rows_stride != 0) {
141
- idst = i01*ne0 + i0/2;
142
- idst += rope_data_i[i02].x * p.set_rows_stride;
121
+ idst = i1*p.nb11 + i0/2;
122
+ idst += rope_data_i[i2].x * p.set_rows_stride;
143
123
  }
144
124
 
145
125
  if (i0 >= p.n_dims) {
@@ -156,26 +136,26 @@ void rope_multi(const uint i0, const uint i1, rope_params p) {
156
136
  float theta_base = 0.0;
157
137
  if (p.is_imrope != 0) {
158
138
  if (sector % 3 == 1 && sector < 3 * p.sections[1]) {
159
- theta_base = rope_data_pos[i02 + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
139
+ theta_base = rope_data_pos[i2 + p.ne02 * 1]*pow(p.theta_scale, i0/2.0f);
160
140
  } else if (sector % 3 == 2 && sector < 3 * p.sections[2]) {
161
- theta_base = rope_data_pos[i02 + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
141
+ theta_base = rope_data_pos[i2 + p.ne02 * 2]*pow(p.theta_scale, i0/2.0f);
162
142
  } else if (sector % 3 == 0 && sector < 3 * p.sections[0]) {
163
- theta_base = rope_data_pos[i02]*pow(p.theta_scale, i0/2.0f);
143
+ theta_base = rope_data_pos[i2]*pow(p.theta_scale, i0/2.0f);
164
144
  } else {
165
- theta_base = rope_data_pos[i02 + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
145
+ theta_base = rope_data_pos[i2 + p.ne02 * 3]*pow(p.theta_scale, i0/2.0f);
166
146
  }
167
147
  } else {
168
148
  if (sector < p.sections[0]) {
169
- theta_base = rope_data_pos[i02]*pow(p.theta_scale, i0/2.0f);
149
+ theta_base = rope_data_pos[i2]*pow(p.theta_scale, i0/2.0f);
170
150
  }
171
151
  else if (sector >= p.sections[0] && sector < sec_w) {
172
- theta_base = rope_data_pos[i02 + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
152
+ theta_base = rope_data_pos[i2 + p.ne02 * 1]*pow(p.theta_scale, i0/2.0f);
173
153
  }
174
154
  else if (sector >= sec_w && sector < sec_w + p.sections[2]) {
175
- theta_base = rope_data_pos[i02 + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
155
+ theta_base = rope_data_pos[i2 + p.ne02 * 2]*pow(p.theta_scale, i0/2.0f);
176
156
  }
177
157
  else if (sector >= sec_w + p.sections[2]) {
178
- theta_base = rope_data_pos[i02 + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
158
+ theta_base = rope_data_pos[i2 + p.ne02 * 3]*pow(p.theta_scale, i0/2.0f);
179
159
  }
180
160
  }
181
161
 
@@ -191,20 +171,13 @@ void rope_multi(const uint i0, const uint i1, rope_params p) {
191
171
  rope_data_d[idst + p.n_dims/2] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);
192
172
  }
193
173
 
194
- void rope_vision(const uint i0, const uint i1, rope_params p) {
195
- uint ne0 = p.ncols;
196
- uint ne1 = p.p_delta_rows;
197
- uint ne2 = p.ne02;
198
-
199
- if (i0 >= ne0) {
174
+ void rope_vision(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) {
175
+ if (i0 >= p.ne00) {
200
176
  return;
201
177
  }
202
178
 
203
- const uint i01 = i1 % ne1;
204
- const uint i02 = i1 / ne1;
205
-
206
- const uint idst = i1*ne0 + i0/2;
207
- const uint ix = rope_a_coord(i0/2, i01, i02, p);
179
+ const uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
180
+ const uint ix = rope_a_coord(i0/2, i1, i2, i3, p);
208
181
 
209
182
  const int sect_dims = p.sections[0] + p.sections[1];
210
183
  const int sec_w = p.sections[1] + p.sections[0];
@@ -213,11 +186,11 @@ void rope_vision(const uint i0, const uint i1, rope_params p) {
213
186
  float theta_base = 0.0;
214
187
  if (sector < p.sections[0]) {
215
188
  const uint p0 = sector;
216
- theta_base = rope_data_pos[i02]*pow(p.theta_scale, p0);
189
+ theta_base = rope_data_pos[i2]*pow(p.theta_scale, p0);
217
190
  }
218
191
  else if (sector >= p.sections[0] && sector < sec_w) {
219
192
  const uint p0 = sector - p.sections[0];
220
- theta_base = rope_data_pos[i02 + ne2]*pow(p.theta_scale, p0);
193
+ theta_base = rope_data_pos[i2 + p.ne02]*pow(p.theta_scale, p0);
221
194
  }
222
195
 
223
196
  const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;
@@ -5,10 +5,13 @@
5
5
 
6
6
  void main() {
7
7
  const uint i0 = 2*gl_GlobalInvocationID.y;
8
- // i1 is actually i2*nb2+i1, but the rows are contiguous
9
- const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
10
- if (i1 >= pc.nrows) {
8
+ const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
9
+ if (row >= pc.nrows) {
11
10
  return;
12
11
  }
13
- rope_multi(i0, i1, pc);
12
+ const uint i3 = row / (pc.ne01*pc.ne02);
13
+ const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01;
14
+ const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01);
15
+
16
+ rope_multi(i0, i1, i2, i3, pc);
14
17
  }
@@ -5,10 +5,13 @@
5
5
 
6
6
  void main() {
7
7
  const uint i0 = 2*gl_GlobalInvocationID.y;
8
- // i1 is actually i2*nb2+i1, but the rows are contiguous
9
- const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
10
- if (i1 >= pc.nrows) {
8
+ const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
9
+ if (row >= pc.nrows) {
11
10
  return;
12
11
  }
13
- rope_neox(i0, i1, pc);
12
+ const uint i3 = row / (pc.ne01*pc.ne02);
13
+ const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01;
14
+ const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01);
15
+
16
+ rope_neox(i0, i1, i2, i3, pc);
14
17
  }
@@ -5,10 +5,13 @@
5
5
 
6
6
  void main() {
7
7
  const uint i0 = 2*gl_GlobalInvocationID.y;
8
- // i1 is actually i2*nb2+i1, but the rows are contiguous
9
- const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
10
- if (i1 >= pc.nrows) {
8
+ const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
9
+ if (row >= pc.nrows) {
11
10
  return;
12
11
  }
13
- rope_norm(i0, i1, pc);
12
+ const uint i3 = row / (pc.ne01*pc.ne02);
13
+ const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01;
14
+ const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01);
15
+
16
+ rope_norm(i0, i1, i2, i3, pc);
14
17
  }
@@ -5,24 +5,29 @@
5
5
 
6
6
  struct rope_params {
7
7
  uint rope_mode;
8
- uint ncols;
9
8
  uint nrows;
10
9
  uint n_dims;
11
10
  float freq_scale;
12
- uint p_delta_rows;
13
11
  float freq_base;
14
12
  float ext_factor;
15
13
  float attn_factor;
16
14
  float corr_dims[2];
17
15
  float theta_scale;
18
16
  uint has_ff;
19
- uint ne02;
20
- uint nb01;
21
- uint nb02;
22
17
  int sections[4];
23
18
  uint is_imrope;
24
19
  uint is_back;
25
20
  uint set_rows_stride;
21
+
22
+ uint ne00;
23
+ uint ne01;
24
+ uint ne02;
25
+ uint nb01;
26
+ uint nb02;
27
+ uint nb03;
28
+ uint nb11;
29
+ uint nb12;
30
+ uint nb13;
26
31
  };
27
32
 
28
33
  #endif // !defined(GGML_ROPE_PARAMS)
@@ -5,10 +5,13 @@
5
5
 
6
6
  void main() {
7
7
  const uint i0 = 2*gl_GlobalInvocationID.y;
8
- // i1 is actually i2*nb2+i1, but the rows are contiguous
9
- const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
10
- if (i1 >= pc.nrows) {
8
+ const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
9
+ if (row >= pc.nrows) {
11
10
  return;
12
11
  }
13
- rope_vision(i0, i1, pc);
12
+ const uint i3 = row / (pc.ne01*pc.ne02);
13
+ const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01;
14
+ const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01);
15
+
16
+ rope_vision(i0, i1, i2, i3, pc);
14
17
  }
@@ -0,0 +1,21 @@
1
+ #version 450
2
+
3
+ #include "generic_head.glsl"
4
+ #include "types.glsl"
5
+
6
+ #extension GL_EXT_control_flow_attributes : enable
7
+
8
+ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
9
+
10
+ layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
11
+ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
12
+
13
+ void main() {
14
+ const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
15
+
16
+ if (i >= p.KX) {
17
+ return;
18
+ }
19
+
20
+ data_d[i] = D_TYPE(sign(float(data_a[i])));
21
+ }