whispercpp 1.3.5 → 1.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (610) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +99 -2
  4. data/ext/extconf.rb +1 -0
  5. data/ext/ruby_whisper.c +20 -4
  6. data/ext/ruby_whisper.h +30 -2
  7. data/ext/ruby_whisper_context.c +216 -124
  8. data/ext/ruby_whisper_context_params.c +163 -0
  9. data/ext/ruby_whisper_model.c +0 -1
  10. data/ext/ruby_whisper_params.c +0 -1
  11. data/ext/ruby_whisper_segment.c +0 -1
  12. data/ext/ruby_whisper_token.c +29 -9
  13. data/ext/ruby_whisper_transcribe.cpp +4 -1
  14. data/ext/ruby_whisper_vad_context.c +48 -1
  15. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  16. data/ext/ruby_whisper_vad_params.c +0 -1
  17. data/ext/ruby_whisper_vad_segment.c +0 -1
  18. data/ext/ruby_whisper_vad_segments.c +0 -1
  19. data/ext/sources/CMakeLists.txt +1 -1
  20. data/ext/sources/bindings/javascript/package.json +1 -1
  21. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  22. data/ext/sources/examples/bench/bench.cpp +23 -18
  23. data/ext/sources/examples/cli/cli.cpp +8 -0
  24. data/ext/sources/examples/common-ggml.cpp +2 -0
  25. data/ext/sources/examples/miniaudio.h +4507 -2131
  26. data/ext/sources/examples/server/server.cpp +18 -4
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -2
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +7 -13
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +4 -3
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +335 -17
  31. data/ext/sources/examples/talk-llama/llama-arch.h +42 -0
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +3 -1
  33. data/ext/sources/examples/talk-llama/llama-chat.cpp +21 -1
  34. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  35. data/ext/sources/examples/talk-llama/llama-context.cpp +508 -520
  36. data/ext/sources/examples/talk-llama/llama-context.h +27 -28
  37. data/ext/sources/examples/talk-llama/llama-cparams.h +5 -0
  38. data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +8 -8
  40. data/ext/sources/examples/talk-llama/llama-graph.cpp +583 -130
  41. data/ext/sources/examples/talk-llama/llama-graph.h +131 -10
  42. data/ext/sources/examples/talk-llama/llama-hparams.cpp +57 -40
  43. data/ext/sources/examples/talk-llama/llama-hparams.h +79 -10
  44. data/ext/sources/examples/talk-llama/llama-impl.cpp +4 -4
  45. data/ext/sources/examples/talk-llama/llama-impl.h +13 -1
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +3 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +274 -89
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.h +2 -3
  49. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
  50. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
  51. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +11 -13
  52. data/ext/sources/examples/talk-llama/llama-mmap.cpp +28 -11
  53. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +527 -119
  54. data/ext/sources/examples/talk-llama/llama-model-loader.h +35 -5
  55. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +60 -46
  56. data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
  57. data/ext/sources/examples/talk-llama/llama-model.cpp +1365 -647
  58. data/ext/sources/examples/talk-llama/llama-model.h +72 -19
  59. data/ext/sources/examples/talk-llama/llama-quant.cpp +578 -346
  60. data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +190 -76
  61. data/ext/sources/examples/talk-llama/{llama-sampling.h → llama-sampler.h} +0 -2
  62. data/ext/sources/examples/talk-llama/llama-vocab.cpp +118 -48
  63. data/ext/sources/examples/talk-llama/llama-vocab.h +5 -0
  64. data/ext/sources/examples/talk-llama/llama.cpp +76 -22
  65. data/ext/sources/examples/talk-llama/llama.h +63 -30
  66. data/ext/sources/examples/talk-llama/models/afmoe.cpp +2 -3
  67. data/ext/sources/examples/talk-llama/models/apertus.cpp +3 -3
  68. data/ext/sources/examples/talk-llama/models/arcee.cpp +3 -3
  69. data/ext/sources/examples/talk-llama/models/arctic.cpp +4 -5
  70. data/ext/sources/examples/talk-llama/models/baichuan.cpp +4 -3
  71. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +1 -2
  72. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +3 -5
  73. data/ext/sources/examples/talk-llama/models/bert.cpp +13 -7
  74. data/ext/sources/examples/talk-llama/models/bitnet.cpp +9 -24
  75. data/ext/sources/examples/talk-llama/models/bloom.cpp +2 -2
  76. data/ext/sources/examples/talk-llama/models/chameleon.cpp +3 -3
  77. data/ext/sources/examples/talk-llama/models/chatglm.cpp +2 -2
  78. data/ext/sources/examples/talk-llama/models/codeshell.cpp +3 -3
  79. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +3 -3
  80. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +2 -2
  81. data/ext/sources/examples/talk-llama/models/command-r.cpp +2 -2
  82. data/ext/sources/examples/talk-llama/models/dbrx.cpp +4 -5
  83. data/ext/sources/examples/talk-llama/models/deci.cpp +3 -3
  84. data/ext/sources/examples/talk-llama/models/deepseek.cpp +4 -6
  85. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +24 -21
  86. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
  87. data/ext/sources/examples/talk-llama/models/dots1.cpp +4 -6
  88. data/ext/sources/examples/talk-llama/models/dream.cpp +3 -3
  89. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +4 -6
  90. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +3 -3
  91. data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
  92. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +3 -3
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +3 -3
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +2 -4
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +3 -3
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +1 -1
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +1 -1
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +1 -1
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +1 -1
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +7 -7
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +3 -3
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +14 -7
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +2 -2
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +2 -2
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +4 -5
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +4 -5
  108. data/ext/sources/examples/talk-llama/models/grok.cpp +4 -4
  109. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +5 -7
  110. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +3 -3
  111. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +4 -5
  112. data/ext/sources/examples/talk-llama/models/internlm2.cpp +3 -3
  113. data/ext/sources/examples/talk-llama/models/jais.cpp +2 -2
  114. data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +3 -3
  116. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
  117. data/ext/sources/examples/talk-llama/models/lfm2.cpp +145 -124
  118. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +4 -4
  119. data/ext/sources/examples/talk-llama/models/llada.cpp +3 -3
  120. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +4 -4
  121. data/ext/sources/examples/talk-llama/models/llama.cpp +18 -11
  122. data/ext/sources/examples/talk-llama/models/maincoder.cpp +3 -3
  123. data/ext/sources/examples/talk-llama/models/{graph-context-mamba.cpp → mamba-base.cpp} +9 -3
  124. data/ext/sources/examples/talk-llama/models/mamba.cpp +1 -2
  125. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +11 -5
  126. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +14 -13
  127. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +4 -5
  128. data/ext/sources/examples/talk-llama/models/mistral3.cpp +4 -4
  129. data/ext/sources/examples/talk-llama/models/models.h +181 -46
  130. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +2 -9
  131. data/ext/sources/examples/talk-llama/models/mpt.cpp +2 -2
  132. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +26 -14
  133. data/ext/sources/examples/talk-llama/models/nemotron.cpp +3 -3
  134. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +2 -2
  135. data/ext/sources/examples/talk-llama/models/olmo.cpp +3 -3
  136. data/ext/sources/examples/talk-llama/models/olmo2.cpp +3 -3
  137. data/ext/sources/examples/talk-llama/models/olmoe.cpp +4 -4
  138. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +1 -1
  139. data/ext/sources/examples/talk-llama/models/openelm.cpp +3 -3
  140. data/ext/sources/examples/talk-llama/models/orion.cpp +3 -3
  141. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
  142. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +3 -3
  143. data/ext/sources/examples/talk-llama/models/phi2.cpp +2 -2
  144. data/ext/sources/examples/talk-llama/models/phi3.cpp +3 -3
  145. data/ext/sources/examples/talk-llama/models/plamo.cpp +3 -3
  146. data/ext/sources/examples/talk-llama/models/plamo2.cpp +9 -5
  147. data/ext/sources/examples/talk-llama/models/plamo3.cpp +2 -2
  148. data/ext/sources/examples/talk-llama/models/plm.cpp +15 -14
  149. data/ext/sources/examples/talk-llama/models/qwen.cpp +2 -2
  150. data/ext/sources/examples/talk-llama/models/qwen2.cpp +3 -3
  151. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +4 -4
  152. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +3 -3
  153. data/ext/sources/examples/talk-llama/models/qwen3.cpp +12 -9
  154. data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
  155. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
  156. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +15 -8
  157. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +84 -432
  158. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +9 -18
  159. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +8 -17
  160. data/ext/sources/examples/talk-llama/models/refact.cpp +2 -2
  161. data/ext/sources/examples/talk-llama/models/rnd1.cpp +4 -4
  162. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +2 -0
  163. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +2 -0
  164. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +3 -3
  165. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +4 -4
  166. data/ext/sources/examples/talk-llama/models/smollm3.cpp +3 -3
  167. data/ext/sources/examples/talk-llama/models/stablelm.cpp +2 -2
  168. data/ext/sources/examples/talk-llama/models/starcoder.cpp +2 -2
  169. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +3 -3
  170. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
  171. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +2 -2
  172. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +2 -2
  173. data/ext/sources/examples/talk-llama/models/xverse.cpp +3 -3
  174. data/ext/sources/examples/talk-llama/unicode.cpp +21 -65
  175. data/ext/sources/ggml/CMakeLists.txt +9 -3
  176. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  177. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  179. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  180. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  181. data/ext/sources/ggml/include/ggml-rpc.h +6 -1
  182. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  183. data/ext/sources/ggml/include/ggml.h +56 -9
  184. data/ext/sources/ggml/src/CMakeLists.txt +3 -0
  185. data/ext/sources/ggml/src/ggml-alloc.c +4 -9
  186. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  187. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  188. data/ext/sources/ggml/src/ggml-backend-reg.cpp +28 -86
  189. data/ext/sources/ggml/src/ggml-backend.cpp +5 -2
  190. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  191. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +6 -2
  192. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  193. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  194. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +348 -189
  195. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +40 -85
  196. data/ext/sources/ggml/src/ggml-cann/common.h +3 -4
  197. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +44 -62
  198. data/ext/sources/ggml/src/ggml-common.h +11 -0
  199. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +16 -11
  200. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -19
  201. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  202. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
  203. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +85 -1
  204. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2744 -548
  205. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1653 -0
  206. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
  207. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
  208. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
  209. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +118 -18
  210. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +107 -26
  211. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  212. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  213. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -0
  214. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +59 -12
  215. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +15 -0
  216. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  217. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +965 -252
  218. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +584 -197
  219. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +903 -188
  220. data/ext/sources/ggml/src/ggml-cpu/ops.h +1 -0
  221. data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
  222. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  223. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2890 -679
  224. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  225. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  226. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +111 -3
  227. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  228. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +17 -0
  229. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
  230. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +19 -10
  231. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +32 -30
  232. data/ext/sources/ggml/src/ggml-cuda/common.cuh +134 -18
  233. data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +6 -3
  235. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +78 -64
  236. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +384 -143
  237. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +36 -22
  238. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +3 -3
  239. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +26 -5
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  241. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +127 -12
  242. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
  243. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  244. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +595 -200
  245. data/ext/sources/ggml/src/ggml-cuda/mean.cu +9 -8
  246. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +173 -6
  247. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  248. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +158 -85
  249. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +34 -22
  250. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +127 -67
  251. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  252. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +157 -65
  253. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -0
  254. data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
  255. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  256. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +1 -1
  257. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
  258. data/ext/sources/ggml/src/ggml-cuda/rope.cu +233 -133
  259. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  260. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  261. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +56 -32
  262. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
  264. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  265. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
  266. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  267. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  268. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  269. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +3 -3
  270. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +0 -1
  271. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +199 -135
  272. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  273. data/ext/sources/ggml/src/ggml-cuda/unary.cu +55 -0
  274. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
  275. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
  276. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +10 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +82 -45
  278. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +334 -160
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +7 -5
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +328 -197
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +765 -234
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +412 -265
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +23 -23
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +1 -1
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.h → hex-dma.h} +28 -3
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +1 -1
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +27 -37
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +6 -35
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  297. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +20 -1347
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +211 -13
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +1119 -952
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +254 -244
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +36 -36
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +155 -138
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +209 -114
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  317. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  321. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +6 -0
  322. data/ext/sources/ggml/src/ggml-impl.h +62 -0
  323. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  324. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  325. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  326. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  327. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +274 -73
  328. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +22 -4
  329. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +102 -36
  330. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +174 -23
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +580 -280
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +5 -4
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +320 -107
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1068 -825
  335. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +19 -1
  336. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +3108 -636
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +204 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  367. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  368. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  369. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
  370. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  371. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
  372. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  373. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
  374. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
  375. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  376. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  377. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  378. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  379. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  380. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  381. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  382. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  383. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  384. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  385. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  386. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  387. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  388. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  389. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  390. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  391. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  392. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
  393. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  394. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  395. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  396. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  397. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  398. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  399. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
  400. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
  401. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
  402. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
  403. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  404. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  405. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  406. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  407. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
  412. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
  413. data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
  414. data/ext/sources/ggml/src/ggml-quants.c +96 -5
  415. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  416. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +15 -88
  417. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  418. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +1 -0
  419. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  420. data/ext/sources/ggml/src/ggml-sycl/common.hpp +315 -10
  421. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +69 -1
  422. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
  423. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  424. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  425. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +78 -68
  426. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  427. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
  428. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
  429. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
  430. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
  431. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
  432. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  433. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
  434. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
  435. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +316 -51
  436. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  437. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  438. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  439. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
  440. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  441. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  442. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  443. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  444. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  445. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  446. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  447. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  448. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  449. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  450. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  451. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  452. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
  453. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
  454. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
  455. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
  456. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
  457. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
  458. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
  459. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
  460. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
  461. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
  462. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
  463. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
  464. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
  465. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
  466. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
  467. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
  468. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
  469. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
  470. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
  471. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
  472. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
  473. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
  474. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
  475. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
  476. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
  477. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
  478. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
  479. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
  480. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
  481. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
  482. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
  483. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
  484. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
  485. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
  486. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
  487. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
  488. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +13 -0
  489. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  490. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  491. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  492. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  493. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  494. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  495. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  496. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  497. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  498. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  499. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  500. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  501. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  502. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  503. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  504. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  505. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  506. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  507. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  508. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  509. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  510. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  511. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  512. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  513. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  514. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  515. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  516. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  517. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  518. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  519. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  520. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  521. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  522. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  523. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  524. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  525. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  526. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  527. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  528. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  529. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  530. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  531. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  532. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +1 -1
  533. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1250 -465
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +374 -170
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +66 -22
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +389 -201
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +106 -58
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -3
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +8 -4
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +3 -3
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +36 -63
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +10 -5
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +16 -10
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +55 -35
  560. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1314 -109
  561. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1660 -1371
  562. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  563. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  564. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  565. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
  566. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
  567. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
  568. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +6 -0
  569. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  570. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +40 -5
  571. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +105 -60
  572. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
  573. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +68 -257
  574. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +692 -23
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_reg_tile.tmpl.wgsl → mul_mat_reg_tile.wgsl} +28 -128
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +31 -137
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +9 -36
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  584. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
  585. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  586. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +9 -6
  587. data/ext/sources/ggml/src/ggml.c +167 -33
  588. data/ext/sources/ggml/src/gguf.cpp +229 -44
  589. data/ext/sources/src/whisper.cpp +6 -28
  590. data/sig/whisper.rbs +43 -2
  591. data/test/test_context_params.rb +82 -0
  592. data/test/test_token.rb +11 -0
  593. data/test/test_vad_context.rb +58 -8
  594. data/test/test_whisper.rb +20 -0
  595. data/whispercpp.gemspec +1 -1
  596. metadata +240 -28
  597. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  598. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  599. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  600. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  601. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  602. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  603. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  604. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  605. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  606. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  607. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  608. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  609. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  610. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
