whispercpp 1.3.4 → 1.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (891) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +158 -44
  4. data/ext/extconf.rb +3 -2
  5. data/ext/ruby_whisper.c +34 -6
  6. data/ext/ruby_whisper.h +67 -0
  7. data/ext/ruby_whisper_context.c +236 -144
  8. data/ext/ruby_whisper_context_params.c +163 -0
  9. data/ext/ruby_whisper_model.c +12 -13
  10. data/ext/ruby_whisper_params.c +47 -24
  11. data/ext/ruby_whisper_segment.c +84 -20
  12. data/ext/ruby_whisper_token.c +371 -0
  13. data/ext/ruby_whisper_transcribe.cpp +5 -2
  14. data/ext/ruby_whisper_vad_context.c +122 -0
  15. data/ext/ruby_whisper_vad_context_detect.cpp +51 -0
  16. data/ext/ruby_whisper_vad_params.c +0 -1
  17. data/ext/ruby_whisper_vad_segment.c +138 -0
  18. data/ext/ruby_whisper_vad_segments.c +105 -0
  19. data/ext/sources/CMakeLists.txt +4 -1
  20. data/ext/sources/bindings/javascript/package.json +1 -1
  21. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  22. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  23. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  24. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  25. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  26. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  27. data/ext/sources/examples/bench/bench.cpp +23 -18
  28. data/ext/sources/examples/cli/cli.cpp +129 -112
  29. data/ext/sources/examples/common-ggml.cpp +2 -0
  30. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  31. data/ext/sources/examples/miniaudio.h +4507 -2131
  32. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  33. data/ext/sources/examples/server/server.cpp +28 -15
  34. data/ext/sources/examples/talk-llama/CMakeLists.txt +8 -3
  35. data/ext/sources/examples/talk-llama/llama-adapter.cpp +5 -2
  36. data/ext/sources/examples/talk-llama/llama-adapter.h +7 -0
  37. data/ext/sources/examples/talk-llama/llama-arch.cpp +2378 -1988
  38. data/ext/sources/examples/talk-llama/llama-arch.h +109 -2
  39. data/ext/sources/examples/talk-llama/llama-batch.cpp +78 -34
  40. data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
  41. data/ext/sources/examples/talk-llama/llama-chat.cpp +100 -4
  42. data/ext/sources/examples/talk-llama/llama-chat.h +5 -0
  43. data/ext/sources/examples/talk-llama/llama-context.cpp +1088 -403
  44. data/ext/sources/examples/talk-llama/llama-context.h +70 -23
  45. data/ext/sources/examples/talk-llama/llama-cparams.h +6 -0
  46. data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
  47. data/ext/sources/examples/talk-llama/llama-grammar.cpp +295 -60
  48. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  49. data/ext/sources/examples/talk-llama/llama-graph.cpp +925 -155
  50. data/ext/sources/examples/talk-llama/llama-graph.h +234 -23
  51. data/ext/sources/examples/talk-llama/llama-hparams.cpp +79 -38
  52. data/ext/sources/examples/talk-llama/llama-hparams.h +118 -18
  53. data/ext/sources/examples/talk-llama/llama-impl.cpp +11 -7
  54. data/ext/sources/examples/talk-llama/llama-impl.h +14 -2
  55. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +8 -4
  56. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +405 -140
  57. data/ext/sources/examples/talk-llama/llama-kv-cache.h +24 -10
  58. data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
  59. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
  60. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
  61. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
  62. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +42 -31
  63. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
  64. data/ext/sources/examples/talk-llama/llama-mmap.cpp +197 -45
  65. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  66. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +606 -116
  67. data/ext/sources/examples/talk-llama/llama-model-loader.h +41 -5
  68. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +61 -44
  69. data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
  70. data/ext/sources/examples/talk-llama/llama-model.cpp +2756 -13643
  71. data/ext/sources/examples/talk-llama/llama-model.h +112 -18
  72. data/ext/sources/examples/talk-llama/llama-quant.cpp +582 -365
  73. data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +1409 -199
  74. data/ext/sources/examples/talk-llama/llama-sampler.h +42 -0
  75. data/ext/sources/examples/talk-llama/llama-vocab.cpp +248 -82
  76. data/ext/sources/examples/talk-llama/llama-vocab.h +50 -40
  77. data/ext/sources/examples/talk-llama/llama.cpp +802 -21
  78. data/ext/sources/examples/talk-llama/llama.h +210 -39
  79. data/ext/sources/examples/talk-llama/models/afmoe.cpp +190 -0
  80. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  81. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  82. data/ext/sources/examples/talk-llama/models/arctic.cpp +137 -0
  83. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  84. data/ext/sources/examples/talk-llama/models/baichuan.cpp +123 -0
  85. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +143 -0
  86. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +133 -0
  87. data/ext/sources/examples/talk-llama/models/bert.cpp +184 -0
  88. data/ext/sources/examples/talk-llama/models/bitnet.cpp +145 -0
  89. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  90. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  91. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  92. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  93. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  94. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  95. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  96. data/ext/sources/examples/talk-llama/models/dbrx.cpp +122 -0
  97. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  98. data/ext/sources/examples/talk-llama/models/deepseek.cpp +142 -0
  99. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +262 -0
  100. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
  101. data/ext/sources/examples/talk-llama/models/dots1.cpp +132 -0
  102. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  103. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +148 -0
  104. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  105. data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
  106. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
  107. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  108. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  109. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +111 -0
  110. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  111. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  112. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  113. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  114. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  115. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  116. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  117. data/ext/sources/examples/talk-llama/models/glm4.cpp +157 -0
  118. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  119. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  120. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +195 -0
  121. data/ext/sources/examples/talk-llama/models/granite.cpp +210 -0
  122. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  123. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +139 -0
  124. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  125. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +153 -0
  126. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  127. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  128. data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
  129. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  130. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
  131. data/ext/sources/examples/talk-llama/models/lfm2.cpp +196 -0
  132. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  133. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  134. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  135. data/ext/sources/examples/talk-llama/models/llama.cpp +175 -0
  136. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  137. data/ext/sources/examples/talk-llama/models/mamba-base.cpp +289 -0
  138. data/ext/sources/examples/talk-llama/models/mamba.cpp +54 -0
  139. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +129 -0
  140. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +200 -0
  141. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +123 -0
  142. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  143. data/ext/sources/examples/talk-llama/models/models.h +704 -0
  144. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +109 -0
  145. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  146. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +162 -0
  147. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  148. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  149. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  150. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  151. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  153. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  154. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  155. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
  156. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  157. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  158. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  159. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  160. data/ext/sources/examples/talk-llama/models/plamo2.cpp +320 -0
  161. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  162. data/ext/sources/examples/talk-llama/models/plm.cpp +169 -0
  163. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  164. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  165. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  166. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  167. data/ext/sources/examples/talk-llama/models/qwen3.cpp +120 -0
  168. data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
  169. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
  170. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +131 -0
  171. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +525 -0
  172. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +140 -0
  173. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +132 -0
  174. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  175. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  176. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +164 -0
  177. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  178. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  179. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +137 -0
  180. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  181. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  182. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  183. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  184. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  185. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  186. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  187. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
  188. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  189. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  190. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  191. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  192. data/ext/sources/examples/talk-llama/unicode.cpp +121 -79
  193. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  194. data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
  195. data/ext/sources/ggml/CMakeLists.txt +90 -56
  196. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  197. data/ext/sources/ggml/include/ggml-backend.h +5 -2
  198. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  199. data/ext/sources/ggml/include/ggml-cpu.h +6 -0
  200. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  201. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  202. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  203. data/ext/sources/ggml/include/ggml-rpc.h +14 -12
  204. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  205. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  206. data/ext/sources/ggml/include/ggml.h +246 -21
  207. data/ext/sources/ggml/src/CMakeLists.txt +85 -11
  208. data/ext/sources/ggml/src/ggml-alloc.c +128 -50
  209. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  210. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  211. data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
  212. data/ext/sources/ggml/src/ggml-backend-reg.cpp +54 -88
  213. data/ext/sources/ggml/src/ggml-backend.cpp +76 -23
  214. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +18 -4
  215. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +11 -11
  216. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +58 -46
  217. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +139 -48
  218. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2427 -1785
  219. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -362
  220. data/ext/sources/ggml/src/ggml-cann/common.h +285 -211
  221. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +663 -831
  222. data/ext/sources/ggml/src/ggml-common.h +11 -0
  223. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +170 -95
  224. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -18
  225. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  226. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
  227. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  228. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +513 -27
  229. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +4192 -992
  230. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
  231. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  232. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1761 -49
  233. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
  234. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  235. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
  236. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
  237. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +124 -24
  238. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +157 -28
  239. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  240. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  241. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -3
  242. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +251 -80
  243. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +19 -0
  244. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +587 -119
  245. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
  246. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1093 -194
  247. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1284 -203
  248. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  249. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1519 -527
  250. data/ext/sources/ggml/src/ggml-cpu/ops.h +6 -4
  251. data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
  252. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  253. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3632 -781
  254. data/ext/sources/ggml/src/ggml-cpu/repack.h +129 -4
  255. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  256. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +152 -46
  257. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
  258. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +152 -1
  259. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  260. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +140 -0
  261. data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
  262. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
  263. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  264. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +132 -6
  265. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  266. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +33 -31
  267. data/ext/sources/ggml/src/ggml-cuda/common.cuh +474 -85
  268. data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
  269. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  270. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
  271. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +342 -246
  272. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  273. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  274. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  275. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  276. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +98 -74
  278. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +973 -665
  279. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
  280. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1255 -0
  281. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +33 -40
  282. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +40 -18
  283. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  284. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +206 -45
  285. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  286. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  287. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
  288. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  289. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1688 -302
  290. data/ext/sources/ggml/src/ggml-cuda/mean.cu +12 -10
  291. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +908 -48
  292. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +88 -20
  293. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +502 -90
  294. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  295. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  296. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
  297. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +532 -193
  298. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +460 -104
  299. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +5 -2
  300. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +360 -122
  301. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +2 -1
  302. data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
  303. data/ext/sources/ggml/src/ggml-cuda/pad.cu +73 -39
  304. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +152 -1
  305. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  306. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
  307. data/ext/sources/ggml/src/ggml-cuda/rope.cu +364 -149
  308. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  309. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
  310. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  311. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  312. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +163 -41
  313. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  314. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  315. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +68 -50
  316. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  317. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
  318. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
  319. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  320. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
  321. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  322. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  323. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  324. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  325. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  326. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  327. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  328. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  329. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  330. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  331. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  332. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  333. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +22 -4
  334. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +95 -0
  335. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  336. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +275 -119
  337. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -7
  338. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  339. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  340. data/ext/sources/ggml/src/ggml-cuda/unary.cu +160 -11
  341. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +38 -0
  342. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
  343. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
  344. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  345. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -1
  346. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
  347. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +117 -0
  348. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3325 -0
  349. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +46 -0
  350. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +813 -0
  351. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  352. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +891 -0
  353. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  354. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
  355. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +713 -0
  356. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  357. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +63 -0
  358. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +182 -0
  359. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  360. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  361. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  362. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  363. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +155 -0
  364. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +63 -0
  365. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  366. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  367. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
  368. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  369. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
  370. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  371. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  372. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  373. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  374. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  375. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  376. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  377. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  378. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  379. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +26 -0
  380. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1199 -0
  381. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2670 -0
  382. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +497 -0
  383. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  384. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +419 -0
  385. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
  386. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  387. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +382 -0
  388. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +293 -0
  389. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  390. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  391. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  392. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  393. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  394. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  395. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -13
  396. data/ext/sources/ggml/src/ggml-impl.h +129 -6
  397. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  398. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +15 -4
  399. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  400. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +173 -34
  401. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +912 -344
  402. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +124 -59
  403. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +588 -144
  404. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +396 -23
  405. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1724 -421
  406. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +16 -3
  407. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +333 -114
  408. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3050 -1539
  409. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
  410. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +30 -1
  411. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4279 -497
  412. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
  413. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
  414. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  415. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +267 -0
  416. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  417. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  418. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +113 -0
  419. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  420. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
  421. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  422. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  423. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  424. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  425. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  426. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  427. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  428. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +140 -0
  429. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  430. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
  431. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
  432. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  433. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  434. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  435. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  436. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  437. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  438. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  439. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  440. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  441. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  442. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  443. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  444. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  445. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
  446. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  447. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  448. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
  449. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +116 -0
  450. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  451. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  452. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  453. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  454. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  455. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  456. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +39 -0
  457. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  458. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  459. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  460. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
  461. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  462. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
  463. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  464. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
  465. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
  466. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  467. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  468. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  469. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  470. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  471. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  472. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  473. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  474. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  475. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  476. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  477. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  478. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  479. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  480. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  481. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  482. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  483. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
  484. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  485. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  486. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  487. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  488. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  489. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  490. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
  491. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
  492. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
  493. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
  494. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  495. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  496. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  497. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  498. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  499. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
  500. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  501. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
  502. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
  503. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
  504. data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
  505. data/ext/sources/ggml/src/ggml-quants.c +96 -5
  506. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  507. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
  508. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +59 -87
  509. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +81 -0
  510. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  511. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +7 -0
  512. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -29
  513. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
  514. data/ext/sources/ggml/src/ggml-sycl/common.hpp +427 -20
  515. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
  516. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +103 -1
  517. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
  518. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  519. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  520. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
  521. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  522. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +867 -50
  523. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +401 -358
  524. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
  525. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
  526. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
  527. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
  528. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
  529. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
  530. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  531. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
  532. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
  533. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +645 -155
  534. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
  535. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +221 -66
  536. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  537. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  538. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  539. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  540. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  541. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  542. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +5 -0
  543. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
  544. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  545. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  546. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  547. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  548. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +457 -281
  549. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  550. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  551. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  552. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
  553. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  554. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  555. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  556. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  557. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  558. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  559. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  560. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  561. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  562. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  563. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  564. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  565. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
  566. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
  567. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
  568. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
  569. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
  570. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
  571. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
  572. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
  573. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
  574. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
  575. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
  576. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
  577. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
  578. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
  579. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
  580. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
  581. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
  582. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
  583. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
  584. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
  585. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
  586. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
  587. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
  588. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
  589. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
  590. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
  591. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
  592. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
  593. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
  594. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
  595. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
  596. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
  597. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
  598. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
  599. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
  600. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
  601. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +71 -0
  602. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  603. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  604. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  605. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  606. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  607. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  608. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  609. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  610. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  611. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  612. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  613. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  614. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  615. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  616. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  617. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  618. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  619. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  620. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  621. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  622. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  623. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  624. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  625. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  626. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  627. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  628. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  629. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  630. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  631. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  632. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  633. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  634. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  635. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  636. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  637. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  638. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  639. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  640. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  641. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  642. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  643. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  644. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  645. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +39 -19
  646. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5994 -3055
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +18 -10
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
  692. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
  695. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  696. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  697. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  698. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  699. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  700. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  701. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
  702. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  703. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +386 -160
  704. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +82 -20
  705. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +400 -174
  706. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +123 -37
  707. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  708. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +10 -9
  709. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  710. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
  711. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  712. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
  713. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
  714. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  715. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
  716. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  717. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
  718. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  719. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  720. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
  721. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
  722. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
  723. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  724. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
  725. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
  726. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
  727. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
  728. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +13 -10
  729. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  730. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  731. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  732. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  733. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +77 -29
  734. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  735. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  736. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  737. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  738. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  739. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  740. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  741. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  742. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
  743. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  744. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  745. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  746. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  747. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  748. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  749. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
  750. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  751. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +88 -105
  752. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +41 -26
  753. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
  754. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +74 -0
  755. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +92 -230
  756. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  757. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  758. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
  759. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  760. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  761. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  762. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
  763. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
  764. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  765. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
  766. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  767. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  768. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  769. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  770. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -4
  771. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  772. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
  773. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
  774. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +207 -0
  775. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  776. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +8 -49
  777. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +8 -32
  778. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +8 -32
  779. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +33 -0
  780. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +8 -38
  781. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  782. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
  783. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  784. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  785. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  786. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  787. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  788. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
  789. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
  790. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  791. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  792. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  793. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  794. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  795. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  796. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
  797. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  798. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +50 -0
  799. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  800. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  801. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  802. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
  803. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  804. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  805. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
  806. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  807. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
  808. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  809. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  810. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  811. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  812. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  813. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
  814. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
  815. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +384 -180
  816. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  817. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
  818. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1374 -0
  819. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2544 -726
  820. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  821. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  822. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  823. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  824. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
  825. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
  826. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
  827. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +107 -0
  828. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  829. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +73 -15
  830. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +636 -0
  831. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
  832. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  833. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +72 -261
  834. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +766 -0
  835. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +147 -0
  836. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +196 -0
  837. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
  838. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  839. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  840. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
  841. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  842. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +63 -0
  843. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
  844. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  845. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  846. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  847. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
  848. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +91 -0
  849. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +469 -0
  850. data/ext/sources/ggml/src/ggml.c +590 -64
  851. data/ext/sources/ggml/src/gguf.cpp +229 -44
  852. data/ext/sources/include/whisper.h +1 -0
  853. data/ext/sources/src/CMakeLists.txt +3 -1
  854. data/ext/sources/src/whisper.cpp +106 -62
  855. data/ext/sources/tests/CMakeLists.txt +2 -2
  856. data/ext/sources/tests/test-vad-full.cpp +4 -2
  857. data/ext/sources/tests/test-vad.cpp +1 -1
  858. data/extsources.rb +1 -0
  859. data/lib/whisper/model/uri.rb +17 -18
  860. data/sig/whisper.rbs +162 -4
  861. data/test/test_context_params.rb +82 -0
  862. data/test/test_params.rb +16 -8
  863. data/test/test_segment.rb +0 -1
  864. data/test/test_token.rb +81 -0
  865. data/test/test_vad.rb +1 -1
  866. data/test/test_vad_context.rb +100 -0
  867. data/test/test_vad_segment.rb +19 -0
  868. data/test/test_vad_segments.rb +16 -0
  869. data/test/test_whisper.rb +27 -0
  870. data/whispercpp.gemspec +1 -1
  871. metadata +502 -37
  872. data/ext/sources/build-xcframework.sh +0 -571
  873. data/ext/sources/examples/talk-llama/llama-sampling.h +0 -32
  874. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  875. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  876. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
  877. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
  878. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
  879. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
  880. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  881. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
  882. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
  883. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
  884. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
  885. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  886. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  887. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  888. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  889. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
  890. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
  891. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
