whispercpp 1.3.5 → 1.3.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1017) hide show
  1. checksums.yaml +4 -4
  2. data/.document +3 -0
  3. data/.rdoc_options +2 -0
  4. data/LICENSE +1 -1
  5. data/README.md +133 -3
  6. data/Rakefile +18 -3
  7. data/ext/dependencies.rb +10 -4
  8. data/ext/dependencies_for_windows.rb +17 -0
  9. data/ext/extconf.rb +20 -7
  10. data/ext/options.rb +54 -14
  11. data/ext/options_for_windows.rb +51 -0
  12. data/ext/ruby_whisper.c +56 -46
  13. data/ext/ruby_whisper.h +165 -2
  14. data/ext/ruby_whisper_context.c +297 -126
  15. data/ext/ruby_whisper_context_params.c +163 -0
  16. data/ext/ruby_whisper_log_queue.c +180 -0
  17. data/ext/ruby_whisper_log_settable.h +47 -0
  18. data/ext/ruby_whisper_model.c +0 -1
  19. data/ext/ruby_whisper_parakeet.c +49 -0
  20. data/ext/ruby_whisper_parakeet_context.c +304 -0
  21. data/ext/ruby_whisper_parakeet_context_params.c +117 -0
  22. data/ext/ruby_whisper_parakeet_model.c +84 -0
  23. data/ext/ruby_whisper_parakeet_params.c +548 -0
  24. data/ext/ruby_whisper_parakeet_segment.c +157 -0
  25. data/ext/ruby_whisper_parakeet_token.c +188 -0
  26. data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
  27. data/ext/ruby_whisper_params.c +256 -66
  28. data/ext/ruby_whisper_segment.c +6 -7
  29. data/ext/ruby_whisper_token.c +29 -9
  30. data/ext/ruby_whisper_transcribe.cpp +46 -16
  31. data/ext/ruby_whisper_vad_context.c +48 -1
  32. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  33. data/ext/ruby_whisper_vad_params.c +0 -1
  34. data/ext/ruby_whisper_vad_segment.c +0 -1
  35. data/ext/ruby_whisper_vad_segments.c +0 -1
  36. data/ext/sources/CMakeLists.txt +41 -3
  37. data/ext/sources/CMakePresets.json +95 -0
  38. data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
  39. data/ext/sources/cmake/parakeet.pc.in +10 -0
  40. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  41. data/ext/sources/cmake/whisper.pc.in +1 -1
  42. data/ext/sources/examples/CMakeLists.txt +4 -2
  43. data/ext/sources/examples/bench/bench.cpp +24 -19
  44. data/ext/sources/examples/cli/cli.cpp +51 -9
  45. data/ext/sources/examples/common-ggml.cpp +4 -0
  46. data/ext/sources/examples/common-whisper.cpp +139 -67
  47. data/ext/sources/examples/common-whisper.h +11 -0
  48. data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
  49. data/ext/sources/examples/miniaudio.h +4507 -2131
  50. data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
  51. data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
  52. data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
  53. data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
  54. data/ext/sources/examples/server/server.cpp +213 -163
  55. data/ext/sources/ggml/CMakeLists.txt +29 -15
  56. data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
  57. data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
  58. data/ext/sources/ggml/include/ggml-alloc.h +1 -0
  59. data/ext/sources/ggml/include/ggml-backend.h +73 -11
  60. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  61. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  62. data/ext/sources/ggml/include/ggml-cuda.h +3 -0
  63. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  64. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  65. data/ext/sources/ggml/include/ggml-rpc.h +8 -3
  66. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  67. data/ext/sources/ggml/include/ggml.h +155 -16
  68. data/ext/sources/ggml/include/gguf.h +10 -2
  69. data/ext/sources/ggml/src/CMakeLists.txt +25 -5
  70. data/ext/sources/ggml/src/ggml-alloc.c +9 -10
  71. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  72. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  73. data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
  74. data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
  75. data/ext/sources/ggml/src/ggml-backend-reg.cpp +40 -86
  76. data/ext/sources/ggml/src/ggml-backend.cpp +114 -10
  77. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  78. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -2
  79. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  80. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  81. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +1016 -442
  82. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +111 -85
  83. data/ext/sources/ggml/src/ggml-cann/common.h +23 -14
  84. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +255 -92
  85. data/ext/sources/ggml/src/ggml-common.h +22 -0
  86. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +68 -34
  87. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +44 -19
  88. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  89. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +101 -101
  90. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +194 -1
  91. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2874 -613
  92. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
  93. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
  94. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +5480 -840
  95. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1361 -0
  96. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -11
  97. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
  98. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +186 -36
  99. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +119 -19
  100. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +112 -26
  101. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  102. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
  103. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  104. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -0
  105. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +153 -16
  106. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +17 -0
  107. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  108. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +976 -251
  109. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +671 -266
  110. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1277 -263
  111. data/ext/sources/ggml/src/ggml-cpu/ops.h +4 -0
  112. data/ext/sources/ggml/src/ggml-cpu/quants.c +95 -0
  113. data/ext/sources/ggml/src/ggml-cpu/quants.h +6 -0
  114. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2893 -679
  115. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  116. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +226 -0
  117. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +114 -19
  118. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
  119. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
  120. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
  121. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
  122. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
  123. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
  124. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
  125. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
  126. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
  127. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
  128. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
  129. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
  130. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
  131. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
  132. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
  133. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  134. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +54 -53
  135. data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
  136. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +18 -8
  137. data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
  138. data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
  139. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +73 -28
  140. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +69 -41
  141. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
  142. data/ext/sources/ggml/src/ggml-cuda/common.cuh +359 -29
  143. data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
  144. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
  145. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
  146. data/ext/sources/ggml/src/ggml-cuda/convert.cu +94 -27
  147. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  148. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +20 -9
  149. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
  150. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +333 -85
  151. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +632 -190
  152. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
  153. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +162 -49
  154. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +43 -18
  155. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +44 -14
  156. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  157. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +241 -23
  158. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  159. data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
  160. data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
  161. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +312 -0
  162. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  163. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
  164. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1454 -599
  165. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
  166. data/ext/sources/ggml/src/ggml-cuda/mean.cu +13 -10
  167. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +397 -183
  168. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  169. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +161 -88
  170. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
  171. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +522 -431
  172. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +139 -72
  173. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  174. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +608 -88
  175. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -0
  176. data/ext/sources/ggml/src/ggml-cuda/norm.cu +47 -79
  177. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
  178. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  179. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +134 -27
  180. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
  181. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +7 -17
  182. data/ext/sources/ggml/src/ggml-cuda/rope.cu +244 -137
  183. data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
  184. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
  185. data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
  186. data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
  187. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
  188. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  189. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  190. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +96 -40
  191. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  192. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
  193. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
  194. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
  195. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +6 -0
  196. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
  197. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +2 -0
  198. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
  199. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +6 -0
  200. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +2 -0
  201. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
  202. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
  203. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +2 -0
  204. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
  205. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +2 -0
  206. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
  207. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
  208. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
  209. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
  210. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
  211. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
  212. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
  213. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
  214. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
  225. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -5
  226. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +202 -135
  227. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  228. data/ext/sources/ggml/src/ggml-cuda/unary.cu +86 -2
  229. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +4 -0
  230. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +111 -17
  231. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
  232. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +30 -2
  233. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
  234. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +84 -46
  235. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1612 -753
  236. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +51 -11
  237. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +361 -261
  238. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +294 -0
  239. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +753 -241
  240. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
  241. data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
  242. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +295 -0
  243. data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
  244. data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
  245. data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
  246. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +471 -296
  247. data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
  248. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +159 -53
  249. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +3 -3
  250. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +372 -0
  251. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +86 -0
  252. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  253. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +137 -0
  254. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
  255. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
  256. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
  257. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
  258. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
  259. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
  260. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
  261. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
  262. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +97 -14
  263. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +163 -67
  264. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
  265. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  266. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +308 -0
  267. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +262 -0
  268. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +291 -0
  269. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  270. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +216 -0
  271. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
  272. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  273. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  274. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
  275. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
  276. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
  278. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +142 -0
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +18 -1348
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +547 -635
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +3556 -1101
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +475 -269
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +94 -72
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +222 -217
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +432 -0
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +886 -117
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  297. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +40 -0
  302. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +28 -9
  303. data/ext/sources/ggml/src/ggml-impl.h +68 -1
  304. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  305. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  306. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  307. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  308. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +409 -83
  309. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +54 -5
  310. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +254 -52
  311. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +254 -23
  312. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +756 -285
  313. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +7 -4
  314. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +359 -133
  315. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1867 -1123
  316. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
  317. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +71 -4
  318. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +14127 -5314
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +97 -88
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +104 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  322. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1978 -67
  323. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  324. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  325. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  326. data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
  327. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
  328. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
  329. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
  330. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
  331. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
  332. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
  333. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
  334. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
  335. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
  336. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q8_0_f32.cl +129 -0
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl +195 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
  367. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
  368. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
  369. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  370. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  371. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
  372. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
  373. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
  374. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
  375. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  376. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
  377. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
  378. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  379. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  380. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  381. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
  382. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
  383. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
  384. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
  385. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
  386. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
  387. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
  388. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  389. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +178 -0
  390. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  391. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  392. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  393. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  394. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  395. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  396. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  397. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  398. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  399. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  400. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  401. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +985 -0
  402. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  403. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +380 -0
  404. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  405. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1132 -0
  406. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +956 -0
  407. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  412. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  413. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  414. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  415. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  416. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  417. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  418. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  419. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  420. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  421. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  422. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  423. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  424. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +149 -0
  425. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  426. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  427. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  428. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  429. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
  430. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  431. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  432. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +47 -0
  433. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +40 -0
  434. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  435. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  436. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  437. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  438. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  439. data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
  440. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +317 -0
  441. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  442. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +257 -0
  443. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +86 -0
  444. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +880 -0
  445. data/ext/sources/ggml/src/ggml-openvino/utils.h +143 -0
  446. data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
  447. data/ext/sources/ggml/src/ggml-quants.c +385 -119
  448. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  449. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
  450. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
  451. data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
  452. data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
  453. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +64 -91
  454. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  455. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +4 -1
  456. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  457. data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
  458. data/ext/sources/ggml/src/ggml-sycl/common.hpp +356 -11
  459. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +184 -14
  460. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +31 -1
  461. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  462. data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
  463. data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
  464. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
  465. data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
  466. data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
  467. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
  468. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  469. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +77 -156
  470. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -2
  471. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
  472. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
  473. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1181 -0
  474. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +59 -0
  475. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1246 -0
  476. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +674 -0
  477. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +227 -0
  478. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  479. data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
  480. data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
  481. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +347 -0
  482. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +9 -0
  483. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  484. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
  485. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +1134 -236
  486. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
  487. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
  488. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
  489. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
  490. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  491. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  492. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
  493. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  494. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +72 -1
  495. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  496. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  497. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
  498. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  499. data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
  500. data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
  501. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
  502. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
  503. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
  504. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
  505. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
  506. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  507. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  508. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  509. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  510. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
  511. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  512. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  513. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  514. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  515. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  516. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +8 -0
  517. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +8 -0
  518. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +8 -0
  519. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +8 -0
  520. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +8 -0
  521. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +8 -0
  522. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +8 -0
  523. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +8 -0
  524. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +8 -0
  525. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +8 -0
  526. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +8 -0
  527. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +8 -0
  528. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +8 -0
  529. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +8 -0
  530. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +8 -0
  531. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +8 -0
  532. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +8 -0
  533. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +8 -0
  534. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +8 -0
  535. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +8 -0
  536. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +8 -0
  537. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +8 -0
  538. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +8 -0
  539. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +8 -0
  540. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +8 -0
  541. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +8 -0
  542. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +8 -0
  543. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +8 -0
  544. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +8 -0
  545. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +8 -0
  546. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +8 -0
  547. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +8 -0
  548. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +8 -0
  549. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +8 -0
  550. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +8 -0
  551. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +8 -0
  552. data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
  553. data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
  554. data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
  555. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +228 -53
  556. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  557. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  558. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  559. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  560. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  561. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  562. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  563. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  564. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  565. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  566. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  567. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  568. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  569. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  570. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  571. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  572. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  573. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  574. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  575. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  576. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  577. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  578. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +123 -0
  579. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +160 -0
  580. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  581. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +71 -0
  582. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  583. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  584. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  585. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  586. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  587. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  588. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  589. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  590. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  591. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  592. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +99 -0
  593. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  594. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  595. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  596. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +545 -0
  597. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +115 -0
  598. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  599. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3250 -940
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +533 -180
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +113 -68
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +412 -222
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +222 -83
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +189 -0
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +22 -20
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +51 -14
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +39 -63
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +13 -7
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +27 -11
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -149
  663. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
  664. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +3221 -97
  665. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3493 -1997
  666. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
  667. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
  668. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  669. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  670. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  671. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +142 -0
  672. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +115 -141
  673. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +93 -0
  674. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
  675. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -44
  676. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  677. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +198 -230
  678. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
  679. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
  680. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
  681. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
  682. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
  683. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
  684. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +234 -335
  685. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
  686. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
  687. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +871 -42
  688. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
  689. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
  690. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
  691. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +149 -0
  692. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +36 -138
  693. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +151 -0
  694. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
  695. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
  696. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  697. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
  698. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
  699. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  700. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
  701. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
  702. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
  703. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +15 -40
  704. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
  705. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +39 -12
  706. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
  707. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
  708. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
  709. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
  710. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
  711. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  712. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +213 -0
  713. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
  714. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +24 -15
  715. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  716. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +253 -16
  717. data/ext/sources/ggml/src/ggml.c +268 -52
  718. data/ext/sources/ggml/src/gguf.cpp +377 -47
  719. data/ext/sources/include/parakeet.h +342 -0
  720. data/ext/sources/include/whisper.h +10 -0
  721. data/ext/sources/media/matmul.png +0 -0
  722. data/ext/sources/src/CMakeLists.txt +23 -0
  723. data/ext/sources/src/parakeet-arch.h +188 -0
  724. data/ext/sources/src/parakeet.cpp +3838 -0
  725. data/ext/sources/src/whisper.cpp +62 -40
  726. data/extsources.rb +26 -10
  727. data/lib/whisper/log_settable.rb +36 -0
  728. data/lib/whisper/model/uri.rb +13 -1
  729. data/lib/whisper/output.rb +74 -0
  730. data/sig/whisper.rbs +445 -55
  731. data/test/helper.rb +2 -0
  732. data/test/jfk_reader/jfk_reader.c +50 -7
  733. data/test/test_callback.rb +1 -0
  734. data/test/test_context_params.rb +82 -0
  735. data/test/test_package.rb +6 -5
  736. data/test/test_parakeet.rb +28 -0
  737. data/test/test_parakeet_callback.rb +107 -0
  738. data/test/test_parakeet_context.rb +116 -0
  739. data/test/test_parakeet_context_params.rb +24 -0
  740. data/test/test_parakeet_model.rb +21 -0
  741. data/test/test_parakeet_params.rb +78 -0
  742. data/test/test_parakeet_segment.rb +42 -0
  743. data/test/test_parakeet_token.rb +73 -0
  744. data/test/test_params.rb +2 -0
  745. data/test/test_token.rb +11 -0
  746. data/test/test_vad_context.rb +58 -8
  747. data/test/test_vad_segment.rb +1 -1
  748. data/test/test_whisper.rb +44 -6
  749. data/whispercpp.gemspec +2 -2
  750. metadata +426 -280
  751. data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
  752. data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
  753. data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
  754. data/ext/sources/bindings/javascript/package.json +0 -26
  755. data/ext/sources/bindings/javascript/whisper.js +0 -19
  756. data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
  757. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
  758. data/ext/sources/examples/addon.node/addon.cpp +0 -557
  759. data/ext/sources/examples/addon.node/index.js +0 -59
  760. data/ext/sources/examples/addon.node/package.json +0 -16
  761. data/ext/sources/examples/addon.node/vad-example.js +0 -132
  762. data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
  763. data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
  764. data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
  765. data/ext/sources/examples/coi-serviceworker.js +0 -146
  766. data/ext/sources/examples/command/CMakeLists.txt +0 -10
  767. data/ext/sources/examples/command/command.cpp +0 -802
  768. data/ext/sources/examples/command/commands.txt +0 -9
  769. data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
  770. data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
  771. data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
  772. data/ext/sources/examples/generate-karaoke.sh +0 -57
  773. data/ext/sources/examples/helpers.js +0 -191
  774. data/ext/sources/examples/livestream.sh +0 -112
  775. data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
  776. data/ext/sources/examples/lsp/lsp.cpp +0 -471
  777. data/ext/sources/examples/lsp/whisper.vim +0 -362
  778. data/ext/sources/examples/python/test_whisper_processor.py +0 -7
  779. data/ext/sources/examples/python/whisper_processor.py +0 -54
  780. data/ext/sources/examples/server/bench.js +0 -29
  781. data/ext/sources/examples/server.py +0 -120
  782. data/ext/sources/examples/stream/CMakeLists.txt +0 -10
  783. data/ext/sources/examples/stream/stream.cpp +0 -437
  784. data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
  785. data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
  786. data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
  787. data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
  788. data/ext/sources/examples/sycl/build.sh +0 -22
  789. data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
  790. data/ext/sources/examples/sycl/run-whisper.sh +0 -17
  791. data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -47
  792. data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
  793. data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -494
  794. data/ext/sources/examples/talk-llama/llama-adapter.h +0 -88
  795. data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2559
  796. data/ext/sources/examples/talk-llama/llama-arch.h +0 -586
  797. data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -917
  798. data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
  799. data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -876
  800. data/ext/sources/examples/talk-llama/llama-chat.h +0 -70
  801. data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3645
  802. data/ext/sources/examples/talk-llama/llama-context.h +0 -360
  803. data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
  804. data/ext/sources/examples/talk-llama/llama-cparams.h +0 -42
  805. data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
  806. data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
  807. data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2282
  808. data/ext/sources/examples/talk-llama/llama-graph.h +0 -910
  809. data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -241
  810. data/ext/sources/examples/talk-llama/llama-hparams.h +0 -284
  811. data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
  812. data/ext/sources/examples/talk-llama/llama-impl.h +0 -63
  813. data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
  814. data/ext/sources/examples/talk-llama/llama-io.h +0 -35
  815. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -328
  816. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
  817. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2100
  818. data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -390
  819. data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
  820. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
  821. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
  822. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1167
  823. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
  824. data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
  825. data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
  826. data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -735
  827. data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
  828. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1247
  829. data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -176
  830. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -285
  831. data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -37
  832. data/ext/sources/examples/talk-llama/llama-model.cpp +0 -8338
  833. data/ext/sources/examples/talk-llama/llama-model.h +0 -544
  834. data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1072
  835. data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
  836. data/ext/sources/examples/talk-llama/llama-sampling.cpp +0 -3771
  837. data/ext/sources/examples/talk-llama/llama-sampling.h +0 -44
  838. data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3900
  839. data/ext/sources/examples/talk-llama/llama-vocab.h +0 -182
  840. data/ext/sources/examples/talk-llama/llama.cpp +0 -1140
  841. data/ext/sources/examples/talk-llama/llama.h +0 -1540
  842. data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -191
  843. data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
  844. data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
  845. data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -138
  846. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
  847. data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -122
  848. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -144
  849. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -135
  850. data/ext/sources/examples/talk-llama/models/bert.cpp +0 -178
  851. data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -160
  852. data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
  853. data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
  854. data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
  855. data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
  856. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
  857. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
  858. data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
  859. data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -123
  860. data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
  861. data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -144
  862. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -259
  863. data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -134
  864. data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
  865. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -150
  866. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
  867. data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
  868. data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
  869. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -113
  870. data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
  871. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
  872. data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
  873. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
  874. data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
  875. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
  876. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
  877. data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -150
  878. data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
  879. data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
  880. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -196
  881. data/ext/sources/examples/talk-llama/models/granite.cpp +0 -211
  882. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +0 -283
  883. data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
  884. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -141
  885. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
  886. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -154
  887. data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
  888. data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
  889. data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
  890. data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -175
  891. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
  892. data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
  893. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
  894. data/ext/sources/examples/talk-llama/models/llama.cpp +0 -168
  895. data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
  896. data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -55
  897. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -123
  898. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -199
  899. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -124
  900. data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
  901. data/ext/sources/examples/talk-llama/models/models.h +0 -569
  902. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -116
  903. data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
  904. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -150
  905. data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
  906. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
  907. data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
  908. data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
  909. data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
  910. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
  911. data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
  912. data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
  913. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
  914. data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
  915. data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
  916. data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
  917. data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -316
  918. data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
  919. data/ext/sources/examples/talk-llama/models/plm.cpp +0 -168
  920. data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
  921. data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
  922. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
  923. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
  924. data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -117
  925. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -124
  926. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -873
  927. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -149
  928. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -141
  929. data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
  930. data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
  931. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -162
  932. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
  933. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
  934. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -135
  935. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
  936. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
  937. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
  938. data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
  939. data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
  940. data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
  941. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
  942. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
  943. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
  944. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
  945. data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
  946. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
  947. data/ext/sources/examples/talk-llama/speak +0 -40
  948. data/ext/sources/examples/talk-llama/speak.bat +0 -1
  949. data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
  950. data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
  951. data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
  952. data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
  953. data/ext/sources/examples/talk-llama/unicode.cpp +0 -1147
  954. data/ext/sources/examples/talk-llama/unicode.h +0 -111
  955. data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
  956. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
  957. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
  958. data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
  959. data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
  960. data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
  961. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
  962. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
  963. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
  964. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
  965. data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
  966. data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
  967. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  968. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  969. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
  970. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +0 -157
  971. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -165
  972. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  973. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  974. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  975. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  976. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  977. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  978. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  979. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
  980. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
  981. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
  982. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  983. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  984. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -147
  985. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
  986. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +0 -907
  987. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +0 -247
  988. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  989. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
  990. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  991. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
  992. data/ext/sources/tests/CMakeLists.txt +0 -112
  993. data/ext/sources/tests/earnings21/eval.mk +0 -58
  994. data/ext/sources/tests/earnings21/eval.py +0 -68
  995. data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
  996. data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
  997. data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
  998. data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
  999. data/ext/sources/tests/earnings21/requirements.txt +0 -6
  1000. data/ext/sources/tests/en-0-ref.txt +0 -1
  1001. data/ext/sources/tests/en-1-ref.txt +0 -1
  1002. data/ext/sources/tests/en-2-ref.txt +0 -1
  1003. data/ext/sources/tests/es-0-ref.txt +0 -1
  1004. data/ext/sources/tests/librispeech/eval.mk +0 -39
  1005. data/ext/sources/tests/librispeech/eval.py +0 -47
  1006. data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
  1007. data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
  1008. data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
  1009. data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
  1010. data/ext/sources/tests/librispeech/requirements.txt +0 -6
  1011. data/ext/sources/tests/run-tests.sh +0 -130
  1012. data/ext/sources/tests/test-c.c +0 -3
  1013. data/ext/sources/tests/test-vad-full.cpp +0 -56
  1014. data/ext/sources/tests/test-vad.cpp +0 -83
  1015. data/ext/sources/tests/test-whisper.js +0 -58
  1016. data/lib/whisper/context.rb +0 -15
  1017. data/lib/whisper/segment.rb +0 -58