@@ -1,5 +1,5 @@
1
1
  /*
2
- * Copyright (c) 2023-2024 The ggml authors
2
+ * Copyright (c) 2023-2026 The ggml authors
3
3
  *
4
4
  * Permission is hereby granted, free of charge, to any person obtaining a copy
5
5
  * of this software and associated documentation files (the "Software"), to
@@ -58,6 +58,7 @@
58
58
  #include <aclnnop/aclnn_mean.h>
59
59
  #include <aclnnop/aclnn_mm.h>
60
60
  #include <aclnnop/aclnn_mul.h>
61
+ #include <aclnnop/aclnn_mv.h>
61
62
  #include <aclnnop/aclnn_permute.h>
62
63
  #include <aclnnop/aclnn_pow.h>
63
64
  #include <aclnnop/aclnn_pow_tensor_tensor.h>
@@ -2338,20 +2339,21 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
2338
2339
 
2339
2340
  // Step1.2: prepare rope_yarn_ramp, if this part updated, should update theta_scale_tensor.
2340
2341
  // TODO: acl_yarn_ramp_tensor use rope cache.
2341
- bool yarn_ramp_tensor_updated = false;
2342
- acl_tensor_ptr acl_yarn_ramp_tensor;
2342
+ bool yarn_ramp_tensor_updated = false;
2343
+ acl_tensor_ptr acl_yarn_ramp_tensor;
2343
2344
  if (ext_factor != 0 && (theta_scale_updated || ctx.rope_cache.theta_scale_length != theta_scale_length ||
2344
2345
  ctx.rope_cache.freq_scale != freq_scale)) {
2345
2346
  yarn_ramp_tensor_updated = true;
2346
2347
  if (ctx.rope_cache.yarn_ramp_cache != nullptr) {
2347
2348
  ACL_CHECK(aclrtFree(ctx.rope_cache.yarn_ramp_cache));
2348
2349
  }
2349
- ACL_CHECK(aclrtMalloc(&ctx.rope_cache.yarn_ramp_cache, theta_scale_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST));
2350
+ ACL_CHECK(aclrtMalloc(&ctx.rope_cache.yarn_ramp_cache, theta_scale_length * sizeof(float),
2351
+ ACL_MEM_MALLOC_HUGE_FIRST));
2350
2352
  // -rope_yarn_ramp
2351
2353
  // const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
2352
2354
  // return MIN(1, MAX(0, y)) - 1;
2353
- acl_yarn_ramp_tensor =
2354
- ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1);
2355
+ acl_yarn_ramp_tensor = ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float),
2356
+ theta_scale_ne, theta_scale_nb, 1);
2355
2357
  float zero_value = 0, one_value = 1;
2356
2358
  float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
2357
2359
  acl_scalar_ptr low = ggml_cann_create_scalar(&corr_dims[0], aclDataType::ACL_FLOAT);
@@ -2382,8 +2384,8 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
2382
2384
  GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), freq_scale_1_sc.get());
2383
2385
  GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor.get(), freq_scale_sc.get(), one.get());
2384
2386
  } else {
2385
- acl_yarn_ramp_tensor =
2386
- ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1);
2387
+ acl_yarn_ramp_tensor = ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float),
2388
+ theta_scale_ne, theta_scale_nb, 1);
2387
2389
  }
2388
2390
  // Step 1.3: update theta_scale_tensor according to ext_factor or freq_scale.
2389
2391
  if (ext_factor != 0) {
@@ -2991,20 +2993,20 @@ void ggml_cann_argmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
2991
2993
  GGML_CANN_CALL_ACLNN_OP(ctx, ArgMax, acl_src.get(), 3, false, acl_dst.get());
2992
2994
  }
2993
2995
 
2994
- void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst){
2996
+ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
2995
2997
  ggml_tensor * src0 = dst->src[0];
2996
2998
  ggml_tensor * src1 = dst->src[1];
2997
2999
 
2998
3000
  // stride
2999
- int64_t s0 = ((const int32_t*)(dst->op_params))[0];
3001
+ int64_t s0 = ((const int32_t *) (dst->op_params))[0];
3000
3002
 
3001
- acl_tensor_ptr acl_input = ggml_cann_create_tensor(src1, src1->ne, src1->nb, 3, ACL_FORMAT_NCL);
3003
+ acl_tensor_ptr acl_input = ggml_cann_create_tensor(src1, src1->ne, src1->nb, 3, ACL_FORMAT_NCL);
3002
3004
  acl_tensor_ptr acl_weight = ggml_cann_create_tensor(src0, src0->ne, src0->nb, 3, ACL_FORMAT_NCL);
3003
- acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3, ACL_FORMAT_NCL);
3005
+ acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3, ACL_FORMAT_NCL);
3004
3006
 
3005
3007
  // get base information of input and kernel
3006
- int64_t input_len = *(src1->ne);
3007
- int64_t dst_len = *(dst->ne);
3008
+ int64_t input_len = *(src1->ne);
3009
+ int64_t dst_len = *(dst->ne);
3008
3010
  int64_t kernel_size = *(src0->ne);
3009
3011
 
3010
3012
  // set the max kernel size for each conv
@@ -3012,56 +3014,55 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
3012
3014
 
3013
3015
  // compute the partition of kernel
3014
3016
  int64_t part_num = 1;
3015
- part_num = (kernel_size + max_kernel_size - 1) / max_kernel_size;
3017
+ part_num = (kernel_size + max_kernel_size - 1) / max_kernel_size;
3016
3018
 
3017
3019
  int64_t strideVal[1];
3018
- strideVal[0] = s0;
3019
- acl_int_array_ptr stride = ggml_cann_create_int_array(strideVal, 1);
3020
- int64_t paddingVal[] = {0};
3021
- acl_int_array_ptr padding = ggml_cann_create_int_array(paddingVal, 1);
3022
- int64_t dilationVal[] = {1};
3023
- acl_int_array_ptr dilation = ggml_cann_create_int_array(dilationVal, 1);
3024
- bool transposed = true;
3025
- int64_t groups = 1;
3026
- int8_t cubeMathType = 0;
3020
+ strideVal[0] = s0;
3021
+ acl_int_array_ptr stride = ggml_cann_create_int_array(strideVal, 1);
3022
+ int64_t paddingVal[] = { 0 };
3023
+ acl_int_array_ptr padding = ggml_cann_create_int_array(paddingVal, 1);
3024
+ int64_t dilationVal[] = { 1 };
3025
+ acl_int_array_ptr dilation = ggml_cann_create_int_array(dilationVal, 1);
3026
+ bool transposed = true;
3027
+ int64_t groups = 1;
3028
+ int8_t cubeMathType = 0;
3027
3029
 
3028
3030
  #ifdef ASCEND_310P
3029
3031
  cubeMathType = 1;
3030
3032
  #endif
3031
3033
 
3032
3034
  auto weight_type = ggml_cann_type_mapping(src0->type);
3033
- auto dst_type = ggml_cann_type_mapping(dst->type);
3035
+ auto dst_type = ggml_cann_type_mapping(dst->type);
3034
3036
 
3035
3037
  // slice the kernel to make each conv available
3036
- int64_t slice_dim = -1;
3038
+ int64_t slice_dim = -1;
3037
3039
  int64_t slice_start = 0;
3038
- int64_t slice_end = max_kernel_size;
3039
- int64_t slice_step = 1;
3040
- int64_t interval = max_kernel_size;
3040
+ int64_t slice_end = max_kernel_size;
3041
+ int64_t slice_step = 1;
3042
+ int64_t interval = max_kernel_size;
3041
3043
 
3042
- int64_t left_pad_len = dilationVal[0] * (max_kernel_size - 1) + 1 - 2 * paddingVal[0];
3044
+ int64_t left_pad_len = dilationVal[0] * (max_kernel_size - 1) + 1 - 2 * paddingVal[0];
3043
3045
  int64_t right_pad_len = 0;
3044
3046
 
3045
- acl_scalar_ptr alpha = nullptr;
3046
- float alphaValue = 1.0;
3047
- alpha = ggml_cann_create_scalar(&alphaValue, aclDataType::ACL_FLOAT);
3047
+ acl_scalar_ptr alpha = nullptr;
3048
+ float alphaValue = 1.0;
3049
+ alpha = ggml_cann_create_scalar(&alphaValue, aclDataType::ACL_FLOAT);
3048
3050
 
3049
3051
  // set zero to destination
3050
3052
  GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_dst.get());
3051
3053
 
3052
- for(int k = 0; k < part_num; k++){
3053
-
3054
+ for (int k = 0; k < part_num; k++) {
3054
3055
  // create part kernel tensor and slice from big kernel
3055
3056
  slice_start = max_kernel_size * k;
3056
- if(k == part_num - 1){
3057
+ if (k == part_num - 1) {
3057
3058
  slice_end = kernel_size;
3058
- interval = kernel_size - max_kernel_size * k;
3059
- }else{
3060
- slice_end = max_kernel_size * (k+1);
3059
+ interval = kernel_size - max_kernel_size * k;
3060
+ } else {
3061
+ slice_end = max_kernel_size * (k + 1);
3061
3062
  }
3062
3063
 
3063
3064
  int64_t part_ne[4];
3064
- for(int i = 0; i < 4; i++) {
3065
+ for (int i = 0; i < 4; i++) {
3065
3066
  part_ne[i] = *(src0->ne + i);
3066
3067
  }
3067
3068
  part_ne[0] = interval;
@@ -3074,16 +3075,17 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
3074
3075
 
3075
3076
  ggml_cann_pool_alloc part_kernel_allocator;
3076
3077
  part_kernel_allocator.alloc(ctx.pool(), part_nb[3]);
3077
- void* part_kernel_buf = part_kernel_allocator.get();
3078
+ void * part_kernel_buf = part_kernel_allocator.get();
3078
3079
 
3079
- acl_tensor_ptr part_kernel = ggml_cann_create_tensor(part_kernel_buf, weight_type,
3080
- ggml_element_size(src0), part_ne, part_nb, 3, ACL_FORMAT_NCL);
3080
+ acl_tensor_ptr part_kernel = ggml_cann_create_tensor(part_kernel_buf, weight_type, ggml_element_size(src0),
3081
+ part_ne, part_nb, 3, ACL_FORMAT_NCL);
3081
3082
 
3082
- GGML_CANN_CALL_ACLNN_OP(ctx, Slice, acl_weight.get(), slice_dim, slice_start, slice_end, slice_step, part_kernel.get());
3083
+ GGML_CANN_CALL_ACLNN_OP(ctx, Slice, acl_weight.get(), slice_dim, slice_start, slice_end, slice_step,
3084
+ part_kernel.get());
3083
3085
 
3084
3086
  // create the part conv result tensor
3085
3087
  int64_t part_dst_ne[4];
3086
- for(int i = 0; i < 4; i++){
3088
+ for (int i = 0; i < 4; i++) {
3087
3089
  part_dst_ne[i] = *(dst->ne + i);
3088
3090
  }
3089
3091
  part_dst_ne[0] = (input_len - 1) * strideVal[0] - 2 * paddingVal[0] + dilationVal[0] * (part_ne[0] - 1) + 1;
@@ -3095,32 +3097,33 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
3095
3097
  }
3096
3098
  ggml_cann_pool_alloc part_dst_allocator;
3097
3099
  part_dst_allocator.alloc(ctx.pool(), part_dst_nb[3]);
3098
- void* part_dst_buf = part_dst_allocator.get();
3100
+ void * part_dst_buf = part_dst_allocator.get();
3099
3101
 
3100
3102
  acl_tensor_ptr acl_part_dst = ggml_cann_create_tensor(part_dst_buf, dst_type, ggml_element_size(dst),
3101
- part_dst_ne, part_dst_nb, 3, ACL_FORMAT_NCL);
3103
+ part_dst_ne, part_dst_nb, 3, ACL_FORMAT_NCL);
3102
3104
  GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_part_dst.get());
3103
3105
 
3104
3106
  // compute part conv transpose 1d
3105
3107
  GGML_CANN_CALL_ACLNN_OP(ctx, Convolution, acl_input.get(), part_kernel.get(), nullptr, stride.get(),
3106
- padding.get(), dilation.get(), transposed, padding.get(), groups, acl_part_dst.get(), cubeMathType);
3108
+ padding.get(), dilation.get(), transposed, padding.get(), groups, acl_part_dst.get(),
3109
+ cubeMathType);
3107
3110
 
3108
3111
  // compute the position of part result in final result
3109
3112
  int64_t global_start = slice_start;
3110
- int64_t global_end = std::min((input_len - 1) * strideVal[0] + slice_end, dst_len);
3113
+ int64_t global_end = std::min((input_len - 1) * strideVal[0] + slice_end, dst_len);
3111
3114
 
3112
- left_pad_len = global_start;
3115
+ left_pad_len = global_start;
3113
3116
  right_pad_len = dst_len - global_end;
3114
3117
 
3115
- std::vector<int64_t> padDataVal = {left_pad_len,right_pad_len};
3116
- acl_int_array_ptr padData = ggml_cann_create_int_array(padDataVal.data(), 2);
3118
+ std::vector<int64_t> padDataVal = { left_pad_len, right_pad_len };
3119
+ acl_int_array_ptr padData = ggml_cann_create_int_array(padDataVal.data(), 2);
3117
3120
 
3118
- acl_scalar_ptr pad_value = nullptr;
3119
- float pad_valueVal = 0.0;
3120
- pad_value = ggml_cann_create_scalar(&pad_valueVal, aclDataType::ACL_FLOAT);
3121
+ acl_scalar_ptr pad_value = nullptr;
3122
+ float pad_valueVal = 0.0;
3123
+ pad_value = ggml_cann_create_scalar(&pad_valueVal, aclDataType::ACL_FLOAT);
3121
3124
 
3122
3125
  int64_t conv_result_ne[4];
3123
- for(int i = 0; i < 4; i++){
3126
+ for (int i = 0; i < 4; i++) {
3124
3127
  conv_result_ne[i] = *(dst->ne + i);
3125
3128
  }
3126
3129
 
@@ -3132,13 +3135,14 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
3132
3135
 
3133
3136
  ggml_cann_pool_alloc conv_result_allocator;
3134
3137
  conv_result_allocator.alloc(ctx.pool(), conv_result_nb[3]);
3135
- void* conv_result_buf = conv_result_allocator.get();
3138
+ void * conv_result_buf = conv_result_allocator.get();
3136
3139
 
3137
3140
  acl_tensor_ptr conv_result = ggml_cann_create_tensor(conv_result_buf, dst_type, ggml_element_size(dst),
3138
- conv_result_ne, conv_result_nb, 3, ACL_FORMAT_NCL);
3141
+ conv_result_ne, conv_result_nb, 3, ACL_FORMAT_NCL);
3139
3142
 
3140
3143
  GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, conv_result.get());
3141
- GGML_CANN_CALL_ACLNN_OP(ctx, ConstantPadNd, acl_part_dst.get(), padData.get(), pad_value.get(), conv_result.get());
3144
+ GGML_CANN_CALL_ACLNN_OP(ctx, ConstantPadNd, acl_part_dst.get(), padData.get(), pad_value.get(),
3145
+ conv_result.get());
3142
3146
  GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, acl_dst.get(), conv_result.get(), alpha.get());
3143
3147
  }
3144
3148
  }
@@ -3282,130 +3286,223 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context & ctx, ggml_tensor
3282
3286
  }
3283
3287
 
3284
3288
  /**
3285
- * @brief Performs expert-specific matrix multiplication (MoE) with
3286
- * quantized precision using the CANN backend.
3287
- *
3288
- * This function executes a matrix multiplication operation tailored for
3289
- * Mixture of Experts (MoE) models, where the input tensor is multiplied
3290
- * with expert-specific quantized weight matrices. It leverages the CANN
3291
- * backend to perform efficient low-precision computations and stores the
3292
- * quantized result in the destination tensor `dst`.
3293
- *
3294
- * Quantization techniques reduce memory footprint and improve performance
3295
- * by using lower-bit representations (e.g., int8) instead of floating-point.
3296
- * This function is designed to work with such formats and may incorporate
3297
- * optimizations like identity-based fast paths or routing masks for sparse
3298
- * expert selection.
3299
- *
3300
- * @param ctx The context for executing CANN backend operations.
3301
- * @param dst The destination tensor where the quantized MoE multiplication result
3302
- * will be stored.
3303
- *
3304
- * @note This function assumes quantized data types and is designed for
3305
- * MoE architectures with potential sparse expert routing.
3289
+ * @brief Performs quantized matrix multiplication for Mixture of Experts (MoE)
3290
+ * models using the CANN backend.
3291
+ *
3292
+ * This function implements MUL_MAT_ID operation for quantized weight matrices
3293
+ * (Q4_0 and Q8_0 formats). It selects expert-specific weight matrices based on
3294
+ * the provided expert indices, and computes matrix multiplication using CANN's
3295
+ * WeightQuantBatchMatmulV2 operator.
3296
+ *
3297
+ * The function performs the following steps:
3298
+ * 1. Converts input/output tensors to F16 format if necessary
3299
+ * 2. Uses IndexSelect to extract expert-specific weights and scales based on indices
3300
+ * 3. Performs quantized matrix multiplication for each expert using WeightQuantBatchMatmulV2
3301
+ * 4. Converts output back to the target type if needed
3302
+ *
3303
+ * Tensor shapes:
3304
+ * - dst: [M, K, N, 1] - output tensor
3305
+ * - src0: [D, M, A, 1] - quantized weight matrices (Q4_0 or Q8_0)
3306
+ * - src1: [D, B, N, 1] - input activations (B = K for per-expert input, or B = 1 for broadcast)
3307
+ * - ids: [K, N] - expert indices for routing
3308
+ *
3309
+ * @param ctx The CANN backend context for operation execution.
3310
+ * @param dst The destination tensor where the multiplication result will be stored.
3311
+ *
3312
+ * @note Only Q4_0 and Q8_0 quantization formats are supported.
3313
+ * @note The function handles automatic type conversion to/from F16 as needed by the hardware.
3306
3314
  */
