whispercpp 1.3.5 → 1.3.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1017) hide show
  1. checksums.yaml +4 -4
  2. data/.document +3 -0
  3. data/.rdoc_options +2 -0
  4. data/LICENSE +1 -1
  5. data/README.md +133 -3
  6. data/Rakefile +18 -3
  7. data/ext/dependencies.rb +10 -4
  8. data/ext/dependencies_for_windows.rb +17 -0
  9. data/ext/extconf.rb +20 -7
  10. data/ext/options.rb +54 -14
  11. data/ext/options_for_windows.rb +51 -0
  12. data/ext/ruby_whisper.c +56 -46
  13. data/ext/ruby_whisper.h +165 -2
  14. data/ext/ruby_whisper_context.c +297 -126
  15. data/ext/ruby_whisper_context_params.c +163 -0
  16. data/ext/ruby_whisper_log_queue.c +180 -0
  17. data/ext/ruby_whisper_log_settable.h +47 -0
  18. data/ext/ruby_whisper_model.c +0 -1
  19. data/ext/ruby_whisper_parakeet.c +49 -0
  20. data/ext/ruby_whisper_parakeet_context.c +304 -0
  21. data/ext/ruby_whisper_parakeet_context_params.c +117 -0
  22. data/ext/ruby_whisper_parakeet_model.c +84 -0
  23. data/ext/ruby_whisper_parakeet_params.c +548 -0
  24. data/ext/ruby_whisper_parakeet_segment.c +157 -0
  25. data/ext/ruby_whisper_parakeet_token.c +188 -0
  26. data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
  27. data/ext/ruby_whisper_params.c +256 -66
  28. data/ext/ruby_whisper_segment.c +6 -7
  29. data/ext/ruby_whisper_token.c +29 -9
  30. data/ext/ruby_whisper_transcribe.cpp +46 -16
  31. data/ext/ruby_whisper_vad_context.c +48 -1
  32. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  33. data/ext/ruby_whisper_vad_params.c +0 -1
  34. data/ext/ruby_whisper_vad_segment.c +0 -1
  35. data/ext/ruby_whisper_vad_segments.c +0 -1
  36. data/ext/sources/CMakeLists.txt +41 -3
  37. data/ext/sources/CMakePresets.json +95 -0
  38. data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
  39. data/ext/sources/cmake/parakeet.pc.in +10 -0
  40. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  41. data/ext/sources/cmake/whisper.pc.in +1 -1
  42. data/ext/sources/examples/CMakeLists.txt +4 -2
  43. data/ext/sources/examples/bench/bench.cpp +24 -19
  44. data/ext/sources/examples/cli/cli.cpp +51 -9
  45. data/ext/sources/examples/common-ggml.cpp +4 -0
  46. data/ext/sources/examples/common-whisper.cpp +139 -67
  47. data/ext/sources/examples/common-whisper.h +11 -0
  48. data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
  49. data/ext/sources/examples/miniaudio.h +4507 -2131
  50. data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
  51. data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
  52. data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
  53. data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
  54. data/ext/sources/examples/server/server.cpp +213 -163
  55. data/ext/sources/ggml/CMakeLists.txt +29 -15
  56. data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
  57. data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
  58. data/ext/sources/ggml/include/ggml-alloc.h +1 -0
  59. data/ext/sources/ggml/include/ggml-backend.h +73 -11
  60. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  61. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  62. data/ext/sources/ggml/include/ggml-cuda.h +3 -0
  63. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  64. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  65. data/ext/sources/ggml/include/ggml-rpc.h +8 -3
  66. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  67. data/ext/sources/ggml/include/ggml.h +155 -16
  68. data/ext/sources/ggml/include/gguf.h +10 -2
  69. data/ext/sources/ggml/src/CMakeLists.txt +25 -5
  70. data/ext/sources/ggml/src/ggml-alloc.c +9 -10
  71. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  72. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  73. data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
  74. data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
  75. data/ext/sources/ggml/src/ggml-backend-reg.cpp +40 -86
  76. data/ext/sources/ggml/src/ggml-backend.cpp +114 -10
  77. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  78. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -2
  79. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  80. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  81. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +1016 -442
  82. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +111 -85
  83. data/ext/sources/ggml/src/ggml-cann/common.h +23 -14
  84. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +255 -92
  85. data/ext/sources/ggml/src/ggml-common.h +22 -0
  86. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +68 -34
  87. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +44 -19
  88. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  89. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +101 -101
  90. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +194 -1
  91. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2874 -613
  92. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
  93. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
  94. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +5480 -840
  95. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1361 -0
  96. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -11
  97. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
  98. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +186 -36
  99. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +119 -19
  100. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +112 -26
  101. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  102. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
  103. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  104. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -0
  105. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +153 -16
  106. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +17 -0
  107. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  108. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +976 -251
  109. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +671 -266
  110. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1277 -263
  111. data/ext/sources/ggml/src/ggml-cpu/ops.h +4 -0
  112. data/ext/sources/ggml/src/ggml-cpu/quants.c +95 -0
  113. data/ext/sources/ggml/src/ggml-cpu/quants.h +6 -0
  114. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2893 -679
  115. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  116. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +226 -0
  117. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +114 -19
  118. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
  119. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
  120. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
  121. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
  122. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
  123. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
  124. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
  125. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
  126. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
  127. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
  128. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
  129. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
  130. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
  131. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
  132. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
  133. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  134. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +54 -53
  135. data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
  136. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +18 -8
  137. data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
  138. data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
  139. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +73 -28
  140. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +69 -41
  141. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
  142. data/ext/sources/ggml/src/ggml-cuda/common.cuh +359 -29
  143. data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
  144. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
  145. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
  146. data/ext/sources/ggml/src/ggml-cuda/convert.cu +94 -27
  147. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  148. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +20 -9
  149. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
  150. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +333 -85
  151. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +632 -190
  152. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
  153. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +162 -49
  154. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +43 -18
  155. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +44 -14
  156. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  157. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +241 -23
  158. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  159. data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
  160. data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
  161. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +312 -0
  162. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  163. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
  164. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1454 -599
  165. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
  166. data/ext/sources/ggml/src/ggml-cuda/mean.cu +13 -10
  167. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +397 -183
  168. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  169. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +161 -88
  170. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
  171. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +522 -431
  172. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +139 -72
  173. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  174. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +608 -88
  175. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -0
  176. data/ext/sources/ggml/src/ggml-cuda/norm.cu +47 -79
  177. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
  178. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  179. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +134 -27
  180. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
  181. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +7 -17
  182. data/ext/sources/ggml/src/ggml-cuda/rope.cu +244 -137
  183. data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
  184. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
  185. data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
  186. data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
  187. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
  188. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  189. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  190. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +96 -40
  191. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  192. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
  193. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
  194. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
  195. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +6 -0
  196. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
  197. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +2 -0
  198. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
  199. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +6 -0
  200. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +2 -0
  201. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
  202. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
  203. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +2 -0
  204. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
  205. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +2 -0
  206. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
  207. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
  208. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
  209. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
  210. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
  211. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
  212. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
  213. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
  214. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
  225. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -5
  226. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +202 -135
  227. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  228. data/ext/sources/ggml/src/ggml-cuda/unary.cu +86 -2
  229. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +4 -0
  230. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +111 -17
  231. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
  232. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +30 -2
  233. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
  234. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +84 -46
  235. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1612 -753
  236. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +51 -11
  237. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +361 -261
  238. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +294 -0
  239. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +753 -241
  240. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
  241. data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
  242. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +295 -0
  243. data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
  244. data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
  245. data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
  246. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +471 -296
  247. data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
  248. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +159 -53
  249. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +3 -3
  250. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +372 -0
  251. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +86 -0
  252. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  253. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +137 -0
  254. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
  255. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
  256. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
  257. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
  258. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
  259. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
  260. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
  261. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
  262. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +97 -14
  263. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +163 -67
  264. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
  265. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  266. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +308 -0
  267. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +262 -0
  268. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +291 -0
  269. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  270. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +216 -0
  271. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
  272. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  273. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  274. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
  275. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
  276. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
  278. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +142 -0
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +18 -1348
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +547 -635
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +3556 -1101
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +475 -269
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +94 -72
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +222 -217
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +432 -0
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +886 -117
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  297. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +40 -0
  302. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +28 -9
  303. data/ext/sources/ggml/src/ggml-impl.h +68 -1
  304. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  305. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  306. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  307. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  308. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +409 -83
  309. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +54 -5
  310. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +254 -52
  311. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +254 -23
  312. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +756 -285
  313. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +7 -4
  314. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +359 -133
  315. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1867 -1123
  316. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
  317. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +71 -4
  318. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +14127 -5314
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +97 -88
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +104 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  322. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1978 -67
  323. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  324. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  325. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  326. data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
  327. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
  328. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
  329. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
  330. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
  331. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
  332. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
  333. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
  334. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
  335. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
  336. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q8_0_f32.cl +129 -0
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl +195 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
  367. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
  368. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
  369. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  370. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  371. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
  372. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
  373. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
  374. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
  375. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  376. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
  377. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
  378. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  379. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  380. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  381. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
  382. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
  383. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
  384. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
  385. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
  386. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
  387. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
  388. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  389. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +178 -0
  390. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  391. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  392. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  393. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  394. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  395. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  396. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  397. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  398. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  399. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  400. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  401. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +985 -0
  402. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  403. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +380 -0
  404. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  405. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1132 -0
  406. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +956 -0
  407. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  412. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  413. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  414. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  415. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  416. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  417. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  418. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  419. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  420. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  421. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  422. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  423. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  424. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +149 -0
  425. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  426. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  427. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  428. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  429. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
  430. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  431. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  432. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +47 -0
  433. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +40 -0
  434. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  435. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  436. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  437. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  438. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  439. data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
  440. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +317 -0
  441. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  442. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +257 -0
  443. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +86 -0
  444. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +880 -0
  445. data/ext/sources/ggml/src/ggml-openvino/utils.h +143 -0
  446. data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
  447. data/ext/sources/ggml/src/ggml-quants.c +385 -119
  448. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  449. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
  450. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
  451. data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
  452. data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
  453. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +64 -91
  454. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  455. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +4 -1
  456. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  457. data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
  458. data/ext/sources/ggml/src/ggml-sycl/common.hpp +356 -11
  459. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +184 -14
  460. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +31 -1
  461. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  462. data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
  463. data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
  464. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
  465. data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
  466. data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
  467. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
  468. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  469. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +77 -156
  470. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -2
  471. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
  472. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
  473. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1181 -0
  474. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +59 -0
  475. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1246 -0
  476. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +674 -0
  477. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +227 -0
  478. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  479. data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
  480. data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
  481. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +347 -0
  482. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +9 -0
  483. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  484. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
  485. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +1134 -236
  486. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
  487. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
  488. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
  489. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
  490. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  491. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  492. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
  493. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  494. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +72 -1
  495. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  496. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  497. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
  498. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  499. data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
  500. data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
  501. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
  502. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
  503. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
  504. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
  505. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
  506. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  507. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  508. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  509. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  510. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
  511. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  512. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  513. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  514. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  515. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  516. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +8 -0
  517. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +8 -0
  518. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +8 -0
  519. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +8 -0
  520. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +8 -0
  521. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +8 -0
  522. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +8 -0
  523. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +8 -0
  524. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +8 -0
  525. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +8 -0
  526. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +8 -0
  527. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +8 -0
  528. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +8 -0
  529. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +8 -0
  530. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +8 -0
  531. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +8 -0
  532. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +8 -0
  533. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +8 -0
  534. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +8 -0
  535. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +8 -0
  536. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +8 -0
  537. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +8 -0
  538. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +8 -0
  539. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +8 -0
  540. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +8 -0
  541. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +8 -0
  542. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +8 -0
  543. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +8 -0
  544. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +8 -0
  545. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +8 -0
  546. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +8 -0
  547. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +8 -0
  548. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +8 -0
  549. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +8 -0
  550. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +8 -0
  551. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +8 -0
  552. data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
  553. data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
  554. data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
  555. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +228 -53
  556. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  557. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  558. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  559. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  560. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  561. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  562. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  563. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  564. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  565. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  566. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  567. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  568. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  569. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  570. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  571. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  572. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  573. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  574. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  575. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  576. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  577. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  578. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +123 -0
  579. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +160 -0
  580. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  581. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +71 -0
  582. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  583. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  584. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  585. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  586. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  587. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  588. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  589. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  590. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  591. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  592. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +99 -0
  593. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  594. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  595. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  596. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +545 -0
  597. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +115 -0
  598. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  599. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3250 -940
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +533 -180
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +113 -68
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +412 -222
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +222 -83
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +189 -0
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +22 -20
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +51 -14
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +39 -63
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +13 -7
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +27 -11
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -149
  663. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
  664. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +3221 -97
  665. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3493 -1997
  666. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
  667. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
  668. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  669. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  670. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  671. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +142 -0
  672. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +115 -141
  673. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +93 -0
  674. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
  675. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -44
  676. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  677. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +198 -230
  678. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
  679. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
  680. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
  681. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
  682. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
  683. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
  684. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +234 -335
  685. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
  686. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
  687. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +871 -42
  688. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
  689. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
  690. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
  691. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +149 -0
  692. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +36 -138
  693. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +151 -0
  694. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
  695. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
  696. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  697. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
  698. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
  699. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  700. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
  701. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
  702. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
  703. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +15 -40
  704. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
  705. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +39 -12
  706. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
  707. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
  708. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
  709. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
  710. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
  711. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  712. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +213 -0
  713. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
  714. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +24 -15
  715. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  716. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +253 -16
  717. data/ext/sources/ggml/src/ggml.c +268 -52
  718. data/ext/sources/ggml/src/gguf.cpp +377 -47
  719. data/ext/sources/include/parakeet.h +342 -0
  720. data/ext/sources/include/whisper.h +10 -0
  721. data/ext/sources/media/matmul.png +0 -0
  722. data/ext/sources/src/CMakeLists.txt +23 -0
  723. data/ext/sources/src/parakeet-arch.h +188 -0
  724. data/ext/sources/src/parakeet.cpp +3838 -0
  725. data/ext/sources/src/whisper.cpp +62 -40
  726. data/extsources.rb +26 -10
  727. data/lib/whisper/log_settable.rb +36 -0
  728. data/lib/whisper/model/uri.rb +13 -1
  729. data/lib/whisper/output.rb +74 -0
  730. data/sig/whisper.rbs +445 -55
  731. data/test/helper.rb +2 -0
  732. data/test/jfk_reader/jfk_reader.c +50 -7
  733. data/test/test_callback.rb +1 -0
  734. data/test/test_context_params.rb +82 -0
  735. data/test/test_package.rb +6 -5
  736. data/test/test_parakeet.rb +28 -0
  737. data/test/test_parakeet_callback.rb +107 -0
  738. data/test/test_parakeet_context.rb +116 -0
  739. data/test/test_parakeet_context_params.rb +24 -0
  740. data/test/test_parakeet_model.rb +21 -0
  741. data/test/test_parakeet_params.rb +78 -0
  742. data/test/test_parakeet_segment.rb +42 -0
  743. data/test/test_parakeet_token.rb +73 -0
  744. data/test/test_params.rb +2 -0
  745. data/test/test_token.rb +11 -0
  746. data/test/test_vad_context.rb +58 -8
  747. data/test/test_vad_segment.rb +1 -1
  748. data/test/test_whisper.rb +44 -6
  749. data/whispercpp.gemspec +2 -2
  750. metadata +426 -280
  751. data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
  752. data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
  753. data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
  754. data/ext/sources/bindings/javascript/package.json +0 -26
  755. data/ext/sources/bindings/javascript/whisper.js +0 -19
  756. data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
  757. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
  758. data/ext/sources/examples/addon.node/addon.cpp +0 -557
  759. data/ext/sources/examples/addon.node/index.js +0 -59
  760. data/ext/sources/examples/addon.node/package.json +0 -16
  761. data/ext/sources/examples/addon.node/vad-example.js +0 -132
  762. data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
  763. data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
  764. data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
  765. data/ext/sources/examples/coi-serviceworker.js +0 -146
  766. data/ext/sources/examples/command/CMakeLists.txt +0 -10
  767. data/ext/sources/examples/command/command.cpp +0 -802
  768. data/ext/sources/examples/command/commands.txt +0 -9
  769. data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
  770. data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
  771. data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
  772. data/ext/sources/examples/generate-karaoke.sh +0 -57
  773. data/ext/sources/examples/helpers.js +0 -191
  774. data/ext/sources/examples/livestream.sh +0 -112
  775. data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
  776. data/ext/sources/examples/lsp/lsp.cpp +0 -471
  777. data/ext/sources/examples/lsp/whisper.vim +0 -362
  778. data/ext/sources/examples/python/test_whisper_processor.py +0 -7
  779. data/ext/sources/examples/python/whisper_processor.py +0 -54
  780. data/ext/sources/examples/server/bench.js +0 -29
  781. data/ext/sources/examples/server.py +0 -120
  782. data/ext/sources/examples/stream/CMakeLists.txt +0 -10
  783. data/ext/sources/examples/stream/stream.cpp +0 -437
  784. data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
  785. data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
  786. data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
  787. data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
  788. data/ext/sources/examples/sycl/build.sh +0 -22
  789. data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
  790. data/ext/sources/examples/sycl/run-whisper.sh +0 -17
  791. data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -47
  792. data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
  793. data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -494
  794. data/ext/sources/examples/talk-llama/llama-adapter.h +0 -88
  795. data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2559
  796. data/ext/sources/examples/talk-llama/llama-arch.h +0 -586
  797. data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -917
  798. data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
  799. data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -876
  800. data/ext/sources/examples/talk-llama/llama-chat.h +0 -70
  801. data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3645
  802. data/ext/sources/examples/talk-llama/llama-context.h +0 -360
  803. data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
  804. data/ext/sources/examples/talk-llama/llama-cparams.h +0 -42
  805. data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
  806. data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
  807. data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2282
  808. data/ext/sources/examples/talk-llama/llama-graph.h +0 -910
  809. data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -241
  810. data/ext/sources/examples/talk-llama/llama-hparams.h +0 -284
  811. data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
  812. data/ext/sources/examples/talk-llama/llama-impl.h +0 -63
  813. data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
  814. data/ext/sources/examples/talk-llama/llama-io.h +0 -35
  815. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -328
  816. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
  817. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2100
  818. data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -390
  819. data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
  820. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
  821. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
  822. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1167
  823. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
  824. data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
  825. data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
  826. data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -735
  827. data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
  828. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1247
  829. data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -176
  830. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -285
  831. data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -37
  832. data/ext/sources/examples/talk-llama/llama-model.cpp +0 -8338
  833. data/ext/sources/examples/talk-llama/llama-model.h +0 -544
  834. data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1072
  835. data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
  836. data/ext/sources/examples/talk-llama/llama-sampling.cpp +0 -3771
  837. data/ext/sources/examples/talk-llama/llama-sampling.h +0 -44
  838. data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3900
  839. data/ext/sources/examples/talk-llama/llama-vocab.h +0 -182
  840. data/ext/sources/examples/talk-llama/llama.cpp +0 -1140
  841. data/ext/sources/examples/talk-llama/llama.h +0 -1540
  842. data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -191
  843. data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
  844. data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
  845. data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -138
  846. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
  847. data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -122
  848. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -144
  849. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -135
  850. data/ext/sources/examples/talk-llama/models/bert.cpp +0 -178
  851. data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -160
  852. data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
  853. data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
  854. data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
  855. data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
  856. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
  857. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
  858. data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
  859. data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -123
  860. data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
  861. data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -144
  862. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -259
  863. data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -134
  864. data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
  865. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -150
  866. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
  867. data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
  868. data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
  869. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -113
  870. data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
  871. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
  872. data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
  873. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
  874. data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
  875. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
  876. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
  877. data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -150
  878. data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
  879. data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
  880. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -196
  881. data/ext/sources/examples/talk-llama/models/granite.cpp +0 -211
  882. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +0 -283
  883. data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
  884. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -141
  885. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
  886. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -154
  887. data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
  888. data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
  889. data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
  890. data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -175
  891. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
  892. data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
  893. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
  894. data/ext/sources/examples/talk-llama/models/llama.cpp +0 -168
  895. data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
  896. data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -55
  897. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -123
  898. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -199
  899. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -124
  900. data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
  901. data/ext/sources/examples/talk-llama/models/models.h +0 -569
  902. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -116
  903. data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
  904. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -150
  905. data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
  906. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
  907. data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
  908. data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
  909. data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
  910. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
  911. data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
  912. data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
  913. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
  914. data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
  915. data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
  916. data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
  917. data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -316
  918. data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
  919. data/ext/sources/examples/talk-llama/models/plm.cpp +0 -168
  920. data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
  921. data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
  922. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
  923. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
  924. data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -117
  925. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -124
  926. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -873
  927. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -149
  928. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -141
  929. data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
  930. data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
  931. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -162
  932. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
  933. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
  934. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -135
  935. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
  936. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
  937. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
  938. data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
  939. data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
  940. data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
  941. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
  942. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
  943. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
  944. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
  945. data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
  946. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
  947. data/ext/sources/examples/talk-llama/speak +0 -40
  948. data/ext/sources/examples/talk-llama/speak.bat +0 -1
  949. data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
  950. data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
  951. data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
  952. data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
  953. data/ext/sources/examples/talk-llama/unicode.cpp +0 -1147
  954. data/ext/sources/examples/talk-llama/unicode.h +0 -111
  955. data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
  956. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
  957. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
  958. data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
  959. data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
  960. data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
  961. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
  962. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
  963. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
  964. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
  965. data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
  966. data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
  967. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  968. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  969. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
  970. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +0 -157
  971. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -165
  972. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  973. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  974. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  975. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  976. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  977. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  978. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  979. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
  980. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
  981. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
  982. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  983. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  984. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -147
  985. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
  986. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +0 -907
  987. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +0 -247
  988. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  989. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
  990. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  991. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
  992. data/ext/sources/tests/CMakeLists.txt +0 -112
  993. data/ext/sources/tests/earnings21/eval.mk +0 -58
  994. data/ext/sources/tests/earnings21/eval.py +0 -68
  995. data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
  996. data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
  997. data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
  998. data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
  999. data/ext/sources/tests/earnings21/requirements.txt +0 -6
  1000. data/ext/sources/tests/en-0-ref.txt +0 -1
  1001. data/ext/sources/tests/en-1-ref.txt +0 -1
  1002. data/ext/sources/tests/en-2-ref.txt +0 -1
  1003. data/ext/sources/tests/es-0-ref.txt +0 -1
  1004. data/ext/sources/tests/librispeech/eval.mk +0 -39
  1005. data/ext/sources/tests/librispeech/eval.py +0 -47
  1006. data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
  1007. data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
  1008. data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
  1009. data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
  1010. data/ext/sources/tests/librispeech/requirements.txt +0 -6
  1011. data/ext/sources/tests/run-tests.sh +0 -130
  1012. data/ext/sources/tests/test-c.c +0 -3
  1013. data/ext/sources/tests/test-vad-full.cpp +0 -56
  1014. data/ext/sources/tests/test-vad.cpp +0 -83
  1015. data/ext/sources/tests/test-whisper.js +0 -58
  1016. data/lib/whisper/context.rb +0 -15
  1017. data/lib/whisper/segment.rb +0 -58
@@ -25,9 +25,8 @@
25
25
  #define UNUSED GGML_UNUSED
26
26
 
27
27
  #if defined(__aarch64__) && defined(__ARM_NEON) && (defined(__ARM_FEATURE_MATMUL_INT8) || defined(__ARM_FEATURE_DOTPROD))
28
- static inline void decode_q4_Kx8_scales_mins(const uint8_t * scales_in,
29
- int16x8_t * out_mins,
30
- int8_t * out_scales) {
28
+ // Helper for decoding scales and mins of Q4_K and Q5_K block formats
29
+ static inline void decode_q_Kx8_6bit_scales(const uint8_t * scales_in, int16x8_t * out_mins, int8_t * out_scales) {
31
30
  constexpr uint32_t kmask1 = 0x3f3f3f3f;
32
31
  constexpr uint32_t kmask2 = 0x0f0f0f0f;
33
32
  constexpr uint32_t kmask3 = 0x03030303;
@@ -499,6 +498,81 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
499
498
  ggml_gemv_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
500
499
  }
501
500
 
501
+ void ggml_gemv_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
502
+ const int qk = QK8_0;
503
+ const int nb = n / qk;
504
+ const int ncols_interleaved = 4;
505
+ const int blocklen = 4;
506
+
507
+ assert (n % qk == 0);
508
+ assert (nc % ncols_interleaved == 0);
509
+
510
+ UNUSED(s);
511
+ UNUSED(bs);
512
+ UNUSED(vx);
513
+ UNUSED(vy);
514
+ UNUSED(nr);
515
+ UNUSED(nc);
516
+ UNUSED(nb);
517
+ UNUSED(ncols_interleaved);
518
+ UNUSED(blocklen);
519
+
520
+ #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
521
+ const int8x16_t kvalues = vld1q_s8(kvalues_mxfp4);
522
+ const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
523
+ float * res_ptr = s;
524
+
525
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
526
+ const block_mxfp4x4 * b_ptr = (const block_mxfp4x4 *) vx + (x * nb);
527
+
528
+ float32x4_t sumf = vdupq_n_f32(0);
529
+ for (int l = 0; l < nb; l++) {
530
+ uint8x16_t b_0 = vld1q_u8(b_ptr[l].qs + 0);
531
+ uint8x16_t b_1 = vld1q_u8(b_ptr[l].qs + 16);
532
+ uint8x16_t b_2 = vld1q_u8(b_ptr[l].qs + 32);
533
+ uint8x16_t b_3 = vld1q_u8(b_ptr[l].qs + 48);
534
+
535
+ int8x16_t b_0_hi = vqtbl1q_s8(kvalues, b_0 >> 4);
536
+ int8x16_t b_0_lo = vqtbl1q_s8(kvalues, b_0 & 0x0F);
537
+ int8x16_t b_1_hi = vqtbl1q_s8(kvalues, b_1 >> 4);
538
+ int8x16_t b_1_lo = vqtbl1q_s8(kvalues, b_1 & 0x0F);
539
+ int8x16_t b_2_hi = vqtbl1q_s8(kvalues, b_2 >> 4);
540
+ int8x16_t b_2_lo = vqtbl1q_s8(kvalues, b_2 & 0x0F);
541
+ int8x16_t b_3_hi = vqtbl1q_s8(kvalues, b_3 >> 4);
542
+ int8x16_t b_3_lo = vqtbl1q_s8(kvalues, b_3 & 0x0F);
543
+
544
+ int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 0);
545
+ int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16);
546
+
547
+ int32x4_t sumi = vdupq_n_s32(0);
548
+ sumi = vdotq_laneq_s32(sumi, b_0_lo, a_0, 0);
549
+ sumi = vdotq_laneq_s32(sumi, b_0_hi, a_1, 0);
550
+ sumi = vdotq_laneq_s32(sumi, b_1_lo, a_0, 1);
551
+ sumi = vdotq_laneq_s32(sumi, b_1_hi, a_1, 1);
552
+ sumi = vdotq_laneq_s32(sumi, b_2_lo, a_0, 2);
553
+ sumi = vdotq_laneq_s32(sumi, b_2_hi, a_1, 2);
554
+ sumi = vdotq_laneq_s32(sumi, b_3_lo, a_0, 3);
555
+ sumi = vdotq_laneq_s32(sumi, b_3_hi, a_1, 3);
556
+
557
+ float32x4_t a_d = vcvt_f32_f16(vld1_dup_f16((const float16_t *)&a_ptr[l].d));
558
+ float32x4_t b_d = {
559
+ GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[0]),
560
+ GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[1]),
561
+ GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[2]),
562
+ GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[3]),
563
+ };
564
+ float32x4_t d = a_d * b_d;
565
+
566
+ sumf = vmlaq_f32(sumf, d, vcvtq_f32_s32(sumi));
567
+ }
568
+
569
+ vst1q_f32(res_ptr + x * 4, sumf);
570
+ }
571
+ return;
572
+ #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
573
+ ggml_gemv_mxfp4_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
574
+ }
575
+
502
576
  void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
503
577
  constexpr int qk = QK_K;
504
578
  const int nb = n / qk;
@@ -561,7 +635,7 @@ void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
561
635
  for (int i = 0; i < 2; i++) {
562
636
  int8_t aux_q4sb[8];
563
637
  const int offset = sb * 24 + i * 12;
564
- decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
638
+ decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
565
639
  q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
566
640
  }
567
641
 
