whispercpp 1.3.5 → 1.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (610) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +99 -2
  4. data/ext/extconf.rb +1 -0
  5. data/ext/ruby_whisper.c +20 -4
  6. data/ext/ruby_whisper.h +30 -2
  7. data/ext/ruby_whisper_context.c +216 -124
  8. data/ext/ruby_whisper_context_params.c +163 -0
  9. data/ext/ruby_whisper_model.c +0 -1
  10. data/ext/ruby_whisper_params.c +0 -1
  11. data/ext/ruby_whisper_segment.c +0 -1
  12. data/ext/ruby_whisper_token.c +29 -9
  13. data/ext/ruby_whisper_transcribe.cpp +4 -1
  14. data/ext/ruby_whisper_vad_context.c +48 -1
  15. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  16. data/ext/ruby_whisper_vad_params.c +0 -1
  17. data/ext/ruby_whisper_vad_segment.c +0 -1
  18. data/ext/ruby_whisper_vad_segments.c +0 -1
  19. data/ext/sources/CMakeLists.txt +1 -1
  20. data/ext/sources/bindings/javascript/package.json +1 -1
  21. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  22. data/ext/sources/examples/bench/bench.cpp +23 -18
  23. data/ext/sources/examples/cli/cli.cpp +8 -0
  24. data/ext/sources/examples/common-ggml.cpp +2 -0
  25. data/ext/sources/examples/miniaudio.h +4507 -2131
  26. data/ext/sources/examples/server/server.cpp +18 -4
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -2
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +7 -13
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +4 -3
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +335 -17
  31. data/ext/sources/examples/talk-llama/llama-arch.h +42 -0
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +3 -1
  33. data/ext/sources/examples/talk-llama/llama-chat.cpp +21 -1
  34. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  35. data/ext/sources/examples/talk-llama/llama-context.cpp +508 -520
  36. data/ext/sources/examples/talk-llama/llama-context.h +27 -28
  37. data/ext/sources/examples/talk-llama/llama-cparams.h +5 -0
  38. data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +8 -8
  40. data/ext/sources/examples/talk-llama/llama-graph.cpp +583 -130
  41. data/ext/sources/examples/talk-llama/llama-graph.h +131 -10
  42. data/ext/sources/examples/talk-llama/llama-hparams.cpp +57 -40
  43. data/ext/sources/examples/talk-llama/llama-hparams.h +79 -10
  44. data/ext/sources/examples/talk-llama/llama-impl.cpp +4 -4
  45. data/ext/sources/examples/talk-llama/llama-impl.h +13 -1
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +3 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +274 -89
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.h +2 -3
  49. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
  50. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
  51. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +11 -13
  52. data/ext/sources/examples/talk-llama/llama-mmap.cpp +28 -11
  53. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +527 -119
  54. data/ext/sources/examples/talk-llama/llama-model-loader.h +35 -5
  55. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +60 -46
  56. data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
  57. data/ext/sources/examples/talk-llama/llama-model.cpp +1365 -647
  58. data/ext/sources/examples/talk-llama/llama-model.h +72 -19
  59. data/ext/sources/examples/talk-llama/llama-quant.cpp +578 -346
  60. data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +190 -76
  61. data/ext/sources/examples/talk-llama/{llama-sampling.h → llama-sampler.h} +0 -2
  62. data/ext/sources/examples/talk-llama/llama-vocab.cpp +118 -48
  63. data/ext/sources/examples/talk-llama/llama-vocab.h +5 -0
  64. data/ext/sources/examples/talk-llama/llama.cpp +76 -22
  65. data/ext/sources/examples/talk-llama/llama.h +63 -30
  66. data/ext/sources/examples/talk-llama/models/afmoe.cpp +2 -3
  67. data/ext/sources/examples/talk-llama/models/apertus.cpp +3 -3
  68. data/ext/sources/examples/talk-llama/models/arcee.cpp +3 -3
  69. data/ext/sources/examples/talk-llama/models/arctic.cpp +4 -5
  70. data/ext/sources/examples/talk-llama/models/baichuan.cpp +4 -3
  71. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +1 -2
  72. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +3 -5
  73. data/ext/sources/examples/talk-llama/models/bert.cpp +13 -7
  74. data/ext/sources/examples/talk-llama/models/bitnet.cpp +9 -24
  75. data/ext/sources/examples/talk-llama/models/bloom.cpp +2 -2
  76. data/ext/sources/examples/talk-llama/models/chameleon.cpp +3 -3
  77. data/ext/sources/examples/talk-llama/models/chatglm.cpp +2 -2
  78. data/ext/sources/examples/talk-llama/models/codeshell.cpp +3 -3
  79. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +3 -3
  80. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +2 -2
  81. data/ext/sources/examples/talk-llama/models/command-r.cpp +2 -2
  82. data/ext/sources/examples/talk-llama/models/dbrx.cpp +4 -5
  83. data/ext/sources/examples/talk-llama/models/deci.cpp +3 -3
  84. data/ext/sources/examples/talk-llama/models/deepseek.cpp +4 -6
  85. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +24 -21
  86. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
  87. data/ext/sources/examples/talk-llama/models/dots1.cpp +4 -6
  88. data/ext/sources/examples/talk-llama/models/dream.cpp +3 -3
  89. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +4 -6
  90. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +3 -3
  91. data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
  92. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +3 -3
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +3 -3
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +2 -4
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +3 -3
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +1 -1
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +1 -1
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +1 -1
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +1 -1
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +7 -7
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +3 -3
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +14 -7
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +2 -2
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +2 -2
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +4 -5
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +4 -5
  108. data/ext/sources/examples/talk-llama/models/grok.cpp +4 -4
  109. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +5 -7
  110. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +3 -3
  111. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +4 -5
  112. data/ext/sources/examples/talk-llama/models/internlm2.cpp +3 -3
  113. data/ext/sources/examples/talk-llama/models/jais.cpp +2 -2
  114. data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +3 -3
  116. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
  117. data/ext/sources/examples/talk-llama/models/lfm2.cpp +145 -124
  118. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +4 -4
  119. data/ext/sources/examples/talk-llama/models/llada.cpp +3 -3
  120. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +4 -4
  121. data/ext/sources/examples/talk-llama/models/llama.cpp +18 -11
  122. data/ext/sources/examples/talk-llama/models/maincoder.cpp +3 -3
  123. data/ext/sources/examples/talk-llama/models/{graph-context-mamba.cpp → mamba-base.cpp} +9 -3
  124. data/ext/sources/examples/talk-llama/models/mamba.cpp +1 -2
  125. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +11 -5
  126. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +14 -13
  127. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +4 -5
  128. data/ext/sources/examples/talk-llama/models/mistral3.cpp +4 -4
  129. data/ext/sources/examples/talk-llama/models/models.h +181 -46
  130. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +2 -9
  131. data/ext/sources/examples/talk-llama/models/mpt.cpp +2 -2
  132. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +26 -14
  133. data/ext/sources/examples/talk-llama/models/nemotron.cpp +3 -3
  134. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +2 -2
  135. data/ext/sources/examples/talk-llama/models/olmo.cpp +3 -3
  136. data/ext/sources/examples/talk-llama/models/olmo2.cpp +3 -3
  137. data/ext/sources/examples/talk-llama/models/olmoe.cpp +4 -4
  138. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +1 -1
  139. data/ext/sources/examples/talk-llama/models/openelm.cpp +3 -3
  140. data/ext/sources/examples/talk-llama/models/orion.cpp +3 -3
  141. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
  142. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +3 -3
  143. data/ext/sources/examples/talk-llama/models/phi2.cpp +2 -2
  144. data/ext/sources/examples/talk-llama/models/phi3.cpp +3 -3
  145. data/ext/sources/examples/talk-llama/models/plamo.cpp +3 -3
  146. data/ext/sources/examples/talk-llama/models/plamo2.cpp +9 -5
  147. data/ext/sources/examples/talk-llama/models/plamo3.cpp +2 -2
  148. data/ext/sources/examples/talk-llama/models/plm.cpp +15 -14
  149. data/ext/sources/examples/talk-llama/models/qwen.cpp +2 -2
  150. data/ext/sources/examples/talk-llama/models/qwen2.cpp +3 -3
  151. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +4 -4
  152. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +3 -3
  153. data/ext/sources/examples/talk-llama/models/qwen3.cpp +12 -9
  154. data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
  155. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
  156. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +15 -8
  157. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +84 -432
  158. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +9 -18
  159. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +8 -17
  160. data/ext/sources/examples/talk-llama/models/refact.cpp +2 -2
  161. data/ext/sources/examples/talk-llama/models/rnd1.cpp +4 -4
  162. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +2 -0
  163. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +2 -0
  164. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +3 -3
  165. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +4 -4
  166. data/ext/sources/examples/talk-llama/models/smollm3.cpp +3 -3
  167. data/ext/sources/examples/talk-llama/models/stablelm.cpp +2 -2
  168. data/ext/sources/examples/talk-llama/models/starcoder.cpp +2 -2
  169. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +3 -3
  170. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
  171. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +2 -2
  172. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +2 -2
  173. data/ext/sources/examples/talk-llama/models/xverse.cpp +3 -3
  174. data/ext/sources/examples/talk-llama/unicode.cpp +21 -65
  175. data/ext/sources/ggml/CMakeLists.txt +9 -3
  176. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  177. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  179. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  180. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  181. data/ext/sources/ggml/include/ggml-rpc.h +6 -1
  182. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  183. data/ext/sources/ggml/include/ggml.h +56 -9
  184. data/ext/sources/ggml/src/CMakeLists.txt +3 -0
  185. data/ext/sources/ggml/src/ggml-alloc.c +4 -9
  186. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  187. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  188. data/ext/sources/ggml/src/ggml-backend-reg.cpp +28 -86
  189. data/ext/sources/ggml/src/ggml-backend.cpp +5 -2
  190. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  191. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +6 -2
  192. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  193. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  194. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +348 -189
  195. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +40 -85
  196. data/ext/sources/ggml/src/ggml-cann/common.h +3 -4
  197. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +44 -62
  198. data/ext/sources/ggml/src/ggml-common.h +11 -0
  199. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +16 -11
  200. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -19
  201. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  202. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
  203. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +85 -1
  204. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2744 -548
  205. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1653 -0
  206. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
  207. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
  208. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
  209. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +118 -18
  210. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +107 -26
  211. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  212. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  213. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -0
  214. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +59 -12
  215. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +15 -0
  216. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  217. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +965 -252
  218. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +584 -197
  219. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +903 -188
  220. data/ext/sources/ggml/src/ggml-cpu/ops.h +1 -0
  221. data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
  222. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  223. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2890 -679
  224. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  225. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  226. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +111 -3
  227. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  228. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +17 -0
  229. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
  230. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +19 -10
  231. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +32 -30
  232. data/ext/sources/ggml/src/ggml-cuda/common.cuh +134 -18
  233. data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +6 -3
  235. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +78 -64
  236. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +384 -143
  237. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +36 -22
  238. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +3 -3
  239. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +26 -5
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  241. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +127 -12
  242. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
  243. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  244. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +595 -200
  245. data/ext/sources/ggml/src/ggml-cuda/mean.cu +9 -8
  246. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +173 -6
  247. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  248. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +158 -85
  249. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +34 -22
  250. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +127 -67
  251. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  252. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +157 -65
  253. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -0
  254. data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
  255. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  256. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +1 -1
  257. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
  258. data/ext/sources/ggml/src/ggml-cuda/rope.cu +233 -133
  259. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  260. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  261. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +56 -32
  262. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
  264. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  265. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
  266. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  267. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  268. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  269. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +3 -3
  270. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +0 -1
  271. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +199 -135
  272. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  273. data/ext/sources/ggml/src/ggml-cuda/unary.cu +55 -0
  274. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
  275. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
  276. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +10 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +82 -45
  278. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +334 -160
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +7 -5
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +328 -197
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +765 -234
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +412 -265
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +23 -23
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +1 -1
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.h → hex-dma.h} +28 -3
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +1 -1
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +27 -37
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +6 -35
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  297. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +20 -1347
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +211 -13
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +1119 -952
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +254 -244
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +36 -36
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +155 -138
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +209 -114
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  317. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  321. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +6 -0
  322. data/ext/sources/ggml/src/ggml-impl.h +62 -0
  323. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  324. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  325. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  326. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  327. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +274 -73
  328. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +22 -4
  329. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +102 -36
  330. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +174 -23
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +580 -280
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +5 -4
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +320 -107
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1068 -825
  335. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +19 -1
  336. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +3108 -636
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +204 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  367. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  368. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  369. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
  370. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  371. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
  372. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  373. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
  374. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
  375. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  376. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  377. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  378. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  379. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  380. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  381. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  382. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  383. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  384. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  385. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  386. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  387. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  388. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  389. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  390. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  391. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  392. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
  393. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  394. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  395. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  396. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  397. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  398. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  399. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
  400. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
  401. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
  402. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
  403. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  404. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  405. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  406. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  407. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
  412. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
  413. data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
  414. data/ext/sources/ggml/src/ggml-quants.c +96 -5
  415. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  416. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +15 -88
  417. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  418. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +1 -0
  419. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  420. data/ext/sources/ggml/src/ggml-sycl/common.hpp +315 -10
  421. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +69 -1
  422. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
  423. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  424. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  425. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +78 -68
  426. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  427. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
  428. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
  429. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
  430. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
  431. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
  432. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  433. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
  434. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
  435. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +316 -51
  436. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  437. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  438. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  439. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
  440. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  441. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  442. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  443. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  444. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  445. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  446. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  447. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  448. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  449. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  450. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  451. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  452. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
  453. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
  454. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
  455. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
  456. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
  457. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
  458. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
  459. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
  460. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
  461. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
  462. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
  463. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
  464. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
  465. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
  466. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
  467. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
  468. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
  469. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
  470. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
  471. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
  472. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
  473. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
  474. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
  475. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
  476. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
  477. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
  478. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
  479. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
  480. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
  481. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
  482. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
  483. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
  484. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
  485. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
  486. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
  487. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
  488. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +13 -0
  489. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  490. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  491. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  492. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  493. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  494. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  495. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  496. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  497. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  498. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  499. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  500. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  501. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  502. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  503. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  504. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  505. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  506. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  507. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  508. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  509. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  510. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  511. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  512. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  513. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  514. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  515. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  516. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  517. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  518. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  519. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  520. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  521. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  522. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  523. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  524. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  525. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  526. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  527. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  528. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  529. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  530. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  531. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  532. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +1 -1
  533. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1250 -465
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +374 -170
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +66 -22
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +389 -201
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +106 -58
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -3
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +8 -4
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +3 -3
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +36 -63
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +10 -5
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +16 -10
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +55 -35
  560. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1314 -109
  561. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1660 -1371
  562. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  563. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  564. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  565. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
  566. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
  567. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
  568. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +6 -0
  569. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  570. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +40 -5
  571. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +105 -60
  572. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
  573. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +68 -257
  574. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +692 -23
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_reg_tile.tmpl.wgsl → mul_mat_reg_tile.wgsl} +28 -128
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +31 -137
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +9 -36
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  584. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
  585. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  586. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +9 -6
  587. data/ext/sources/ggml/src/ggml.c +167 -33
  588. data/ext/sources/ggml/src/gguf.cpp +229 -44
  589. data/ext/sources/src/whisper.cpp +6 -28
  590. data/sig/whisper.rbs +43 -2
  591. data/test/test_context_params.rb +82 -0
  592. data/test/test_token.rb +11 -0
  593. data/test/test_vad_context.rb +58 -8
  594. data/test/test_whisper.rb +20 -0
  595. data/whispercpp.gemspec +1 -1
  596. metadata +240 -28
  597. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  598. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  599. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  600. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  601. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  602. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  603. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  604. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  605. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  606. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  607. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  608. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  609. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  610. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