3307
3315
  static void ggml_cann_mul_mat_id_quant(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
3308
- // TODO: Use aclnnGroupedMatMul
3309
- //dst [M, K, N, 1]
3310
- ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1]
3311
- ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1
3312
- ggml_tensor * ids = dst->src[2]; //ids [K, N]
3316
+ // dst: [M, K, N, 1]
3317
+ // src0: [D, M, A, 1] - quantized weights
3318
+ // src1: [D, B, N, 1] - input activations, B = K or B = 1
3319
+ // ids: [K, N] - expert indices
3320
+ ggml_tensor * src0 = dst->src[0];
3321
+ ggml_tensor * src1 = dst->src[1];
3322
+ ggml_tensor * ids = dst->src[2];
3313
3323
 
3314
- GGML_TENSOR_BINARY_OP_LOCALS
3324
+ GGML_ASSERT(src0->ne[3] == 1);
3325
+ GGML_ASSERT(src1->ne[3] == 1);
3326
+ GGML_ASSERT(dst->ne[3] == 1);
3327
+ GGML_ASSERT(src1->ne[2] == ids->ne[1]);
3328
+
3329
+ const int64_t n_batches = ids->ne[1];
3330
+ const int64_t n_select_experts = ids->ne[0];
3331
+ const enum ggml_type type = src0->type;
3332
+
3333
+ const int32_t group_size = QK8_0; // Both Q4_0 and Q8_0 use group size of 32
3334
+ GGML_ASSERT(group_size == QK4_0);
3335
+
3336
+ // Calculate element size for quantized weights
3337
+ const float weight_elem_size =
3338
+ (type == GGML_TYPE_Q4_0) ? 0.5f :
3339
+ (type == GGML_TYPE_Q8_0) ? 1.0f :
3340
+ (GGML_ABORT("MUL_MAT_ID only supports Q4_0 and Q8_0"), 0.0f);
3341
+
3342
+ // Calculate scale offset in memory
3343
+ const size_t weight_size = src0->ne[0] * src0->ne[1] * src0->ne[2] * weight_elem_size;
3344
+ const size_t scale_elem_size = sizeof(uint16_t);
3345
+ char * scale_data = (char *) src0->data + weight_size;
3346
+
3347
+ // Allocate buffers for selected expert weights and scales
3348
+ const size_t selected_weight_size = src0->ne[0] * src0->ne[1] * n_select_experts * weight_elem_size;
3349
+ ggml_cann_pool_alloc selected_weight_alloc(ctx.pool(), selected_weight_size);
3350
+ void * selected_weight_buffer = selected_weight_alloc.get();
3351
+
3352
+ const size_t selected_scale_size = (src0->ne[0] / group_size) * src0->ne[1] * n_select_experts * scale_elem_size;
3353
+ ggml_cann_pool_alloc selected_scale_alloc(ctx.pool(), selected_scale_size);
3354
+ void * selected_scale_buffer = selected_scale_alloc.get();
3355
+
3356
+ // Helper lambda to allocate and cast tensor to F16 if needed
3357
+ constexpr size_t f16_elem_size = sizeof(uint16_t);
3358
+ auto prepare_f16_buffer = [&](ggml_tensor * tensor, ggml_cann_pool_alloc & allocator,
3359
+ bool need_cast = false) -> void * {
3360
+ if (tensor->type == GGML_TYPE_F16) {
3361
+ return tensor->data;
3362
+ }
3315
3363
 
3316
- // copy index from npu to cpu
3317
- int64_t n_as = ne02; // A
3318
- int64_t n_ids = ids->ne[0]; // K
3364
+ size_t total_size = f16_elem_size;
3365
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
3366
+ total_size *= tensor->ne[i];
3367
+ }
3368
+ void * buffer = allocator.alloc(total_size);
3319
3369
 