@@ -701,13 +775,13 @@ void ggml_gemv_q4_K_8x8_q8_K(int n,
701
775
  for (int i = 0; i < 2; i++) {
702
776
  int8_t aux_q4sb[8];
703
777
  const int offset = sb * 24 + i * 12;
704
- decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
778
+ decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
705
779
  q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
706
780
  }
707
781
 
708
782
  const uint8_t * q4_base = q4_ptr[b].qs + sb * QK_K;
709
783
 
710
- // Load the 64 quants from q8K duplicated to use vecdots with the interelaved columns
784
+ // Load the 64 quants from q8K duplicated to use vecdots with the interleaved columns
711
785
  // but still need the qs to use the low and hi bits from q4
712
786
  const int8_t * q8_base = q8_ptr[b].qs + sb * 64;
713
787
  int8x16_t q8_qs[8];
@@ -786,17 +860,18 @@ void ggml_gemv_q4_K_8x8_q8_K(int n,
786
860
  ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
787
861
  }
788
862
 
789
- void ggml_gemv_q8_0_4x4_q8_0(int n,
863
+ void ggml_gemv_q5_K_8x4_q8_K(int n,
790
864
  float * GGML_RESTRICT s,
791
865
  size_t bs,
792
866
  const void * GGML_RESTRICT vx,
793
867
  const void * GGML_RESTRICT vy,
794
868
  int nr,
795
869
  int nc) {
796
- const int qk = QK8_0;
797
- const int nb = n / qk;
798
- const int ncols_interleaved = 4;
799
- const int blocklen = 4;
870
+ constexpr int qk = QK_K;
871
+ const int nb = n / qk;
872
+
873
+ constexpr int ncols_interleaved = 8;
874
+ constexpr int blocklen = 4;
800
875
 
801
876
  assert(n % qk == 0);
802
877
  assert(nc % ncols_interleaved == 0);
@@ -806,55 +881,156 @@ void ggml_gemv_q8_0_4x4_q8_0(int n,
806
881
  UNUSED(blocklen);
807
882
 
808
883
  #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
809
- const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx;
884
+ constexpr int col_groups = ncols_interleaved / 4; // 0123 and 4567
885
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
886
+ const uint8x16_t mone = vdupq_n_u8(1);
887
+ const uint8x16_t mtwo = vdupq_n_u8(2);
888
+
889
+ // 1x8 tile = 2 x 4
890
+ float32x4_t acc_f32[col_groups];
891
+
892
+ const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
893
+
894
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
895
+ const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);
896
+
897
+ for (int i = 0; i < col_groups; i++) {
898
+ acc_f32[i] = vdupq_n_f32(0);
899
+ }
810
900
 
811
- for (int c = 0; c < nc; c += ncols_interleaved) {
812
- const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
813
- float32x4_t acc = vdupq_n_f32(0);
814
901
  for (int b = 0; b < nb; b++) {
815
- int8x16x4_t b_low = vld1q_s8_x4((const int8_t *) b_ptr->qs);
816
- int8x16x4_t b_high = vld1q_s8_x4((const int8_t *) b_ptr->qs + 64);
817
- float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);
902
+ float32x4_t q5_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d)); // d0 d1 d2 d3
903
+ float32x4_t q5_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4)); // d4 d5 d6 d7
904
+ float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
905
+ float32x4_t sb_scale_0123 = vmulq_f32(q5_d_0, q8_d);
906
+ float32x4_t sb_scale_4567 = vmulq_f32(q5_d_1, q8_d);
907
+ float32x4_t q5_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin)); // dmin 0..3
908
+ float32x4_t q5_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4)); // dmin 4..7
909
+ float32x4_t sb_min_0123 = vmulq_f32(q5_dmin_0, q8_d);
910
+ float32x4_t sb_min_4567 = vmulq_f32(q5_dmin_1, q8_d);
818
911
 
819
- int8x16x2_t a = vld1q_s8_x2(a_ptr->qs);
820
- float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
912
+ // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567
913
+ int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
914
+ int32x4_t acc_lo[col_groups];
915
+ int32x4_t acc_hi[col_groups];
821
916
 
822
- int32x4_t ret = vdupq_n_s32(0);
917
+ // Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
918
+ const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
919
+ int16_t bsums_arr[8];
920
+ vst1q_s16(bsums_arr, bsums);
823
921
 
824
- ret = vdotq_laneq_s32(ret, b_low.val[0], a.val[0], 0);
825
- ret = vdotq_laneq_s32(ret, b_low.val[1], a.val[0], 1);
826
- ret = vdotq_laneq_s32(ret, b_low.val[2], a.val[0], 2);
827
- ret = vdotq_laneq_s32(ret, b_low.val[3], a.val[0], 3);
922
+ uint8x16_t qh[col_groups][8];
923
+ for (int c = 0; c < col_groups; c++) {
924
+ for (int i = 0; i < 8; i++) {
925
+ qh[c][i] = vld1q_u8(q5_ptr[b].qh + i * 32 + 16 * c);
926
+ }
927
+ }
828
928
 
829
- ret = vdotq_laneq_s32(ret, b_high.val[0], a.val[1], 0);
830
- ret = vdotq_laneq_s32(ret, b_high.val[1], a.val[1], 1);
831
- ret = vdotq_laneq_s32(ret, b_high.val[2], a.val[1], 2);
832
- ret = vdotq_laneq_s32(ret, b_high.val[3], a.val[1], 3);
929
+ for (int sb = 0; sb < QK_K / 64; sb++) {
930
+ for (int i = 0; i < col_groups; i++) {
931
+ acc_lo[i] = vdupq_n_s32(0);
932
+ acc_hi[i] = vdupq_n_s32(0);
933
+ }
934
+ // Need scales for the low and high nibbles
935
+ // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
936
+ int16x8_t q5sb_mins[2];
937
+ int16x8_t q5sb_scales[2];
938
+ for (int i = 0; i < 2; i++) {
939
+ int8_t aux_q5sb[8];
940
+ const int offset = sb * 24 + i * 12;
941
+ decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb);
942
+ q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb));
943
+ }
833
944
 
834
- acc = vfmaq_f32(acc, vcvtq_f32_s32(ret), vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
835
- a_ptr++;
836
- b_ptr++;
837
- }
838
- vst1q_f32(s, acc);
839
- s += ncols_interleaved;
840
- }
841
- return;
945
+ int8x16_t q8_qs[4];
946
+ for (int i = 0; i < 4; i++) {
947
+ q8_qs[i] = vld1q_s8(q8_ptr[b].qs + sb * 64 + i * 16);
948
+ }
949
+
950
+ for (int c = 0; c < col_groups; c++) {
951
+ uint8x16_t q5_cols[8];
952
+ uint8x16_t hbit_lo[8];
953
+ uint8x16_t hbit_hi[8];
954
+ int8x16_t q5_lo[8];
955
+ int8x16_t q5_hi[8];
956
+
957
+ for (int i = 0; i < 8; i++) {
958
+ q5_cols[i] = vld1q_u8(q5_ptr[b].qs + sb * QK_K + i * 32 + 16 * c);
959
+ hbit_lo[i] = vandq_u8(qh[c][i], mone);
960
+ hbit_hi[i] = vshlq_n_u8(vandq_u8(qh[c][i], mtwo), 3);
961
+ qh[c][i] = vshrq_n_u8(qh[c][i], 2);
962
+ q5_lo[i] = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_cols[i], m4b), hbit_lo[i], 4));
963
+ q5_hi[i] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_cols[i], 4), hbit_hi[i]));
964
+ }
965
+
966
+ acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[0], q8_qs[0], 0);
967
+ acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[1], q8_qs[0], 1);
968
+ acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[2], q8_qs[0], 2);
969
+ acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[3], q8_qs[0], 3);
970
+ acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[4], q8_qs[1], 0);
971
+ acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[5], q8_qs[1], 1);
972
+ acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[6], q8_qs[1], 2);
973
+ acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[7], q8_qs[1], 3);
974
+
975
+ acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[0], q8_qs[2], 0);
976
+ acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[1], q8_qs[2], 1);
977
+ acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[2], q8_qs[2], 2);
978
+ acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[3], q8_qs[2], 3);
979
+ acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[4], q8_qs[3], 0);
980
+ acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[5], q8_qs[3], 1);
981
+ acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[6], q8_qs[3], 2);
982
+ acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[7], q8_qs[3], 3);
983
+ }
984
+
985
+ // Scales
986
+ // row c0123 blk0 and blk1
987
+ const int16x4_t sc_0123_lo = vget_low_s16(q5sb_scales[0]);
988
+ const int16x4_t sc_0123_hi = vget_low_s16(q5sb_scales[1]);
989
+ const float32x4_t sumf_0123 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[0]),
990
+ vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[0])));
991
+ acc_f32[0] = vfmaq_f32(acc_f32[0], sb_scale_0123, sumf_0123);
992
+ // row c4567 blk0 and blk1
993
+ const int16x4_t sc_4567_lo = vget_high_s16(q5sb_scales[0]);
994
+ const int16x4_t sc_4567_hi = vget_high_s16(q5sb_scales[1]);
995
+ const float32x4_t sumf_4567 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[1]),
996
+ vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[1])));
997
+ acc_f32[1] = vfmaq_f32(acc_f32[1], sb_scale_4567, sumf_4567);
998
+
999
+ // Bias Correction
1000
+ const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
1001
+ const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
1002
+
1003
+ bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q5sb_mins[0]));
1004
+ bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q5sb_mins[1]));
1005
+ bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q5sb_mins[0]));
1006
+ bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q5sb_mins[1]));
1007
+ } // for sb
1008
+
1009
+ acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0123);
1010
+ acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_4567);
1011
+ } // for b
842
1012
 
1013
+ int base = x * ncols_interleaved;
1014
+ vst1q_f32(s + base, acc_f32[0]);
1015
+ vst1q_f32(s + base + 4, acc_f32[1]);
1016
+ } // for x
1017
+ return;
843
1018
  #endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
844
- ggml_gemv_q8_0_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
1019
+ ggml_gemv_q5_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
845
1020
  }
846
1021
 
847
- void ggml_gemv_q8_0_4x8_q8_0(int n,
1022
+ void ggml_gemv_q5_K_8x8_q8_K(int n,
848
1023
  float * GGML_RESTRICT s,
849
1024
  size_t bs,
850
1025
  const void * GGML_RESTRICT vx,
851
1026
  const void * GGML_RESTRICT vy,
852
1027
  int nr,
853
1028
  int nc) {
854
- const int qk = QK8_0;
855
- const int nb = n / qk;
856
- const int ncols_interleaved = 4;
857
- const int blocklen = 8;
1029
+ constexpr int qk = QK_K;
1030
+ const int nb = n / qk;
1031
+
1032
+ constexpr int ncols_interleaved = 8;
1033
+ constexpr int blocklen = 8;
858
1034
 
859
1035
  assert(n % qk == 0);
860
1036
  assert(nc % ncols_interleaved == 0);
@@ -864,269 +1040,1003 @@ void ggml_gemv_q8_0_4x8_q8_0(int n,
864
1040
  UNUSED(blocklen);
865
1041
 
866
1042
  #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
867
- const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx;
1043
+ constexpr int col_pairs = ncols_interleaved / 2;
1044
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
1045
+ const uint8x16_t mone = vdupq_n_u8(1);
1046
+ const uint8x16_t mtwo = vdupq_n_u8(2);
868
1047
 
869
- for (int c = 0; c < nc; c += ncols_interleaved) {
870
- const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
871
- float32x4_t acc = vdupq_n_f32(0);
1048
+ // 1x8 tile = 2 x 4
1049
+ float32x4_t acc_f32[ncols_interleaved / 4];
872
1050
 
873
- for (int b = 0; b < nb; b++) {
874
- int8x16x4_t b_low = vld1q_s8_x4((const int8_t *) b_ptr->qs);
875
- int8x16x4_t b_high = vld1q_s8_x4((const int8_t *) b_ptr->qs + 64);
876
- float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);
1051
+ const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
877
1052
 
878
- int8x8x4_t a_chunks = vld1_s8_x4(a_ptr->qs);
879
- int8x16_t a0 = vcombine_s8(a_chunks.val[0], a_chunks.val[0]);
880
- int8x16_t a1 = vcombine_s8(a_chunks.val[1], a_chunks.val[1]);
881
- int8x16_t a2 = vcombine_s8(a_chunks.val[2], a_chunks.val[2]);
882
- int8x16_t a3 = vcombine_s8(a_chunks.val[3], a_chunks.val[3]);
883
- float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
1053
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
1054
+ const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);
884
1055
 
885
- int32x4_t ret0 = vdupq_n_s32(0);
886
- int32x4_t ret1 = vdupq_n_s32(0);
1056
+ for (int i = 0; i < ncols_interleaved / 4; i++) {
1057
+ acc_f32[i] = vdupq_n_f32(0);
1058
+ }
887
1059
 
888
- // 0..7
889
- ret0 = vdotq_s32(ret0, b_low.val[0], a0);
890
- ret1 = vdotq_s32(ret1, b_low.val[1], a0);
891
- // 8..15
892
- ret0 = vdotq_s32(ret0, b_low.val[2], a1);
893
- ret1 = vdotq_s32(ret1, b_low.val[3], a1);
894
- // 16..23
895
- ret0 = vdotq_s32(ret0, b_high.val[0], a2);
896
- ret1 = vdotq_s32(ret1, b_high.val[1], a2);
897
- // 24..31
898
- ret0 = vdotq_s32(ret0, b_high.val[2], a3);
899
- ret1 = vdotq_s32(ret1, b_high.val[3], a3);
1060
+ for (int b = 0; b < nb; b++) {
1061
+ float32x4_t q5_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d)); // d0 d1 d2 d3
1062
+ float32x4_t q5_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4)); // d4 d5 d6 d7
1063
+ float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
1064
+ float32x4_t sb_scale_0 = vmulq_f32(q5_d_0, q8_d);
1065
+ float32x4_t sb_scale_1 = vmulq_f32(q5_d_1, q8_d);
1066
+ float32x4_t q5_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin)); // dmin 0..3
1067
+ float32x4_t q5_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4)); // dmin 4..7
1068
+ float32x4_t sb_min_0 = vmulq_f32(q5_dmin_0, q8_d);
1069
+ float32x4_t sb_min_1 = vmulq_f32(q5_dmin_1, q8_d);
900
1070
 
901
- int32x4_t ret = vpaddq_s32(ret0, ret1);
1071
+ // 2 sb each iteration
1072
+ int32x4_t acc_lo[col_pairs];
1073
+ int32x4_t acc_hi[col_pairs];
902
1074
 
903
- acc = vfmaq_f32(acc, vcvtq_f32_s32(ret), vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
904
- a_ptr++;
905
- b_ptr++;
906
- }
907
- vst1q_f32(s, acc);
908
- s += ncols_interleaved;
909
- }
910
- return;
1075
+ // Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
1076
+ const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
1077
+ int16_t bsums_arr[8];
1078
+ vst1q_s16(bsums_arr, bsums);
911
1079
 
912
- #endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
913
- ggml_gemv_q8_0_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
914
- }
1080
+ // Load qh once per block and shift after each subblock
1081
+ const uint8_t * qh_base = q5_ptr[b].qh;
1082
+ uint8x16_t qh[col_pairs][4];
1083
+ for (int cp = 0; cp < col_pairs; cp++) {
1084
+ qh[cp][0] = vld1q_u8(qh_base + 16 * cp);
1085
+ qh[cp][1] = vld1q_u8(qh_base + 16 * cp + 64);
1086
+ qh[cp][2] = vld1q_u8(qh_base + 16 * cp + 128);
1087
+ qh[cp][3] = vld1q_u8(qh_base + 16 * cp + 192);
1088
+ }
915
1089
 
916
- void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
917
- const int qk = QK8_0;
918
- const int nb = n / qk;
919
- const int ncols_interleaved = 4;
920
- const int blocklen = 4;
1090
+ for (int sb = 0; sb < QK_K / 64; sb++) {
1091
+ for (int i = 0; i < col_pairs; i++) {
1092
+ acc_lo[i] = vdupq_n_s32(0);
1093
+ acc_hi[i] = vdupq_n_s32(0);
1094
+ }
1095
+ // Need scales for the low and high nibbles
1096
+ // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
1097
+ int16x8_t q5sb_mins[2]; // int16 as its needed for bias_acc later
1098
+ int16x8_t q5sb_scales[2];
1099
+ for (int i = 0; i < 2; i++) {
1100
+ int8_t aux_q5sb[8];
1101
+ const int offset = sb * 24 + i * 12;
1102
+ decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb);
1103
+ q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb));
1104
+ }
921
1105
 