@@ -1,20 +1,31 @@
1
- // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
1
+ // SPDX-FileCopyrightText: Copyright 2025-2026 Arm Limited and/or its affiliates <open-source-office@arm.com>
2
2
  // SPDX-License-Identifier: MIT
3
3
  //
4
4
  #include <arm_neon.h>
5
5
  #include <assert.h>
6
+ #include <stdio.h>
6
7
  #include <atomic>
7
8
  #include <cfloat>
8
- #include <cmath>
9
9
  #include <algorithm>
10
+ #include <cmath>
10
11
  #include <stdexcept>
11
12
  #include <stdint.h>
12
13
  #include <string.h>
13
14
  #include <string>
14
15
  #include <vector>
16
+ #include <array>
17
+ #include <cstddef>
18
+ #include <cstdint>
19
+ #include <fstream>
20
+ #include <set>
21
+ #include <iostream>
22
+ #include <climits>
15
23
  #if defined(__linux__)
16
24
  #include <asm/hwcap.h>
17
25
  #include <sys/auxv.h>
26
+ #include <sys/types.h>
27
+ #include <sys/stat.h>
28
+ #include <unistd.h>
18
29
  #elif defined(__APPLE__)
19
30
  #include <string_view>
20
31
  #include <sys/sysctl.h>