@@ -1,20 +1,31 @@
1
- // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
1
+ // SPDX-FileCopyrightText: Copyright 2025-2026 Arm Limited and/or its affiliates <open-source-office@arm.com>
2
2
  // SPDX-License-Identifier: MIT
3
3
  //
4
4
  #include <arm_neon.h>
5
5
  #include <assert.h>
6
+ #include <stdio.h>
6
7
  #include <atomic>
7
8
  #include <cfloat>
8
- #include <cmath>
9
9
  #include <algorithm>
10
+ #include <cmath>
10
11
  #include <stdexcept>
11
12
  #include <stdint.h>
12
13
  #include <string.h>
13
14
  #include <string>
14
15
  #include <vector>
16
+ #include <array>
17
+ #include <cstddef>
18
+ #include <cstdint>
19
+ #include <fstream>
20
+ #include <set>
21
+ #include <iostream>
22
+ #include <climits>
15
23
  #if defined(__linux__)
16
24
  #include <asm/hwcap.h>
17
25
  #include <sys/auxv.h>
26
+ #include <sys/types.h>
27
+ #include <sys/stat.h>
28
+ #include <unistd.h>
18
29
  #elif defined(__APPLE__)
19
30
  #include <string_view>
20
31
  #include <sys/sysctl.h>
@@ -39,11 +50,18 @@
39
50
  #define GGML_COMMON_DECL_CPP
40
51
  #include "ggml-common.h"
41
52
 
53
+ static constexpr int GGML_KLEIDIAI_MAX_KERNEL_SLOTS = 2;
54
+ static constexpr uint32_t GGML_KLEIDIAI_PACK_MAGIC = 0x4b4c4149; // "KLAI"
55
+ static constexpr uint16_t GGML_KLEIDIAI_PACK_VERSION = 1;
56
+ static constexpr size_t GGML_KLEIDIAI_PACK_ALIGN = 64;
57
+
42
58
  struct ggml_kleidiai_context {
43
59
  cpu_feature features;
44
60
  ggml_kleidiai_kernels * kernels_q4;
45
61
  ggml_kleidiai_kernels * kernels_q8;
46
- } static ctx = { CPU_FEATURE_NONE, NULL, NULL };
62
+ int sme_thread_cap; // <= 0 means “SME disabled/unknown”;
63
+ int thread_hint; // <= 0 means “no hint”
64
+ } static ctx = { CPU_FEATURE_NONE, nullptr, nullptr, 0, -1 };
47
65
 