922
- assert (n % qk == 0);
923
- assert (nr % 4 == 0);
924
- assert (nc % ncols_interleaved == 0);
1106
+ const uint8_t * qs_base = q5_ptr[b].qs + sb * QK_K;
1107
+
1108
+ // Load the 64 quants from q8K duplicated to use vecdots with the interleaved columns
1109
+ const int8_t * q8_base = q8_ptr[b].qs + sb * 64;
1110
+ int8x16_t q8_qs[8];
1111
+ for (int i = 0; i < 8; i++) {
1112
+ q8_qs[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base + i * 8));
1113
+ }
1114
+
1115
+ // Q5s column pair loop unrolled
1116
+ {
1117
+ // Cols 01
1118
+ uint8x16_t qs_0 = vld1q_u8(qs_base);
1119
+ uint8x16_t qs_1 = vld1q_u8(qs_base + 64);
1120
+ uint8x16_t qs_2 = vld1q_u8(qs_base + 128);
1121
+ uint8x16_t qs_3 = vld1q_u8(qs_base + 192);
1122
+
1123
+ uint8x16_t hbit_lo_0 = vandq_u8(qh[0][0], mone);
1124
+ uint8x16_t hbit_lo_1 = vandq_u8(qh[0][1], mone);
1125
+ uint8x16_t hbit_lo_2 = vandq_u8(qh[0][2], mone);
1126
+ uint8x16_t hbit_lo_3 = vandq_u8(qh[0][3], mone);
1127
+ uint8x16_t hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[0][0], mtwo), 3);
1128
+ uint8x16_t hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[0][1], mtwo), 3);
1129
+ uint8x16_t hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[0][2], mtwo), 3);
1130
+ uint8x16_t hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[0][3], mtwo), 3);
1131
+
1132
+ qh[0][0] = vshrq_n_u8(qh[0][0], 2);
1133
+ qh[0][1] = vshrq_n_u8(qh[0][1], 2);
1134
+ qh[0][2] = vshrq_n_u8(qh[0][2], 2);
1135
+ qh[0][3] = vshrq_n_u8(qh[0][3], 2);
1136
+
1137
+ acc_lo[0] = ggml_vdotq_s32(
1138
+ acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
1139
+ acc_lo[0] = ggml_vdotq_s32(
1140
+ acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
1141
+ acc_lo[0] = ggml_vdotq_s32(
1142
+ acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
1143
+ acc_lo[0] = ggml_vdotq_s32(
1144
+ acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
1145
+ acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
1146
+ q8_qs[4]);
1147
+ acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
1148
+ q8_qs[5]);
1149
+ acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
1150
+ q8_qs[6]);
1151
+ acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
1152
+ q8_qs[7]);
1153
+
1154
+ // Cols 23
1155
+ qs_0 = vld1q_u8(qs_base + 16);
1156
+ qs_1 = vld1q_u8(qs_base + 80);
1157
+ qs_2 = vld1q_u8(qs_base + 144);
1158
+ qs_3 = vld1q_u8(qs_base + 208);
1159
+
1160
+ hbit_lo_0 = vandq_u8(qh[1][0], mone);
1161
+ hbit_lo_1 = vandq_u8(qh[1][1], mone);
1162
+ hbit_lo_2 = vandq_u8(qh[1][2], mone);
1163
+ hbit_lo_3 = vandq_u8(qh[1][3], mone);
1164
+ hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[1][0], mtwo), 3);
1165
+ hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[1][1], mtwo), 3);
1166
+ hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[1][2], mtwo), 3);
1167
+ hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[1][3], mtwo), 3);
1168
+
1169
+ qh[1][0] = vshrq_n_u8(qh[1][0], 2);
1170
+ qh[1][1] = vshrq_n_u8(qh[1][1], 2);
1171
+ qh[1][2] = vshrq_n_u8(qh[1][2], 2);
1172
+ qh[1][3] = vshrq_n_u8(qh[1][3], 2);
1173
+
1174
+ acc_lo[1] = ggml_vdotq_s32(
1175
+ acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
1176
+ acc_lo[1] = ggml_vdotq_s32(
1177
+ acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
1178
+ acc_lo[1] = ggml_vdotq_s32(
1179
+ acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
1180
+ acc_lo[1] = ggml_vdotq_s32(
1181
+ acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
1182
+ acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
1183
+ q8_qs[4]);
1184
+ acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
1185
+ q8_qs[5]);
1186
+ acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
1187
+ q8_qs[6]);
1188
+ acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
1189
+ q8_qs[7]);
1190
+
1191
+ // Cols 45
1192
+ qs_0 = vld1q_u8(qs_base + 32);
1193
+ qs_1 = vld1q_u8(qs_base + 96);
1194
+ qs_2 = vld1q_u8(qs_base + 160);
1195
+ qs_3 = vld1q_u8(qs_base + 224);
1196
+
1197
+ hbit_lo_0 = vandq_u8(qh[2][0], mone);
1198
+ hbit_lo_1 = vandq_u8(qh[2][1], mone);
1199
+ hbit_lo_2 = vandq_u8(qh[2][2], mone);
1200
+ hbit_lo_3 = vandq_u8(qh[2][3], mone);
1201
+ hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[2][0], mtwo), 3);
1202
+ hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[2][1], mtwo), 3);
1203
+ hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[2][2], mtwo), 3);
1204
+ hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[2][3], mtwo), 3);
1205
+
1206
+ qh[2][0] = vshrq_n_u8(qh[2][0], 2);
1207
+ qh[2][1] = vshrq_n_u8(qh[2][1], 2);
1208
+ qh[2][2] = vshrq_n_u8(qh[2][2], 2);
1209
+ qh[2][3] = vshrq_n_u8(qh[2][3], 2);
1210
+
1211
+ acc_lo[2] = ggml_vdotq_s32(
1212
+ acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
1213
+ acc_lo[2] = ggml_vdotq_s32(
1214
+ acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
1215
+ acc_lo[2] = ggml_vdotq_s32(
1216
+ acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
1217
+ acc_lo[2] = ggml_vdotq_s32(
1218
+ acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
1219
+ acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
1220
+ q8_qs[4]);
1221
+ acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
1222
+ q8_qs[5]);
1223
+ acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
1224
+ q8_qs[6]);
1225
+ acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
1226
+ q8_qs[7]);
1227
+
1228
+ // Cols 45
1229
+ qs_0 = vld1q_u8(qs_base + 48);
1230
+ qs_1 = vld1q_u8(qs_base + 112);
1231
+ qs_2 = vld1q_u8(qs_base + 176);
1232
+ qs_3 = vld1q_u8(qs_base + 240);
1233
+
1234
+ hbit_lo_0 = vandq_u8(qh[3][0], mone);
1235
+ hbit_lo_1 = vandq_u8(qh[3][1], mone);
1236
+ hbit_lo_2 = vandq_u8(qh[3][2], mone);
1237
+ hbit_lo_3 = vandq_u8(qh[3][3], mone);
1238
+ hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[3][0], mtwo), 3);
1239
+ hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[3][1], mtwo), 3);
1240
+ hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[3][2], mtwo), 3);
1241
+ hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[3][3], mtwo), 3);
1242
+
1243
+ qh[3][0] = vshrq_n_u8(qh[3][0], 2);
1244
+ qh[3][1] = vshrq_n_u8(qh[3][1], 2);
1245
+ qh[3][2] = vshrq_n_u8(qh[3][2], 2);
1246
+ qh[3][3] = vshrq_n_u8(qh[3][3], 2);
1247
+
1248
+ acc_lo[3] = ggml_vdotq_s32(
1249
+ acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
1250
+ acc_lo[3] = ggml_vdotq_s32(
1251
+ acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
1252
+ acc_lo[3] = ggml_vdotq_s32(
1253
+ acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
1254
+ acc_lo[3] = ggml_vdotq_s32(
1255
+ acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
1256
+ acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
1257
+ q8_qs[4]);
1258
+ acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
1259
+ q8_qs[5]);
1260
+ acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
1261
+ q8_qs[6]);
1262
+ acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
1263
+ q8_qs[7]);
1264
+ }
1265
+
1266
+ // Prepare bsum vectors for bias computation
1267
+ // Each pair of subblocks share the same bsums
1268
+ int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
1269
+ int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
1270
+
1271
+ // Iterates over a pair of column pairs (4 columns) to use a single 128 register
1272
+ // p = 0 -> 0123 p2 -> 4567
1273
+ for (int i = 0, p = 0; p < col_pairs; i++, p += 2) {
1274
+ int16x4_t group_scales_lo = p == 0 ? vget_low_s16(q5sb_scales[0]) : vget_high_s16(q5sb_scales[0]);
1275
+ int16x4_t group_scales_hi = p == 0 ? vget_low_s16(q5sb_scales[1]) : vget_high_s16(q5sb_scales[1]);
1276
+ int16x4_t group_mins_lo = p == 0 ? vget_low_s16(q5sb_mins[0]) : vget_high_s16(q5sb_mins[0]);
1277
+ int16x4_t group_mins_hi = p == 0 ? vget_low_s16(q5sb_mins[1]) : vget_high_s16(q5sb_mins[1]);
1278
+ float32x4_t sb_scale = p == 0 ? sb_scale_0 : sb_scale_1;
1279
+ float32x4_t sb_min = p == 0 ? sb_min_0 : sb_min_1;
1280
+
1281
+ // 0123 or 4567
1282
+ float32x4_t sumf_0 =
1283
+ vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_lo), vpaddq_s32(acc_lo[p], acc_lo[p + 1])));
1284
+ acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_0);
1285
+
1286
+ float32x4_t sumf_1 =
1287
+ vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_hi), vpaddq_s32(acc_hi[p], acc_hi[p + 1])));
1288
+ acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_1);
1289
+
1290
+ // FUSED BIAS: Compute and subtract bias immediately
1291
+ // bias = (bsums_lo * mins_lo + bsums_hi * mins_hi) * sb_min
1292
+ int32x4_t bias = vmull_s16(bsums_vec_lo, group_mins_lo);
1293
+ bias = vmlal_s16(bias, bsums_vec_hi, group_mins_hi);
1294
+ float32x4_t bias_f32 = vcvtq_f32_s32(bias);
1295
+ acc_f32[i] = vmlsq_f32(acc_f32[i], sb_min, bias_f32);
1296
+ }
1297
+ } // for sb
1298
+ } // for b
1299
+
1300
+ int base = x * ncols_interleaved;
1301
+ vst1q_f32(s + base, acc_f32[0]);
1302
+ vst1q_f32(s + base + 4, acc_f32[1]);
1303
+ } // for x
1304
+ return;
1305
+ #endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
1306
+ ggml_gemv_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
1307
+ }
1308
+
1309
+ void ggml_gemv_q6_K_8x4_q8_K(int n,
1310
+ float * GGML_RESTRICT s,
1311
+ size_t bs,
1312
+ const void * GGML_RESTRICT vx,
1313
+ const void * GGML_RESTRICT vy,
1314
+ int nr,
1315
+ int nc) {
1316
+ constexpr int qk = QK_K;
1317
+ const int nb = n / qk;
1318
+
1319
+ constexpr int ncols_interleaved = 8;
1320
+ constexpr int blocklen = 4;
1321
+
1322
+ assert(n % qk == 0);
1323
+ assert(nc % ncols_interleaved == 0);
925
1324
 
926
- UNUSED(s);
927
- UNUSED(bs);
928
- UNUSED(vx);
929
- UNUSED(vy);
930
- UNUSED(nr);
931
- UNUSED(nc);
932
1325
  UNUSED(nb);
933
1326
  UNUSED(ncols_interleaved);
934
1327
  UNUSED(blocklen);
935
1328
 
936
- #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
937
- const void * b_ptr = vx;
938
- const void * a_ptr = vy;
939
- float * res_ptr = s;
940
- size_t res_stride = bs * sizeof(float);
1329
+ #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
1330
+ constexpr int col_groups = ncols_interleaved / 4;
1331
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
1332
+ const uint8x16_t mask_lo = vdupq_n_u8(0x03);
1333
+ const uint8x16_t mask_hi = vdupq_n_u8(0x30);
941
1334
 
942
- __asm__ __volatile__(
943
- "mov x10, %x[nr]\n"
944
- "mov x9, #0x88\n"
945
- "cmp x10, #0x10\n"
946
- "mul x9, %x[nb], x9\n"
947
- "blt 4f\n"
948
- "1:" // Row loop
949
- "add x28, %x[b_ptr], #0x8\n"
950
- "mov x27, %x[nc]\n"
951
- "add x26, %x[res_ptr], %x[res_stride], LSL #4\n"
952
- "2:" // Column loop
953
- "add x25, %x[a_ptr], #0x8\n"
954
- "movi v15.16b, #0x0\n"
955
- "movi v19.16b, #0x0\n"
956
- "mov x24, %x[nb]\n"
957
- "add x23, x25, x9\n"
958
- "movi v18.16b, #0x0\n"
959
- "movi v14.16b, #0x0\n"
960
- "add x22, x23, x9\n"
961
- "movi v11.16b, #0x0\n"
962
- "movi v13.16b, #0x0\n"
963
- "add x21, x22, x9\n"
964
- "movi v23.16b, #0x0\n"
965
- "movi v16.16b, #0x0\n"
966
- "movi v25.16b, #0x0\n"
967
- "movi v7.16b, #0x0\n"
968
- "movi v0.16b, #0x0\n"
969
- "movi v4.16b, #0x0\n"
970
- "movi v5.16b, #0x0\n"
971
- "movi v21.16b, #0x0\n"
972
- "movi v8.16b, #0x0\n"
973
- "movi v1.16b, #0x0\n"
974
- "3:" // Block loop
975
- "ldr q3, [x28, #0x0]\n"
976
- "ldr q31, [x25, #0x0]\n"
977
- "movi v28.16b, #0x4\n"
978
- "movi v10.4s, #0x0\n"
979
- "ldr q22, [x28, #0x10]\n"
980
- "ldr q6, [x25, #0x10]\n"
981
- "movi v29.4s, #0x0\n"
982
- "movi v9.4s, #0x0\n"
983
- "ldr q27, [x28, #0x20]\n"
984
- "ldr q30, [x28, #0x30]\n"
985
- "movi v20.4s, #0x0\n"
986
- "movi v24.16b, #0xf0\n"
987
- "ldr d2, [x25, #-0x8]\n"
988
- "ldr d26, [x23, #-0x8]\n"
989
- "sshl v12.16b, v3.16b, v28.16b\n"
990
- "sub x20, x28, #0x8\n"
991
- "ldr d17, [x20, #0x0]\n"
992
- "and v3.16b, v3.16b, v24.16b\n"
993
- "subs x24, x24, #0x1\n"
994
- "add x28, x28, #0x48\n"
995
- ".inst 0x4f9fe18a // sdot v10.4s, v12.16b, v31.4b[0]\n"
996
- ".inst 0x4fbfe19d // sdot v29.4s, v12.16b, v31.4b[1]\n"
997
- ".inst 0x4f9fe989 // sdot v9.4s, v12.16b, v31.4b[2]\n"
998
- ".inst 0x4fbfe994 // sdot v20.4s, v12.16b, v31.4b[3]\n"
999
- "sshl v31.16b, v22.16b, v28.16b\n"
1000
- "and v22.16b, v22.16b, v24.16b\n"
1001
- "fcvtl v17.4s, v17.4h\n"
1002
- "fcvtl v2.4s, v2.4h\n"
1003
- "fcvtl v26.4s, v26.4h\n"
1004
- ".inst 0x4f86e3ea // sdot v10.4s, v31.16b, v6.4b[0]\n"
1005
- ".inst 0x4fa6e3fd // sdot v29.4s, v31.16b, v6.4b[1]\n"
1006
- ".inst 0x4f86ebe9 // sdot v9.4s, v31.16b, v6.4b[2]\n"
1007
- ".inst 0x4fa6ebf4 // sdot v20.4s, v31.16b, v6.4b[3]\n"
1008
- "sshl v6.16b, v27.16b, v28.16b\n"
1009
- "sshl v28.16b, v30.16b, v28.16b\n"
1010
- "and v27.16b, v27.16b, v24.16b\n"
1011
- "and v30.16b, v30.16b, v24.16b\n"
1012
- "ldr q24, [x25, #0x20]\n"
1013
- ".inst 0x4f98e0ca // sdot v10.4s, v6.16b, v24.4b[0]\n"
1014
- ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n"
1015
- ".inst 0x4f98e8c9 // sdot v9.4s, v6.16b, v24.4b[2]\n"
1016
- ".inst 0x4fb8e8d4 // sdot v20.4s, v6.16b, v24.4b[3]\n"
1017
- "ldr q24, [x25, #0x30]\n"
1018
- ".inst 0x4f98e38a // sdot v10.4s, v28.16b, v24.4b[0]\n"
1019
- ".inst 0x4fb8e39d // sdot v29.4s, v28.16b, v24.4b[1]\n"
1020
- ".inst 0x4f98eb89 // sdot v9.4s, v28.16b, v24.4b[2]\n"
1021
- ".inst 0x4fb8eb94 // sdot v20.4s, v28.16b, v24.4b[3]\n"
1022
- "ldr q24, [x25, #0x40]\n"
1023
- ".inst 0x4f98e06a // sdot v10.4s, v3.16b, v24.4b[0]\n"
1024
- ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n"
1025
- ".inst 0x4f98e869 // sdot v9.4s, v3.16b, v24.4b[2]\n"
1026
- ".inst 0x4fb8e874 // sdot v20.4s, v3.16b, v24.4b[3]\n"
1027
- "ldr q24, [x25, #0x50]\n"
1028
- ".inst 0x4f98e2ca // sdot v10.4s, v22.16b, v24.4b[0]\n"
1029
- ".inst 0x4fb8e2dd // sdot v29.4s, v22.16b, v24.4b[1]\n"
1030
- ".inst 0x4f98eac9 // sdot v9.4s, v22.16b, v24.4b[2]\n"
1031
- ".inst 0x4fb8ead4 // sdot v20.4s, v22.16b, v24.4b[3]\n"
1032
- "ldr q24, [x25, #0x60]\n"
1033
- ".inst 0x4f98e36a // sdot v10.4s, v27.16b, v24.4b[0]\n"
1034
- ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n"
1035
- ".inst 0x4f98eb69 // sdot v9.4s, v27.16b, v24.4b[2]\n"
1036
- ".inst 0x4fb8eb74 // sdot v20.4s, v27.16b, v24.4b[3]\n"
1037
- "ldr q24, [x25, #0x70]\n"
1038
- "add x25, x25, #0x88\n"
1039
- ".inst 0x4f98e3ca // sdot v10.4s, v30.16b, v24.4b[0]\n"
1040
- ".inst 0x4fb8e3dd // sdot v29.4s, v30.16b, v24.4b[1]\n"
1041
- ".inst 0x4f98ebc9 // sdot v9.4s, v30.16b, v24.4b[2]\n"
1042
- ".inst 0x4fb8ebd4 // sdot v20.4s, v30.16b, v24.4b[3]\n"
1043
- "fmul v24.4s, v17.4s, v2.s[0]\n"
1044
- "scvtf v10.4s, v10.4s, #0x4\n"
1045
- "scvtf v29.4s, v29.4s, #0x4\n"
1046
- "scvtf v9.4s, v9.4s, #0x4\n"
1047
- "scvtf v20.4s, v20.4s, #0x4\n"
1048
- "fmla v15.4s, v10.4s, v24.4s\n"
1049
- "ldr q24, [x23, #0x0]\n"
1050
- "fmul v10.4s, v17.4s, v2.s[1]\n"
1051
- "fmla v19.4s, v29.4s, v10.4s\n"
1052
- "ldr q10, [x23, #0x10]\n"
1053
- "fmul v29.4s, v17.4s, v2.s[2]\n"
1054
- "fmul v2.4s, v17.4s, v2.s[3]\n"
1055
- "fmla v18.4s, v9.4s, v29.4s\n"
1056
- "movi v9.4s, #0x0\n"
1057
- "movi v29.4s, #0x0\n"
1058
- ".inst 0x4f98e189 // sdot v9.4s, v12.16b, v24.4b[0]\n"
1059
- ".inst 0x4fb8e19d // sdot v29.4s, v12.16b, v24.4b[1]\n"
1060
- "fmla v14.4s, v20.4s, v2.4s\n"
1061
- "movi v20.4s, #0x0\n"
1062
- "movi v2.4s, #0x0\n"
1063
- ".inst 0x4f98e994 // sdot v20.4s, v12.16b, v24.4b[2]\n"
1064
- ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n"
1065
- "ldr q24, [x23, #0x20]\n"
1066
- ".inst 0x4f8ae3e9 // sdot v9.4s, v31.16b, v10.4b[0]\n"
1067
- ".inst 0x4faae3fd // sdot v29.4s, v31.16b, v10.4b[1]\n"
1068
- ".inst 0x4f8aebf4 // sdot v20.4s, v31.16b, v10.4b[2]\n"
1069
- ".inst 0x4faaebe2 // sdot v2.4s, v31.16b, v10.4b[3]\n"
1070
- "ldr q10, [x23, #0x30]\n"
1071
- ".inst 0x4f98e0c9 // sdot v9.4s, v6.16b, v24.4b[0]\n"
1072
- ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n"
1073
- ".inst 0x4f98e8d4 // sdot v20.4s, v6.16b, v24.4b[2]\n"
1074
- ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n"
1075
- "ldr q24, [x23, #0x40]\n"
1076
- ".inst 0x4f8ae389 // sdot v9.4s, v28.16b, v10.4b[0]\n"
1077
- ".inst 0x4faae39d // sdot v29.4s, v28.16b, v10.4b[1]\n"
1078
- ".inst 0x4f8aeb94 // sdot v20.4s, v28.16b, v10.4b[2]\n"
1079
- ".inst 0x4faaeb82 // sdot v2.4s, v28.16b, v10.4b[3]\n"
1080
- "ldr q10, [x23, #0x50]\n"
1081
- ".inst 0x4f98e069 // sdot v9.4s, v3.16b, v24.4b[0]\n"
1082
- ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n"
1083
- ".inst 0x4f98e874 // sdot v20.4s, v3.16b, v24.4b[2]\n"
1084
- ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n"
1085
- "ldr q24, [x23, #0x60]\n"
1086
- ".inst 0x4f8ae2c9 // sdot v9.4s, v22.16b, v10.4b[0]\n"
1087
- ".inst 0x4faae2dd // sdot v29.4s, v22.16b, v10.4b[1]\n"
1088
- ".inst 0x4f8aead4 // sdot v20.4s, v22.16b, v10.4b[2]\n"
1089
- ".inst 0x4faaeac2 // sdot v2.4s, v22.16b, v10.4b[3]\n"
1090
- "ldr q10, [x23, #0x70]\n"
1091
- "add x23, x23, #0x88\n"
1092
- ".inst 0x4f98e369 // sdot v9.4s, v27.16b, v24.4b[0]\n"
1093
- ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n"
1094
- ".inst 0x4f98eb74 // sdot v20.4s, v27.16b, v24.4b[2]\n"
1095
- ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n"
1096
- "ldr q24, [x22, #0x0]\n"
1097
- ".inst 0x4f8ae3c9 // sdot v9.4s, v30.16b, v10.4b[0]\n"
1098
- ".inst 0x4faae3dd // sdot v29.4s, v30.16b, v10.4b[1]\n"
1099
- ".inst 0x4f8aebd4 // sdot v20.4s, v30.16b, v10.4b[2]\n"
1100
- ".inst 0x4faaebc2 // sdot v2.4s, v30.16b, v10.4b[3]\n"
1101
- "fmul v10.4s, v17.4s, v26.s[0]\n"
1102
- "scvtf v9.4s, v9.4s, #0x4\n"
1103
- "scvtf v29.4s, v29.4s, #0x4\n"
1104
- "scvtf v20.4s, v20.4s, #0x4\n"
1105
- "scvtf v2.4s, v2.4s, #0x4\n"
1106
- "fmla v11.4s, v9.4s, v10.4s\n"
1107
- "ldr q9, [x22, #0x10]\n"
1108
- "fmul v10.4s, v17.4s, v26.s[1]\n"
1109
- "fmla v13.4s, v29.4s, v10.4s\n"
1110
- "ldr d29, [x22, #-0x8]\n"
1111
- "fmul v10.4s, v17.4s, v26.s[2]\n"
1112
- "fmul v26.4s, v17.4s, v26.s[3]\n"
1113
- "fcvtl v29.4s, v29.4h\n"
1114
- "fmla v23.4s, v20.4s, v10.4s\n"
1115
- "movi v20.4s, #0x0\n"
1116
- "movi v10.4s, #0x0\n"
1117
- "fmla v16.4s, v2.4s, v26.4s\n"
1118
- "movi v26.4s, #0x0\n"
1119
- "movi v2.4s, #0x0\n"
1120
- ".inst 0x4f98e194 // sdot v20.4s, v12.16b, v24.4b[0]\n"
1121
- ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n"
1122
- ".inst 0x4f98e99a // sdot v26.4s, v12.16b, v24.4b[2]\n"
1123
- ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n"
1124
- "ldr q24, [x22, #0x20]\n"
1125
- ".inst 0x4f89e3f4 // sdot v20.4s, v31.16b, v9.4b[0]\n"
1126
- ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n"
1127
- ".inst 0x4f89ebfa // sdot v26.4s, v31.16b, v9.4b[2]\n"
1128
- ".inst 0x4fa9ebe2 // sdot v2.4s, v31.16b, v9.4b[3]\n"
1129
- "ldr q9, [x22, #0x30]\n"
1335
+ // 1x8 tile = 2 x 4
1336
+ float32x4_t acc_f32[2];
1337
+
1338
+ const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
1339
+
1340
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
1341
+ const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb);
1342
+
1343
+ for (int i = 0; i < col_groups; i++) {
1344
+ acc_f32[i] = vdupq_n_f32(0);
1345
+ }
1346
+
1347
+ for (int b = 0; b < nb; b++) {
1348
+ float32x4_t q6_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d)); // d0 d1 d2 d3
1349
+ float32x4_t q6_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4)); // d4 d5 d6 d7
1350
+ float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
1351
+ float32x4_t sb_scale_0 = vmulq_f32(q6_d_0, q8_d);
1352
+ float32x4_t sb_scale_1 = vmulq_f32(q6_d_1, q8_d);
1353
+
1354
+ int32x4_t acc[col_groups];
1355
+ for (int i = 0; i < col_groups; i++) {
1356
+ acc[i] = vdupq_n_s32(0);
1357
+ }
1358
+
1359
+ // Load all 16 scales once and widen to int16 (Q6_K has 16 scales per block)
1360
+ // Reused for bias and dequantization later
1361
+ int16_t q6_scales[16 * 8];
1362
+ for (int i = 0; i < 16; i++) {
1363
+ int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8));
1364
+ vst1q_s16(q6_scales + i * 8, scales);
1365
+ }
1366
+
1367
+ // Compute bias per column using q8 bsums and preloaded scales to skip the -32 shift
1368
+ int32x4_t bias_lo = vdupq_n_s32(0);
1369
+ int32x4_t bias_hi = vdupq_n_s32(0);
1370
+
1371
+ // Load bsums in chunks of 4 to process with vectorized operations
1372
+ for (int i = 0; i < 16; i += 4) {
1373
+ int16x4_t bsums_vec = vld1_s16(q8_ptr[b].bsums + i);
1374
+ int16x4_t scales_lo_0 = vld1_s16(q6_scales + (i + 0) * 8);
1375
+ int16x4_t scales_hi_0 = vld1_s16(q6_scales + (i + 0) * 8 + 4);
1376
+ int16x4_t scales_lo_1 = vld1_s16(q6_scales + (i + 1) * 8);
1377
+ int16x4_t scales_hi_1 = vld1_s16(q6_scales + (i + 1) * 8 + 4);
1378
+ int16x4_t scales_lo_2 = vld1_s16(q6_scales + (i + 2) * 8);
1379
+ int16x4_t scales_hi_2 = vld1_s16(q6_scales + (i + 2) * 8 + 4);
1380
+ int16x4_t scales_lo_3 = vld1_s16(q6_scales + (i + 3) * 8);
1381
+ int16x4_t scales_hi_3 = vld1_s16(q6_scales + (i + 3) * 8 + 4);
1382
+
1383
+ bias_lo = vmlal_lane_s16(bias_lo, scales_lo_0, bsums_vec, 0);
1384
+ bias_hi = vmlal_lane_s16(bias_hi, scales_hi_0, bsums_vec, 0);
1385
+ bias_lo = vmlal_lane_s16(bias_lo, scales_lo_1, bsums_vec, 1);
1386
+ bias_hi = vmlal_lane_s16(bias_hi, scales_hi_1, bsums_vec, 1);
1387
+ bias_lo = vmlal_lane_s16(bias_lo, scales_lo_2, bsums_vec, 2);
1388
+ bias_hi = vmlal_lane_s16(bias_hi, scales_hi_2, bsums_vec, 2);
1389
+ bias_lo = vmlal_lane_s16(bias_lo, scales_lo_3, bsums_vec, 3);
1390
+ bias_hi = vmlal_lane_s16(bias_hi, scales_hi_3, bsums_vec, 3);
1391
+ }
1392
+ bias_lo = vshlq_n_s32(bias_lo, 5);
1393
+ bias_hi = vshlq_n_s32(bias_hi, 5);
1394
+
1395
+ // Process two 128-value halves per superblock
1396
+ for (int half = 0; half < 2; half++) {
1397
+ const uint8_t * ql_base = q6_ptr[b].ql + half * 512;
1398
+ const uint8_t * qh_base = q6_ptr[b].qh + half * 256;
1399
+
1400
+ // A subblock (sb) is a set of weights that share the scale
1401
+ // Since q6_K scales are per 16 elements
1402
+ // num sbs -> 256 elements / (16 elements/scale * 2 elements/byte * 2 halves)
1403
+ for (int sb = 0; sb < QK_K / 64; sb++) {
1404
+ const int8_t * q8_base_l = q8_ptr[b].qs + half * 128 + sb * 16;
1405
+ const int8_t * q8_base_h = q8_base_l + 64;
1406
+
1407
+ // Load and duplicate q8 values (each register covers four interleaved columns of q6)
1408
+ int8x16_t q8_l[4];
1409
+ int8x16_t q8_h[4];
1410
+ for (int i = 0; i < 4; i++) {
1411
+ q8_l[i] = (int8x16_t) vld1q_dup_s32((const int32_t *) (q8_base_l + i * 4));
1412
+ q8_h[i] = (int8x16_t) vld1q_dup_s32((const int32_t *) (q8_base_h + i * 4));
1413
+ }
1414
+
1415
+ const int ql_off_base = sb * QK_K / 2;
1416
+ const int qh_off_base = ql_off_base & 255; // wraps after 256 bytes
1417
+
1418
+ // Load 4 vectors at once (64 bytes each for ql_0, ql_1, qh_0, qh_1)
1419
+ uint8x16x4_t q6_ql_0 = vld1q_u8_x4(ql_base + ql_off_base);
1420
+ uint8x16x4_t q6_ql_1 = vld1q_u8_x4(ql_base + ql_off_base + 64);
1421
+ uint8x16x4_t q6_qh_0 = vld1q_u8_x4(qh_base + qh_off_base);
1422
+ uint8x16x4_t q6_qh_1 = vld1q_u8_x4(qh_base + qh_off_base + 64);
1423
+
1424
+ // Adjust qh for subblocks 2 and 3 (shift right by 2)
1425
+ if (sb > 1) {
1426
+ q6_qh_0.val[0] = vshrq_n_u8(q6_qh_0.val[0], 2);
1427
+ q6_qh_0.val[1] = vshrq_n_u8(q6_qh_0.val[1], 2);
1428
+ q6_qh_0.val[2] = vshrq_n_u8(q6_qh_0.val[2], 2);
1429
+ q6_qh_0.val[3] = vshrq_n_u8(q6_qh_0.val[3], 2);
1430
+ q6_qh_1.val[0] = vshrq_n_u8(q6_qh_1.val[0], 2);
1431
+ q6_qh_1.val[1] = vshrq_n_u8(q6_qh_1.val[1], 2);
1432
+ q6_qh_1.val[2] = vshrq_n_u8(q6_qh_1.val[2], 2);
1433
+ q6_qh_1.val[3] = vshrq_n_u8(q6_qh_1.val[3], 2);
1434
+ }
1435
+
1436
+ const uint8x16_t q6_ql[8] = { q6_ql_0.val[0], q6_ql_0.val[1], q6_ql_0.val[2], q6_ql_0.val[3],
1437
+ q6_ql_1.val[0], q6_ql_1.val[1], q6_ql_1.val[2], q6_ql_1.val[3] };
1438
+ const uint8x16_t q6_qh[8] = { q6_qh_0.val[0], q6_qh_0.val[1], q6_qh_0.val[2], q6_qh_0.val[3],
1439
+ q6_qh_1.val[0], q6_qh_1.val[1], q6_qh_1.val[2], q6_qh_1.val[3] };
1440
+
1441
+ // Process column groups (0-3, 4-7)
1442
+ for (int g = 0; g < col_groups; g++) {
1443
+ int32x4_t sb_acc_l = vdupq_n_s32(0);
1444
+ int32x4_t sb_acc_h = vdupq_n_s32(0);
1445
+
1446
+ for (int chunk = 0; chunk < 4; chunk++) {
1447
+ const int idx = chunk * 2 + g;
1448
+
1449
+ const uint8x16_t q6_qs_l = q6_ql[idx];
1450
+ const uint8x16_t q6_qs_h = q6_qh[idx];
1451
+
1452
+ // Extract high 2 bits for upper nibble reconstruction
1453
+ const uint8x16_t q6_qs_hh = vandq_u8(q6_qs_h, mask_hi);
1454
+
1455
+ // q6 = (low4 | high2<<4), without -32 bias (handled via bsums)
1456
+ const int8x16_t q6_l =
1457
+ vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_qs_l, m4b), vandq_u8(q6_qs_h, mask_lo), 4));
1458
+ const int8x16_t q6_h = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_l, 4), q6_qs_hh));
1459
+
1460
+ sb_acc_l = vdotq_s32(sb_acc_l, q6_l, q8_l[chunk]);
1461
+ sb_acc_h = vdotq_s32(sb_acc_h, q6_h, q8_h[chunk]);
1462
+ }
1463
+
1464
+ const int scale_idx_l = half * 8 + sb;
1465
+ const int scale_idx_h = half * 8 + sb + 4;
1466
+
1467
+ const int32x4_t scale_vec_l = vmovl_s16(vld1_s16(q6_scales + scale_idx_l * 8 + g * 4));
1468
+ const int32x4_t scale_vec_h = vmovl_s16(vld1_s16(q6_scales + scale_idx_h * 8 + g * 4));
1469
+
1470
+ acc[g] = vmlaq_s32(acc[g], sb_acc_l, scale_vec_l);
1471
+ acc[g] = vmlaq_s32(acc[g], sb_acc_h, scale_vec_h);
1472
+ }
1473
+ }
1474
+ } // for half
1475
+
1476
+ // Bias correction
1477
+ acc[0] = vsubq_s32(acc[0], bias_lo);
1478
+ acc[1] = vsubq_s32(acc[1], bias_hi);
1479
+
1480
+ // Apply superblock scale (no mins for q6_K)
1481
+ // acc[g] has [c0, c1, c2, c3]
1482
+ float32x4_t w_0123 = vmulq_f32(vcvtq_f32_s32(acc[0]), sb_scale_0);
1483
+ float32x4_t w_4567 = vmulq_f32(vcvtq_f32_s32(acc[1]), sb_scale_1);
1484
+
1485
+ acc_f32[0] = vaddq_f32(acc_f32[0], w_0123);
1486
+ acc_f32[1] = vaddq_f32(acc_f32[1], w_4567);
1487
+ } // for b
1488
+
1489
+ int base = x * ncols_interleaved;
1490
+ vst1q_f32(s + base, acc_f32[0]);
1491
+ vst1q_f32(s + base + 4, acc_f32[1]);
1492
+ } // for x
1493
+ return;
1494
+ #endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
1495
+ ggml_gemv_q6_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
1496
+ }
1497
+
1498
+ void ggml_gemv_q6_K_8x8_q8_K(int n,
1499
+ float * GGML_RESTRICT s,
1500
+ size_t bs,
1501
+ const void * GGML_RESTRICT vx,
1502
+ const void * GGML_RESTRICT vy,
1503
+ int nr,
1504
+ int nc) {
1505
+ constexpr int qk = QK_K;
1506
+ const int nb = n / qk;
1507
+
1508
+ constexpr int ncols_interleaved = 8;
1509
+ constexpr int blocklen = 8;
1510
+
1511
+ assert(n % qk == 0);
1512
+ assert(nc % ncols_interleaved == 0);
1513
+
1514
+ UNUSED(nb);
1515
+ UNUSED(ncols_interleaved);
1516
+ UNUSED(blocklen);
1517
+
1518
+ #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
1519
+ constexpr int col_pairs = ncols_interleaved / 2;
1520
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
1521
+ const uint8x16_t mask_lo = vdupq_n_u8(0x03);
1522
+ const uint8x16_t mask_hi = vdupq_n_u8(0x30);
1523
+
1524
+ // 1x8 tile = 2 x 4
1525
+ float32x4_t acc_f32[2];
1526
+
1527
+ const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
1528
+
1529
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
1530
+ const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb);
1531
+
1532
+ acc_f32[0] = vdupq_n_f32(0);
1533
+ acc_f32[1] = vdupq_n_f32(0);
1534
+
1535
+ for (int b = 0; b < nb; b++) {
1536
+ float32x4_t q6_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d)); // d0 d1 d2 d3
1537
+ float32x4_t q6_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4)); // d4 d5 d6 d7
1538
+ float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
1539
+ float32x4_t sb_scale_0 = vmulq_f32(q6_d_0, q8_d);
1540
+ float32x4_t sb_scale_1 = vmulq_f32(q6_d_1, q8_d);
1541
+
1542
+ int32x2_t acc[col_pairs];
1543
+ for (int i = 0; i < col_pairs; i++) {
1544
+ acc[i] = vdup_n_s32(0);
1545
+ }
1546
+
1547
+ // Load all 16 scales once and widen to int16 (Q6_K has 16 scales per block)
1548
+ // Reused for bias and dequantization later
1549
+ int16_t q6_scales[16 * 8];
1550
+ for (int i = 0; i < 16; i++) {
1551
+ int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8));
1552
+ vst1q_s16(q6_scales + i * 8, scales);
1553
+ }
1554
+
1555
+ // Compute bias per column using q8 bsums and preloaded scales to skip the -32 shift
1556
+ int32x4_t bias_lo = vdupq_n_s32(0);
1557
+ int32x4_t bias_hi = vdupq_n_s32(0);
1558
+
1559
+ // Load bsums in chunks of 4 to process with vectorized operations
1560
+ for (int i = 0; i < 16; i += 4) {
1561
+ int16x4_t bsums_vec = vld1_s16(q8_ptr[b].bsums + i);
1562
+ int16x4_t scales_lo_0 = vld1_s16(q6_scales + (i + 0) * 8);
1563
+ int16x4_t scales_hi_0 = vld1_s16(q6_scales + (i + 0) * 8 + 4);
1564
+ int16x4_t scales_lo_1 = vld1_s16(q6_scales + (i + 1) * 8);
1565
+ int16x4_t scales_hi_1 = vld1_s16(q6_scales + (i + 1) * 8 + 4);
1566
+ int16x4_t scales_lo_2 = vld1_s16(q6_scales + (i + 2) * 8);
1567
+ int16x4_t scales_hi_2 = vld1_s16(q6_scales + (i + 2) * 8 + 4);
1568
+ int16x4_t scales_lo_3 = vld1_s16(q6_scales + (i + 3) * 8);
1569
+ int16x4_t scales_hi_3 = vld1_s16(q6_scales + (i + 3) * 8 + 4);
1570
+
1571
+ bias_lo = vmlal_lane_s16(bias_lo, scales_lo_0, bsums_vec, 0);
1572
+ bias_hi = vmlal_lane_s16(bias_hi, scales_hi_0, bsums_vec, 0);
1573
+ bias_lo = vmlal_lane_s16(bias_lo, scales_lo_1, bsums_vec, 1);
1574
+ bias_hi = vmlal_lane_s16(bias_hi, scales_hi_1, bsums_vec, 1);
1575
+ bias_lo = vmlal_lane_s16(bias_lo, scales_lo_2, bsums_vec, 2);
1576
+ bias_hi = vmlal_lane_s16(bias_hi, scales_hi_2, bsums_vec, 2);
1577
+ bias_lo = vmlal_lane_s16(bias_lo, scales_lo_3, bsums_vec, 3);
1578
+ bias_hi = vmlal_lane_s16(bias_hi, scales_hi_3, bsums_vec, 3);
1579
+ }
1580
+ bias_lo = vshlq_n_s32(bias_lo, 5);
1581
+ bias_hi = vshlq_n_s32(bias_hi, 5);
1582
+
1583
+ // Process two 128-value halves per superblock
1584
+ for (int half = 0; half < 2; half++) {
1585
+ const uint8_t * ql_base = q6_ptr[b].ql + half * 512;
1586
+ const uint8_t * qh_base = q6_ptr[b].qh + half * 256;
1587
+
1588
+ // A subblock (sb) is a set of weights that share the scale
1589
+ // Since q6_K scales are per 16 elements
1590
+ // num sbs -> 256 elements / (16 elements/scale * 2 elements/byte * 2 halves)
1591
+ for (int sb = 0; sb < QK_K / 64; sb++) {
1592
+ const int8_t * q8_base_l = q8_ptr[b].qs + half * 128 + sb * 16;
1593
+ const int8_t * q8_base_h = q8_base_l + 64;
1594
+
1595
+ // Load and duplicate q8 values (each register covers two interleaved columns of q6)
1596
+ int8x16_t q8_l[2];
1597
+ int8x16_t q8_h[2];
1598
+ for (int i = 0; i < 2; i++) {
1599
+ q8_l[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base_l + i * 8));
1600
+ q8_h[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base_h + i * 8));
1601
+ }
1602
+
1603
+ const int ql_off_base = sb * QK_K / 2;
1604
+ const int qh_off_base = ql_off_base & 255; // wraps after 256 bytes
1605
+
1606
+ // Load 4 vectors at once (64 bytes each for ql_0, ql_1, qh_0, qh_1)
1607
+ uint8x16x4_t q6_ql_0 = vld1q_u8_x4(ql_base + ql_off_base);
1608
+ uint8x16x4_t q6_ql_1 = vld1q_u8_x4(ql_base + ql_off_base + 64);
1609
+ uint8x16x4_t q6_qh_0 = vld1q_u8_x4(qh_base + qh_off_base);
1610
+ uint8x16x4_t q6_qh_1 = vld1q_u8_x4(qh_base + qh_off_base + 64);
1611
+
1612
+ // Adjust qh for subblocks 2 and 3 (shift right by 2)
1613
+ if (sb > 1) {
1614
+ q6_qh_0.val[0] = vshrq_n_u8(q6_qh_0.val[0], 2);
1615
+ q6_qh_0.val[1] = vshrq_n_u8(q6_qh_0.val[1], 2);
1616
+ q6_qh_0.val[2] = vshrq_n_u8(q6_qh_0.val[2], 2);
1617
+ q6_qh_0.val[3] = vshrq_n_u8(q6_qh_0.val[3], 2);
1618
+ q6_qh_1.val[0] = vshrq_n_u8(q6_qh_1.val[0], 2);
1619
+ q6_qh_1.val[1] = vshrq_n_u8(q6_qh_1.val[1], 2);
1620
+ q6_qh_1.val[2] = vshrq_n_u8(q6_qh_1.val[2], 2);
1621
+ q6_qh_1.val[3] = vshrq_n_u8(q6_qh_1.val[3], 2);
1622
+ }
1623
+
1624
+ // Process column pairs (0-1, 2-3, 4-5, 6-7)
1625
+ for (int cp = 0; cp < col_pairs; cp++) {
1626
+ const uint8x16_t q6_qs_cp_0_l = q6_ql_0.val[cp];
1627
+ const uint8x16_t q6_qs_cp_1_l = q6_ql_1.val[cp];
1628
+ const uint8x16_t q6_qs_cp_0_h = q6_qh_0.val[cp];
1629
+ const uint8x16_t q6_qs_cp_1_h = q6_qh_1.val[cp];
1630
+
1631
+ // Extract high 2 bits for upper nibble reconstruction
1632
+ const uint8x16_t q6_qs_cp_0_hh = vandq_u8(q6_qs_cp_0_h, mask_hi);
1633
+ const uint8x16_t q6_qs_cp_1_hh = vandq_u8(q6_qs_cp_1_h, mask_hi);
1634
+
1635
+ // q6 = (low4 | high2<<4), without -32 bias (handled via bsums)
1636
+ const int8x16_t q6_l0 = vreinterpretq_s8_u8(
1637
+ vsliq_n_u8(vandq_u8(q6_qs_cp_0_l, m4b), vandq_u8(q6_qs_cp_0_h, mask_lo), 4));
1638
+ const int8x16_t q6_l1 = vreinterpretq_s8_u8(
1639
+ vsliq_n_u8(vandq_u8(q6_qs_cp_1_l, m4b), vandq_u8(q6_qs_cp_1_h, mask_lo), 4));
1640
+ const int8x16_t q6_h0 =
1641
+ vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_0_l, 4), q6_qs_cp_0_hh));
1642
+ const int8x16_t q6_h1 =
1643
+ vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_1_l, 4), q6_qs_cp_1_hh));
1644
+
1645
+ int32x4_t sb_acc_l = vdupq_n_s32(0);
1646
+ sb_acc_l = vdotq_s32(sb_acc_l, q6_l0, q8_l[0]);
1647
+ sb_acc_l = vdotq_s32(sb_acc_l, q6_l1, q8_l[1]);
1648
+
1649
+ int32x4_t sb_acc_h = vdupq_n_s32(0);
1650
+ sb_acc_h = vdotq_s32(sb_acc_h, q6_h0, q8_h[0]);
1651
+ sb_acc_h = vdotq_s32(sb_acc_h, q6_h1, q8_h[1]);
1652
+
1653
+ // Pairwise add to get per-column sums: [col0, col1]
1654
+ int32x2_t sum_l = vpadd_s32(vget_low_s32(sb_acc_l), vget_high_s32(sb_acc_l));
1655
+ int32x2_t sum_h = vpadd_s32(vget_low_s32(sb_acc_h), vget_high_s32(sb_acc_h));
1656
+
1657
+ const int scale_idx_l = half * 8 + sb;
1658
+ const int scale_idx_h = half * 8 + sb + 4;
1659
+
1660
+ // Access scales using array indexing (scales are interleaved by column)
1661
+ const int32x2_t scale_vec_l = { (int32_t) q6_scales[scale_idx_l * 8 + cp * 2],
1662
+ (int32_t) q6_scales[scale_idx_l * 8 + cp * 2 + 1] };
1663
+ const int32x2_t scale_vec_h = { (int32_t) q6_scales[scale_idx_h * 8 + cp * 2],
1664
+ (int32_t) q6_scales[scale_idx_h * 8 + cp * 2 + 1] };
1665
+
1666
+ // Accumulate scaled results
1667
+ acc[cp] = vmla_s32(acc[cp], sum_l, scale_vec_l);
1668
+ acc[cp] = vmla_s32(acc[cp], sum_h, scale_vec_h);
1669
+ }
1670
+ }
1671
+ } // for half
1672
+
1673
+ // Bias correction
1674
+ acc[0] = vsub_s32(acc[0], vget_low_s32(bias_lo));
1675
+ acc[1] = vsub_s32(acc[1], vget_high_s32(bias_lo));
1676
+ acc[2] = vsub_s32(acc[2], vget_low_s32(bias_hi));
1677
+ acc[3] = vsub_s32(acc[3], vget_high_s32(bias_hi));
1678
+
1679
+ // Apply superblock scale (no mins for q6_K)
1680
+ // acc[cp] has [c0, c1]
1681
+ float32x2_t w_01 = vmul_f32(vcvt_f32_s32(acc[0]), vget_low_f32(sb_scale_0));
1682
+ float32x2_t w_23 = vmul_f32(vcvt_f32_s32(acc[1]), vget_high_f32(sb_scale_0));
1683
+ float32x2_t w_45 = vmul_f32(vcvt_f32_s32(acc[2]), vget_low_f32(sb_scale_1));
1684
+ float32x2_t w_67 = vmul_f32(vcvt_f32_s32(acc[3]), vget_high_f32(sb_scale_1));
1685
+
1686
+ acc_f32[0] = vaddq_f32(acc_f32[0], vcombine_f32(w_01, w_23));
1687
+ acc_f32[1] = vaddq_f32(acc_f32[1], vcombine_f32(w_45, w_67));
1688
+ } // for b
1689
+
1690
+ int base = x * ncols_interleaved;
1691
+ vst1q_f32(s + base, acc_f32[0]);
1692
+ vst1q_f32(s + base + 4, acc_f32[1]);
1693
+ } // for x
1694
+ return;
1695
+ #endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
1696
+ ggml_gemv_q6_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
1697
+ }
1698
+
1699
+ void ggml_gemv_q8_0_4x4_q8_0(int n,
1700
+ float * GGML_RESTRICT s,
1701
+ size_t bs,
1702
+ const void * GGML_RESTRICT vx,
1703
+ const void * GGML_RESTRICT vy,
1704
+ int nr,
1705
+ int nc) {
1706
+ const int qk = QK8_0;
1707
+ const int nb = n / qk;
1708
+ const int ncols_interleaved = 4;
1709
+ const int blocklen = 4;
1710
+
1711
+ assert(n % qk == 0);
1712
+ assert(nc % ncols_interleaved == 0);
1713
+
1714
+ UNUSED(nb);
1715
+ UNUSED(ncols_interleaved);
1716
+ UNUSED(blocklen);
1717
+
1718
+ #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
1719
+ const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx;
1720
+
1721
+ for (int c = 0; c < nc; c += ncols_interleaved) {
1722
+ const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
1723
+ float32x4_t acc = vdupq_n_f32(0);
1724
+ for (int b = 0; b < nb; b++) {
1725
+ int8x16x4_t b_low = vld1q_s8_x4((const int8_t *) b_ptr->qs);
1726
+ int8x16x4_t b_high = vld1q_s8_x4((const int8_t *) b_ptr->qs + 64);
1727
+ float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);
1728
+
1729
+ int8x16x2_t a = vld1q_s8_x2(a_ptr->qs);
1730
+ float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
1731
+
1732
+ int32x4_t ret = vdupq_n_s32(0);
1733
+
1734
+ ret = vdotq_laneq_s32(ret, b_low.val[0], a.val[0], 0);
1735
+ ret = vdotq_laneq_s32(ret, b_low.val[1], a.val[0], 1);
1736
+ ret = vdotq_laneq_s32(ret, b_low.val[2], a.val[0], 2);
1737
+ ret = vdotq_laneq_s32(ret, b_low.val[3], a.val[0], 3);
1738
+
1739
+ ret = vdotq_laneq_s32(ret, b_high.val[0], a.val[1], 0);
1740
+ ret = vdotq_laneq_s32(ret, b_high.val[1], a.val[1], 1);
1741
+ ret = vdotq_laneq_s32(ret, b_high.val[2], a.val[1], 2);
1742
+ ret = vdotq_laneq_s32(ret, b_high.val[3], a.val[1], 3);
1743
+
1744
+ acc = vfmaq_f32(acc, vcvtq_f32_s32(ret), vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
1745
+ a_ptr++;
1746
+ b_ptr++;
1747
+ }
1748
+ vst1q_f32(s, acc);
1749
+ s += ncols_interleaved;
1750
+ }
1751
+ return;
1752
+
1753
+ #endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
1754
+ ggml_gemv_q8_0_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
1755
+ }
1756
+
1757
+ void ggml_gemv_q8_0_4x8_q8_0(int n,
1758
+ float * GGML_RESTRICT s,
1759
+ size_t bs,
1760
+ const void * GGML_RESTRICT vx,
1761
+ const void * GGML_RESTRICT vy,
1762
+ int nr,
1763
+ int nc) {
1764
+ const int qk = QK8_0;
1765
+ const int nb = n / qk;
1766
+ const int ncols_interleaved = 4;
1767
+ const int blocklen = 8;
1768
+
1769
+ assert(n % qk == 0);
1770
+ assert(nc % ncols_interleaved == 0);
1771
+
1772
+ UNUSED(nb);
1773
+ UNUSED(ncols_interleaved);
1774
+ UNUSED(blocklen);
1775
+
1776
+ #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
1777
+ const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx;
1778
+
1779
+ for (int c = 0; c < nc; c += ncols_interleaved) {
1780
+ const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
1781
+ float32x4_t acc = vdupq_n_f32(0);
1782
+
1783
+ for (int b = 0; b < nb; b++) {
1784
+ int8x16x4_t b_low = vld1q_s8_x4((const int8_t *) b_ptr->qs);
1785
+ int8x16x4_t b_high = vld1q_s8_x4((const int8_t *) b_ptr->qs + 64);
1786
+ float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);
1787
+
1788
+ int8x8x4_t a_chunks = vld1_s8_x4(a_ptr->qs);
1789
+ int8x16_t a0 = vcombine_s8(a_chunks.val[0], a_chunks.val[0]);
1790
+ int8x16_t a1 = vcombine_s8(a_chunks.val[1], a_chunks.val[1]);
1791
+ int8x16_t a2 = vcombine_s8(a_chunks.val[2], a_chunks.val[2]);
1792
+ int8x16_t a3 = vcombine_s8(a_chunks.val[3], a_chunks.val[3]);
1793
+ float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
1794
+
1795
+ int32x4_t ret0 = vdupq_n_s32(0);
1796
+ int32x4_t ret1 = vdupq_n_s32(0);
1797
+
1798
+ // 0..7
1799
+ ret0 = vdotq_s32(ret0, b_low.val[0], a0);
1800
+ ret1 = vdotq_s32(ret1, b_low.val[1], a0);
1801
+ // 8..15
1802
+ ret0 = vdotq_s32(ret0, b_low.val[2], a1);
1803
+ ret1 = vdotq_s32(ret1, b_low.val[3], a1);
1804
+ // 16..23
1805
+ ret0 = vdotq_s32(ret0, b_high.val[0], a2);
1806
+ ret1 = vdotq_s32(ret1, b_high.val[1], a2);
1807
+ // 24..31
1808
+ ret0 = vdotq_s32(ret0, b_high.val[2], a3);
1809
+ ret1 = vdotq_s32(ret1, b_high.val[3], a3);
1810
+
1811
+ int32x4_t ret = vpaddq_s32(ret0, ret1);
1812
+
1813
+ acc = vfmaq_f32(acc, vcvtq_f32_s32(ret), vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
1814
+ a_ptr++;
1815
+ b_ptr++;
1816
+ }
1817
+ vst1q_f32(s, acc);
1818
+ s += ncols_interleaved;
1819
+ }
1820
+ return;
1821
+
1822
+ #endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
1823
+ ggml_gemv_q8_0_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
1824
+ }
1825
+
1826
+ void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
1827
+ const int qk = QK8_0;
1828
+ const int nb = n / qk;
1829
+ const int ncols_interleaved = 4;
1830
+ const int blocklen = 4;
1831
+
1832
+ assert (n % qk == 0);
1833
+ assert (nr % 4 == 0);
1834
+ assert (nc % ncols_interleaved == 0);
1835
+
1836
+ UNUSED(s);
1837
+ UNUSED(bs);
1838
+ UNUSED(vx);
1839
+ UNUSED(vy);
1840
+ UNUSED(nr);
1841
+ UNUSED(nc);
1842
+ UNUSED(nb);
1843
+ UNUSED(ncols_interleaved);
1844
+ UNUSED(blocklen);
1845
+
1846
+ #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
1847
+ const void * b_ptr = vx;
1848
+ const void * a_ptr = vy;
1849
+ float * res_ptr = s;
1850
+ size_t res_stride = bs * sizeof(float);
1851
+
1852
+ __asm__ __volatile__(
1853
+ "mov x10, %x[nr]\n"
1854
+ "mov x9, #0x88\n"
1855
+ "cmp x10, #0x10\n"
1856
+ "mul x9, %x[nb], x9\n"
1857
+ "blt 4f\n"
1858
+ "1:" // Row loop
1859
+ "add x28, %x[b_ptr], #0x8\n"
1860
+ "mov x27, %x[nc]\n"
1861
+ "add x26, %x[res_ptr], %x[res_stride], LSL #4\n"
1862
+ "2:" // Column loop
1863
+ "add x25, %x[a_ptr], #0x8\n"
1864
+ "movi v15.16b, #0x0\n"
1865
+ "movi v19.16b, #0x0\n"
1866
+ "mov x24, %x[nb]\n"
1867
+ "add x23, x25, x9\n"
1868
+ "movi v18.16b, #0x0\n"
1869
+ "movi v14.16b, #0x0\n"
1870
+ "add x22, x23, x9\n"
1871
+ "movi v11.16b, #0x0\n"
1872
+ "movi v13.16b, #0x0\n"
1873
+ "add x21, x22, x9\n"
1874
+ "movi v23.16b, #0x0\n"
1875
+ "movi v16.16b, #0x0\n"
1876
+ "movi v25.16b, #0x0\n"
1877
+ "movi v7.16b, #0x0\n"
1878
+ "movi v0.16b, #0x0\n"
1879
+ "movi v4.16b, #0x0\n"
1880
+ "movi v5.16b, #0x0\n"
1881
+ "movi v21.16b, #0x0\n"
1882
+ "movi v8.16b, #0x0\n"
1883
+ "movi v1.16b, #0x0\n"
1884
+ "3:" // Block loop
1885
+ "ldr q3, [x28, #0x0]\n"
1886
+ "ldr q31, [x25, #0x0]\n"
1887
+ "movi v28.16b, #0x4\n"
1888
+ "movi v10.4s, #0x0\n"
1889
+ "ldr q22, [x28, #0x10]\n"
1890
+ "ldr q6, [x25, #0x10]\n"
1891
+ "movi v29.4s, #0x0\n"
1892
+ "movi v9.4s, #0x0\n"
1893
+ "ldr q27, [x28, #0x20]\n"
1894
+ "ldr q30, [x28, #0x30]\n"
1895
+ "movi v20.4s, #0x0\n"
1896
+ "movi v24.16b, #0xf0\n"
1897
+ "ldr d2, [x25, #-0x8]\n"
1898
+ "ldr d26, [x23, #-0x8]\n"
1899
+ "sshl v12.16b, v3.16b, v28.16b\n"
1900
+ "sub x20, x28, #0x8\n"
1901
+ "ldr d17, [x20, #0x0]\n"
1902
+ "and v3.16b, v3.16b, v24.16b\n"
1903
+ "subs x24, x24, #0x1\n"
1904
+ "add x28, x28, #0x48\n"
1905
+ ".inst 0x4f9fe18a // sdot v10.4s, v12.16b, v31.4b[0]\n"
1906
+ ".inst 0x4fbfe19d // sdot v29.4s, v12.16b, v31.4b[1]\n"
1907
+ ".inst 0x4f9fe989 // sdot v9.4s, v12.16b, v31.4b[2]\n"
1908
+ ".inst 0x4fbfe994 // sdot v20.4s, v12.16b, v31.4b[3]\n"
1909
+ "sshl v31.16b, v22.16b, v28.16b\n"
1910
+ "and v22.16b, v22.16b, v24.16b\n"
1911
+ "fcvtl v17.4s, v17.4h\n"
1912
+ "fcvtl v2.4s, v2.4h\n"
1913
+ "fcvtl v26.4s, v26.4h\n"
1914
+ ".inst 0x4f86e3ea // sdot v10.4s, v31.16b, v6.4b[0]\n"
1915
+ ".inst 0x4fa6e3fd // sdot v29.4s, v31.16b, v6.4b[1]\n"
1916
+ ".inst 0x4f86ebe9 // sdot v9.4s, v31.16b, v6.4b[2]\n"
1917
+ ".inst 0x4fa6ebf4 // sdot v20.4s, v31.16b, v6.4b[3]\n"
1918
+ "sshl v6.16b, v27.16b, v28.16b\n"
1919
+ "sshl v28.16b, v30.16b, v28.16b\n"
1920
+ "and v27.16b, v27.16b, v24.16b\n"
1921
+ "and v30.16b, v30.16b, v24.16b\n"
1922
+ "ldr q24, [x25, #0x20]\n"
1923
+ ".inst 0x4f98e0ca // sdot v10.4s, v6.16b, v24.4b[0]\n"
1924
+ ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n"
1925
+ ".inst 0x4f98e8c9 // sdot v9.4s, v6.16b, v24.4b[2]\n"
1926
+ ".inst 0x4fb8e8d4 // sdot v20.4s, v6.16b, v24.4b[3]\n"
1927
+ "ldr q24, [x25, #0x30]\n"
1928
+ ".inst 0x4f98e38a // sdot v10.4s, v28.16b, v24.4b[0]\n"
1929
+ ".inst 0x4fb8e39d // sdot v29.4s, v28.16b, v24.4b[1]\n"
1930
+ ".inst 0x4f98eb89 // sdot v9.4s, v28.16b, v24.4b[2]\n"
1931
+ ".inst 0x4fb8eb94 // sdot v20.4s, v28.16b, v24.4b[3]\n"
1932
+ "ldr q24, [x25, #0x40]\n"
1933
+ ".inst 0x4f98e06a // sdot v10.4s, v3.16b, v24.4b[0]\n"
1934
+ ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n"
1935
+ ".inst 0x4f98e869 // sdot v9.4s, v3.16b, v24.4b[2]\n"
1936
+ ".inst 0x4fb8e874 // sdot v20.4s, v3.16b, v24.4b[3]\n"
1937
+ "ldr q24, [x25, #0x50]\n"
1938
+ ".inst 0x4f98e2ca // sdot v10.4s, v22.16b, v24.4b[0]\n"
1939
+ ".inst 0x4fb8e2dd // sdot v29.4s, v22.16b, v24.4b[1]\n"
1940
+ ".inst 0x4f98eac9 // sdot v9.4s, v22.16b, v24.4b[2]\n"
1941
+ ".inst 0x4fb8ead4 // sdot v20.4s, v22.16b, v24.4b[3]\n"
1942
+ "ldr q24, [x25, #0x60]\n"
1943
+ ".inst 0x4f98e36a // sdot v10.4s, v27.16b, v24.4b[0]\n"
1944
+ ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n"
1945
+ ".inst 0x4f98eb69 // sdot v9.4s, v27.16b, v24.4b[2]\n"
1946
+ ".inst 0x4fb8eb74 // sdot v20.4s, v27.16b, v24.4b[3]\n"
1947
+ "ldr q24, [x25, #0x70]\n"
1948
+ "add x25, x25, #0x88\n"
1949
+ ".inst 0x4f98e3ca // sdot v10.4s, v30.16b, v24.4b[0]\n"
1950
+ ".inst 0x4fb8e3dd // sdot v29.4s, v30.16b, v24.4b[1]\n"
1951
+ ".inst 0x4f98ebc9 // sdot v9.4s, v30.16b, v24.4b[2]\n"
1952
+ ".inst 0x4fb8ebd4 // sdot v20.4s, v30.16b, v24.4b[3]\n"
1953
+ "fmul v24.4s, v17.4s, v2.s[0]\n"
1954
+ "scvtf v10.4s, v10.4s, #0x4\n"
1955
+ "scvtf v29.4s, v29.4s, #0x4\n"
1956
+ "scvtf v9.4s, v9.4s, #0x4\n"
1957
+ "scvtf v20.4s, v20.4s, #0x4\n"
1958
+ "fmla v15.4s, v10.4s, v24.4s\n"
1959
+ "ldr q24, [x23, #0x0]\n"
1960
+ "fmul v10.4s, v17.4s, v2.s[1]\n"
1961
+ "fmla v19.4s, v29.4s, v10.4s\n"
1962
+ "ldr q10, [x23, #0x10]\n"
1963
+ "fmul v29.4s, v17.4s, v2.s[2]\n"
1964
+ "fmul v2.4s, v17.4s, v2.s[3]\n"
1965
+ "fmla v18.4s, v9.4s, v29.4s\n"
1966
+ "movi v9.4s, #0x0\n"
1967
+ "movi v29.4s, #0x0\n"
1968
+ ".inst 0x4f98e189 // sdot v9.4s, v12.16b, v24.4b[0]\n"
1969
+ ".inst 0x4fb8e19d // sdot v29.4s, v12.16b, v24.4b[1]\n"
1970
+ "fmla v14.4s, v20.4s, v2.4s\n"
1971
+ "movi v20.4s, #0x0\n"
1972
+ "movi v2.4s, #0x0\n"
1973
+ ".inst 0x4f98e994 // sdot v20.4s, v12.16b, v24.4b[2]\n"
1974
+ ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n"
1975
+ "ldr q24, [x23, #0x20]\n"
1976
+ ".inst 0x4f8ae3e9 // sdot v9.4s, v31.16b, v10.4b[0]\n"
1977
+ ".inst 0x4faae3fd // sdot v29.4s, v31.16b, v10.4b[1]\n"
1978
+ ".inst 0x4f8aebf4 // sdot v20.4s, v31.16b, v10.4b[2]\n"
1979
+ ".inst 0x4faaebe2 // sdot v2.4s, v31.16b, v10.4b[3]\n"
1980
+ "ldr q10, [x23, #0x30]\n"
1981
+ ".inst 0x4f98e0c9 // sdot v9.4s, v6.16b, v24.4b[0]\n"
1982
+ ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n"
1983
+ ".inst 0x4f98e8d4 // sdot v20.4s, v6.16b, v24.4b[2]\n"
1984
+ ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n"
1985
+ "ldr q24, [x23, #0x40]\n"
1986
+ ".inst 0x4f8ae389 // sdot v9.4s, v28.16b, v10.4b[0]\n"
1987
+ ".inst 0x4faae39d // sdot v29.4s, v28.16b, v10.4b[1]\n"
1988
+ ".inst 0x4f8aeb94 // sdot v20.4s, v28.16b, v10.4b[2]\n"
1989
+ ".inst 0x4faaeb82 // sdot v2.4s, v28.16b, v10.4b[3]\n"
1990
+ "ldr q10, [x23, #0x50]\n"
1991
+ ".inst 0x4f98e069 // sdot v9.4s, v3.16b, v24.4b[0]\n"
1992
+ ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n"
1993
+ ".inst 0x4f98e874 // sdot v20.4s, v3.16b, v24.4b[2]\n"
1994
+ ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n"
1995
+ "ldr q24, [x23, #0x60]\n"
1996
+ ".inst 0x4f8ae2c9 // sdot v9.4s, v22.16b, v10.4b[0]\n"
1997
+ ".inst 0x4faae2dd // sdot v29.4s, v22.16b, v10.4b[1]\n"
1998
+ ".inst 0x4f8aead4 // sdot v20.4s, v22.16b, v10.4b[2]\n"
1999
+ ".inst 0x4faaeac2 // sdot v2.4s, v22.16b, v10.4b[3]\n"
2000
+ "ldr q10, [x23, #0x70]\n"
2001
+ "add x23, x23, #0x88\n"
2002
+ ".inst 0x4f98e369 // sdot v9.4s, v27.16b, v24.4b[0]\n"
2003
+ ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n"
2004
+ ".inst 0x4f98eb74 // sdot v20.4s, v27.16b, v24.4b[2]\n"
2005
+ ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n"
2006
+ "ldr q24, [x22, #0x0]\n"
2007
+ ".inst 0x4f8ae3c9 // sdot v9.4s, v30.16b, v10.4b[0]\n"
2008
+ ".inst 0x4faae3dd // sdot v29.4s, v30.16b, v10.4b[1]\n"
2009
+ ".inst 0x4f8aebd4 // sdot v20.4s, v30.16b, v10.4b[2]\n"
2010
+ ".inst 0x4faaebc2 // sdot v2.4s, v30.16b, v10.4b[3]\n"
2011
+ "fmul v10.4s, v17.4s, v26.s[0]\n"
2012
+ "scvtf v9.4s, v9.4s, #0x4\n"
2013
+ "scvtf v29.4s, v29.4s, #0x4\n"
2014
+ "scvtf v20.4s, v20.4s, #0x4\n"
2015
+ "scvtf v2.4s, v2.4s, #0x4\n"
2016
+ "fmla v11.4s, v9.4s, v10.4s\n"
2017
+ "ldr q9, [x22, #0x10]\n"
2018
+ "fmul v10.4s, v17.4s, v26.s[1]\n"
2019
+ "fmla v13.4s, v29.4s, v10.4s\n"
2020
+ "ldr d29, [x22, #-0x8]\n"
2021
+ "fmul v10.4s, v17.4s, v26.s[2]\n"
2022
+ "fmul v26.4s, v17.4s, v26.s[3]\n"
2023
+ "fcvtl v29.4s, v29.4h\n"
2024
+ "fmla v23.4s, v20.4s, v10.4s\n"
2025
+ "movi v20.4s, #0x0\n"
2026
+ "movi v10.4s, #0x0\n"
2027
+ "fmla v16.4s, v2.4s, v26.4s\n"
2028
+ "movi v26.4s, #0x0\n"
2029
+ "movi v2.4s, #0x0\n"
2030
+ ".inst 0x4f98e194 // sdot v20.4s, v12.16b, v24.4b[0]\n"
2031
+ ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n"
2032
+ ".inst 0x4f98e99a // sdot v26.4s, v12.16b, v24.4b[2]\n"
2033
+ ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n"
2034
+ "ldr q24, [x22, #0x20]\n"
2035
+ ".inst 0x4f89e3f4 // sdot v20.4s, v31.16b, v9.4b[0]\n"
2036
+ ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n"
2037
+ ".inst 0x4f89ebfa // sdot v26.4s, v31.16b, v9.4b[2]\n"
2038
+ ".inst 0x4fa9ebe2 // sdot v2.4s, v31.16b, v9.4b[3]\n"
2039
+ "ldr q9, [x22, #0x30]\n"
1130
2040
  ".inst 0x4f98e0d4 // sdot v20.4s, v6.16b, v24.4b[0]\n"
1131
2041
  ".inst 0x4fb8e0ca // sdot v10.4s, v6.16b, v24.4b[1]\n"
1132
2042
  ".inst 0x4f98e8da // sdot v26.4s, v6.16b, v24.4b[2]\n"
@@ -2247,89 +3157,1372 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
2247
3157
  );