@@ -1,16 +1,31 @@
1
- // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
1
+ // SPDX-FileCopyrightText: Copyright 2025-2026 Arm Limited and/or its affiliates <open-source-office@arm.com>
2
2
  // SPDX-License-Identifier: MIT
3
3
  //
4
4
  #include <arm_neon.h>
5
5
  #include <assert.h>
6
+ #include <stdio.h>
6
7
  #include <atomic>
7
8
  #include <cfloat>
9
+ #include <algorithm>
10
+ #include <cmath>
8
11
  #include <stdexcept>
9
12
  #include <stdint.h>
10
13
  #include <string.h>
14
+ #include <string>
15
+ #include <vector>
16
+ #include <array>
17
+ #include <cstddef>
18
+ #include <cstdint>
19
+ #include <fstream>
20
+ #include <set>
21
+ #include <iostream>
22
+ #include <climits>
11
23
  #if defined(__linux__)
12
24
  #include <asm/hwcap.h>
13
25
  #include <sys/auxv.h>
26
+ #include <sys/types.h>
27
+ #include <sys/stat.h>
28
+ #include <unistd.h>
14
29
  #elif defined(__APPLE__)
15
30
  #include <string_view>
16
31
  #include <sys/sysctl.h>
@@ -35,90 +50,369 @@
35
50
  #define GGML_COMMON_DECL_CPP
36
51
  #include "ggml-common.h"
37
52
 
53
+ static constexpr int GGML_KLEIDIAI_MAX_KERNEL_SLOTS = 2;
54
+ static constexpr uint32_t GGML_KLEIDIAI_PACK_MAGIC = 0x4b4c4149; // "KLAI"
55
+ static constexpr uint16_t GGML_KLEIDIAI_PACK_VERSION = 1;
56
+ static constexpr size_t GGML_KLEIDIAI_PACK_ALIGN = 64;
57
+
38
58
  struct ggml_kleidiai_context {
39
59
  cpu_feature features;
40
- ggml_kleidiai_kernels * kernels;
41
- } static ctx = { CPU_FEATURE_NONE, NULL };
60
+ ggml_kleidiai_kernels * kernels_q4;
61
+ ggml_kleidiai_kernels * kernels_q8;
62
+ int sme_thread_cap; // <= 0 means “SME disabled/unknown”;
63
+ int thread_hint; // <= 0 means “no hint”
64
+ } static ctx = { CPU_FEATURE_NONE, nullptr, nullptr, 0, -1 };
42
65
 