3320
- std::vector<char> ids_host(ggml_nbytes(ids));
3321
- ACL_CHECK(aclrtMemcpyAsync(ids_host.data(), ggml_nbytes(ids), ids->data, ggml_nbytes(ids),
3322
- ACL_MEMCPY_DEVICE_TO_HOST, ctx.stream()));
3323
- ACL_CHECK(aclrtSynchronizeStream(ctx.stream()));
3370
+ if (need_cast == false) {
3371
+ return buffer;
3372
+ }
3324
3373
 
3325
- char * src0_original = (char *) src0->data;
3326
- char * src1_original = (char *) src1->data;
3327
- char * dst_original = (char *) dst->data;
3374
+ int64_t ne[GGML_MAX_DIMS];
3375
+ size_t nb[GGML_MAX_DIMS] = { f16_elem_size };
3376
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
3377
+ ne[i] = tensor->ne[i];
3378
+ if (i > 0) {
3379
+ nb[i] = nb[i - 1] * ne[i - 1];
3380
+ }
3381
+ }
3328
3382
 
3329
- ggml_tensor src0_row = *src0;
3330
- ggml_tensor src1_row = *src1;
3331
- ggml_tensor dst_row = *dst;
3383
+ acl_tensor_ptr src_tensor = ggml_cann_create_tensor(tensor);
3384
+ acl_tensor_ptr f16_tensor = ggml_cann_create_tensor(buffer, ACL_FLOAT16, f16_elem_size, ne, nb, GGML_MAX_DIMS);
3385
+ aclnn_cast(ctx, src_tensor.get(), f16_tensor.get(), ACL_FLOAT16);
3332
3386
 