2248
3158
  return;
2249
3159
  }
2250
- #endif // #if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
3160
+ #endif // #if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
3161
+
3162
+ #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
3163
+ ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
3164
+ }
3165
+
3166
+ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
3167
+ const int qk = QK8_0;
3168
+ const int nb = n / qk;
3169
+ const int ncols_interleaved = 4;
3170
+ const int blocklen = 4;
3171
+
3172
+ assert (n % qk == 0);
3173
+ assert (nr % 4 == 0);
3174
+ assert (nc % ncols_interleaved == 0);
3175
+
3176
+ UNUSED(s);
3177
+ UNUSED(bs);
3178
+ UNUSED(vx);
3179
+ UNUSED(vy);
3180
+ UNUSED(nr);
3181
+ UNUSED(nc);
3182
+ UNUSED(nb);
3183
+ UNUSED(ncols_interleaved);
3184
+ UNUSED(blocklen);
3185
+
3186
+ #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
3187
+ const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl);
3188
+
3189
+ for (int y = 0; y < nr / 4; y++) {
3190
+ const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
3191
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
3192
+ const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
3193
+
3194
+ float32x4_t sumf[4];
3195
+ for (int m = 0; m < 4; m++) {
3196
+ sumf[m] = vdupq_n_f32(0);
3197
+ }
3198
+
3199
+ for (int l = 0; l < nb; l++) {
3200
+ float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d));
3201
+ float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
3202
+
3203
+ int32x4_t sumi_0 = vdupq_n_s32(0);
3204
+ int32x4_t sumi_1 = vdupq_n_s32(0);
3205
+ int32x4_t sumi_2 = vdupq_n_s32(0);
3206
+ int32x4_t sumi_3 = vdupq_n_s32(0);
3207
+
3208
+ for (int k = 0; k < 4; k++) {
3209
+ int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0);
3210
+ int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64);
3211
+
3212
+ uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k);
3213
+ int8x16_t b_hi = vqtbl1q_s8(kvalues, b >> 4);
3214
+ int8x16_t b_lo = vqtbl1q_s8(kvalues, b & 0xF);
3215
+
3216
+ sumi_0 = vdotq_laneq_s32(sumi_0, b_lo, a_0, 0);
3217
+ sumi_1 = vdotq_laneq_s32(sumi_1, b_lo, a_0, 1);
3218
+ sumi_2 = vdotq_laneq_s32(sumi_2, b_lo, a_0, 2);
3219
+ sumi_3 = vdotq_laneq_s32(sumi_3, b_lo, a_0, 3);
3220
+ sumi_0 = vdotq_laneq_s32(sumi_0, b_hi, a_1, 0);
3221
+ sumi_1 = vdotq_laneq_s32(sumi_1, b_hi, a_1, 1);
3222
+ sumi_2 = vdotq_laneq_s32(sumi_2, b_hi, a_1, 2);
3223
+ sumi_3 = vdotq_laneq_s32(sumi_3, b_hi, a_1, 3);
3224
+ }
3225
+
3226
+ sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0));
3227
+ sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1));
3228
+ sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2));
3229
+ sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3));
3230
+ }
3231
+
3232
+ for (int m = 0; m < 4; m++) {
3233
+ vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);
3234
+ }
3235
+ }
3236
+ }
3237
+ return;
3238
+ #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
3239
+ ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
3240
+ }
3241
+
3242
+ void ggml_gemm_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
3243
+ const int qk = QK8_0;
3244
+ const int nb = n / qk;
3245
+ const int ncols_interleaved = 4;
3246
+ const int blocklen = 4;
3247
+
3248
+ assert (n % qk == 0);
3249
+ assert (nr % 4 == 0);
3250
+ assert (nc % ncols_interleaved == 0);
3251
+
3252
+ UNUSED(s);
3253
+ UNUSED(bs);
3254
+ UNUSED(vx);
3255
+ UNUSED(vy);
3256
+ UNUSED(nr);
3257
+ UNUSED(nc);
3258
+ UNUSED(nb);
3259
+ UNUSED(ncols_interleaved);
3260
+ UNUSED(blocklen);
3261
+
3262
+ #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
3263
+ const int8x16_t kvalues = vld1q_s8(kvalues_mxfp4);
3264
+
3265
+ for (int y = 0; y < nr / 4; y++) {
3266
+ const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
3267
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
3268
+ const block_mxfp4x4 * b_ptr = (const block_mxfp4x4 *) vx + (x * nb);
3269
+
3270
+ float32x4_t sumf[4];
3271
+ for (int m = 0; m < 4; m++) {
3272
+ sumf[m] = vdupq_n_f32(0);
3273
+ }
3274
+
3275
+ for (int l = 0; l < nb; l++) {
3276
+ float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d));
3277
+ float32x4_t b_d = {
3278
+ GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[0]),
3279
+ GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[1]),
3280
+ GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[2]),
3281
+ GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[3]),
3282
+ };
3283
+
3284
+ int32x4_t sumi_0 = vdupq_n_s32(0);
3285
+ int32x4_t sumi_1 = vdupq_n_s32(0);
3286
+ int32x4_t sumi_2 = vdupq_n_s32(0);
3287
+ int32x4_t sumi_3 = vdupq_n_s32(0);
3288
+
3289
+ for (int k = 0; k < 4; k++) {
3290
+ int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0);
3291
+ int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64);
3292
+
3293
+ uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k);
3294
+ int8x16_t b_hi = vqtbl1q_s8(kvalues, b >> 4);
3295
+ int8x16_t b_lo = vqtbl1q_s8(kvalues, b & 0xF);
3296
+
3297
+ sumi_0 = vdotq_laneq_s32(sumi_0, b_lo, a_0, 0);
3298
+ sumi_1 = vdotq_laneq_s32(sumi_1, b_lo, a_0, 1);
3299
+ sumi_2 = vdotq_laneq_s32(sumi_2, b_lo, a_0, 2);
3300
+ sumi_3 = vdotq_laneq_s32(sumi_3, b_lo, a_0, 3);
3301
+ sumi_0 = vdotq_laneq_s32(sumi_0, b_hi, a_1, 0);
3302
+ sumi_1 = vdotq_laneq_s32(sumi_1, b_hi, a_1, 1);
3303
+ sumi_2 = vdotq_laneq_s32(sumi_2, b_hi, a_1, 2);
3304
+ sumi_3 = vdotq_laneq_s32(sumi_3, b_hi, a_1, 3);
3305
+ }
3306
+
3307
+ sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0));
3308
+ sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1));
3309
+ sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2));
3310
+ sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3));
3311
+ }
3312
+
3313
+ for (int m = 0; m < 4; m++) {
3314
+ vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);
3315
+ }
3316
+ }
3317
+ }
3318
+ return;
3319
+ #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
3320
+ ggml_gemm_mxfp4_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
3321
+ }
3322
+
3323
+ void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
3324
+ constexpr int qk = QK_K;
3325
+ const int nb = n / qk;
3326
+
3327
+ constexpr int ncols_interleaved = 8;
3328
+ constexpr int blocklen = 4;
3329
+
3330
+ assert(n % qk == 0);
3331
+ assert(nr % 4 == 0);
3332
+ assert(nc % ncols_interleaved == 0);
3333
+
3334
+ UNUSED(nb);
3335
+ UNUSED(ncols_interleaved);
3336
+ UNUSED(blocklen);
3337
+
3338
+ #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
3339
+ constexpr int q8_k_blocklen = 4;
3340
+ constexpr int acc_size = 2 * 4; // 2 row pairs × 4 col pairs
3341
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
3342
+
3343
+ // 8 accumulators: 2 row pairs × 4 col pairs
3344
+ float32x4_t acc_f32[acc_size];
3345
+
3346
+ for (int y = 0; y < nr / q8_k_blocklen; y++) {
3347
+ const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
3348
+
3349
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
3350
+ const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
3351
+
3352
+ for (int i = 0; i < acc_size; i++) {
3353
+ acc_f32[i] = vdupq_n_f32(0);
3354
+ }
3355
+
3356
+ for (int b = 0; b < nb; b++) {
3357
+ // d4 0 1 2 3, 4 5 6 7
3358
+ float32x4_t q4_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d));
3359
+ float32x4_t q4_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4));
3360
+ // d8 0 1 2 3
3361
+ float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d);
3362
+ // mins
3363
+ float32x4_t q4_dmin_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin));
3364
+ float32x4_t q4_dmin_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4));
3365
+
3366
+ // Precomputation of scales and mins
3367
+ float32x4_t sbd_scale_0123[q8_k_blocklen];
3368
+ float32x4_t sbd_scale_4567[q8_k_blocklen];
3369
+ float32x4_t sbd_min_0123[q8_k_blocklen];
3370
+ float32x4_t sbd_min_4567[q8_k_blocklen];
3371
+
3372
+ sbd_scale_0123[0] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 0);
3373
+ sbd_scale_4567[0] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 0);
3374
+ sbd_min_0123[0] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 0);
3375
+ sbd_min_4567[0] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 0);
3376
+
3377
+ sbd_scale_0123[1] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 1);
3378
+ sbd_scale_4567[1] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 1);
3379
+ sbd_min_0123[1] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 1);
3380
+ sbd_min_4567[1] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 1);
3381
+
3382
+ sbd_scale_0123[2] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 2);
3383
+ sbd_scale_4567[2] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 2);
3384
+ sbd_min_0123[2] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 2);
3385
+ sbd_min_4567[2] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 2);
3386
+
3387
+ sbd_scale_0123[3] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 3);
3388
+ sbd_scale_4567[3] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 3);
3389
+ sbd_min_0123[3] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 3);
3390
+ sbd_min_4567[3] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 3);
3391
+
3392
+ // Precomputation of bsums, each vpaddq calcs all the bsums for each row
3393
+ const int16x8_t bsums[q8_k_blocklen] = {
3394
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
3395
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
3396
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
3397
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
3398
+ };
3399
+ int16_t bsums_arr[QK_K / 64][8];
3400
+ for (int q8_row = 0; q8_row < 4; q8_row++) {
3401
+ vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
3402
+ }
3403
+
3404
+ // interleaved bias_acc: [0]->r0 0123, [1]->r1 0123, .., [4]->r0 4567, [5]->r1 4567 ..
3405
+ int32x4_t bias_acc[acc_size];
3406
+ for (int i = 0; i < acc_size; i++) {
3407
+ bias_acc[i] = vdupq_n_s32(0);
3408
+ }
3409
+
3410
+ for (int sb = 0; sb < QK_K / 64; sb++) {
3411
+ // Int accumulators for qs vecdot (4 row x 2 col quartets)
3412
+ int32x4_t acc_lo[acc_size];
3413
+ int32x4_t acc_hi[acc_size];
3414
+ for (int i = 0; i < acc_size; i++) {
3415
+ acc_lo[i] = vdupq_n_s32(0);
3416
+ acc_hi[i] = vdupq_n_s32(0);
3417
+ }
3418
+ // Need scales for the low and high nibbles
3419
+ // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
3420
+ int16x8_t q4sb_scales[2];
3421
+ int16x8_t q4sb_mins[2];
3422
+ for (int i = 0; i < 2; i++) {
3423
+ int8_t aux_q4sb[8];
3424
+ const int offset = sb * 24 + i * 12;
3425
+ decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
3426
+ q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
3427
+ }
3428
+
3429
+ constexpr int reads_per_sb = 8; // 8 * 16 bytes each => 32 qs * 4 rows
3430
+ for (int k = 0; k < reads_per_sb; k++) {
3431
+ const int8x16_t q8_blk0 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k);
3432
+ const int8x16_t q8_blk1 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k + 128);
3433
+
3434
+ // 0..3 & 32..35
3435
+ const uint8x16_t q4_0123 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k);
3436
+ const uint8x16_t q4_4567 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k + 16);
3437
+
3438
+ const int8x16_t q4_0123_lo = vreinterpretq_s8_u8(vandq_u8(q4_0123, m4b));
3439
+ const int8x16_t q4_0123_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_0123, 4));
3440
+
3441
+ acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q4_0123_lo, q8_blk0, 0); // 0..3 r0 c0123
3442
+ acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q4_0123_lo, q8_blk0, 1); // 0..3 r1 c0123
3443
+ acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q4_0123_lo, q8_blk0, 2); // 0..3 r2 c0123
3444
+ acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q4_0123_lo, q8_blk0, 3); // 0..3 r3 c0123
3445
+
3446
+ acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q4_0123_hi, q8_blk1, 0); // 32..35 r0 c0123
3447
+ acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q4_0123_hi, q8_blk1, 1); // 32..35 r1 c0123
3448
+ acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q4_0123_hi, q8_blk1, 2); // 32..35 r2 c0123
3449
+ acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q4_0123_hi, q8_blk1, 3); // 32..35 r3 c0123
3450
+
3451
+ const int8x16_t q4_4567_lo = vreinterpretq_s8_u8(vandq_u8(q4_4567, m4b));
3452
+ const int8x16_t q4_4567_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_4567, 4));
3453
+
3454
+ acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q4_4567_lo, q8_blk0, 0); // 0..3 r0 c4567
3455
+ acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q4_4567_lo, q8_blk0, 1); // 0..3 r1 c4567
3456
+ acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q4_4567_lo, q8_blk0, 2); // 0..3 r2 c4567
3457
+ acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q4_4567_lo, q8_blk0, 3); // 0..3 r3 c4567
3458
+
3459
+ acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q4_4567_hi, q8_blk1, 0); // 32..35 r0 c4567
3460
+ acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q4_4567_hi, q8_blk1, 1); // 32..35 r1 c4567
3461
+ acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q4_4567_hi, q8_blk1, 2); // 32..35 r2 c4567
3462
+ acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q4_4567_hi, q8_blk1, 3); // 32..35 r3 c4567
3463
+ }
3464
+
3465
+ // Scale and bias application
3466
+ // acc is stored interleaved to match output layout
3467
+ const int16x4_t sc_0123_lo = vget_low_s16(q4sb_scales[0]);
3468
+ const int16x4_t sc_4567_lo = vget_high_s16(q4sb_scales[0]);
3469
+ const int16x4_t sc_0123_hi = vget_low_s16(q4sb_scales[1]);
3470
+ const int16x4_t sc_4567_hi = vget_high_s16(q4sb_scales[1]);
3471
+ for (int row = 0; row < q8_k_blocklen; row++) {
3472
+ // Bias correction
3473
+ // row c0123 blk0 and blk1
3474
+ const float32x4_t sumf_0123 =
3475
+ vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[row]),
3476
+ vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[row])));
3477
+ acc_f32[2 * row] = vfmaq_f32(acc_f32[2 * row], sbd_scale_0123[row], sumf_0123);
3478
+
3479
+ // row c4567 blk0 and blk1
3480
+ const float32x4_t sumf_4567 =
3481
+ vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[row + 4]),
3482
+ vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[row + 4])));
3483
+ acc_f32[2 * row + 1] = vfmaq_f32(acc_f32[2 * row + 1], sbd_scale_4567[row], sumf_4567);
3484
+
3485
+ // Bias
3486
+ const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][row * 2]);
3487
+ const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][row * 2 + 1]);
3488
+
3489
+ // row c0123 blk0 and blk1
3490
+ bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
3491
+ bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
3492
+
3493
+ // row c4567 blk0 and blk1
3494
+ bias_acc[2 * row + 1] =
3495
+ vmlal_s16(bias_acc[2 * row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
3496
+ bias_acc[2 * row + 1] =
3497
+ vmlal_s16(bias_acc[2 * row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
3498
+ }
3499
+ } // for sb
3500
+
3501
+ for (int row = 0; row < q8_k_blocklen; row++) {
3502
+ acc_f32[2 * row] = vmlsq_f32(acc_f32[2 * row], vcvtq_f32_s32(bias_acc[2 * row]), sbd_min_0123[row]);
3503
+ acc_f32[2 * row + 1] =
3504
+ vmlsq_f32(acc_f32[2 * row + 1], vcvtq_f32_s32(bias_acc[2 * row + 1]), sbd_min_4567[row]);
3505
+ }
3506
+ } // for b
3507
+
3508
+ for (int i = 0; i < q8_k_blocklen; i++) {
3509
+ int row = y * q8_k_blocklen + i;
3510
+ for (int j = 0; j < 2; j++) {
3511
+ int col = x * ncols_interleaved + j * 4;
3512
+ int offset = row * bs + col;
3513
+ vst1q_f32(s + offset, acc_f32[2 * i + j]);
3514
+ }
3515
+ }
3516
+ } // for x
3517
+ } // for y
3518
+ return;
3519
+ #endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
3520
+ ggml_gemm_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
3521
+ }
3522
+
3523
+ void ggml_gemm_q5_K_8x4_q8_K(int n,
3524
+ float * GGML_RESTRICT s,
3525
+ size_t bs,
3526
+ const void * GGML_RESTRICT vx,
3527
+ const void * GGML_RESTRICT vy,
3528
+ int nr,
3529
+ int nc) {
3530
+ constexpr int qk = QK_K;
3531
+ const int nb = n / qk;
3532
+
3533
+ constexpr int ncols_interleaved = 8;
3534
+ constexpr int blocklen = 4;
3535
+
3536
+ assert(n % qk == 0);
3537
+ assert(nr % 4 == 0);
3538
+ assert(nc % ncols_interleaved == 0);
3539
+
3540
+ UNUSED(nb);
3541
+ UNUSED(ncols_interleaved);
3542
+ UNUSED(blocklen);
3543
+
3544
+ #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
3545
+ constexpr int q8_k_blocklen = 4;
3546
+ constexpr int acc_size = 2 * 4; // 2 row pairs, 4 col pairs
3547
+ constexpr int col_groups = ncols_interleaved / 4;
3548
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
3549
+ const uint8x16_t mone = vdupq_n_u8(1);
3550
+ const uint8x16_t mtwo = vdupq_n_u8(2);
3551
+
3552
+ // 8 accumulators: 2 row pairs, 4 col pairs
3553
+ float32x4_t acc_f32[acc_size];
3554
+
3555
+ for (int y = 0; y < nr / q8_k_blocklen; y++) {
3556
+ const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
3557
+
3558
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
3559
+ const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);
3560
+
3561
+ for (int i = 0; i < acc_size; i++) {
3562
+ acc_f32[i] = vdupq_n_f32(0);
3563
+ }
3564
+
3565
+ for (int b = 0; b < nb; b++) {
3566
+ // d5 0 1 2 3, 4 5 6 7
3567
+ float32x4_t q5_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d));
3568
+ float32x4_t q5_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4));
3569
+ // d8 0 1 2 3
3570
+ float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d);
3571
+ // mins
3572
+ float32x4_t q5_dmin_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin));
3573
+ float32x4_t q5_dmin_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4));
3574
+
3575
+ // Precomputation of scales and mins
3576
+ float32x4_t sbd_scale_0123[q8_k_blocklen];
3577
+ float32x4_t sbd_scale_4567[q8_k_blocklen];
3578
+ float32x4_t sbd_min_0123[q8_k_blocklen];
3579
+ float32x4_t sbd_min_4567[q8_k_blocklen];
3580
+
3581
+ sbd_scale_0123[0] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 0);
3582
+ sbd_scale_4567[0] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 0);
3583
+ sbd_min_0123[0] = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 0);
3584
+ sbd_min_4567[0] = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 0);
3585
+
3586
+ sbd_scale_0123[1] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 1);
3587
+ sbd_scale_4567[1] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 1);
3588
+ sbd_min_0123[1] = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 1);
3589
+ sbd_min_4567[1] = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 1);
3590
+
3591
+ sbd_scale_0123[2] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 2);
3592
+ sbd_scale_4567[2] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 2);
3593
+ sbd_min_0123[2] = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 2);
3594
+ sbd_min_4567[2] = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 2);
3595
+
3596
+ sbd_scale_0123[3] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 3);
3597
+ sbd_scale_4567[3] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 3);
3598
+ sbd_min_0123[3] = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 3);
3599
+ sbd_min_4567[3] = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 3);
3600
+
3601
+ // Precomputation of bsums, each vpaddq calcs all the bsums for each row
3602
+ const int16x8_t bsums[q8_k_blocklen] = {
3603
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
3604
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
3605
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
3606
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
3607
+ };
3608
+ int16_t bsums_arr[QK_K / 64][8];
3609
+ for (int q8_row = 0; q8_row < 4; q8_row++) {
3610
+ vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
3611
+ }
3612
+
3613
+ // interleaved bias_acc: [0]->r0 0123, [1]->r1 0123, .., [4]->r0 4567, [5]->r1 4567 ..
3614
+ int32x4_t bias_acc[acc_size];
3615
+ for (int i = 0; i < acc_size; i++) {
3616
+ bias_acc[i] = vdupq_n_s32(0);
3617
+ }
3618
+
3619
+ uint8x16_t qh[col_groups][8];
3620
+ for (int c = 0; c < col_groups; c++) {
3621
+ for (int i = 0; i < 8; i++) {
3622
+ qh[c][i] = vld1q_u8(q5_ptr[b].qh + i * 32 + 16 * c);
3623
+ }
3624
+ }
3625
+
3626
+ for (int sb = 0; sb < QK_K / 64; sb++) {
3627
+ // Int accumulators for qs vecdot (4 row * 2 col quartets)
3628
+ int32x4_t acc_lo[acc_size];
3629
+ int32x4_t acc_hi[acc_size];
3630
+ for (int i = 0; i < acc_size; i++) {
3631
+ acc_lo[i] = vdupq_n_s32(0);
3632
+ acc_hi[i] = vdupq_n_s32(0);
3633
+ }
3634
+ // Need scales for the low and high nibbles
3635
+ // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
3636
+ int16x8_t q5sb_scales[2];
3637
+ int16x8_t q5sb_mins[2];
3638
+ for (int i = 0; i < 2; i++) {
3639
+ int8_t aux_q5sb[8];
3640
+ const int offset = sb * 24 + i * 12;
3641
+ decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb);
3642
+ q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb));
3643
+ }
3644
+
3645
+ constexpr int reads_per_sb = 8; // 8 * 16 bytes each => 32 qs * 4 rows
3646
+ for (int k = 0; k < reads_per_sb; k++) {
3647
+ const int8x16_t q8_blk0 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k);
3648
+ const int8x16_t q8_blk1 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k + 128);
3649
+
3650
+ // 0..3 & 32..35
3651
+ const uint8x16_t q5_0123 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 32 * k);
3652
+ const uint8x16_t q5_4567 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 32 * k + 16);
3653
+
3654
+ // NOTE: This is the only difference with q4_K
3655
+ const uint8x16_t hbit_lo_0123 = vandq_u8(qh[0][k], mone);
3656
+ const uint8x16_t hbit_hi_0123 = vshlq_n_u8(vandq_u8(qh[0][k], mtwo), 3);
3657
+ qh[0][k] = vshrq_n_u8(qh[0][k], 2);
3658
+ const uint8x16_t hbit_lo_4567 = vandq_u8(qh[1][k], mone);
3659
+ const uint8x16_t hbit_hi_4567 = vshlq_n_u8(vandq_u8(qh[1][k], mtwo), 3);
3660
+ qh[1][k] = vshrq_n_u8(qh[1][k], 2);
3661
+ // From here, same as q4_K
3662
+
3663
+ const int8x16_t q5_0123_lo =
3664
+ vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_0123, m4b), hbit_lo_0123, 4));
3665
+ const int8x16_t q5_0123_hi =
3666
+ vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_0123, 4), hbit_hi_0123));
3667
+
3668
+ acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q5_0123_lo, q8_blk0, 0); // 0..3 r0 c0123
3669
+ acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q5_0123_lo, q8_blk0, 1); // 0..3 r1 c0123
3670
+ acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q5_0123_lo, q8_blk0, 2); // 0..3 r2 c0123
3671
+ acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q5_0123_lo, q8_blk0, 3); // 0..3 r3 c0123
3672
+
3673
+ acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q5_0123_hi, q8_blk1, 0); // 32..35 r0 c0123
3674
+ acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q5_0123_hi, q8_blk1, 1); // 32..35 r1 c0123
3675
+ acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q5_0123_hi, q8_blk1, 2); // 32..35 r2 c0123
3676
+ acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q5_0123_hi, q8_blk1, 3); // 32..35 r3 c0123
3677
+
3678
+ const int8x16_t q5_4567_lo =
3679
+ vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_4567, m4b), hbit_lo_4567, 4));
3680
+ const int8x16_t q5_4567_hi =
3681
+ vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_4567, 4), hbit_hi_4567));
3682
+
3683
+ acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q5_4567_lo, q8_blk0, 0); // 0..3 r0 c4567
3684
+ acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q5_4567_lo, q8_blk0, 1); // 0..3 r1 c4567
3685
+ acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q5_4567_lo, q8_blk0, 2); // 0..3 r2 c4567
3686
+ acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q5_4567_lo, q8_blk0, 3); // 0..3 r3 c4567
3687
+
3688
+ acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q5_4567_hi, q8_blk1, 0); // 32..35 r0 c4567
3689
+ acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q5_4567_hi, q8_blk1, 1); // 32..35 r1 c4567
3690
+ acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q5_4567_hi, q8_blk1, 2); // 32..35 r2 c4567
3691
+ acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q5_4567_hi, q8_blk1, 3); // 32..35 r3 c4567
3692
+ }
3693
+
3694
+ // Scale and bias application
3695
+ // acc is stored interleaved to match output layout
3696
+ const int16x4_t sc_0123_lo = vget_low_s16(q5sb_scales[0]);
3697
+ const int16x4_t sc_4567_lo = vget_high_s16(q5sb_scales[0]);
3698
+ const int16x4_t sc_0123_hi = vget_low_s16(q5sb_scales[1]);
3699
+ const int16x4_t sc_4567_hi = vget_high_s16(q5sb_scales[1]);
3700
+ for (int row = 0; row < q8_k_blocklen; row++) {
3701
+ // Bias correction
3702
+ // row c0123 blk0 and blk1
3703
+ const float32x4_t sumf_0123 =
3704
+ vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[row]),
3705
+ vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[row])));
3706
+ acc_f32[2 * row] = vfmaq_f32(acc_f32[2 * row], sbd_scale_0123[row], sumf_0123);
3707
+
3708
+ // row c4567 blk0 and blk1
3709
+ const float32x4_t sumf_4567 =
3710
+ vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[row + 4]),
3711
+ vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[row + 4])));
3712
+ acc_f32[2 * row + 1] = vfmaq_f32(acc_f32[2 * row + 1], sbd_scale_4567[row], sumf_4567);
3713
+
3714
+ // Bias
3715
+ const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][row * 2]);
3716
+ const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][row * 2 + 1]);
3717
+
3718
+ // row c0123 blk0 and blk1
3719
+ bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_lo, vget_low_s16(q5sb_mins[0]));
3720
+ bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_hi, vget_low_s16(q5sb_mins[1]));
3721
+
3722
+ // row c4567 blk0 and blk1
3723
+ bias_acc[2 * row + 1] =
3724
+ vmlal_s16(bias_acc[2 * row + 1], bsums_vec_lo, vget_high_s16(q5sb_mins[0]));
3725
+ bias_acc[2 * row + 1] =
3726
+ vmlal_s16(bias_acc[2 * row + 1], bsums_vec_hi, vget_high_s16(q5sb_mins[1]));
3727
+ }
3728
+ } // for sb
3729
+
3730
+ for (int row = 0; row < q8_k_blocklen; row++) {
3731
+ acc_f32[2 * row] = vmlsq_f32(acc_f32[2 * row], vcvtq_f32_s32(bias_acc[2 * row]), sbd_min_0123[row]);
3732
+ acc_f32[2 * row + 1] =
3733
+ vmlsq_f32(acc_f32[2 * row + 1], vcvtq_f32_s32(bias_acc[2 * row + 1]), sbd_min_4567[row]);
3734
+ }
3735
+ } // for b
3736
+
3737
+ for (int i = 0; i < q8_k_blocklen; i++) {
3738
+ int row = y * q8_k_blocklen + i;
3739
+ for (int j = 0; j < 2; j++) {
3740
+ int col = x * ncols_interleaved + j * 4;
3741
+ int offset = row * bs + col;
3742
+ vst1q_f32(s + offset, acc_f32[2 * i + j]);
3743
+ }
3744
+ }
3745
+ } // for x
3746
+ } // for y
3747
+ return;
3748
+ #endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
3749
+ ggml_gemm_q5_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
3750
+ }
3751
+
3752
+ void ggml_gemm_q4_K_8x8_q8_K(int n,
3753
+ float * GGML_RESTRICT s,
3754
+ size_t bs,
3755
+ const void * GGML_RESTRICT vx,
3756
+ const void * GGML_RESTRICT vy,
3757
+ int nr,
3758
+ int nc) {
3759
+ constexpr int qk = QK_K;
3760
+ const int nb = n / qk;
3761
+
3762
+ constexpr int ncols_interleaved = 8;
3763
+ constexpr int blocklen = 8;
3764
+
3765
+ assert(n % qk == 0);
3766
+ assert(nr % 4 == 0);
3767
+ assert(nc % ncols_interleaved == 0);
3768
+
3769
+ UNUSED(nb);
3770
+ UNUSED(ncols_interleaved);
3771
+ UNUSED(blocklen);
3772
+
3773
+ #if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
3774
+ if (svcntb() * 8 == 256) {
3775
+ constexpr int q8_k_blocklen = 4;
3776
+ const svuint8_t m4b_1 = svdup_n_u8(0x0f);
3777
+ // 8 accumulators: 2 row pairs × 4 col pairs
3778
+ svfloat32_t acc_f32_01, acc_f32_23, acc_f32_45, acc_f32_67;
3779
+ uint32_t idx_arr[8] = { 0, 2, 4, 6, 1, 3, 5, 7 };
3780
+ svbool_t pg = svptrue_pat_b32(SV_VL8);
3781
+ svuint32_t idx = svld1(pg, idx_arr);
3782
+
3783
+ static const uint32_t idx_data[8] = {0, 4, 2, 6, 1, 5, 3, 7};
3784
+ svuint32_t idx1 = svld1_u32(svptrue_b32(), idx_data);
3785
+
3786
+ for (int y = 0; y < nr / q8_k_blocklen; y++) {
3787
+ const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
3788
+
3789
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
3790
+ const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
3791
+
3792
+ acc_f32_01 = svdup_n_f32(0);
3793
+ acc_f32_23 = svdup_n_f32(0);
3794
+ acc_f32_45 = svdup_n_f32(0);
3795
+ acc_f32_67 = svdup_n_f32(0);
3796
+
3797
+ for (int b = 0; b < nb; b++) {
3798
+ // bsums pairs belongs to the same q8_k subblock
3799
+ // 64 elements loaded and made sum of 0-7 and 8-15 sum || 16-23 and 24 - 31 sum
3800
+ const int16x8_t bsums[4]{
3801
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
3802
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
3803
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
3804
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
3805
+ };
3806
+
3807
+ int32_t bsums_arr32[4][8];
3808
+
3809
+ for (int q8_row = 0; q8_row < 4; q8_row++) {
3810
+ int16x8_t v16 = bsums[q8_row];
3811
+
3812
+ // low 4
3813
+ int32x4_t v32_lo = vmovl_s16(vget_low_s16(v16));
3814
+ vst1q_s32(&bsums_arr32[q8_row][0], v32_lo);
3815
+
3816
+ // high 4
3817
+ int32x4_t v32_hi = vmovl_s16(vget_high_s16(v16));
3818
+ vst1q_s32(&bsums_arr32[q8_row][4], v32_hi);
3819
+ }
3820
+
3821
+ svint32_t sb_acc_0 = svdup_n_s32(0);
3822
+ svint32_t sb_acc_2 = svdup_n_s32(0);
3823
+
3824
+ svint32_t acc_00 = svdup_n_s32(0);
3825
+ svint32_t acc_11 = svdup_n_s32(0);
3826
+ svint32_t acc_22 = svdup_n_s32(0);
3827
+ svint32_t acc_33 = svdup_n_s32(0);
3828
+ svint32_t acc_44 = svdup_n_s32(0);
3829
+ svint32_t acc_55 = svdup_n_s32(0);
3830
+ svint32_t acc_66 = svdup_n_s32(0);
3831
+ svint32_t acc_77 = svdup_n_s32(0);
3832
+
3833
+ svint32_t bias_acc_00 = svdup_n_s32(0);
3834
+ svint32_t bias_acc_22 = svdup_n_s32(0);
3835
+ svint32_t bias_acc_44 = svdup_n_s32(0);
3836
+ svint32_t bias_acc_66 = svdup_n_s32(0);
3837
+
3838
+ for (int sb = 0; sb < QK_K / 64; sb++) {
3839
+ // Need scales for the low and high nibbles
3840
+ // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
3841
+ svint32_t block_scale_0, block_scale_1, block_scale_2, block_scale_3;
3842
+ svint32_t q4sb_mins_0, q4sb_mins_1;
3843
+ {
3844
+ // 2-superblock I am working on
3845
+ const int offset = sb * 24 + 0 * 12;
3846
+ const uint8_t * scales_in = &q4_ptr[b].scales[offset];
3847
+
3848
+ const int offset1 = sb * 24 + 12;
3849
+ const uint8_t * scales_in1 = &q4_ptr[b].scales[offset1];
3850
+
3851
+ constexpr uint32_t kmask1 = 0x3f3f3f3f;
3852
+ constexpr uint32_t kmask2 = 0x0f0f0f0f;
3853
+ constexpr uint32_t kmask3 = 0x03030303;
3854
+ constexpr uint8_t scales_size = 12;
3855
+
3856
+ uint32_t sm[3];
3857
+ memcpy(sm, scales_in, scales_size);
3858
+
3859
+ uint32_t sm1[3];
3860
+ memcpy(sm1, scales_in1, scales_size);
3861
+
3862
+ const uint32_t mins_0_3 = sm[1] & kmask1;
3863
+ const uint32_t mins_4_7 = ((sm[2] >> 4) & kmask2) | (((sm[1] >> 6) & kmask3) << 4);
3864
+
3865
+ const uint32_t mins_0_3_1 = sm1[1] & kmask1;
3866
+ const uint32_t mins_4_7_1 = ((sm1[2] >> 4) & kmask2) | (((sm1[1] >> 6) & kmask3) << 4);
3867
+
3868
+ svuint32_t mins_u32_temp = svzip1_u32(svdup_n_u32(mins_0_3), svdup_n_u32(mins_4_7));
3869
+ svuint32_t mins_u32_temp_1 = svzip1_u32(svdup_n_u32(mins_0_3_1), svdup_n_u32(mins_4_7_1));
3870
+
3871
+ /* reinterpret u32 → u8 */
3872
+ svuint8_t mins_u8 = svreinterpret_u8_u32(mins_u32_temp);
3873
+ svuint8_t mins_u8_1 = svreinterpret_u8_u32(mins_u32_temp_1);
3874
+
3875
+ /* widen u8 → u16->u32 (lower half only) */
3876
+ svuint32_t mins_u16 = svunpklo_u32(svunpklo_u16(mins_u8));
3877
+ svuint32_t mins_u16_1 = svunpklo_u32(svunpklo_u16(mins_u8_1));
3878
+
3879
+ q4sb_mins_0 = svreinterpret_s32_u32(mins_u16);
3880
+ q4sb_mins_1 = svreinterpret_s32_u32(mins_u16_1);
3881
+
3882
+ uint32_t scales_u32_0 = sm[0] & kmask1;
3883
+ uint32_t scales_u32_1 = (sm[2] & kmask2) | (((sm[0] >> 6) & kmask3) << 4);
3884
+ uint32_t scales_u32_2 = sm1[0] & kmask1;
3885
+ uint32_t scales_u32_3 = (sm1[2] & kmask2) | (((sm1[0] >> 6) & kmask3) << 4);
3886
+
3887
+ svuint32_t S01 = svdup_n_u32(scales_u32_0);
3888
+ svuint32_t S23 = svdup_n_u32(scales_u32_1);
3889
+ svuint32_t R01 = svdup_n_u32(scales_u32_2);
3890
+ svuint32_t R23 = svdup_n_u32(scales_u32_3);
3891
+
3892
+ svint8_t S01_b = svreinterpret_s8_u32(S01);
3893
+ svint8_t S23_b = svreinterpret_s8_u32(S23);
3894
+ svint8_t R01_b = svreinterpret_s8_u32(R01);
3895
+ svint8_t R23_b = svreinterpret_s8_u32(R23);
3896
+
3897
+ svint32_t S01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S01_b, S01_b)));
3898
+ svint32_t R01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R01_b, R01_b)));
3899
+ svint32_t S23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S23_b, S23_b)));
3900
+ svint32_t R23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R23_b, R23_b)));
3901
+
3902
+ block_scale_0 = svtbl_s32(svzip1_s32(S01_d, R01_d), idx);
3903
+ block_scale_1 = svtbl_s32(svzip2_s32(S01_d, R01_d), idx);
3904
+ block_scale_2 = svtbl_s32(svzip1_s32(S23_d, R23_d), idx);
3905
+ block_scale_3 = svtbl_s32(svzip2_s32(S23_d, R23_d), idx);
3906
+ }
3907
+
3908
+ const int8_t * q8_base_1 = q8_ptr[b].qs + sb * 256;
3909
+
3910
+ // Load 32-byte per row pair, 1 subblock each time
3911
+ // predicate for activating higher lanes for 16 int8 elements
3912
+ const svbool_t ph16 = svptrue_pat_b8(SV_VL16);
3913
+ // predicate for activating lower lanes for 16 int8 elements
3914
+ const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16);
3915
+
3916
+ svint8_t q8_qs_0 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 0), svld1_s8(pl16, q8_base_1 + 112));
3917
+ svint8_t q8_qs_2 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 32), svld1_s8(pl16, q8_base_1 + 144));
3918
+ svint8_t q8_qs_4 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 64), svld1_s8(pl16, q8_base_1 + 176));
3919
+ svint8_t q8_qs_6 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 96), svld1_s8(pl16, q8_base_1 + 208));
3920
+
3921
+ svint8_t q8_qs_1 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 16), svld1_s8(pl16, q8_base_1 + 128));
3922
+ svint8_t q8_qs_3 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 48), svld1_s8(pl16, q8_base_1 + 160));
3923
+ svint8_t q8_qs_5 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 80), svld1_s8(pl16, q8_base_1 + 192));
3924
+ svint8_t q8_qs_7 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 112), svld1_s8(pl16, q8_base_1 + 224));
3925
+
3926
+ // Q4s columns iterated in pairs (01, 23, 45, 67)
3927
+ for (int cp = 0; cp < ncols_interleaved / 2; cp++) {
3928
+
3929
+ sb_acc_0 = svdup_n_s32(0);
3930
+ sb_acc_2 = svdup_n_s32(0);
3931
+
3932
+ svuint8_t q4_qs_cp_00 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 0);
3933
+ svuint8_t q4_qs_cp_01 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 64);
3934
+ svuint8_t q4_qs_cp_02 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 128);
3935
+ svuint8_t q4_qs_cp_03 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 192);
3936
+
3937
+ svint8_t q4_nibbles_00 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_00, m4b_1), 4));
3938
+ svint8_t q4_nibbles_01 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_01, m4b_1), 4));
3939
+ svint8_t q4_nibbles_02 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_02, m4b_1), 4));
3940
+ svint8_t q4_nibbles_03 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_03, m4b_1), 4));
3941
+
3942
+ sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_00, q8_qs_0);
3943
+ sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_01, q8_qs_2);
3944
+
3945
+ sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_02, q8_qs_4);
3946
+ sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_03, q8_qs_6);
3947
+
3948
+ sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_00, q8_qs_1);
3949
+ sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_01, q8_qs_3);
3950
+
3951
+ sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_02, q8_qs_5);
3952
+ sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_03, q8_qs_7);
3953
+
3954
+ if(cp == 0) {
3955
+ acc_00 = svmla_s32_m(svptrue_b32(), acc_00, sb_acc_0, block_scale_0);
3956
+ acc_44 = svmla_s32_m(svptrue_b32(), acc_44, sb_acc_2, block_scale_0);
3957
+ }
3958
+ if(cp == 1) {
3959
+ acc_11 = svmla_s32_m(svptrue_b32(), acc_11, sb_acc_0, block_scale_1);
3960
+ acc_55 = svmla_s32_m(svptrue_b32(), acc_55, sb_acc_2, block_scale_1);
3961
+ }
3962
+ if(cp == 2) {
3963
+ acc_22 = svmla_s32_m(svptrue_b32(), acc_22, sb_acc_0, block_scale_2);
3964
+ acc_66 = svmla_s32_m(svptrue_b32(), acc_66, sb_acc_2, block_scale_2);
3965
+ }
3966
+ if(cp == 3) {
3967
+ acc_33 = svmla_s32_m(svptrue_b32(), acc_33, sb_acc_0, block_scale_3);
3968
+ acc_77 = svmla_s32_m(svptrue_b32(), acc_77, sb_acc_2, block_scale_3);
3969
+ }
3970
+ }
3971
+
3972
+ bias_acc_00 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_00, svdup_n_s32(bsums_arr32[sb][0]), q4sb_mins_0);
3973
+ bias_acc_00 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_00, svdup_n_s32(bsums_arr32[sb][1]), q4sb_mins_1);
3974
+
3975
+ bias_acc_22 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_22, svdup_n_s32(bsums_arr32[sb][2]), q4sb_mins_0);
3976
+ bias_acc_22 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_22, svdup_n_s32(bsums_arr32[sb][3]), q4sb_mins_1);
3977
+
3978
+ bias_acc_44 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_44, svdup_n_s32(bsums_arr32[sb][4]), q4sb_mins_0);
3979
+ bias_acc_44 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_44, svdup_n_s32(bsums_arr32[sb][5]), q4sb_mins_1);
3980
+
3981
+ bias_acc_66 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_66, svdup_n_s32(bsums_arr32[sb][6]), q4sb_mins_0);
3982
+ bias_acc_66 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_66, svdup_n_s32(bsums_arr32[sb][7]), q4sb_mins_1);
3983
+ } // for sb
3984
+
3985
+
3986
+ acc_00 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_00, svext_s32(acc_00, acc_00, 4));
3987
+ acc_11 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_11, svext_s32(acc_11, acc_11, 4));
3988
+ acc_22 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_22, svext_s32(acc_22, acc_22, 4));
3989
+ acc_33 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_33, svext_s32(acc_33, acc_33, 4));
3990
+ acc_44 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_44, svext_s32(acc_44, acc_44, 4));
3991
+ acc_55 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_55, svext_s32(acc_55, acc_55, 4));
3992
+ acc_66 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_66, svext_s32(acc_66, acc_66, 4));
3993
+ acc_77 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_77, svext_s32(acc_77, acc_77, 4));
3994
+
3995
+ svint32_t reorder_acc_01 = svtbl_s32( svzip1_s32( svtrn1_s32(acc_00, acc_11), svtrn1_s32(acc_22, acc_33)), idx1);
3996
+ svint32_t reorder_acc_23 = svtbl_s32( svzip1_s32( svtrn2_s32(acc_00, acc_11), svtrn2_s32(acc_22, acc_33)), idx1);
3997
+
3998
+ svint32_t reorder_acc_45 = svtbl_s32( svzip1_s32( svtrn1_s32(acc_44, acc_55), svtrn1_s32(acc_66, acc_77)), idx1);
3999
+ svint32_t reorder_acc_67 = svtbl_s32( svzip1_s32( svtrn2_s32(acc_44, acc_55), svtrn2_s32(acc_66, acc_77)), idx1);
4000
+
4001
+ // Broadcast q8 scalar
4002
+ svfloat32_t q8_d = svdup_f32(q8_ptr[b].d[0]);
4003
+
4004
+ svfloat32_t q4_dmin_temp = svcvt_f32_f16_x(svptrue_b32(), svzip1_f16( svld1_f16(svptrue_pat_b16(SV_VL8), (const __fp16 *)q4_ptr[b].dmin), svdup_f16(0)));
4005
+
4006
+ svfloat32_t q4_d_temp = svcvt_f32_f16_x(svptrue_b32(), svzip1_f16( svld1_f16(svptrue_pat_b16(SV_VL8), (const __fp16 *)q4_ptr[b].d), svdup_f16(0)));
4007
+
4008
+ svfloat32_t scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
4009
+ svfloat32_t dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
4010
+
4011
+ acc_f32_01 = svmls_f32_m(svptrue_b32(), acc_f32_01, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_00), dmins1);
4012
+ acc_f32_01 = svmla_f32_m(svptrue_b32(), acc_f32_01, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_01), scale1);
4013
+
4014
+ q8_d = svdup_f32(q8_ptr[b].d[1]);
4015
+
4016
+ scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
4017
+ dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
4018
+
4019
+ acc_f32_23 = svmls_f32_m(svptrue_b32(), acc_f32_23, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_22), dmins1);
4020
+ acc_f32_23 = svmla_f32_m(svptrue_b32(), acc_f32_23, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_23), scale1);
4021
+
4022
+ q8_d = svdup_f32(q8_ptr[b].d[2]);
4023
+
4024
+
4025
+ scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
4026
+ dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
4027
+
4028
+ acc_f32_45 = svmls_f32_m(svptrue_b32(), acc_f32_45, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_44), dmins1);
4029
+ acc_f32_45 = svmla_f32_m(svptrue_b32(), acc_f32_45, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_45), scale1);
4030
+
4031
+ q8_d = svdup_f32(q8_ptr[b].d[3]);
4032
+
4033
+ scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
4034
+ dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
4035
+
4036
+ acc_f32_67 = svmls_f32_m(svptrue_b32(), acc_f32_67, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_66), dmins1);
4037
+ acc_f32_67 = svmla_f32_m(svptrue_b32(), acc_f32_67, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_67), scale1);
4038
+
4039
+ } // for b
4040
+
4041
+ // With the previous reorder, the tile is already in the correct memory layout.
4042
+ // Predicate for exactly 4 lanes
4043
+ svbool_t pg4 = svptrue_pat_b32(SV_VL4);
4044
+ for (int i = 0; i < q8_k_blocklen; i++) {
4045
+ int row = y * q8_k_blocklen + i;
4046
+ for (int j = 0; j < 2; j++) {
4047
+ int col = x * ncols_interleaved + j * 4;
4048
+ int offset = row * bs + col;
4049
+
4050
+ if (i == 0 && j == 0) {
4051
+ // acc_f32_0 → lower half of acc_f32_01
4052
+ svst1_f32(pg4, s + offset, acc_f32_01);
4053
+ } else if (i == 0 && j == 1) {
4054
+ // acc_f32_1 → upper half of acc_f32_01
4055
+ svst1_f32(pg4, s + offset, svext_f32(acc_f32_01, acc_f32_01, 4));
4056
+ } else if (i == 1 && j == 0) {
4057
+ // acc_f32_2
4058
+ svst1_f32(pg4, s + offset, acc_f32_23);
4059
+ } else if (i == 1 && j == 1) {
4060
+ // acc_f32_3
4061
+ svst1_f32(pg4, s + offset, svext_f32(acc_f32_23, acc_f32_23, 4));
4062
+ } else if (i == 2 && j == 0) {
4063
+ // acc_f32_4
4064
+ svst1_f32(pg4, s + offset, acc_f32_45);
4065
+ } else if (i == 2 && j == 1) {
4066
+ // acc_f32_5
4067
+ svst1_f32(pg4, s + offset, svext_f32(acc_f32_45, acc_f32_45, 4));
4068
+ } else if (i == 3 && j == 0) {
4069
+ // acc_f32_6
4070
+ svst1_f32(pg4, s + offset, acc_f32_67);
4071
+ } else if (i == 3 && j == 1) {
4072
+ // acc_f32_7
4073
+ svst1_f32(pg4, s + offset, svext_f32(acc_f32_67, acc_f32_67, 4));
4074
+ }
4075
+ }
4076
+ }
4077
+ } // for x
4078
+ } // for y
4079
+ return;
4080
+ }
4081
+ #endif // SVE compile-time end
4082
+
4083
+ #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
4084
+ constexpr int q8_k_blocklen = 4;
4085
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
4086
+
4087
+ // 8 accumulators: 2 row pairs × 4 col pairs
4088
+ float32x4_t acc_f32[blocklen];
4089
+
4090
+ for (int y = 0; y < nr / q8_k_blocklen; y++) {
4091
+ const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
4092
+
4093
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
4094
+ const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
4095
+
4096
+ for (int i = 0; i < blocklen; i++) {
4097
+ acc_f32[i] = vdupq_n_f32(0);
4098
+ }
4099
+
4100
+ for (int b = 0; b < nb; b++) {
4101
+ // bsums pairs belongs to the same q8_k subblock
4102
+ const int16x8_t bsums[4]{
4103
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
4104
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
4105
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
4106
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
4107
+ };
4108
+ int16_t bsums_arr[4][8];
4109
+ for (int q8_row = 0; q8_row < 4; q8_row++) {
4110
+ vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
4111
+ }
4112
+
4113
+ int32x4_t sb_acc[4]; // Aux accumulators to store subblock (partial) results
4114
+ int32x4_t acc[8]; // rows 01 stored in [0][1][2][3] rows 23 stored in [4][5][6][7]
4115
+ int32x4_t bias_acc[8]; // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567, [2]->r1 0123 ...
4116
+ for (int i = 0; i < 8; i++) {
4117
+ acc[i] = vdupq_n_s32(0);
4118
+ bias_acc[i] = vdupq_n_s32(0);
4119
+ }
4120
+
4121
+ for (int sb = 0; sb < QK_K / 64; sb++) {
4122
+ // Need scales for the low and high nibbles
4123
+ // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
4124
+ int8_t q4sb_scales[2][8];
4125
+ int16x8_t q4sb_mins[2]; // int16 as its needed for bias_acc later
4126
+ for (int i = 0; i < 2; i++) {
4127
+ const int offset = sb * 24 + i * 12;
4128
+ decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]);
4129
+ }
4130
+
4131
+ // q8_ptr[b].qs has interleaved Q8 rows (01, 23)
4132
+ const int8_t * q8_base = q8_ptr[b].qs + sb * 256;
4133
+
4134
+ int8x16_t q8_qs_01[8];
4135
+ int8x16_t q8_qs_23[8];
4136
+
4137
+ // Load 32-byte per row pair, 1 subblock each time
4138
+ for (int i = 0; i < 8; i++) {
4139
+ const int offset = i * 32; // 16 for row 01, 16 for row 23
4140
+ q8_qs_01[i] = vld1q_s8(q8_base + offset);
4141
+ q8_qs_23[i] = vld1q_s8(q8_base + offset + 16);
4142
+ }
4143
+
4144
+ const int8x16_t q8s[2][8] = {
4145
+ { q8_qs_01[0], q8_qs_01[1], q8_qs_01[2], q8_qs_01[3],
4146
+ q8_qs_01[4], q8_qs_01[5], q8_qs_01[6], q8_qs_01[7] },
4147
+ { q8_qs_23[0], q8_qs_23[1], q8_qs_23[2], q8_qs_23[3],
4148
+ q8_qs_23[4], q8_qs_23[5], q8_qs_23[6], q8_qs_23[7] },
4149
+ };
4150
+
4151
+ // Q4s columns iterated in pairs (01, 23, 45, 67)
4152
+ for (int cp = 0; cp < ncols_interleaved / 2; cp++) {
4153
+ for (int i = 0; i < 4; i++) {
4154
+ sb_acc[i] = vdupq_n_s32(0);
4155
+ }
4156
+
4157
+ uint8x16_t q4_qs_cp_0 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 0); // 0 .. 7 & 32..39
4158
+ uint8x16_t q4_qs_cp_1 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 64); // 8 ..15 & 40..47
4159
+ uint8x16_t q4_qs_cp_2 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 128); // 16..23 & 48..55
4160
+ uint8x16_t q4_qs_cp_3 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 192); // 24..31 & 56..63
4161
+ const int8x16_t q4_nibbles[2][4] = {
4162
+ {
4163
+ vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_0, m4b)),
4164
+ vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_1, m4b)),
4165
+ vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_2, m4b)),
4166
+ vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_3, m4b)),
4167
+ },
4168
+ {
4169
+ vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_0, 4)),
4170
+ vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_1, 4)),
4171
+ vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_2, 4)),
4172
+ vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_3, 4)),
4173
+ }
4174
+ };
4175
+
4176
+ // Calculates the Qs muladd of every row pair (rp) rows 01 and 23 of q8
4177
+ // for each of the internal 32 qs subblock (blk)
4178
+ for (int rp = 0; rp < 2; rp++) {
4179
+ for (int blk = 0; blk < 2; blk++) {
4180
+ const int8x16_t * q8 = &q8s[rp][4 * blk];
4181
+ const int8x16_t * q4 = q4_nibbles[blk];
4182
+ int32x4_t acc = sb_acc[2 * rp + blk];
4183
+ // mul add for each qs in the same subblock
4184
+ for (int qs_offset = 0; qs_offset < 4; qs_offset++) {
4185
+ acc = vmmlaq_s32(acc, q4[qs_offset], q8[qs_offset]);
4186
+ }
4187
+ sb_acc[2 * rp + blk] = acc;
4188
+ }
4189
+ }
4190
+
4191
+ // Scales[i] corresponds to column i
4192
+ const int scale_offset = cp * 2;
4193
+ const int32_t scale_00 = q4sb_scales[0][scale_offset];
4194
+ const int32_t scale_01 = q4sb_scales[0][scale_offset + 1];
4195
+ const int32_t scale_10 = q4sb_scales[1][scale_offset];
4196
+ const int32_t scale_11 = q4sb_scales[1][scale_offset + 1];
4197
+ const int32x4_t block_scale_0 = vcombine_s32(vdup_n_s32(scale_00), vdup_n_s32(scale_01));
4198
+ const int32x4_t block_scale_1 = vcombine_s32(vdup_n_s32(scale_10), vdup_n_s32(scale_11));
4199
+
4200
+ acc[cp] = vmlaq_s32(acc[cp], sb_acc[0], block_scale_0);
4201
+ acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[2], block_scale_0);
4202
+ acc[cp] = vmlaq_s32(acc[cp], sb_acc[1], block_scale_1);
4203
+ acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[3], block_scale_1);
4204
+ }
4205
+
4206
+ // Multiply Acc bsum + mins
4207
+ for (int q8_row = 0; q8_row < 4; q8_row++) {
4208
+ // Each pair of subblocks share the same bsums
4209
+ // Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).
4210
+ int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][q8_row * 2]);
4211
+ int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][q8_row * 2 + 1]);
4212
+
4213
+ bias_acc[2 * q8_row] =
4214
+ vmlal_s16(bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
4215
+ bias_acc[2 * q8_row] =
4216
+ vmlal_s16(bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
4217
+ bias_acc[2 * q8_row + 1] =
4218
+ vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
4219
+ bias_acc[2 * q8_row + 1] =
4220
+ vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
4221
+ }
4222
+ } // for sb
4223
+
4224
+ // Reorder of i8mm output with bias and output layout
4225
+ for (int i = 0; i < 8; i++) {
4226
+ int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i]));
4227
+ acc[i] = vcombine_s32(aux.val[0], aux.val[1]);
4228
+ }
4229
+ int32x4_t reorder_acc[8] = {
4230
+ vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])),
4231
+ vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])),
4232
+ vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])),
4233
+ vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])),
4234
+ vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])),
4235
+ vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])),
4236
+ vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])),
4237
+ vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])),
4238
+ };
4239
+
4240
+ for (int i = 0; i < q8_k_blocklen; i++) {
4241
+ for (int j = 0; j < 2; j++) {
4242
+ float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d[i]);
4243
+ float32x4_t q4_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].dmin + j * 4)));
4244
+ const float32x4_t dmins = vmulq_f32(q4_dmin, q8_d);
4245
+
4246
+ float32x4_t q4_d = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].d + j * 4)));
4247
+ const float32x4_t scale = vmulq_f32(q4_d, q8_d);
4248
+
4249
+ acc_f32[2 * i + j] = vmlsq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(bias_acc[2 * i + j]), dmins);
4250
+ acc_f32[2 * i + j] =
4251
+ vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale);
4252
+ }
4253
+ }
4254
+ } // for b
2251
4255
 