43
66
  static const char* cpu_feature_to_string(cpu_feature f) {
44
- switch (f) {
45
- case CPU_FEATURE_NONE: return "NONE";
46
- case CPU_FEATURE_DOTPROD: return "DOTPROD";
47
- case CPU_FEATURE_I8MM: return "I8MM";
48
- case CPU_FEATURE_SVE: return "SVE";
49
- case CPU_FEATURE_SME: return "SME";
50
- default: return "UNKNOWN";
67
+ if (f == CPU_FEATURE_NONE) {
68
+ return "NONE";
69
+ } else if ((f & CPU_FEATURE_SME) == CPU_FEATURE_SME) {
70
+ return "SME";
71
+ } else if ((f & CPU_FEATURE_SVE) == CPU_FEATURE_SVE) {
72
+ return "SVE";
73
+ }
74
+ else if ((f & CPU_FEATURE_I8MM) == CPU_FEATURE_I8MM) {
75
+ return "I8MM";
76
+ } else if ((f & CPU_FEATURE_DOTPROD) == CPU_FEATURE_DOTPROD) {
77
+ return "DOTPROD";
78
+ }
79
+ else {
80
+ return "UNKNOWN";
51
81
  }
52
82
  }
53
83
 
54
- static void init_kleidiai_context(void) {
84
+ static size_t detect_num_smcus() {
85
+ if (!ggml_cpu_has_sme()) {
86
+ return 0;
87
+ }
88
+
89
+ #if defined(__linux__) && defined(__aarch64__)
90
+ // Linux/aarch64: Best-effort count of Streaming Mode Compute Units (SMCUs) via SMIDR_EL1 sysfs.
91
+ size_t num_private = 0;
92
+ std::set<uint32_t> shared_ids;
93
+
94
+ for (size_t cpu = 0;; ++cpu) {
95
+ const std::string path =
96
+ "/sys/devices/system/cpu/cpu" + std::to_string(cpu) +
97
+ "/regs/identification/smidr_el1";
98
+
99
+ std::ifstream file(path);
100
+ if (!file.is_open()) {
101
+ break;
102
+ }
55
103
 
104
+ uint64_t smidr = 0;
105
+ if (!(file >> std::hex >> smidr)) {
106
+ continue;
107
+ }
108
+
109
+ // Arm ARM: SMIDR_EL1
110
+ const uint32_t sh = (uint32_t)((smidr >> 13) & 0x3);
111
+ // Build an "affinity-like" identifier for shared SMCUs.
112
+ // Keep the original packing logic, but isolate it here.
113
+ const uint32_t id = (uint32_t)((smidr & 0xFFFu) | ((smidr >> 20) & 0xFFFFF000u));
114
+
115
+ switch (sh) {
116
+ case 0b10: // private SMCU
117
+ ++num_private;
118
+ break;
119
+ case 0b11: // shared SMCU
120
+ shared_ids.emplace(id);
121
+ break;
122
+ case 0b00:
123
+ // Ambiguous / implementation-defined. Be conservative:
124
+ // treat id==0 as private, otherwise as shared.
125
+ if (id == 0) ++num_private;
126
+ else shared_ids.emplace(id);
127
+ break;
128
+ default:
129
+ break;
130
+ }
131
+ }
132
+
133
+ return num_private + shared_ids.size();
134
+
135
+ #elif defined(__APPLE__) && defined(__aarch64__)
136
+ // table for known M4 variants. Users can override via GGML_KLEIDIAI_SME=<n>.
137
+ char chip_name[256] = {};
138
+ size_t size = sizeof(chip_name);
139
+
140
+ if (sysctlbyname("machdep.cpu.brand_string", chip_name, &size, nullptr, 0) == 0) {
141
+ const std::string brand(chip_name);
142
+
143
+ struct ModelSMCU { const char *match; size_t smcus; };
144
+ static const ModelSMCU table[] = {
145
+ { "M4 Ultra", 2 },
146
+ { "M4 Max", 2 },
147
+ { "M4 Pro", 2 },
148
+ { "M4", 1 },
149
+ };
150
+
151
+ for (const auto &e : table) {
152
+ if (brand.find(e.match) != std::string::npos) {
153
+ return e.smcus;
154
+ }
155
+ }
156
+ }
157
+ return 1;
158
+
159
+ #else
160
+ return 1;
161
+ #endif
162
+ }
163
+
164
+ static int parse_uint_env(const char *s, const char *name, bool *ok) {
165
+ if (!s) { *ok = false; return 0; }
166
+ char *end = nullptr;
167
+ long v = strtol(s, &end, 10);
168
+ if (end == s || *end != '\0') {
169
+ GGML_LOG_WARN("kleidiai: invalid %s='%s' (expected integer)\n", name, s);
170
+ *ok = false;
171
+ return 0;
172
+ }
173
+ if (v < 0 || v > INT_MAX) {
174
+ GGML_LOG_WARN("kleidiai: out-of-range %s='%s'\n", name, s);
175
+ *ok = false;
176
+ return 0;
177
+ }
178
+ *ok = true;
179
+ return (int)v;
180
+ }
181
+
182
+ static void init_kleidiai_context(void) {
56
183
  ggml_critical_section_start();
57
184
  static bool initialized = false;
58
185
 
59
186
  if (!initialized) {
60
187
  initialized = true;
61
- const char *env_var = getenv("GGML_KLEIDIAI_SME");
62
- int sme_enabled = 0;
188
+
189
+ const char *env_sme = getenv("GGML_KLEIDIAI_SME");
190
+ const char *env_threads = getenv("GGML_TOTAL_THREADS");
191
+
192
+ const bool cpu_has_sme = ggml_cpu_has_sme();
193
+ size_t detected_smcus = 0;
63
194
 
64
195
  ctx.features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) |
65
196
  (ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) |
66
- (ggml_cpu_has_sve() ? CPU_FEATURE_SVE : CPU_FEATURE_NONE);
197
+ ((ggml_cpu_has_sve() && ggml_cpu_get_sve_cnt() == QK8_0) ? CPU_FEATURE_SVE : CPU_FEATURE_NONE);
67
198
 
68
- if (env_var) {
69
- sme_enabled = atoi(env_var);
199
+ if (env_threads) {
200
+ bool ok = false;
201
+ int hint = parse_uint_env(env_threads, "GGML_TOTAL_THREADS", &ok);
202
+ if (ok && hint > 0) {
203
+ ctx.thread_hint = hint;
204
+ }
70
205
  }
71
206
 
72
- if (sme_enabled != 0) {
73
- ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
207
+ // SME policy:
208
+ // - If CPU doesn't support SME: SME always off.
209
+ // - Else:
210
+ // - env unset => auto-detect cores; enable if detected > 0.
211
+ // - env=0 => force off.
212
+ // - env>0 => force N cores (skip detection).
213
+ int sme_cores = 0;
214
+ bool sme_env_ok = false;
215
+ bool sme_env_set = (env_sme != nullptr);
216
+
217
+ if (!cpu_has_sme) {
218
+ if (sme_env_set) {
219
+ bool ok = false;
220
+ int req = parse_uint_env(env_sme, "GGML_KLEIDIAI_SME", &ok);
221
+ if (ok && req > 0) {
222
+ GGML_LOG_WARN("kleidiai: GGML_KLEIDIAI_SME=%d but SME is not supported on this CPU; disabling SME\n", req);
223
+ }
224
+ }
225
+ sme_cores = 0;
226
+ } else {
227
+ if (sme_env_set) {
228
+ bool ok = false;
229
+ int v = parse_uint_env(env_sme, "GGML_KLEIDIAI_SME", &ok);
230
+ sme_env_ok = ok;
231
+
232
+ if (!ok) {
233
+ GGML_LOG_WARN("kleidiai: GGML_KLEIDIAI_SME set but parsing failed; falling back to runtime SME-core detection\n");
234
+ detected_smcus = detect_num_smcus();
235
+ sme_cores = detected_smcus > 0 ? (int)detected_smcus : 0;
236
+ } else if (v == 0) {
237
+ sme_cores = 0;
238
+ } else {
239
+ sme_cores = v;
240
+ }
241
+ } else {
242
+ detected_smcus = detect_num_smcus();
243
+ sme_cores = detected_smcus > 0 ? (int)detected_smcus : 0;
244
+ }
245
+
246
+ if (!sme_env_set && sme_cores == 0) {
247
+ GGML_LOG_WARN("kleidiai: SME supported but runtime SME-core detection returned 0; falling back to NEON\n");
248
+ }
249
+
250
+ if (sme_cores > 0) {
251
+ ctx.features |= CPU_FEATURE_SME;
252
+ }
74
253
  }
75
- ctx.kernels = ggml_kleidiai_select_kernels_q4_0(ctx.features);
76
- #ifndef NDEBUG
77
- if (ctx.kernels) {
78
- GGML_LOG_DEBUG("kleidiai: using kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels->required_cpu));
254
+
255
+ // Kernel selection
256
+ ctx.kernels_q4 = ggml_kleidiai_select_kernels_q4_0(ctx.features);
257
+ ctx.kernels_q8 = ggml_kleidiai_select_kernels_q8_0(ctx.features);
258
+
259
+ if (!ctx.kernels_q4) {
260
+ GGML_LOG_INFO("kleidiai: no compatible q4 kernels found for CPU features mask %d\n", (int)ctx.features);
261
+ } else {
262
+ GGML_LOG_INFO("kleidiai: primary q4 kernel feature %s\n", cpu_feature_to_string(ctx.kernels_q4->required_cpu));
263
+ }
264
+
265
+ if (!ctx.kernels_q8) {
266
+ GGML_LOG_INFO("kleidiai: no compatible q8 kernels found for CPU features mask %d\n", (int)ctx.features);
267
+ } else {
268
+ GGML_LOG_INFO("kleidiai: primary q8 kernel feature %s\n", cpu_feature_to_string(ctx.kernels_q8->required_cpu));
269
+ }
270
+
271
+ ctx.sme_thread_cap = (ctx.features & CPU_FEATURE_SME) ? sme_cores : 0;
272
+
273
+ if (ctx.features & CPU_FEATURE_SME) {
274
+ if (sme_env_set && sme_env_ok && sme_cores > 0) {
275
+ GGML_LOG_INFO("kleidiai: SME enabled (GGML_KLEIDIAI_SME=%d override)\n", sme_cores);
276
+ } else {
277
+ GGML_LOG_INFO("kleidiai: SME enabled (runtime-detected SME cores=%d)\n", sme_cores);
278
+ }
279
+ } else {
280
+ GGML_LOG_INFO("kleidiai: SME disabled\n");
79
281
  }
80
- #endif
81
282
  }
283
+
82
284
  ggml_critical_section_end();
83
285
  }
84
286
 
85
- static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) {
86
- GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
87
- return tensor->ne[dim];
287
+ static inline int kleidiai_sme_thread_cap() {
288
+ return ctx.sme_thread_cap;
88
289
  }
89
290
 
90
- template <typename Variant, typename Ret, typename... Args, std::size_t... Is>
91
- constexpr bool variant_any_invocable_impl(std::index_sequence<Is...>) {
92
- using V = std::remove_reference_t<Variant>;
93
- return (std::is_invocable_r_v<
94
- Ret,
95
- std::variant_alternative_t<Is, V>,
96
- Args...> || ...);
291
+ static inline size_t align_up(size_t value, size_t alignment) {
292
+ if (alignment == 0) {
293
+ return value;
294
+ }
295
+ const size_t remainder = value % alignment;
296
+ return remainder == 0 ? value : value + (alignment - remainder);
97
297
  }
98
298
 
99
- template <typename Variant, typename Ret, typename... Args>
100
- constexpr bool variant_any_invocable_v =
101
- variant_any_invocable_impl<Variant, Ret, Args...>(
102
- std::make_index_sequence<
103
- std::variant_size_v<std::remove_reference_t<Variant>>>{});
104
-
105
- template<typename Ret, typename Variant, typename... Args>
106
- static inline Ret variant_call(Variant && var, Args&&... args) {
107
- static_assert(variant_any_invocable_v<std::remove_reference_t<Variant>, Ret, Args...>,
108
- "No alternative in Variant is invocable with the provided arguments and return type.");
109
-
110
- return std::visit(
111
- [&](auto && f) -> Ret {
112
- using F = std::decay_t<decltype(f)>;
113
- if constexpr (std::is_invocable_r_v<Ret, F, Args...>) {
114
- return std::invoke(std::forward<decltype(f)>(f), std::forward<Args>(args)...);
115
- } else {
116
- GGML_ABORT("Invalid function type in variant_call");
117
- GGML_UNREACHABLE();
299
+ static inline bool kleidiai_pack_fallback_allowed() {
300
+ if (ctx.sme_thread_cap <= 0) {
301
+ return false;
302
+ }
303
+ if (ctx.thread_hint <= 0) {
304
+ return true;
305
+ }
306
+ return ctx.thread_hint > ctx.sme_thread_cap;
307
+ }
308
+
309
+ struct kleidiai_weight_header {
310
+ uint32_t magic;
311
+ uint16_t version;
312
+ uint16_t slot_count;
313
+ uint64_t offsets[GGML_KLEIDIAI_MAX_KERNEL_SLOTS];
314
+ uint64_t sizes[GGML_KLEIDIAI_MAX_KERNEL_SLOTS];
315
+ };
316
+
317
+ static inline kleidiai_weight_header * kleidiai_weight_header_from_ptr(void * data) {
318
+ return reinterpret_cast<kleidiai_weight_header *>(data);
319
+ }
320
+
321
+ static inline const kleidiai_weight_header * kleidiai_weight_header_from_ptr(const void * data) {
322
+ return reinterpret_cast<const kleidiai_weight_header *>(data);
323
+ }
324
+
325
+ static inline bool kleidiai_is_weight_header_valid(const kleidiai_weight_header * header) {
326
+ if (!header) {
327
+ return false;
328
+ }
329
+ if (header->magic != GGML_KLEIDIAI_PACK_MAGIC || header->version != GGML_KLEIDIAI_PACK_VERSION) {
330
+ return false;
331
+ }
332
+ if (header->slot_count == 0 || header->slot_count > GGML_KLEIDIAI_MAX_KERNEL_SLOTS) {
333
+ return false;
334
+ }
335
+ return true;
336
+ }
337
+
338
+ static inline uint8_t * kleidiai_weight_slot_ptr(kleidiai_weight_header * header, int slot) {
339
+ if (!kleidiai_is_weight_header_valid(header)) {
340
+ return nullptr;
341
+ }
342
+ if (slot < 0 || slot >= header->slot_count) {
343
+ return nullptr;
344
+ }
345
+ return reinterpret_cast<uint8_t *>(header) + header->offsets[slot];
346
+ }
347
+
348
+ static inline const uint8_t * kleidiai_weight_slot_ptr(const kleidiai_weight_header * header, int slot) {
349
+ if (!kleidiai_is_weight_header_valid(header)) {
350
+ return nullptr;
351
+ }
352
+ if (slot < 0 || slot >= header->slot_count) {
353
+ return nullptr;
354
+ }
355
+ return reinterpret_cast<const uint8_t *>(header) + header->offsets[slot];
356
+ }
357
+
358
+ static inline ggml_kleidiai_kernels * kleidiai_primary_kernel_q4() {
359
+ return ctx.kernels_q4;
360
+ }
361
+
362
+ static inline ggml_kleidiai_kernels * kleidiai_primary_kernel_q8() {
363
+ return ctx.kernels_q8;
364
+ }
365
+
366
+ template <typename SelectFallback>
367
+ static int kleidiai_collect_kernel_chain_common(
368
+ ggml_kleidiai_kernels * primary,
369
+ cpu_feature features,
370
+ std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> & out,
371
+ SelectFallback select_fallback) {
372
+ int count = 0;
373
+ if (!primary) {
374
+ return 0;
375
+ }
376
+ out[count++] = primary;
377
+
378
+ if ((primary->required_cpu & CPU_FEATURE_SME) == CPU_FEATURE_SME) {
379
+ const cpu_feature fallback_mask = static_cast<cpu_feature>(features & ~CPU_FEATURE_SME);
380
+ if (fallback_mask != CPU_FEATURE_NONE) {
381
+ ggml_kleidiai_kernels * fallback = select_fallback(fallback_mask);
382
+ if (fallback && fallback != primary &&
383
+ fallback->lhs_type == primary->lhs_type &&
384
+ fallback->rhs_type == primary->rhs_type &&
385
+ fallback->op_type == primary->op_type) {
386
+ out[count++] = fallback;
118
387
  }
119
- },
120
- std::forward<Variant>(var)
121
- );
388
+ }
389
+ }
390
+
391
+ return count;
392
+ }
393
+
394
+ static int kleidiai_collect_kernel_chain(const struct ggml_tensor * op,
395
+ std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> & out) {
396
+ ggml_kleidiai_kernels * primary = ggml_kleidiai_select_kernels(ctx.features, op);
397
+ return kleidiai_collect_kernel_chain_common(primary, ctx.features, out,
398
+ [&](cpu_feature mask) { return ggml_kleidiai_select_kernels(mask, op); });
399
+ }
400
+
401
+ static int kleidiai_collect_q4_chain(std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> & out) {
402
+ ggml_kleidiai_kernels * primary = kleidiai_primary_kernel_q4();
403
+ return kleidiai_collect_kernel_chain_common(primary, ctx.features, out,
404
+ [&](cpu_feature mask) { return ggml_kleidiai_select_kernels_q4_0(mask); });
405
+ }
406
+
407
+ static int kleidiai_collect_q8_chain(std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> & out) {
408
+ ggml_kleidiai_kernels * primary = kleidiai_primary_kernel_q8();
409
+ return kleidiai_collect_kernel_chain_common(primary, ctx.features, out,
410
+ [&](cpu_feature mask) { return ggml_kleidiai_select_kernels_q8_0(mask); });
411
+ }
412
+
413
+ static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) {
414
+ GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
415
+ return tensor->ne[dim];
122
416
  }
123
417
 
124
418
  namespace ggml::cpu::kleidiai {
@@ -144,45 +438,113 @@ class tensor_traits : public ggml::cpu::tensor_traits {
144
438
  if (op->op != GGML_OP_MUL_MAT) {
145
439
  return false;
146
440
  }
147
- ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
148
- GGML_ASSERT(kernels);
149
- bool is_gemv = op->src[1]->ne[1] == 1;
150
- kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
151
- lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
152
441
 
153
- size_t k = op->src[0]->ne[0];
154
- size_t n = op->src[0]->ne[1];
155
- size_t m = op->src[1]->ne[1];
442
+ std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
443
+ const int slot_count = kleidiai_collect_kernel_chain(op, kernel_chain);
444
+ if (slot_count == 0) {
445
+ return false;
446
+ }
156
447
 
157
- size_t mr = kernel->get_mr();
158
- size_t kr = kernel->get_kr();
159
- size_t sr = kernel->get_sr();
448
+ const bool is_gemv = op->src[1]->ne[1] == 1;
449
+ const size_t k = op->src[0]->ne[0];
450
+ const size_t n = op->src[0]->ne[1];
451
+ const size_t m = op->src[1]->ne[1];
160
452
 
161
- if (kernels->rhs_type == GGML_TYPE_Q4_0) {
162
- size = variant_call<size_t>(lhs_info->packed_size, m, k, QK4_0, mr, kr, sr);
163
- } else if (kernels->rhs_type == GGML_TYPE_F16) {
453
+ if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q8_0) {
454
+ const size_t qk = (op->src[0]->type == GGML_TYPE_Q4_0) ? QK4_0 : QK8_0;
455
+
456
+ size_t cursor = 0;
457
+ bool any_slot = false;
458
+
459
+ for (int slot = 0; slot < slot_count; ++slot) {
460
+ ggml_kleidiai_kernels * kernels = kernel_chain[slot];
461
+ lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
462
+ kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
463
+
464
+ if (!lhs_info || !lhs_info->packed_size_ex || !kernel) {
465
+ return false;
466
+ }
467
+
468
+ const size_t mr = kernel->get_mr();
469
+ const size_t kr = kernel->get_kr();
470
+ const size_t sr = kernel->get_sr();
471
+
472
+ const size_t packed = lhs_info->packed_size_ex(m, k, qk, mr, kr, sr);
473
+
474
+ cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
475
+ cursor += packed;
476
+ any_slot = true;
477
+ }
478
+
479
+ if (!any_slot) {
480
+ return false;
481
+ }
482
+
483
+ size = cursor;
484
+ return true;
485
+ }
486
+
487
+ if (op->src[0]->type == GGML_TYPE_F16) {
164
488
  const int64_t lhs_batch_size0 = op->src[1]->ne[2];
165
489
  const int64_t rhs_batch_size0 = op->src[0]->ne[2];
490
+ GGML_ASSERT(rhs_batch_size0 > 0);
166
491
  const int64_t r = lhs_batch_size0 / rhs_batch_size0;
167
- size = variant_call<size_t>(lhs_info->packed_size, m * r, k, mr, kr, sr) +
168
- variant_call<size_t>(kernels->rhs_info.packed_size, n, k) +
169
- k * n * sizeof(float) + n * sizeof(float);
170
- } else {
171
- GGML_ASSERT(false);
492
+
493
+ size_t cursor = 0;
494
+ bool any_slot = false;
495
+
496
+ for (int slot = 0; slot < slot_count; ++slot) {
497
+ ggml_kleidiai_kernels * kernels = kernel_chain[slot];
498
+ lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
499
+ kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
500
+ if (!lhs_info || !lhs_info->packed_size_ex || !kernels->rhs_info.packed_size_ex || !kernel) {
501
+ return false;
502
+ }
503
+
504
+ const size_t mr = kernel->get_mr();
505
+ const size_t kr = kernel->get_kr();
506
+ const size_t sr = kernel->get_sr();
507
+
508
+ cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
509
+ cursor += lhs_info->packed_size_ex(m * r, k, 0, mr, kr, sr);
510
+ any_slot = true;
511
+ }
512
+
513
+ for (int slot = 0; slot < slot_count; ++slot) {
514
+ ggml_kleidiai_kernels * kernels = kernel_chain[slot];
515
+ kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
516
+ if (!kernel || !kernels->rhs_info.packed_size_ex) {
517
+ return false;
518
+ }
519
+ cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
520
+ cursor += kernels->rhs_info.packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), 0);
521
+ }
522
+
523
+ cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
524
+ cursor += k * n * sizeof(float);
525
+ cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
526
+ cursor += n * sizeof(float);
527
+
528
+ if (!any_slot) {
529
+ return false;
530
+ }
531
+
532
+ size = cursor;
533
+ return true;
172
534
  }
173
535
 
174
- return true;
536
+ return false;
175
537
  }
176
538
 
177
539
  bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override {
178
540
  if (dst->op == GGML_OP_MUL_MAT) {
179
- if (dst->src[0]->type == GGML_TYPE_Q4_0) {
180
- return compute_forward_q4_0(params, dst);
541
+ if (dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0) {
542
+ return compute_forward_qx(params, dst);
181
543
  } else if (dst->src[0]->type == GGML_TYPE_F16) {
182
544
  return compute_forward_fp16(params, dst);
183
545
  }
184
546
  } else if (dst->op == GGML_OP_GET_ROWS) {
185
- if (dst->src[0]->type == GGML_TYPE_Q4_0) {
547
+ if (dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0) {
186
548
  return compute_forward_get_rows(params, dst);
187
549
  }
188
550
  }
@@ -196,12 +558,18 @@ class tensor_traits : public ggml::cpu::tensor_traits {
196
558
  GGML_TENSOR_BINARY_OP_LOCALS
197
559
 
198
560
  ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
199
- GGML_ASSERT(kernels);
561
+ if (!kernels) {
562
+ return false;
563
+ }
200
564
 
201
565
  const bool is_gemv = src1->ne[1] == 1;
202
566
  kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
203
567
  lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
204
568
  GGML_ASSERT(kernel);
569
+ if (!kernels->rhs_info.pack_func_ex ||
570
+ !kernel->get_lhs_offset_ex || !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex) {
571
+ return false;
572
+ }
205
573
 
206
574
  const int nth = params->nth;
207
575
  const int ith = params->ith;
@@ -228,10 +596,10 @@ class tensor_traits : public ggml::cpu::tensor_traits {
228
596
  const int64_t kr = (int64_t) kernel->get_kr();
229
597
  const int64_t sr = (int64_t) kernel->get_sr();
230
598
 
231
- const size_t lhs_packed_size = variant_call<size_t>(lhs_info->packed_size, (size_t)m, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr);
232
- const size_t rhs_packed_size = variant_call<size_t>(kernels->rhs_info.packed_size, (size_t)n, (size_t)k);
233
- const size_t kxn_size = (size_t)k * (size_t)n * sizeof(float);
234
- const size_t bias_size = (size_t)n * sizeof(float);
599
+ const size_t lhs_packed_size = lhs_info->packed_size_ex(m, k, 0, mr, kr, sr);
600
+ const size_t rhs_packed_size = kernels->rhs_info.packed_size_ex(n, k, nr, kr, 0);
601
+ const size_t kxn_size = k * n * sizeof(float);
602
+ const size_t bias_size = n * sizeof(float);
235
603
 
236
604
  const size_t wsize_required = lhs_packed_size + rhs_packed_size + kxn_size + bias_size;
237
605
  GGML_ASSERT(wsize_required <= params->wsize);
@@ -259,10 +627,8 @@ class tensor_traits : public ggml::cpu::tensor_traits {
259
627
  const int64_t m_count = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
260
628
 
261
629
  // Base packed offset (aligned) and per-row stride in bytes
262
- const size_t base_packed_off = variant_call<size_t>(
263
- lhs_info->get_packed_offset, (size_t)m_start, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr);
264
- const size_t next_block_off = variant_call<size_t>(
265
- lhs_info->get_packed_offset, (size_t)(m_start + mr), (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr);
630
+ const size_t base_packed_off = lhs_info->get_packed_offset_ex(m_start, k, 0, mr, kr, sr);
631
+ const size_t next_block_off = lhs_info->get_packed_offset_ex(m_start + mr, k, 0, mr, kr, sr);
266
632
  const size_t row_stride_bytes = (next_block_off - base_packed_off) / (size_t)mr;
267
633
 
268
634
  int64_t remaining = m_count;
@@ -278,9 +644,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
278
644
  const size_t dst_off = base_packed_off + (size_t)(cur - m_start) * row_stride_bytes;
279
645
  void * dst_ptr = lhs_packed + dst_off;
280
646
 
281
- variant_call<void>(lhs_info->pack_func,
282
- (size_t)take, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr,
283
- /*m_idx_start*/ 0, src_ptr, lhs_stride, dst_ptr);
647
+ lhs_info->pack_func_ex(take, k, 0, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr);
284
648
 
285
649
  cur += take;
286
650
  remaining -= take;
@@ -296,10 +660,8 @@ class tensor_traits : public ggml::cpu::tensor_traits {
296
660
  reinterpret_cast<const uint16_t *>(rhs_batch_base),
297
661
  rhs_stride);
298
662
 
299
- variant_call<void>(kernels->rhs_info.pack_func,
300
- /*num_groups*/ 1, (size_t)n, (size_t)k, (size_t)nr, (size_t)kr, (size_t)sr,
301
- /*rhs_stride (bytes)*/ (size_t)(n * sizeof(float)),
302
- rhs_kxn, bias, nullptr, rhs_packed, /*extra_bytes*/ 0, /*params*/ nullptr);
663
+ kernels->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, 0, n * sizeof(float),
664
+ rhs_kxn, bias, nullptr, rhs_packed, 0, nullptr);
303
665
  }
304
666
 
305
667
  ggml_barrier(params->threadpool);
@@ -320,20 +682,15 @@ class tensor_traits : public ggml::cpu::tensor_traits {
320
682
  const int64_t n_to_process = (ith == num_threads_n - 1) ? num_n_per_threadN_1 : num_n_per_thread0;
321
683
 
322
684
  // LHS packed base at row 0 (consistent with packing above)
323
- const size_t lhs_packed_offset0 = variant_call<size_t>(
324
- lhs_info->get_packed_offset, (size_t)0, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr);
325
- const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, (size_t)n_start, (size_t)k);
326
- const size_t dst_offset = kernel->get_dst_offset((size_t)0, (size_t)n_start, dst_stride);
685
+ const size_t lhs_packed_offset0 = lhs_info->get_packed_offset_ex(0, k, 0, mr, kr, sr);
686
+ const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, 0);
687
+ const size_t dst_offset = kernel->get_dst_offset((size_t)0, (size_t)n_start, dst_stride);
327
688
 
328
689
  const void * lhs_ptr = lhs_packed + lhs_packed_offset0;
329
690
  const void * rhs_ptr = rhs_packed + rhs_packed_offset;
330
691
  float * dst_ptr = reinterpret_cast<float *>(dst_batch_base + dst_offset);
331
692
 
332
- variant_call<void>(kernel->run_kernel,
333
- (size_t)m, (size_t)n_to_process, (size_t)k,
334
- lhs_ptr, rhs_ptr,
335
- dst_ptr, dst_stride, sizeof(float),
336
- -FLT_MAX, FLT_MAX);
693
+ kernel->run_kernel_ex(m, n_to_process, k, 0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
337
694
  }
338
695
  }
339
696
 
@@ -345,108 +702,486 @@ class tensor_traits : public ggml::cpu::tensor_traits {
345
702
  return true;
346
703
  }
347
704
 
348
- bool compute_forward_q4_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
349
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
705
+ bool compute_forward_qx(struct ggml_compute_params * params, struct ggml_tensor * dst) {
706
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0);
350
707
 
351
708
  const ggml_tensor * src0 = dst->src[0];
352
709
  const ggml_tensor * src1 = dst->src[1];
353
710
 
354
711
  GGML_TENSOR_BINARY_OP_LOCALS
355
712
 
356
- ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
357
- GGML_ASSERT(kernels);
713
+ const kleidiai_weight_header * header = kleidiai_weight_header_from_ptr(src0->data);
714
+ const bool has_header = kleidiai_is_weight_header_valid(header);
715
+ const bool is_gemv = src1->ne[1] == 1;
716
+ std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
717
+ const int slot_total = kleidiai_collect_kernel_chain(dst, kernel_chain);
358
718
 
359
- bool is_gemv = src1->ne[1] == 1;
360
- kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
361
- lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
719
+ auto weight_for_slot = [&](int slot_index, size_t & size_out) -> const uint8_t * {
720
+ if (slot_index < 0 || slot_index >= slot_total) {
721
+ return nullptr;
722
+ }
723
+ if (has_header) {
724
+ if (slot_index < header->slot_count) {
725
+ size_out = static_cast<size_t>(header->sizes[slot_index]);
726
+ return kleidiai_weight_slot_ptr(header, slot_index);
727
+ }
728
+ return nullptr;
729
+ }
730
+ if (slot_index == 0) {
731
+ size_out = ggml_nbytes(src0);
732
+ return static_cast<const uint8_t *>(src0->data);
733
+ }
734
+ return nullptr;
735
+ };
736
+
737
+ struct runtime_slot {
738
+ int slot_index;
739
+ ggml_kleidiai_kernels * kernels;
740
+ kernel_info * kernel;
741
+ lhs_packing_info * lhs_info;
742
+ size_t mr;
743
+ size_t nr;
744
+ size_t kr;
745
+ size_t sr;
746
+ size_t n_step;
747
+ size_t lhs_packed_size;
748
+ size_t lhs_offset;
749
+ size_t n_offset;
750
+ size_t n_cols;
751
+ int assigned_threads;
752
+ int thread_begin;
753
+ int thread_end;
754
+ const uint8_t * rhs_base;
755
+ };
756
+
757
+ std::array<runtime_slot, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> runtime{};
758
+ int runtime_count = 0;
759
+
760
+ for (int slot = 0; slot < slot_total && runtime_count < GGML_KLEIDIAI_MAX_KERNEL_SLOTS; ++slot) {
761
+ ggml_kleidiai_kernels * kernels = kernel_chain[slot];
762
+ kernel_info * kinfo = is_gemv ? &kernels->gemv : &kernels->gemm;
763
+ lhs_packing_info * linfo = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
764
+ if (!kinfo || !linfo || !linfo->packed_size_ex || !linfo->pack_func_ex || !linfo->get_offset ||
765
+ !kinfo->get_rhs_packed_offset_ex || !kinfo->run_kernel_ex || !kinfo->get_dst_offset) {
766
+ continue;
767
+ }
362
768
 
363
- GGML_ASSERT(kernel);
769
+ size_t rhs_size = 0;
770
+ const uint8_t * rhs_ptr = weight_for_slot(slot, rhs_size);
771
+ if (!rhs_ptr || rhs_size == 0) {
772
+ continue;
773
+ }
364
774
 
365
- const int ith = params->ith;
366
- const int nth_raw = params->nth;
367
- const int nth = nth_raw > 0 ? nth_raw : 1;
775
+ runtime[runtime_count] = {
776
+ slot,
777
+ kernels,
778
+ kinfo,
779
+ linfo,
780
+ kinfo->get_mr(),
781
+ kinfo->get_nr(),
782
+ kinfo->get_kr(),
783
+ kinfo->get_sr(),
784
+ kinfo->get_n_step(),
785
+ 0,
786
+ 0,
787
+ 0,
788
+ 0,
789
+ 0,
790
+ 0,
791
+ 0,
792
+ rhs_ptr
793
+ };
794
+ ++runtime_count;
795
+ }
796
+
797
+ if (runtime_count == 0) {
798
+ ggml_kleidiai_kernels * fallback = ggml_kleidiai_select_kernels(ctx.features, dst);
799
+ if (!fallback) {
800
+ return false;
801
+ }
802
+ kernel_info * kinfo = is_gemv ? &fallback->gemv : &fallback->gemm;
803
+ lhs_packing_info * linfo = is_gemv ? &fallback->gemv_lhs_info : &fallback->gemm_lhs_info;
804
+ rhs_packing_info * rinfo = &fallback->rhs_info;
805
+ if (!kinfo || !linfo || !linfo->packed_size_ex || !linfo->pack_func_ex ||
806
+ !kinfo->get_rhs_packed_offset_ex || !kinfo->run_kernel_ex || !kinfo->get_dst_offset ||
807
+ !rinfo || !rinfo->pack_func_ex || !rinfo->packed_size_ex) {
808
+ return false;
809
+ }
810
+ kernel_chain[0] = fallback;
811
+ runtime[0] = {
812
+ 0,
813
+ fallback,
814
+ kinfo,
815
+ linfo,
816
+ kinfo->get_mr(),
817
+ kinfo->get_nr(),
818
+ kinfo->get_kr(),
819
+ kinfo->get_sr(),
820
+ kinfo->get_n_step(),
821
+ 0,
822
+ 0,
823
+ 0,
824
+ 0,
825
+ 0,
826
+ 0,
827
+ 0,
828
+ nullptr
829
+ };
830
+ size_t rhs_size_fallback = 0;
831
+ const uint8_t * rhs_base = weight_for_slot(0, rhs_size_fallback);
832
+ if (!rhs_base) {
833
+ rhs_base = static_cast<const uint8_t *>(src0->data);
834
+ }
835
+ runtime[0].rhs_base = rhs_base;
836
+ runtime_count = 1;
837
+ }
838
+
839
+ const int nth_total = params->nth > 0 ? params->nth : 1;
840
+ const int ith_total = params->ith;
841
+
842
+ int sme_slot = -1;
843
+ for (int i = 0; i < runtime_count; ++i) {
844
+ if ((runtime[i].kernels->required_cpu & CPU_FEATURE_SME) == CPU_FEATURE_SME) {
845
+ sme_slot = i;
846
+ break;
847
+ }
848
+ }
849
+
850
+ const int sme_cap_limit = ctx.sme_thread_cap;
851
+ const bool use_hybrid = sme_cap_limit > 0 &&
852
+ runtime_count > 1 &&
853
+ nth_total > sme_cap_limit;
854
+ // Heuristic: disable hybrid for very small workloads where per-slot overhead dominates.
855
+ // If rows are small or average columns per thread are small, keep single-slot.
856
+ size_t min_cols_per_thread = 0;
857
+ if (runtime_count > 0 && nth_total > 0) {
858
+ min_cols_per_thread = (size_t) std::max<int64_t>(1, (int64_t)ne01 / (int64_t)nth_total);
859
+ }
860
+ const bool too_small_for_hybrid = (min_cols_per_thread < 2) || (ne11 < 128);
861
+
862
+ const bool hybrid_enabled = use_hybrid && !too_small_for_hybrid;
863
+
864
+ if (!hybrid_enabled) {
865
+ int chosen_slot = 0;
866
+ if (too_small_for_hybrid && sme_slot != -1) {
867
+ chosen_slot = sme_slot;
868
+ } else if (runtime_count > 1 && ctx.sme_thread_cap > 0 && nth_total > ctx.sme_thread_cap) {
869
+ chosen_slot = 1;
870
+ }
871
+ if (chosen_slot != 0 && chosen_slot < runtime_count) {
872
+ runtime[0] = runtime[chosen_slot];
873
+ }
874
+ runtime_count = runtime_count > 0 ? 1 : 0;
875
+
876
+ // Recompute SME slot based on the collapsed runtime[0]
877
+ sme_slot = -1;
878
+ if (runtime_count > 0 &&
879
+ (runtime[0].kernels->required_cpu & CPU_FEATURE_SME) == CPU_FEATURE_SME) {
880
+ sme_slot = 0;
881
+ }
882
+ }
883
+
884
+ int sme_cap = kleidiai_sme_thread_cap();
885
+ if (sme_cap < 0) {
886
+ sme_cap = nth_total;
887
+ }
888
+ sme_cap = std::min(sme_cap, nth_total);
889
+
890
+ int threads_remaining = nth_total;
891
+ if (sme_slot != -1) {
892
+ int sme_threads = std::min(std::max(sme_cap, 0), threads_remaining);
893
+ runtime[sme_slot].assigned_threads = sme_threads;
894
+ threads_remaining -= sme_threads;
895
+ }
896
+
897
+ int fallback_indices[GGML_KLEIDIAI_MAX_KERNEL_SLOTS];
898
+ int fallback_count = 0;
899
+ for (int i = 0; i < runtime_count; ++i) {
900
+ if (i == sme_slot) {
901
+ continue;
902
+ }
903
+ fallback_indices[fallback_count++] = i;
904
+ }
905
+
906
+ for (int fi = 0; fi < fallback_count; ++fi) {
907
+ if (threads_remaining <= 0) {
908
+ break;
909
+ }
910
+ const int slot_index = fallback_indices[fi];
911
+ const int slots_left = fallback_count - fi;
912
+ int share = (threads_remaining + slots_left - 1) / slots_left;
913
+ share = std::min(share, threads_remaining);
914
+ runtime[slot_index].assigned_threads = share;
915
+ threads_remaining -= share;
916
+ }
917
+
918
+ if (threads_remaining > 0) {
919
+ const int fallback_slot = (sme_slot != -1) ? sme_slot : 0;
920
+ runtime[fallback_slot].assigned_threads += threads_remaining;
921
+ threads_remaining = 0;
922
+ }
923
+
924
+ int thread_cursor = 0;
925
+ for (int i = 0; i < runtime_count; ++i) {
926
+ runtime[i].thread_begin = thread_cursor;
927
+ thread_cursor += runtime[i].assigned_threads;
928
+ runtime[i].thread_end = thread_cursor;
929
+ }
930
+
931
+ if (thread_cursor < nth_total && runtime_count > 0) {
932
+ runtime[runtime_count - 1].assigned_threads += nth_total - thread_cursor;
933
+ runtime[runtime_count - 1].thread_end = nth_total;
934
+ }
935
+
936
+ int local_slot = -1;
937
+ int local_ith = 0;
938
+ for (int i = 0; i < runtime_count; ++i) {
939
+ if (ith_total >= runtime[i].thread_begin && ith_total < runtime[i].thread_end) {
940
+ local_slot = i;
941
+ local_ith = ith_total - runtime[i].thread_begin;
942
+ break;
943
+ }
944
+ }
945
+ if (local_slot == -1) {
946
+ return false;
947
+ }
368
948
 
369
949
  const size_t k = ne00;
370
950
  const size_t m = ne11;
371
951
  const size_t n = ne01;
372
952
 
373
- size_t mr = kernel->get_mr();
374
- size_t kr = kernel->get_kr();
375
- size_t sr = kernel->get_sr();
376
-
377
- const uint8_t * lhs = static_cast<const uint8_t *>(src1->data);
378
- uint8_t * lhs_packed = (uint8_t*)params->wdata;
379
- const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
953
+ size_t cursor = 0;
954
+ for (int i = 0; i < runtime_count; ++i) {
955
+ const ggml_type slot_rhs_type = runtime[i].kernels->rhs_type;
956
+ const size_t slot_pack_size_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
957
+ slot_rhs_type == GGML_TYPE_Q8_0 ? QK8_0 : 0;
958
+ runtime[i].lhs_packed_size = runtime[i].lhs_info->packed_size_ex(m, k, slot_pack_size_arg, runtime[i].mr, runtime[i].kr, runtime[i].sr);
959
+ cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
960
+ runtime[i].lhs_offset = cursor;
961
+ cursor += runtime[i].lhs_packed_size;
962
+ }
380
963
 
381
- const size_t n_step = kernel->get_n_step();
382
- const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
383
- const size_t n_start = ith * num_n_per_thread;
964
+ GGML_ASSERT(cursor <= params->wsize);
965
+ uint8_t * scratch = static_cast<uint8_t *>(params->wdata);
384
966
 
385
- size_t n_to_process = 0;
386
- if (n_start < n) {
387
- n_to_process = num_n_per_thread;
388
- if ((n_start + n_to_process) > n) {
389
- n_to_process = n - n_start;
967
+ size_t assigned_cols = 0;
968
+ uint64_t weighted_total = 0;
969
+ if (runtime_count > 1 && sme_slot != -1) {
970
+ for (int i = 0; i < runtime_count; ++i) {
971
+ const uint64_t weight = (i == sme_slot) ? (sme_cap << 1) : 1;
972
+ weighted_total += (uint64_t)runtime[i].assigned_threads * weight;
390
973
  }
391
974
  }
975
+ for (int i = 0; i < runtime_count; ++i) {
976
+ runtime[i].n_offset = assigned_cols;
977
+ if (runtime[i].assigned_threads == 0) {
978
+ runtime[i].n_cols = 0;
979
+ continue;
980
+ }
981
+ const size_t remaining_cols = n - assigned_cols;
982
+ if (remaining_cols == 0) {
983
+ runtime[i].n_cols = 0;
984
+ continue;
985
+ }
986
+ const size_t step = runtime[i].n_step ? runtime[i].n_step : 1;
987
+ size_t target = 0;
988
+ if (weighted_total > 0) {
989
+ const uint64_t weight = (i == sme_slot) ? (sme_cap << 1) : 1;
990
+ target = (size_t)(((uint64_t)n * runtime[i].assigned_threads * weight) / weighted_total);
991
+ } else {
992
+ target = (size_t)(((uint64_t)n * runtime[i].assigned_threads) / nth_total);
993
+ }
994
+ target = std::min(target, remaining_cols);
995
+ size_t aligned = round_down(target, step);
996
+ if (aligned == 0 && remaining_cols >= step) {
997
+ aligned = step;
998
+ }
999
+ runtime[i].n_cols = aligned;
1000
+ assigned_cols += aligned;
1001
+ }
392
1002
 
393
- // Calculate number of columns to be processed per thread
394
- const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
395
- const size_t m_start = ith * num_m_per_thread;
396
- size_t m_to_process = num_m_per_thread;
397
- if ((m_start + m_to_process) > m) {
398
- m_to_process = m - m_start;
1003
+ if (assigned_cols < n) {
1004
+ for (int i = runtime_count - 1; i >= 0; --i) {
1005
+ if (runtime[i].assigned_threads > 0) {
1006
+ runtime[i].n_cols += n - assigned_cols;
1007
+ break;
1008
+ }
1009
+ }
399
1010
  }
1011
+ const size_t dst_stride = dst->nb[1];
1012
+
1013
+ for (int64_t batch_idx = 0; batch_idx < ne12; ++batch_idx) {
1014
+ const uint8_t * lhs_batch_base = static_cast<const uint8_t *>(src1->data) + batch_idx * src1->nb[2];
1015
+ uint8_t * dst_batch_base = static_cast<uint8_t *>(dst->data) + batch_idx * dst->nb[2];
400
1016
 
401
- if (m_start < m) {
402
- // Transform LHS
403
- const size_t src_stride = src1->nb[1];
404
- const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
405
- const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, m_start, k, QK4_0, mr, kr, sr);
406
- void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
1017
+ if (runtime[local_slot].assigned_threads > 0) {
1018
+ runtime_slot & slot = runtime[local_slot];
1019
+ const ggml_type slot_rhs_type = slot.kernels->rhs_type;
1020
+ const size_t slot_lhs_exec_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
1021
+ slot_rhs_type == GGML_TYPE_Q8_0 ? 0 : 0;
1022
+ const int64_t m_roundup_mr = kai_roundup((int64_t)m, (int64_t)slot.mr);
1023
+ int64_t max_threads = slot.mr ? (m_roundup_mr / (int64_t)slot.mr) : slot.assigned_threads;
1024
+ max_threads = std::max<int64_t>(1, max_threads);
1025
+ const int64_t use_threads = std::min<int64_t>(slot.assigned_threads, max_threads);
407
1026
 
408
- variant_call<void>(lhs_info->pack_func, m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
409
- }
1027
+ if (local_ith < use_threads) {
1028
+ const int64_t num_m_per_thread0 = round_down((size_t)(m_roundup_mr / use_threads), slot.mr);
1029
+ const int64_t num_m_per_threadN_1 = (int64_t)m - (use_threads - 1) * num_m_per_thread0;
1030
+
1031
+ const int64_t m_start = (int64_t)local_ith * num_m_per_thread0;
1032
+ const int64_t m_count = (local_ith == use_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
1033
+
1034
+ const size_t base_packed_off = slot.lhs_info->get_packed_offset_ex(m_start, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr);
1035
+ const size_t next_block_off = slot.lhs_info->get_packed_offset_ex(m_start + slot.mr, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr);
1036
+ const size_t row_stride_bytes = slot.mr ? (next_block_off - base_packed_off) / slot.mr : 0;
1037
+
1038
+ int64_t remaining = m_count;
1039
+ int64_t cur = m_start;
1040
+
1041
+ uint8_t * lhs_packed = scratch + slot.lhs_offset;
1042
+ while (remaining > 0) {
1043
+ const int64_t row_in_group = cur;
1044
+ const int64_t avail = (int64_t)m - row_in_group;
1045
+ const int64_t take = std::min(avail, remaining);
1046
+
1047
+ const size_t src_off = slot.lhs_info->get_offset(row_in_group, src1->nb[1]);
1048
+ const void * src_ptr = lhs_batch_base + src_off;
1049
+ const size_t dst_off = base_packed_off + (size_t)(cur - m_start) * row_stride_bytes;
1050
+ void * dst_ptr = lhs_packed + dst_off;
1051
+
1052
+ slot.lhs_info->pack_func_ex(take, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr, 0, src_ptr, src1->nb[1], dst_ptr);
1053
+
1054
+ cur += take;
1055
+ remaining -= take;
1056
+ }
1057
+ }
1058
+ }
410
1059
 
411
- ggml_barrier(params->threadpool);
1060
+ ggml_barrier(params->threadpool);
412
1061
 
413
- // Perform the operation
414
- const size_t dst_stride = dst->nb[1];
415
- const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, 0, k, QK4_0, mr, kr, sr);
416
- const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, n_start, k, QK4_0);
417
- const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride);
418
- const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
419
- const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset);
420
- float *dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
1062
+ runtime_slot & slot = runtime[local_slot];
1063
+ if (slot.n_cols > 0 && slot.assigned_threads > 0) {
1064
+ int64_t active_threads = slot.assigned_threads;
1065
+ const int64_t max_threads = slot.n_step ? (slot.n_cols / slot.n_step) : slot.assigned_threads;
1066
+ if (max_threads > 0) {
1067
+ active_threads = std::min<int64_t>(active_threads, std::max<int64_t>(1, max_threads));
1068
+ }
1069
+ active_threads = std::max<int64_t>(1, active_threads);
1070
+
1071
+ if (local_ith < active_threads) {
1072
+ const size_t step = slot.n_step ? slot.n_step : 1;
1073
+ const size_t chunk0 = round_down((size_t)(slot.n_cols / active_threads), step);
1074
+ const size_t chunkN = slot.n_cols - (active_threads - 1) * chunk0;
1075
+ const size_t local_start = (size_t)local_ith * chunk0;
1076
+ const size_t cols = (local_ith == active_threads - 1) ? chunkN : chunk0;
1077
+
1078
+ if (cols > 0) {
1079
+ const ggml_type slot_rhs_type = slot.kernels->rhs_type;
1080
+ const size_t slot_lhs_exec_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
1081
+ slot_rhs_type == GGML_TYPE_Q8_0 ? 0 : 0;
1082
+ const size_t slot_rhs_block_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
1083
+ slot_rhs_type == GGML_TYPE_Q8_0 ? 0 : 0;
1084
+ const size_t global_start = slot.n_offset + local_start;
1085
+ const size_t lhs_packed_offset = slot.lhs_info->get_packed_offset_ex(0, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr);
1086
+ const size_t rhs_packed_offset = slot.kernel->get_rhs_packed_offset_ex(global_start, k, slot_rhs_block_arg);
1087
+ const size_t dst_offset = slot.kernel->get_dst_offset(0, global_start, dst_stride);
1088
+
1089
+ const uint8_t * lhs_ptr = scratch + slot.lhs_offset + lhs_packed_offset;
1090
+ const uint8_t * rhs_ptr = slot.rhs_base + rhs_packed_offset;
1091
+ float * dst_ptr = reinterpret_cast<float *>(dst_batch_base + dst_offset);
1092
+
1093
+ slot.kernel->run_kernel_ex(m, cols, k, slot_rhs_block_arg,
1094
+ lhs_ptr,
1095
+ rhs_ptr,
1096
+ dst_ptr,
1097
+ dst_stride,
1098
+ sizeof(float),
1099
+ -FLT_MAX,
1100
+ FLT_MAX);
1101
+ }
1102
+ }
1103
+ }
421
1104
 
422
- if (n_to_process > 0) {
423
- variant_call<void>(kernel->run_kernel, m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
424
- sizeof(float), -FLT_MAX, FLT_MAX);
1105
+ if (batch_idx != ne12 - 1) {
1106
+ ggml_barrier(params->threadpool);
1107
+ }
425
1108
  }
426
1109
 
427
1110
  return true;
428
1111
  }
429
1112
 
430
1113
  bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
431
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
432
- GGML_ASSERT(ctx.kernels);
433
-
1114
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0);
434
1115
  const ggml_tensor * src0 = dst->src[0];
435
1116
  const ggml_tensor * src1 = dst->src[1];
436
1117
 
437
1118
  GGML_TENSOR_BINARY_OP_LOCALS
438
1119
 
439
- rhs_packing_info * rhs_info = &ctx.kernels->rhs_info;
440
- kernel_info * kernel = &ctx.kernels->gemm;
1120
+ const kleidiai_weight_header * header = kleidiai_weight_header_from_ptr(src0->data);
1121
+ const bool has_header = kleidiai_is_weight_header_valid(header);
1122
+
1123
+ std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
1124
+ const bool want_q8 = src0->type == GGML_TYPE_Q8_0;
1125
+ const int chain_count = want_q8 ? kleidiai_collect_q8_chain(kernel_chain)
1126
+ : kleidiai_collect_q4_chain(kernel_chain);
1127
+
1128
+ ggml_kleidiai_kernels * kernels = nullptr;
1129
+ const uint8_t * packed_base = static_cast<const uint8_t *>(src0->data);
1130
+
1131
+ if (has_header && chain_count > 0) {
1132
+ int select_slot = 0;
1133
+ if (select_slot >= header->slot_count) {
1134
+ select_slot = header->slot_count - 1;
1135
+ }
1136
+ if (select_slot >= 0 && select_slot < chain_count) {
1137
+ kernels = kernel_chain[select_slot];
1138
+ const uint8_t * slot_ptr = kleidiai_weight_slot_ptr(header, select_slot);
1139
+ if (slot_ptr) {
1140
+ packed_base = slot_ptr;
1141
+ }
1142
+ }
1143
+ }
1144
+
1145
+ if (!kernels && chain_count > 0) {
1146
+ kernels = kernel_chain[0];
1147
+ if (has_header) {
1148
+ const uint8_t * slot_ptr = kleidiai_weight_slot_ptr(header, 0);
1149
+ if (slot_ptr) {
1150
+ packed_base = slot_ptr;
1151
+ }
1152
+ }
1153
+ }
1154
+
1155
+ if (!kernels) {
1156
+ return false;
1157
+ }
1158
+
1159
+ rhs_packing_info * rhs_info = &kernels->rhs_info;
1160
+ kernel_info * kernel = &kernels->gemm;
1161
+ if (!rhs_info->to_float || !kernel->get_nr) {
1162
+ return false;
1163
+ }
441
1164
 
442
1165
  const int64_t nc = ne00;
443
1166
  const int64_t nr = ggml_nelements(src1);
444
1167
 
1168
+ const ggml_type rhs_type = kernels->rhs_type;
1169
+ size_t block_len = 0;
1170
+ size_t num_bytes_multiplier = 0;
1171
+ if (rhs_type == GGML_TYPE_Q4_0) {
1172
+ block_len = QK4_0;
1173
+ num_bytes_multiplier = sizeof(uint16_t);
1174
+ } else if (rhs_type == GGML_TYPE_Q8_0) {
1175
+ block_len = QK8_0;
1176
+ num_bytes_multiplier = sizeof(float);
1177
+ } else {
1178
+ return false;
1179
+ }
1180
+
445
1181
  const size_t block_rows = kernel->get_nr();
446
1182
  const size_t kr = kernel->get_kr();
447
1183
 
448
- const size_t num_bytes_multiplier = sizeof(uint16_t);
449
- const size_t packed_stride = rhs_info->packed_stride(nc, block_rows, kr, QK4_0);
1184
+ const size_t packed_stride = rhs_info->packed_stride(nc, block_rows, kr, block_len);
450
1185
 
451
1186
  const int ith = params->ith;
452
1187
  const int nth = params->nth;
@@ -461,7 +1196,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
461
1196
  GGML_ASSERT(row_idx >= 0 && row_idx < src0->ne[1]);
462
1197
 
463
1198
  float *out = (float *)((char *)dst->data + i * nb1);
464
- rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, QK4_0, num_bytes_multiplier);
1199
+ rhs_info->to_float(packed_base, row_idx, nc, out, block_rows, packed_stride, kr, block_len, num_bytes_multiplier);
465
1200
  }
466
1201
 
467
1202
  return true;
@@ -469,21 +1204,136 @@ class tensor_traits : public ggml::cpu::tensor_traits {
469
1204
 
470
1205
  public:
471
1206
  int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
472
- GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0);
473
- GGML_ASSERT(ctx.kernels);
1207
+ GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q8_0);
474
1208
  const size_t n = tensor->ne[1];
475
1209
  const size_t k = tensor->ne[0];
476
- size_t nr = ctx.kernels->gemm.get_nr();
477
- size_t kr = ctx.kernels->gemm.get_kr();
478
- size_t sr = ctx.kernels->gemm.get_sr();
479
1210
 
480
- struct kai_rhs_pack_qs4cxs1s0_param params;
481
- params.lhs_zero_point = 1;
482
- params.rhs_zero_point = 8;
483
- variant_call<void>(ctx.kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, QK4_0, (const uint8_t*)data, nullptr, tensor->data, 0, &params);
1211
+ kleidiai_weight_header * header = kleidiai_weight_header_from_ptr(tensor->data);
1212
+ if (!header) {
1213
+ return -1;
1214
+ }
1215
+
1216
+ header->magic = GGML_KLEIDIAI_PACK_MAGIC;
1217
+ header->version = GGML_KLEIDIAI_PACK_VERSION;
1218
+ header->slot_count = 0;
1219
+
1220
+ uint8_t * base_ptr = static_cast<uint8_t *>(tensor->data);
1221
+ size_t cursor = sizeof(kleidiai_weight_header);
1222
+ cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
1223
+
1224
+ std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
1225
+ const bool want_q8 = tensor->type == GGML_TYPE_Q8_0;
1226
+ const int slot_total = want_q8 ? kleidiai_collect_q8_chain(kernel_chain)
1227
+ : kleidiai_collect_q4_chain(kernel_chain);
1228
+ const bool allow_fallback = kleidiai_pack_fallback_allowed();
1229
+
1230
+ std::vector<int8_t> qdata;
1231
+ std::vector<float> scales;
1232
+
1233
+ if (want_q8 && slot_total > 0) {
1234
+ qdata.resize(n * k, 0);
1235
+ scales.resize(n, 0.0f);
1236
+
1237
+ const size_t row_stride = tensor->nb[1];
1238
+ const size_t k_blocks = (k + QK8_0 - 1) / QK8_0;
1239
+
1240
+ for (size_t row = 0; row < n; ++row) {
1241
+ const auto * row_blocks = reinterpret_cast<const block_q8_0 *>(
1242
+ static_cast<const uint8_t *>(data) + row * row_stride);
1243
+
1244
+ float max_abs = 0.0f;
1245
+ for (size_t block = 0; block < k_blocks; ++block) {
1246
+ const block_q8_0 & blk = row_blocks[block];
1247
+ const float d = GGML_FP16_TO_FP32(blk.d);
1248
+ for (size_t l = 0; l < QK8_0; ++l) {
1249
+ const size_t linear_idx = block * QK8_0 + l;
1250
+ if (linear_idx >= k) {
1251
+ break;
1252
+ }
1253
+ const float value = d * static_cast<float>(blk.qs[l]);
1254
+ max_abs = std::max(max_abs, std::fabs(value));
1255
+ }
1256
+ }
1257
+
1258
+ float scale = max_abs > 0.0f ? max_abs / 127.0f : 0.0f;
1259
+ scales[row] = scale;
1260
+ const float inv_scale = scale > 0.0f ? 1.0f / scale : 0.0f;
1261
+
1262
+ for (size_t block = 0; block < k_blocks; ++block) {
1263
+ const block_q8_0 & blk = row_blocks[block];
1264
+ const float d = GGML_FP16_TO_FP32(blk.d);
1265
+ for (size_t l = 0; l < QK8_0; ++l) {
1266
+ const size_t linear_idx = block * QK8_0 + l;
1267
+ if (linear_idx >= k) {
1268
+ break;
1269
+ }
1270
+ const float value = d * static_cast<float>(blk.qs[l]);
1271
+ int32_t q = scale > 0.0f ? static_cast<int32_t>(std::lround(value * inv_scale)) : 0;
1272
+ q = std::clamp(q, -127, 127);
1273
+ qdata[row * k + linear_idx] = static_cast<int8_t>(q);
1274
+ }
1275
+ }
1276
+ }
1277
+ }
1278
+
1279
+ for (int slot = 0; slot < slot_total && slot < GGML_KLEIDIAI_MAX_KERNEL_SLOTS; ++slot) {
1280
+ if (!allow_fallback && slot > 0) {
1281
+ break;
1282
+ }
1283
+ ggml_kleidiai_kernels * kernels = kernel_chain[slot];
1284
+ kernel_info * kernel = &kernels->gemm;
1285
+ rhs_packing_info * rhs_info = &kernels->rhs_info;
1286
+ if (!rhs_info || !rhs_info->pack_func_ex || !rhs_info->packed_size_ex || !kernel) {
1287
+ continue;
1288
+ }
1289
+
1290
+ const size_t nr = kernel->get_nr();
1291
+ const size_t kr = kernel->get_kr();
1292
+ const size_t sr = kernel->get_sr();
1293
+ const ggml_type rhs_type = kernels->rhs_type;
1294
+ const size_t block_len = rhs_type == GGML_TYPE_Q8_0 ? QK8_0 :
1295
+ rhs_type == GGML_TYPE_Q4_0 ? QK4_0 : 0;
1296
+ if (block_len == 0) {
1297
+ continue;
1298
+ }
1299
+
1300
+ const size_t packed_size = rhs_info->packed_size_ex(n, k, nr, kr, block_len);
1301
+ const size_t aligned_cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
1302
+
1303
+ uint8_t * dst_ptr = base_ptr + aligned_cursor;
1304
+
1305
+ if (rhs_type == GGML_TYPE_Q4_0) {
1306
+ struct kai_rhs_pack_qs4cxs1s0_param params;
1307
+ params.lhs_zero_point = 1;
1308
+ params.rhs_zero_point = 8;
1309
+ rhs_info->pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0,
1310
+ static_cast<const uint8_t *>(data), nullptr, nullptr,
1311
+ dst_ptr, 0, &params);
1312
+ } else if (rhs_type == GGML_TYPE_Q8_0) {
1313
+ struct kai_rhs_pack_qsi8cx_params params;
1314
+ params.lhs_zero_point = 1;
1315
+ params.scale_multiplier = 1.0f;
1316
+ rhs_info->pack_func_ex(1, n, k, nr, kr, sr, 0, 0,
1317
+ qdata.data(), nullptr, scales.data(),
1318
+ dst_ptr, 0, &params);
1319
+ } else {
1320
+ continue;
1321
+ }
1322
+
1323
+ header->offsets[header->slot_count] = aligned_cursor;
1324
+ header->sizes[header->slot_count] = packed_size;
1325
+ ++header->slot_count;
1326
+
1327
+ cursor = aligned_cursor + packed_size;
1328
+ }
1329
+
1330
+ if (header->slot_count == 0) {
1331
+ header->magic = 0;
1332
+ header->version = 0;
1333
+ memcpy(tensor->data, data, data_size);
1334
+ }
484
1335
 
485
1336
  return 0;
486
- GGML_UNUSED(data_size);
487
1337
  }
488
1338
  };
489
1339
 
@@ -513,9 +1363,8 @@ static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t bu
513
1363
  }
514
1364
 
515
1365
  static const char * ggml_backend_cpu_kleidiai_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
516
- return "CPU_KLEIDIAI";
517
-
518
1366
  GGML_UNUSED(buft);
1367
+ return "CPU_KLEIDIAI";
519
1368
  }