48
66
  static const char* cpu_feature_to_string(cpu_feature f) {
49
67
  if (f == CPU_FEATURE_NONE) {
@@ -63,41 +81,335 @@ static const char* cpu_feature_to_string(cpu_feature f) {
63
81
  }
64
82
  }
65
83
 
66
- static void init_kleidiai_context(void) {
84
+ static size_t detect_num_smcus() {
85
+ if (!ggml_cpu_has_sme()) {
86
+ return 0;
87
+ }
88
+
89
+ #if defined(__linux__) && defined(__aarch64__)
90
+ // Linux/aarch64: Best-effort count of Streaming Mode Compute Units (SMCUs) via SMIDR_EL1 sysfs.
91
+ size_t num_private = 0;
92
+ std::set<uint32_t> shared_ids;
93
+
94
+ for (size_t cpu = 0;; ++cpu) {
95
+ const std::string path =
96
+ "/sys/devices/system/cpu/cpu" + std::to_string(cpu) +
97
+ "/regs/identification/smidr_el1";
98
+
99
+ std::ifstream file(path);
100
+ if (!file.is_open()) {
101
+ break;
102
+ }
103
+
104
+ uint64_t smidr = 0;
105
+ if (!(file >> std::hex >> smidr)) {
106
+ continue;
107
+ }
108
+
109
+ // Arm ARM: SMIDR_EL1
110
+ const uint32_t sh = (uint32_t)((smidr >> 13) & 0x3);
111
+ // Build an "affinity-like" identifier for shared SMCUs.
112
+ // Keep the original packing logic, but isolate it here.
113
+ const uint32_t id = (uint32_t)((smidr & 0xFFFu) | ((smidr >> 20) & 0xFFFFF000u));
114
+
115
+ switch (sh) {
116
+ case 0b10: // private SMCU
117
+ ++num_private;
118
+ break;
119
+ case 0b11: // shared SMCU
120
+ shared_ids.emplace(id);
121
+ break;
122
+ case 0b00:
123
+ // Ambiguous / implementation-defined. Be conservative:
124
+ // treat id==0 as private, otherwise as shared.
125
+ if (id == 0) ++num_private;
126
+ else shared_ids.emplace(id);
127
+ break;
128
+ default:
129
+ break;
130
+ }
131
+ }
132
+
133
+ return num_private + shared_ids.size();
134
+
135
+ #elif defined(__APPLE__) && defined(__aarch64__)
136
+ // table for known M4 variants. Users can override via GGML_KLEIDIAI_SME=<n>.
137
+ char chip_name[256] = {};
138
+ size_t size = sizeof(chip_name);
139
+
140
+ if (sysctlbyname("machdep.cpu.brand_string", chip_name, &size, nullptr, 0) == 0) {
141
+ const std::string brand(chip_name);
142
+
143
+ struct ModelSMCU { const char *match; size_t smcus; };
144
+ static const ModelSMCU table[] = {
145
+ { "M4 Ultra", 2 },
146
+ { "M4 Max", 2 },
147
+ { "M4 Pro", 2 },
148
+ { "M4", 1 },
149
+ };
67
150
 
151
+ for (const auto &e : table) {
152
+ if (brand.find(e.match) != std::string::npos) {
153
+ return e.smcus;
154
+ }
155
+ }
156
+ }
157
+ return 1;
158
+
159
+ #else
160
+ return 1;
161
+ #endif
162
+ }
163
+
164
+ static int parse_uint_env(const char *s, const char *name, bool *ok) {
165
+ if (!s) { *ok = false; return 0; }
166
+ char *end = nullptr;
167
+ long v = strtol(s, &end, 10);
168
+ if (end == s || *end != '\0') {
169
+ GGML_LOG_WARN("kleidiai: invalid %s='%s' (expected integer)\n", name, s);
170
+ *ok = false;
171
+ return 0;
172
+ }
173
+ if (v < 0 || v > INT_MAX) {
174
+ GGML_LOG_WARN("kleidiai: out-of-range %s='%s'\n", name, s);
175
+ *ok = false;
176
+ return 0;
177
+ }
178
+ *ok = true;
179
+ return (int)v;
180
+ }
181
+
182
+ static void init_kleidiai_context(void) {
68
183
  ggml_critical_section_start();
69
184
  static bool initialized = false;
70
185
 
71
186
  if (!initialized) {
72
187
  initialized = true;
73
- const char *env_var = getenv("GGML_KLEIDIAI_SME");
74
- int sme_enabled = 0;
188
+
189
+ const char *env_sme = getenv("GGML_KLEIDIAI_SME");
190
+ const char *env_threads = getenv("GGML_TOTAL_THREADS");
191
+
192
+ const bool cpu_has_sme = ggml_cpu_has_sme();
193
+ size_t detected_smcus = 0;
75
194
 
76
195
  ctx.features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) |
77
196
  (ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) |
78
197
  ((ggml_cpu_has_sve() && ggml_cpu_get_sve_cnt() == QK8_0) ? CPU_FEATURE_SVE : CPU_FEATURE_NONE);
79
198
 
80
- if (env_var) {
81
- sme_enabled = atoi(env_var);
199
+ if (env_threads) {
200
+ bool ok = false;
201
+ int hint = parse_uint_env(env_threads, "GGML_TOTAL_THREADS", &ok);
202
+ if (ok && hint > 0) {
203
+ ctx.thread_hint = hint;
204
+ }
82
205
  }
83
206
 
84
- if (sme_enabled != 0) {
85
- ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
207
+ // SME policy:
208
+ // - If CPU doesn't support SME: SME always off.
209
+ // - Else:
210
+ // - env unset => auto-detect cores; enable if detected > 0.
211
+ // - env=0 => force off.
212
+ // - env>0 => force N cores (skip detection).
213
+ int sme_cores = 0;
214
+ bool sme_env_ok = false;
215
+ bool sme_env_set = (env_sme != nullptr);
216
+
217
+ if (!cpu_has_sme) {
218
+ if (sme_env_set) {
219
+ bool ok = false;
220
+ int req = parse_uint_env(env_sme, "GGML_KLEIDIAI_SME", &ok);
221
+ if (ok && req > 0) {
222
+ GGML_LOG_WARN("kleidiai: GGML_KLEIDIAI_SME=%d but SME is not supported on this CPU; disabling SME\n", req);
223
+ }
224
+ }
225
+ sme_cores = 0;
226
+ } else {
227
+ if (sme_env_set) {
228
+ bool ok = false;
229
+ int v = parse_uint_env(env_sme, "GGML_KLEIDIAI_SME", &ok);
230
+ sme_env_ok = ok;
231
+
232
+ if (!ok) {
233
+ GGML_LOG_WARN("kleidiai: GGML_KLEIDIAI_SME set but parsing failed; falling back to runtime SME-core detection\n");
234
+ detected_smcus = detect_num_smcus();
235
+ sme_cores = detected_smcus > 0 ? (int)detected_smcus : 0;
236
+ } else if (v == 0) {
237
+ sme_cores = 0;
238
+ } else {
239
+ sme_cores = v;
240
+ }
241
+ } else {
242
+ detected_smcus = detect_num_smcus();
243
+ sme_cores = detected_smcus > 0 ? (int)detected_smcus : 0;
244
+ }
245
+
246
+ if (!sme_env_set && sme_cores == 0) {
247
+ GGML_LOG_WARN("kleidiai: SME supported but runtime SME-core detection returned 0; falling back to NEON\n");
248
+ }
249
+
250
+ if (sme_cores > 0) {
251
+ ctx.features |= CPU_FEATURE_SME;
252
+ }
86
253
  }
254
+
255
+ // Kernel selection
87
256
  ctx.kernels_q4 = ggml_kleidiai_select_kernels_q4_0(ctx.features);
88
257
  ctx.kernels_q8 = ggml_kleidiai_select_kernels_q8_0(ctx.features);
89
- #ifndef NDEBUG
90
- if (ctx.kernels_q4) {
91
- GGML_LOG_DEBUG("kleidiai: using q4 kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels_q4->required_cpu));
258
+
259
+ if (!ctx.kernels_q4) {
260
+ GGML_LOG_INFO("kleidiai: no compatible q4 kernels found for CPU features mask %d\n", (int)ctx.features);
261
+ } else {
262
+ GGML_LOG_INFO("kleidiai: primary q4 kernel feature %s\n", cpu_feature_to_string(ctx.kernels_q4->required_cpu));
263
+ }
264
+
265
+ if (!ctx.kernels_q8) {
266
+ GGML_LOG_INFO("kleidiai: no compatible q8 kernels found for CPU features mask %d\n", (int)ctx.features);
267
+ } else {
268
+ GGML_LOG_INFO("kleidiai: primary q8 kernel feature %s\n", cpu_feature_to_string(ctx.kernels_q8->required_cpu));
92
269
  }
93
- if (ctx.kernels_q8) {
94
- GGML_LOG_DEBUG("kleidiai: using q8 kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels_q8->required_cpu));
270
+
271
+ ctx.sme_thread_cap = (ctx.features & CPU_FEATURE_SME) ? sme_cores : 0;
272
+
273
+ if (ctx.features & CPU_FEATURE_SME) {
274
+ if (sme_env_set && sme_env_ok && sme_cores > 0) {
275
+ GGML_LOG_INFO("kleidiai: SME enabled (GGML_KLEIDIAI_SME=%d override)\n", sme_cores);
276
+ } else {
277
+ GGML_LOG_INFO("kleidiai: SME enabled (runtime-detected SME cores=%d)\n", sme_cores);
278
+ }
279
+ } else {
280
+ GGML_LOG_INFO("kleidiai: SME disabled\n");
95
281
  }
96
- #endif
97
282
  }
283
+
98
284
  ggml_critical_section_end();
99
285
  }
100
286
 
287
+ static inline int kleidiai_sme_thread_cap() {
288
+ return ctx.sme_thread_cap;
289
+ }
290
+
291
+ static inline size_t align_up(size_t value, size_t alignment) {
292
+ if (alignment == 0) {
293
+ return value;
294
+ }
295
+ const size_t remainder = value % alignment;
296
+ return remainder == 0 ? value : value + (alignment - remainder);
297
+ }
298
+
299
+ static inline bool kleidiai_pack_fallback_allowed() {
300
+ if (ctx.sme_thread_cap <= 0) {
301
+ return false;
302
+ }
303
+ if (ctx.thread_hint <= 0) {
304
+ return true;
305
+ }
306
+ return ctx.thread_hint > ctx.sme_thread_cap;
307
+ }
308
+
309
+ struct kleidiai_weight_header {
310
+ uint32_t magic;
311
+ uint16_t version;
312
+ uint16_t slot_count;
313
+ uint64_t offsets[GGML_KLEIDIAI_MAX_KERNEL_SLOTS];
314
+ uint64_t sizes[GGML_KLEIDIAI_MAX_KERNEL_SLOTS];
315
+ };
316
+
317
+ static inline kleidiai_weight_header * kleidiai_weight_header_from_ptr(void * data) {
318
+ return reinterpret_cast<kleidiai_weight_header *>(data);
319
+ }
320
+
321
+ static inline const kleidiai_weight_header * kleidiai_weight_header_from_ptr(const void * data) {
322
+ return reinterpret_cast<const kleidiai_weight_header *>(data);
323
+ }
324
+
325
+ static inline bool kleidiai_is_weight_header_valid(const kleidiai_weight_header * header) {
326
+ if (!header) {
327
+ return false;
328
+ }
329
+ if (header->magic != GGML_KLEIDIAI_PACK_MAGIC || header->version != GGML_KLEIDIAI_PACK_VERSION) {
330
+ return false;
331
+ }
332
+ if (header->slot_count == 0 || header->slot_count > GGML_KLEIDIAI_MAX_KERNEL_SLOTS) {
333
+ return false;
334
+ }
335
+ return true;
336
+ }
337
+
338
+ static inline uint8_t * kleidiai_weight_slot_ptr(kleidiai_weight_header * header, int slot) {
339
+ if (!kleidiai_is_weight_header_valid(header)) {
340
+ return nullptr;
341
+ }
342
+ if (slot < 0 || slot >= header->slot_count) {
343
+ return nullptr;
344
+ }
345
+ return reinterpret_cast<uint8_t *>(header) + header->offsets[slot];
346
+ }
347
+
348
+ static inline const uint8_t * kleidiai_weight_slot_ptr(const kleidiai_weight_header * header, int slot) {
349
+ if (!kleidiai_is_weight_header_valid(header)) {
350
+ return nullptr;
351
+ }
352
+ if (slot < 0 || slot >= header->slot_count) {
353
+ return nullptr;
354
+ }
355
+ return reinterpret_cast<const uint8_t *>(header) + header->offsets[slot];
356
+ }
357
+
358
+ static inline ggml_kleidiai_kernels * kleidiai_primary_kernel_q4() {
359
+ return ctx.kernels_q4;
360
+ }
361
+
362
+ static inline ggml_kleidiai_kernels * kleidiai_primary_kernel_q8() {
363
+ return ctx.kernels_q8;
364
+ }
365
+
366
+ template <typename SelectFallback>
367
+ static int kleidiai_collect_kernel_chain_common(
368
+ ggml_kleidiai_kernels * primary,
369
+ cpu_feature features,
370
+ std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> & out,
371
+ SelectFallback select_fallback) {
372
+ int count = 0;
373
+ if (!primary) {
374
+ return 0;
375
+ }
376
+ out[count++] = primary;
377
+
378
+ if ((primary->required_cpu & CPU_FEATURE_SME) == CPU_FEATURE_SME) {
379
+ const cpu_feature fallback_mask = static_cast<cpu_feature>(features & ~CPU_FEATURE_SME);
380
+ if (fallback_mask != CPU_FEATURE_NONE) {
381
+ ggml_kleidiai_kernels * fallback = select_fallback(fallback_mask);
382
+ if (fallback && fallback != primary &&
383
+ fallback->lhs_type == primary->lhs_type &&
384
+ fallback->rhs_type == primary->rhs_type &&
385
+ fallback->op_type == primary->op_type) {
386
+ out[count++] = fallback;
387
+ }
388
+ }
389
+ }
390
+
391
+ return count;
392
+ }
393
+
394
+ static int kleidiai_collect_kernel_chain(const struct ggml_tensor * op,
395
+ std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> & out) {
396
+ ggml_kleidiai_kernels * primary = ggml_kleidiai_select_kernels(ctx.features, op);
397
+ return kleidiai_collect_kernel_chain_common(primary, ctx.features, out,
398
+ [&](cpu_feature mask) { return ggml_kleidiai_select_kernels(mask, op); });
399
+ }
400
+
401
+ static int kleidiai_collect_q4_chain(std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> & out) {
402
+ ggml_kleidiai_kernels * primary = kleidiai_primary_kernel_q4();
403
+ return kleidiai_collect_kernel_chain_common(primary, ctx.features, out,
404
+ [&](cpu_feature mask) { return ggml_kleidiai_select_kernels_q4_0(mask); });
405
+ }
406
+
407
+ static int kleidiai_collect_q8_chain(std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> & out) {
408
+ ggml_kleidiai_kernels * primary = kleidiai_primary_kernel_q8();
409
+ return kleidiai_collect_kernel_chain_common(primary, ctx.features, out,
410
+ [&](cpu_feature mask) { return ggml_kleidiai_select_kernels_q8_0(mask); });
411
+ }
412
+
101
413
  static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) {
102
414
  GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
103
415
  return tensor->ne[dim];
@@ -126,49 +438,108 @@ class tensor_traits : public ggml::cpu::tensor_traits {
126
438
  if (op->op != GGML_OP_MUL_MAT) {
127
439
  return false;
128
440
  }
129
- ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
130
- if (!kernels) {
441
+
442
+ std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
443
+ const int slot_count = kleidiai_collect_kernel_chain(op, kernel_chain);
444
+ if (slot_count == 0) {
131
445
  return false;
132
446
  }
133
- bool is_gemv = op->src[1]->ne[1] == 1;
134
- kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
135
- lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
136
447
 
137
- size_t k = op->src[0]->ne[0];
138
- size_t n = op->src[0]->ne[1];
139
- size_t m = op->src[1]->ne[1];
140
-
141
- size_t mr = kernel->get_mr();
142
- size_t kr = kernel->get_kr();
143
- size_t sr = kernel->get_sr();
144
-
145
- if (kernels->rhs_type == GGML_TYPE_Q4_0) {
146
- if (!lhs_info->packed_size_ex) return false;
147
- size = lhs_info->packed_size_ex(m, k, QK4_0, mr, kr, sr);
148
- } else if (kernels->rhs_type == GGML_TYPE_Q8_0) {
149
- if (!lhs_info->packed_size_ex) return false;
150
- size = lhs_info->packed_size_ex(m, k, QK8_0, mr, kr, sr);
151
- } else if (kernels->rhs_type == GGML_TYPE_F16) {
152
- if (!lhs_info->packed_size_ex || !kernels->rhs_info.packed_size_ex) return false;
448
+ const bool is_gemv = op->src[1]->ne[1] == 1;
449
+ const size_t k = op->src[0]->ne[0];
450
+ const size_t n = op->src[0]->ne[1];
451
+ const size_t m = op->src[1]->ne[1];
452
+
453
+ if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q8_0) {
454
+ const size_t qk = (op->src[0]->type == GGML_TYPE_Q4_0) ? QK4_0 : QK8_0;
455
+
456
+ size_t cursor = 0;
457
+ bool any_slot = false;
458
+
459
+ for (int slot = 0; slot < slot_count; ++slot) {
460
+ ggml_kleidiai_kernels * kernels = kernel_chain[slot];
461
+ lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
462
+ kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
463
+
464
+ if (!lhs_info || !lhs_info->packed_size_ex || !kernel) {
465
+ return false;
466
+ }
467
+
468
+ const size_t mr = kernel->get_mr();
469
+ const size_t kr = kernel->get_kr();
470
+ const size_t sr = kernel->get_sr();
471
+
472
+ const size_t packed = lhs_info->packed_size_ex(m, k, qk, mr, kr, sr);
473
+
474
+ cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
475
+ cursor += packed;
476
+ any_slot = true;
477
+ }
478
+
479
+ if (!any_slot) {
480
+ return false;
481
+ }
482
+
483
+ size = cursor;
484
+ return true;
485
+ }
486
+
487
+ if (op->src[0]->type == GGML_TYPE_F16) {
153
488
  const int64_t lhs_batch_size0 = op->src[1]->ne[2];
154
489
  const int64_t rhs_batch_size0 = op->src[0]->ne[2];
490
+ GGML_ASSERT(rhs_batch_size0 > 0);
155
491
  const int64_t r = lhs_batch_size0 / rhs_batch_size0;
156
- size = lhs_info->packed_size_ex(m * r, k, 0, mr, kr, sr) +
157
- kernels->rhs_info.packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), 0) +
158
- k * n * sizeof(float) + n * sizeof(float);
159
- } else {
160
- return false;
492
+
493
+ size_t cursor = 0;
494
+ bool any_slot = false;
495
+
496
+ for (int slot = 0; slot < slot_count; ++slot) {
497
+ ggml_kleidiai_kernels * kernels = kernel_chain[slot];
498
+ lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
499
+ kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
500
+ if (!lhs_info || !lhs_info->packed_size_ex || !kernels->rhs_info.packed_size_ex || !kernel) {
501
+ return false;
502
+ }
503
+
504
+ const size_t mr = kernel->get_mr();
505
+ const size_t kr = kernel->get_kr();
506
+ const size_t sr = kernel->get_sr();
507
+
508
+ cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
509
+ cursor += lhs_info->packed_size_ex(m * r, k, 0, mr, kr, sr);
510
+ any_slot = true;
511
+ }
512
+
513
+ for (int slot = 0; slot < slot_count; ++slot) {
514
+ ggml_kleidiai_kernels * kernels = kernel_chain[slot];
515
+ kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
516
+ if (!kernel || !kernels->rhs_info.packed_size_ex) {
517
+ return false;
518
+ }
519
+ cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
520
+ cursor += kernels->rhs_info.packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), 0);
521
+ }
522
+
523
+ cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
524
+ cursor += k * n * sizeof(float);
525
+ cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
526
+ cursor += n * sizeof(float);
527
+
528
+ if (!any_slot) {
529
+ return false;
530
+ }
531
+
532
+ size = cursor;
533
+ return true;
161
534
  }
162
535
 
163
- return true;
536
+ return false;
164
537
  }
165
538
 
166
539
  bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override {
167
540
  if (dst->op == GGML_OP_MUL_MAT) {
168
- if (dst->src[0]->type == GGML_TYPE_Q4_0) {
169
- return compute_forward_q4_0(params, dst);
170
- } else if (dst->src[0]->type == GGML_TYPE_Q8_0) {
171
- return compute_forward_q8_0(params, dst);
541
+ if (dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0) {
542
+ return compute_forward_qx(params, dst);
172
543
  } else if (dst->src[0]->type == GGML_TYPE_F16) {
173
544
  return compute_forward_fp16(params, dst);
174
545
  }
@@ -331,204 +702,457 @@ class tensor_traits : public ggml::cpu::tensor_traits {
331
702
  return true;
332
703
  }
333
704
 
334
- bool compute_forward_q4_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
335
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
705
+ bool compute_forward_qx(struct ggml_compute_params * params, struct ggml_tensor * dst) {
706
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0);
336
707
 
337
708
  const ggml_tensor * src0 = dst->src[0];
338
709
  const ggml_tensor * src1 = dst->src[1];
339
710
 
340
711
  GGML_TENSOR_BINARY_OP_LOCALS
341
712
 
342
- ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
343
- if (!kernels) {
344
- return false;
345
- }
346
-
347
- bool is_gemv = src1->ne[1] == 1;
348
- kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
349
- lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
350
-
351
- GGML_ASSERT(kernel);
352
- if (!lhs_info->get_packed_offset_ex || !lhs_info->pack_func_ex ||
353
- !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex || !kernel->get_dst_offset) {
354
- return false;
355
- }
713
+ const kleidiai_weight_header * header = kleidiai_weight_header_from_ptr(src0->data);
714
+ const bool has_header = kleidiai_is_weight_header_valid(header);
715
+ const bool is_gemv = src1->ne[1] == 1;
716
+ std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
717
+ const int slot_total = kleidiai_collect_kernel_chain(dst, kernel_chain);
356
718
 
357
- const int ith = params->ith;
358
- const int nth_raw = params->nth;
359
- const int nth = nth_raw > 0 ? nth_raw : 1;
719
+ auto weight_for_slot = [&](int slot_index, size_t & size_out) -> const uint8_t * {
720
+ if (slot_index < 0 || slot_index >= slot_total) {
721
+ return nullptr;
722
+ }
723
+ if (has_header) {
724
+ if (slot_index < header->slot_count) {
725
+ size_out = static_cast<size_t>(header->sizes[slot_index]);
726
+ return kleidiai_weight_slot_ptr(header, slot_index);
727
+ }
728
+ return nullptr;
729
+ }
730
+ if (slot_index == 0) {
731
+ size_out = ggml_nbytes(src0);
732
+ return static_cast<const uint8_t *>(src0->data);
733
+ }
734
+ return nullptr;
735
+ };
736
+
737
+ struct runtime_slot {
738
+ int slot_index;
739
+ ggml_kleidiai_kernels * kernels;
740
+ kernel_info * kernel;
741
+ lhs_packing_info * lhs_info;
742
+ size_t mr;
743
+ size_t nr;
744
+ size_t kr;
745
+ size_t sr;
746
+ size_t n_step;
747
+ size_t lhs_packed_size;
748
+ size_t lhs_offset;
749
+ size_t n_offset;
750
+ size_t n_cols;
751
+ int assigned_threads;
752
+ int thread_begin;
753
+ int thread_end;
754
+ const uint8_t * rhs_base;
755
+ };
756
+
757
+ std::array<runtime_slot, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> runtime{};
758
+ int runtime_count = 0;
759
+
760
+ for (int slot = 0; slot < slot_total && runtime_count < GGML_KLEIDIAI_MAX_KERNEL_SLOTS; ++slot) {
761
+ ggml_kleidiai_kernels * kernels = kernel_chain[slot];
762
+ kernel_info * kinfo = is_gemv ? &kernels->gemv : &kernels->gemm;
763
+ lhs_packing_info * linfo = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
764
+ if (!kinfo || !linfo || !linfo->packed_size_ex || !linfo->pack_func_ex || !linfo->get_offset ||
765
+ !kinfo->get_rhs_packed_offset_ex || !kinfo->run_kernel_ex || !kinfo->get_dst_offset) {
766
+ continue;
767
+ }
360
768
 
361
- const size_t k = ne00;
362
- const size_t m = ne11;
363
- const size_t n = ne01;
769
+ size_t rhs_size = 0;
770
+ const uint8_t * rhs_ptr = weight_for_slot(slot, rhs_size);
771
+ if (!rhs_ptr || rhs_size == 0) {
772
+ continue;
773
+ }
364
774
 
365
- size_t mr = kernel->get_mr();
366
- size_t kr = kernel->get_kr();
367
- size_t sr = kernel->get_sr();
775
+ runtime[runtime_count] = {
776
+ slot,
777
+ kernels,
778
+ kinfo,
779
+ linfo,
780
+ kinfo->get_mr(),
781
+ kinfo->get_nr(),
782
+ kinfo->get_kr(),
783
+ kinfo->get_sr(),
784
+ kinfo->get_n_step(),
785
+ 0,
786
+ 0,
787
+ 0,
788
+ 0,
789
+ 0,
790
+ 0,
791
+ 0,
792
+ rhs_ptr
793
+ };
794
+ ++runtime_count;
795
+ }
368
796
 
369
- const uint8_t * lhs = static_cast<const uint8_t *>(src1->data);
370
- uint8_t * lhs_packed = (uint8_t*)params->wdata;
371
- const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
797
+ if (runtime_count == 0) {
798
+ ggml_kleidiai_kernels * fallback = ggml_kleidiai_select_kernels(ctx.features, dst);
799
+ if (!fallback) {
800
+ return false;
801
+ }
802
+ kernel_info * kinfo = is_gemv ? &fallback->gemv : &fallback->gemm;
803
+ lhs_packing_info * linfo = is_gemv ? &fallback->gemv_lhs_info : &fallback->gemm_lhs_info;
804
+ rhs_packing_info * rinfo = &fallback->rhs_info;
805
+ if (!kinfo || !linfo || !linfo->packed_size_ex || !linfo->pack_func_ex ||
806
+ !kinfo->get_rhs_packed_offset_ex || !kinfo->run_kernel_ex || !kinfo->get_dst_offset ||
807
+ !rinfo || !rinfo->pack_func_ex || !rinfo->packed_size_ex) {
808
+ return false;
809
+ }
810
+ kernel_chain[0] = fallback;
811
+ runtime[0] = {
812
+ 0,
813
+ fallback,
814
+ kinfo,
815
+ linfo,
816
+ kinfo->get_mr(),
817
+ kinfo->get_nr(),
818
+ kinfo->get_kr(),
819
+ kinfo->get_sr(),
820
+ kinfo->get_n_step(),
821
+ 0,
822
+ 0,
823
+ 0,
824
+ 0,
825
+ 0,
826
+ 0,
827
+ 0,
828
+ nullptr
829
+ };
830
+ size_t rhs_size_fallback = 0;
831
+ const uint8_t * rhs_base = weight_for_slot(0, rhs_size_fallback);
832
+ if (!rhs_base) {
833
+ rhs_base = static_cast<const uint8_t *>(src0->data);
834
+ }
835
+ runtime[0].rhs_base = rhs_base;
836
+ runtime_count = 1;
837
+ }
372
838
 
373
- const size_t n_step = kernel->get_n_step();
374
- const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
375
- const size_t n_start = ith * num_n_per_thread;
839
+ const int nth_total = params->nth > 0 ? params->nth : 1;
840
+ const int ith_total = params->ith;
376
841
 
377
- size_t n_to_process = 0;
378
- if (n_start < n) {
379
- n_to_process = num_n_per_thread;
380
- if ((n_start + n_to_process) > n) {
381
- n_to_process = n - n_start;
842
+ int sme_slot = -1;
843
+ for (int i = 0; i < runtime_count; ++i) {
844
+ if ((runtime[i].kernels->required_cpu & CPU_FEATURE_SME) == CPU_FEATURE_SME) {
845
+ sme_slot = i;
846
+ break;
382
847
  }
383
848
  }
384
849
 
385
- // Calculate number of columns to be processed per thread
386
- const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
387
- const size_t m_start = ith * num_m_per_thread;
388
- size_t m_to_process = num_m_per_thread;
389
- if ((m_start + m_to_process) > m) {
390
- m_to_process = m - m_start;
850
+ const int sme_cap_limit = ctx.sme_thread_cap;
851
+ const bool use_hybrid = sme_cap_limit > 0 &&
852
+ runtime_count > 1 &&
853
+ nth_total > sme_cap_limit;
854
+ // Heuristic: disable hybrid for very small workloads where per-slot overhead dominates.
855
+ // If rows are small or average columns per thread are small, keep single-slot.
856
+ size_t min_cols_per_thread = 0;
857
+ if (runtime_count > 0 && nth_total > 0) {
858
+ min_cols_per_thread = (size_t) std::max<int64_t>(1, (int64_t)ne01 / (int64_t)nth_total);
391
859
  }
860
+ const bool too_small_for_hybrid = (min_cols_per_thread < 2) || (ne11 < 128);
392
861
 
393
- if (m_start < m) {
394
- // Transform LHS
395
- const size_t src_stride = src1->nb[1];
396
- const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
397
- const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(m_start, k, QK4_0, mr, kr, sr);
398
- void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
399
-
400
- // Pack this thread's chunk with m_idx_start = 0 and per-thread output pointer
401
- lhs_info->pack_func_ex(m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
402
- }
862
+ const bool hybrid_enabled = use_hybrid && !too_small_for_hybrid;
403
863
 
404
- ggml_barrier(params->threadpool);
864
+ if (!hybrid_enabled) {
865
+ int chosen_slot = 0;
866
+ if (too_small_for_hybrid && sme_slot != -1) {
867
+ chosen_slot = sme_slot;
868
+ } else if (runtime_count > 1 && ctx.sme_thread_cap > 0 && nth_total > ctx.sme_thread_cap) {
869
+ chosen_slot = 1;
870
+ }
871
+ if (chosen_slot != 0 && chosen_slot < runtime_count) {
872
+ runtime[0] = runtime[chosen_slot];
873
+ }
874
+ runtime_count = runtime_count > 0 ? 1 : 0;
405
875
 
406
- // Perform the operation
407
- const size_t dst_stride = dst->nb[1];
408
- const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(0, k, QK4_0, mr, kr, sr);
409
- const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, QK4_0);
410
- const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride);
411
- const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
412
- const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset);
413
- float *dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
876
+ // Recompute SME slot based on the collapsed runtime[0]
877
+ sme_slot = -1;
878
+ if (runtime_count > 0 &&
879
+ (runtime[0].kernels->required_cpu & CPU_FEATURE_SME) == CPU_FEATURE_SME) {
880
+ sme_slot = 0;
881
+ }
882
+ }
414
883
 
415
- if (n_to_process > 0) {
416
- kernel->run_kernel_ex(m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
417
- sizeof(float), -FLT_MAX, FLT_MAX);
884
+ int sme_cap = kleidiai_sme_thread_cap();
885
+ if (sme_cap < 0) {
886
+ sme_cap = nth_total;
418
887
  }
888
+ sme_cap = std::min(sme_cap, nth_total);
419
889
 
420
- return true;
421
- }
890
+ int threads_remaining = nth_total;
891
+ if (sme_slot != -1) {
892
+ int sme_threads = std::min(std::max(sme_cap, 0), threads_remaining);
893
+ runtime[sme_slot].assigned_threads = sme_threads;
894
+ threads_remaining -= sme_threads;
895
+ }
422
896
 
423
- bool compute_forward_q8_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
424
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q8_0);
897
+ int fallback_indices[GGML_KLEIDIAI_MAX_KERNEL_SLOTS];
898
+ int fallback_count = 0;
899
+ for (int i = 0; i < runtime_count; ++i) {
900
+ if (i == sme_slot) {
901
+ continue;
902
+ }
903
+ fallback_indices[fallback_count++] = i;
904
+ }
425
905
 
426
- const ggml_tensor * src0 = dst->src[0];
427
- const ggml_tensor * src1 = dst->src[1];
906
+ for (int fi = 0; fi < fallback_count; ++fi) {
907
+ if (threads_remaining <= 0) {
908
+ break;
909
+ }
910
+ const int slot_index = fallback_indices[fi];
911
+ const int slots_left = fallback_count - fi;
912
+ int share = (threads_remaining + slots_left - 1) / slots_left;
913
+ share = std::min(share, threads_remaining);
914
+ runtime[slot_index].assigned_threads = share;
915
+ threads_remaining -= share;
916
+ }
428
917
 
429
- GGML_TENSOR_BINARY_OP_LOCALS
918
+ if (threads_remaining > 0) {
919
+ const int fallback_slot = (sme_slot != -1) ? sme_slot : 0;
920
+ runtime[fallback_slot].assigned_threads += threads_remaining;
921
+ threads_remaining = 0;
922
+ }
430
923
 
431
- ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
432
- if (!kernels) {
433
- return false;
924
+ int thread_cursor = 0;
925
+ for (int i = 0; i < runtime_count; ++i) {
926
+ runtime[i].thread_begin = thread_cursor;
927
+ thread_cursor += runtime[i].assigned_threads;
928
+ runtime[i].thread_end = thread_cursor;
434
929
  }
435
930
 
436
- bool is_gemv = src1->ne[1] == 1;
437
- kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
438
- lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
931
+ if (thread_cursor < nth_total && runtime_count > 0) {
932
+ runtime[runtime_count - 1].assigned_threads += nth_total - thread_cursor;
933
+ runtime[runtime_count - 1].thread_end = nth_total;
934
+ }
439
935
 
440
- if (!kernel || !lhs_info->get_packed_offset_ex || !lhs_info->pack_func_ex ||
441
- !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex || !kernel->get_dst_offset) {
936
+ int local_slot = -1;
937
+ int local_ith = 0;
938
+ for (int i = 0; i < runtime_count; ++i) {
939
+ if (ith_total >= runtime[i].thread_begin && ith_total < runtime[i].thread_end) {
940
+ local_slot = i;
941
+ local_ith = ith_total - runtime[i].thread_begin;
942
+ break;
943
+ }
944
+ }
945
+ if (local_slot == -1) {
442
946
  return false;
443
947
  }
444
948
 
445
- const int ith = params->ith;
446
- const int nth_raw = params->nth;
447
- const int nth = nth_raw > 0 ? nth_raw : 1;
448
-
449
949
  const size_t k = ne00;
450
950
  const size_t m = ne11;
451
951
  const size_t n = ne01;
452
952
 
453
- size_t mr = kernel->get_mr();
454
- size_t kr = kernel->get_kr();
455
- size_t sr = kernel->get_sr();
456
-
457
- const uint8_t * lhs = static_cast<const uint8_t *>(src1->data);
458
- uint8_t * lhs_packed = static_cast<uint8_t *>(params->wdata);
459
- const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
953
+ size_t cursor = 0;
954
+ for (int i = 0; i < runtime_count; ++i) {
955
+ const ggml_type slot_rhs_type = runtime[i].kernels->rhs_type;
956
+ const size_t slot_pack_size_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
957
+ slot_rhs_type == GGML_TYPE_Q8_0 ? QK8_0 : 0;
958
+ runtime[i].lhs_packed_size = runtime[i].lhs_info->packed_size_ex(m, k, slot_pack_size_arg, runtime[i].mr, runtime[i].kr, runtime[i].sr);
959
+ cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
960
+ runtime[i].lhs_offset = cursor;
961
+ cursor += runtime[i].lhs_packed_size;
962
+ }
460
963
 
461
- const size_t n_step = kernel->get_n_step();
462
- const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
463
- const size_t n_start = ith * num_n_per_thread;
964
+ GGML_ASSERT(cursor <= params->wsize);
965
+ uint8_t * scratch = static_cast<uint8_t *>(params->wdata);
464
966
 
465
- size_t n_to_process = 0;
466
- if (n_start < n) {
467
- n_to_process = num_n_per_thread;
468
- if ((n_start + n_to_process) > n) {
469
- n_to_process = n - n_start;
967
+ size_t assigned_cols = 0;
968
+ uint64_t weighted_total = 0;
969
+ if (runtime_count > 1 && sme_slot != -1) {
970
+ for (int i = 0; i < runtime_count; ++i) {
971
+ const uint64_t weight = (i == sme_slot) ? (sme_cap << 1) : 1;
972
+ weighted_total += (uint64_t)runtime[i].assigned_threads * weight;
470
973
  }
471
974
  }
975
+ for (int i = 0; i < runtime_count; ++i) {
976
+ runtime[i].n_offset = assigned_cols;
977
+ if (runtime[i].assigned_threads == 0) {
978
+ runtime[i].n_cols = 0;
979
+ continue;
980
+ }
981
+ const size_t remaining_cols = n - assigned_cols;
982
+ if (remaining_cols == 0) {
983
+ runtime[i].n_cols = 0;
984
+ continue;
985
+ }
986
+ const size_t step = runtime[i].n_step ? runtime[i].n_step : 1;
987
+ size_t target = 0;
988
+ if (weighted_total > 0) {
989
+ const uint64_t weight = (i == sme_slot) ? (sme_cap << 1) : 1;
990
+ target = (size_t)(((uint64_t)n * runtime[i].assigned_threads * weight) / weighted_total);
991
+ } else {
992
+ target = (size_t)(((uint64_t)n * runtime[i].assigned_threads) / nth_total);
993
+ }
994
+ target = std::min(target, remaining_cols);
995
+ size_t aligned = round_down(target, step);
996
+ if (aligned == 0 && remaining_cols >= step) {
997
+ aligned = step;
998
+ }
999
+ runtime[i].n_cols = aligned;
1000
+ assigned_cols += aligned;
1001
+ }
472
1002
 
473
- const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
474
- const size_t m_start = ith * num_m_per_thread;
475
- size_t m_to_process = num_m_per_thread;
476
- if ((m_start + m_to_process) > m) {
477
- m_to_process = m - m_start;
1003
+ if (assigned_cols < n) {
1004
+ for (int i = runtime_count - 1; i >= 0; --i) {
1005
+ if (runtime[i].assigned_threads > 0) {
1006
+ runtime[i].n_cols += n - assigned_cols;
1007
+ break;
1008
+ }
1009
+ }
478
1010
  }
1011
+ const size_t dst_stride = dst->nb[1];
479
1012
 
480
- if (m_start < m) {
481
- const size_t src_stride = src1->nb[1];
482
- const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
483
- const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(m_start, k, 0, mr, kr, sr);
484
- void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
1013
+ for (int64_t batch_idx = 0; batch_idx < ne12; ++batch_idx) {
1014
+ const uint8_t * lhs_batch_base = static_cast<const uint8_t *>(src1->data) + batch_idx * src1->nb[2];
1015
+ uint8_t * dst_batch_base = static_cast<uint8_t *>(dst->data) + batch_idx * dst->nb[2];
485
1016
 
486
- lhs_info->pack_func_ex(m_to_process, k, 0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
487
- }
1017
+ if (runtime[local_slot].assigned_threads > 0) {
1018
+ runtime_slot & slot = runtime[local_slot];
1019
+ const ggml_type slot_rhs_type = slot.kernels->rhs_type;
1020
+ const size_t slot_lhs_exec_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
1021
+ slot_rhs_type == GGML_TYPE_Q8_0 ? 0 : 0;
1022
+ const int64_t m_roundup_mr = kai_roundup((int64_t)m, (int64_t)slot.mr);
1023
+ int64_t max_threads = slot.mr ? (m_roundup_mr / (int64_t)slot.mr) : slot.assigned_threads;
1024
+ max_threads = std::max<int64_t>(1, max_threads);
1025
+ const int64_t use_threads = std::min<int64_t>(slot.assigned_threads, max_threads);
488
1026
 
489
- ggml_barrier(params->threadpool);
1027
+ if (local_ith < use_threads) {
1028
+ const int64_t num_m_per_thread0 = round_down((size_t)(m_roundup_mr / use_threads), slot.mr);
1029
+ const int64_t num_m_per_threadN_1 = (int64_t)m - (use_threads - 1) * num_m_per_thread0;
490
1030
 
491
- const size_t dst_stride = dst->nb[1];
492
- const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(0, k, 0, mr, kr, sr);
493
- const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, 0);
494
- const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride);
495
- const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
496
- const void * lhs_ptr = static_cast<const void *>(lhs_packed + lhs_packed_offset);
497
- float * dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
1031
+ const int64_t m_start = (int64_t)local_ith * num_m_per_thread0;
1032
+ const int64_t m_count = (local_ith == use_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
1033
+
1034
+ const size_t base_packed_off = slot.lhs_info->get_packed_offset_ex(m_start, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr);
1035
+ const size_t next_block_off = slot.lhs_info->get_packed_offset_ex(m_start + slot.mr, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr);
1036
+ const size_t row_stride_bytes = slot.mr ? (next_block_off - base_packed_off) / slot.mr : 0;
1037
+
1038
+ int64_t remaining = m_count;
1039
+ int64_t cur = m_start;
1040
+
1041
+ uint8_t * lhs_packed = scratch + slot.lhs_offset;
1042
+ while (remaining > 0) {
1043
+ const int64_t row_in_group = cur;
1044
+ const int64_t avail = (int64_t)m - row_in_group;
1045
+ const int64_t take = std::min(avail, remaining);
1046
+
1047
+ const size_t src_off = slot.lhs_info->get_offset(row_in_group, src1->nb[1]);
1048
+ const void * src_ptr = lhs_batch_base + src_off;
1049
+ const size_t dst_off = base_packed_off + (size_t)(cur - m_start) * row_stride_bytes;
1050
+ void * dst_ptr = lhs_packed + dst_off;
1051
+
1052
+ slot.lhs_info->pack_func_ex(take, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr, 0, src_ptr, src1->nb[1], dst_ptr);
1053
+
1054
+ cur += take;
1055
+ remaining -= take;
1056
+ }
1057
+ }
1058
+ }
1059
+
1060
+ ggml_barrier(params->threadpool);
498
1061
 
499
- if (n_to_process > 0) {
500
- kernel->run_kernel_ex(m, n_to_process, k, 0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
501
- sizeof(float), -FLT_MAX, FLT_MAX);
1062
+ runtime_slot & slot = runtime[local_slot];
1063
+ if (slot.n_cols > 0 && slot.assigned_threads > 0) {
1064
+ int64_t active_threads = slot.assigned_threads;
1065
+ const int64_t max_threads = slot.n_step ? (slot.n_cols / slot.n_step) : slot.assigned_threads;
1066
+ if (max_threads > 0) {
1067
+ active_threads = std::min<int64_t>(active_threads, std::max<int64_t>(1, max_threads));
1068
+ }
1069
+ active_threads = std::max<int64_t>(1, active_threads);
1070
+
1071
+ if (local_ith < active_threads) {
1072
+ const size_t step = slot.n_step ? slot.n_step : 1;
1073
+ const size_t chunk0 = round_down((size_t)(slot.n_cols / active_threads), step);
1074
+ const size_t chunkN = slot.n_cols - (active_threads - 1) * chunk0;
1075
+ const size_t local_start = (size_t)local_ith * chunk0;
1076
+ const size_t cols = (local_ith == active_threads - 1) ? chunkN : chunk0;
1077
+
1078
+ if (cols > 0) {
1079
+ const ggml_type slot_rhs_type = slot.kernels->rhs_type;
1080
+ const size_t slot_lhs_exec_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
1081
+ slot_rhs_type == GGML_TYPE_Q8_0 ? 0 : 0;
1082
+ const size_t slot_rhs_block_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
1083
+ slot_rhs_type == GGML_TYPE_Q8_0 ? 0 : 0;
1084
+ const size_t global_start = slot.n_offset + local_start;
1085
+ const size_t lhs_packed_offset = slot.lhs_info->get_packed_offset_ex(0, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr);
1086
+ const size_t rhs_packed_offset = slot.kernel->get_rhs_packed_offset_ex(global_start, k, slot_rhs_block_arg);
1087
+ const size_t dst_offset = slot.kernel->get_dst_offset(0, global_start, dst_stride);
1088
+
1089
+ const uint8_t * lhs_ptr = scratch + slot.lhs_offset + lhs_packed_offset;
1090
+ const uint8_t * rhs_ptr = slot.rhs_base + rhs_packed_offset;
1091
+ float * dst_ptr = reinterpret_cast<float *>(dst_batch_base + dst_offset);
1092
+
1093
+ slot.kernel->run_kernel_ex(m, cols, k, slot_rhs_block_arg,
1094
+ lhs_ptr,
1095
+ rhs_ptr,
1096
+ dst_ptr,
1097
+ dst_stride,
1098
+ sizeof(float),
1099
+ -FLT_MAX,
1100
+ FLT_MAX);
1101
+ }
1102
+ }
1103
+ }
1104
+
1105
+ if (batch_idx != ne12 - 1) {
1106
+ ggml_barrier(params->threadpool);
1107
+ }
502
1108
  }
503
1109
 
504
1110
  return true;
505
1111
  }
506
1112
 
507
1113
  bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
1114
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0);
508
1115
  const ggml_tensor * src0 = dst->src[0];
509
1116
  const ggml_tensor * src1 = dst->src[1];
510
1117
 
511
1118
  GGML_TENSOR_BINARY_OP_LOCALS
512
1119
 
1120
+ const kleidiai_weight_header * header = kleidiai_weight_header_from_ptr(src0->data);
1121
+ const bool has_header = kleidiai_is_weight_header_valid(header);
1122
+
1123
+ std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
1124
+ const bool want_q8 = src0->type == GGML_TYPE_Q8_0;
1125
+ const int chain_count = want_q8 ? kleidiai_collect_q8_chain(kernel_chain)
1126
+ : kleidiai_collect_q4_chain(kernel_chain);
1127
+
513
1128
  ggml_kleidiai_kernels * kernels = nullptr;
514
- size_t block_len = 0;
515
- size_t num_bytes_multiplier = 0;
1129
+ const uint8_t * packed_base = static_cast<const uint8_t *>(src0->data);
516
1130
 
517
- if (dst->src[0]->type == GGML_TYPE_Q4_0) {
518
- if (!ctx.kernels_q4) {
519
- return false;
1131
+ if (has_header && chain_count > 0) {
1132
+ int select_slot = 0;
1133
+ if (select_slot >= header->slot_count) {
1134
+ select_slot = header->slot_count - 1;
520
1135
  }
521
- kernels = ctx.kernels_q4;
522
- block_len = QK4_0;
523
- num_bytes_multiplier = sizeof(uint16_t);
524
- } else if (dst->src[0]->type == GGML_TYPE_Q8_0) {
525
- if (!ctx.kernels_q8) {
526
- return false;
1136
+ if (select_slot >= 0 && select_slot < chain_count) {
1137
+ kernels = kernel_chain[select_slot];
1138
+ const uint8_t * slot_ptr = kleidiai_weight_slot_ptr(header, select_slot);
1139
+ if (slot_ptr) {
1140
+ packed_base = slot_ptr;
1141
+ }
527
1142
  }
528
- kernels = ctx.kernels_q8;
529
- block_len = QK8_0;
530
- num_bytes_multiplier = sizeof(float);
531
- } else {
1143
+ }
1144
+
1145
+ if (!kernels && chain_count > 0) {
1146
+ kernels = kernel_chain[0];
1147
+ if (has_header) {
1148
+ const uint8_t * slot_ptr = kleidiai_weight_slot_ptr(header, 0);
1149
+ if (slot_ptr) {
1150
+ packed_base = slot_ptr;
1151
+ }
1152
+ }
1153
+ }
1154
+
1155
+ if (!kernels) {
532
1156
  return false;
533
1157
  }
534
1158
 
@@ -541,6 +1165,19 @@ class tensor_traits : public ggml::cpu::tensor_traits {
541
1165
  const int64_t nc = ne00;
542
1166
  const int64_t nr = ggml_nelements(src1);
543
1167
 
1168
+ const ggml_type rhs_type = kernels->rhs_type;
1169
+ size_t block_len = 0;
1170
+ size_t num_bytes_multiplier = 0;
1171
+ if (rhs_type == GGML_TYPE_Q4_0) {
1172
+ block_len = QK4_0;
1173
+ num_bytes_multiplier = sizeof(uint16_t);
1174
+ } else if (rhs_type == GGML_TYPE_Q8_0) {
1175
+ block_len = QK8_0;
1176
+ num_bytes_multiplier = sizeof(float);
1177
+ } else {
1178
+ return false;
1179
+ }
1180
+
544
1181
  const size_t block_rows = kernel->get_nr();
545
1182
  const size_t kr = kernel->get_kr();
546
1183
 
@@ -559,7 +1196,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
559
1196
  GGML_ASSERT(row_idx >= 0 && row_idx < src0->ne[1]);
560
1197
 
561
1198
  float *out = (float *)((char *)dst->data + i * nb1);
562
- rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, block_len, num_bytes_multiplier);
1199
+ rhs_info->to_float(packed_base, row_idx, nc, out, block_rows, packed_stride, kr, block_len, num_bytes_multiplier);
563
1200
  }
564
1201
 
565
1202
  return true;
@@ -567,36 +1204,39 @@ class tensor_traits : public ggml::cpu::tensor_traits {
567
1204
 
568
1205
  public:
569
1206
  int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
1207
+ GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q8_0);
570
1208
  const size_t n = tensor->ne[1];
571
1209
  const size_t k = tensor->ne[0];
572
1210
 
573
- if (tensor->type == GGML_TYPE_Q4_0) {
574
- if (!ctx.kernels_q4) {
575
- return -1;
576
- }
577
- size_t nr = ctx.kernels_q4->gemm.get_nr();
578
- size_t kr = ctx.kernels_q4->gemm.get_kr();
579
- size_t sr = ctx.kernels_q4->gemm.get_sr();
1211
+ kleidiai_weight_header * header = kleidiai_weight_header_from_ptr(tensor->data);
1212
+ if (!header) {
1213
+ return -1;
1214
+ }
580
1215
 
581
- struct kai_rhs_pack_qs4cxs1s0_param params;
582
- params.lhs_zero_point = 1;
583
- params.rhs_zero_point = 8;
584
- ctx.kernels_q4->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0,
585
- static_cast<const uint8_t *>(data),
586
- nullptr, nullptr, tensor->data, 0, &params);
587
- GGML_UNUSED(data_size);
588
- return 0;
589
- } else if (tensor->type == GGML_TYPE_Q8_0) {
590
- if (!ctx.kernels_q8) {
591
- return -1;
592
- }
1216
+ header->magic = GGML_KLEIDIAI_PACK_MAGIC;
1217
+ header->version = GGML_KLEIDIAI_PACK_VERSION;
1218
+ header->slot_count = 0;
1219
+
1220
+ uint8_t * base_ptr = static_cast<uint8_t *>(tensor->data);
1221
+ size_t cursor = sizeof(kleidiai_weight_header);
1222
+ cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
1223
+
1224
+ std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
1225
+ const bool want_q8 = tensor->type == GGML_TYPE_Q8_0;
1226
+ const int slot_total = want_q8 ? kleidiai_collect_q8_chain(kernel_chain)
1227
+ : kleidiai_collect_q4_chain(kernel_chain);
1228
+ const bool allow_fallback = kleidiai_pack_fallback_allowed();
1229
+
1230
+ std::vector<int8_t> qdata;
1231
+ std::vector<float> scales;
1232
+
1233
+ if (want_q8 && slot_total > 0) {
1234
+ qdata.resize(n * k, 0);
1235
+ scales.resize(n, 0.0f);
593
1236
 
594
1237
  const size_t row_stride = tensor->nb[1];
595
1238
  const size_t k_blocks = (k + QK8_0 - 1) / QK8_0;
596
1239
 
597
- std::vector<int8_t> qdata(n * k, 0);
598
- std::vector<float> scales(n, 0.0f);
599
-
600
1240
  for (size_t row = 0; row < n; ++row) {
601
1241
  const auto * row_blocks = reinterpret_cast<const block_q8_0 *>(
602
1242
  static_cast<const uint8_t *>(data) + row * row_stride);
@@ -610,7 +1250,7 @@ public:
610
1250
  if (linear_idx >= k) {
611
1251
  break;
612
1252
  }
613
- const float value = d * blk.qs[l];
1253
+ const float value = d * static_cast<float>(blk.qs[l]);
614
1254
  max_abs = std::max(max_abs, std::fabs(value));
615
1255
  }
616
1256
  }
@@ -627,31 +1267,73 @@ public:
627
1267
  if (linear_idx >= k) {
628
1268
  break;
629
1269
  }
630
- const float value = d * blk.qs[l];
1270
+ const float value = d * static_cast<float>(blk.qs[l]);
631
1271
  int32_t q = scale > 0.0f ? static_cast<int32_t>(std::lround(value * inv_scale)) : 0;
632
1272
  q = std::clamp(q, -127, 127);
633
1273
  qdata[row * k + linear_idx] = static_cast<int8_t>(q);
634
1274
  }
635
1275
  }
636
1276
  }
1277
+ }
1278
+
1279
+ for (int slot = 0; slot < slot_total && slot < GGML_KLEIDIAI_MAX_KERNEL_SLOTS; ++slot) {
1280
+ if (!allow_fallback && slot > 0) {
1281
+ break;
1282
+ }
1283
+ ggml_kleidiai_kernels * kernels = kernel_chain[slot];
1284
+ kernel_info * kernel = &kernels->gemm;
1285
+ rhs_packing_info * rhs_info = &kernels->rhs_info;
1286
+ if (!rhs_info || !rhs_info->pack_func_ex || !rhs_info->packed_size_ex || !kernel) {
1287
+ continue;
1288
+ }
1289
+
1290
+ const size_t nr = kernel->get_nr();
1291
+ const size_t kr = kernel->get_kr();
1292
+ const size_t sr = kernel->get_sr();
1293
+ const ggml_type rhs_type = kernels->rhs_type;
1294
+ const size_t block_len = rhs_type == GGML_TYPE_Q8_0 ? QK8_0 :
1295
+ rhs_type == GGML_TYPE_Q4_0 ? QK4_0 : 0;
1296
+ if (block_len == 0) {
1297
+ continue;
1298
+ }
637
1299
 
638
- size_t nr = ctx.kernels_q8->gemm.get_nr();
639
- size_t kr = ctx.kernels_q8->gemm.get_kr();
640
- size_t sr = ctx.kernels_q8->gemm.get_sr();
1300
+ const size_t packed_size = rhs_info->packed_size_ex(n, k, nr, kr, block_len);
1301
+ const size_t aligned_cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
1302
+
1303
+ uint8_t * dst_ptr = base_ptr + aligned_cursor;
1304
+
1305
+ if (rhs_type == GGML_TYPE_Q4_0) {
1306
+ struct kai_rhs_pack_qs4cxs1s0_param params;
1307
+ params.lhs_zero_point = 1;
1308
+ params.rhs_zero_point = 8;
1309
+ rhs_info->pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0,
1310
+ static_cast<const uint8_t *>(data), nullptr, nullptr,
1311
+ dst_ptr, 0, &params);
1312
+ } else if (rhs_type == GGML_TYPE_Q8_0) {
1313
+ struct kai_rhs_pack_qsi8cx_params params;
1314
+ params.lhs_zero_point = 1;
1315
+ params.scale_multiplier = 1.0f;
1316
+ rhs_info->pack_func_ex(1, n, k, nr, kr, sr, 0, 0,
1317
+ qdata.data(), nullptr, scales.data(),
1318
+ dst_ptr, 0, &params);
1319
+ } else {
1320
+ continue;
1321
+ }
1322
+
1323
+ header->offsets[header->slot_count] = aligned_cursor;
1324
+ header->sizes[header->slot_count] = packed_size;
1325
+ ++header->slot_count;
641
1326
 
642
- struct kai_rhs_pack_qsi8cx_params params;
643
- params.lhs_zero_point = 1;
644
- params.scale_multiplier = 1.0f;
1327
+ cursor = aligned_cursor + packed_size;
1328
+ }
645
1329
 
646
- ctx.kernels_q8->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, 0, 0,
647
- qdata.data(), nullptr, scales.data(),
648
- tensor->data, 0, &params);
649
- GGML_UNUSED(data_size);
650
- return 0;
1330
+ if (header->slot_count == 0) {
1331
+ header->magic = 0;
1332
+ header->version = 0;
1333
+ memcpy(tensor->data, data, data_size);
651
1334
  }
652
1335
 
653
- GGML_UNUSED(data_size);
654
- return -1;
1336
+ return 0;
655
1337
  }
656
1338
  };
657
1339
 
@@ -681,9 +1363,8 @@ static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t bu
681
1363
  }
682
1364
 
683
1365
  static const char * ggml_backend_cpu_kleidiai_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
684
- return "CPU_KLEIDIAI";
685
-
686
1366
  GGML_UNUSED(buft);
1367
+ return "CPU_KLEIDIAI";
687
1368
  }
688
1369
 
689
1370
  static ggml_backend_buffer_t ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
@@ -702,49 +1383,78 @@ static ggml_backend_buffer_t ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer(
702
1383
  }
703
1384
 
704
1385
  static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
705
- return TENSOR_ALIGNMENT;
706
-
707
1386
  GGML_UNUSED(buft);
1387
+ return TENSOR_ALIGNMENT;
708
1388
  }
709
1389
 
710
1390
  static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
711
1391
  GGML_UNUSED(buft);
712
1392
 
1393
+ if (tensor->type != GGML_TYPE_Q4_0 && tensor->type != GGML_TYPE_Q8_0) {
1394
+ return ggml_nbytes(tensor);
1395
+ }
1396
+
713
1397
  const size_t n = tensor->ne[1];
714
1398
  const size_t k = tensor->ne[0];
715
1399
 
716
- ggml_kleidiai_kernels * kernels = nullptr;
717
- size_t block_len = 0;
718
-
719
- if (tensor->type == GGML_TYPE_Q4_0) {
720
- GGML_ASSERT(ctx.kernels_q4);
721
- kernels = ctx.kernels_q4;
722
- block_len = QK4_0;
723
- } else if (tensor->type == GGML_TYPE_Q8_0) {
724
- GGML_ASSERT(ctx.kernels_q8);
725
- kernels = ctx.kernels_q8;
726
- block_len = QK8_0;
727
- } else {
728
- return 0;
1400
+ size_t cursor = sizeof(kleidiai_weight_header);
1401
+ cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
1402
+
1403
+ std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
1404
+ const bool want_q8 = tensor->type == GGML_TYPE_Q8_0;
1405
+ const int slot_total = want_q8 ? kleidiai_collect_q8_chain(kernel_chain)
1406
+ : kleidiai_collect_q4_chain(kernel_chain);
1407
+ const bool allow_fallback = kleidiai_pack_fallback_allowed();
1408
+
1409
+ size_t slot_count = 0;
1410
+ for (int slot = 0; slot < slot_total; ++slot) {
1411
+ if (!allow_fallback && slot > 0) {
1412
+ break;
1413
+ }
1414
+ ggml_kleidiai_kernels * kernels = kernel_chain[slot];
1415
+ if (!kernels) {
1416
+ continue;
1417
+ }
1418
+ kernel_info * kernel = &kernels->gemm;
1419
+ rhs_packing_info * rhs_info = &kernels->rhs_info;
1420
+ if (!kernel || !rhs_info || !rhs_info->packed_size_ex) {
1421
+ continue;
1422
+ }
1423
+
1424
+ const ggml_type rhs_type = kernels->rhs_type;
1425
+ const size_t block_len = rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
1426
+ rhs_type == GGML_TYPE_Q8_0 ? QK8_0 : 0;
1427
+ if (block_len == 0) {
1428
+ continue;
1429
+ }
1430
+
1431
+ cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
1432
+ cursor += rhs_info->packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), block_len);
1433
+ ++slot_count;
729
1434
  }
730
1435
 
731
- const size_t nr = kernels->gemm.get_nr();
732
- const size_t kr = kernels->gemm.get_kr();
733
- const size_t packed = kernels->rhs_info.packed_size_ex(n, k, nr, kr, block_len);
734
- const size_t raw = ggml_nbytes(tensor);
1436
+ if (slot_count == 0) {
1437
+ return ggml_nbytes(tensor);
1438
+ }
735
1439
 
736
- return packed > raw ? packed : raw;
1440
+ return std::max(cursor, ggml_nbytes(tensor));
737
1441
  }