2252
- #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
2253
- ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
4256
+ // With the previous reorder, the tile is already in the correct memory layout.
4257
+ for (int i = 0; i < q8_k_blocklen; i++) {
4258
+ int row = y * q8_k_blocklen + i;
4259
+ for (int j = 0; j < 2; j++) {
4260
+ int col = x * ncols_interleaved + j * 4;
4261
+ int offset = row * bs + col;
4262
+ vst1q_f32(s + offset, acc_f32[2 * i + j]);
4263
+ }
4264
+ }
4265
+ } // for x
4266
+ } // for y
4267
+ return;
4268
+ #endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
4269
+ ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
2254
4270
  }
2255
4271
 
2256
- void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
2257
- const int qk = QK8_0;
2258
- const int nb = n / qk;
2259
- const int ncols_interleaved = 4;
2260
- const int blocklen = 4;
4272
+ void ggml_gemm_q5_K_8x8_q8_K(int n,
4273
+ float * GGML_RESTRICT s,
4274
+ size_t bs,
4275
+ const void * GGML_RESTRICT vx,
4276
+ const void * GGML_RESTRICT vy,
4277
+ int nr,
4278
+ int nc) {
4279
+ constexpr int qk = QK_K;
4280
+ const int nb = n / qk;
2261
4281
 
2262
- assert (n % qk == 0);
2263
- assert (nr % 4 == 0);
2264
- assert (nc % ncols_interleaved == 0);
4282
+ constexpr int ncols_interleaved = 8;
4283
+ constexpr int blocklen = 8;
4284
+
4285
+ assert(n % qk == 0);
4286
+ assert(nr % 4 == 0);
4287
+ assert(nc % ncols_interleaved == 0);
2265
4288
 
2266
- UNUSED(s);
2267
- UNUSED(bs);
2268
- UNUSED(vx);
2269
- UNUSED(vy);
2270
- UNUSED(nr);
2271
- UNUSED(nc);
2272
4289
  UNUSED(nb);
2273
4290
  UNUSED(ncols_interleaved);
2274
4291
  UNUSED(blocklen);
2275
4292
 
2276
- #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
2277
- const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl);
4293
+ #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
4294
+ constexpr int q8_k_blocklen = 4;
4295
+ constexpr int col_pairs = ncols_interleaved / 2;
4296
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
4297
+ const uint8x16_t mone = vdupq_n_u8(1);
4298
+ const uint8x16_t mtwo = vdupq_n_u8(2);
4299
+
4300
+ // 8 accumulators: 2 row pairs × 4 col pairs
4301
+ float32x4_t acc_f32[blocklen];
4302
+
4303
+ for (int y = 0; y < nr / q8_k_blocklen; y++) {
4304
+ const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
2278
4305
 
2279
- for (int y = 0; y < nr / 4; y++) {
2280
- const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
2281
4306
  for (int x = 0; x < nc / ncols_interleaved; x++) {
2282
- const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
4307
+ const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);
2283
4308
 
2284
- float32x4_t sumf[4];
2285
- for (int m = 0; m < 4; m++) {
2286
- sumf[m] = vdupq_n_f32(0);
4309
+ for (int i = 0; i < blocklen; i++) {
4310
+ acc_f32[i] = vdupq_n_f32(0);
2287
4311
  }
2288
4312
 
2289
- for (int l = 0; l < nb; l++) {
2290
- float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d));
2291
- float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
4313
+ for (int b = 0; b < nb; b++) {
4314
+ // bsums pairs belongs to the same q8_k subblock
4315
+ const int16x8_t bsums[4]{
4316
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
4317
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
4318
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
4319
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
4320
+ };
4321
+ int16_t bsums_arr[4][8];
4322
+ for (int q8_row = 0; q8_row < 4; q8_row++) {
4323
+ vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
4324
+ }
2292
4325
 
2293
- int32x4_t sumi_0 = vdupq_n_s32(0);
2294
- int32x4_t sumi_1 = vdupq_n_s32(0);
2295
- int32x4_t sumi_2 = vdupq_n_s32(0);
2296
- int32x4_t sumi_3 = vdupq_n_s32(0);
4326
+ int32x4_t sb_acc[4]; // Aux accumulators to store subblock (partial) results
4327
+ int32x4_t acc[8]; // rows 01 stored in [0][1][2][3] rows 23 stored in [4][5][6][7]
4328
+ int32x4_t bias_acc[8]; // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567, [2]->r1 0123 ...
4329
+ for (int i = 0; i < 8; i++) {
4330
+ acc[i] = vdupq_n_s32(0);
4331
+ bias_acc[i] = vdupq_n_s32(0);
4332
+ }
2297
4333
 
2298
- for (int k = 0; k < 4; k++) {
2299
- int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0);
2300
- int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64);
4334
+ // Load qh once per block and shift after each subblock
4335
+ const uint8_t * qh_base = q5_ptr[b].qh;
4336
+ uint8x16_t qh[col_pairs][4];
4337
+ for (int cp = 0; cp < col_pairs; cp++) {
4338
+ qh[cp][0] = vld1q_u8(qh_base + 16 * cp);
4339
+ qh[cp][1] = vld1q_u8(qh_base + 16 * cp + 64);
4340
+ qh[cp][2] = vld1q_u8(qh_base + 16 * cp + 128);
4341
+ qh[cp][3] = vld1q_u8(qh_base + 16 * cp + 192);
4342
+ }
2301
4343
 
2302
- uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k);
2303
- int8x16_t b_hi = vqtbl1q_s8(kvalues, b >> 4);
2304
- int8x16_t b_lo = vqtbl1q_s8(kvalues, b & 0xF);
4344
+ for (int sb = 0; sb < QK_K / 64; sb++) {
4345
+ // Need scales for the low and high nibbles
4346
+ // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
4347
+ int8_t q5sb_scales[2][8];
4348
+ int16x8_t q5sb_mins[2]; // int16 as its needed for bias_acc later
4349
+ for (int i = 0; i < 2; i++) {
4350
+ const int offset = sb * 24 + i * 12;
4351
+ decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], q5sb_scales[i]);
4352
+ }
2305
4353
 