520
1369
 
521
1370
  static ggml_backend_buffer_t ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
@@ -534,33 +1383,80 @@ static ggml_backend_buffer_t ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer(
534
1383
  }
535
1384
 
536
1385
  static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
537
- return TENSOR_ALIGNMENT;
538
-
539
1386
  GGML_UNUSED(buft);
1387
+ return TENSOR_ALIGNMENT;
540
1388
  }
541
1389
 
542
1390
  static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
543
- GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0);
544
- GGML_ASSERT(ctx.kernels);
1391
+ GGML_UNUSED(buft);
545
1392
 
546
- const size_t n = tensor->ne[1];
547
- const size_t k = tensor->ne[0];
548
- const size_t nr = ctx.kernels->gemm.get_nr();
549
- const size_t kr = ctx.kernels->gemm.get_kr();
1393
+ if (tensor->type != GGML_TYPE_Q4_0 && tensor->type != GGML_TYPE_Q8_0) {
1394
+ return ggml_nbytes(tensor);
1395
+ }
550
1396
 
551
- return variant_call<size_t>(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0);
1397
+ const size_t n = tensor->ne[1];
1398
+ const size_t k = tensor->ne[0];
552
1399
 
553
- GGML_UNUSED(buft);
1400
+ size_t cursor = sizeof(kleidiai_weight_header);
1401
+ cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
1402
+
1403
+ std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
1404
+ const bool want_q8 = tensor->type == GGML_TYPE_Q8_0;
1405
+ const int slot_total = want_q8 ? kleidiai_collect_q8_chain(kernel_chain)
1406
+ : kleidiai_collect_q4_chain(kernel_chain);
1407
+ const bool allow_fallback = kleidiai_pack_fallback_allowed();
1408
+
1409
+ size_t slot_count = 0;
1410
+ for (int slot = 0; slot < slot_total; ++slot) {
1411
+ if (!allow_fallback && slot > 0) {
1412
+ break;
1413
+ }
1414
+ ggml_kleidiai_kernels * kernels = kernel_chain[slot];
1415
+ if (!kernels) {
1416
+ continue;
1417
+ }
1418
+ kernel_info * kernel = &kernels->gemm;
1419
+ rhs_packing_info * rhs_info = &kernels->rhs_info;
1420
+ if (!kernel || !rhs_info || !rhs_info->packed_size_ex) {
1421
+ continue;
1422
+ }
1423
+
1424
+ const ggml_type rhs_type = kernels->rhs_type;
1425
+ const size_t block_len = rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
1426
+ rhs_type == GGML_TYPE_Q8_0 ? QK8_0 : 0;
1427
+ if (block_len == 0) {
1428
+ continue;
1429
+ }
1430
+
1431
+ cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
1432
+ cursor += rhs_info->packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), block_len);
1433
+ ++slot_count;
1434
+ }
1435
+
1436
+ if (slot_count == 0) {
1437
+ return ggml_nbytes(tensor);
1438
+ }
1439
+
1440
+ return std::max(cursor, ggml_nbytes(tensor));
554
1441
  }