738
1442
 
739
1443
  namespace ggml::cpu::kleidiai {
740
1444
  class extra_buffer_type : ggml::cpu::extra_buffer_type {
741
1445
  bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
1446
+ std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
1447
+ const int slot_total = kleidiai_collect_kernel_chain(op, kernel_chain);
742
1448
  if ((op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) &&
743
1449
  (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q8_0) &&
744
1450
  op->src[0]->buffer &&
745
1451
  (ggml_n_dims(op->src[0]) == 2) &&
746
- op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
747
- if (((op->src[0]->type == GGML_TYPE_Q4_0) ? ctx.kernels_q4 : ctx.kernels_q8) == nullptr) {
1452
+ op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() &&
1453
+ slot_total > 0) {
1454
+ if (op->src[0]->type == GGML_TYPE_Q4_0 && ctx.kernels_q4 == nullptr) {
1455
+ return false;
1456
+ }
1457
+ if (op->src[0]->type == GGML_TYPE_Q8_0 && ctx.kernels_q8 == nullptr) {
748
1458
  return false;
749
1459
  }
750
1460
  if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
@@ -762,14 +1472,17 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
762
1472
  if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) {
763
1473
  if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
764
1474
  return (ggml::cpu::tensor_traits *) op->src[0]->extra;
765
- }
766
- else if (ggml_kleidiai_select_kernels(ctx.features, op) && op->src[1]->ne[1] > 1) {
767
- if ((op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
768
- (op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) {
769
- return nullptr;
1475
+ } else {
1476
+ std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
1477
+ const int slot_total = kleidiai_collect_kernel_chain(op, kernel_chain);
1478
+ const bool has_kernel = slot_total > 0;
1479
+ if (has_kernel && op->src[1]->ne[1] > 1) {
1480
+ if ((op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
1481
+ (op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) {
1482
+ return nullptr;
1483
+ }
1484
+ return ggml::cpu::kleidiai::get_tensor_traits(NULL, NULL);
770
1485
  }
771
-
772
- return ggml::cpu::kleidiai::get_tensor_traits(NULL, NULL);
773
1486
  }
774
1487
  }
775
1488
  return nullptr;