2306
- sumi_0 = vdotq_laneq_s32(sumi_0, b_lo, a_0, 0);
2307
- sumi_1 = vdotq_laneq_s32(sumi_1, b_lo, a_0, 1);
2308
- sumi_2 = vdotq_laneq_s32(sumi_2, b_lo, a_0, 2);
2309
- sumi_3 = vdotq_laneq_s32(sumi_3, b_lo, a_0, 3);
2310
- sumi_0 = vdotq_laneq_s32(sumi_0, b_hi, a_1, 0);
2311
- sumi_1 = vdotq_laneq_s32(sumi_1, b_hi, a_1, 1);
2312
- sumi_2 = vdotq_laneq_s32(sumi_2, b_hi, a_1, 2);
2313
- sumi_3 = vdotq_laneq_s32(sumi_3, b_hi, a_1, 3);
4354
+ // q8_ptr[b].qs has interleaved Q8 rows (01, 23)
4355
+ const int8_t * q8_base = q8_ptr[b].qs + sb * 256;
4356
+
4357
+ int8x16_t q8_qs_01[8];
4358
+ int8x16_t q8_qs_23[8];
4359
+
4360
+ // Load 32-byte per row pair, 1 subblock each time
4361
+ for (int i = 0; i < 8; i++) {
4362
+ const int offset = i * 32; // 16 for row 01, 16 for row 23
4363
+ q8_qs_01[i] = vld1q_s8(q8_base + offset);
4364
+ q8_qs_23[i] = vld1q_s8(q8_base + offset + 16);
4365
+ }
4366
+
4367
+ const int8x16_t q8s[2][8] = {
4368
+ { q8_qs_01[0], q8_qs_01[1], q8_qs_01[2], q8_qs_01[3], q8_qs_01[4], q8_qs_01[5], q8_qs_01[6],
4369
+ q8_qs_01[7] },
4370
+ { q8_qs_23[0], q8_qs_23[1], q8_qs_23[2], q8_qs_23[3], q8_qs_23[4], q8_qs_23[5], q8_qs_23[6],
4371
+ q8_qs_23[7] },
4372
+ };
4373
+
4374
+ // Q5s columns iterated in pairs (01, 23, 45, 67)
4375
+ for (int cp = 0; cp < col_pairs; cp++) {
4376
+ for (int i = 0; i < 4; i++) {
4377
+ sb_acc[i] = vdupq_n_s32(0);
4378
+ }
4379
+
4380
+ uint8x16_t qs_cp_0 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 0); // 0 .. 7 & 32..39
4381
+ uint8x16_t qs_cp_1 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 64); // 8 ..15 & 40..47
4382
+ uint8x16_t qs_cp_2 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 128); // 16..23 & 48..55
4383
+ uint8x16_t qs_cp_3 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 192); // 24..31 & 56..63
4384
+
4385
+ // This is the only part of the algorithm that differs with Q4_K
4386
+ // Extract High bits and pack into 5 bit weights
4387
+ uint8x16_t hbit_lo_0 = vandq_u8(qh[cp][0], mone);
4388
+ uint8x16_t hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[cp][0], mtwo), 3);
4389
+ qh[cp][0] = vshrq_n_u8(qh[cp][0], 2);
4390
+ // Same as Q4_K, i8mm to dequantize the weights.
4391
+ const int8x16_t qs_lo_0 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_0, m4b), hbit_lo_0, 4));
4392
+ int32x4_t acc_0 = sb_acc[0];
4393
+ acc_0 = vmmlaq_s32(acc_0, qs_lo_0, q8s[0][0]);
4394
+ int32x4_t acc_2 = sb_acc[2];
4395
+ acc_2 = vmmlaq_s32(acc_2, qs_lo_0, q8s[1][0]);
4396
+ const int8x16_t qs_hi_0 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_0, 4), hbit_hi_0));
4397
+ int32x4_t acc_1 = sb_acc[1];
4398
+ acc_1 = vmmlaq_s32(acc_1, qs_hi_0, q8s[0][4]);
4399
+ int32x4_t acc_3 = sb_acc[3];
4400
+ acc_3 = vmmlaq_s32(acc_3, qs_hi_0, q8s[1][4]);
4401
+
4402
+ // Repeat for the other 3 columns (8..15, 16..23, 24..31)
4403
+ uint8x16_t hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[cp][1], mtwo), 3);
4404
+ uint8x16_t hbit_lo_1 = vandq_u8(qh[cp][1], mone);
4405
+ qh[cp][1] = vshrq_n_u8(qh[cp][1], 2);
4406
+ const int8x16_t qs_lo_1 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_1, m4b), hbit_lo_1, 4));
4407
+ acc_0 = vmmlaq_s32(acc_0, qs_lo_1, q8s[0][1]);
4408
+ acc_2 = vmmlaq_s32(acc_2, qs_lo_1, q8s[1][1]);
4409
+ const int8x16_t qs_hi_1 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_1, 4), hbit_hi_1));
4410
+ acc_1 = vmmlaq_s32(acc_1, qs_hi_1, q8s[0][5]);
4411
+ acc_3 = vmmlaq_s32(acc_3, qs_hi_1, q8s[1][5]);
4412
+
4413
+ uint8x16_t hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[cp][2], mtwo), 3);
4414
+ uint8x16_t hbit_lo_2 = vandq_u8(qh[cp][2], mone);
4415
+ qh[cp][2] = vshrq_n_u8(qh[cp][2], 2);
4416
+ const int8x16_t qs_lo_2 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_2, m4b), hbit_lo_2, 4));
4417
+ acc_0 = vmmlaq_s32(acc_0, qs_lo_2, q8s[0][2]);
4418
+ acc_2 = vmmlaq_s32(acc_2, qs_lo_2, q8s[1][2]);
4419
+ const int8x16_t qs_hi_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_2, 4), hbit_hi_2));
4420
+ acc_1 = vmmlaq_s32(acc_1, qs_hi_2, q8s[0][6]);
4421
+ acc_3 = vmmlaq_s32(acc_3, qs_hi_2, q8s[1][6]);
4422
+
4423
+ uint8x16_t hbit_lo_3 = vandq_u8(qh[cp][3], mone);
4424
+ uint8x16_t hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[cp][3], mtwo), 3);
4425
+ qh[cp][3] = vshrq_n_u8(qh[cp][3], 2);
4426
+ const int8x16_t qs_lo_3 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_3, m4b), hbit_lo_3, 4));
4427
+ acc_0 = vmmlaq_s32(acc_0, qs_lo_3, q8s[0][3]);
4428
+ sb_acc[0] = acc_0;
4429
+ acc_2 = vmmlaq_s32(acc_2, qs_lo_3, q8s[1][3]);
4430
+ sb_acc[2] = acc_2;
4431
+
4432
+ // Scales[i] corresponds to column i
4433
+ const int scale_offset = cp * 2;
4434
+ const int32_t s0 = q5sb_scales[0][scale_offset];
4435
+ const int32_t s1 = q5sb_scales[0][scale_offset + 1];
4436
+ const int32x4_t block_scale = vcombine_s32(vdup_n_s32(s0), vdup_n_s32(s1));
4437
+ acc[cp] = vmlaq_s32(acc[cp], sb_acc[0], block_scale);
4438
+ acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[2], block_scale);
4439
+
4440
+ const int8x16_t qs_hi_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_3, 4), hbit_hi_3));
4441
+ acc_1 = vmmlaq_s32(acc_1, qs_hi_3, q8s[0][7]);
4442
+ sb_acc[1] = acc_1;
4443
+ acc_3 = vmmlaq_s32(acc_3, qs_hi_3, q8s[1][7]);
4444
+ sb_acc[3] = acc_3;
4445
+
4446
+ const int32_t s2 = q5sb_scales[1][scale_offset];
4447
+ const int32_t s3 = q5sb_scales[1][scale_offset + 1];
4448
+ const int32x4_t block_scale2 = vcombine_s32(vdup_n_s32(s2), vdup_n_s32(s3));
4449
+ acc[cp] = vmlaq_s32(acc[cp], sb_acc[1], block_scale2);
4450
+ acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[3], block_scale2);
4451
+ }
4452
+
4453
+ // Multiply Acc bsum + mins
4454
+ for (int q8_row = 0; q8_row < 4; q8_row++) {
4455
+ // Each pair of subblocks share the same bsums
4456
+ // Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).
4457
+ int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][q8_row * 2]);
4458
+ int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][q8_row * 2 + 1]);
4459
+
4460
+ bias_acc[2 * q8_row] =
4461
+ vmlal_s16(bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16(q5sb_mins[0]));
4462
+ bias_acc[2 * q8_row] =
4463
+ vmlal_s16(bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16(q5sb_mins[1]));
4464
+ bias_acc[2 * q8_row + 1] =
4465
+ vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_lo, vget_high_s16(q5sb_mins[0]));
4466
+ bias_acc[2 * q8_row + 1] =
4467
+ vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_hi, vget_high_s16(q5sb_mins[1]));
4468
+ }
4469
+ } // for sb
4470
+
4471
+ // Reorder of i8mm output with bias and output layout
4472
+ for (int i = 0; i < 8; i++) {
4473
+ int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i]));
4474
+ acc[i] = vcombine_s32(aux.val[0], aux.val[1]);
2314
4475
  }