3333
- const enum ggml_type type = dst->src[0]->type;
3334
- float weight_elem_size;
3335
- if (type == GGML_TYPE_Q4_0) {
3336
- weight_elem_size = float(sizeof(uint8_t)) / 2;
3337
- } else if (type == GGML_TYPE_Q8_0) {
3338
- weight_elem_size = float(sizeof(uint8_t));
3339
- } else {
3340
- GGML_ABORT("MUL_MAT_ID only support quant type Q4_0 and Q8_0 ");
3341
- }
3387
+ return buffer;
3388
+ };
3342
3389
 
3343
- // src0_row [D, M, 1, 1] weight without permute
3344
- src0_row.ne[2] = 1;
3345
- src0_row.ne[3] = 1;
3346
- src0_row.nb[0] = weight_elem_size;
3347
- src0_row.nb[1] = weight_elem_size * ne00;
3348
- src0_row.nb[2] = weight_elem_size * ne00;
3349
- src0_row.nb[3] = weight_elem_size * ne00;
3350
- size_t weight_stride = ne00 * ne01 * weight_elem_size;
3351
- size_t weight_size = weight_stride * ne02 * ne03;
3390
+ // Prepare input and output buffers
3391
+ ggml_cann_pool_alloc input_alloc(ctx.pool());
3392
+ void * input_buffer = prepare_f16_buffer(src1, input_alloc, true);
3352
3393
 