@@ -27,6 +38,7 @@
27
38
  #include "kleidiai.h"
28
39
 
29
40
  #include "ggml-cpu.h"
41
+ #include "ggml-cpu-impl.h"
30
42
  #include "ggml-impl.h"
31
43
  #include "ggml-backend-impl.h"
32
44
  #include "ggml-threading.h"
@@ -39,11 +51,19 @@
39
51
  #define GGML_COMMON_DECL_CPP
40
52
  #include "ggml-common.h"
41
53
 
54
+ static constexpr int GGML_KLEIDIAI_MAX_KERNEL_SLOTS = 2;
55
+ static constexpr uint32_t GGML_KLEIDIAI_PACK_MAGIC = 0x4b4c4149; // "KLAI"
56
+ static constexpr uint16_t GGML_KLEIDIAI_PACK_VERSION = 1;
57
+ static constexpr size_t GGML_KLEIDIAI_PACK_ALIGN = 64;
58
+
42
59
  struct ggml_kleidiai_context {
43
60
  cpu_feature features;
44
61
  ggml_kleidiai_kernels * kernels_q4;
45
62
  ggml_kleidiai_kernels * kernels_q8;
46
- } static ctx = { CPU_FEATURE_NONE, NULL, NULL };
63
+ int sme_thread_cap; // <= 0 means “SME disabled/unknown”;
64
+ int thread_hint; // <= 0 means “no hint”
65
+ int chunk_multiplier;
66
+ } static ctx = { CPU_FEATURE_NONE, nullptr, nullptr, 0, -1, 4 };
47
67
 