4476
+ int32x4_t reorder_acc[8] = {
4477
+ vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])),
4478
+ vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])),
4479
+ vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])),
4480
+ vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])),
4481
+ vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])),
4482
+ vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])),
4483
+ vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])),
4484
+ vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])),
4485
+ };
2315
4486
 
2316
- sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0));
2317
- sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1));
2318
- sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2));
2319
- sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3));
2320
- }
4487
+ for (int i = 0; i < q8_k_blocklen; i++) {
4488
+ for (int j = 0; j < 2; j++) {
4489
+ float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d[i]);
4490
+ float32x4_t q5_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) (q5_ptr[b].dmin + j * 4)));
4491
+ const float32x4_t dmins = vmulq_f32(q5_dmin, q8_d);
2321
4492
 
2322
- for (int m = 0; m < 4; m++) {
2323
- vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);
4493
+ float32x4_t q5_d = vcvt_f32_f16(vld1_f16((const __fp16 *) (q5_ptr[b].d + j * 4)));
4494
+ const float32x4_t scale = vmulq_f32(q5_d, q8_d);
4495
+
4496
+ acc_f32[2 * i + j] = vmlsq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(bias_acc[2 * i + j]), dmins);
4497
+ acc_f32[2 * i + j] =
4498
+ vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale);
4499
+ }
4500
+ }
4501
+ } // for b
4502
+
4503
+ // With the previous reorder, the tile is already in the correct memory layout.
4504
+ for (int i = 0; i < q8_k_blocklen; i++) {
4505
+ int row = y * q8_k_blocklen + i;
4506
+ for (int j = 0; j < 2; j++) {
4507
+ int col = x * ncols_interleaved + j * 4;
4508
+ int offset = row * bs + col;
4509
+ vst1q_f32(s + offset, acc_f32[2 * i + j]);
4510
+ }
2324
4511
  }
2325
- }
2326
- }
4512
+ } // for x
4513
+ } // for y
2327
4514
  return;
2328
- #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
2329
- ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
4515
+ #endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
4516
+ ggml_gemm_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
2330
4517
  }
2331
4518
 
2332
- void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
4519
+ void ggml_gemm_q6_K_8x4_q8_K(int n,
4520
+ float * GGML_RESTRICT s,
4521
+ size_t bs,
4522
+ const void * GGML_RESTRICT vx,
4523
+ const void * GGML_RESTRICT vy,
4524
+ int nr,
4525
+ int nc) {
2333
4526
  constexpr int qk = QK_K;
2334
4527
  const int nb = n / qk;
2335
4528
 
@@ -2346,171 +4539,167 @@ void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
2346
4539
 
2347
4540
  #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
2348
4541
  constexpr int q8_k_blocklen = 4;
2349
- constexpr int acc_size = 2 * 4; // 2 row pairs × 4 col pairs
2350
- const uint8x16_t m4b = vdupq_n_u8(0x0f);
2351
-
2352
- // 8 accumulators: 2 row pairs × 4 col pairs
2353
- float32x4_t acc_f32[acc_size];
2354
-
2355
- for (int y = 0; y < nr / q8_k_blocklen; y++) {
2356
- const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
2357
-
2358
- for (int x = 0; x < nc / ncols_interleaved; x++) {
2359
- const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
2360
-
2361
- for (int i = 0; i < acc_size; i++) {
2362
- acc_f32[i] = vdupq_n_f32(0);
2363
- }
2364
-
2365
- for (int b = 0; b < nb; b++) {
2366
- // d4 0 1 2 3, 4 5 6 7
2367
- float32x4_t q4_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d));
2368
- float32x4_t q4_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4));
2369
- // d8 0 1 2 3
2370
- float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d);
2371
- // mins
2372
- float32x4_t q4_dmin_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin));
2373
- float32x4_t q4_dmin_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4));
2374
-
2375
- // Precomputation of scales and mins
2376
- float32x4_t sbd_scale_0123[q8_k_blocklen];
2377
- float32x4_t sbd_scale_4567[q8_k_blocklen];
2378
- float32x4_t sbd_min_0123[q8_k_blocklen];
2379
- float32x4_t sbd_min_4567[q8_k_blocklen];
2380
-
2381
- sbd_scale_0123[0] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 0);
2382
- sbd_scale_4567[0] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 0);
2383
- sbd_min_0123[0] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 0);
2384
- sbd_min_4567[0] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 0);
2385
-
2386
- sbd_scale_0123[1] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 1);
2387
- sbd_scale_4567[1] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 1);
2388
- sbd_min_0123[1] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 1);
2389
- sbd_min_4567[1] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 1);
2390
-
2391
- sbd_scale_0123[2] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 2);
2392
- sbd_scale_4567[2] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 2);
2393
- sbd_min_0123[2] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 2);
2394
- sbd_min_4567[2] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 2);
4542
+ constexpr int col_groups = ncols_interleaved / 4;
4543
+ constexpr int acc_size = q8_k_blocklen * col_groups; // 4 rows, 2 column groups
4544
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
4545
+ const uint8x16_t mask_lo = vdupq_n_u8(0x03);
4546
+ const uint8x16_t mask_hi = vdupq_n_u8(0x30);
4547
+ const int8x16_t m32s = vdupq_n_s8(32);
2395
4548
 
2396
- sbd_scale_0123[3] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 3);
2397
- sbd_scale_4567[3] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 3);
2398
- sbd_min_0123[3] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 3);
2399
- sbd_min_4567[3] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 3);
4549
+ float32x4_t acc_f32[acc_size];
2400
4550
 
2401
- // Precomputation of bsums, each vpaddq calcs all the bsums for each row
2402
- const int16x8_t bsums[q8_k_blocklen] = {
2403
- vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
2404
- vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
2405
- vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
2406
- vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
2407
- };
2408
- int16_t bsums_arr[QK_K / 64][8];
2409
- for (int q8_row = 0; q8_row < 4; q8_row++) {
2410
- vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
2411
- }
4551
+ for (int y = 0; y < nr / q8_k_blocklen; y++) {
4552
+ const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
2412
4553
 
2413
- // interleaved bias_acc: [0]->r0 0123, [1]->r1 0123, .., [4]->r0 4567, [5]->r1 4567 ..
2414
- int32x4_t bias_acc[acc_size];
2415
- for (int i = 0; i < acc_size; i++) {
2416
- bias_acc[i] = vdupq_n_s32(0);
2417
- }
4554
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
4555
+ const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb);
2418
4556
 
2419
- for (int sb = 0; sb < QK_K / 64; sb++) {
2420
- // Int accumulators for qs vecdot (4 row x 2 col quartets)
2421
- int32x4_t acc_lo[acc_size];
2422
- int32x4_t acc_hi[acc_size];
2423
- for (int i = 0; i < acc_size; i++) {
2424
- acc_lo[i] = vdupq_n_s32(0);
2425
- acc_hi[i] = vdupq_n_s32(0);
2426
- }
2427
- // Need scales for the low and high nibbles
2428
- // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
2429
- int16x8_t q4sb_scales[2];
2430
- int16x8_t q4sb_mins[2];
2431
- for (int i = 0; i < 2; i++) {
2432
- int8_t aux_q4sb[8];
2433
- const int offset = sb * 24 + i * 12;
2434
- decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
2435
- q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
2436
- }
4557
+ for (int i = 0; i < acc_size; i++) {
4558
+ acc_f32[i] = vdupq_n_f32(0);
4559
+ }
2437
4560
 
2438
- constexpr int reads_per_sb = 8; // 8 * 16 bytes each => 32 qs * 4 rows
2439
- for (int k = 0; k < reads_per_sb; k++) {
2440
- const int8x16_t q8_blk0 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k);
2441
- const int8x16_t q8_blk1 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k + 128);
4561
+ for (int b = 0; b < nb; b++) {
4562
+ float32x4_t q6_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d));
4563
+ float32x4_t q6_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4));
4564
+ float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d);
2442
4565
 
2443
- // 0..3 & 32..35
2444
- const uint8x16_t q4_0123 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k);
2445
- const uint8x16_t q4_4567 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k + 16);
4566
+ float32x4_t sbd_scale_0123[q8_k_blocklen];
4567
+ float32x4_t sbd_scale_4567[q8_k_blocklen];
2446
4568
 
2447
- const int8x16_t q4_0123_lo = vreinterpretq_s8_u8(vandq_u8(q4_0123, m4b));
2448
- const int8x16_t q4_0123_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_0123, 4));
4569
+ sbd_scale_0123[0] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 0);
4570
+ sbd_scale_4567[0] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 0);
4571
+ sbd_scale_0123[1] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 1);
4572
+ sbd_scale_4567[1] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 1);
4573
+ sbd_scale_0123[2] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 2);
4574
+ sbd_scale_4567[2] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 2);
4575
+ sbd_scale_0123[3] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 3);
4576
+ sbd_scale_4567[3] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 3);
2449
4577
 
2450
- acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q4_0123_lo, q8_blk0, 0); // 0..3 r0 c0123
2451
- acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q4_0123_lo, q8_blk0, 1); // 0..3 r1 c0123
2452
- acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q4_0123_lo, q8_blk0, 2); // 0..3 r2 c0123
2453
- acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q4_0123_lo, q8_blk0, 3); // 0..3 r3 c0123
4578
+ int32x4_t acc_s32[acc_size];
4579
+ for (int i = 0; i < acc_size; i++) {
4580
+ acc_s32[i] = vdupq_n_s32(0);
4581
+ }
2454
4582
 
2455
- acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q4_0123_hi, q8_blk1, 0); // 32..35 r0 c0123
2456
- acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q4_0123_hi, q8_blk1, 1); // 32..35 r1 c0123
2457
- acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q4_0123_hi, q8_blk1, 2); // 32..35 r2 c0123
2458
- acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q4_0123_hi, q8_blk1, 3); // 32..35 r3 c0123
4583
+ int16_t q6_scales[8 * 16];
4584
+ for (int i = 0; i < 16; i++) {
4585
+ int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8));
4586
+ vst1q_s16(q6_scales + i * 8, scales);
4587
+ }
2459
4588
 
2460
- const int8x16_t q4_4567_lo = vreinterpretq_s8_u8(vandq_u8(q4_4567, m4b));
2461
- const int8x16_t q4_4567_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_4567, 4));
4589
+ for (int half = 0; half < 2; half++) {
4590
+ const uint8_t * ql_base = q6_ptr[b].ql + half * 512;
4591
+ const uint8_t * qh_base = q6_ptr[b].qh + half * 256;
2462
4592
 
2463
- acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q4_4567_lo, q8_blk0, 0); // 0..3 r0 c4567
2464
- acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q4_4567_lo, q8_blk0, 1); // 0..3 r1 c4567
2465
- acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q4_4567_lo, q8_blk0, 2); // 0..3 r2 c4567
2466
- acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q4_4567_lo, q8_blk0, 3); // 0..3 r3 c4567
4593
+ for (int sb = 0; sb < QK_K / 64; sb++) {
4594
+ int32x4_t acc_lo[acc_size];
4595
+ int32x4_t acc_hi[acc_size];
4596
+ for (int i = 0; i < acc_size; i++) {
4597
+ acc_lo[i] = vdupq_n_s32(0);
4598
+ acc_hi[i] = vdupq_n_s32(0);
4599
+ }
2467
4600
 
2468
- acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q4_4567_hi, q8_blk1, 0); // 32..35 r0 c4567
2469
- acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q4_4567_hi, q8_blk1, 1); // 32..35 r1 c4567
2470
- acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q4_4567_hi, q8_blk1, 2); // 32..35 r2 c4567
2471
- acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q4_4567_hi, q8_blk1, 3); // 32..35 r3 c4567
2472
- }
4601
+ const int8_t * q8_base_l = q8_ptr[b].qs + half * 512 + sb * 64;
4602
+ const int8_t * q8_base_h = q8_ptr[b].qs + half * 512 + 256 + sb * 64;
4603
+
4604
+ // 4 rows * 16 elements per scale
4605
+ // 4 reads of 16 bytes each
4606
+ constexpr int reads_per_sb = 4;
4607
+ int8x16_t q8_l[reads_per_sb];
4608
+ int8x16_t q8_h[reads_per_sb];
4609
+ for (int k = 0; k < reads_per_sb; k++) {
4610
+ q8_l[k] = vld1q_s8(q8_base_l + 16 * k);
4611
+ q8_h[k] = vld1q_s8(q8_base_h + 16 * k);
4612
+ }
2473
4613
 
2474
- // Scale and bias application
2475
- // acc is stored interleaved to match output layout
2476
- const int16x4_t sc_0123_lo = vget_low_s16(q4sb_scales[0]);
2477
- const int16x4_t sc_4567_lo = vget_high_s16(q4sb_scales[0]);
2478
- const int16x4_t sc_0123_hi = vget_low_s16(q4sb_scales[1]);
2479
- const int16x4_t sc_4567_hi = vget_high_s16(q4sb_scales[1]);
2480
- for (int row = 0; row < q8_k_blocklen; row++) {
2481
- // Bias correction
2482
- // row c0123 blk0 and blk1
2483
- const float32x4_t sumf_0123 =
2484
- vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[row]),
2485
- vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[row])));
2486
- acc_f32[2 * row] = vfmaq_f32(acc_f32[2 * row], sbd_scale_0123[row], sumf_0123);
4614
+ const int ql_off_base = sb * QK_K / 2;
4615
+ const int qh_off_base = ql_off_base & 255;
2487
4616
 
2488
- // row c4567 blk0 and blk1
2489
- const float32x4_t sumf_4567 =
2490
- vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[row + 4]),
2491
- vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[row + 4])));
2492
- acc_f32[2 * row + 1] = vfmaq_f32(acc_f32[2 * row + 1], sbd_scale_4567[row], sumf_4567);
4617
+ uint8x16_t q6_ql_0123[reads_per_sb];
4618
+ uint8x16_t q6_ql_4567[reads_per_sb];
4619
+ uint8x16_t q6_qh_0123[reads_per_sb];
4620
+ uint8x16_t q6_qh_4567[reads_per_sb];
2493
4621
 
2494
- // Bias
2495
- const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][row * 2]);
2496
- const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][row * 2 + 1]);
4622
+ for (int k = 0; k < reads_per_sb; k++) {
4623
+ q6_ql_0123[k] = vld1q_u8(ql_base + ql_off_base + k * 32);
4624
+ q6_ql_4567[k] = vld1q_u8(ql_base + ql_off_base + k * 32 + 16);
4625
+ q6_qh_0123[k] = vld1q_u8(qh_base + qh_off_base + k * 32);
4626
+ q6_qh_4567[k] = vld1q_u8(qh_base + qh_off_base + k * 32 + 16);
4627
+ }
2497
4628
 