3353
- // scale [D, M, 1, 1] -> scale && permute
3354
- size_t scale_elem_size = sizeof(uint16_t);
3355
- size_t scale_stride = src0->ne[1] * src0->ne[0] / QK8_0 * scale_elem_size;
3394
+ ggml_cann_pool_alloc output_alloc(ctx.pool());
3395
+ void * output_buffer = prepare_f16_buffer(dst, output_alloc, false);
3396
+
3397
+ // Process each batch
3398
+ for (int64_t batch_idx = 0; batch_idx < n_batches; batch_idx++) {
3399
+ // Create index tensor for current batch
3400
+ const size_t index_offset = batch_idx * ids->nb[1];
3401
+ acl_tensor_ptr batch_indices = ggml_cann_create_tensor(ids, ids->ne, ids->nb, 1, ACL_FORMAT_ND, index_offset);
3402
+
3403
+ // Select quantized weights using expert indices
3404
+ // Q4_0 stores 2 values per byte, Q8_0 stores 1 value per byte
3405
+ const int64_t weight_d = (type == GGML_TYPE_Q4_0) ? src0->ne[0] / 2 : src0->ne[0];
3406
+ const int64_t weight_m = src0->ne[1];
3407
+ const int64_t weight_n_experts = src0->ne[2];
3408
+
3409
+ int64_t weight_ne[3] = { weight_d, weight_m, weight_n_experts };
3410
+ size_t weight_nb[3] = { sizeof(int8_t), weight_d * sizeof(int8_t), weight_d * weight_m * sizeof(int8_t) };
3411
+
3412
+ acl_tensor_ptr all_weights =
3413
+ ggml_cann_create_tensor(src0->data, ACL_INT8, sizeof(int8_t), weight_ne, weight_nb, 3);
3414
+
3415
+ int64_t selected_weight_ne[3] = { weight_d, weight_m, n_select_experts };
3416
+ size_t selected_weight_nb[3] = { sizeof(int8_t), weight_d * sizeof(int8_t),
3417
+ weight_d * weight_m * sizeof(int8_t) };
3418
+
3419
+ acl_tensor_ptr selected_weights = ggml_cann_create_tensor(selected_weight_buffer, ACL_INT8, sizeof(int8_t),
3420
+ selected_weight_ne, selected_weight_nb, 3);
3421
+
3422
+ GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, all_weights.get(), 0, batch_indices.get(), selected_weights.get());
3423
+
3424
+ // Select scales using the same expert indices
3425
+ const int64_t scale_d = src0->ne[0] / group_size;
3426
+ int64_t scale_ne[3] = { scale_d, weight_m, weight_n_experts };
3427
+ size_t scale_nb[3] = { scale_elem_size, scale_d * scale_elem_size, scale_d * weight_m * scale_elem_size };
3428
+
3429
+ acl_tensor_ptr all_scales =
3430
+ ggml_cann_create_tensor(scale_data, ACL_FLOAT16, scale_elem_size, scale_ne, scale_nb, 3);
3431
+
3432
+ int64_t selected_scale_ne[3] = { scale_d, weight_m, n_select_experts };
3433
+ size_t selected_scale_nb[3] = { scale_elem_size, scale_d * scale_elem_size,
3434
+ scale_d * weight_m * scale_elem_size };
3435
+
3436
+ acl_tensor_ptr selected_scales = ggml_cann_create_tensor(selected_scale_buffer, ACL_FLOAT16, scale_elem_size,
3437
+ selected_scale_ne, selected_scale_nb, 3);
3438
+
3439
+ GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, all_scales.get(), 0, batch_indices.get(), selected_scales.get());
3440
+
3441
+ // Process each expert for current batch
3442
+ // IndexSelect output layout: [D, M, K] in contiguous format
3443
+ // WeightQuantBatchMatmulV2 expects: [M, D] with row-major stride
3444
+ for (int64_t expert_idx = 0; expert_idx < n_select_experts; expert_idx++) {
3445
+ // Determine input offset: broadcast if src1->ne[1]==1, otherwise use per-expert input
3446
+ const size_t input_offset =
3447
+ (batch_idx * src1->ne[1] + (src1->ne[1] == 1 ? 0 : expert_idx)) * src1->ne[0] * f16_elem_size;
3448
+ const size_t output_offset = (batch_idx * dst->ne[1] + expert_idx) * dst->ne[0] * f16_elem_size;
3449
+
3450
+ // Create weight view for current expert: [D, M, K] -> [M, D]
3451
+ int64_t weight_view_ne[2] = { weight_m, src0->ne[0] };
3452
+ float weight_view_nb[2] = { src0->ne[0] * weight_elem_size, weight_elem_size };
3453
+ const size_t weight_view_offset = expert_idx * selected_weight_nb[2];
3454
+
3455
+ acl_tensor_ptr weight_view =
3456
+ ggml_cann_create_tensor(selected_weight_buffer, ggml_cann_type_mapping(type), weight_elem_size,
3457
+ weight_view_ne, weight_view_nb, 2, ACL_FORMAT_ND, weight_view_offset);
3458
+
3459
+ // Create scale view for current expert: [D, M, K] -> [M, D]
3460
+ int64_t scale_view_ne[2] = { weight_m, scale_d };
3461
+ size_t scale_view_nb[2] = { selected_scale_nb[1], selected_scale_nb[0] };
3462
+ const size_t scale_view_offset = expert_idx * selected_scale_nb[2];
3463
+
3464
+ acl_tensor_ptr scale_view =
3465
+ ggml_cann_create_tensor(selected_scale_buffer, ACL_FLOAT16, scale_elem_size, scale_view_ne,
3466
+ scale_view_nb, 2, ACL_FORMAT_ND, scale_view_offset);
3356
3467
 
3357
- // src1_row [D, 1, 1, 1] -> input
3358
- src1_row.ne[1] = 1;
3359
- src1_row.ne[2] = 1;
3360
- src1_row.ne[3] = 1;
3361
- src1_row.nb[2] = nb11;
3362
- src1_row.nb[3] = nb11;
3363
-
3364
- // dst_row [M, 1, 1, 1] -> out
3365
- dst_row.ne[1] = 1;
3366
- dst_row.ne[2] = 1;
3367
- dst_row.ne[3] = 1;
3368
- dst_row.nb[2] = nb1;
3369
- dst_row.nb[3] = nb1;
3370
-
3371
- //create weight for one row
3372
- ggml_cann_pool_alloc weight_allocator(ctx.pool());
3373
- void * weight_buffer = weight_allocator.alloc(nb02);
3374
- for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
3375
- for (int64_t id = 0; id < n_ids; id++) {
3376
- // expert index
3377
- int32_t i02 = *(int32_t *) (ids_host.data() + iid1 * ids->nb[1] + id * ids->nb[0]);
3378
- GGML_ASSERT(i02 >= 0 && i02 < n_as);
3379
-
3380
- // If B = 1 (broadcast), always use 0; otherwise, use id.
3381
- int64_t i11 = (ne11 == 1 ? 0 : id);
3382
- int64_t i12 = iid1;
3383
-
3384
- int64_t i1 = id;
3385
- int64_t i2 = i12;
3386
-
3387
- void * src0_tmp_ptr = src0_original + i02 * weight_stride;
3388
- void * scale_tmp_ptr = src0_original + weight_size + i02 * scale_stride;
3389
- void * src1_tmp_ptr = src1_original + i11 * nb11 + i12 * nb12;
3390
- void * dst_tmp_ptr = dst_original + i1 * nb1 + i2 * nb2;
3391
-
3392
- // mem cpy
3393
- ACL_CHECK(aclrtMemcpyAsync(weight_buffer, weight_stride, src0_tmp_ptr, weight_stride,
3394
- ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
3395
- void * scale_buffer = (char *) weight_buffer + weight_stride;
3396
- ACL_CHECK(aclrtMemcpyAsync(scale_buffer, scale_stride, scale_tmp_ptr, scale_stride,
3397
- ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
3398
-
3399
- src0_row.data = weight_buffer;
3400
- src1_row.data = src1_tmp_ptr;
3401
- dst_row.data = dst_tmp_ptr;
3402
- dst_row.src[0] = &src0_row;
3403
- dst_row.src[1] = &src1_row;
3404
-
3405
- ggml_cann_mul_mat(ctx, &dst_row);
3468
+ // Create input activation tensor [D, 1]
3469
+ int64_t input_ne[2] = { src1->ne[0], 1 };
3470
+ size_t input_nb[2] = { f16_elem_size, src1->ne[0] * f16_elem_size };
3471
+
3472
+ acl_tensor_ptr input_tensor = ggml_cann_create_tensor(input_buffer, ACL_FLOAT16, f16_elem_size, input_ne,
3473
+ input_nb, 2, ACL_FORMAT_ND, input_offset);
3474
+
3475
+ // Create output tensor [M, 1]
3476
+ int64_t output_ne[2] = { dst->ne[0], 1 };
3477
+ size_t output_nb[2] = { f16_elem_size, dst->ne[0] * f16_elem_size };
3478
+
3479
+ acl_tensor_ptr output_tensor = ggml_cann_create_tensor(output_buffer, ACL_FLOAT16, f16_elem_size, output_ne,
3480
+ output_nb, 2, ACL_FORMAT_ND, output_offset);
3481
+
3482
+ // Perform quantized matrix multiplication
3483
+ GGML_CANN_CALL_ACLNN_OP(ctx, WeightQuantBatchMatmulV2, input_tensor.get(), weight_view.get(),
3484
+ scale_view.get(), nullptr, nullptr, nullptr, nullptr, group_size,
3485
+ output_tensor.get());
3406
3486
  }
3407
3487
  }
3408
- return;
3488
+
3489
+ // Cast output back to original type if we used a temporary F16 buffer
3490
+ if (dst->type != GGML_TYPE_F16) {
3491
+ int64_t ne[GGML_MAX_DIMS];
3492
+ size_t nb[GGML_MAX_DIMS] = { f16_elem_size };
3493
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
3494
+ ne[i] = dst->ne[i];
3495
+ if (i > 0) {
3496
+ nb[i] = nb[i - 1] * ne[i - 1];
3497
+ }
3498
+ }
3499
+
3500
+ acl_tensor_ptr f16_output =
3501
+ ggml_cann_create_tensor(output_buffer, ACL_FLOAT16, f16_elem_size, ne, nb, GGML_MAX_DIMS);
3502
+ acl_tensor_ptr dst_tensor = ggml_cann_create_tensor(dst);
3503
+
3504
+ aclnn_cast(ctx, f16_output.get(), dst_tensor.get(), ggml_cann_type_mapping(dst->type));
3505
+ }
3409
3506
  }
3410
3507
 
3411
3508
  void ggml_cann_mul_mat_id(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
@@ -3742,15 +3839,15 @@ void ggml_cann_ssm_conv(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
3742
3839
  // we want a view: ne_w = { nc, 1, nr } // [K, 1, C]
3743
3840
  // so that reversed dims -> [C, 1, K] which matches
3744
3841
  // [out_channels, in_channels/groups, kernel_size]
3745
- int64_t w_ne[GGML_MAX_DIMS] = { nc, 1, nr, 1 }; // [K, 1 input ch. per group, C groups]
3842
+ int64_t w_ne[GGML_MAX_DIMS] = { nc, 1, nr, 1 }; // [K, 1 input ch. per group, C groups]
3746
3843
  // Layout: src1 data is [K, C] with
3747
3844
  // offset(k, c) = k*nb0 + c*nb1
3748
3845
  // We want offset_w(k, 0, c) = k*nb0 + c*nb1,
3749
3846
  // so we can reuse nb0 and nb1, and set nb2 = nb1.
3750
- size_t w_nb[GGML_MAX_DIMS] = { src1->nb[0], src1->nb[1], src1->nb[1], src1->nb[3] }; // same as src1
3847
+ size_t w_nb[GGML_MAX_DIMS] = { src1->nb[0], src1->nb[1], src1->nb[1], src1->nb[3] }; // same as src1
3751
3848
 
3752
- acl_tensor_ptr acl_w = ggml_cann_create_tensor(
3753
- src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type), w_ne, w_nb, 3, ACL_FORMAT_NCL);
3849
+ acl_tensor_ptr acl_w = ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type),
3850
+ ggml_type_size(src1->type), w_ne, w_nb, 3, ACL_FORMAT_NCL);
3754
3851
 
3755
3852
  // 3) Output: dst is { d_inner, n_t, n_s } (CLN)
3756
3853
  //
@@ -3768,11 +3865,12 @@ void ggml_cann_ssm_conv(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
3768
3865
  // nb_y[0] = nr * sizeof(float); // step in L
3769
3866
  // nb_y[1] = sizeof(float); // step in C
3770
3867
  // nb_y[2] = nr * n_t * sizeof(float); // step in N
3771
- int64_t y_ne[GGML_MAX_DIMS] = { n_t, nr, n_s, 1 }; // [L_out, C, N]
3772
- size_t y_nb[GGML_MAX_DIMS] = { dst->ne[0] * sizeof(float), sizeof(float), dst->ne[0] * dst->ne[1] * sizeof(float), dst->nb[3] }; // [nr, 1, nr * n_t]
3868
+ int64_t y_ne[GGML_MAX_DIMS] = { n_t, nr, n_s, 1 }; // [L_out, C, N]
3869
+ size_t y_nb[GGML_MAX_DIMS] = { dst->ne[0] * sizeof(float), sizeof(float), dst->ne[0] * dst->ne[1] * sizeof(float),
3870
+ dst->nb[3] }; // [nr, 1, nr * n_t]
3773
3871
 
3774
- acl_tensor_ptr acl_y = ggml_cann_create_tensor(
3775
- dst->data, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), y_ne, y_nb, 3, ACL_FORMAT_NCL);
3872
+ acl_tensor_ptr acl_y = ggml_cann_create_tensor(dst->data, ggml_cann_type_mapping(dst->type),
3873
+ ggml_type_size(dst->type), y_ne, y_nb, 3, ACL_FORMAT_NCL);
3776
3874
 
3777
3875
  // --- Conv1d parameters: depthwise, stride 1, no padding ("valid") ---
3778
3876
  int64_t strideVal[1] = { 1 };
@@ -3791,22 +3889,15 @@ void ggml_cann_ssm_conv(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
3791
3889
  cubeMathType = 1;
3792
3890
  #endif
3793
3891
 
3794
- GGML_CANN_CALL_ACLNN_OP(ctx,
3795
- Convolution,
3892
+ GGML_CANN_CALL_ACLNN_OP(ctx, Convolution,
3796
3893
  acl_x.get(), // input: N, C, L_in = ncs
3797
3894
  acl_w.get(), // weight: [C, 1, K] with groups=nr
3798
3895
  nullptr, // bias
3799
- stride.get(),
3800
- padding.get(),
3801
- dilation.get(),
3802
- transposed,
3803
- padding.get(), // output padding (unused for non-transposed)
3804
- groups,
3805
- acl_y.get(),
3806
- cubeMathType);
3896
+ stride.get(), padding.get(), dilation.get(), transposed,
3897
+ padding.get(), // output padding (unused for non-transposed)
3898
+ groups, acl_y.get(), cubeMathType);
3807
3899
  }
3808
3900
 
3809
-
3810
3901
  void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx,
3811
3902
  ggml_tensor * add_node,
3812
3903
  ggml_tensor * rms_norm_node) {
@@ -3860,3 +3951,71 @@ void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx,
3860
3951
  eps, // double type
3861
3952
  acl_yout.get(), acl_rstd.get(), acl_xout.get());