48
68
  static const char* cpu_feature_to_string(cpu_feature f) {
49
69
  if (f == CPU_FEATURE_NONE) {
@@ -63,41 +83,388 @@ static const char* cpu_feature_to_string(cpu_feature f) {
63
83
  }
64
84
  }
65
85
 
66
- static void init_kleidiai_context(void) {
86
+ static size_t detect_num_smcus() {
87
+ if (!ggml_cpu_has_sme()) {
88
+ return 0;
89
+ }
90
+
91
+ #if defined(__linux__) && defined(__aarch64__)
92
+ // Linux/aarch64: Best-effort count of Streaming Mode Compute Units (SMCUs) via SMIDR_EL1 sysfs.
93
+ size_t num_private = 0;
94
+ std::set<uint32_t> shared_ids;
95
+
96
+ for (size_t cpu = 0;; ++cpu) {
97
+ const std::string path =
98
+ "/sys/devices/system/cpu/cpu" + std::to_string(cpu) +
99
+ "/regs/identification/smidr_el1";
100
+
101
+ std::ifstream file(path);
102
+ if (!file.is_open()) {
103
+ break;
104
+ }
105
+
106
+ uint64_t smidr = 0;
107
+ if (!(file >> std::hex >> smidr)) {
108
+ continue;
109
+ }
110
+
111
+ // Arm ARM: SMIDR_EL1
112
+ const uint32_t sh = (uint32_t)((smidr >> 13) & 0x3);
113
+ // Build an "affinity-like" identifier for shared SMCUs.
114
+ // Keep the original packing logic, but isolate it here.
115
+ const uint32_t id = (uint32_t)((smidr & 0xFFFu) | ((smidr >> 20) & 0xFFFFF000u));
116
+
117
+ switch (sh) {
118
+ case 0b10: // private SMCU
119
+ ++num_private;
120
+ break;
121
+ case 0b11: // shared SMCU
122
+ shared_ids.emplace(id);
123
+ break;
124
+ case 0b00:
125
+ // Ambiguous / implementation-defined. Be conservative:
126
+ // treat id==0 as private, otherwise as shared.
127
+ if (id == 0) ++num_private;
128
+ else shared_ids.emplace(id);
129
+ break;
130
+ default:
131
+ break;
132
+ }
133
+ }
134
+
135
+ return num_private + shared_ids.size();
136
+
137
+ #elif defined(__APPLE__) && defined(__aarch64__)
138
+ // table for known M4 variants. Users can override via GGML_KLEIDIAI_SME=<n>.
139
+ char chip_name[256] = {};
140
+ size_t size = sizeof(chip_name);
141
+
142
+ if (sysctlbyname("machdep.cpu.brand_string", chip_name, &size, nullptr, 0) == 0) {
143
+ const std::string brand(chip_name);
144
+
145
+ struct ModelSMCU { const char *match; size_t smcus; };
146
+ static const ModelSMCU table[] = {
147
+ { "M4 Ultra", 2 },
148
+ { "M4 Max", 2 },
149
+ { "M4 Pro", 2 },
150
+ { "M4", 1 },
151
+ };
152
+
153
+ for (const auto &e : table) {
154
+ if (brand.find(e.match) != std::string::npos) {
155
+ return e.smcus;
156
+ }
157
+ }
158
+ }
159
+ return 1;
160
+
161
+ #else
162
+ return 1;
163
+ #endif
164
+ }
67
165
 
166
+ static int parse_uint_env(const char *s, const char *name, bool *ok) {
167
+ if (!s) { *ok = false; return 0; }
168
+ char *end = nullptr;
169
+ long v = strtol(s, &end, 10);
170
+ if (end == s || *end != '\0') {
171
+ GGML_LOG_WARN("kleidiai: invalid %s='%s' (expected integer)\n", name, s);
172
+ *ok = false;
173
+ return 0;
174
+ }
175
+ if (v < 0 || v > INT_MAX) {
176
+ GGML_LOG_WARN("kleidiai: out-of-range %s='%s'\n", name, s);
177
+ *ok = false;
178
+ return 0;
179
+ }
180
+ *ok = true;
181
+ return (int)v;
182
+ }
183
+
184
+ static void init_kleidiai_context(void) {
68
185
  ggml_critical_section_start();
69
186
  static bool initialized = false;
70
187
 
71
188
  if (!initialized) {
72
189
  initialized = true;
73
- const char *env_var = getenv("GGML_KLEIDIAI_SME");
74
- int sme_enabled = 0;
190
+
191
+ const char *env_sme = getenv("GGML_KLEIDIAI_SME");
192
+ const char *env_threads = getenv("GGML_TOTAL_THREADS");
193
+ const char *env_chunk_mult = getenv("GGML_KLEIDIAI_CHUNK_MULTIPLIER");
194
+
195
+ const bool cpu_has_sme = ggml_cpu_has_sme();
196
+ size_t detected_smcus = 0;
75
197
 
76
198
  ctx.features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) |
77
199
  (ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) |
78
200
  ((ggml_cpu_has_sve() && ggml_cpu_get_sve_cnt() == QK8_0) ? CPU_FEATURE_SVE : CPU_FEATURE_NONE);
79
201
 
80
- if (env_var) {
81
- sme_enabled = atoi(env_var);
202
+ if (env_threads) {
203
+ bool ok = false;
204
+ int hint = parse_uint_env(env_threads, "GGML_TOTAL_THREADS", &ok);
205
+ if (ok && hint > 0) {
206
+ ctx.thread_hint = hint;
207
+ }
82
208
  }
83
209
 
84
- if (sme_enabled != 0) {
85
- ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
210
+ if (env_chunk_mult) {
211
+ bool ok = false;
212
+ int multiplier = parse_uint_env(env_chunk_mult, "GGML_KLEIDIAI_CHUNK_MULTIPLIER", &ok);
213
+ if (ok && multiplier > 0) {
214
+ ctx.chunk_multiplier = multiplier;
215
+ }
86
216
  }
217
+
218
+ // SME policy:
219
+ // - If CPU doesn't support SME: SME always off.
220
+ // - Else:
221
+ // - env unset => auto-detect cores; enable if detected > 0.
222
+ // - env=0 => force off.
223
+ // - env>0 => force N cores (skip detection).
224
+ int sme_cores = 0;
225
+ bool sme_env_ok = false;
226
+ bool sme_env_set = (env_sme != nullptr);
227
+
228
+ if (!cpu_has_sme) {
229
+ if (sme_env_set) {
230
+ bool ok = false;
231
+ int req = parse_uint_env(env_sme, "GGML_KLEIDIAI_SME", &ok);
232
+ if (ok && req > 0) {
233
+ GGML_LOG_WARN("kleidiai: GGML_KLEIDIAI_SME=%d but SME is not supported on this CPU; disabling SME\n", req);
234
+ }
235
+ }
236
+ sme_cores = 0;
237
+ } else {
238
+ if (sme_env_set) {
239
+ bool ok = false;
240
+ int v = parse_uint_env(env_sme, "GGML_KLEIDIAI_SME", &ok);
241
+ sme_env_ok = ok;
242
+
243
+ if (!ok) {
244
+ GGML_LOG_WARN("kleidiai: GGML_KLEIDIAI_SME set but parsing failed; falling back to runtime SME-core detection\n");
245
+ detected_smcus = detect_num_smcus();
246
+ sme_cores = detected_smcus > 0 ? (int)detected_smcus : 0;
247
+ } else if (v == 0) {
248
+ sme_cores = 0;
249
+ } else {
250
+ sme_cores = v;
251
+ }
252
+ } else {
253
+ detected_smcus = detect_num_smcus();
254
+ sme_cores = detected_smcus > 0 ? (int)detected_smcus : 0;
255
+ }
256
+
257
+ if (!sme_env_set && sme_cores == 0) {
258
+ GGML_LOG_WARN("kleidiai: SME supported but runtime SME-core detection returned 0; falling back to NEON\n");
259
+ }
260
+
261
+ if (sme_cores > 0) {
262
+ ctx.features |= CPU_FEATURE_SME;
263
+ }
264
+ }
265
+
266
+ // Kernel selection
87
267
  ctx.kernels_q4 = ggml_kleidiai_select_kernels_q4_0(ctx.features);
88
268
  ctx.kernels_q8 = ggml_kleidiai_select_kernels_q8_0(ctx.features);
89
- #ifndef NDEBUG
90
- if (ctx.kernels_q4) {
91
- GGML_LOG_DEBUG("kleidiai: using q4 kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels_q4->required_cpu));
269
+
270
+ if (!ctx.kernels_q4) {
271
+ GGML_LOG_INFO("kleidiai: no compatible q4 kernels found for CPU features mask %d\n", (int)ctx.features);
272
+ } else {
273
+ GGML_LOG_INFO("kleidiai: primary q4 kernel feature %s\n", cpu_feature_to_string(ctx.kernels_q4->required_cpu));
92
274
  }
93
- if (ctx.kernels_q8) {
94
- GGML_LOG_DEBUG("kleidiai: using q8 kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels_q8->required_cpu));
275
+
276
+ if (!ctx.kernels_q8) {
277
+ GGML_LOG_INFO("kleidiai: no compatible q8 kernels found for CPU features mask %d\n", (int)ctx.features);
278
+ } else {
279
+ GGML_LOG_INFO("kleidiai: primary q8 kernel feature %s\n", cpu_feature_to_string(ctx.kernels_q8->required_cpu));
280
+ }
281
+
282
+ ctx.sme_thread_cap = (ctx.features & CPU_FEATURE_SME) ? sme_cores : 0;
283
+
284
+ if (ctx.features & CPU_FEATURE_SME) {
285
+ if (sme_env_set && sme_env_ok && sme_cores > 0) {
286
+ GGML_LOG_INFO("kleidiai: SME enabled (GGML_KLEIDIAI_SME=%d override)\n", sme_cores);
287
+ } else {
288
+ GGML_LOG_INFO("kleidiai: SME enabled (runtime-detected SME cores=%d)\n", sme_cores);
289
+ }
290
+ } else {
291
+ GGML_LOG_INFO("kleidiai: SME disabled\n");
95
292
  }
96
- #endif
97
293
  }
294
+
98
295
  ggml_critical_section_end();
99
296
  }
100
297
 
298
+ static inline int kleidiai_sme_thread_cap() {
299
+ return ctx.sme_thread_cap;
300
+ }
301
+
302
+ static inline size_t align_up(size_t value, size_t alignment) {
303
+ if (alignment == 0) {
304
+ return value;
305
+ }
306
+ const size_t remainder = value % alignment;
307
+ return remainder == 0 ? value : value + (alignment - remainder);
308
+ }
309
+
310
+ static inline size_t gcd_size(size_t a, size_t b) {
311
+ while (b != 0) {
312
+ const size_t t = a % b;
313
+ a = b;
314
+ b = t;
315
+ }
316
+ return a;
317
+ }
318
+
319
+ static inline bool lcm_size(size_t a, size_t b, size_t & result) {
320
+ if (a == 0 || b == 0) {
321
+ result = 0;
322
+ return false;
323
+ }
324
+ const size_t g = gcd_size(a, b);
325
+ const size_t q = a / g;
326
+ if (q > SIZE_MAX / b) {
327
+ return false;
328
+ }
329
+ result = q * b;
330
+ return true;
331
+ }
332
+
333
+ static inline size_t ceil_div_size(size_t a, size_t b) {
334
+ return b == 0 ? 0 : (a + b - 1) / b;
335
+ }
336
+
337
+ struct kleidiai_block_args {
338
+ size_t lhs_bl;
339
+ size_t rhs_bl;
340
+ size_t pack_bl;
341
+ };
342
+
343
+ static inline kleidiai_block_args kleidiai_get_block_args(ggml_type rhs_type) {
344
+ switch (rhs_type) {
345
+ case GGML_TYPE_Q4_0:
346
+ return { QK4_0, QK4_0, QK4_0 };
347
+ case GGML_TYPE_Q8_0:
348
+ return { 0, 0, QK8_0 };
349
+ default:
350
+ return { 0, 0, 0 };
351
+ }
352
+ }
353
+
354
+ static inline bool kleidiai_pack_fallback_allowed() {
355
+ if (ctx.sme_thread_cap <= 0) {
356
+ return false;
357
+ }
358
+ if (ctx.thread_hint <= 0) {
359
+ return true;
360
+ }
361
+ return ctx.thread_hint > ctx.sme_thread_cap;
362
+ }
363
+
364
+ struct kleidiai_weight_header {
365
+ uint32_t magic;
366
+ uint16_t version;
367
+ uint16_t slot_count;
368
+ uint64_t offsets[GGML_KLEIDIAI_MAX_KERNEL_SLOTS];
369
+ uint64_t sizes[GGML_KLEIDIAI_MAX_KERNEL_SLOTS];
370
+ };
371
+
372
+ static inline kleidiai_weight_header * kleidiai_weight_header_from_ptr(void * data) {
373
+ return reinterpret_cast<kleidiai_weight_header *>(data);
374
+ }
375
+
376
+ static inline const kleidiai_weight_header * kleidiai_weight_header_from_ptr(const void * data) {
377
+ return reinterpret_cast<const kleidiai_weight_header *>(data);
378
+ }
379
+
380
+ static inline bool kleidiai_is_weight_header_valid(const kleidiai_weight_header * header) {
381
+ if (!header) {
382
+ return false;
383
+ }
384
+ if (header->magic != GGML_KLEIDIAI_PACK_MAGIC || header->version != GGML_KLEIDIAI_PACK_VERSION) {
385
+ return false;
386
+ }
387
+ if (header->slot_count == 0 || header->slot_count > GGML_KLEIDIAI_MAX_KERNEL_SLOTS) {
388
+ return false;
389
+ }
390
+ return true;
391
+ }
392
+
393
+ static inline uint8_t * kleidiai_weight_slot_ptr(kleidiai_weight_header * header, int slot) {
394
+ if (!kleidiai_is_weight_header_valid(header)) {
395
+ return nullptr;
396
+ }
397
+ if (slot < 0 || slot >= header->slot_count) {
398
+ return nullptr;
399
+ }
400
+ return reinterpret_cast<uint8_t *>(header) + header->offsets[slot];
401
+ }
402
+
403
+ static inline const uint8_t * kleidiai_weight_slot_ptr(const kleidiai_weight_header * header, int slot) {
404
+ if (!kleidiai_is_weight_header_valid(header)) {
405
+ return nullptr;
406
+ }
407
+ if (slot < 0 || slot >= header->slot_count) {
408
+ return nullptr;
409
+ }
410
+ return reinterpret_cast<const uint8_t *>(header) + header->offsets[slot];
411
+ }
412
+
413
+ static inline ggml_kleidiai_kernels * kleidiai_primary_kernel_q4() {
414
+ return ctx.kernels_q4;
415
+ }
416
+
417
+ static inline ggml_kleidiai_kernels * kleidiai_primary_kernel_q8() {
418
+ return ctx.kernels_q8;
419
+ }
420
+
421
+ template <typename SelectFallback>
422
+ static int kleidiai_collect_kernel_chain_common(
423
+ ggml_kleidiai_kernels * primary,
424
+ cpu_feature features,
425
+ std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> & out,
426
+ SelectFallback select_fallback) {
427
+ int count = 0;
428
+ if (!primary) {
429
+ return 0;
430
+ }
431
+ out[count++] = primary;
432
+
433
+ if ((primary->required_cpu & CPU_FEATURE_SME) == CPU_FEATURE_SME) {
434
+ const cpu_feature fallback_mask = static_cast<cpu_feature>(features & ~CPU_FEATURE_SME);
435
+ if (fallback_mask != CPU_FEATURE_NONE) {
436
+ ggml_kleidiai_kernels * fallback = select_fallback(fallback_mask);
437
+ if (fallback && fallback != primary &&
438
+ fallback->lhs_type == primary->lhs_type &&
439
+ fallback->rhs_type == primary->rhs_type &&
440
+ fallback->op_type == primary->op_type) {
441
+ out[count++] = fallback;
442
+ }
443
+ }
444
+ }
445
+
446
+ return count;
447
+ }
448
+
449
+ static int kleidiai_collect_kernel_chain(const struct ggml_tensor * op,
450
+ std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> & out) {
451
+ ggml_kleidiai_kernels * primary = ggml_kleidiai_select_kernels(ctx.features, op);
452
+ return kleidiai_collect_kernel_chain_common(primary, ctx.features, out,
453
+ [&](cpu_feature mask) { return ggml_kleidiai_select_kernels(mask, op); });
454
+ }
455
+
456
+ static int kleidiai_collect_q4_chain(std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> & out) {
457
+ ggml_kleidiai_kernels * primary = kleidiai_primary_kernel_q4();
458
+ return kleidiai_collect_kernel_chain_common(primary, ctx.features, out,
459
+ [&](cpu_feature mask) { return ggml_kleidiai_select_kernels_q4_0(mask); });
460
+ }
461
+
462
+ static int kleidiai_collect_q8_chain(std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> & out) {
463
+ ggml_kleidiai_kernels * primary = kleidiai_primary_kernel_q8();
464
+ return kleidiai_collect_kernel_chain_common(primary, ctx.features, out,
465
+ [&](cpu_feature mask) { return ggml_kleidiai_select_kernels_q8_0(mask); });
466
+ }
467
+
101
468
  static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) {
102
469
  GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
103
470
  return tensor->ne[dim];
@@ -126,49 +493,108 @@ class tensor_traits : public ggml::cpu::tensor_traits {
126
493
  if (op->op != GGML_OP_MUL_MAT) {
127
494
  return false;
128
495
  }
129
- ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
130
- if (!kernels) {
496
+
497
+ std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
498
+ const int slot_count = kleidiai_collect_kernel_chain(op, kernel_chain);
499
+ if (slot_count == 0) {
131
500
  return false;
132
501
  }
133
- bool is_gemv = op->src[1]->ne[1] == 1;
134
- kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
135
- lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
136
502
 
137
- size_t k = op->src[0]->ne[0];
138
- size_t n = op->src[0]->ne[1];
139
- size_t m = op->src[1]->ne[1];
140
-
141
- size_t mr = kernel->get_mr();
142
- size_t kr = kernel->get_kr();
143
- size_t sr = kernel->get_sr();
144
-
145
- if (kernels->rhs_type == GGML_TYPE_Q4_0) {
146
- if (!lhs_info->packed_size_ex) return false;
147
- size = lhs_info->packed_size_ex(m, k, QK4_0, mr, kr, sr);
148
- } else if (kernels->rhs_type == GGML_TYPE_Q8_0) {
149
- if (!lhs_info->packed_size_ex) return false;
150
- size = lhs_info->packed_size_ex(m, k, QK8_0, mr, kr, sr);
151
- } else if (kernels->rhs_type == GGML_TYPE_F16) {
152
- if (!lhs_info->packed_size_ex || !kernels->rhs_info.packed_size_ex) return false;
503
+ const bool is_gemv = op->src[1]->ne[1] == 1;
504
+ const size_t k = op->src[0]->ne[0];
505
+ const size_t n = op->src[0]->ne[1];
506
+ const size_t m = op->src[1]->ne[1];
507
+
508
+ if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q8_0) {
509
+ const size_t qk = (op->src[0]->type == GGML_TYPE_Q4_0) ? QK4_0 : QK8_0;
510
+
511
+ size_t cursor = 0;
512
+ bool any_slot = false;
513
+
514
+ for (int slot = 0; slot < slot_count; ++slot) {
515
+ ggml_kleidiai_kernels * kernels = kernel_chain[slot];
516
+ lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
517
+ kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
518
+
519
+ if (!lhs_info || !lhs_info->packed_size_ex || !kernel) {
520
+ return false;
521
+ }
522
+
523
+ const size_t mr = kernel->get_mr();
524
+ const size_t kr = kernel->get_kr();
525
+ const size_t sr = kernel->get_sr();
526
+
527
+ const size_t packed = lhs_info->packed_size_ex(m, k, qk, mr, kr, sr);
528
+
529
+ cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
530
+ cursor += packed;
531
+ any_slot = true;
532
+ }
533
+
534
+ if (!any_slot) {
535
+ return false;
536
+ }
537
+
538
+ size = cursor;
539
+ return true;
540
+ }
541
+
542
+ if (op->src[0]->type == GGML_TYPE_F16) {
153
543
  const int64_t lhs_batch_size0 = op->src[1]->ne[2];
154
544
  const int64_t rhs_batch_size0 = op->src[0]->ne[2];
545
+ GGML_ASSERT(rhs_batch_size0 > 0);
155
546
  const int64_t r = lhs_batch_size0 / rhs_batch_size0;
156
- size = lhs_info->packed_size_ex(m * r, k, 0, mr, kr, sr) +
157
- kernels->rhs_info.packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), 0) +
158
- k * n * sizeof(float) + n * sizeof(float);
159
- } else {
160
- return false;
547
+
548
+ size_t cursor = 0;
549
+ bool any_slot = false;
550
+
551
+ for (int slot = 0; slot < slot_count; ++slot) {
552
+ ggml_kleidiai_kernels * kernels = kernel_chain[slot];
553
+ lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
554
+ kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
555
+ if (!lhs_info || !lhs_info->packed_size_ex || !kernels->rhs_info.packed_size_ex || !kernel) {
556
+ return false;
557
+ }
558
+
559
+ const size_t mr = kernel->get_mr();
560
+ const size_t kr = kernel->get_kr();
561
+ const size_t sr = kernel->get_sr();
562
+
563
+ cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
564
+ cursor += lhs_info->packed_size_ex(m * r, k, 0, mr, kr, sr);
565
+ any_slot = true;
566
+ }
567
+
568
+ for (int slot = 0; slot < slot_count; ++slot) {
569
+ ggml_kleidiai_kernels * kernels = kernel_chain[slot];
570
+ kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
571
+ if (!kernel || !kernels->rhs_info.packed_size_ex) {
572
+ return false;
573
+ }
574
+ cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
575
+ cursor += kernels->rhs_info.packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), 0);
576
+ }
577
+
578
+ cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
579
+ cursor += k * n * sizeof(float);
580
+ cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
581
+ cursor += n * sizeof(float);
582
+
583
+ if (!any_slot) {
584
+ return false;
585
+ }
586
+
587
+ size = cursor;
588
+ return true;
161
589
  }
162
590
 
163
- return true;
591
+ return false;
164
592
  }
165
593
 
166
594
  bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override {
167
595
  if (dst->op == GGML_OP_MUL_MAT) {
168
- if (dst->src[0]->type == GGML_TYPE_Q4_0) {
169
- return compute_forward_q4_0(params, dst);
170
- } else if (dst->src[0]->type == GGML_TYPE_Q8_0) {
171
- return compute_forward_q8_0(params, dst);
596
+ if (dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0) {
597
+ return compute_forward_qx(params, dst);
172
598
  } else if (dst->src[0]->type == GGML_TYPE_F16) {
173
599
  return compute_forward_fp16(params, dst);
174
600
  }
@@ -331,204 +757,412 @@ class tensor_traits : public ggml::cpu::tensor_traits {
331
757
  return true;
332
758
  }
333
759
 
334
- bool compute_forward_q4_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
335
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
760
+ bool compute_forward_qx(struct ggml_compute_params * params, struct ggml_tensor * dst) {
761
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0);
336
762
 
337
763
  const ggml_tensor * src0 = dst->src[0];
338
764
  const ggml_tensor * src1 = dst->src[1];
339
765
 
340
766
  GGML_TENSOR_BINARY_OP_LOCALS
341
767
 
342
- ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
343
- if (!kernels) {
344
- return false;
345
- }
768
+ const kleidiai_weight_header * header = kleidiai_weight_header_from_ptr(src0->data);
769
+ const bool has_header = kleidiai_is_weight_header_valid(header);
770
+ const bool is_gemv = src1->ne[1] == 1;
771
+ std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
772
+ const int slot_total = kleidiai_collect_kernel_chain(dst, kernel_chain);
346
773
 
347
- bool is_gemv = src1->ne[1] == 1;
348
- kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
349
- lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
774
+ auto weight_for_slot = [&](int slot_index, size_t & size_out) -> const uint8_t * {
775
+ if (slot_index < 0 || slot_index >= slot_total) {
776
+ return nullptr;
777
+ }
778
+ if (has_header) {
779
+ if (slot_index < header->slot_count) {
780
+ size_out = static_cast<size_t>(header->sizes[slot_index]);
781
+ return kleidiai_weight_slot_ptr(header, slot_index);
782
+ }
783
+ return nullptr;
784
+ }
785
+ if (slot_index == 0) {
786
+ size_out = ggml_nbytes(src0);
787
+ return static_cast<const uint8_t *>(src0->data);
788
+ }
789
+ return nullptr;
790
+ };
791
+
792
+ struct runtime_slot {
793
+ int slot_index;
794
+ ggml_kleidiai_kernels * kernels;
795
+ kernel_info * kernel;
796
+ lhs_packing_info * lhs_info;
797
+ size_t mr;
798
+ size_t nr;
799
+ size_t kr;
800
+ size_t sr;
801
+ size_t n_step;
802
+ size_t lhs_packed_size;
803
+ size_t lhs_offset;
804
+ size_t lhs_bl;
805
+ size_t rhs_bl;
806
+ size_t pack_bl;
807
+ size_t lhs_packed_offset0;
808
+ int assigned_threads;
809
+ int thread_begin;
810
+ int thread_end;
811
+ const uint8_t * rhs_base;
812
+ };
813
+
814
+ std::array<runtime_slot, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> runtime{};
815
+ int runtime_count = 0;
816
+
817
+ for (int slot = 0; slot < slot_total && runtime_count < GGML_KLEIDIAI_MAX_KERNEL_SLOTS; ++slot) {
818
+ ggml_kleidiai_kernels * kernels = kernel_chain[slot];
819
+ kernel_info * kinfo = is_gemv ? &kernels->gemv : &kernels->gemm;
820
+ lhs_packing_info * linfo = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
821
+ if (!kinfo || !linfo || !linfo->packed_size_ex || !linfo->pack_func_ex || !linfo->get_offset ||
822
+ !kinfo->get_rhs_packed_offset_ex || !kinfo->run_kernel_ex || !kinfo->get_dst_offset) {
823
+ continue;
824
+ }
350
825
 
351
- GGML_ASSERT(kernel);
352
- if (!lhs_info->get_packed_offset_ex || !lhs_info->pack_func_ex ||
353
- !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex || !kernel->get_dst_offset) {
826
+ size_t rhs_size = 0;
827
+ const uint8_t * rhs_ptr = weight_for_slot(slot, rhs_size);
828
+ if (!rhs_ptr || rhs_size == 0) {
829
+ continue;
830
+ }
831
+
832
+ const kleidiai_block_args block_args = kleidiai_get_block_args(kernels->rhs_type);
833
+
834
+ runtime[runtime_count] = {
835
+ slot,
836
+ kernels,
837
+ kinfo,
838
+ linfo,
839
+ kinfo->get_mr(),
840
+ kinfo->get_nr(),
841
+ kinfo->get_kr(),
842
+ kinfo->get_sr(),
843
+ kinfo->get_n_step(),
844
+ 0,
845
+ 0,
846
+ block_args.lhs_bl,
847
+ block_args.rhs_bl,
848
+ block_args.pack_bl,
849
+ 0,
850
+ 0,
851
+ 0,
852
+ 0,
853
+ rhs_ptr
854
+ };
855
+ ++runtime_count;
856
+ }
857
+
858
+ if (runtime_count == 0) {
859
+ GGML_LOG_WARN("kleidiai: no runtime kernel slot available for supported op %s\n", dst->name);
354
860
  return false;
355
861
  }
356
862
 
357
- const int ith = params->ith;
358
- const int nth_raw = params->nth;
359
- const int nth = nth_raw > 0 ? nth_raw : 1;
863
+ const int nth_total = params->nth > 0 ? params->nth : 1;
864
+ const int ith_total = params->ith;
360
865
 
361
- const size_t k = ne00;
362
- const size_t m = ne11;
363
- const size_t n = ne01;
866
+ int sme_slot = -1;
867
+ for (int i = 0; i < runtime_count; ++i) {
868
+ if ((runtime[i].kernels->required_cpu & CPU_FEATURE_SME) == CPU_FEATURE_SME) {
869
+ sme_slot = i;
870
+ break;
871
+ }
872
+ }
873
+ int non_sme_slot = -1;
874
+ for (int i = 0; i < runtime_count; ++i) {
875
+ if ((runtime[i].kernels->required_cpu & CPU_FEATURE_SME) != CPU_FEATURE_SME) {
876
+ non_sme_slot = i;
877
+ break;
878
+ }
879
+ }
364
880
 
365
- size_t mr = kernel->get_mr();
366
- size_t kr = kernel->get_kr();
367
- size_t sr = kernel->get_sr();
881
+ const int sme_cap_limit = ctx.sme_thread_cap;
882
+ const bool use_hybrid = sme_cap_limit > 0 &&
883
+ runtime_count > 1 &&
884
+ nth_total > sme_cap_limit;
885
+ // Heuristic: disable hybrid for very small workloads where per-slot overhead dominates.
886
+ // If rows are small or average columns per thread are small, keep single-slot.
887
+ size_t min_cols_per_thread = 0;
888
+ if (runtime_count > 0 && nth_total > 0) {
889
+ min_cols_per_thread = (size_t) std::max<int64_t>(1, (int64_t)ne01 / (int64_t)nth_total);
890
+ }
891
+ const bool too_small_for_hybrid = (min_cols_per_thread < 2) || (ne11 < 128);
368
892
 
369
- const uint8_t * lhs = static_cast<const uint8_t *>(src1->data);
370
- uint8_t * lhs_packed = (uint8_t*)params->wdata;
371
- const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
893
+ const bool hybrid_enabled = use_hybrid && !too_small_for_hybrid;
372
894
 
373
- const size_t n_step = kernel->get_n_step();
374
- const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
375
- const size_t n_start = ith * num_n_per_thread;
895
+ if (!hybrid_enabled) {
896
+ int chosen_slot = 0;
897
+ if (too_small_for_hybrid && sme_slot != -1) {
898
+ chosen_slot = nth_total > sme_cap_limit && non_sme_slot != -1 ? non_sme_slot : sme_slot;
899
+ } else if (runtime_count > 1 && ctx.sme_thread_cap > 0 && nth_total > ctx.sme_thread_cap) {
900
+ chosen_slot = 1;
901
+ }
902
+ if (chosen_slot != 0 && chosen_slot < runtime_count) {
903
+ runtime[0] = runtime[chosen_slot];
904
+ runtime[0].assigned_threads = 0;
905
+ runtime[0].thread_begin = 0;
906
+ runtime[0].thread_end = 0;
907
+ }
908
+ runtime_count = runtime_count > 0 ? 1 : 0;
376
909
 
377
- size_t n_to_process = 0;
378
- if (n_start < n) {
379
- n_to_process = num_n_per_thread;
380
- if ((n_start + n_to_process) > n) {
381
- n_to_process = n - n_start;
910
+ // Recompute SME slot based on the collapsed runtime[0]
911
+ sme_slot = -1;
912
+ if (runtime_count > 0 &&
913
+ (runtime[0].kernels->required_cpu & CPU_FEATURE_SME) == CPU_FEATURE_SME) {
914
+ sme_slot = 0;
382
915
  }
383
916
  }
384
917
 
385
- // Calculate number of columns to be processed per thread
386
- const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
387
- const size_t m_start = ith * num_m_per_thread;
388
- size_t m_to_process = num_m_per_thread;
389
- if ((m_start + m_to_process) > m) {
390
- m_to_process = m - m_start;
918
+ int sme_cap = kleidiai_sme_thread_cap();
919
+ if (sme_cap < 0) {
920
+ sme_cap = nth_total;
391
921
  }
922
+ sme_cap = std::min(sme_cap, nth_total);
392
923
 
393
- if (m_start < m) {
394
- // Transform LHS
395
- const size_t src_stride = src1->nb[1];
396
- const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
397
- const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(m_start, k, QK4_0, mr, kr, sr);
398
- void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
399
-
400
- // Pack this thread's chunk with m_idx_start = 0 and per-thread output pointer
401
- lhs_info->pack_func_ex(m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
924
+ int threads_remaining = nth_total;
925
+ if (sme_slot != -1) {
926
+ int sme_threads = std::min(std::max(sme_cap, 0), threads_remaining);
927
+ runtime[sme_slot].assigned_threads = sme_threads;
928
+ threads_remaining -= sme_threads;
402
929
  }
403
930
 
404
- ggml_barrier(params->threadpool);
931
+ int fallback_indices[GGML_KLEIDIAI_MAX_KERNEL_SLOTS];
932
+ int fallback_count = 0;
933
+ // The current hybrid chain is bounded to SME + one non-SME fallback slot.
934
+ GGML_ASSERT(GGML_KLEIDIAI_MAX_KERNEL_SLOTS == 2);
935
+ for (int i = 0; i < runtime_count; ++i) {
936
+ if (i == sme_slot) {
937
+ continue;
938
+ }
939
+ fallback_indices[fallback_count++] = i;
940
+ }
405
941
 
406
- // Perform the operation
407
- const size_t dst_stride = dst->nb[1];
408
- const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(0, k, QK4_0, mr, kr, sr);
409
- const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, QK4_0);
410
- const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride);
411
- const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
412
- const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset);
413
- float *dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
942
+ for (int fi = 0; fi < fallback_count; ++fi) {
943
+ if (threads_remaining <= 0) {
944
+ break;
945
+ }
946
+ const int slot_index = fallback_indices[fi];
947
+ const int slots_left = fallback_count - fi;
948
+ int share = (threads_remaining + slots_left - 1) / slots_left;
949
+ share = std::min(share, threads_remaining);
950
+ runtime[slot_index].assigned_threads = share;
951
+ threads_remaining -= share;
952
+ }
414
953
 
415
- if (n_to_process > 0) {
416
- kernel->run_kernel_ex(m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
417
- sizeof(float), -FLT_MAX, FLT_MAX);
954
+ if (threads_remaining > 0) {
955
+ const int fallback_slot = (sme_slot != -1) ? sme_slot : 0;
956
+ runtime[fallback_slot].assigned_threads += threads_remaining;
957
+ threads_remaining = 0;
418
958
  }
419
959
 
420
- return true;
421
- }
960
+ int thread_cursor = 0;
961
+ for (int i = 0; i < runtime_count; ++i) {
962
+ runtime[i].thread_begin = thread_cursor;
963
+ thread_cursor += runtime[i].assigned_threads;
964
+ runtime[i].thread_end = thread_cursor;
965
+ }
422
966
 
423
- bool compute_forward_q8_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
424
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q8_0);
967
+ if (thread_cursor < nth_total && runtime_count > 0) {
968
+ runtime[runtime_count - 1].assigned_threads += nth_total - thread_cursor;
969
+ runtime[runtime_count - 1].thread_end = nth_total;
970
+ }
425
971
 
426
- const ggml_tensor * src0 = dst->src[0];
427
- const ggml_tensor * src1 = dst->src[1];
972
+ int local_slot = -1;
973
+ int local_ith = 0;
974
+ for (int i = 0; i < runtime_count; ++i) {
975
+ if (ith_total >= runtime[i].thread_begin && ith_total < runtime[i].thread_end) {
976
+ local_slot = i;
977
+ local_ith = ith_total - runtime[i].thread_begin;
978
+ break;
979
+ }
980
+ }
981
+ if (local_slot == -1) {
982
+ return false;
983
+ }
428
984
 
429
- GGML_TENSOR_BINARY_OP_LOCALS
985
+ const size_t k = ne00;
986
+ const size_t m = ne11;
987
+ const size_t n = ne01;
430
988
 
431
- ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
432
- if (!kernels) {
433
- return false;
989
+ size_t cursor = 0;
990
+ for (int i = 0; i < runtime_count; ++i) {
991
+ runtime[i].lhs_packed_size = runtime[i].lhs_info->packed_size_ex(m, k, runtime[i].pack_bl, runtime[i].mr, runtime[i].kr, runtime[i].sr);
992
+ cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
993
+ runtime[i].lhs_offset = cursor;
994
+ runtime[i].lhs_packed_offset0 = runtime[i].lhs_info->get_packed_offset_ex(0, k, runtime[i].lhs_bl, runtime[i].mr, runtime[i].kr, runtime[i].sr);
995
+ cursor += runtime[i].lhs_packed_size;
434
996
  }
435
997
 
436
- bool is_gemv = src1->ne[1] == 1;
437
- kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
438
- lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
998
+ GGML_ASSERT(cursor <= params->wsize);
999
+ uint8_t * scratch = static_cast<uint8_t *>(params->wdata);
439
1000
 
440
- if (!kernel || !lhs_info->get_packed_offset_ex || !lhs_info->pack_func_ex ||
441
- !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex || !kernel->get_dst_offset) {
442
- return false;
1001
+ size_t common_step = 1;
1002
+ for (int i = 0; i < runtime_count; ++i) {
1003
+ if (runtime[i].assigned_threads == 0) {
1004
+ continue;
1005
+ }
1006
+ size_t next_step = 0;
1007
+ if (!lcm_size(common_step, runtime[i].n_step ? runtime[i].n_step : 1, next_step)) {
1008
+ return false;
1009
+ }
1010
+ common_step = next_step;
1011
+ }
1012
+ GGML_ASSERT(common_step > 0);
1013
+
1014
+ const bool disable_chunking = ggml_is_numa();
1015
+ const size_t chunk_multiplier = std::max(1, ctx.chunk_multiplier);
1016
+ const size_t chunk_divisor = (nth_total == 1 || disable_chunking) ? (size_t)nth_total : (size_t)nth_total * chunk_multiplier;
1017
+ size_t chunk_cols = align_up(std::max<size_t>(1, ceil_div_size(n, chunk_divisor)), common_step);
1018
+ if (chunk_cols == 0) {
1019
+ chunk_cols = common_step;
443
1020
  }
1021
+ // If common_step is larger than n, the loop below runs one valid tail chunk
1022
+ // with cols == n.
1023
+ const size_t nchunk_size = std::max<size_t>(1, ceil_div_size(n, chunk_cols));
1024
+ GGML_ASSERT(nchunk_size <= (size_t)INT_MAX);
1025
+ const int nchunk = (int)nchunk_size;
1026
+ const size_t dst_stride = dst->nb[1];
444
1027
 
445
- const int ith = params->ith;
446
- const int nth_raw = params->nth;
447
- const int nth = nth_raw > 0 ? nth_raw : 1;
1028
+ auto run_chunk = [&](runtime_slot & slot, size_t global_start, size_t cols, uint8_t * dst_batch_base) {
1029
+ const size_t rhs_packed_offset = slot.kernel->get_rhs_packed_offset_ex(global_start, k, slot.rhs_bl);
1030
+ const size_t dst_offset = slot.kernel->get_dst_offset(0, global_start, dst_stride);
1031
+
1032
+ const uint8_t * lhs_ptr = scratch + slot.lhs_offset + slot.lhs_packed_offset0;
1033
+ const uint8_t * rhs_ptr = slot.rhs_base + rhs_packed_offset;
1034
+ float * dst_ptr = reinterpret_cast<float *>(dst_batch_base + dst_offset);
1035
+
1036
+ slot.kernel->run_kernel_ex(m, cols, k, slot.rhs_bl,
1037
+ lhs_ptr,
1038
+ rhs_ptr,
1039
+ dst_ptr,
1040
+ dst_stride,
1041
+ sizeof(float),
1042
+ -FLT_MAX,
1043
+ FLT_MAX);
1044
+ };
1045
+
1046
+ for (int64_t batch_idx = 0; batch_idx < ne12; ++batch_idx) {
1047
+ const uint8_t * lhs_batch_base = static_cast<const uint8_t *>(src1->data) + batch_idx * src1->nb[2];
1048
+ uint8_t * dst_batch_base = static_cast<uint8_t *>(dst->data) + batch_idx * dst->nb[2];
448
1049
 
449
- const size_t k = ne00;
450
- const size_t m = ne11;
451
- const size_t n = ne01;
1050
+ if (runtime[local_slot].assigned_threads > 0) {
1051
+ runtime_slot & slot = runtime[local_slot];
1052
+ const int64_t m_roundup_mr = kai_roundup((int64_t)m, (int64_t)slot.mr);
1053
+ int64_t max_threads = slot.mr ? (m_roundup_mr / (int64_t)slot.mr) : slot.assigned_threads;
1054
+ max_threads = std::max<int64_t>(1, max_threads);
1055
+ const int64_t use_threads = std::min<int64_t>(slot.assigned_threads, max_threads);
1056
+
1057
+ if (local_ith < use_threads) {
1058
+ const int64_t num_m_per_thread0 = round_down((size_t)(m_roundup_mr / use_threads), slot.mr);
1059
+ const int64_t num_m_per_threadN_1 = (int64_t)m - (use_threads - 1) * num_m_per_thread0;
1060
+
1061
+ const int64_t m_start = (int64_t)local_ith * num_m_per_thread0;
1062
+ const int64_t m_count = (local_ith == use_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
1063
+
1064
+ const size_t base_packed_off = slot.lhs_info->get_packed_offset_ex(m_start, k, slot.lhs_bl, slot.mr, slot.kr, slot.sr);
1065
+ const size_t next_block_off = slot.lhs_info->get_packed_offset_ex(m_start + slot.mr, k, slot.lhs_bl, slot.mr, slot.kr, slot.sr);
1066
+ const size_t row_stride_bytes = slot.mr ? (next_block_off - base_packed_off) / slot.mr : 0;
1067
+
1068
+ int64_t remaining = m_count;
1069
+ int64_t cur = m_start;
452
1070
 
453
- size_t mr = kernel->get_mr();
454
- size_t kr = kernel->get_kr();
455
- size_t sr = kernel->get_sr();
1071
+ uint8_t * lhs_packed = scratch + slot.lhs_offset;
1072
+ while (remaining > 0) {
1073
+ const int64_t row_in_group = cur;
1074
+ const int64_t avail = (int64_t)m - row_in_group;
1075
+ const int64_t take = std::min(avail, remaining);
456
1076
 
457
- const uint8_t * lhs = static_cast<const uint8_t *>(src1->data);
458
- uint8_t * lhs_packed = static_cast<uint8_t *>(params->wdata);
459
- const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
1077
+ const size_t src_off = slot.lhs_info->get_offset(row_in_group, src1->nb[1]);
1078
+ const void * src_ptr = lhs_batch_base + src_off;
1079
+ const size_t dst_off = base_packed_off + (size_t)(cur - m_start) * row_stride_bytes;
1080
+ void * dst_ptr = lhs_packed + dst_off;
460
1081
 
461
- const size_t n_step = kernel->get_n_step();
462
- const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
463
- const size_t n_start = ith * num_n_per_thread;
1082
+ slot.lhs_info->pack_func_ex(take, k, slot.lhs_bl, slot.mr, slot.kr, slot.sr, 0, src_ptr, src1->nb[1], dst_ptr);
464
1083
 
465
- size_t n_to_process = 0;
466
- if (n_start < n) {
467
- n_to_process = num_n_per_thread;
468
- if ((n_start + n_to_process) > n) {
469
- n_to_process = n - n_start;
1084
+ cur += take;
1085
+ remaining -= take;
1086
+ }
1087
+ }
470
1088
  }
471
- }
472
1089
 
473
- const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
474
- const size_t m_start = ith * num_m_per_thread;
475
- size_t m_to_process = num_m_per_thread;
476
- if ((m_start + m_to_process) > m) {
477
- m_to_process = m - m_start;
478
- }
1090
+ if (ith_total == 0) {
1091
+ ggml_threadpool_chunk_set(params->threadpool, nth_total);
1092
+ }
479
1093
 
480
- if (m_start < m) {
481
- const size_t src_stride = src1->nb[1];
482
- const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
483
- const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(m_start, k, 0, mr, kr, sr);
484
- void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
1094
+ // Publishes both LHS packing and the initialized dynamic chunk queue.
1095
+ ggml_barrier(params->threadpool);
485
1096
 
486
- lhs_info->pack_func_ex(m_to_process, k, 0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
487
- }
1097
+ runtime_slot & slot = runtime[local_slot];
1098
+ int current_chunk = ith_total;
1099
+ while (current_chunk < nchunk) {
1100
+ const size_t global_start = (size_t)current_chunk * chunk_cols;
1101
+ if (global_start >= n) {
1102
+ break;
1103
+ }
488
1104
 
489
- ggml_barrier(params->threadpool);
1105
+ const size_t cols = std::min(chunk_cols, n - global_start);
1106
+ if (cols > 0) {
1107
+ // KleidiAI GEMM/GEMV kernels accept arbitrary final tail widths;
1108
+ // only non-tail chunks are guaranteed to be n_step-aligned.
1109
+ run_chunk(slot, global_start, cols, dst_batch_base);
1110
+ }
490
1111
 
491
- const size_t dst_stride = dst->nb[1];
492
- const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(0, k, 0, mr, kr, sr);
493
- const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, 0);
494
- const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride);
495
- const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
496
- const void * lhs_ptr = static_cast<const void *>(lhs_packed + lhs_packed_offset);
497
- float * dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
1112
+ current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
1113
+ }
498
1114
 
499
- if (n_to_process > 0) {
500
- kernel->run_kernel_ex(m, n_to_process, k, 0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
501
- sizeof(float), -FLT_MAX, FLT_MAX);
1115
+ if (batch_idx != ne12 - 1) {
1116
+ ggml_barrier(params->threadpool);
1117
+ }
502
1118
  }
503
1119
 
504
1120
  return true;
505
1121
  }
506
1122
 
507
1123
  bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
1124
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0);
508
1125
  const ggml_tensor * src0 = dst->src[0];
509
1126
  const ggml_tensor * src1 = dst->src[1];
510
1127
 
511
1128
  GGML_TENSOR_BINARY_OP_LOCALS
512
1129
 
1130
+ const kleidiai_weight_header * header = kleidiai_weight_header_from_ptr(src0->data);
1131
+ const bool has_header = kleidiai_is_weight_header_valid(header);
1132
+
1133
+ std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
1134
+ const bool want_q8 = src0->type == GGML_TYPE_Q8_0;
1135
+ const int chain_count = want_q8 ? kleidiai_collect_q8_chain(kernel_chain)
1136
+ : kleidiai_collect_q4_chain(kernel_chain);
1137
+
513
1138
  ggml_kleidiai_kernels * kernels = nullptr;
514
- size_t block_len = 0;
515
- size_t num_bytes_multiplier = 0;
1139
+ const uint8_t * packed_base = static_cast<const uint8_t *>(src0->data);
516
1140
 
517
- if (dst->src[0]->type == GGML_TYPE_Q4_0) {
518
- if (!ctx.kernels_q4) {
519
- return false;
1141
+ if (has_header && chain_count > 0) {
1142
+ int select_slot = 0;
1143
+ if (select_slot >= header->slot_count) {
1144
+ select_slot = header->slot_count - 1;
520
1145
  }
521
- kernels = ctx.kernels_q4;
522
- block_len = QK4_0;
523
- num_bytes_multiplier = sizeof(uint16_t);
524
- } else if (dst->src[0]->type == GGML_TYPE_Q8_0) {
525
- if (!ctx.kernels_q8) {
526
- return false;
1146
+ if (select_slot >= 0 && select_slot < chain_count) {
1147
+ kernels = kernel_chain[select_slot];
1148
+ const uint8_t * slot_ptr = kleidiai_weight_slot_ptr(header, select_slot);
1149
+ if (slot_ptr) {
1150
+ packed_base = slot_ptr;
1151
+ }
527
1152
  }
528
- kernels = ctx.kernels_q8;
529
- block_len = QK8_0;
530
- num_bytes_multiplier = sizeof(float);
531
- } else {
1153
+ }
1154
+
1155
+ if (!kernels && chain_count > 0) {
1156
+ kernels = kernel_chain[0];
1157
+ if (has_header) {
1158
+ const uint8_t * slot_ptr = kleidiai_weight_slot_ptr(header, 0);
1159
+ if (slot_ptr) {
1160
+ packed_base = slot_ptr;
1161
+ }
1162
+ }
1163
+ }
1164
+
1165
+ if (!kernels) {
532
1166
  return false;
533
1167
  }
534
1168
 
@@ -541,6 +1175,19 @@ class tensor_traits : public ggml::cpu::tensor_traits {
541
1175
  const int64_t nc = ne00;
542
1176
  const int64_t nr = ggml_nelements(src1);
543
1177
 
1178
+ const ggml_type rhs_type = kernels->rhs_type;
1179
+ size_t block_len = 0;
1180
+ size_t num_bytes_multiplier = 0;
1181
+ if (rhs_type == GGML_TYPE_Q4_0) {
1182
+ block_len = QK4_0;
1183
+ num_bytes_multiplier = sizeof(uint16_t);
1184
+ } else if (rhs_type == GGML_TYPE_Q8_0) {
1185
+ block_len = QK8_0;
1186
+ num_bytes_multiplier = sizeof(float);
1187
+ } else {
1188
+ return false;
1189
+ }
1190
+
544
1191
  const size_t block_rows = kernel->get_nr();
545
1192
  const size_t kr = kernel->get_kr();
546
1193
 
@@ -559,7 +1206,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
559
1206
  GGML_ASSERT(row_idx >= 0 && row_idx < src0->ne[1]);
560
1207
 
561
1208
  float *out = (float *)((char *)dst->data + i * nb1);
562
- rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, block_len, num_bytes_multiplier);
1209
+ rhs_info->to_float(packed_base, row_idx, nc, out, block_rows, packed_stride, kr, block_len, num_bytes_multiplier);
563
1210
  }
564
1211
 
565
1212
  return true;
@@ -567,36 +1214,39 @@ class tensor_traits : public ggml::cpu::tensor_traits {
567
1214
 
568
1215
  public:
569
1216
  int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
1217
+ GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q8_0);
570
1218
  const size_t n = tensor->ne[1];
571
1219
  const size_t k = tensor->ne[0];
572
1220
 
573
- if (tensor->type == GGML_TYPE_Q4_0) {
574
- if (!ctx.kernels_q4) {
575
- return -1;
576
- }
577
- size_t nr = ctx.kernels_q4->gemm.get_nr();
578
- size_t kr = ctx.kernels_q4->gemm.get_kr();
579
- size_t sr = ctx.kernels_q4->gemm.get_sr();
1221
+ kleidiai_weight_header * header = kleidiai_weight_header_from_ptr(tensor->data);
1222
+ if (!header) {
1223
+ return -1;
1224
+ }
580
1225
 
581
- struct kai_rhs_pack_qs4cxs1s0_param params;
582
- params.lhs_zero_point = 1;
583
- params.rhs_zero_point = 8;
584
- ctx.kernels_q4->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0,
585
- static_cast<const uint8_t *>(data),
586
- nullptr, nullptr, tensor->data, 0, &params);
587
- GGML_UNUSED(data_size);
588
- return 0;
589
- } else if (tensor->type == GGML_TYPE_Q8_0) {
590
- if (!ctx.kernels_q8) {
591
- return -1;
592
- }
1226
+ header->magic = GGML_KLEIDIAI_PACK_MAGIC;
1227
+ header->version = GGML_KLEIDIAI_PACK_VERSION;
1228
+ header->slot_count = 0;
1229
+
1230
+ uint8_t * base_ptr = static_cast<uint8_t *>(tensor->data);
1231
+ size_t cursor = sizeof(kleidiai_weight_header);
1232
+ cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
1233
+
1234
+ std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
1235
+ const bool want_q8 = tensor->type == GGML_TYPE_Q8_0;
1236
+ const int slot_total = want_q8 ? kleidiai_collect_q8_chain(kernel_chain)
1237
+ : kleidiai_collect_q4_chain(kernel_chain);
1238
+ const bool allow_fallback = kleidiai_pack_fallback_allowed();
1239
+
1240
+ std::vector<int8_t> qdata;
1241
+ std::vector<float> scales;
1242
+
1243
+ if (want_q8 && slot_total > 0) {
1244
+ qdata.resize(n * k, 0);
1245
+ scales.resize(n, 0.0f);
593
1246
 
594
1247
  const size_t row_stride = tensor->nb[1];
595
1248
  const size_t k_blocks = (k + QK8_0 - 1) / QK8_0;
596
1249
 
597
- std::vector<int8_t> qdata(n * k, 0);
598
- std::vector<float> scales(n, 0.0f);
599
-
600
1250
  for (size_t row = 0; row < n; ++row) {
601
1251
  const auto * row_blocks = reinterpret_cast<const block_q8_0 *>(
602
1252
  static_cast<const uint8_t *>(data) + row * row_stride);
@@ -610,7 +1260,7 @@ public:
610
1260
  if (linear_idx >= k) {
611
1261
  break;
612
1262
  }
613
- const float value = d * blk.qs[l];
1263
+ const float value = d * static_cast<float>(blk.qs[l]);
614
1264
  max_abs = std::max(max_abs, std::fabs(value));
615
1265
  }
616
1266
  }
@@ -627,31 +1277,73 @@ public:
627
1277
  if (linear_idx >= k) {
628
1278
  break;
629
1279
  }
630
- const float value = d * blk.qs[l];
1280
+ const float value = d * static_cast<float>(blk.qs[l]);
631
1281
  int32_t q = scale > 0.0f ? static_cast<int32_t>(std::lround(value * inv_scale)) : 0;
632
1282
  q = std::clamp(q, -127, 127);
633
1283
  qdata[row * k + linear_idx] = static_cast<int8_t>(q);
634
1284
  }
635
1285
  }
636
1286
  }
1287
+ }
637
1288
 
638
- size_t nr = ctx.kernels_q8->gemm.get_nr();
639
- size_t kr = ctx.kernels_q8->gemm.get_kr();
640
- size_t sr = ctx.kernels_q8->gemm.get_sr();
1289
+ for (int slot = 0; slot < slot_total && slot < GGML_KLEIDIAI_MAX_KERNEL_SLOTS; ++slot) {
1290
+ if (!allow_fallback && slot > 0) {
1291
+ break;
1292
+ }
1293
+ ggml_kleidiai_kernels * kernels = kernel_chain[slot];
1294
+ kernel_info * kernel = &kernels->gemm;
1295
+ rhs_packing_info * rhs_info = &kernels->rhs_info;
1296
+ if (!rhs_info || !rhs_info->pack_func_ex || !rhs_info->packed_size_ex || !kernel) {
1297
+ continue;
1298
+ }
641
1299
 
642
- struct kai_rhs_pack_qsi8cx_params params;
643
- params.lhs_zero_point = 1;
644
- params.scale_multiplier = 1.0f;
1300
+ const size_t nr = kernel->get_nr();
1301
+ const size_t kr = kernel->get_kr();
1302
+ const size_t sr = kernel->get_sr();
1303
+ const ggml_type rhs_type = kernels->rhs_type;
1304
+ const size_t block_len = rhs_type == GGML_TYPE_Q8_0 ? QK8_0 :
1305
+ rhs_type == GGML_TYPE_Q4_0 ? QK4_0 : 0;
1306
+ if (block_len == 0) {
1307
+ continue;
1308
+ }
1309
+
1310
+ const size_t packed_size = rhs_info->packed_size_ex(n, k, nr, kr, block_len);
1311
+ const size_t aligned_cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
1312
+
1313
+ uint8_t * dst_ptr = base_ptr + aligned_cursor;
1314
+
1315
+ if (rhs_type == GGML_TYPE_Q4_0) {
1316
+ struct kai_rhs_pack_qs4cxs1s0_param params;
1317
+ params.lhs_zero_point = 1;
1318
+ params.rhs_zero_point = 8;
1319
+ rhs_info->pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0,
1320
+ static_cast<const uint8_t *>(data), nullptr, nullptr,
1321
+ dst_ptr, 0, &params);
1322
+ } else if (rhs_type == GGML_TYPE_Q8_0) {
1323
+ struct kai_rhs_pack_qsi8cx_params params;
1324
+ params.lhs_zero_point = 1;
1325
+ params.scale_multiplier = 1.0f;
1326
+ rhs_info->pack_func_ex(1, n, k, nr, kr, sr, 0, 0,
1327
+ qdata.data(), nullptr, scales.data(),
1328
+ dst_ptr, 0, &params);
1329
+ } else {
1330
+ continue;
1331
+ }
645
1332
 
646
- ctx.kernels_q8->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, 0, 0,
647
- qdata.data(), nullptr, scales.data(),
648
- tensor->data, 0, &params);
649
- GGML_UNUSED(data_size);
650
- return 0;
1333
+ header->offsets[header->slot_count] = aligned_cursor;
1334
+ header->sizes[header->slot_count] = packed_size;
1335
+ ++header->slot_count;
1336
+
1337
+ cursor = aligned_cursor + packed_size;
1338
+ }
1339
+
1340
+ if (header->slot_count == 0) {
1341
+ header->magic = 0;
1342
+ header->version = 0;
1343
+ memcpy(tensor->data, data, data_size);
651
1344
  }
652
1345
 
653
- GGML_UNUSED(data_size);
654
- return -1;
1346
+ return 0;
655
1347
  }
656
1348
  };
657
1349
 
@@ -681,9 +1373,8 @@ static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t bu
681
1373
  }
682
1374
 
683
1375
  static const char * ggml_backend_cpu_kleidiai_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
684
- return "CPU_KLEIDIAI";
685
-
686
1376
  GGML_UNUSED(buft);
1377
+ return "CPU_KLEIDIAI";
687
1378
  }
688
1379
 
689
1380
  static ggml_backend_buffer_t ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
@@ -702,56 +1393,85 @@ static ggml_backend_buffer_t ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer(
702
1393
  }
703
1394
 
704
1395
  static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
705
- return TENSOR_ALIGNMENT;
706
-
707
1396
  GGML_UNUSED(buft);
1397
+ return TENSOR_ALIGNMENT;
708
1398
  }
709
1399
 
710
1400
  static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
711
1401
  GGML_UNUSED(buft);
712
1402
 
1403
+ if (tensor->type != GGML_TYPE_Q4_0 && tensor->type != GGML_TYPE_Q8_0) {
1404
+ return ggml_nbytes(tensor);
1405
+ }
1406
+
713
1407
  const size_t n = tensor->ne[1];
714
1408
  const size_t k = tensor->ne[0];
715
1409
 
716
- ggml_kleidiai_kernels * kernels = nullptr;
717
- size_t block_len = 0;
718
-
719
- if (tensor->type == GGML_TYPE_Q4_0) {
720
- GGML_ASSERT(ctx.kernels_q4);
721
- kernels = ctx.kernels_q4;
722
- block_len = QK4_0;
723
- } else if (tensor->type == GGML_TYPE_Q8_0) {
724
- GGML_ASSERT(ctx.kernels_q8);
725
- kernels = ctx.kernels_q8;
726
- block_len = QK8_0;
727
- } else {
728
- return 0;
1410
+ size_t cursor = sizeof(kleidiai_weight_header);
1411
+ cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
1412
+
1413
+ std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
1414
+ const bool want_q8 = tensor->type == GGML_TYPE_Q8_0;
1415
+ const int slot_total = want_q8 ? kleidiai_collect_q8_chain(kernel_chain)
1416
+ : kleidiai_collect_q4_chain(kernel_chain);
1417
+ const bool allow_fallback = kleidiai_pack_fallback_allowed();
1418
+
1419
+ size_t slot_count = 0;
1420
+ for (int slot = 0; slot < slot_total; ++slot) {
1421
+ if (!allow_fallback && slot > 0) {
1422
+ break;
1423
+ }
1424
+ ggml_kleidiai_kernels * kernels = kernel_chain[slot];
1425
+ if (!kernels) {
1426
+ continue;
1427
+ }
1428
+ kernel_info * kernel = &kernels->gemm;
1429
+ rhs_packing_info * rhs_info = &kernels->rhs_info;
1430
+ if (!kernel || !rhs_info || !rhs_info->packed_size_ex) {
1431
+ continue;
1432
+ }
1433
+
1434
+ const ggml_type rhs_type = kernels->rhs_type;
1435
+ const size_t block_len = rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
1436
+ rhs_type == GGML_TYPE_Q8_0 ? QK8_0 : 0;
1437
+ if (block_len == 0) {
1438
+ continue;
1439
+ }
1440
+
1441
+ cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
1442
+ cursor += rhs_info->packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), block_len);
1443
+ ++slot_count;
729
1444
  }
730
1445
 
731
- const size_t nr = kernels->gemm.get_nr();
732
- const size_t kr = kernels->gemm.get_kr();
733
- const size_t packed = kernels->rhs_info.packed_size_ex(n, k, nr, kr, block_len);
734
- const size_t raw = ggml_nbytes(tensor);
1446
+ if (slot_count == 0) {
1447
+ return ggml_nbytes(tensor);
1448
+ }
735
1449
 
736
- return packed > raw ? packed : raw;
1450
+ return std::max(cursor, ggml_nbytes(tensor));
737
1451
  }