2498
- // row c0123 blk0 and blk1
2499
- bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
2500
- bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
4629
+ if (sb > 1) {
4630
+ for (int k = 0; k < reads_per_sb; k++) {
4631
+ q6_qh_0123[k] = vshrq_n_u8(q6_qh_0123[k], 2);
4632
+ q6_qh_4567[k] = vshrq_n_u8(q6_qh_4567[k], 2);
4633
+ }
4634
+ }
2501
4635
 
2502
- // row c4567 blk0 and blk1
2503
- bias_acc[2 * row + 1] =
2504
- vmlal_s16(bias_acc[2 * row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
2505
- bias_acc[2 * row + 1] =
2506
- vmlal_s16(bias_acc[2 * row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
4636
+ for (int k = 0; k < reads_per_sb; k++) {
4637
+ // q = (ql | qh) - 32
4638
+ const uint8x16_t hbit_lo_0123 = vandq_u8(q6_qh_0123[k], mask_lo);
4639
+ const uint8x16_t hbit_hi_0123 = vandq_u8(q6_qh_0123[k], mask_hi);
4640
+ const uint8x16_t hbit_lo_4567 = vandq_u8(q6_qh_4567[k], mask_lo);
4641
+ const uint8x16_t hbit_hi_4567 = vandq_u8(q6_qh_4567[k], mask_hi);
4642
+
4643
+ const int8x16_t q6_0123_lo = vsubq_s8(
4644
+ vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_ql_0123[k], m4b), hbit_lo_0123, 4)), m32s);
4645
+ const int8x16_t q6_0123_hi = vsubq_s8(
4646
+ vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_ql_0123[k], 4), hbit_hi_0123)), m32s);
4647
+
4648
+ acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q6_0123_lo, q8_l[k], 0); // 0..3 r0 c0123
4649
+ acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q6_0123_lo, q8_l[k], 1); // 0..3 r1 c0123
4650
+ acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q6_0123_lo, q8_l[k], 2); // 0..3 r2 c0123
4651
+ acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q6_0123_lo, q8_l[k], 3); // 0..3 r3 c0123
4652
+
4653
+ acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q6_0123_hi, q8_h[k], 0); // 64..67 r0 c0123
4654
+ acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q6_0123_hi, q8_h[k], 1); // 64..67 r1 c0123
4655
+ acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q6_0123_hi, q8_h[k], 2); // 64..67 r2 c0123
4656
+ acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q6_0123_hi, q8_h[k], 3); // 64..67 r3 c0123
4657
+
4658
+ const int8x16_t q6_4567_lo = vsubq_s8(
4659
+ vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_ql_4567[k], m4b), hbit_lo_4567, 4)), m32s);
4660
+ const int8x16_t q6_4567_hi = vsubq_s8(
4661
+ vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_ql_4567[k], 4), hbit_hi_4567)), m32s);
4662
+
4663
+ acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q6_4567_lo, q8_l[k], 0); // 0..3 r0 c4567
4664
+ acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q6_4567_lo, q8_l[k], 1); // 0..3 r1 c4567
4665
+ acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q6_4567_lo, q8_l[k], 2); // 0..3 r2 c4567
4666
+ acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q6_4567_lo, q8_l[k], 3); // 0..3 r3 c4567
4667
+
4668
+ acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q6_4567_hi, q8_h[k], 0); // 64..67 r0 c4567
4669
+ acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q6_4567_hi, q8_h[k], 1); // 64..67 r1 c4567
4670
+ acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q6_4567_hi, q8_h[k], 2); // 64..67 r2 c4567
4671
+ acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q6_4567_hi, q8_h[k], 3); // 64..67 r3 c4567
4672
+ }
4673
+
4674
+ // Scale and bias
4675
+ const int scale_idx_l = half * 8 + sb;
4676
+ const int scale_idx_h = half * 8 + sb + 4;
4677
+
4678
+ for (int g = 0; g < col_groups; g++) {
4679
+ const int16x4_t scales_l16 = vld1_s16(q6_scales + scale_idx_l * 8 + g * 4);
4680
+ const int16x4_t scales_h16 = vld1_s16(q6_scales + scale_idx_h * 8 + g * 4);
4681
+ const int32x4_t scale_vec_l = vmovl_s16(scales_l16);
4682
+ const int32x4_t scale_vec_h = vmovl_s16(scales_h16);
4683
+ const int acc_offset = g * q8_k_blocklen;
4684
+
4685
+ for (int row = 0; row < q8_k_blocklen; row++) {
4686
+ const int idx = row * 2 + g;
4687
+ acc_s32[idx] = vmlaq_s32(acc_s32[idx], acc_lo[acc_offset + row], scale_vec_l);
4688
+ acc_s32[idx] = vmlaq_s32(acc_s32[idx], acc_hi[acc_offset + row], scale_vec_h);
4689
+ }
4690
+ }
2507
4691
  }
2508
- } // for sb
4692
+ }
2509
4693
 
4694
+ // Finally we apply the superblock scales
2510
4695
  for (int row = 0; row < q8_k_blocklen; row++) {
2511
- acc_f32[2 * row] = vmlsq_f32(acc_f32[2 * row], vcvtq_f32_s32(bias_acc[2 * row]), sbd_min_0123[row]);
2512
- acc_f32[2 * row + 1] =
2513
- vmlsq_f32(acc_f32[2 * row + 1], vcvtq_f32_s32(bias_acc[2 * row + 1]), sbd_min_4567[row]);
4696
+ const int idx0 = 2 * row;
4697
+ const int idx1 = 2 * row + 1;
4698
+ const int32x4_t acc_0123 = acc_s32[idx0];
4699
+ const int32x4_t acc_4567 = acc_s32[idx1];
4700
+
4701
+ acc_f32[idx0] = vmlaq_f32(acc_f32[idx0], vcvtq_f32_s32(acc_0123), sbd_scale_0123[row]);
4702
+ acc_f32[idx1] = vmlaq_f32(acc_f32[idx1], vcvtq_f32_s32(acc_4567), sbd_scale_4567[row]);
2514
4703
  }
2515
4704
  } // for b
2516
4705
 
@@ -2526,10 +4715,10 @@ void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
2526
4715
  } // for y
2527
4716
  return;
2528
4717
  #endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
2529
- ggml_gemm_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
4718
+ ggml_gemm_q6_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
2530
4719
  }
2531
4720
 
2532
- void ggml_gemm_q4_K_8x8_q8_K(int n,
4721
+ void ggml_gemm_q6_K_8x8_q8_K(int n,
2533
4722
  float * GGML_RESTRICT s,
2534
4723
  size_t bs,
2535
4724
  const void * GGML_RESTRICT vx,
@@ -2553,144 +4742,155 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
2553
4742
  #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
2554
4743
  constexpr int q8_k_blocklen = 4;
2555
4744
  const uint8x16_t m4b = vdupq_n_u8(0x0f);
4745
+ const uint8x16_t mask_lo = vdupq_n_u8(0x03);
4746
+ const uint8x16_t mask_hi = vdupq_n_u8(0x30);
4747
+ const int8x16_t m32s = vdupq_n_s8(32);
2556
4748
 
2557
- // 8 accumulators: 2 row pairs × 4 col pairs
4749
+ // 8 accumulators: 4 q8 rows × 2 col groups (0-3, 4-7)
2558
4750
  float32x4_t acc_f32[blocklen];
2559
4751
 
2560
4752
  for (int y = 0; y < nr / q8_k_blocklen; y++) {
2561
4753
  const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
2562
4754
 
2563
4755
  for (int x = 0; x < nc / ncols_interleaved; x++) {
2564
- const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
4756
+ const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb);
2565
4757
 
2566
4758
  for (int i = 0; i < blocklen; i++) {
2567
4759
  acc_f32[i] = vdupq_n_f32(0);
2568
4760
  }
2569
4761
 
2570
4762
  for (int b = 0; b < nb; b++) {
2571
- // bsums pairs belongs to the same q8_k subblock
2572
- const int16x8_t bsums[4]{
2573
- vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
2574
- vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
2575
- vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
2576
- vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
2577
- };
2578
- int16_t bsums_arr[4][8];
2579
- for (int q8_row = 0; q8_row < 4; q8_row++) {
2580
- vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
2581
- }
2582
-
2583
- int32x4_t sb_acc[4]; // Aux accumulators to store subblock (partial) results
2584
- int32x4_t acc[8]; // rows 01 stored in [0][1][2][3] rows 23 stored in [4][5][6][7]
2585
- int32x4_t bias_acc[8]; // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567, [2]->r1 0123 ...
4763
+ int32x4_t acc[8]; // rows 01 stored in [0][1][2][3], rows 23 stored in [4][5][6][7]
2586
4764
  for (int i = 0; i < 8; i++) {
2587
- acc[i] = vdupq_n_s32(0);
2588
- bias_acc[i] = vdupq_n_s32(0);
4765
+ acc[i] = vdupq_n_s32(0);
2589
4766
  }
2590
4767
 
2591
- for (int sb = 0; sb < QK_K / 64; sb++) {
2592
- // Need scales for the low and high nibbles
2593
- // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
2594
- int8_t q4sb_scales[2][8];
2595
- int16x8_t q4sb_mins[2]; // int16 as its needed for bias_acc later
2596
- for (int i = 0; i < 2; i++) {
2597
- const int offset = sb * 24 + i * 12;
2598
- decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]);
2599
- }
2600
-
2601
- // q8_ptr[b].qs has interleaved Q8 rows (01, 23)
2602
- const int8_t * q8_base = q8_ptr[b].qs + sb * 256;
4768
+ // Q6_K has simple 8-bit scales, 16 per block (one per 16 values)
4769
+ // Reused for bias and dequantization later
4770
+ int16_t q6_scales[16 * 8];
4771
+ for (int i = 0; i < 16; ++i) {
4772
+ int16x8_t s16 = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8));
4773
+ vst1q_s16(q6_scales + i * 8, s16);
4774
+ }
2603
4775
 
2604
- int8x16_t q8_qs_01[8];
2605
- int8x16_t q8_qs_23[8];
4776
+ // Process two 128-value halves per superblock
4777
+ for (int half = 0; half < 2; half++) {
4778
+
4779
+ const uint8_t * ql_base = q6_ptr[b].ql + half * 512;
4780
+ const uint8_t * qh_base = q6_ptr[b].qh + half * 256;
4781
+
4782
+ // A subblock (sb) is a set of weights that share the scale
4783
+ // Since q6_K scales are per 16 elements
4784
+ // num sbs -> 256 elements / (16 elements/scale * 2 elements/byte * 2 halves)
4785
+ for (int sb = 0; sb < QK_K / 64; sb++) {
4786
+ // Q6_K weight index increasing by 64 instead of 32 requires
4787
+ // loading various q8 memory regions
4788
+ const int8_t * q8_base_l = q8_ptr[b].qs + half * 512 + sb * 64;
4789
+ const int8_t * q8_base_h = q8_ptr[b].qs + half * 512 + 256 + sb * 64;
4790
+
4791
+ int8x16_t q8_l_01[2];
4792
+ int8x16_t q8_l_23[2];
4793
+ for (int i = 0; i < 2; i++) {
4794
+ const int offset = i * 32;
4795
+ q8_l_01[i] = vld1q_s8(q8_base_l + offset); // 0..7 & 8..15 (r01)
4796
+ q8_l_23[i] = vld1q_s8(q8_base_l + offset + 16); // 0..7 & 8..15 (r23)
4797
+ }
2606
4798
 
2607
- // Load 32-byte per row pair, 1 subblock each time
2608
- for (int i = 0; i < 8; i++) {
2609
- const int offset = i * 32; // 16 for row 01, 16 for row 23
2610
- q8_qs_01[i] = vld1q_s8(q8_base + offset);
2611
- q8_qs_23[i] = vld1q_s8(q8_base + offset + 16);
2612
- }
4799
+ int8x16_t q8_h_01[2];
4800
+ int8x16_t q8_h_23[2];
4801
+ for (int i = 0; i < 2; i++) {
4802
+ const int offset = i * 32;
4803
+ q8_h_01[i] = vld1q_s8(q8_base_h + offset);
4804
+ q8_h_23[i] = vld1q_s8(q8_base_h + offset + 16);
4805
+ }
2613
4806
 
2614
- const int8x16_t q8s[2][8] = {
2615
- { q8_qs_01[0], q8_qs_01[1], q8_qs_01[2], q8_qs_01[3],
2616
- q8_qs_01[4], q8_qs_01[5], q8_qs_01[6], q8_qs_01[7] },
2617
- { q8_qs_23[0], q8_qs_23[1], q8_qs_23[2], q8_qs_23[3],
2618
- q8_qs_23[4], q8_qs_23[5], q8_qs_23[6], q8_qs_23[7] },
2619
- };
4807
+ const int ql_off_base = sb * QK_K / 2;
2620
4808
 
2621
- // Q4s columns iterated in pairs (01, 23, 45, 67)
2622
- for (int cp = 0; cp < ncols_interleaved / 2; cp++) {
2623
- for (int i = 0; i < 4; i++) {
2624
- sb_acc[i] = vdupq_n_s32(0);
4809
+ uint8x16_t q6_ql_0[4];
4810
+ uint8x16_t q6_ql_1[4];
4811
+ for (int k = 0; k < 4; k++) {
4812
+ q6_ql_0[k] = vld1q_u8(ql_base + ql_off_base + 16 * k);
4813
+ q6_ql_1[k] = vld1q_u8(ql_base + ql_off_base + 64 + 16 * k);
2625
4814
  }
2626
4815
 
2627
- uint8x16_t q4_qs_cp_0 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 0); // 0 .. 7 & 32..39
2628
- uint8x16_t q4_qs_cp_1 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 64); // 8 ..15 & 40..47
2629
- uint8x16_t q4_qs_cp_2 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 128); // 16..23 & 48..55
2630
- uint8x16_t q4_qs_cp_3 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 192); // 24..31 & 56..63
2631
- const int8x16_t q4_nibbles[2][4] = {
2632
- {
2633
- vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_0, m4b)),
2634
- vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_1, m4b)),
2635
- vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_2, m4b)),
2636
- vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_3, m4b)),
2637
- },
2638
- {
2639
- vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_0, 4)),
2640
- vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_1, 4)),
2641
- vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_2, 4)),
2642
- vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_3, 4)),
2643
- }
2644
- };
4816
+ const int qh_off_base = (sb * QK_K / 2) & 255; // wrap after 256 bytes
4817
+ uint8x16_t q6_qh_0[4];
4818
+ uint8x16_t q6_qh_1[4];
4819
+ for (int k = 0; k < 4; k++) {
4820
+ q6_qh_0[k] = vld1q_u8(qh_base + qh_off_base + 16 * k);
4821
+ q6_qh_1[k] = vld1q_u8(qh_base + qh_off_base + 64 + 16 * k);
4822
+ }
2645
4823
 
2646
- // Calculates the Qs muladd of every row pair (rp) rows 01 and 23 of q8
2647
- // for each of the internal 32 qs subblock (blk)
2648
- for (int rp = 0; rp < 2; rp++) {
2649
- for (int blk = 0; blk < 2; blk++) {
2650
- const int8x16_t * q8 = &q8s[rp][4 * blk];
2651
- const int8x16_t * q4 = q4_nibbles[blk];
2652
- int32x4_t acc = sb_acc[2 * rp + blk];
2653
- // mul add for each qs in the same subblock
2654
- for (int qs_offset = 0; qs_offset < 4; qs_offset++) {
2655
- acc = vmmlaq_s32(acc, q4[qs_offset], q8[qs_offset]);
2656
- }
2657
- sb_acc[2 * rp + blk] = acc;
4824
+ // Adjust for the proper high bits (Sb 2 and 3)
4825
+ if (sb > 1) {
4826
+ for (int k = 0; k < 4; k++) {
4827
+ q6_qh_0[k] = vshrq_n_u8(q6_qh_0[k], 2);
4828
+ q6_qh_1[k] = vshrq_n_u8(q6_qh_1[k], 2);
2658
4829
  }
2659
4830
  }
2660
4831
 
2661
- // Scales[i] corresponds to column i
2662
- const int scale_offset = cp * 2;
2663
- for (int blk = 0; blk < 2; blk++) {
2664
- const int32x4_t block_scale = {
2665
- (int32_t) q4sb_scales[blk][scale_offset],
2666
- (int32_t) q4sb_scales[blk][scale_offset],
2667
- (int32_t) q4sb_scales[blk][scale_offset + 1],
2668
- (int32_t) q4sb_scales[blk][scale_offset + 1],
4832
+ // Process column pairs (0-1, 2-3, 4-5, 6-7)
4833
+ for (int cp = 0; cp < ncols_interleaved / 2; cp++) {
4834
+ const uint8x16_t q6_qs_cp_0_l = q6_ql_0[cp];
4835
+ const uint8x16_t q6_qs_cp_1_l = q6_ql_1[cp];
4836
+ const uint8x16_t q6_qs_cp_0_h = q6_qh_0[cp];
4837
+ const uint8x16_t q6_qs_cp_1_h = q6_qh_1[cp];
4838
+
4839
+ // Extract high 2 bits for upper nibble reconstruction
4840
+ const uint8x16_t q6_qs_cp_0_hh = vandq_u8(q6_qs_cp_0_h, mask_hi);
4841
+ const uint8x16_t q6_qs_cp_1_hh = vandq_u8(q6_qs_cp_1_h, mask_hi);
4842
+
4843
+ // q6 = (low4 | high2<<4) - 32
4844
+ // Use vsliq_n_u8 to combine shift-left-insert in one instruction (like Q5_K)
4845
+ const int8x16_t q6_l0 = vsubq_s8(
4846
+ vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_qs_cp_0_l, m4b), vandq_u8(q6_qs_cp_0_h, mask_lo), 4)),
4847
+ m32s);
4848
+ const int8x16_t q6_l1 = vsubq_s8(
4849
+ vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_qs_cp_1_l, m4b), vandq_u8(q6_qs_cp_1_h, mask_lo), 4)),
4850
+ m32s);
4851
+ const int8x16_t q6_h0 = vsubq_s8(
4852
+ vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_0_l, 4), q6_qs_cp_0_hh)), m32s);
4853
+ const int8x16_t q6_h1 = vsubq_s8(
4854
+ vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_1_l, 4), q6_qs_cp_1_hh)), m32s);
4855
+
4856
+ // row pair 0, base_l
4857
+ int32x4_t sb_acc_0l = vmmlaq_s32(vdupq_n_s32(0), q6_l0, q8_l_01[0]);
4858
+ sb_acc_0l = vmmlaq_s32(sb_acc_0l, q6_l1, q8_l_01[1]);
4859
+ // row pair 0, base_h
4860
+ int32x4_t sb_acc_0h = vmmlaq_s32(vdupq_n_s32(0), q6_h0, q8_h_01[0]);
4861
+ sb_acc_0h = vmmlaq_s32(sb_acc_0h, q6_h1, q8_h_01[1]);
4862
+ // row pair 1, base_l
4863
+ int32x4_t sb_acc_1l = vmmlaq_s32(vdupq_n_s32(0), q6_l0, q8_l_23[0]);
4864
+ sb_acc_1l = vmmlaq_s32(sb_acc_1l, q6_l1, q8_l_23[1]);
4865
+ // row pair 1, base_h
4866
+ int32x4_t sb_acc_1h = vmmlaq_s32(vdupq_n_s32(0), q6_h0, q8_h_23[0]);
4867
+ sb_acc_1h = vmmlaq_s32(sb_acc_1h, q6_h1, q8_h_23[1]);
4868
+
4869
+ const int scale_idx_l = half * 8 + sb;
4870
+ const int scale_idx_h = half * 8 + sb + 4;
4871
+
4872
+ const int32x4_t scale_vec_l = {
4873
+ q6_scales[scale_idx_l * 8 + cp * 2 + 0],
4874
+ q6_scales[scale_idx_l * 8 + cp * 2 + 0],
4875
+ q6_scales[scale_idx_l * 8 + cp * 2 + 1],
4876
+ q6_scales[scale_idx_l * 8 + cp * 2 + 1],
4877
+ };
4878
+ const int32x4_t scale_vec_h = {
4879
+ q6_scales[scale_idx_h * 8 + cp * 2 + 0],
4880
+ q6_scales[scale_idx_h * 8 + cp * 2 + 0],
4881
+ q6_scales[scale_idx_h * 8 + cp * 2 + 1],
4882
+ q6_scales[scale_idx_h * 8 + cp * 2 + 1],
2669
4883
  };
2670
- acc[cp] = vmlaq_s32(acc[cp], sb_acc[blk], block_scale);
2671
- acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[blk + 2], block_scale);
2672
- }
2673
- }
2674
-
2675
- // Multiply Acc bsum + mins
2676
- for (int q8_row = 0; q8_row < 4; q8_row++) {
2677
- // Each pair of subblocks share the same bsums
2678
- // Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).
2679
- int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][q8_row * 2]);
2680
- int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][q8_row * 2 + 1]);
2681
4884
 
2682
- bias_acc[2 * q8_row] =
2683
- vmlal_s16(bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
2684
- bias_acc[2 * q8_row] =
2685
- vmlal_s16(bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
2686
- bias_acc[2 * q8_row + 1] =
2687
- vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
2688
- bias_acc[2 * q8_row + 1] =
2689
- vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
4885
+ acc[cp] = vmlaq_s32(acc[cp], sb_acc_0l, scale_vec_l);
4886
+ acc[cp] = vmlaq_s32(acc[cp], sb_acc_0h, scale_vec_h);
4887
+ acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc_1l, scale_vec_l);
4888
+ acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc_1h, scale_vec_h);
4889
+ }
2690
4890
  }
2691
- } // for sb
4891
+ } // for half
2692
4892
 
2693
- // Reorder of i8mm output with bias and output layout
4893
+ // Reorder i8mm output to match memory layout
2694
4894
  for (int i = 0; i < 8; i++) {
2695
4895
  int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i]));
2696
4896
  acc[i] = vcombine_s32(aux.val[0], aux.val[1]);
@@ -2706,23 +4906,20 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
2706
4906
  vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])),
2707
4907
  };
2708
4908
 
4909
+ // Apply superblock scale (no mins for q6_K)
2709
4910
  for (int i = 0; i < q8_k_blocklen; i++) {
2710
4911
  for (int j = 0; j < 2; j++) {
2711
- float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d[i]);
2712
- float32x4_t q4_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].dmin + j * 4)));
2713
- const float32x4_t dmins = vmulq_f32(q4_dmin, q8_d);
2714
-
2715
- float32x4_t q4_d = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].d + j * 4)));
2716
- const float32x4_t scale = vmulq_f32(q4_d, q8_d);
4912
+ float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d[i]);
4913
+ float32x4_t q6_d = vcvt_f32_f16(vld1_f16((const __fp16 *) (q6_ptr[b].d + j * 4)));
4914
+ const float32x4_t scale = vmulq_f32(q6_d, q8_d);
2717
4915
 
2718
- acc_f32[2 * i + j] = vmlsq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(bias_acc[2 * i + j]), dmins);
2719
4916
  acc_f32[2 * i + j] =
2720
4917
  vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale);
2721
4918
  }
2722
4919
  }
2723
4920
  } // for b
2724
4921
 
2725
- // With the previous reorder, the tile is already in the correct memory layout.
4922
+ // Store results
2726
4923
  for (int i = 0; i < q8_k_blocklen; i++) {
2727
4924
  int row = y * q8_k_blocklen + i;
2728
4925
  for (int j = 0; j < 2; j++) {
@@ -2735,10 +4932,9 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
2735
4932
  } // for y
2736
4933
  return;
2737
4934
  #endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
2738
- ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
4935
+ ggml_gemm_q6_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
2739
4936
  }
2740
4937
 
2741
-
2742
4938
  void ggml_gemm_q8_0_4x4_q8_0(int n,
2743
4939
  float * GGML_RESTRICT s,
2744
4940
  size_t bs,
@@ -2827,6 +5023,71 @@ void ggml_gemm_q8_0_4x8_q8_0(int n,
2827
5023
  UNUSED(ncols_interleaved);
2828
5024
  UNUSED(blocklen);
2829
5025
 
5026
+ #if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
5027
+ if (svcntb() * 8 == 256) {
5028
+ const block_q8_0x4 * b_ptr_base = (const block_q8_0x4 *) vx;
5029
+
5030
+ static const uint32_t idx_arr[8] = {0, 1, 4, 5, 2, 3, 6, 7};
5031
+ svuint32_t idx = svld1(svptrue_b32(), idx_arr);
5032
+ static const uint32_t idx_arr1[8] = {0, 1, 2, 3, 1, 2, 3, 0};
5033
+ svuint32_t idx_sc1 = svld1(svptrue_b32(), idx_arr1);
5034
+ static const uint32_t idx_arr2[8] = {0, 1, 2, 3, 0, 1, 2, 3};
5035
+ svuint32_t idx_sc2 = svld1(svptrue_b32(), idx_arr2);
5036
+
5037
+ for (int y = 0; y < nr; y += 4) {
5038
+ const block_q8_0x4 * a_ptr_base = (const block_q8_0x4 *) vy + (y / 4) * nb;
5039
+
5040
+ for (int x = 0; x < nc; x += ncols_interleaved) {
5041
+ const block_q8_0x4 * b_ptr = b_ptr_base + (x / 4) * nb;
5042
+ const block_q8_0x4 * a_ptr = a_ptr_base;
5043
+
5044
+ svfloat32_t acc_f32_01 = svdup_f32(0);
5045
+ svfloat32_t acc_f32_23 = svdup_f32(0);
5046
+
5047
+ for (int b = 0; b < nb; b++) {
5048
+
5049
+ svint32_t acc_01 = svdup_s32(0);
5050
+ svint32_t acc_23 = svdup_s32(0);
5051
+
5052
+ // Process 4 chunks of 8 positions each
5053
+ for (int chunk = 0; chunk < 4; chunk++) {
5054
+ svint8_t s_a01 = svld1rq_s8(svptrue_b8(), a_ptr->qs + chunk * 32);
5055
+ svint8_t s_a23 = svld1rq_s8(svptrue_b8(), a_ptr->qs + chunk * 32 + 16);
5056
+ svint8_t s_b0123 = svld1_s8(svptrue_b8(), b_ptr->qs + chunk * 32);
5057
+
5058
+ acc_01 = svmmla_s32(acc_01, s_a01, s_b0123);
5059
+ acc_23 = svmmla_s32(acc_23, s_a23, s_b0123);
5060
+ }
5061
+
5062
+ // Reorder outputs from 2×2 tiles to row-major
5063
+ // acc[01] = [r0c0, r0c1, r1c0, r1c1, r0c2, r0c3, r1c2, r1c3]
5064
+ // acc[23] = [r2c0, r2c1, r3c0, r3c1, r2c2, r2c3, r3c2, r3c3]
5065
+
5066
+ svint32_t row01 = svtbl_s32(acc_01, idx);
5067
+ svint32_t row23 = svtbl_s32(acc_23, idx);
5068
+
5069
+ svfloat16_t temp1 = svld1_f16(svptrue_pat_b16(SV_VL4), (const __fp16 *) a_ptr->d);
5070
+ svfloat16_t temp2 = svld1_f16(svptrue_pat_b16(SV_VL4), (const __fp16 *) b_ptr->d);
5071
+ svfloat32_t sv_a_d = svtbl_f32(svcvt_f32_f16_x(svptrue_b32(), svzip1_f16(temp1, temp1)), idx_sc1);
5072
+ svfloat32_t sv_b_d = svtbl_f32(svcvt_f32_f16_x(svptrue_b32(), svzip1_f16(temp2, temp2)), idx_sc2);
5073
+
5074
+ acc_f32_01 = svmla_f32_x(svptrue_b32(), acc_f32_01, svcvt_f32_s32_x(svptrue_b32(), row01), svmul_lane_f32(sv_b_d, sv_a_d, 0));
5075
+ acc_f32_23 = svmla_f32_x(svptrue_b32(), acc_f32_23, svcvt_f32_s32_x(svptrue_b32(), row23), svmul_lane_f32(sv_b_d, sv_a_d, 2));
5076
+ a_ptr++;
5077
+ b_ptr++;
5078
+ }
5079
+
5080
+ svbool_t pg4 = svptrue_pat_b32(SV_VL4);
5081
+ svst1_f32(pg4, s + (y+0) * bs + x, acc_f32_01);
5082
+ svst1_f32(pg4, s + (y+1) * bs + x, svext_f32(acc_f32_01, acc_f32_01, 4));
5083
+ svst1_f32(pg4, s + (y+2) * bs + x, acc_f32_23);
5084
+ svst1_f32(pg4, s + (y+3) * bs + x, svext_f32(acc_f32_23, acc_f32_23, 4));
5085
+ }
5086
+ }
5087
+ return;
5088
+ }
5089
+ #endif // SVE compile-time end
5090
+
2830
5091
  #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
2831
5092
  const block_q8_0x4 * b_ptr_base = (const block_q8_0x4 *) vx;
2832
5093