3862
3953
  }
3954
+
3955
+ void ggml_cann_gated_linear_attn(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
3956
+ ggml_tensor * k = dst->src[0];
3957
+ ggml_tensor * v = dst->src[1];
3958
+ ggml_tensor * q = dst->src[2];
3959
+ ggml_tensor * g = dst->src[3];
3960
+ ggml_tensor * s = dst->src[4];
3961
+
3962
+ int64_t B = dst->src[4]->ne[1];
3963
+ int64_t T = dst->src[0]->ne[2];
3964
+ int64_t H = dst->src[0]->ne[1];
3965
+ int64_t C = dst->ne[0];
3966
+ int64_t D = C / H;
3967
+ int64_t L = T / B;
3968
+
3969
+ int64_t ne_qkg[2] = { 1, D };
3970
+ int64_t ne_s[2] = { D, D };
3971
+ int64_t ne_st[2] = { ne_s[1], ne_s[0] };
3972
+ int64_t ne_vo[2] = { D, 1 };
3973
+ int64_t ne_q[1] = { D };
3974
+ size_t nb_base = ggml_type_size(k->type);
3975
+ size_t nb_qkg[2] = { nb_base, nb_base };
3976
+ size_t nb_s[2] = { nb_base, D * nb_base };
3977
+ size_t nb_st[2] = { nb_s[1], nb_s[0] };
3978
+ size_t nb_vo[2] = { nb_base, D * nb_base };
3979
+ size_t nb_q[1] = { nb_base };
3980
+
3981
+ const float scale = ggml_get_op_params_f32(dst, 0);
3982
+
3983
+ acl_tensor_ptr acl_s = ggml_cann_create_tensor(s, s->ne, s->nb, 2, ACL_FORMAT_ND);
3984
+ acl_tensor_ptr new_state = ggml_cann_create_tensor(dst, s->ne, s->nb, 2, ACL_FORMAT_ND, (B * L * H * D) * nb_base);
3985
+ cann_copy(ctx, acl_s.get(), new_state.get());
3986
+
3987
+ for (int64_t b = 0; b < B; b++) {
3988
+ for (int64_t h = 0; h < H; h++) {
3989
+ size_t s_offset = (b * (H * D * D) + h * (D * D)) * nb_base;
3990
+ // D * D
3991
+ acl_tensor_ptr acl_s_new =
3992
+ ggml_cann_create_tensor(dst, ne_s, nb_s, 2, ACL_FORMAT_ND, (B * L * H * D) * nb_base + s_offset);
3993
+ acl_tensor_ptr acl_s_new_t =
3994
+ ggml_cann_create_tensor(dst, ne_st, nb_st, 2, ACL_FORMAT_ND, (B * L * H * D) * nb_base + s_offset);
3995
+ for (int64_t l = 0; l < L; l++) {
3996
+ size_t qkvgo_offset = (b * (L * H * D) + l * (H * D) + h * (D)) * nb_base;
3997
+ // D * 1
3998
+ acl_tensor_ptr acl_k = ggml_cann_create_tensor(k, ne_qkg, nb_qkg, 2, ACL_FORMAT_ND, qkvgo_offset);
3999
+ acl_tensor_ptr acl_g = ggml_cann_create_tensor(g, ne_qkg, nb_qkg, 2, ACL_FORMAT_ND, qkvgo_offset);
4000
+ // D
4001
+ acl_tensor_ptr acl_q = ggml_cann_create_tensor(q, ne_q, nb_q, 1, ACL_FORMAT_ND, qkvgo_offset);
4002
+ // 1 * D
4003
+ acl_tensor_ptr acl_v = ggml_cann_create_tensor(v, ne_vo, nb_vo, 2, ACL_FORMAT_ND, qkvgo_offset);
4004
+ // D
4005
+ acl_tensor_ptr acl_o = ggml_cann_create_tensor(dst, ne_q, nb_q, 1, ACL_FORMAT_ND, qkvgo_offset);
4006
+ // k ⊗ v
4007
+ size_t buf_size = D * D * nb_base;
4008
+ ggml_cann_pool_alloc buffer_allocator(ctx.pool(), buf_size);
4009
+ acl_tensor_ptr tmp_tensor = ggml_cann_create_tensor(
4010
+ buffer_allocator.get(), ggml_cann_type_mapping(k->type), nb_base, ne_s, nb_s, 2);
4011
+ aclnn_mul(ctx, acl_k.get(), acl_v.get(), tmp_tensor.get());
4012
+ //s_new = g ⊗ s_old + k ⊗ v
4013
+ aclnn_mul(ctx, acl_s_new.get(), acl_g.get(), nullptr);
4014
+ aclnn_add(ctx, acl_s_new.get(), tmp_tensor.get(), nullptr);
4015
+ // compute output
4016
+ GGML_CANN_CALL_ACLNN_OP(ctx, Mv, acl_s_new_t.get(), acl_q.get(), acl_o.get(), 1);
4017
+ aclnn_muls(ctx, acl_o.get(), scale, nullptr, true);
4018
+ }
4019
+ }
4020
+ }
4021
+ }