555
1442
 
556
1443
  namespace ggml::cpu::kleidiai {
557
1444
  class extra_buffer_type : ggml::cpu::extra_buffer_type {
558
1445
  bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
1446
+ std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
1447
+ const int slot_total = kleidiai_collect_kernel_chain(op, kernel_chain);
559
1448
  if ((op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) &&
560
- op->src[0]->type == GGML_TYPE_Q4_0 &&
1449
+ (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q8_0) &&
561
1450
  op->src[0]->buffer &&
562
1451
  (ggml_n_dims(op->src[0]) == 2) &&
563
- op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) {
1452
+ op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() &&
1453
+ slot_total > 0) {
1454
+ if (op->src[0]->type == GGML_TYPE_Q4_0 && ctx.kernels_q4 == nullptr) {
1455
+ return false;
1456
+ }
1457
+ if (op->src[0]->type == GGML_TYPE_Q8_0 && ctx.kernels_q8 == nullptr) {
1458
+ return false;
1459
+ }
564
1460
  if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
565
1461
  return false;
566
1462
  }
@@ -576,14 +1472,17 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
576
1472
  if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) {
577
1473
  if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
578
1474
  return (ggml::cpu::tensor_traits *) op->src[0]->extra;
579
- }
580
- else if (ggml_kleidiai_select_kernels(ctx.features, op) && op->src[1]->ne[1] > 1) {
581
- if ((op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
582
- (op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) {
583
- return nullptr;
1475
+ } else {
1476
+ std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
1477
+ const int slot_total = kleidiai_collect_kernel_chain(op, kernel_chain);
1478
+ const bool has_kernel = slot_total > 0;
1479
+ if (has_kernel && op->src[1]->ne[1] > 1) {
1480
+ if ((op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
1481
+ (op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) {
1482
+ return nullptr;
1483
+ }
1484
+ return ggml::cpu::kleidiai::get_tensor_traits(NULL, NULL);
584
1485
  }
585
-
586
- return ggml::cpu::kleidiai::get_tensor_traits(NULL, NULL);
587
1486
  }
588
1487
  }
589
1488
  return nullptr;