738
1452
 
739
1453
  namespace ggml::cpu::kleidiai {
740
1454
  class extra_buffer_type : ggml::cpu::extra_buffer_type {
741
1455
  bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
1456
+ std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
1457
+ const int slot_total = kleidiai_collect_kernel_chain(op, kernel_chain);
742
1458
  if ((op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) &&
743
1459
  (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q8_0) &&
744
1460
  op->src[0]->buffer &&
745
1461
  (ggml_n_dims(op->src[0]) == 2) &&
746
- op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
747
- if (((op->src[0]->type == GGML_TYPE_Q4_0) ? ctx.kernels_q4 : ctx.kernels_q8) == nullptr) {
1462
+ op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() &&
1463
+ slot_total > 0) {
1464
+ if (op->src[0]->type == GGML_TYPE_Q4_0 && ctx.kernels_q4 == nullptr) {
1465
+ return false;
1466
+ }
1467
+ if (op->src[0]->type == GGML_TYPE_Q8_0 && ctx.kernels_q8 == nullptr) {
748
1468
  return false;
749
1469
  }
750
1470
  if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
751
1471
  return false;
752
1472
  }
753
1473
  if ((op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_I32) &&
754
- ggml_ne(op->src[1], 2) == 1 && ggml_ne(op->src[1], 3) == 1) {
1474
+ ggml_ne(op->src[1], 3) == 1) {
755
1475
  return true;
756
1476
  }
757
1477
  }
@@ -762,14 +1482,19 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
762
1482
  if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) {
763
1483
  if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
764
1484
  return (ggml::cpu::tensor_traits *) op->src[0]->extra;
765
- }
766
- else if (ggml_kleidiai_select_kernels(ctx.features, op) && op->src[1]->ne[1] > 1) {
767
- if ((op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
768
- (op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) {
1485
+ } else {
1486
+ if (op->src[0]->type != GGML_TYPE_F16) {
769
1487
  return nullptr;
770
1488
  }
771
-
772
- return ggml::cpu::kleidiai::get_tensor_traits(NULL, NULL);
1489
+ std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
1490
+ const int slot_total = kleidiai_collect_kernel_chain(op, kernel_chain);
1491
+ if (slot_total > 0 && op->src[1]->ne[1] > 1) {
1492
+ if ((op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
1493
+ (op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) {
1494
+ return nullptr;
1495
+ }
1496
+ return ggml::cpu::kleidiai::get_tensor_traits(NULL, NULL);
1497
+ }
773
1498
  }
774
1499
  }
775
1500
  return nullptr;