whispercpp 1.3.5 → 1.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (610) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +99 -2
  4. data/ext/extconf.rb +1 -0
  5. data/ext/ruby_whisper.c +20 -4
  6. data/ext/ruby_whisper.h +30 -2
  7. data/ext/ruby_whisper_context.c +216 -124
  8. data/ext/ruby_whisper_context_params.c +163 -0
  9. data/ext/ruby_whisper_model.c +0 -1
  10. data/ext/ruby_whisper_params.c +0 -1
  11. data/ext/ruby_whisper_segment.c +0 -1
  12. data/ext/ruby_whisper_token.c +29 -9
  13. data/ext/ruby_whisper_transcribe.cpp +4 -1
  14. data/ext/ruby_whisper_vad_context.c +48 -1
  15. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  16. data/ext/ruby_whisper_vad_params.c +0 -1
  17. data/ext/ruby_whisper_vad_segment.c +0 -1
  18. data/ext/ruby_whisper_vad_segments.c +0 -1
  19. data/ext/sources/CMakeLists.txt +1 -1
  20. data/ext/sources/bindings/javascript/package.json +1 -1
  21. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  22. data/ext/sources/examples/bench/bench.cpp +23 -18
  23. data/ext/sources/examples/cli/cli.cpp +8 -0
  24. data/ext/sources/examples/common-ggml.cpp +2 -0
  25. data/ext/sources/examples/miniaudio.h +4507 -2131
  26. data/ext/sources/examples/server/server.cpp +18 -4
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -2
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +7 -13
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +4 -3
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +335 -17
  31. data/ext/sources/examples/talk-llama/llama-arch.h +42 -0
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +3 -1
  33. data/ext/sources/examples/talk-llama/llama-chat.cpp +21 -1
  34. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  35. data/ext/sources/examples/talk-llama/llama-context.cpp +508 -520
  36. data/ext/sources/examples/talk-llama/llama-context.h +27 -28
  37. data/ext/sources/examples/talk-llama/llama-cparams.h +5 -0
  38. data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +8 -8
  40. data/ext/sources/examples/talk-llama/llama-graph.cpp +583 -130
  41. data/ext/sources/examples/talk-llama/llama-graph.h +131 -10
  42. data/ext/sources/examples/talk-llama/llama-hparams.cpp +57 -40
  43. data/ext/sources/examples/talk-llama/llama-hparams.h +79 -10
  44. data/ext/sources/examples/talk-llama/llama-impl.cpp +4 -4
  45. data/ext/sources/examples/talk-llama/llama-impl.h +13 -1
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +3 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +274 -89
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.h +2 -3
  49. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
  50. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
  51. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +11 -13
  52. data/ext/sources/examples/talk-llama/llama-mmap.cpp +28 -11
  53. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +527 -119
  54. data/ext/sources/examples/talk-llama/llama-model-loader.h +35 -5
  55. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +60 -46
  56. data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
  57. data/ext/sources/examples/talk-llama/llama-model.cpp +1365 -647
  58. data/ext/sources/examples/talk-llama/llama-model.h +72 -19
  59. data/ext/sources/examples/talk-llama/llama-quant.cpp +578 -346
  60. data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +190 -76
  61. data/ext/sources/examples/talk-llama/{llama-sampling.h → llama-sampler.h} +0 -2
  62. data/ext/sources/examples/talk-llama/llama-vocab.cpp +118 -48
  63. data/ext/sources/examples/talk-llama/llama-vocab.h +5 -0
  64. data/ext/sources/examples/talk-llama/llama.cpp +76 -22
  65. data/ext/sources/examples/talk-llama/llama.h +63 -30
  66. data/ext/sources/examples/talk-llama/models/afmoe.cpp +2 -3
  67. data/ext/sources/examples/talk-llama/models/apertus.cpp +3 -3
  68. data/ext/sources/examples/talk-llama/models/arcee.cpp +3 -3
  69. data/ext/sources/examples/talk-llama/models/arctic.cpp +4 -5
  70. data/ext/sources/examples/talk-llama/models/baichuan.cpp +4 -3
  71. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +1 -2
  72. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +3 -5
  73. data/ext/sources/examples/talk-llama/models/bert.cpp +13 -7
  74. data/ext/sources/examples/talk-llama/models/bitnet.cpp +9 -24
  75. data/ext/sources/examples/talk-llama/models/bloom.cpp +2 -2
  76. data/ext/sources/examples/talk-llama/models/chameleon.cpp +3 -3
  77. data/ext/sources/examples/talk-llama/models/chatglm.cpp +2 -2
  78. data/ext/sources/examples/talk-llama/models/codeshell.cpp +3 -3
  79. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +3 -3
  80. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +2 -2
  81. data/ext/sources/examples/talk-llama/models/command-r.cpp +2 -2
  82. data/ext/sources/examples/talk-llama/models/dbrx.cpp +4 -5
  83. data/ext/sources/examples/talk-llama/models/deci.cpp +3 -3
  84. data/ext/sources/examples/talk-llama/models/deepseek.cpp +4 -6
  85. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +24 -21
  86. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
  87. data/ext/sources/examples/talk-llama/models/dots1.cpp +4 -6
  88. data/ext/sources/examples/talk-llama/models/dream.cpp +3 -3
  89. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +4 -6
  90. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +3 -3
  91. data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
  92. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +3 -3
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +3 -3
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +2 -4
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +3 -3
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +1 -1
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +1 -1
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +1 -1
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +1 -1
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +7 -7
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +3 -3
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +14 -7
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +2 -2
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +2 -2
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +4 -5
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +4 -5
  108. data/ext/sources/examples/talk-llama/models/grok.cpp +4 -4
  109. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +5 -7
  110. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +3 -3
  111. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +4 -5
  112. data/ext/sources/examples/talk-llama/models/internlm2.cpp +3 -3
  113. data/ext/sources/examples/talk-llama/models/jais.cpp +2 -2
  114. data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +3 -3
  116. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
  117. data/ext/sources/examples/talk-llama/models/lfm2.cpp +145 -124
  118. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +4 -4
  119. data/ext/sources/examples/talk-llama/models/llada.cpp +3 -3
  120. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +4 -4
  121. data/ext/sources/examples/talk-llama/models/llama.cpp +18 -11
  122. data/ext/sources/examples/talk-llama/models/maincoder.cpp +3 -3
  123. data/ext/sources/examples/talk-llama/models/{graph-context-mamba.cpp → mamba-base.cpp} +9 -3
  124. data/ext/sources/examples/talk-llama/models/mamba.cpp +1 -2
  125. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +11 -5
  126. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +14 -13
  127. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +4 -5
  128. data/ext/sources/examples/talk-llama/models/mistral3.cpp +4 -4
  129. data/ext/sources/examples/talk-llama/models/models.h +181 -46
  130. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +2 -9
  131. data/ext/sources/examples/talk-llama/models/mpt.cpp +2 -2
  132. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +26 -14
  133. data/ext/sources/examples/talk-llama/models/nemotron.cpp +3 -3
  134. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +2 -2
  135. data/ext/sources/examples/talk-llama/models/olmo.cpp +3 -3
  136. data/ext/sources/examples/talk-llama/models/olmo2.cpp +3 -3
  137. data/ext/sources/examples/talk-llama/models/olmoe.cpp +4 -4
  138. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +1 -1
  139. data/ext/sources/examples/talk-llama/models/openelm.cpp +3 -3
  140. data/ext/sources/examples/talk-llama/models/orion.cpp +3 -3
  141. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
  142. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +3 -3
  143. data/ext/sources/examples/talk-llama/models/phi2.cpp +2 -2
  144. data/ext/sources/examples/talk-llama/models/phi3.cpp +3 -3
  145. data/ext/sources/examples/talk-llama/models/plamo.cpp +3 -3
  146. data/ext/sources/examples/talk-llama/models/plamo2.cpp +9 -5
  147. data/ext/sources/examples/talk-llama/models/plamo3.cpp +2 -2
  148. data/ext/sources/examples/talk-llama/models/plm.cpp +15 -14
  149. data/ext/sources/examples/talk-llama/models/qwen.cpp +2 -2
  150. data/ext/sources/examples/talk-llama/models/qwen2.cpp +3 -3
  151. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +4 -4
  152. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +3 -3
  153. data/ext/sources/examples/talk-llama/models/qwen3.cpp +12 -9
  154. data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
  155. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
  156. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +15 -8
  157. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +84 -432
  158. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +9 -18
  159. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +8 -17
  160. data/ext/sources/examples/talk-llama/models/refact.cpp +2 -2
  161. data/ext/sources/examples/talk-llama/models/rnd1.cpp +4 -4
  162. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +2 -0
  163. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +2 -0
  164. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +3 -3
  165. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +4 -4
  166. data/ext/sources/examples/talk-llama/models/smollm3.cpp +3 -3
  167. data/ext/sources/examples/talk-llama/models/stablelm.cpp +2 -2
  168. data/ext/sources/examples/talk-llama/models/starcoder.cpp +2 -2
  169. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +3 -3
  170. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
  171. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +2 -2
  172. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +2 -2
  173. data/ext/sources/examples/talk-llama/models/xverse.cpp +3 -3
  174. data/ext/sources/examples/talk-llama/unicode.cpp +21 -65
  175. data/ext/sources/ggml/CMakeLists.txt +9 -3
  176. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  177. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  179. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  180. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  181. data/ext/sources/ggml/include/ggml-rpc.h +6 -1
  182. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  183. data/ext/sources/ggml/include/ggml.h +56 -9
  184. data/ext/sources/ggml/src/CMakeLists.txt +3 -0
  185. data/ext/sources/ggml/src/ggml-alloc.c +4 -9
  186. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  187. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  188. data/ext/sources/ggml/src/ggml-backend-reg.cpp +28 -86
  189. data/ext/sources/ggml/src/ggml-backend.cpp +5 -2
  190. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  191. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +6 -2
  192. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  193. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  194. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +348 -189
  195. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +40 -85
  196. data/ext/sources/ggml/src/ggml-cann/common.h +3 -4
  197. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +44 -62
  198. data/ext/sources/ggml/src/ggml-common.h +11 -0
  199. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +16 -11
  200. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -19
  201. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  202. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
  203. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +85 -1
  204. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2744 -548
  205. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1653 -0
  206. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
  207. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
  208. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
  209. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +118 -18
  210. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +107 -26
  211. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  212. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  213. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -0
  214. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +59 -12
  215. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +15 -0
  216. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  217. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +965 -252
  218. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +584 -197
  219. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +903 -188
  220. data/ext/sources/ggml/src/ggml-cpu/ops.h +1 -0
  221. data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
  222. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  223. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2890 -679
  224. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  225. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  226. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +111 -3
  227. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  228. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +17 -0
  229. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
  230. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +19 -10
  231. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +32 -30
  232. data/ext/sources/ggml/src/ggml-cuda/common.cuh +134 -18
  233. data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +6 -3
  235. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +78 -64
  236. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +384 -143
  237. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +36 -22
  238. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +3 -3
  239. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +26 -5
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  241. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +127 -12
  242. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
  243. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  244. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +595 -200
  245. data/ext/sources/ggml/src/ggml-cuda/mean.cu +9 -8
  246. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +173 -6
  247. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  248. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +158 -85
  249. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +34 -22
  250. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +127 -67
  251. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  252. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +157 -65
  253. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -0
  254. data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
  255. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  256. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +1 -1
  257. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
  258. data/ext/sources/ggml/src/ggml-cuda/rope.cu +233 -133
  259. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  260. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  261. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +56 -32
  262. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
  264. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  265. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
  266. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  267. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  268. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  269. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +3 -3
  270. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +0 -1
  271. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +199 -135
  272. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  273. data/ext/sources/ggml/src/ggml-cuda/unary.cu +55 -0
  274. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
  275. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
  276. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +10 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +82 -45
  278. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +334 -160
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +7 -5
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +328 -197
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +765 -234
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +412 -265
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +23 -23
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +1 -1
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.h → hex-dma.h} +28 -3
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +1 -1
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +27 -37
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +6 -35
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  297. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +20 -1347
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +211 -13
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +1119 -952
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +254 -244
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +36 -36
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +155 -138
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +209 -114
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  317. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  321. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +6 -0
  322. data/ext/sources/ggml/src/ggml-impl.h +62 -0
  323. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  324. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  325. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  326. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  327. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +274 -73
  328. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +22 -4
  329. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +102 -36
  330. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +174 -23
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +580 -280
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +5 -4
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +320 -107
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1068 -825
  335. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +19 -1
  336. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +3108 -636
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +204 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  367. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  368. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  369. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
  370. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  371. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
  372. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  373. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
  374. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
  375. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  376. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  377. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  378. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  379. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  380. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  381. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  382. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  383. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  384. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  385. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  386. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  387. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  388. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  389. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  390. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  391. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  392. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
  393. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  394. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  395. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  396. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  397. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  398. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  399. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
  400. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
  401. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
  402. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
  403. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  404. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  405. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  406. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  407. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
  412. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
  413. data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
  414. data/ext/sources/ggml/src/ggml-quants.c +96 -5
  415. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  416. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +15 -88
  417. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  418. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +1 -0
  419. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  420. data/ext/sources/ggml/src/ggml-sycl/common.hpp +315 -10
  421. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +69 -1
  422. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
  423. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  424. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  425. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +78 -68
  426. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  427. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
  428. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
  429. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
  430. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
  431. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
  432. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  433. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
  434. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
  435. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +316 -51
  436. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  437. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  438. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  439. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
  440. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  441. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  442. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  443. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  444. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  445. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  446. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  447. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  448. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  449. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  450. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  451. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  452. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
  453. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
  454. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
  455. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
  456. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
  457. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
  458. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
  459. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
  460. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
  461. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
  462. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
  463. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
  464. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
  465. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
  466. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
  467. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
  468. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
  469. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
  470. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
  471. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
  472. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
  473. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
  474. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
  475. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
  476. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
  477. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
  478. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
  479. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
  480. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
  481. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
  482. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
  483. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
  484. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
  485. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
  486. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
  487. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
  488. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +13 -0
  489. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  490. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  491. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  492. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  493. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  494. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  495. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  496. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  497. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  498. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  499. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  500. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  501. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  502. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  503. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  504. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  505. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  506. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  507. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  508. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  509. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  510. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  511. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  512. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  513. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  514. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  515. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  516. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  517. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  518. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  519. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  520. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  521. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  522. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  523. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  524. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  525. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  526. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  527. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  528. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  529. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  530. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  531. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  532. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +1 -1
  533. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1250 -465
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +374 -170
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +66 -22
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +389 -201
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +106 -58
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -3
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +8 -4
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +3 -3
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +36 -63
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +10 -5
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +16 -10
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +55 -35
  560. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1314 -109
  561. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1660 -1371
  562. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  563. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  564. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  565. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
  566. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
  567. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
  568. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +6 -0
  569. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  570. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +40 -5
  571. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +105 -60
  572. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
  573. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +68 -257
  574. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +692 -23
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_reg_tile.tmpl.wgsl → mul_mat_reg_tile.wgsl} +28 -128
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +31 -137
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +9 -36
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  584. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
  585. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  586. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +9 -6
  587. data/ext/sources/ggml/src/ggml.c +167 -33
  588. data/ext/sources/ggml/src/gguf.cpp +229 -44
  589. data/ext/sources/src/whisper.cpp +6 -28
  590. data/sig/whisper.rbs +43 -2
  591. data/test/test_context_params.rb +82 -0
  592. data/test/test_token.rb +11 -0
  593. data/test/test_vad_context.rb +58 -8
  594. data/test/test_whisper.rb +20 -0
  595. data/whispercpp.gemspec +1 -1
  596. metadata +240 -28
  597. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  598. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  599. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  600. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  601. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  602. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  603. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  604. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  605. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  606. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  607. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  608. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  609. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  610. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
@@ -25,9 +25,8 @@
25
25
  #define UNUSED GGML_UNUSED
26
26
 
27
27
  #if defined(__aarch64__) && defined(__ARM_NEON) && (defined(__ARM_FEATURE_MATMUL_INT8) || defined(__ARM_FEATURE_DOTPROD))
28
- static inline void decode_q4_Kx8_scales_mins(const uint8_t * scales_in,
29
- int16x8_t * out_mins,
30
- int8_t * out_scales) {
28
+ // Helper for decoding scales and mins of Q4_K and Q5_K block formats
29
+ static inline void decode_q_Kx8_6bit_scales(const uint8_t * scales_in, int16x8_t * out_mins, int8_t * out_scales) {
31
30
  constexpr uint32_t kmask1 = 0x3f3f3f3f;
32
31
  constexpr uint32_t kmask2 = 0x0f0f0f0f;
33
32
  constexpr uint32_t kmask3 = 0x03030303;
@@ -499,6 +498,81 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
499
498
  ggml_gemv_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
500
499
  }
501
500
 
501
+ void ggml_gemv_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
502
+ const int qk = QK8_0;
503
+ const int nb = n / qk;
504
+ const int ncols_interleaved = 4;
505
+ const int blocklen = 4;
506
+
507
+ assert (n % qk == 0);
508
+ assert (nc % ncols_interleaved == 0);
509
+
510
+ UNUSED(s);
511
+ UNUSED(bs);
512
+ UNUSED(vx);
513
+ UNUSED(vy);
514
+ UNUSED(nr);
515
+ UNUSED(nc);
516
+ UNUSED(nb);
517
+ UNUSED(ncols_interleaved);
518
+ UNUSED(blocklen);
519
+
520
+ #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
521
+ const int8x16_t kvalues = vld1q_s8(kvalues_mxfp4);
522
+ const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
523
+ float * res_ptr = s;
524
+
525
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
526
+ const block_mxfp4x4 * b_ptr = (const block_mxfp4x4 *) vx + (x * nb);
527
+
528
+ float32x4_t sumf = vdupq_n_f32(0);
529
+ for (int l = 0; l < nb; l++) {
530
+ uint8x16_t b_0 = vld1q_u8(b_ptr[l].qs + 0);
531
+ uint8x16_t b_1 = vld1q_u8(b_ptr[l].qs + 16);
532
+ uint8x16_t b_2 = vld1q_u8(b_ptr[l].qs + 32);
533
+ uint8x16_t b_3 = vld1q_u8(b_ptr[l].qs + 48);
534
+
535
+ int8x16_t b_0_hi = vqtbl1q_s8(kvalues, b_0 >> 4);
536
+ int8x16_t b_0_lo = vqtbl1q_s8(kvalues, b_0 & 0x0F);
537
+ int8x16_t b_1_hi = vqtbl1q_s8(kvalues, b_1 >> 4);
538
+ int8x16_t b_1_lo = vqtbl1q_s8(kvalues, b_1 & 0x0F);
539
+ int8x16_t b_2_hi = vqtbl1q_s8(kvalues, b_2 >> 4);
540
+ int8x16_t b_2_lo = vqtbl1q_s8(kvalues, b_2 & 0x0F);
541
+ int8x16_t b_3_hi = vqtbl1q_s8(kvalues, b_3 >> 4);
542
+ int8x16_t b_3_lo = vqtbl1q_s8(kvalues, b_3 & 0x0F);
543
+
544
+ int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 0);
545
+ int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16);
546
+
547
+ int32x4_t sumi = vdupq_n_s32(0);
548
+ sumi = vdotq_laneq_s32(sumi, b_0_lo, a_0, 0);
549
+ sumi = vdotq_laneq_s32(sumi, b_0_hi, a_1, 0);
550
+ sumi = vdotq_laneq_s32(sumi, b_1_lo, a_0, 1);
551
+ sumi = vdotq_laneq_s32(sumi, b_1_hi, a_1, 1);
552
+ sumi = vdotq_laneq_s32(sumi, b_2_lo, a_0, 2);
553
+ sumi = vdotq_laneq_s32(sumi, b_2_hi, a_1, 2);
554
+ sumi = vdotq_laneq_s32(sumi, b_3_lo, a_0, 3);
555
+ sumi = vdotq_laneq_s32(sumi, b_3_hi, a_1, 3);
556
+
557
+ float32x4_t a_d = vcvt_f32_f16(vld1_dup_f16((const float16_t *)&a_ptr[l].d));
558
+ float32x4_t b_d = {
559
+ GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[0]),
560
+ GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[1]),
561
+ GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[2]),
562
+ GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[3]),
563
+ };
564
+ float32x4_t d = a_d * b_d;
565
+
566
+ sumf = vmlaq_f32(sumf, d, vcvtq_f32_s32(sumi));
567
+ }
568
+
569
+ vst1q_f32(res_ptr + x * 4, sumf);
570
+ }
571
+ return;
572
+ #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
573
+ ggml_gemv_mxfp4_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
574
+ }
575
+
502
576
  void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
503
577
  constexpr int qk = QK_K;
504
578
  const int nb = n / qk;
@@ -561,7 +635,7 @@ void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
561
635
  for (int i = 0; i < 2; i++) {
562
636
  int8_t aux_q4sb[8];
563
637
  const int offset = sb * 24 + i * 12;
564
- decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
638
+ decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
565
639
  q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
566
640
  }
567
641
 
@@ -701,13 +775,13 @@ void ggml_gemv_q4_K_8x8_q8_K(int n,
701
775
  for (int i = 0; i < 2; i++) {
702
776
  int8_t aux_q4sb[8];
703
777
  const int offset = sb * 24 + i * 12;
704
- decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
778
+ decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
705
779
  q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
706
780
  }
707
781
 
708
782
  const uint8_t * q4_base = q4_ptr[b].qs + sb * QK_K;
709
783
 
710
- // Load the 64 quants from q8K duplicated to use vecdots with the interelaved columns
784
+ // Load the 64 quants from q8K duplicated to use vecdots with the interleaved columns
711
785
  // but still need the qs to use the low and hi bits from q4
712
786
  const int8_t * q8_base = q8_ptr[b].qs + sb * 64;
713
787
  int8x16_t q8_qs[8];
@@ -786,17 +860,18 @@ void ggml_gemv_q4_K_8x8_q8_K(int n,
786
860
  ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
787
861
  }
788
862
 
789
- void ggml_gemv_q8_0_4x4_q8_0(int n,
863
+ void ggml_gemv_q5_K_8x4_q8_K(int n,
790
864
  float * GGML_RESTRICT s,
791
865
  size_t bs,
792
866
  const void * GGML_RESTRICT vx,
793
867
  const void * GGML_RESTRICT vy,
794
868
  int nr,
795
869
  int nc) {
796
- const int qk = QK8_0;
797
- const int nb = n / qk;
798
- const int ncols_interleaved = 4;
799
- const int blocklen = 4;
870
+ constexpr int qk = QK_K;
871
+ const int nb = n / qk;
872
+
873
+ constexpr int ncols_interleaved = 8;
874
+ constexpr int blocklen = 4;
800
875
 
801
876
  assert(n % qk == 0);
802
877
  assert(nc % ncols_interleaved == 0);
@@ -806,55 +881,156 @@ void ggml_gemv_q8_0_4x4_q8_0(int n,
806
881
  UNUSED(blocklen);
807
882
 
808
883
  #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
809
- const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx;
884
+ constexpr int col_groups = ncols_interleaved / 4; // 0123 and 4567
885
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
886
+ const uint8x16_t mone = vdupq_n_u8(1);
887
+ const uint8x16_t mtwo = vdupq_n_u8(2);
888
+
889
+ // 1x8 tile = 2 x 4
890
+ float32x4_t acc_f32[col_groups];
891
+
892
+ const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
893
+
894
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
895
+ const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);
896
+
897
+ for (int i = 0; i < col_groups; i++) {
898
+ acc_f32[i] = vdupq_n_f32(0);
899
+ }
810
900
 
811
- for (int c = 0; c < nc; c += ncols_interleaved) {
812
- const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
813
- float32x4_t acc = vdupq_n_f32(0);
814
901
  for (int b = 0; b < nb; b++) {
815
- int8x16x4_t b_low = vld1q_s8_x4((const int8_t *) b_ptr->qs);
816
- int8x16x4_t b_high = vld1q_s8_x4((const int8_t *) b_ptr->qs + 64);
817
- float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);
902
+ float32x4_t q5_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d)); // d0 d1 d2 d3
903
+ float32x4_t q5_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4)); // d4 d5 d6 d7
904
+ float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
905
+ float32x4_t sb_scale_0123 = vmulq_f32(q5_d_0, q8_d);
906
+ float32x4_t sb_scale_4567 = vmulq_f32(q5_d_1, q8_d);
907
+ float32x4_t q5_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin)); // dmin 0..3
908
+ float32x4_t q5_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4)); // dmin 4..7
909
+ float32x4_t sb_min_0123 = vmulq_f32(q5_dmin_0, q8_d);
910
+ float32x4_t sb_min_4567 = vmulq_f32(q5_dmin_1, q8_d);
818
911
 
819
- int8x16x2_t a = vld1q_s8_x2(a_ptr->qs);
820
- float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
912
+ // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567
913
+ int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
914
+ int32x4_t acc_lo[col_groups];
915
+ int32x4_t acc_hi[col_groups];
821
916
 
822
- int32x4_t ret = vdupq_n_s32(0);
917
+ // Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
918
+ const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
919
+ int16_t bsums_arr[8];
920
+ vst1q_s16(bsums_arr, bsums);
823
921
 
824
- ret = vdotq_laneq_s32(ret, b_low.val[0], a.val[0], 0);
825
- ret = vdotq_laneq_s32(ret, b_low.val[1], a.val[0], 1);
826
- ret = vdotq_laneq_s32(ret, b_low.val[2], a.val[0], 2);
827
- ret = vdotq_laneq_s32(ret, b_low.val[3], a.val[0], 3);
922
+ uint8x16_t qh[col_groups][8];
923
+ for (int c = 0; c < col_groups; c++) {
924
+ for (int i = 0; i < 8; i++) {
925
+ qh[c][i] = vld1q_u8(q5_ptr[b].qh + i * 32 + 16 * c);
926
+ }
927
+ }
828
928
 
829
- ret = vdotq_laneq_s32(ret, b_high.val[0], a.val[1], 0);
830
- ret = vdotq_laneq_s32(ret, b_high.val[1], a.val[1], 1);
831
- ret = vdotq_laneq_s32(ret, b_high.val[2], a.val[1], 2);
832
- ret = vdotq_laneq_s32(ret, b_high.val[3], a.val[1], 3);
929
+ for (int sb = 0; sb < QK_K / 64; sb++) {
930
+ for (int i = 0; i < col_groups; i++) {
931
+ acc_lo[i] = vdupq_n_s32(0);
932
+ acc_hi[i] = vdupq_n_s32(0);
933
+ }
934
+ // Need scales for the low and high nibbles
935
+ // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
936
+ int16x8_t q5sb_mins[2];
937
+ int16x8_t q5sb_scales[2];
938
+ for (int i = 0; i < 2; i++) {
939
+ int8_t aux_q5sb[8];
940
+ const int offset = sb * 24 + i * 12;
941
+ decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb);
942
+ q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb));
943
+ }
833
944
 
834
- acc = vfmaq_f32(acc, vcvtq_f32_s32(ret), vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
835
- a_ptr++;
836
- b_ptr++;
837
- }
838
- vst1q_f32(s, acc);
839
- s += ncols_interleaved;
840
- }
841
- return;
945
+ int8x16_t q8_qs[4];
946
+ for (int i = 0; i < 4; i++) {
947
+ q8_qs[i] = vld1q_s8(q8_ptr[b].qs + sb * 64 + i * 16);
948
+ }
949
+
950
+ for (int c = 0; c < col_groups; c++) {
951
+ uint8x16_t q5_cols[8];
952
+ uint8x16_t hbit_lo[8];
953
+ uint8x16_t hbit_hi[8];
954
+ int8x16_t q5_lo[8];
955
+ int8x16_t q5_hi[8];
956
+
957
+ for (int i = 0; i < 8; i++) {
958
+ q5_cols[i] = vld1q_u8(q5_ptr[b].qs + sb * QK_K + i * 32 + 16 * c);
959
+ hbit_lo[i] = vandq_u8(qh[c][i], mone);
960
+ hbit_hi[i] = vshlq_n_u8(vandq_u8(qh[c][i], mtwo), 3);
961
+ qh[c][i] = vshrq_n_u8(qh[c][i], 2);
962
+ q5_lo[i] = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_cols[i], m4b), hbit_lo[i], 4));
963
+ q5_hi[i] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_cols[i], 4), hbit_hi[i]));
964
+ }
965
+
966
+ acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[0], q8_qs[0], 0);
967
+ acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[1], q8_qs[0], 1);
968
+ acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[2], q8_qs[0], 2);
969
+ acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[3], q8_qs[0], 3);
970
+ acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[4], q8_qs[1], 0);
971
+ acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[5], q8_qs[1], 1);
972
+ acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[6], q8_qs[1], 2);
973
+ acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[7], q8_qs[1], 3);
974
+
975
+ acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[0], q8_qs[2], 0);
976
+ acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[1], q8_qs[2], 1);
977
+ acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[2], q8_qs[2], 2);
978
+ acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[3], q8_qs[2], 3);
979
+ acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[4], q8_qs[3], 0);
980
+ acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[5], q8_qs[3], 1);
981
+ acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[6], q8_qs[3], 2);
982
+ acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[7], q8_qs[3], 3);
983
+ }
984
+
985
+ // Scales
986
+ // row c0123 blk0 and blk1
987
+ const int16x4_t sc_0123_lo = vget_low_s16(q5sb_scales[0]);
988
+ const int16x4_t sc_0123_hi = vget_low_s16(q5sb_scales[1]);
989
+ const float32x4_t sumf_0123 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[0]),
990
+ vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[0])));
991
+ acc_f32[0] = vfmaq_f32(acc_f32[0], sb_scale_0123, sumf_0123);
992
+ // row c4567 blk0 and blk1
993
+ const int16x4_t sc_4567_lo = vget_high_s16(q5sb_scales[0]);
994
+ const int16x4_t sc_4567_hi = vget_high_s16(q5sb_scales[1]);
995
+ const float32x4_t sumf_4567 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[1]),
996
+ vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[1])));
997
+ acc_f32[1] = vfmaq_f32(acc_f32[1], sb_scale_4567, sumf_4567);
998
+
999
+ // Bias Correction
1000
+ const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
1001
+ const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
1002
+
1003
+ bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q5sb_mins[0]));
1004
+ bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q5sb_mins[1]));
1005
+ bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q5sb_mins[0]));
1006
+ bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q5sb_mins[1]));
1007
+ } // for sb
1008
+
1009
+ acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0123);
1010
+ acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_4567);
1011
+ } // for b
842
1012
 
1013
+ int base = x * ncols_interleaved;
1014
+ vst1q_f32(s + base, acc_f32[0]);
1015
+ vst1q_f32(s + base + 4, acc_f32[1]);
1016
+ } // for x
1017
+ return;
843
1018
  #endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
844
- ggml_gemv_q8_0_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
1019
+ ggml_gemv_q5_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
845
1020
  }
846
1021
 
847
- void ggml_gemv_q8_0_4x8_q8_0(int n,
1022
+ void ggml_gemv_q5_K_8x8_q8_K(int n,
848
1023
  float * GGML_RESTRICT s,
849
1024
  size_t bs,
850
1025
  const void * GGML_RESTRICT vx,
851
1026
  const void * GGML_RESTRICT vy,
852
1027
  int nr,
853
1028
  int nc) {
854
- const int qk = QK8_0;
855
- const int nb = n / qk;
856
- const int ncols_interleaved = 4;
857
- const int blocklen = 8;
1029
+ constexpr int qk = QK_K;
1030
+ const int nb = n / qk;
1031
+
1032
+ constexpr int ncols_interleaved = 8;
1033
+ constexpr int blocklen = 8;
858
1034
 
859
1035
  assert(n % qk == 0);
860
1036
  assert(nc % ncols_interleaved == 0);
@@ -864,269 +1040,1003 @@ void ggml_gemv_q8_0_4x8_q8_0(int n,
864
1040
  UNUSED(blocklen);
865
1041
 
866
1042
  #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
867
- const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx;
1043
+ constexpr int col_pairs = ncols_interleaved / 2;
1044
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
1045
+ const uint8x16_t mone = vdupq_n_u8(1);
1046
+ const uint8x16_t mtwo = vdupq_n_u8(2);
868
1047
 
869
- for (int c = 0; c < nc; c += ncols_interleaved) {
870
- const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
871
- float32x4_t acc = vdupq_n_f32(0);
1048
+ // 1x8 tile = 2 x 4
1049
+ float32x4_t acc_f32[ncols_interleaved / 4];
872
1050
 
873
- for (int b = 0; b < nb; b++) {
874
- int8x16x4_t b_low = vld1q_s8_x4((const int8_t *) b_ptr->qs);
875
- int8x16x4_t b_high = vld1q_s8_x4((const int8_t *) b_ptr->qs + 64);
876
- float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);
1051
+ const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
877
1052
 
878
- int8x8x4_t a_chunks = vld1_s8_x4(a_ptr->qs);
879
- int8x16_t a0 = vcombine_s8(a_chunks.val[0], a_chunks.val[0]);
880
- int8x16_t a1 = vcombine_s8(a_chunks.val[1], a_chunks.val[1]);
881
- int8x16_t a2 = vcombine_s8(a_chunks.val[2], a_chunks.val[2]);
882
- int8x16_t a3 = vcombine_s8(a_chunks.val[3], a_chunks.val[3]);
883
- float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
1053
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
1054
+ const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);
884
1055
 
885
- int32x4_t ret0 = vdupq_n_s32(0);
886
- int32x4_t ret1 = vdupq_n_s32(0);
1056
+ for (int i = 0; i < ncols_interleaved / 4; i++) {
1057
+ acc_f32[i] = vdupq_n_f32(0);
1058
+ }
887
1059
 
888
- // 0..7
889
- ret0 = vdotq_s32(ret0, b_low.val[0], a0);
890
- ret1 = vdotq_s32(ret1, b_low.val[1], a0);
891
- // 8..15
892
- ret0 = vdotq_s32(ret0, b_low.val[2], a1);
893
- ret1 = vdotq_s32(ret1, b_low.val[3], a1);
894
- // 16..23
895
- ret0 = vdotq_s32(ret0, b_high.val[0], a2);
896
- ret1 = vdotq_s32(ret1, b_high.val[1], a2);
897
- // 24..31
898
- ret0 = vdotq_s32(ret0, b_high.val[2], a3);
899
- ret1 = vdotq_s32(ret1, b_high.val[3], a3);
1060
+ for (int b = 0; b < nb; b++) {
1061
+ float32x4_t q5_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d)); // d0 d1 d2 d3
1062
+ float32x4_t q5_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4)); // d4 d5 d6 d7
1063
+ float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
1064
+ float32x4_t sb_scale_0 = vmulq_f32(q5_d_0, q8_d);
1065
+ float32x4_t sb_scale_1 = vmulq_f32(q5_d_1, q8_d);
1066
+ float32x4_t q5_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin)); // dmin 0..3
1067
+ float32x4_t q5_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4)); // dmin 4..7
1068
+ float32x4_t sb_min_0 = vmulq_f32(q5_dmin_0, q8_d);
1069
+ float32x4_t sb_min_1 = vmulq_f32(q5_dmin_1, q8_d);
900
1070
 
901
- int32x4_t ret = vpaddq_s32(ret0, ret1);
1071
+ // 2 sb each iteration
1072
+ int32x4_t acc_lo[col_pairs];
1073
+ int32x4_t acc_hi[col_pairs];
902
1074
 
903
- acc = vfmaq_f32(acc, vcvtq_f32_s32(ret), vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
904
- a_ptr++;
905
- b_ptr++;
906
- }
907
- vst1q_f32(s, acc);
908
- s += ncols_interleaved;
909
- }
910
- return;
1075
+ // Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
1076
+ const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
1077
+ int16_t bsums_arr[8];
1078
+ vst1q_s16(bsums_arr, bsums);
911
1079
 
912
- #endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
913
- ggml_gemv_q8_0_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
914
- }
1080
+ // Load qh once per block and shift after each subblock
1081
+ const uint8_t * qh_base = q5_ptr[b].qh;
1082
+ uint8x16_t qh[col_pairs][4];
1083
+ for (int cp = 0; cp < col_pairs; cp++) {
1084
+ qh[cp][0] = vld1q_u8(qh_base + 16 * cp);
1085
+ qh[cp][1] = vld1q_u8(qh_base + 16 * cp + 64);
1086
+ qh[cp][2] = vld1q_u8(qh_base + 16 * cp + 128);
1087
+ qh[cp][3] = vld1q_u8(qh_base + 16 * cp + 192);
1088
+ }
915
1089
 
916
- void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
917
- const int qk = QK8_0;
918
- const int nb = n / qk;
919
- const int ncols_interleaved = 4;
920
- const int blocklen = 4;
1090
+ for (int sb = 0; sb < QK_K / 64; sb++) {
1091
+ for (int i = 0; i < col_pairs; i++) {
1092
+ acc_lo[i] = vdupq_n_s32(0);
1093
+ acc_hi[i] = vdupq_n_s32(0);
1094
+ }
1095
+ // Need scales for the low and high nibbles
1096
+ // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
1097
+ int16x8_t q5sb_mins[2]; // int16 as its needed for bias_acc later
1098
+ int16x8_t q5sb_scales[2];
1099
+ for (int i = 0; i < 2; i++) {
1100
+ int8_t aux_q5sb[8];
1101
+ const int offset = sb * 24 + i * 12;
1102
+ decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb);
1103
+ q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb));
1104
+ }
921
1105
 
922
- assert (n % qk == 0);
923
- assert (nr % 4 == 0);
924
- assert (nc % ncols_interleaved == 0);
1106
+ const uint8_t * qs_base = q5_ptr[b].qs + sb * QK_K;
1107
+
1108
+ // Load the 64 quants from q8K duplicated to use vecdots with the interleaved columns
1109
+ const int8_t * q8_base = q8_ptr[b].qs + sb * 64;
1110
+ int8x16_t q8_qs[8];
1111
+ for (int i = 0; i < 8; i++) {
1112
+ q8_qs[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base + i * 8));
1113
+ }
1114
+
1115
+ // Q5s column pair loop unrolled
1116
+ {
1117
+ // Cols 01
1118
+ uint8x16_t qs_0 = vld1q_u8(qs_base);
1119
+ uint8x16_t qs_1 = vld1q_u8(qs_base + 64);
1120
+ uint8x16_t qs_2 = vld1q_u8(qs_base + 128);
1121
+ uint8x16_t qs_3 = vld1q_u8(qs_base + 192);
1122
+
1123
+ uint8x16_t hbit_lo_0 = vandq_u8(qh[0][0], mone);
1124
+ uint8x16_t hbit_lo_1 = vandq_u8(qh[0][1], mone);
1125
+ uint8x16_t hbit_lo_2 = vandq_u8(qh[0][2], mone);
1126
+ uint8x16_t hbit_lo_3 = vandq_u8(qh[0][3], mone);
1127
+ uint8x16_t hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[0][0], mtwo), 3);
1128
+ uint8x16_t hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[0][1], mtwo), 3);
1129
+ uint8x16_t hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[0][2], mtwo), 3);
1130
+ uint8x16_t hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[0][3], mtwo), 3);
1131
+
1132
+ qh[0][0] = vshrq_n_u8(qh[0][0], 2);
1133
+ qh[0][1] = vshrq_n_u8(qh[0][1], 2);
1134
+ qh[0][2] = vshrq_n_u8(qh[0][2], 2);
1135
+ qh[0][3] = vshrq_n_u8(qh[0][3], 2);
1136
+
1137
+ acc_lo[0] = ggml_vdotq_s32(
1138
+ acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
1139
+ acc_lo[0] = ggml_vdotq_s32(
1140
+ acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
1141
+ acc_lo[0] = ggml_vdotq_s32(
1142
+ acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
1143
+ acc_lo[0] = ggml_vdotq_s32(
1144
+ acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
1145
+ acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
1146
+ q8_qs[4]);
1147
+ acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
1148
+ q8_qs[5]);
1149
+ acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
1150
+ q8_qs[6]);
1151
+ acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
1152
+ q8_qs[7]);
1153
+
1154
+ // Cols 23
1155
+ qs_0 = vld1q_u8(qs_base + 16);
1156
+ qs_1 = vld1q_u8(qs_base + 80);
1157
+ qs_2 = vld1q_u8(qs_base + 144);
1158
+ qs_3 = vld1q_u8(qs_base + 208);
1159
+
1160
+ hbit_lo_0 = vandq_u8(qh[1][0], mone);
1161
+ hbit_lo_1 = vandq_u8(qh[1][1], mone);
1162
+ hbit_lo_2 = vandq_u8(qh[1][2], mone);
1163
+ hbit_lo_3 = vandq_u8(qh[1][3], mone);
1164
+ hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[1][0], mtwo), 3);
1165
+ hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[1][1], mtwo), 3);
1166
+ hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[1][2], mtwo), 3);
1167
+ hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[1][3], mtwo), 3);
1168
+
1169
+ qh[1][0] = vshrq_n_u8(qh[1][0], 2);
1170
+ qh[1][1] = vshrq_n_u8(qh[1][1], 2);
1171
+ qh[1][2] = vshrq_n_u8(qh[1][2], 2);
1172
+ qh[1][3] = vshrq_n_u8(qh[1][3], 2);
1173
+
1174
+ acc_lo[1] = ggml_vdotq_s32(
1175
+ acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
1176
+ acc_lo[1] = ggml_vdotq_s32(
1177
+ acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
1178
+ acc_lo[1] = ggml_vdotq_s32(
1179
+ acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
1180
+ acc_lo[1] = ggml_vdotq_s32(
1181
+ acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
1182
+ acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
1183
+ q8_qs[4]);
1184
+ acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
1185
+ q8_qs[5]);
1186
+ acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
1187
+ q8_qs[6]);
1188
+ acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
1189
+ q8_qs[7]);
1190
+
1191
+ // Cols 45
1192
+ qs_0 = vld1q_u8(qs_base + 32);
1193
+ qs_1 = vld1q_u8(qs_base + 96);
1194
+ qs_2 = vld1q_u8(qs_base + 160);
1195
+ qs_3 = vld1q_u8(qs_base + 224);
1196
+
1197
+ hbit_lo_0 = vandq_u8(qh[2][0], mone);
1198
+ hbit_lo_1 = vandq_u8(qh[2][1], mone);
1199
+ hbit_lo_2 = vandq_u8(qh[2][2], mone);
1200
+ hbit_lo_3 = vandq_u8(qh[2][3], mone);
1201
+ hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[2][0], mtwo), 3);
1202
+ hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[2][1], mtwo), 3);
1203
+ hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[2][2], mtwo), 3);
1204
+ hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[2][3], mtwo), 3);
1205
+
1206
+ qh[2][0] = vshrq_n_u8(qh[2][0], 2);
1207
+ qh[2][1] = vshrq_n_u8(qh[2][1], 2);
1208
+ qh[2][2] = vshrq_n_u8(qh[2][2], 2);
1209
+ qh[2][3] = vshrq_n_u8(qh[2][3], 2);
1210
+
1211
+ acc_lo[2] = ggml_vdotq_s32(
1212
+ acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
1213
+ acc_lo[2] = ggml_vdotq_s32(
1214
+ acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
1215
+ acc_lo[2] = ggml_vdotq_s32(
1216
+ acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
1217
+ acc_lo[2] = ggml_vdotq_s32(
1218
+ acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
1219
+ acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
1220
+ q8_qs[4]);
1221
+ acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
1222
+ q8_qs[5]);
1223
+ acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
1224
+ q8_qs[6]);
1225
+ acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
1226
+ q8_qs[7]);
1227
+
1228
+ // Cols 45
1229
+ qs_0 = vld1q_u8(qs_base + 48);
1230
+ qs_1 = vld1q_u8(qs_base + 112);
1231
+ qs_2 = vld1q_u8(qs_base + 176);
1232
+ qs_3 = vld1q_u8(qs_base + 240);
1233
+
1234
+ hbit_lo_0 = vandq_u8(qh[3][0], mone);
1235
+ hbit_lo_1 = vandq_u8(qh[3][1], mone);
1236
+ hbit_lo_2 = vandq_u8(qh[3][2], mone);
1237
+ hbit_lo_3 = vandq_u8(qh[3][3], mone);
1238
+ hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[3][0], mtwo), 3);
1239
+ hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[3][1], mtwo), 3);
1240
+ hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[3][2], mtwo), 3);
1241
+ hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[3][3], mtwo), 3);
1242
+
1243
+ qh[3][0] = vshrq_n_u8(qh[3][0], 2);
1244
+ qh[3][1] = vshrq_n_u8(qh[3][1], 2);
1245
+ qh[3][2] = vshrq_n_u8(qh[3][2], 2);
1246
+ qh[3][3] = vshrq_n_u8(qh[3][3], 2);
1247
+
1248
+ acc_lo[3] = ggml_vdotq_s32(
1249
+ acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
1250
+ acc_lo[3] = ggml_vdotq_s32(
1251
+ acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
1252
+ acc_lo[3] = ggml_vdotq_s32(
1253
+ acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
1254
+ acc_lo[3] = ggml_vdotq_s32(
1255
+ acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
1256
+ acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
1257
+ q8_qs[4]);
1258
+ acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
1259
+ q8_qs[5]);
1260
+ acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
1261
+ q8_qs[6]);
1262
+ acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
1263
+ q8_qs[7]);
1264
+ }
1265
+
1266
+ // Prepare bsum vectors for bias computation
1267
+ // Each pair of subblocks share the same bsums
1268
+ int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
1269
+ int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
1270
+
1271
+ // Iterates over a pair of column pairs (4 columns) to use a single 128 register
1272
+ // p = 0 -> 0123 p2 -> 4567
1273
+ for (int i = 0, p = 0; p < col_pairs; i++, p += 2) {
1274
+ int16x4_t group_scales_lo = p == 0 ? vget_low_s16(q5sb_scales[0]) : vget_high_s16(q5sb_scales[0]);
1275
+ int16x4_t group_scales_hi = p == 0 ? vget_low_s16(q5sb_scales[1]) : vget_high_s16(q5sb_scales[1]);
1276
+ int16x4_t group_mins_lo = p == 0 ? vget_low_s16(q5sb_mins[0]) : vget_high_s16(q5sb_mins[0]);
1277
+ int16x4_t group_mins_hi = p == 0 ? vget_low_s16(q5sb_mins[1]) : vget_high_s16(q5sb_mins[1]);
1278
+ float32x4_t sb_scale = p == 0 ? sb_scale_0 : sb_scale_1;
1279
+ float32x4_t sb_min = p == 0 ? sb_min_0 : sb_min_1;
1280
+
1281
+ // 0123 or 4567
1282
+ float32x4_t sumf_0 =
1283
+ vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_lo), vpaddq_s32(acc_lo[p], acc_lo[p + 1])));
1284
+ acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_0);
1285
+
1286
+ float32x4_t sumf_1 =
1287
+ vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_hi), vpaddq_s32(acc_hi[p], acc_hi[p + 1])));
1288
+ acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_1);
1289
+
1290
+ // FUSED BIAS: Compute and subtract bias immediately
1291
+ // bias = (bsums_lo * mins_lo + bsums_hi * mins_hi) * sb_min
1292
+ int32x4_t bias = vmull_s16(bsums_vec_lo, group_mins_lo);
1293
+ bias = vmlal_s16(bias, bsums_vec_hi, group_mins_hi);
1294
+ float32x4_t bias_f32 = vcvtq_f32_s32(bias);
1295
+ acc_f32[i] = vmlsq_f32(acc_f32[i], sb_min, bias_f32);
1296
+ }
1297
+ } // for sb
1298
+ } // for b
1299
+
1300
+ int base = x * ncols_interleaved;
1301
+ vst1q_f32(s + base, acc_f32[0]);
1302
+ vst1q_f32(s + base + 4, acc_f32[1]);
1303
+ } // for x
1304
+ return;
1305
+ #endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
1306
+ ggml_gemv_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
1307
+ }
1308
+
1309
+ void ggml_gemv_q6_K_8x4_q8_K(int n,
1310
+ float * GGML_RESTRICT s,
1311
+ size_t bs,
1312
+ const void * GGML_RESTRICT vx,
1313
+ const void * GGML_RESTRICT vy,
1314
+ int nr,
1315
+ int nc) {
1316
+ constexpr int qk = QK_K;
1317
+ const int nb = n / qk;
1318
+
1319
+ constexpr int ncols_interleaved = 8;
1320
+ constexpr int blocklen = 4;
1321
+
1322
+ assert(n % qk == 0);
1323
+ assert(nc % ncols_interleaved == 0);
925
1324
 
926
- UNUSED(s);
927
- UNUSED(bs);
928
- UNUSED(vx);
929
- UNUSED(vy);
930
- UNUSED(nr);
931
- UNUSED(nc);
932
1325
  UNUSED(nb);
933
1326
  UNUSED(ncols_interleaved);
934
1327
  UNUSED(blocklen);
935
1328
 
936
- #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
937
- const void * b_ptr = vx;
938
- const void * a_ptr = vy;
939
- float * res_ptr = s;
940
- size_t res_stride = bs * sizeof(float);
1329
+ #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
1330
+ constexpr int col_groups = ncols_interleaved / 4;
1331
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
1332
+ const uint8x16_t mask_lo = vdupq_n_u8(0x03);
1333
+ const uint8x16_t mask_hi = vdupq_n_u8(0x30);
941
1334
 
942
- __asm__ __volatile__(
943
- "mov x10, %x[nr]\n"
944
- "mov x9, #0x88\n"
945
- "cmp x10, #0x10\n"
946
- "mul x9, %x[nb], x9\n"
947
- "blt 4f\n"
948
- "1:" // Row loop
949
- "add x28, %x[b_ptr], #0x8\n"
950
- "mov x27, %x[nc]\n"
951
- "add x26, %x[res_ptr], %x[res_stride], LSL #4\n"
952
- "2:" // Column loop
953
- "add x25, %x[a_ptr], #0x8\n"
954
- "movi v15.16b, #0x0\n"
955
- "movi v19.16b, #0x0\n"
956
- "mov x24, %x[nb]\n"
957
- "add x23, x25, x9\n"
958
- "movi v18.16b, #0x0\n"
959
- "movi v14.16b, #0x0\n"
960
- "add x22, x23, x9\n"
961
- "movi v11.16b, #0x0\n"
962
- "movi v13.16b, #0x0\n"
963
- "add x21, x22, x9\n"
964
- "movi v23.16b, #0x0\n"
965
- "movi v16.16b, #0x0\n"
966
- "movi v25.16b, #0x0\n"
967
- "movi v7.16b, #0x0\n"
968
- "movi v0.16b, #0x0\n"
969
- "movi v4.16b, #0x0\n"
970
- "movi v5.16b, #0x0\n"
971
- "movi v21.16b, #0x0\n"
972
- "movi v8.16b, #0x0\n"
973
- "movi v1.16b, #0x0\n"
974
- "3:" // Block loop
975
- "ldr q3, [x28, #0x0]\n"
976
- "ldr q31, [x25, #0x0]\n"
977
- "movi v28.16b, #0x4\n"
978
- "movi v10.4s, #0x0\n"
979
- "ldr q22, [x28, #0x10]\n"
980
- "ldr q6, [x25, #0x10]\n"
981
- "movi v29.4s, #0x0\n"
982
- "movi v9.4s, #0x0\n"
983
- "ldr q27, [x28, #0x20]\n"
984
- "ldr q30, [x28, #0x30]\n"
985
- "movi v20.4s, #0x0\n"
986
- "movi v24.16b, #0xf0\n"
987
- "ldr d2, [x25, #-0x8]\n"
988
- "ldr d26, [x23, #-0x8]\n"
989
- "sshl v12.16b, v3.16b, v28.16b\n"
990
- "sub x20, x28, #0x8\n"
991
- "ldr d17, [x20, #0x0]\n"
992
- "and v3.16b, v3.16b, v24.16b\n"
993
- "subs x24, x24, #0x1\n"
994
- "add x28, x28, #0x48\n"
995
- ".inst 0x4f9fe18a // sdot v10.4s, v12.16b, v31.4b[0]\n"
996
- ".inst 0x4fbfe19d // sdot v29.4s, v12.16b, v31.4b[1]\n"
997
- ".inst 0x4f9fe989 // sdot v9.4s, v12.16b, v31.4b[2]\n"
998
- ".inst 0x4fbfe994 // sdot v20.4s, v12.16b, v31.4b[3]\n"
999
- "sshl v31.16b, v22.16b, v28.16b\n"
1000
- "and v22.16b, v22.16b, v24.16b\n"
1001
- "fcvtl v17.4s, v17.4h\n"
1002
- "fcvtl v2.4s, v2.4h\n"
1003
- "fcvtl v26.4s, v26.4h\n"
1004
- ".inst 0x4f86e3ea // sdot v10.4s, v31.16b, v6.4b[0]\n"
1005
- ".inst 0x4fa6e3fd // sdot v29.4s, v31.16b, v6.4b[1]\n"
1006
- ".inst 0x4f86ebe9 // sdot v9.4s, v31.16b, v6.4b[2]\n"
1007
- ".inst 0x4fa6ebf4 // sdot v20.4s, v31.16b, v6.4b[3]\n"
1008
- "sshl v6.16b, v27.16b, v28.16b\n"
1009
- "sshl v28.16b, v30.16b, v28.16b\n"
1010
- "and v27.16b, v27.16b, v24.16b\n"
1011
- "and v30.16b, v30.16b, v24.16b\n"
1012
- "ldr q24, [x25, #0x20]\n"
1013
- ".inst 0x4f98e0ca // sdot v10.4s, v6.16b, v24.4b[0]\n"
1014
- ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n"
1015
- ".inst 0x4f98e8c9 // sdot v9.4s, v6.16b, v24.4b[2]\n"
1016
- ".inst 0x4fb8e8d4 // sdot v20.4s, v6.16b, v24.4b[3]\n"
1017
- "ldr q24, [x25, #0x30]\n"
1018
- ".inst 0x4f98e38a // sdot v10.4s, v28.16b, v24.4b[0]\n"
1019
- ".inst 0x4fb8e39d // sdot v29.4s, v28.16b, v24.4b[1]\n"
1020
- ".inst 0x4f98eb89 // sdot v9.4s, v28.16b, v24.4b[2]\n"
1021
- ".inst 0x4fb8eb94 // sdot v20.4s, v28.16b, v24.4b[3]\n"
1022
- "ldr q24, [x25, #0x40]\n"
1023
- ".inst 0x4f98e06a // sdot v10.4s, v3.16b, v24.4b[0]\n"
1024
- ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n"
1025
- ".inst 0x4f98e869 // sdot v9.4s, v3.16b, v24.4b[2]\n"
1026
- ".inst 0x4fb8e874 // sdot v20.4s, v3.16b, v24.4b[3]\n"
1027
- "ldr q24, [x25, #0x50]\n"
1028
- ".inst 0x4f98e2ca // sdot v10.4s, v22.16b, v24.4b[0]\n"
1029
- ".inst 0x4fb8e2dd // sdot v29.4s, v22.16b, v24.4b[1]\n"
1030
- ".inst 0x4f98eac9 // sdot v9.4s, v22.16b, v24.4b[2]\n"
1031
- ".inst 0x4fb8ead4 // sdot v20.4s, v22.16b, v24.4b[3]\n"
1032
- "ldr q24, [x25, #0x60]\n"
1033
- ".inst 0x4f98e36a // sdot v10.4s, v27.16b, v24.4b[0]\n"
1034
- ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n"
1035
- ".inst 0x4f98eb69 // sdot v9.4s, v27.16b, v24.4b[2]\n"
1036
- ".inst 0x4fb8eb74 // sdot v20.4s, v27.16b, v24.4b[3]\n"
1037
- "ldr q24, [x25, #0x70]\n"
1038
- "add x25, x25, #0x88\n"
1039
- ".inst 0x4f98e3ca // sdot v10.4s, v30.16b, v24.4b[0]\n"
1040
- ".inst 0x4fb8e3dd // sdot v29.4s, v30.16b, v24.4b[1]\n"
1041
- ".inst 0x4f98ebc9 // sdot v9.4s, v30.16b, v24.4b[2]\n"
1042
- ".inst 0x4fb8ebd4 // sdot v20.4s, v30.16b, v24.4b[3]\n"
1043
- "fmul v24.4s, v17.4s, v2.s[0]\n"
1044
- "scvtf v10.4s, v10.4s, #0x4\n"
1045
- "scvtf v29.4s, v29.4s, #0x4\n"
1046
- "scvtf v9.4s, v9.4s, #0x4\n"
1047
- "scvtf v20.4s, v20.4s, #0x4\n"
1048
- "fmla v15.4s, v10.4s, v24.4s\n"
1049
- "ldr q24, [x23, #0x0]\n"
1050
- "fmul v10.4s, v17.4s, v2.s[1]\n"
1051
- "fmla v19.4s, v29.4s, v10.4s\n"
1052
- "ldr q10, [x23, #0x10]\n"
1053
- "fmul v29.4s, v17.4s, v2.s[2]\n"
1054
- "fmul v2.4s, v17.4s, v2.s[3]\n"
1055
- "fmla v18.4s, v9.4s, v29.4s\n"
1056
- "movi v9.4s, #0x0\n"
1057
- "movi v29.4s, #0x0\n"
1058
- ".inst 0x4f98e189 // sdot v9.4s, v12.16b, v24.4b[0]\n"
1059
- ".inst 0x4fb8e19d // sdot v29.4s, v12.16b, v24.4b[1]\n"
1060
- "fmla v14.4s, v20.4s, v2.4s\n"
1061
- "movi v20.4s, #0x0\n"
1062
- "movi v2.4s, #0x0\n"
1063
- ".inst 0x4f98e994 // sdot v20.4s, v12.16b, v24.4b[2]\n"
1064
- ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n"
1065
- "ldr q24, [x23, #0x20]\n"
1066
- ".inst 0x4f8ae3e9 // sdot v9.4s, v31.16b, v10.4b[0]\n"
1067
- ".inst 0x4faae3fd // sdot v29.4s, v31.16b, v10.4b[1]\n"
1068
- ".inst 0x4f8aebf4 // sdot v20.4s, v31.16b, v10.4b[2]\n"
1069
- ".inst 0x4faaebe2 // sdot v2.4s, v31.16b, v10.4b[3]\n"
1070
- "ldr q10, [x23, #0x30]\n"
1071
- ".inst 0x4f98e0c9 // sdot v9.4s, v6.16b, v24.4b[0]\n"
1072
- ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n"
1073
- ".inst 0x4f98e8d4 // sdot v20.4s, v6.16b, v24.4b[2]\n"
1074
- ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n"
1075
- "ldr q24, [x23, #0x40]\n"
1076
- ".inst 0x4f8ae389 // sdot v9.4s, v28.16b, v10.4b[0]\n"
1077
- ".inst 0x4faae39d // sdot v29.4s, v28.16b, v10.4b[1]\n"
1078
- ".inst 0x4f8aeb94 // sdot v20.4s, v28.16b, v10.4b[2]\n"
1079
- ".inst 0x4faaeb82 // sdot v2.4s, v28.16b, v10.4b[3]\n"
1080
- "ldr q10, [x23, #0x50]\n"
1081
- ".inst 0x4f98e069 // sdot v9.4s, v3.16b, v24.4b[0]\n"
1082
- ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n"
1083
- ".inst 0x4f98e874 // sdot v20.4s, v3.16b, v24.4b[2]\n"
1084
- ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n"
1085
- "ldr q24, [x23, #0x60]\n"
1086
- ".inst 0x4f8ae2c9 // sdot v9.4s, v22.16b, v10.4b[0]\n"
1087
- ".inst 0x4faae2dd // sdot v29.4s, v22.16b, v10.4b[1]\n"
1088
- ".inst 0x4f8aead4 // sdot v20.4s, v22.16b, v10.4b[2]\n"
1089
- ".inst 0x4faaeac2 // sdot v2.4s, v22.16b, v10.4b[3]\n"
1090
- "ldr q10, [x23, #0x70]\n"
1091
- "add x23, x23, #0x88\n"
1092
- ".inst 0x4f98e369 // sdot v9.4s, v27.16b, v24.4b[0]\n"
1093
- ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n"
1094
- ".inst 0x4f98eb74 // sdot v20.4s, v27.16b, v24.4b[2]\n"
1095
- ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n"
1096
- "ldr q24, [x22, #0x0]\n"
1097
- ".inst 0x4f8ae3c9 // sdot v9.4s, v30.16b, v10.4b[0]\n"
1098
- ".inst 0x4faae3dd // sdot v29.4s, v30.16b, v10.4b[1]\n"
1099
- ".inst 0x4f8aebd4 // sdot v20.4s, v30.16b, v10.4b[2]\n"
1100
- ".inst 0x4faaebc2 // sdot v2.4s, v30.16b, v10.4b[3]\n"
1101
- "fmul v10.4s, v17.4s, v26.s[0]\n"
1102
- "scvtf v9.4s, v9.4s, #0x4\n"
1103
- "scvtf v29.4s, v29.4s, #0x4\n"
1104
- "scvtf v20.4s, v20.4s, #0x4\n"
1105
- "scvtf v2.4s, v2.4s, #0x4\n"
1106
- "fmla v11.4s, v9.4s, v10.4s\n"
1107
- "ldr q9, [x22, #0x10]\n"
1108
- "fmul v10.4s, v17.4s, v26.s[1]\n"
1109
- "fmla v13.4s, v29.4s, v10.4s\n"
1110
- "ldr d29, [x22, #-0x8]\n"
1111
- "fmul v10.4s, v17.4s, v26.s[2]\n"
1112
- "fmul v26.4s, v17.4s, v26.s[3]\n"
1113
- "fcvtl v29.4s, v29.4h\n"
1114
- "fmla v23.4s, v20.4s, v10.4s\n"
1115
- "movi v20.4s, #0x0\n"
1116
- "movi v10.4s, #0x0\n"
1117
- "fmla v16.4s, v2.4s, v26.4s\n"
1118
- "movi v26.4s, #0x0\n"
1119
- "movi v2.4s, #0x0\n"
1120
- ".inst 0x4f98e194 // sdot v20.4s, v12.16b, v24.4b[0]\n"
1121
- ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n"
1122
- ".inst 0x4f98e99a // sdot v26.4s, v12.16b, v24.4b[2]\n"
1123
- ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n"
1124
- "ldr q24, [x22, #0x20]\n"
1125
- ".inst 0x4f89e3f4 // sdot v20.4s, v31.16b, v9.4b[0]\n"
1126
- ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n"
1127
- ".inst 0x4f89ebfa // sdot v26.4s, v31.16b, v9.4b[2]\n"
1128
- ".inst 0x4fa9ebe2 // sdot v2.4s, v31.16b, v9.4b[3]\n"
1129
- "ldr q9, [x22, #0x30]\n"
1335
+ // 1x8 tile = 2 x 4
1336
+ float32x4_t acc_f32[2];
1337
+
1338
+ const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
1339
+
1340
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
1341
+ const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb);
1342
+
1343
+ for (int i = 0; i < col_groups; i++) {
1344
+ acc_f32[i] = vdupq_n_f32(0);
1345
+ }
1346
+
1347
+ for (int b = 0; b < nb; b++) {
1348
+ float32x4_t q6_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d)); // d0 d1 d2 d3
1349
+ float32x4_t q6_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4)); // d4 d5 d6 d7
1350
+ float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
1351
+ float32x4_t sb_scale_0 = vmulq_f32(q6_d_0, q8_d);
1352
+ float32x4_t sb_scale_1 = vmulq_f32(q6_d_1, q8_d);
1353
+
1354
+ int32x4_t acc[col_groups];
1355
+ for (int i = 0; i < col_groups; i++) {
1356
+ acc[i] = vdupq_n_s32(0);
1357
+ }
1358
+
1359
+ // Load all 16 scales once and widen to int16 (Q6_K has 16 scales per block)
1360
+ // Reused for bias and dequantization later
1361
+ int16_t q6_scales[16 * 8];
1362
+ for (int i = 0; i < 16; i++) {
1363
+ int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8));
1364
+ vst1q_s16(q6_scales + i * 8, scales);
1365
+ }
1366
+
1367
+ // Compute bias per column using q8 bsums and preloaded scales to skip the -32 shift
1368
+ int32x4_t bias_lo = vdupq_n_s32(0);
1369
+ int32x4_t bias_hi = vdupq_n_s32(0);
1370
+
1371
+ // Load bsums in chunks of 4 to process with vectorized operations
1372
+ for (int i = 0; i < 16; i += 4) {
1373
+ int16x4_t bsums_vec = vld1_s16(q8_ptr[b].bsums + i);
1374
+ int16x4_t scales_lo_0 = vld1_s16(q6_scales + (i + 0) * 8);
1375
+ int16x4_t scales_hi_0 = vld1_s16(q6_scales + (i + 0) * 8 + 4);
1376
+ int16x4_t scales_lo_1 = vld1_s16(q6_scales + (i + 1) * 8);
1377
+ int16x4_t scales_hi_1 = vld1_s16(q6_scales + (i + 1) * 8 + 4);
1378
+ int16x4_t scales_lo_2 = vld1_s16(q6_scales + (i + 2) * 8);
1379
+ int16x4_t scales_hi_2 = vld1_s16(q6_scales + (i + 2) * 8 + 4);
1380
+ int16x4_t scales_lo_3 = vld1_s16(q6_scales + (i + 3) * 8);
1381
+ int16x4_t scales_hi_3 = vld1_s16(q6_scales + (i + 3) * 8 + 4);
1382
+
1383
+ bias_lo = vmlal_lane_s16(bias_lo, scales_lo_0, bsums_vec, 0);
1384
+ bias_hi = vmlal_lane_s16(bias_hi, scales_hi_0, bsums_vec, 0);
1385
+ bias_lo = vmlal_lane_s16(bias_lo, scales_lo_1, bsums_vec, 1);
1386
+ bias_hi = vmlal_lane_s16(bias_hi, scales_hi_1, bsums_vec, 1);
1387
+ bias_lo = vmlal_lane_s16(bias_lo, scales_lo_2, bsums_vec, 2);
1388
+ bias_hi = vmlal_lane_s16(bias_hi, scales_hi_2, bsums_vec, 2);
1389
+ bias_lo = vmlal_lane_s16(bias_lo, scales_lo_3, bsums_vec, 3);
1390
+ bias_hi = vmlal_lane_s16(bias_hi, scales_hi_3, bsums_vec, 3);
1391
+ }
1392
+ bias_lo = vshlq_n_s32(bias_lo, 5);
1393
+ bias_hi = vshlq_n_s32(bias_hi, 5);
1394
+
1395
+ // Process two 128-value halves per superblock
1396
+ for (int half = 0; half < 2; half++) {
1397
+ const uint8_t * ql_base = q6_ptr[b].ql + half * 512;
1398
+ const uint8_t * qh_base = q6_ptr[b].qh + half * 256;
1399
+
1400
+ // A subblock (sb) is a set of weights that share the scale
1401
+ // Since q6_K scales are per 16 elements
1402
+ // num sbs -> 256 elements / (16 elements/scale * 2 elements/byte * 2 halves)
1403
+ for (int sb = 0; sb < QK_K / 64; sb++) {
1404
+ const int8_t * q8_base_l = q8_ptr[b].qs + half * 128 + sb * 16;
1405
+ const int8_t * q8_base_h = q8_base_l + 64;
1406
+
1407
+ // Load and duplicate q8 values (each register covers four interleaved columns of q6)
1408
+ int8x16_t q8_l[4];
1409
+ int8x16_t q8_h[4];
1410
+ for (int i = 0; i < 4; i++) {
1411
+ q8_l[i] = (int8x16_t) vld1q_dup_s32((const int32_t *) (q8_base_l + i * 4));
1412
+ q8_h[i] = (int8x16_t) vld1q_dup_s32((const int32_t *) (q8_base_h + i * 4));
1413
+ }
1414
+
1415
+ const int ql_off_base = sb * QK_K / 2;
1416
+ const int qh_off_base = ql_off_base & 255; // wraps after 256 bytes
1417
+
1418
+ // Load 4 vectors at once (64 bytes each for ql_0, ql_1, qh_0, qh_1)
1419
+ uint8x16x4_t q6_ql_0 = vld1q_u8_x4(ql_base + ql_off_base);
1420
+ uint8x16x4_t q6_ql_1 = vld1q_u8_x4(ql_base + ql_off_base + 64);
1421
+ uint8x16x4_t q6_qh_0 = vld1q_u8_x4(qh_base + qh_off_base);
1422
+ uint8x16x4_t q6_qh_1 = vld1q_u8_x4(qh_base + qh_off_base + 64);
1423
+
1424
+ // Adjust qh for subblocks 2 and 3 (shift right by 2)
1425
+ if (sb > 1) {
1426
+ q6_qh_0.val[0] = vshrq_n_u8(q6_qh_0.val[0], 2);
1427
+ q6_qh_0.val[1] = vshrq_n_u8(q6_qh_0.val[1], 2);
1428
+ q6_qh_0.val[2] = vshrq_n_u8(q6_qh_0.val[2], 2);
1429
+ q6_qh_0.val[3] = vshrq_n_u8(q6_qh_0.val[3], 2);
1430
+ q6_qh_1.val[0] = vshrq_n_u8(q6_qh_1.val[0], 2);
1431
+ q6_qh_1.val[1] = vshrq_n_u8(q6_qh_1.val[1], 2);
1432
+ q6_qh_1.val[2] = vshrq_n_u8(q6_qh_1.val[2], 2);
1433
+ q6_qh_1.val[3] = vshrq_n_u8(q6_qh_1.val[3], 2);
1434
+ }
1435
+
1436
+ const uint8x16_t q6_ql[8] = { q6_ql_0.val[0], q6_ql_0.val[1], q6_ql_0.val[2], q6_ql_0.val[3],
1437
+ q6_ql_1.val[0], q6_ql_1.val[1], q6_ql_1.val[2], q6_ql_1.val[3] };
1438
+ const uint8x16_t q6_qh[8] = { q6_qh_0.val[0], q6_qh_0.val[1], q6_qh_0.val[2], q6_qh_0.val[3],
1439
+ q6_qh_1.val[0], q6_qh_1.val[1], q6_qh_1.val[2], q6_qh_1.val[3] };
1440
+
1441
+ // Process column groups (0-3, 4-7)
1442
+ for (int g = 0; g < col_groups; g++) {
1443
+ int32x4_t sb_acc_l = vdupq_n_s32(0);
1444
+ int32x4_t sb_acc_h = vdupq_n_s32(0);
1445
+
1446
+ for (int chunk = 0; chunk < 4; chunk++) {
1447
+ const int idx = chunk * 2 + g;
1448
+
1449
+ const uint8x16_t q6_qs_l = q6_ql[idx];
1450
+ const uint8x16_t q6_qs_h = q6_qh[idx];
1451
+
1452
+ // Extract high 2 bits for upper nibble reconstruction
1453
+ const uint8x16_t q6_qs_hh = vandq_u8(q6_qs_h, mask_hi);
1454
+
1455
+ // q6 = (low4 | high2<<4), without -32 bias (handled via bsums)
1456
+ const int8x16_t q6_l =
1457
+ vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_qs_l, m4b), vandq_u8(q6_qs_h, mask_lo), 4));
1458
+ const int8x16_t q6_h = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_l, 4), q6_qs_hh));
1459
+
1460
+ sb_acc_l = vdotq_s32(sb_acc_l, q6_l, q8_l[chunk]);
1461
+ sb_acc_h = vdotq_s32(sb_acc_h, q6_h, q8_h[chunk]);
1462
+ }
1463
+
1464
+ const int scale_idx_l = half * 8 + sb;
1465
+ const int scale_idx_h = half * 8 + sb + 4;
1466
+
1467
+ const int32x4_t scale_vec_l = vmovl_s16(vld1_s16(q6_scales + scale_idx_l * 8 + g * 4));
1468
+ const int32x4_t scale_vec_h = vmovl_s16(vld1_s16(q6_scales + scale_idx_h * 8 + g * 4));
1469
+
1470
+ acc[g] = vmlaq_s32(acc[g], sb_acc_l, scale_vec_l);
1471
+ acc[g] = vmlaq_s32(acc[g], sb_acc_h, scale_vec_h);
1472
+ }
1473
+ }
1474
+ } // for half
1475
+
1476
+ // Bias correction
1477
+ acc[0] = vsubq_s32(acc[0], bias_lo);
1478
+ acc[1] = vsubq_s32(acc[1], bias_hi);
1479
+
1480
+ // Apply superblock scale (no mins for q6_K)
1481
+ // acc[g] has [c0, c1, c2, c3]
1482
+ float32x4_t w_0123 = vmulq_f32(vcvtq_f32_s32(acc[0]), sb_scale_0);
1483
+ float32x4_t w_4567 = vmulq_f32(vcvtq_f32_s32(acc[1]), sb_scale_1);
1484
+
1485
+ acc_f32[0] = vaddq_f32(acc_f32[0], w_0123);
1486
+ acc_f32[1] = vaddq_f32(acc_f32[1], w_4567);
1487
+ } // for b
1488
+
1489
+ int base = x * ncols_interleaved;
1490
+ vst1q_f32(s + base, acc_f32[0]);
1491
+ vst1q_f32(s + base + 4, acc_f32[1]);
1492
+ } // for x
1493
+ return;
1494
+ #endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
1495
+ ggml_gemv_q6_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
1496
+ }
1497
+
1498
+ void ggml_gemv_q6_K_8x8_q8_K(int n,
1499
+ float * GGML_RESTRICT s,
1500
+ size_t bs,
1501
+ const void * GGML_RESTRICT vx,
1502
+ const void * GGML_RESTRICT vy,
1503
+ int nr,
1504
+ int nc) {
1505
+ constexpr int qk = QK_K;
1506
+ const int nb = n / qk;
1507
+
1508
+ constexpr int ncols_interleaved = 8;
1509
+ constexpr int blocklen = 8;
1510
+
1511
+ assert(n % qk == 0);
1512
+ assert(nc % ncols_interleaved == 0);
1513
+
1514
+ UNUSED(nb);
1515
+ UNUSED(ncols_interleaved);
1516
+ UNUSED(blocklen);
1517
+
1518
+ #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
1519
+ constexpr int col_pairs = ncols_interleaved / 2;
1520
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
1521
+ const uint8x16_t mask_lo = vdupq_n_u8(0x03);
1522
+ const uint8x16_t mask_hi = vdupq_n_u8(0x30);
1523
+
1524
+ // 1x8 tile = 2 x 4
1525
+ float32x4_t acc_f32[2];
1526
+
1527
+ const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
1528
+
1529
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
1530
+ const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb);
1531
+
1532
+ acc_f32[0] = vdupq_n_f32(0);
1533
+ acc_f32[1] = vdupq_n_f32(0);
1534
+
1535
+ for (int b = 0; b < nb; b++) {
1536
+ float32x4_t q6_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d)); // d0 d1 d2 d3
1537
+ float32x4_t q6_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4)); // d4 d5 d6 d7
1538
+ float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
1539
+ float32x4_t sb_scale_0 = vmulq_f32(q6_d_0, q8_d);
1540
+ float32x4_t sb_scale_1 = vmulq_f32(q6_d_1, q8_d);
1541
+
1542
+ int32x2_t acc[col_pairs];
1543
+ for (int i = 0; i < col_pairs; i++) {
1544
+ acc[i] = vdup_n_s32(0);
1545
+ }
1546
+
1547
+ // Load all 16 scales once and widen to int16 (Q6_K has 16 scales per block)
1548
+ // Reused for bias and dequantization later
1549
+ int16_t q6_scales[16 * 8];
1550
+ for (int i = 0; i < 16; i++) {
1551
+ int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8));
1552
+ vst1q_s16(q6_scales + i * 8, scales);
1553
+ }
1554
+
1555
+ // Compute bias per column using q8 bsums and preloaded scales to skip the -32 shift
1556
+ int32x4_t bias_lo = vdupq_n_s32(0);
1557
+ int32x4_t bias_hi = vdupq_n_s32(0);
1558
+
1559
+ // Load bsums in chunks of 4 to process with vectorized operations
1560
+ for (int i = 0; i < 16; i += 4) {
1561
+ int16x4_t bsums_vec = vld1_s16(q8_ptr[b].bsums + i);
1562
+ int16x4_t scales_lo_0 = vld1_s16(q6_scales + (i + 0) * 8);
1563
+ int16x4_t scales_hi_0 = vld1_s16(q6_scales + (i + 0) * 8 + 4);
1564
+ int16x4_t scales_lo_1 = vld1_s16(q6_scales + (i + 1) * 8);
1565
+ int16x4_t scales_hi_1 = vld1_s16(q6_scales + (i + 1) * 8 + 4);
1566
+ int16x4_t scales_lo_2 = vld1_s16(q6_scales + (i + 2) * 8);
1567
+ int16x4_t scales_hi_2 = vld1_s16(q6_scales + (i + 2) * 8 + 4);
1568
+ int16x4_t scales_lo_3 = vld1_s16(q6_scales + (i + 3) * 8);
1569
+ int16x4_t scales_hi_3 = vld1_s16(q6_scales + (i + 3) * 8 + 4);
1570
+
1571
+ bias_lo = vmlal_lane_s16(bias_lo, scales_lo_0, bsums_vec, 0);
1572
+ bias_hi = vmlal_lane_s16(bias_hi, scales_hi_0, bsums_vec, 0);
1573
+ bias_lo = vmlal_lane_s16(bias_lo, scales_lo_1, bsums_vec, 1);
1574
+ bias_hi = vmlal_lane_s16(bias_hi, scales_hi_1, bsums_vec, 1);
1575
+ bias_lo = vmlal_lane_s16(bias_lo, scales_lo_2, bsums_vec, 2);
1576
+ bias_hi = vmlal_lane_s16(bias_hi, scales_hi_2, bsums_vec, 2);
1577
+ bias_lo = vmlal_lane_s16(bias_lo, scales_lo_3, bsums_vec, 3);
1578
+ bias_hi = vmlal_lane_s16(bias_hi, scales_hi_3, bsums_vec, 3);
1579
+ }
1580
+ bias_lo = vshlq_n_s32(bias_lo, 5);
1581
+ bias_hi = vshlq_n_s32(bias_hi, 5);
1582
+
1583
+ // Process two 128-value halves per superblock
1584
+ for (int half = 0; half < 2; half++) {
1585
+ const uint8_t * ql_base = q6_ptr[b].ql + half * 512;
1586
+ const uint8_t * qh_base = q6_ptr[b].qh + half * 256;
1587
+
1588
+ // A subblock (sb) is a set of weights that share the scale
1589
+ // Since q6_K scales are per 16 elements
1590
+ // num sbs -> 256 elements / (16 elements/scale * 2 elements/byte * 2 halves)
1591
+ for (int sb = 0; sb < QK_K / 64; sb++) {
1592
+ const int8_t * q8_base_l = q8_ptr[b].qs + half * 128 + sb * 16;
1593
+ const int8_t * q8_base_h = q8_base_l + 64;
1594
+
1595
+ // Load and duplicate q8 values (each register covers two interleaved columns of q6)
1596
+ int8x16_t q8_l[2];
1597
+ int8x16_t q8_h[2];
1598
+ for (int i = 0; i < 2; i++) {
1599
+ q8_l[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base_l + i * 8));
1600
+ q8_h[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base_h + i * 8));
1601
+ }
1602
+
1603
+ const int ql_off_base = sb * QK_K / 2;
1604
+ const int qh_off_base = ql_off_base & 255; // wraps after 256 bytes
1605
+
1606
+ // Load 4 vectors at once (64 bytes each for ql_0, ql_1, qh_0, qh_1)
1607
+ uint8x16x4_t q6_ql_0 = vld1q_u8_x4(ql_base + ql_off_base);
1608
+ uint8x16x4_t q6_ql_1 = vld1q_u8_x4(ql_base + ql_off_base + 64);
1609
+ uint8x16x4_t q6_qh_0 = vld1q_u8_x4(qh_base + qh_off_base);
1610
+ uint8x16x4_t q6_qh_1 = vld1q_u8_x4(qh_base + qh_off_base + 64);
1611
+
1612
+ // Adjust qh for subblocks 2 and 3 (shift right by 2)
1613
+ if (sb > 1) {
1614
+ q6_qh_0.val[0] = vshrq_n_u8(q6_qh_0.val[0], 2);
1615
+ q6_qh_0.val[1] = vshrq_n_u8(q6_qh_0.val[1], 2);
1616
+ q6_qh_0.val[2] = vshrq_n_u8(q6_qh_0.val[2], 2);
1617
+ q6_qh_0.val[3] = vshrq_n_u8(q6_qh_0.val[3], 2);
1618
+ q6_qh_1.val[0] = vshrq_n_u8(q6_qh_1.val[0], 2);
1619
+ q6_qh_1.val[1] = vshrq_n_u8(q6_qh_1.val[1], 2);
1620
+ q6_qh_1.val[2] = vshrq_n_u8(q6_qh_1.val[2], 2);
1621
+ q6_qh_1.val[3] = vshrq_n_u8(q6_qh_1.val[3], 2);
1622
+ }
1623
+
1624
+ // Process column pairs (0-1, 2-3, 4-5, 6-7)
1625
+ for (int cp = 0; cp < col_pairs; cp++) {
1626
+ const uint8x16_t q6_qs_cp_0_l = q6_ql_0.val[cp];
1627
+ const uint8x16_t q6_qs_cp_1_l = q6_ql_1.val[cp];
1628
+ const uint8x16_t q6_qs_cp_0_h = q6_qh_0.val[cp];
1629
+ const uint8x16_t q6_qs_cp_1_h = q6_qh_1.val[cp];
1630
+
1631
+ // Extract high 2 bits for upper nibble reconstruction
1632
+ const uint8x16_t q6_qs_cp_0_hh = vandq_u8(q6_qs_cp_0_h, mask_hi);
1633
+ const uint8x16_t q6_qs_cp_1_hh = vandq_u8(q6_qs_cp_1_h, mask_hi);
1634
+
1635
+ // q6 = (low4 | high2<<4), without -32 bias (handled via bsums)
1636
+ const int8x16_t q6_l0 = vreinterpretq_s8_u8(
1637
+ vsliq_n_u8(vandq_u8(q6_qs_cp_0_l, m4b), vandq_u8(q6_qs_cp_0_h, mask_lo), 4));
1638
+ const int8x16_t q6_l1 = vreinterpretq_s8_u8(
1639
+ vsliq_n_u8(vandq_u8(q6_qs_cp_1_l, m4b), vandq_u8(q6_qs_cp_1_h, mask_lo), 4));
1640
+ const int8x16_t q6_h0 =
1641
+ vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_0_l, 4), q6_qs_cp_0_hh));
1642
+ const int8x16_t q6_h1 =
1643
+ vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_1_l, 4), q6_qs_cp_1_hh));
1644
+
1645
+ int32x4_t sb_acc_l = vdupq_n_s32(0);
1646
+ sb_acc_l = vdotq_s32(sb_acc_l, q6_l0, q8_l[0]);
1647
+ sb_acc_l = vdotq_s32(sb_acc_l, q6_l1, q8_l[1]);
1648
+
1649
+ int32x4_t sb_acc_h = vdupq_n_s32(0);
1650
+ sb_acc_h = vdotq_s32(sb_acc_h, q6_h0, q8_h[0]);
1651
+ sb_acc_h = vdotq_s32(sb_acc_h, q6_h1, q8_h[1]);
1652
+
1653
+ // Pairwise add to get per-column sums: [col0, col1]
1654
+ int32x2_t sum_l = vpadd_s32(vget_low_s32(sb_acc_l), vget_high_s32(sb_acc_l));
1655
+ int32x2_t sum_h = vpadd_s32(vget_low_s32(sb_acc_h), vget_high_s32(sb_acc_h));
1656
+
1657
+ const int scale_idx_l = half * 8 + sb;
1658
+ const int scale_idx_h = half * 8 + sb + 4;
1659
+
1660
+ // Access scales using array indexing (scales are interleaved by column)
1661
+ const int32x2_t scale_vec_l = { (int32_t) q6_scales[scale_idx_l * 8 + cp * 2],
1662
+ (int32_t) q6_scales[scale_idx_l * 8 + cp * 2 + 1] };
1663
+ const int32x2_t scale_vec_h = { (int32_t) q6_scales[scale_idx_h * 8 + cp * 2],
1664
+ (int32_t) q6_scales[scale_idx_h * 8 + cp * 2 + 1] };
1665
+
1666
+ // Accumulate scaled results
1667
+ acc[cp] = vmla_s32(acc[cp], sum_l, scale_vec_l);
1668
+ acc[cp] = vmla_s32(acc[cp], sum_h, scale_vec_h);
1669
+ }
1670
+ }
1671
+ } // for half
1672
+
1673
+ // Bias correction
1674
+ acc[0] = vsub_s32(acc[0], vget_low_s32(bias_lo));
1675
+ acc[1] = vsub_s32(acc[1], vget_high_s32(bias_lo));
1676
+ acc[2] = vsub_s32(acc[2], vget_low_s32(bias_hi));
1677
+ acc[3] = vsub_s32(acc[3], vget_high_s32(bias_hi));
1678
+
1679
+ // Apply superblock scale (no mins for q6_K)
1680
+ // acc[cp] has [c0, c1]
1681
+ float32x2_t w_01 = vmul_f32(vcvt_f32_s32(acc[0]), vget_low_f32(sb_scale_0));
1682
+ float32x2_t w_23 = vmul_f32(vcvt_f32_s32(acc[1]), vget_high_f32(sb_scale_0));
1683
+ float32x2_t w_45 = vmul_f32(vcvt_f32_s32(acc[2]), vget_low_f32(sb_scale_1));
1684
+ float32x2_t w_67 = vmul_f32(vcvt_f32_s32(acc[3]), vget_high_f32(sb_scale_1));
1685
+
1686
+ acc_f32[0] = vaddq_f32(acc_f32[0], vcombine_f32(w_01, w_23));
1687
+ acc_f32[1] = vaddq_f32(acc_f32[1], vcombine_f32(w_45, w_67));
1688
+ } // for b
1689
+
1690
+ int base = x * ncols_interleaved;
1691
+ vst1q_f32(s + base, acc_f32[0]);
1692
+ vst1q_f32(s + base + 4, acc_f32[1]);
1693
+ } // for x
1694
+ return;
1695
+ #endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
1696
+ ggml_gemv_q6_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
1697
+ }
1698
+
1699
+ void ggml_gemv_q8_0_4x4_q8_0(int n,
1700
+ float * GGML_RESTRICT s,
1701
+ size_t bs,
1702
+ const void * GGML_RESTRICT vx,
1703
+ const void * GGML_RESTRICT vy,
1704
+ int nr,
1705
+ int nc) {
1706
+ const int qk = QK8_0;
1707
+ const int nb = n / qk;
1708
+ const int ncols_interleaved = 4;
1709
+ const int blocklen = 4;
1710
+
1711
+ assert(n % qk == 0);
1712
+ assert(nc % ncols_interleaved == 0);
1713
+
1714
+ UNUSED(nb);
1715
+ UNUSED(ncols_interleaved);
1716
+ UNUSED(blocklen);
1717
+
1718
+ #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
1719
+ const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx;
1720
+
1721
+ for (int c = 0; c < nc; c += ncols_interleaved) {
1722
+ const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
1723
+ float32x4_t acc = vdupq_n_f32(0);
1724
+ for (int b = 0; b < nb; b++) {
1725
+ int8x16x4_t b_low = vld1q_s8_x4((const int8_t *) b_ptr->qs);
1726
+ int8x16x4_t b_high = vld1q_s8_x4((const int8_t *) b_ptr->qs + 64);
1727
+ float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);
1728
+
1729
+ int8x16x2_t a = vld1q_s8_x2(a_ptr->qs);
1730
+ float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
1731
+
1732
+ int32x4_t ret = vdupq_n_s32(0);
1733
+
1734
+ ret = vdotq_laneq_s32(ret, b_low.val[0], a.val[0], 0);
1735
+ ret = vdotq_laneq_s32(ret, b_low.val[1], a.val[0], 1);
1736
+ ret = vdotq_laneq_s32(ret, b_low.val[2], a.val[0], 2);
1737
+ ret = vdotq_laneq_s32(ret, b_low.val[3], a.val[0], 3);
1738
+
1739
+ ret = vdotq_laneq_s32(ret, b_high.val[0], a.val[1], 0);
1740
+ ret = vdotq_laneq_s32(ret, b_high.val[1], a.val[1], 1);
1741
+ ret = vdotq_laneq_s32(ret, b_high.val[2], a.val[1], 2);
1742
+ ret = vdotq_laneq_s32(ret, b_high.val[3], a.val[1], 3);
1743
+
1744
+ acc = vfmaq_f32(acc, vcvtq_f32_s32(ret), vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
1745
+ a_ptr++;
1746
+ b_ptr++;
1747
+ }
1748
+ vst1q_f32(s, acc);
1749
+ s += ncols_interleaved;
1750
+ }
1751
+ return;
1752
+
1753
+ #endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
1754
+ ggml_gemv_q8_0_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
1755
+ }
1756
+
1757
+ void ggml_gemv_q8_0_4x8_q8_0(int n,
1758
+ float * GGML_RESTRICT s,
1759
+ size_t bs,
1760
+ const void * GGML_RESTRICT vx,
1761
+ const void * GGML_RESTRICT vy,
1762
+ int nr,
1763
+ int nc) {
1764
+ const int qk = QK8_0;
1765
+ const int nb = n / qk;
1766
+ const int ncols_interleaved = 4;
1767
+ const int blocklen = 8;
1768
+
1769
+ assert(n % qk == 0);
1770
+ assert(nc % ncols_interleaved == 0);
1771
+
1772
+ UNUSED(nb);
1773
+ UNUSED(ncols_interleaved);
1774
+ UNUSED(blocklen);
1775
+
1776
+ #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
1777
+ const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx;
1778
+
1779
+ for (int c = 0; c < nc; c += ncols_interleaved) {
1780
+ const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
1781
+ float32x4_t acc = vdupq_n_f32(0);
1782
+
1783
+ for (int b = 0; b < nb; b++) {
1784
+ int8x16x4_t b_low = vld1q_s8_x4((const int8_t *) b_ptr->qs);
1785
+ int8x16x4_t b_high = vld1q_s8_x4((const int8_t *) b_ptr->qs + 64);
1786
+ float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);
1787
+
1788
+ int8x8x4_t a_chunks = vld1_s8_x4(a_ptr->qs);
1789
+ int8x16_t a0 = vcombine_s8(a_chunks.val[0], a_chunks.val[0]);
1790
+ int8x16_t a1 = vcombine_s8(a_chunks.val[1], a_chunks.val[1]);
1791
+ int8x16_t a2 = vcombine_s8(a_chunks.val[2], a_chunks.val[2]);
1792
+ int8x16_t a3 = vcombine_s8(a_chunks.val[3], a_chunks.val[3]);
1793
+ float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
1794
+
1795
+ int32x4_t ret0 = vdupq_n_s32(0);
1796
+ int32x4_t ret1 = vdupq_n_s32(0);
1797
+
1798
+ // 0..7
1799
+ ret0 = vdotq_s32(ret0, b_low.val[0], a0);
1800
+ ret1 = vdotq_s32(ret1, b_low.val[1], a0);
1801
+ // 8..15
1802
+ ret0 = vdotq_s32(ret0, b_low.val[2], a1);
1803
+ ret1 = vdotq_s32(ret1, b_low.val[3], a1);
1804
+ // 16..23
1805
+ ret0 = vdotq_s32(ret0, b_high.val[0], a2);
1806
+ ret1 = vdotq_s32(ret1, b_high.val[1], a2);
1807
+ // 24..31
1808
+ ret0 = vdotq_s32(ret0, b_high.val[2], a3);
1809
+ ret1 = vdotq_s32(ret1, b_high.val[3], a3);
1810
+
1811
+ int32x4_t ret = vpaddq_s32(ret0, ret1);
1812
+
1813
+ acc = vfmaq_f32(acc, vcvtq_f32_s32(ret), vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
1814
+ a_ptr++;
1815
+ b_ptr++;
1816
+ }
1817
+ vst1q_f32(s, acc);
1818
+ s += ncols_interleaved;
1819
+ }
1820
+ return;
1821
+
1822
+ #endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
1823
+ ggml_gemv_q8_0_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
1824
+ }
1825
+
1826
+ void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
1827
+ const int qk = QK8_0;
1828
+ const int nb = n / qk;
1829
+ const int ncols_interleaved = 4;
1830
+ const int blocklen = 4;
1831
+
1832
+ assert (n % qk == 0);
1833
+ assert (nr % 4 == 0);
1834
+ assert (nc % ncols_interleaved == 0);
1835
+
1836
+ UNUSED(s);
1837
+ UNUSED(bs);
1838
+ UNUSED(vx);
1839
+ UNUSED(vy);
1840
+ UNUSED(nr);
1841
+ UNUSED(nc);
1842
+ UNUSED(nb);
1843
+ UNUSED(ncols_interleaved);
1844
+ UNUSED(blocklen);
1845
+
1846
+ #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
1847
+ const void * b_ptr = vx;
1848
+ const void * a_ptr = vy;
1849
+ float * res_ptr = s;
1850
+ size_t res_stride = bs * sizeof(float);
1851
+
1852
+ __asm__ __volatile__(
1853
+ "mov x10, %x[nr]\n"
1854
+ "mov x9, #0x88\n"
1855
+ "cmp x10, #0x10\n"
1856
+ "mul x9, %x[nb], x9\n"
1857
+ "blt 4f\n"
1858
+ "1:" // Row loop
1859
+ "add x28, %x[b_ptr], #0x8\n"
1860
+ "mov x27, %x[nc]\n"
1861
+ "add x26, %x[res_ptr], %x[res_stride], LSL #4\n"
1862
+ "2:" // Column loop
1863
+ "add x25, %x[a_ptr], #0x8\n"
1864
+ "movi v15.16b, #0x0\n"
1865
+ "movi v19.16b, #0x0\n"
1866
+ "mov x24, %x[nb]\n"
1867
+ "add x23, x25, x9\n"
1868
+ "movi v18.16b, #0x0\n"
1869
+ "movi v14.16b, #0x0\n"
1870
+ "add x22, x23, x9\n"
1871
+ "movi v11.16b, #0x0\n"
1872
+ "movi v13.16b, #0x0\n"
1873
+ "add x21, x22, x9\n"
1874
+ "movi v23.16b, #0x0\n"
1875
+ "movi v16.16b, #0x0\n"
1876
+ "movi v25.16b, #0x0\n"
1877
+ "movi v7.16b, #0x0\n"
1878
+ "movi v0.16b, #0x0\n"
1879
+ "movi v4.16b, #0x0\n"
1880
+ "movi v5.16b, #0x0\n"
1881
+ "movi v21.16b, #0x0\n"
1882
+ "movi v8.16b, #0x0\n"
1883
+ "movi v1.16b, #0x0\n"
1884
+ "3:" // Block loop
1885
+ "ldr q3, [x28, #0x0]\n"
1886
+ "ldr q31, [x25, #0x0]\n"
1887
+ "movi v28.16b, #0x4\n"
1888
+ "movi v10.4s, #0x0\n"
1889
+ "ldr q22, [x28, #0x10]\n"
1890
+ "ldr q6, [x25, #0x10]\n"
1891
+ "movi v29.4s, #0x0\n"
1892
+ "movi v9.4s, #0x0\n"
1893
+ "ldr q27, [x28, #0x20]\n"
1894
+ "ldr q30, [x28, #0x30]\n"
1895
+ "movi v20.4s, #0x0\n"
1896
+ "movi v24.16b, #0xf0\n"
1897
+ "ldr d2, [x25, #-0x8]\n"
1898
+ "ldr d26, [x23, #-0x8]\n"
1899
+ "sshl v12.16b, v3.16b, v28.16b\n"
1900
+ "sub x20, x28, #0x8\n"
1901
+ "ldr d17, [x20, #0x0]\n"
1902
+ "and v3.16b, v3.16b, v24.16b\n"
1903
+ "subs x24, x24, #0x1\n"
1904
+ "add x28, x28, #0x48\n"
1905
+ ".inst 0x4f9fe18a // sdot v10.4s, v12.16b, v31.4b[0]\n"
1906
+ ".inst 0x4fbfe19d // sdot v29.4s, v12.16b, v31.4b[1]\n"
1907
+ ".inst 0x4f9fe989 // sdot v9.4s, v12.16b, v31.4b[2]\n"
1908
+ ".inst 0x4fbfe994 // sdot v20.4s, v12.16b, v31.4b[3]\n"
1909
+ "sshl v31.16b, v22.16b, v28.16b\n"
1910
+ "and v22.16b, v22.16b, v24.16b\n"
1911
+ "fcvtl v17.4s, v17.4h\n"
1912
+ "fcvtl v2.4s, v2.4h\n"
1913
+ "fcvtl v26.4s, v26.4h\n"
1914
+ ".inst 0x4f86e3ea // sdot v10.4s, v31.16b, v6.4b[0]\n"
1915
+ ".inst 0x4fa6e3fd // sdot v29.4s, v31.16b, v6.4b[1]\n"
1916
+ ".inst 0x4f86ebe9 // sdot v9.4s, v31.16b, v6.4b[2]\n"
1917
+ ".inst 0x4fa6ebf4 // sdot v20.4s, v31.16b, v6.4b[3]\n"
1918
+ "sshl v6.16b, v27.16b, v28.16b\n"
1919
+ "sshl v28.16b, v30.16b, v28.16b\n"
1920
+ "and v27.16b, v27.16b, v24.16b\n"
1921
+ "and v30.16b, v30.16b, v24.16b\n"
1922
+ "ldr q24, [x25, #0x20]\n"
1923
+ ".inst 0x4f98e0ca // sdot v10.4s, v6.16b, v24.4b[0]\n"
1924
+ ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n"
1925
+ ".inst 0x4f98e8c9 // sdot v9.4s, v6.16b, v24.4b[2]\n"
1926
+ ".inst 0x4fb8e8d4 // sdot v20.4s, v6.16b, v24.4b[3]\n"
1927
+ "ldr q24, [x25, #0x30]\n"
1928
+ ".inst 0x4f98e38a // sdot v10.4s, v28.16b, v24.4b[0]\n"
1929
+ ".inst 0x4fb8e39d // sdot v29.4s, v28.16b, v24.4b[1]\n"
1930
+ ".inst 0x4f98eb89 // sdot v9.4s, v28.16b, v24.4b[2]\n"
1931
+ ".inst 0x4fb8eb94 // sdot v20.4s, v28.16b, v24.4b[3]\n"
1932
+ "ldr q24, [x25, #0x40]\n"
1933
+ ".inst 0x4f98e06a // sdot v10.4s, v3.16b, v24.4b[0]\n"
1934
+ ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n"
1935
+ ".inst 0x4f98e869 // sdot v9.4s, v3.16b, v24.4b[2]\n"
1936
+ ".inst 0x4fb8e874 // sdot v20.4s, v3.16b, v24.4b[3]\n"
1937
+ "ldr q24, [x25, #0x50]\n"
1938
+ ".inst 0x4f98e2ca // sdot v10.4s, v22.16b, v24.4b[0]\n"
1939
+ ".inst 0x4fb8e2dd // sdot v29.4s, v22.16b, v24.4b[1]\n"
1940
+ ".inst 0x4f98eac9 // sdot v9.4s, v22.16b, v24.4b[2]\n"
1941
+ ".inst 0x4fb8ead4 // sdot v20.4s, v22.16b, v24.4b[3]\n"
1942
+ "ldr q24, [x25, #0x60]\n"
1943
+ ".inst 0x4f98e36a // sdot v10.4s, v27.16b, v24.4b[0]\n"
1944
+ ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n"
1945
+ ".inst 0x4f98eb69 // sdot v9.4s, v27.16b, v24.4b[2]\n"
1946
+ ".inst 0x4fb8eb74 // sdot v20.4s, v27.16b, v24.4b[3]\n"
1947
+ "ldr q24, [x25, #0x70]\n"
1948
+ "add x25, x25, #0x88\n"
1949
+ ".inst 0x4f98e3ca // sdot v10.4s, v30.16b, v24.4b[0]\n"
1950
+ ".inst 0x4fb8e3dd // sdot v29.4s, v30.16b, v24.4b[1]\n"
1951
+ ".inst 0x4f98ebc9 // sdot v9.4s, v30.16b, v24.4b[2]\n"
1952
+ ".inst 0x4fb8ebd4 // sdot v20.4s, v30.16b, v24.4b[3]\n"
1953
+ "fmul v24.4s, v17.4s, v2.s[0]\n"
1954
+ "scvtf v10.4s, v10.4s, #0x4\n"
1955
+ "scvtf v29.4s, v29.4s, #0x4\n"
1956
+ "scvtf v9.4s, v9.4s, #0x4\n"
1957
+ "scvtf v20.4s, v20.4s, #0x4\n"
1958
+ "fmla v15.4s, v10.4s, v24.4s\n"
1959
+ "ldr q24, [x23, #0x0]\n"
1960
+ "fmul v10.4s, v17.4s, v2.s[1]\n"
1961
+ "fmla v19.4s, v29.4s, v10.4s\n"
1962
+ "ldr q10, [x23, #0x10]\n"
1963
+ "fmul v29.4s, v17.4s, v2.s[2]\n"
1964
+ "fmul v2.4s, v17.4s, v2.s[3]\n"
1965
+ "fmla v18.4s, v9.4s, v29.4s\n"
1966
+ "movi v9.4s, #0x0\n"
1967
+ "movi v29.4s, #0x0\n"
1968
+ ".inst 0x4f98e189 // sdot v9.4s, v12.16b, v24.4b[0]\n"
1969
+ ".inst 0x4fb8e19d // sdot v29.4s, v12.16b, v24.4b[1]\n"
1970
+ "fmla v14.4s, v20.4s, v2.4s\n"
1971
+ "movi v20.4s, #0x0\n"
1972
+ "movi v2.4s, #0x0\n"
1973
+ ".inst 0x4f98e994 // sdot v20.4s, v12.16b, v24.4b[2]\n"
1974
+ ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n"
1975
+ "ldr q24, [x23, #0x20]\n"
1976
+ ".inst 0x4f8ae3e9 // sdot v9.4s, v31.16b, v10.4b[0]\n"
1977
+ ".inst 0x4faae3fd // sdot v29.4s, v31.16b, v10.4b[1]\n"
1978
+ ".inst 0x4f8aebf4 // sdot v20.4s, v31.16b, v10.4b[2]\n"
1979
+ ".inst 0x4faaebe2 // sdot v2.4s, v31.16b, v10.4b[3]\n"
1980
+ "ldr q10, [x23, #0x30]\n"
1981
+ ".inst 0x4f98e0c9 // sdot v9.4s, v6.16b, v24.4b[0]\n"
1982
+ ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n"
1983
+ ".inst 0x4f98e8d4 // sdot v20.4s, v6.16b, v24.4b[2]\n"
1984
+ ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n"
1985
+ "ldr q24, [x23, #0x40]\n"
1986
+ ".inst 0x4f8ae389 // sdot v9.4s, v28.16b, v10.4b[0]\n"
1987
+ ".inst 0x4faae39d // sdot v29.4s, v28.16b, v10.4b[1]\n"
1988
+ ".inst 0x4f8aeb94 // sdot v20.4s, v28.16b, v10.4b[2]\n"
1989
+ ".inst 0x4faaeb82 // sdot v2.4s, v28.16b, v10.4b[3]\n"
1990
+ "ldr q10, [x23, #0x50]\n"
1991
+ ".inst 0x4f98e069 // sdot v9.4s, v3.16b, v24.4b[0]\n"
1992
+ ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n"
1993
+ ".inst 0x4f98e874 // sdot v20.4s, v3.16b, v24.4b[2]\n"
1994
+ ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n"
1995
+ "ldr q24, [x23, #0x60]\n"
1996
+ ".inst 0x4f8ae2c9 // sdot v9.4s, v22.16b, v10.4b[0]\n"
1997
+ ".inst 0x4faae2dd // sdot v29.4s, v22.16b, v10.4b[1]\n"
1998
+ ".inst 0x4f8aead4 // sdot v20.4s, v22.16b, v10.4b[2]\n"
1999
+ ".inst 0x4faaeac2 // sdot v2.4s, v22.16b, v10.4b[3]\n"
2000
+ "ldr q10, [x23, #0x70]\n"
2001
+ "add x23, x23, #0x88\n"
2002
+ ".inst 0x4f98e369 // sdot v9.4s, v27.16b, v24.4b[0]\n"
2003
+ ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n"
2004
+ ".inst 0x4f98eb74 // sdot v20.4s, v27.16b, v24.4b[2]\n"
2005
+ ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n"
2006
+ "ldr q24, [x22, #0x0]\n"
2007
+ ".inst 0x4f8ae3c9 // sdot v9.4s, v30.16b, v10.4b[0]\n"
2008
+ ".inst 0x4faae3dd // sdot v29.4s, v30.16b, v10.4b[1]\n"
2009
+ ".inst 0x4f8aebd4 // sdot v20.4s, v30.16b, v10.4b[2]\n"
2010
+ ".inst 0x4faaebc2 // sdot v2.4s, v30.16b, v10.4b[3]\n"
2011
+ "fmul v10.4s, v17.4s, v26.s[0]\n"
2012
+ "scvtf v9.4s, v9.4s, #0x4\n"
2013
+ "scvtf v29.4s, v29.4s, #0x4\n"
2014
+ "scvtf v20.4s, v20.4s, #0x4\n"
2015
+ "scvtf v2.4s, v2.4s, #0x4\n"
2016
+ "fmla v11.4s, v9.4s, v10.4s\n"
2017
+ "ldr q9, [x22, #0x10]\n"
2018
+ "fmul v10.4s, v17.4s, v26.s[1]\n"
2019
+ "fmla v13.4s, v29.4s, v10.4s\n"
2020
+ "ldr d29, [x22, #-0x8]\n"
2021
+ "fmul v10.4s, v17.4s, v26.s[2]\n"
2022
+ "fmul v26.4s, v17.4s, v26.s[3]\n"
2023
+ "fcvtl v29.4s, v29.4h\n"
2024
+ "fmla v23.4s, v20.4s, v10.4s\n"
2025
+ "movi v20.4s, #0x0\n"
2026
+ "movi v10.4s, #0x0\n"
2027
+ "fmla v16.4s, v2.4s, v26.4s\n"
2028
+ "movi v26.4s, #0x0\n"
2029
+ "movi v2.4s, #0x0\n"
2030
+ ".inst 0x4f98e194 // sdot v20.4s, v12.16b, v24.4b[0]\n"
2031
+ ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n"
2032
+ ".inst 0x4f98e99a // sdot v26.4s, v12.16b, v24.4b[2]\n"
2033
+ ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n"
2034
+ "ldr q24, [x22, #0x20]\n"
2035
+ ".inst 0x4f89e3f4 // sdot v20.4s, v31.16b, v9.4b[0]\n"
2036
+ ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n"
2037
+ ".inst 0x4f89ebfa // sdot v26.4s, v31.16b, v9.4b[2]\n"
2038
+ ".inst 0x4fa9ebe2 // sdot v2.4s, v31.16b, v9.4b[3]\n"
2039
+ "ldr q9, [x22, #0x30]\n"
1130
2040
  ".inst 0x4f98e0d4 // sdot v20.4s, v6.16b, v24.4b[0]\n"
1131
2041
  ".inst 0x4fb8e0ca // sdot v10.4s, v6.16b, v24.4b[1]\n"
1132
2042
  ".inst 0x4f98e8da // sdot v26.4s, v6.16b, v24.4b[2]\n"
@@ -2247,110 +3157,935 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
2247
3157
  );
2248
3158
  return;
2249
3159
  }
2250
- #endif // #if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
3160
+ #endif // #if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
3161
+
3162
+ #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
3163
+ ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
3164
+ }
3165
+
3166
+ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
3167
+ const int qk = QK8_0;
3168
+ const int nb = n / qk;
3169
+ const int ncols_interleaved = 4;
3170
+ const int blocklen = 4;
3171
+
3172
+ assert (n % qk == 0);
3173
+ assert (nr % 4 == 0);
3174
+ assert (nc % ncols_interleaved == 0);
3175
+
3176
+ UNUSED(s);
3177
+ UNUSED(bs);
3178
+ UNUSED(vx);
3179
+ UNUSED(vy);
3180
+ UNUSED(nr);
3181
+ UNUSED(nc);
3182
+ UNUSED(nb);
3183
+ UNUSED(ncols_interleaved);
3184
+ UNUSED(blocklen);
3185
+
3186
+ #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
3187
+ const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl);
3188
+
3189
+ for (int y = 0; y < nr / 4; y++) {
3190
+ const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
3191
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
3192
+ const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
3193
+
3194
+ float32x4_t sumf[4];
3195
+ for (int m = 0; m < 4; m++) {
3196
+ sumf[m] = vdupq_n_f32(0);
3197
+ }
3198
+
3199
+ for (int l = 0; l < nb; l++) {
3200
+ float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d));
3201
+ float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
3202
+
3203
+ int32x4_t sumi_0 = vdupq_n_s32(0);
3204
+ int32x4_t sumi_1 = vdupq_n_s32(0);
3205
+ int32x4_t sumi_2 = vdupq_n_s32(0);
3206
+ int32x4_t sumi_3 = vdupq_n_s32(0);
3207
+
3208
+ for (int k = 0; k < 4; k++) {
3209
+ int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0);
3210
+ int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64);
3211
+
3212
+ uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k);
3213
+ int8x16_t b_hi = vqtbl1q_s8(kvalues, b >> 4);
3214
+ int8x16_t b_lo = vqtbl1q_s8(kvalues, b & 0xF);
3215
+
3216
+ sumi_0 = vdotq_laneq_s32(sumi_0, b_lo, a_0, 0);
3217
+ sumi_1 = vdotq_laneq_s32(sumi_1, b_lo, a_0, 1);
3218
+ sumi_2 = vdotq_laneq_s32(sumi_2, b_lo, a_0, 2);
3219
+ sumi_3 = vdotq_laneq_s32(sumi_3, b_lo, a_0, 3);
3220
+ sumi_0 = vdotq_laneq_s32(sumi_0, b_hi, a_1, 0);
3221
+ sumi_1 = vdotq_laneq_s32(sumi_1, b_hi, a_1, 1);
3222
+ sumi_2 = vdotq_laneq_s32(sumi_2, b_hi, a_1, 2);
3223
+ sumi_3 = vdotq_laneq_s32(sumi_3, b_hi, a_1, 3);
3224
+ }
3225
+
3226
+ sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0));
3227
+ sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1));
3228
+ sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2));
3229
+ sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3));
3230
+ }
3231
+
3232
+ for (int m = 0; m < 4; m++) {
3233
+ vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);
3234
+ }
3235
+ }
3236
+ }
3237
+ return;
3238
+ #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
3239
+ ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
3240
+ }
3241
+
3242
+ void ggml_gemm_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
3243
+ const int qk = QK8_0;
3244
+ const int nb = n / qk;
3245
+ const int ncols_interleaved = 4;
3246
+ const int blocklen = 4;
3247
+
3248
+ assert (n % qk == 0);
3249
+ assert (nr % 4 == 0);
3250
+ assert (nc % ncols_interleaved == 0);
3251
+
3252
+ UNUSED(s);
3253
+ UNUSED(bs);
3254
+ UNUSED(vx);
3255
+ UNUSED(vy);
3256
+ UNUSED(nr);
3257
+ UNUSED(nc);
3258
+ UNUSED(nb);
3259
+ UNUSED(ncols_interleaved);
3260
+ UNUSED(blocklen);
3261
+
3262
+ #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
3263
+ const int8x16_t kvalues = vld1q_s8(kvalues_mxfp4);
3264
+
3265
+ for (int y = 0; y < nr / 4; y++) {
3266
+ const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
3267
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
3268
+ const block_mxfp4x4 * b_ptr = (const block_mxfp4x4 *) vx + (x * nb);
3269
+
3270
+ float32x4_t sumf[4];
3271
+ for (int m = 0; m < 4; m++) {
3272
+ sumf[m] = vdupq_n_f32(0);
3273
+ }
3274
+
3275
+ for (int l = 0; l < nb; l++) {
3276
+ float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d));
3277
+ float32x4_t b_d = {
3278
+ GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[0]),
3279
+ GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[1]),
3280
+ GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[2]),
3281
+ GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[3]),
3282
+ };
3283
+
3284
+ int32x4_t sumi_0 = vdupq_n_s32(0);
3285
+ int32x4_t sumi_1 = vdupq_n_s32(0);
3286
+ int32x4_t sumi_2 = vdupq_n_s32(0);
3287
+ int32x4_t sumi_3 = vdupq_n_s32(0);
3288
+
3289
+ for (int k = 0; k < 4; k++) {
3290
+ int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0);
3291
+ int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64);
3292
+
3293
+ uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k);
3294
+ int8x16_t b_hi = vqtbl1q_s8(kvalues, b >> 4);
3295
+ int8x16_t b_lo = vqtbl1q_s8(kvalues, b & 0xF);
3296
+
3297
+ sumi_0 = vdotq_laneq_s32(sumi_0, b_lo, a_0, 0);
3298
+ sumi_1 = vdotq_laneq_s32(sumi_1, b_lo, a_0, 1);
3299
+ sumi_2 = vdotq_laneq_s32(sumi_2, b_lo, a_0, 2);
3300
+ sumi_3 = vdotq_laneq_s32(sumi_3, b_lo, a_0, 3);
3301
+ sumi_0 = vdotq_laneq_s32(sumi_0, b_hi, a_1, 0);
3302
+ sumi_1 = vdotq_laneq_s32(sumi_1, b_hi, a_1, 1);
3303
+ sumi_2 = vdotq_laneq_s32(sumi_2, b_hi, a_1, 2);
3304
+ sumi_3 = vdotq_laneq_s32(sumi_3, b_hi, a_1, 3);
3305
+ }
3306
+
3307
+ sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0));
3308
+ sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1));
3309
+ sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2));
3310
+ sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3));
3311
+ }
3312
+
3313
+ for (int m = 0; m < 4; m++) {
3314
+ vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);
3315
+ }
3316
+ }
3317
+ }
3318
+ return;
3319
+ #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
3320
+ ggml_gemm_mxfp4_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
3321
+ }
3322
+
3323
+ void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
3324
+ constexpr int qk = QK_K;
3325
+ const int nb = n / qk;
3326
+
3327
+ constexpr int ncols_interleaved = 8;
3328
+ constexpr int blocklen = 4;
3329
+
3330
+ assert(n % qk == 0);
3331
+ assert(nr % 4 == 0);
3332
+ assert(nc % ncols_interleaved == 0);
3333
+
3334
+ UNUSED(nb);
3335
+ UNUSED(ncols_interleaved);
3336
+ UNUSED(blocklen);
3337
+
3338
+ #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
3339
+ constexpr int q8_k_blocklen = 4;
3340
+ constexpr int acc_size = 2 * 4; // 2 row pairs × 4 col pairs
3341
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
3342
+
3343
+ // 8 accumulators: 2 row pairs × 4 col pairs
3344
+ float32x4_t acc_f32[acc_size];
3345
+
3346
+ for (int y = 0; y < nr / q8_k_blocklen; y++) {
3347
+ const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
3348
+
3349
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
3350
+ const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
3351
+
3352
+ for (int i = 0; i < acc_size; i++) {
3353
+ acc_f32[i] = vdupq_n_f32(0);
3354
+ }
3355
+
3356
+ for (int b = 0; b < nb; b++) {
3357
+ // d4 0 1 2 3, 4 5 6 7
3358
+ float32x4_t q4_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d));
3359
+ float32x4_t q4_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4));
3360
+ // d8 0 1 2 3
3361
+ float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d);
3362
+ // mins
3363
+ float32x4_t q4_dmin_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin));
3364
+ float32x4_t q4_dmin_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4));
3365
+
3366
+ // Precomputation of scales and mins
3367
+ float32x4_t sbd_scale_0123[q8_k_blocklen];
3368
+ float32x4_t sbd_scale_4567[q8_k_blocklen];
3369
+ float32x4_t sbd_min_0123[q8_k_blocklen];
3370
+ float32x4_t sbd_min_4567[q8_k_blocklen];
3371
+
3372
+ sbd_scale_0123[0] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 0);
3373
+ sbd_scale_4567[0] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 0);
3374
+ sbd_min_0123[0] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 0);
3375
+ sbd_min_4567[0] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 0);
3376
+
3377
+ sbd_scale_0123[1] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 1);
3378
+ sbd_scale_4567[1] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 1);
3379
+ sbd_min_0123[1] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 1);
3380
+ sbd_min_4567[1] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 1);
3381
+
3382
+ sbd_scale_0123[2] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 2);
3383
+ sbd_scale_4567[2] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 2);
3384
+ sbd_min_0123[2] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 2);
3385
+ sbd_min_4567[2] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 2);
3386
+
3387
+ sbd_scale_0123[3] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 3);
3388
+ sbd_scale_4567[3] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 3);
3389
+ sbd_min_0123[3] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 3);
3390
+ sbd_min_4567[3] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 3);
3391
+
3392
+ // Precomputation of bsums, each vpaddq calcs all the bsums for each row
3393
+ const int16x8_t bsums[q8_k_blocklen] = {
3394
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
3395
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
3396
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
3397
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
3398
+ };
3399
+ int16_t bsums_arr[QK_K / 64][8];
3400
+ for (int q8_row = 0; q8_row < 4; q8_row++) {
3401
+ vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
3402
+ }
3403
+
3404
+ // interleaved bias_acc: [0]->r0 0123, [1]->r1 0123, .., [4]->r0 4567, [5]->r1 4567 ..
3405
+ int32x4_t bias_acc[acc_size];
3406
+ for (int i = 0; i < acc_size; i++) {
3407
+ bias_acc[i] = vdupq_n_s32(0);
3408
+ }
3409
+
3410
+ for (int sb = 0; sb < QK_K / 64; sb++) {
3411
+ // Int accumulators for qs vecdot (4 row x 2 col quartets)
3412
+ int32x4_t acc_lo[acc_size];
3413
+ int32x4_t acc_hi[acc_size];
3414
+ for (int i = 0; i < acc_size; i++) {
3415
+ acc_lo[i] = vdupq_n_s32(0);
3416
+ acc_hi[i] = vdupq_n_s32(0);
3417
+ }
3418
+ // Need scales for the low and high nibbles
3419
+ // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
3420
+ int16x8_t q4sb_scales[2];
3421
+ int16x8_t q4sb_mins[2];
3422
+ for (int i = 0; i < 2; i++) {
3423
+ int8_t aux_q4sb[8];
3424
+ const int offset = sb * 24 + i * 12;
3425
+ decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
3426
+ q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
3427
+ }
3428
+
3429
+ constexpr int reads_per_sb = 8; // 8 * 16 bytes each => 32 qs * 4 rows
3430
+ for (int k = 0; k < reads_per_sb; k++) {
3431
+ const int8x16_t q8_blk0 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k);
3432
+ const int8x16_t q8_blk1 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k + 128);
3433
+
3434
+ // 0..3 & 32..35
3435
+ const uint8x16_t q4_0123 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k);
3436
+ const uint8x16_t q4_4567 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k + 16);
3437
+
3438
+ const int8x16_t q4_0123_lo = vreinterpretq_s8_u8(vandq_u8(q4_0123, m4b));
3439
+ const int8x16_t q4_0123_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_0123, 4));
3440
+
3441
+ acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q4_0123_lo, q8_blk0, 0); // 0..3 r0 c0123
3442
+ acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q4_0123_lo, q8_blk0, 1); // 0..3 r1 c0123
3443
+ acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q4_0123_lo, q8_blk0, 2); // 0..3 r2 c0123
3444
+ acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q4_0123_lo, q8_blk0, 3); // 0..3 r3 c0123
3445
+
3446
+ acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q4_0123_hi, q8_blk1, 0); // 32..35 r0 c0123
3447
+ acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q4_0123_hi, q8_blk1, 1); // 32..35 r1 c0123
3448
+ acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q4_0123_hi, q8_blk1, 2); // 32..35 r2 c0123
3449
+ acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q4_0123_hi, q8_blk1, 3); // 32..35 r3 c0123
3450
+
3451
+ const int8x16_t q4_4567_lo = vreinterpretq_s8_u8(vandq_u8(q4_4567, m4b));
3452
+ const int8x16_t q4_4567_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_4567, 4));
3453
+
3454
+ acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q4_4567_lo, q8_blk0, 0); // 0..3 r0 c4567
3455
+ acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q4_4567_lo, q8_blk0, 1); // 0..3 r1 c4567
3456
+ acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q4_4567_lo, q8_blk0, 2); // 0..3 r2 c4567
3457
+ acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q4_4567_lo, q8_blk0, 3); // 0..3 r3 c4567
3458
+
3459
+ acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q4_4567_hi, q8_blk1, 0); // 32..35 r0 c4567
3460
+ acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q4_4567_hi, q8_blk1, 1); // 32..35 r1 c4567
3461
+ acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q4_4567_hi, q8_blk1, 2); // 32..35 r2 c4567
3462
+ acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q4_4567_hi, q8_blk1, 3); // 32..35 r3 c4567
3463
+ }
3464
+
3465
+ // Scale and bias application
3466
+ // acc is stored interleaved to match output layout
3467
+ const int16x4_t sc_0123_lo = vget_low_s16(q4sb_scales[0]);
3468
+ const int16x4_t sc_4567_lo = vget_high_s16(q4sb_scales[0]);
3469
+ const int16x4_t sc_0123_hi = vget_low_s16(q4sb_scales[1]);
3470
+ const int16x4_t sc_4567_hi = vget_high_s16(q4sb_scales[1]);
3471
+ for (int row = 0; row < q8_k_blocklen; row++) {
3472
+ // Bias correction
3473
+ // row c0123 blk0 and blk1
3474
+ const float32x4_t sumf_0123 =
3475
+ vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[row]),
3476
+ vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[row])));
3477
+ acc_f32[2 * row] = vfmaq_f32(acc_f32[2 * row], sbd_scale_0123[row], sumf_0123);
3478
+
3479
+ // row c4567 blk0 and blk1
3480
+ const float32x4_t sumf_4567 =
3481
+ vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[row + 4]),
3482
+ vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[row + 4])));
3483
+ acc_f32[2 * row + 1] = vfmaq_f32(acc_f32[2 * row + 1], sbd_scale_4567[row], sumf_4567);
3484
+
3485
+ // Bias
3486
+ const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][row * 2]);
3487
+ const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][row * 2 + 1]);
3488
+
3489
+ // row c0123 blk0 and blk1
3490
+ bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
3491
+ bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
3492
+
3493
+ // row c4567 blk0 and blk1
3494
+ bias_acc[2 * row + 1] =
3495
+ vmlal_s16(bias_acc[2 * row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
3496
+ bias_acc[2 * row + 1] =
3497
+ vmlal_s16(bias_acc[2 * row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
3498
+ }
3499
+ } // for sb
2251
3500
 
2252
- #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
2253
- ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
3501
+ for (int row = 0; row < q8_k_blocklen; row++) {
3502
+ acc_f32[2 * row] = vmlsq_f32(acc_f32[2 * row], vcvtq_f32_s32(bias_acc[2 * row]), sbd_min_0123[row]);
3503
+ acc_f32[2 * row + 1] =
3504
+ vmlsq_f32(acc_f32[2 * row + 1], vcvtq_f32_s32(bias_acc[2 * row + 1]), sbd_min_4567[row]);
3505
+ }
3506
+ } // for b
3507
+
3508
+ for (int i = 0; i < q8_k_blocklen; i++) {
3509
+ int row = y * q8_k_blocklen + i;
3510
+ for (int j = 0; j < 2; j++) {
3511
+ int col = x * ncols_interleaved + j * 4;
3512
+ int offset = row * bs + col;
3513
+ vst1q_f32(s + offset, acc_f32[2 * i + j]);
3514
+ }
3515
+ }
3516
+ } // for x
3517
+ } // for y
3518
+ return;
3519
+ #endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
3520
+ ggml_gemm_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
2254
3521
  }
2255
3522
 
2256
- void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
2257
- const int qk = QK8_0;
2258
- const int nb = n / qk;
2259
- const int ncols_interleaved = 4;
2260
- const int blocklen = 4;
3523
+ void ggml_gemm_q5_K_8x4_q8_K(int n,
3524
+ float * GGML_RESTRICT s,
3525
+ size_t bs,
3526
+ const void * GGML_RESTRICT vx,
3527
+ const void * GGML_RESTRICT vy,
3528
+ int nr,
3529
+ int nc) {
3530
+ constexpr int qk = QK_K;
3531
+ const int nb = n / qk;
2261
3532
 
2262
- assert (n % qk == 0);
2263
- assert (nr % 4 == 0);
2264
- assert (nc % ncols_interleaved == 0);
3533
+ constexpr int ncols_interleaved = 8;
3534
+ constexpr int blocklen = 4;
3535
+
3536
+ assert(n % qk == 0);
3537
+ assert(nr % 4 == 0);
3538
+ assert(nc % ncols_interleaved == 0);
2265
3539
 
2266
- UNUSED(s);
2267
- UNUSED(bs);
2268
- UNUSED(vx);
2269
- UNUSED(vy);
2270
- UNUSED(nr);
2271
- UNUSED(nc);
2272
3540
  UNUSED(nb);
2273
3541
  UNUSED(ncols_interleaved);
2274
3542
  UNUSED(blocklen);
2275
3543
 
2276
- #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
2277
- const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl);
3544
+ #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
3545
+ constexpr int q8_k_blocklen = 4;
3546
+ constexpr int acc_size = 2 * 4; // 2 row pairs, 4 col pairs
3547
+ constexpr int col_groups = ncols_interleaved / 4;
3548
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
3549
+ const uint8x16_t mone = vdupq_n_u8(1);
3550
+ const uint8x16_t mtwo = vdupq_n_u8(2);
3551
+
3552
+ // 8 accumulators: 2 row pairs, 4 col pairs
3553
+ float32x4_t acc_f32[acc_size];
3554
+
3555
+ for (int y = 0; y < nr / q8_k_blocklen; y++) {
3556
+ const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
2278
3557
 
2279
- for (int y = 0; y < nr / 4; y++) {
2280
- const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
2281
3558
  for (int x = 0; x < nc / ncols_interleaved; x++) {
2282
- const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
3559
+ const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);
2283
3560
 
2284
- float32x4_t sumf[4];
2285
- for (int m = 0; m < 4; m++) {
2286
- sumf[m] = vdupq_n_f32(0);
3561
+ for (int i = 0; i < acc_size; i++) {
3562
+ acc_f32[i] = vdupq_n_f32(0);
2287
3563
  }
2288
3564
 
2289
- for (int l = 0; l < nb; l++) {
2290
- float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d));
2291
- float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
3565
+ for (int b = 0; b < nb; b++) {
3566
+ // d5 0 1 2 3, 4 5 6 7
3567
+ float32x4_t q5_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d));
3568
+ float32x4_t q5_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4));
3569
+ // d8 0 1 2 3
3570
+ float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d);
3571
+ // mins
3572
+ float32x4_t q5_dmin_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin));
3573
+ float32x4_t q5_dmin_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4));
2292
3574
 
2293
- int32x4_t sumi_0 = vdupq_n_s32(0);
2294
- int32x4_t sumi_1 = vdupq_n_s32(0);
2295
- int32x4_t sumi_2 = vdupq_n_s32(0);
2296
- int32x4_t sumi_3 = vdupq_n_s32(0);
3575
+ // Precomputation of scales and mins
3576
+ float32x4_t sbd_scale_0123[q8_k_blocklen];
3577
+ float32x4_t sbd_scale_4567[q8_k_blocklen];
3578
+ float32x4_t sbd_min_0123[q8_k_blocklen];
3579
+ float32x4_t sbd_min_4567[q8_k_blocklen];
2297
3580
 
2298
- for (int k = 0; k < 4; k++) {
2299
- int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0);
2300
- int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64);
3581
+ sbd_scale_0123[0] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 0);
3582
+ sbd_scale_4567[0] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 0);
3583
+ sbd_min_0123[0] = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 0);
3584
+ sbd_min_4567[0] = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 0);
2301
3585
 
2302
- uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k);
2303
- int8x16_t b_hi = vqtbl1q_s8(kvalues, b >> 4);
2304
- int8x16_t b_lo = vqtbl1q_s8(kvalues, b & 0xF);
3586
+ sbd_scale_0123[1] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 1);
3587
+ sbd_scale_4567[1] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 1);
3588
+ sbd_min_0123[1] = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 1);
3589
+ sbd_min_4567[1] = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 1);
2305
3590
 
2306
- sumi_0 = vdotq_laneq_s32(sumi_0, b_lo, a_0, 0);
2307
- sumi_1 = vdotq_laneq_s32(sumi_1, b_lo, a_0, 1);
2308
- sumi_2 = vdotq_laneq_s32(sumi_2, b_lo, a_0, 2);
2309
- sumi_3 = vdotq_laneq_s32(sumi_3, b_lo, a_0, 3);
2310
- sumi_0 = vdotq_laneq_s32(sumi_0, b_hi, a_1, 0);
2311
- sumi_1 = vdotq_laneq_s32(sumi_1, b_hi, a_1, 1);
2312
- sumi_2 = vdotq_laneq_s32(sumi_2, b_hi, a_1, 2);
2313
- sumi_3 = vdotq_laneq_s32(sumi_3, b_hi, a_1, 3);
3591
+ sbd_scale_0123[2] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 2);
3592
+ sbd_scale_4567[2] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 2);
3593
+ sbd_min_0123[2] = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 2);
3594
+ sbd_min_4567[2] = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 2);
3595
+
3596
+ sbd_scale_0123[3] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 3);
3597
+ sbd_scale_4567[3] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 3);
3598
+ sbd_min_0123[3] = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 3);
3599
+ sbd_min_4567[3] = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 3);
3600
+
3601
+ // Precomputation of bsums, each vpaddq calcs all the bsums for each row
3602
+ const int16x8_t bsums[q8_k_blocklen] = {
3603
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
3604
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
3605
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
3606
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
3607
+ };
3608
+ int16_t bsums_arr[QK_K / 64][8];
3609
+ for (int q8_row = 0; q8_row < 4; q8_row++) {
3610
+ vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
3611
+ }
3612
+
3613
+ // interleaved bias_acc: [0]->r0 0123, [1]->r1 0123, .., [4]->r0 4567, [5]->r1 4567 ..
3614
+ int32x4_t bias_acc[acc_size];
3615
+ for (int i = 0; i < acc_size; i++) {
3616
+ bias_acc[i] = vdupq_n_s32(0);
3617
+ }
3618
+
3619
+ uint8x16_t qh[col_groups][8];
3620
+ for (int c = 0; c < col_groups; c++) {
3621
+ for (int i = 0; i < 8; i++) {
3622
+ qh[c][i] = vld1q_u8(q5_ptr[b].qh + i * 32 + 16 * c);
3623
+ }
3624
+ }
3625
+
3626
+ for (int sb = 0; sb < QK_K / 64; sb++) {
3627
+ // Int accumulators for qs vecdot (4 row * 2 col quartets)
3628
+ int32x4_t acc_lo[acc_size];
3629
+ int32x4_t acc_hi[acc_size];
3630
+ for (int i = 0; i < acc_size; i++) {
3631
+ acc_lo[i] = vdupq_n_s32(0);
3632
+ acc_hi[i] = vdupq_n_s32(0);
3633
+ }
3634
+ // Need scales for the low and high nibbles
3635
+ // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
3636
+ int16x8_t q5sb_scales[2];
3637
+ int16x8_t q5sb_mins[2];
3638
+ for (int i = 0; i < 2; i++) {
3639
+ int8_t aux_q5sb[8];
3640
+ const int offset = sb * 24 + i * 12;
3641
+ decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb);
3642
+ q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb));
3643
+ }
3644
+
3645
+ constexpr int reads_per_sb = 8; // 8 * 16 bytes each => 32 qs * 4 rows
3646
+ for (int k = 0; k < reads_per_sb; k++) {
3647
+ const int8x16_t q8_blk0 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k);
3648
+ const int8x16_t q8_blk1 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k + 128);
3649
+
3650
+ // 0..3 & 32..35
3651
+ const uint8x16_t q5_0123 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 32 * k);
3652
+ const uint8x16_t q5_4567 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 32 * k + 16);
3653
+
3654
+ // NOTE: This is the only difference with q4_K
3655
+ const uint8x16_t hbit_lo_0123 = vandq_u8(qh[0][k], mone);
3656
+ const uint8x16_t hbit_hi_0123 = vshlq_n_u8(vandq_u8(qh[0][k], mtwo), 3);
3657
+ qh[0][k] = vshrq_n_u8(qh[0][k], 2);
3658
+ const uint8x16_t hbit_lo_4567 = vandq_u8(qh[1][k], mone);
3659
+ const uint8x16_t hbit_hi_4567 = vshlq_n_u8(vandq_u8(qh[1][k], mtwo), 3);
3660
+ qh[1][k] = vshrq_n_u8(qh[1][k], 2);
3661
+ // From here, same as q4_K
3662
+
3663
+ const int8x16_t q5_0123_lo =
3664
+ vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_0123, m4b), hbit_lo_0123, 4));
3665
+ const int8x16_t q5_0123_hi =
3666
+ vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_0123, 4), hbit_hi_0123));
3667
+
3668
+ acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q5_0123_lo, q8_blk0, 0); // 0..3 r0 c0123
3669
+ acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q5_0123_lo, q8_blk0, 1); // 0..3 r1 c0123
3670
+ acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q5_0123_lo, q8_blk0, 2); // 0..3 r2 c0123
3671
+ acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q5_0123_lo, q8_blk0, 3); // 0..3 r3 c0123
3672
+
3673
+ acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q5_0123_hi, q8_blk1, 0); // 32..35 r0 c0123
3674
+ acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q5_0123_hi, q8_blk1, 1); // 32..35 r1 c0123
3675
+ acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q5_0123_hi, q8_blk1, 2); // 32..35 r2 c0123
3676
+ acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q5_0123_hi, q8_blk1, 3); // 32..35 r3 c0123
3677
+
3678
+ const int8x16_t q5_4567_lo =
3679
+ vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_4567, m4b), hbit_lo_4567, 4));
3680
+ const int8x16_t q5_4567_hi =
3681
+ vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_4567, 4), hbit_hi_4567));
3682
+
3683
+ acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q5_4567_lo, q8_blk0, 0); // 0..3 r0 c4567
3684
+ acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q5_4567_lo, q8_blk0, 1); // 0..3 r1 c4567
3685
+ acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q5_4567_lo, q8_blk0, 2); // 0..3 r2 c4567
3686
+ acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q5_4567_lo, q8_blk0, 3); // 0..3 r3 c4567
3687
+
3688
+ acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q5_4567_hi, q8_blk1, 0); // 32..35 r0 c4567
3689
+ acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q5_4567_hi, q8_blk1, 1); // 32..35 r1 c4567
3690
+ acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q5_4567_hi, q8_blk1, 2); // 32..35 r2 c4567
3691
+ acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q5_4567_hi, q8_blk1, 3); // 32..35 r3 c4567
3692
+ }
3693
+
3694
+ // Scale and bias application
3695
+ // acc is stored interleaved to match output layout
3696
+ const int16x4_t sc_0123_lo = vget_low_s16(q5sb_scales[0]);
3697
+ const int16x4_t sc_4567_lo = vget_high_s16(q5sb_scales[0]);
3698
+ const int16x4_t sc_0123_hi = vget_low_s16(q5sb_scales[1]);
3699
+ const int16x4_t sc_4567_hi = vget_high_s16(q5sb_scales[1]);
3700
+ for (int row = 0; row < q8_k_blocklen; row++) {
3701
+ // Bias correction
3702
+ // row c0123 blk0 and blk1
3703
+ const float32x4_t sumf_0123 =
3704
+ vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[row]),
3705
+ vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[row])));
3706
+ acc_f32[2 * row] = vfmaq_f32(acc_f32[2 * row], sbd_scale_0123[row], sumf_0123);
3707
+
3708
+ // row c4567 blk0 and blk1
3709
+ const float32x4_t sumf_4567 =
3710
+ vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[row + 4]),
3711
+ vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[row + 4])));
3712
+ acc_f32[2 * row + 1] = vfmaq_f32(acc_f32[2 * row + 1], sbd_scale_4567[row], sumf_4567);
3713
+
3714
+ // Bias
3715
+ const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][row * 2]);
3716
+ const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][row * 2 + 1]);
3717
+
3718
+ // row c0123 blk0 and blk1
3719
+ bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_lo, vget_low_s16(q5sb_mins[0]));
3720
+ bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_hi, vget_low_s16(q5sb_mins[1]));
3721
+
3722
+ // row c4567 blk0 and blk1
3723
+ bias_acc[2 * row + 1] =
3724
+ vmlal_s16(bias_acc[2 * row + 1], bsums_vec_lo, vget_high_s16(q5sb_mins[0]));
3725
+ bias_acc[2 * row + 1] =
3726
+ vmlal_s16(bias_acc[2 * row + 1], bsums_vec_hi, vget_high_s16(q5sb_mins[1]));
3727
+ }
3728
+ } // for sb
3729
+
3730
+ for (int row = 0; row < q8_k_blocklen; row++) {
3731
+ acc_f32[2 * row] = vmlsq_f32(acc_f32[2 * row], vcvtq_f32_s32(bias_acc[2 * row]), sbd_min_0123[row]);
3732
+ acc_f32[2 * row + 1] =
3733
+ vmlsq_f32(acc_f32[2 * row + 1], vcvtq_f32_s32(bias_acc[2 * row + 1]), sbd_min_4567[row]);
3734
+ }
3735
+ } // for b
3736
+
3737
+ for (int i = 0; i < q8_k_blocklen; i++) {
3738
+ int row = y * q8_k_blocklen + i;
3739
+ for (int j = 0; j < 2; j++) {
3740
+ int col = x * ncols_interleaved + j * 4;
3741
+ int offset = row * bs + col;
3742
+ vst1q_f32(s + offset, acc_f32[2 * i + j]);
2314
3743
  }
3744
+ }
3745
+ } // for x
3746
+ } // for y
3747
+ return;
3748
+ #endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
3749
+ ggml_gemm_q5_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
3750
+ }
3751
+
3752
+ void ggml_gemm_q4_K_8x8_q8_K(int n,
3753
+ float * GGML_RESTRICT s,
3754
+ size_t bs,
3755
+ const void * GGML_RESTRICT vx,
3756
+ const void * GGML_RESTRICT vy,
3757
+ int nr,
3758
+ int nc) {
3759
+ constexpr int qk = QK_K;
3760
+ const int nb = n / qk;
3761
+
3762
+ constexpr int ncols_interleaved = 8;
3763
+ constexpr int blocklen = 8;
3764
+
3765
+ assert(n % qk == 0);
3766
+ assert(nr % 4 == 0);
3767
+ assert(nc % ncols_interleaved == 0);
3768
+
3769
+ UNUSED(nb);
3770
+ UNUSED(ncols_interleaved);
3771
+ UNUSED(blocklen);
3772
+
3773
+ #if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
3774
+ if (svcntb() * 8 == 256) {
3775
+ constexpr int q8_k_blocklen = 4;
3776
+ const svuint8_t m4b_1 = svdup_n_u8(0x0f);
3777
+ // 8 accumulators: 2 row pairs × 4 col pairs
3778
+ svfloat32_t acc_f32_01, acc_f32_23, acc_f32_45, acc_f32_67;
3779
+ uint32_t idx_arr[8] = { 0, 2, 4, 6, 1, 3, 5, 7 };
3780
+ svbool_t pg = svptrue_pat_b32(SV_VL8);
3781
+ svuint32_t idx = svld1(pg, idx_arr);
3782
+
3783
+ static const uint32_t idx_data[8] = {0, 4, 2, 6, 1, 5, 3, 7};
3784
+ svuint32_t idx1 = svld1_u32(svptrue_b32(), idx_data);
3785
+
3786
+ for (int y = 0; y < nr / q8_k_blocklen; y++) {
3787
+ const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
3788
+
3789
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
3790
+ const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
3791
+
3792
+ acc_f32_01 = svdup_n_f32(0);
3793
+ acc_f32_23 = svdup_n_f32(0);
3794
+ acc_f32_45 = svdup_n_f32(0);
3795
+ acc_f32_67 = svdup_n_f32(0);
3796
+
3797
+ for (int b = 0; b < nb; b++) {
3798
+ // bsums pairs belongs to the same q8_k subblock
3799
+ // 64 elements loaded and made sum of 0-7 and 8-15 sum || 16-23 and 24 - 31 sum
3800
+ const int16x8_t bsums[4]{
3801
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
3802
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
3803
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
3804
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
3805
+ };
3806
+
3807
+ int32_t bsums_arr32[4][8];
3808
+
3809
+ for (int q8_row = 0; q8_row < 4; q8_row++) {
3810
+ int16x8_t v16 = bsums[q8_row];
3811
+
3812
+ // low 4
3813
+ int32x4_t v32_lo = vmovl_s16(vget_low_s16(v16));
3814
+ vst1q_s32(&bsums_arr32[q8_row][0], v32_lo);
3815
+
3816
+ // high 4
3817
+ int32x4_t v32_hi = vmovl_s16(vget_high_s16(v16));
3818
+ vst1q_s32(&bsums_arr32[q8_row][4], v32_hi);
3819
+ }
3820
+
3821
+ svint32_t sb_acc_0 = svdup_n_s32(0);
3822
+ svint32_t sb_acc_2 = svdup_n_s32(0);
3823
+
3824
+ svint32_t acc_00 = svdup_n_s32(0);
3825
+ svint32_t acc_11 = svdup_n_s32(0);
3826
+ svint32_t acc_22 = svdup_n_s32(0);
3827
+ svint32_t acc_33 = svdup_n_s32(0);
3828
+ svint32_t acc_44 = svdup_n_s32(0);
3829
+ svint32_t acc_55 = svdup_n_s32(0);
3830
+ svint32_t acc_66 = svdup_n_s32(0);
3831
+ svint32_t acc_77 = svdup_n_s32(0);
3832
+
3833
+ svint32_t bias_acc_00 = svdup_n_s32(0);
3834
+ svint32_t bias_acc_22 = svdup_n_s32(0);
3835
+ svint32_t bias_acc_44 = svdup_n_s32(0);
3836
+ svint32_t bias_acc_66 = svdup_n_s32(0);
3837
+
3838
+ for (int sb = 0; sb < QK_K / 64; sb++) {
3839
+ // Need scales for the low and high nibbles
3840
+ // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
3841
+ svint32_t block_scale_0, block_scale_1, block_scale_2, block_scale_3;
3842
+ svint32_t q4sb_mins_0, q4sb_mins_1;
3843
+ {
3844
+ // 2-superblock I am working on
3845
+ const int offset = sb * 24 + 0 * 12;
3846
+ const uint8_t * scales_in = &q4_ptr[b].scales[offset];
3847
+
3848
+ const int offset1 = sb * 24 + 12;
3849
+ const uint8_t * scales_in1 = &q4_ptr[b].scales[offset1];
3850
+
3851
+ constexpr uint32_t kmask1 = 0x3f3f3f3f;
3852
+ constexpr uint32_t kmask2 = 0x0f0f0f0f;
3853
+ constexpr uint32_t kmask3 = 0x03030303;
3854
+ constexpr uint8_t scales_size = 12;
3855
+
3856
+ uint32_t sm[3];
3857
+ memcpy(sm, scales_in, scales_size);
3858
+
3859
+ uint32_t sm1[3];
3860
+ memcpy(sm1, scales_in1, scales_size);
3861
+
3862
+ const uint32_t mins_0_3 = sm[1] & kmask1;
3863
+ const uint32_t mins_4_7 = ((sm[2] >> 4) & kmask2) | (((sm[1] >> 6) & kmask3) << 4);
3864
+
3865
+ const uint32_t mins_0_3_1 = sm1[1] & kmask1;
3866
+ const uint32_t mins_4_7_1 = ((sm1[2] >> 4) & kmask2) | (((sm1[1] >> 6) & kmask3) << 4);
3867
+
3868
+ svuint32_t mins_u32_temp = svzip1_u32(svdup_n_u32(mins_0_3), svdup_n_u32(mins_4_7));
3869
+ svuint32_t mins_u32_temp_1 = svzip1_u32(svdup_n_u32(mins_0_3_1), svdup_n_u32(mins_4_7_1));
3870
+
3871
+ /* reinterpret u32 → u8 */
3872
+ svuint8_t mins_u8 = svreinterpret_u8_u32(mins_u32_temp);
3873
+ svuint8_t mins_u8_1 = svreinterpret_u8_u32(mins_u32_temp_1);
3874
+
3875
+ /* widen u8 → u16->u32 (lower half only) */
3876
+ svuint32_t mins_u16 = svunpklo_u32(svunpklo_u16(mins_u8));
3877
+ svuint32_t mins_u16_1 = svunpklo_u32(svunpklo_u16(mins_u8_1));
3878
+
3879
+ q4sb_mins_0 = svreinterpret_s32_u32(mins_u16);
3880
+ q4sb_mins_1 = svreinterpret_s32_u32(mins_u16_1);
3881
+
3882
+ uint32_t scales_u32_0 = sm[0] & kmask1;
3883
+ uint32_t scales_u32_1 = (sm[2] & kmask2) | (((sm[0] >> 6) & kmask3) << 4);
3884
+ uint32_t scales_u32_2 = sm1[0] & kmask1;
3885
+ uint32_t scales_u32_3 = (sm1[2] & kmask2) | (((sm1[0] >> 6) & kmask3) << 4);
3886
+
3887
+ svuint32_t S01 = svdup_n_u32(scales_u32_0);
3888
+ svuint32_t S23 = svdup_n_u32(scales_u32_1);
3889
+ svuint32_t R01 = svdup_n_u32(scales_u32_2);
3890
+ svuint32_t R23 = svdup_n_u32(scales_u32_3);
3891
+
3892
+ svint8_t S01_b = svreinterpret_s8_u32(S01);
3893
+ svint8_t S23_b = svreinterpret_s8_u32(S23);
3894
+ svint8_t R01_b = svreinterpret_s8_u32(R01);
3895
+ svint8_t R23_b = svreinterpret_s8_u32(R23);
3896
+
3897
+ svint32_t S01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S01_b, S01_b)));
3898
+ svint32_t R01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R01_b, R01_b)));
3899
+ svint32_t S23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S23_b, S23_b)));
3900
+ svint32_t R23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R23_b, R23_b)));
3901
+
3902
+ block_scale_0 = svtbl_s32(svzip1_s32(S01_d, R01_d), idx);
3903
+ block_scale_1 = svtbl_s32(svzip2_s32(S01_d, R01_d), idx);
3904
+ block_scale_2 = svtbl_s32(svzip1_s32(S23_d, R23_d), idx);
3905
+ block_scale_3 = svtbl_s32(svzip2_s32(S23_d, R23_d), idx);
3906
+ }
3907
+
3908
+ const int8_t * q8_base_1 = q8_ptr[b].qs + sb * 256;
3909
+
3910
+ // Load 32-byte per row pair, 1 subblock each time
3911
+ // predicate for activating higher lanes for 16 int8 elements
3912
+ const svbool_t ph16 = svptrue_pat_b8(SV_VL16);
3913
+ // predicate for activating lower lanes for 16 int8 elements
3914
+ const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16);
3915
+
3916
+ svint8_t q8_qs_0 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 0), svld1_s8(pl16, q8_base_1 + 112));
3917
+ svint8_t q8_qs_2 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 32), svld1_s8(pl16, q8_base_1 + 144));
3918
+ svint8_t q8_qs_4 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 64), svld1_s8(pl16, q8_base_1 + 176));
3919
+ svint8_t q8_qs_6 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 96), svld1_s8(pl16, q8_base_1 + 208));
3920
+
3921
+ svint8_t q8_qs_1 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 16), svld1_s8(pl16, q8_base_1 + 128));
3922
+ svint8_t q8_qs_3 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 48), svld1_s8(pl16, q8_base_1 + 160));
3923
+ svint8_t q8_qs_5 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 80), svld1_s8(pl16, q8_base_1 + 192));
3924
+ svint8_t q8_qs_7 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 112), svld1_s8(pl16, q8_base_1 + 224));
3925
+
3926
+ // Q4s columns iterated in pairs (01, 23, 45, 67)
3927
+ for (int cp = 0; cp < ncols_interleaved / 2; cp++) {
3928
+
3929
+ sb_acc_0 = svdup_n_s32(0);
3930
+ sb_acc_2 = svdup_n_s32(0);
3931
+
3932
+ svuint8_t q4_qs_cp_00 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 0);
3933
+ svuint8_t q4_qs_cp_01 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 64);
3934
+ svuint8_t q4_qs_cp_02 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 128);
3935
+ svuint8_t q4_qs_cp_03 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 192);
3936
+
3937
+ svint8_t q4_nibbles_00 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_00, m4b_1), 4));
3938
+ svint8_t q4_nibbles_01 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_01, m4b_1), 4));
3939
+ svint8_t q4_nibbles_02 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_02, m4b_1), 4));
3940
+ svint8_t q4_nibbles_03 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_03, m4b_1), 4));
3941
+
3942
+ sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_00, q8_qs_0);
3943
+ sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_01, q8_qs_2);
3944
+
3945
+ sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_02, q8_qs_4);
3946
+ sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_03, q8_qs_6);
3947
+
3948
+ sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_00, q8_qs_1);
3949
+ sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_01, q8_qs_3);
3950
+
3951
+ sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_02, q8_qs_5);
3952
+ sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_03, q8_qs_7);
3953
+
3954
+ if(cp == 0) {
3955
+ acc_00 = svmla_s32_m(svptrue_b32(), acc_00, sb_acc_0, block_scale_0);
3956
+ acc_44 = svmla_s32_m(svptrue_b32(), acc_44, sb_acc_2, block_scale_0);
3957
+ }
3958
+ if(cp == 1) {
3959
+ acc_11 = svmla_s32_m(svptrue_b32(), acc_11, sb_acc_0, block_scale_1);
3960
+ acc_55 = svmla_s32_m(svptrue_b32(), acc_55, sb_acc_2, block_scale_1);
3961
+ }
3962
+ if(cp == 2) {
3963
+ acc_22 = svmla_s32_m(svptrue_b32(), acc_22, sb_acc_0, block_scale_2);
3964
+ acc_66 = svmla_s32_m(svptrue_b32(), acc_66, sb_acc_2, block_scale_2);
3965
+ }
3966
+ if(cp == 3) {
3967
+ acc_33 = svmla_s32_m(svptrue_b32(), acc_33, sb_acc_0, block_scale_3);
3968
+ acc_77 = svmla_s32_m(svptrue_b32(), acc_77, sb_acc_2, block_scale_3);
3969
+ }
3970
+ }
3971
+
3972
+ bias_acc_00 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_00, svdup_n_s32(bsums_arr32[sb][0]), q4sb_mins_0);
3973
+ bias_acc_00 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_00, svdup_n_s32(bsums_arr32[sb][1]), q4sb_mins_1);
3974
+
3975
+ bias_acc_22 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_22, svdup_n_s32(bsums_arr32[sb][2]), q4sb_mins_0);
3976
+ bias_acc_22 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_22, svdup_n_s32(bsums_arr32[sb][3]), q4sb_mins_1);
3977
+
3978
+ bias_acc_44 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_44, svdup_n_s32(bsums_arr32[sb][4]), q4sb_mins_0);
3979
+ bias_acc_44 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_44, svdup_n_s32(bsums_arr32[sb][5]), q4sb_mins_1);
3980
+
3981
+ bias_acc_66 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_66, svdup_n_s32(bsums_arr32[sb][6]), q4sb_mins_0);
3982
+ bias_acc_66 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_66, svdup_n_s32(bsums_arr32[sb][7]), q4sb_mins_1);
3983
+ } // for sb
3984
+
3985
+
3986
+ acc_00 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_00, svext_s32(acc_00, acc_00, 4));
3987
+ acc_11 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_11, svext_s32(acc_11, acc_11, 4));
3988
+ acc_22 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_22, svext_s32(acc_22, acc_22, 4));
3989
+ acc_33 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_33, svext_s32(acc_33, acc_33, 4));
3990
+ acc_44 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_44, svext_s32(acc_44, acc_44, 4));
3991
+ acc_55 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_55, svext_s32(acc_55, acc_55, 4));
3992
+ acc_66 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_66, svext_s32(acc_66, acc_66, 4));
3993
+ acc_77 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_77, svext_s32(acc_77, acc_77, 4));
3994
+
3995
+ svint32_t reorder_acc_01 = svtbl_s32( svzip1_s32( svtrn1_s32(acc_00, acc_11), svtrn1_s32(acc_22, acc_33)), idx1);
3996
+ svint32_t reorder_acc_23 = svtbl_s32( svzip1_s32( svtrn2_s32(acc_00, acc_11), svtrn2_s32(acc_22, acc_33)), idx1);
3997
+
3998
+ svint32_t reorder_acc_45 = svtbl_s32( svzip1_s32( svtrn1_s32(acc_44, acc_55), svtrn1_s32(acc_66, acc_77)), idx1);
3999
+ svint32_t reorder_acc_67 = svtbl_s32( svzip1_s32( svtrn2_s32(acc_44, acc_55), svtrn2_s32(acc_66, acc_77)), idx1);
4000
+
4001
+ // Broadcast q8 scalar
4002
+ svfloat32_t q8_d = svdup_f32(q8_ptr[b].d[0]);
4003
+
4004
+ svfloat32_t q4_dmin_temp = svcvt_f32_f16_x(svptrue_b32(), svzip1_f16( svld1_f16(svptrue_pat_b16(SV_VL8), (const __fp16 *)q4_ptr[b].dmin), svdup_f16(0)));
4005
+
4006
+ svfloat32_t q4_d_temp = svcvt_f32_f16_x(svptrue_b32(), svzip1_f16( svld1_f16(svptrue_pat_b16(SV_VL8), (const __fp16 *)q4_ptr[b].d), svdup_f16(0)));
4007
+
4008
+ svfloat32_t scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
4009
+ svfloat32_t dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
4010
+
4011
+ acc_f32_01 = svmls_f32_m(svptrue_b32(), acc_f32_01, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_00), dmins1);
4012
+ acc_f32_01 = svmla_f32_m(svptrue_b32(), acc_f32_01, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_01), scale1);
4013
+
4014
+ q8_d = svdup_f32(q8_ptr[b].d[1]);
4015
+
4016
+ scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
4017
+ dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
2315
4018
 
2316
- sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0));
2317
- sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1));
2318
- sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2));
2319
- sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3));
2320
- }
4019
+ acc_f32_23 = svmls_f32_m(svptrue_b32(), acc_f32_23, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_22), dmins1);
4020
+ acc_f32_23 = svmla_f32_m(svptrue_b32(), acc_f32_23, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_23), scale1);
2321
4021
 
2322
- for (int m = 0; m < 4; m++) {
2323
- vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);
2324
- }
2325
- }
2326
- }
2327
- return;
2328
- #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
2329
- ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
2330
- }
4022
+ q8_d = svdup_f32(q8_ptr[b].d[2]);
2331
4023
 
2332
- void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
2333
- constexpr int qk = QK_K;
2334
- const int nb = n / qk;
2335
4024
 
2336
- constexpr int ncols_interleaved = 8;
2337
- constexpr int blocklen = 4;
4025
+ scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
4026
+ dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
2338
4027
 
2339
- assert(n % qk == 0);
2340
- assert(nr % 4 == 0);
2341
- assert(nc % ncols_interleaved == 0);
4028
+ acc_f32_45 = svmls_f32_m(svptrue_b32(), acc_f32_45, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_44), dmins1);
4029
+ acc_f32_45 = svmla_f32_m(svptrue_b32(), acc_f32_45, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_45), scale1);
2342
4030
 
2343
- UNUSED(nb);
2344
- UNUSED(ncols_interleaved);
2345
- UNUSED(blocklen);
4031
+ q8_d = svdup_f32(q8_ptr[b].d[3]);
2346
4032
 
2347
- #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
4033
+ scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
4034
+ dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
4035
+
4036
+ acc_f32_67 = svmls_f32_m(svptrue_b32(), acc_f32_67, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_66), dmins1);
4037
+ acc_f32_67 = svmla_f32_m(svptrue_b32(), acc_f32_67, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_67), scale1);
4038
+
4039
+ } // for b
4040
+
4041
+ // With the previous reorder, the tile is already in the correct memory layout.
4042
+ // Predicate for exactly 4 lanes
4043
+ svbool_t pg4 = svptrue_pat_b32(SV_VL4);
4044
+ for (int i = 0; i < q8_k_blocklen; i++) {
4045
+ int row = y * q8_k_blocklen + i;
4046
+ for (int j = 0; j < 2; j++) {
4047
+ int col = x * ncols_interleaved + j * 4;
4048
+ int offset = row * bs + col;
4049
+
4050
+ if (i == 0 && j == 0) {
4051
+ // acc_f32_0 → lower half of acc_f32_01
4052
+ svst1_f32(pg4, s + offset, acc_f32_01);
4053
+ } else if (i == 0 && j == 1) {
4054
+ // acc_f32_1 → upper half of acc_f32_01
4055
+ svst1_f32(pg4, s + offset, svext_f32(acc_f32_01, acc_f32_01, 4));
4056
+ } else if (i == 1 && j == 0) {
4057
+ // acc_f32_2
4058
+ svst1_f32(pg4, s + offset, acc_f32_23);
4059
+ } else if (i == 1 && j == 1) {
4060
+ // acc_f32_3
4061
+ svst1_f32(pg4, s + offset, svext_f32(acc_f32_23, acc_f32_23, 4));
4062
+ } else if (i == 2 && j == 0) {
4063
+ // acc_f32_4
4064
+ svst1_f32(pg4, s + offset, acc_f32_45);
4065
+ } else if (i == 2 && j == 1) {
4066
+ // acc_f32_5
4067
+ svst1_f32(pg4, s + offset, svext_f32(acc_f32_45, acc_f32_45, 4));
4068
+ } else if (i == 3 && j == 0) {
4069
+ // acc_f32_6
4070
+ svst1_f32(pg4, s + offset, acc_f32_67);
4071
+ } else if (i == 3 && j == 1) {
4072
+ // acc_f32_7
4073
+ svst1_f32(pg4, s + offset, svext_f32(acc_f32_67, acc_f32_67, 4));
4074
+ }
4075
+ }
4076
+ }
4077
+ } // for x
4078
+ } // for y
4079
+ return;
4080
+ }
4081
+ #endif // SVE compile-time end
4082
+
4083
+ #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
2348
4084
  constexpr int q8_k_blocklen = 4;
2349
- constexpr int acc_size = 2 * 4; // 2 row pairs × 4 col pairs
2350
- const uint8x16_t m4b = vdupq_n_u8(0x0f);
4085
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
2351
4086
 
2352
4087
  // 8 accumulators: 2 row pairs × 4 col pairs
2353
- float32x4_t acc_f32[acc_size];
4088
+ float32x4_t acc_f32[blocklen];
2354
4089
 
2355
4090
  for (int y = 0; y < nr / q8_k_blocklen; y++) {
2356
4091
  const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
@@ -2358,162 +4093,167 @@ void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
2358
4093
  for (int x = 0; x < nc / ncols_interleaved; x++) {
2359
4094
  const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
2360
4095
 
2361
- for (int i = 0; i < acc_size; i++) {
4096
+ for (int i = 0; i < blocklen; i++) {
2362
4097
  acc_f32[i] = vdupq_n_f32(0);
2363
4098
  }
2364
4099
 
2365
4100
  for (int b = 0; b < nb; b++) {
2366
- // d4 0 1 2 3, 4 5 6 7
2367
- float32x4_t q4_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d));
2368
- float32x4_t q4_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4));
2369
- // d8 0 1 2 3
2370
- float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d);
2371
- // mins
2372
- float32x4_t q4_dmin_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin));
2373
- float32x4_t q4_dmin_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4));
2374
-
2375
- // Precomputation of scales and mins
2376
- float32x4_t sbd_scale_0123[q8_k_blocklen];
2377
- float32x4_t sbd_scale_4567[q8_k_blocklen];
2378
- float32x4_t sbd_min_0123[q8_k_blocklen];
2379
- float32x4_t sbd_min_4567[q8_k_blocklen];
2380
-
2381
- sbd_scale_0123[0] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 0);
2382
- sbd_scale_4567[0] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 0);
2383
- sbd_min_0123[0] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 0);
2384
- sbd_min_4567[0] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 0);
2385
-
2386
- sbd_scale_0123[1] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 1);
2387
- sbd_scale_4567[1] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 1);
2388
- sbd_min_0123[1] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 1);
2389
- sbd_min_4567[1] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 1);
2390
-
2391
- sbd_scale_0123[2] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 2);
2392
- sbd_scale_4567[2] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 2);
2393
- sbd_min_0123[2] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 2);
2394
- sbd_min_4567[2] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 2);
2395
-
2396
- sbd_scale_0123[3] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 3);
2397
- sbd_scale_4567[3] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 3);
2398
- sbd_min_0123[3] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 3);
2399
- sbd_min_4567[3] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 3);
2400
-
2401
- // Precomputation of bsums, each vpaddq calcs all the bsums for each row
2402
- const int16x8_t bsums[q8_k_blocklen] = {
4101
+ // bsums pairs belongs to the same q8_k subblock
4102
+ const int16x8_t bsums[4]{
2403
4103
  vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
2404
4104
  vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
2405
4105
  vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
2406
4106
  vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
2407
4107
  };
2408
- int16_t bsums_arr[QK_K / 64][8];
4108
+ int16_t bsums_arr[4][8];
2409
4109
  for (int q8_row = 0; q8_row < 4; q8_row++) {
2410
4110
  vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
2411
4111
  }
2412
4112
 
2413
- // interleaved bias_acc: [0]->r0 0123, [1]->r1 0123, .., [4]->r0 4567, [5]->r1 4567 ..
2414
- int32x4_t bias_acc[acc_size];
2415
- for (int i = 0; i < acc_size; i++) {
4113
+ int32x4_t sb_acc[4]; // Aux accumulators to store subblock (partial) results
4114
+ int32x4_t acc[8]; // rows 01 stored in [0][1][2][3] rows 23 stored in [4][5][6][7]
4115
+ int32x4_t bias_acc[8]; // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567, [2]->r1 0123 ...
4116
+ for (int i = 0; i < 8; i++) {
4117
+ acc[i] = vdupq_n_s32(0);
2416
4118
  bias_acc[i] = vdupq_n_s32(0);
2417
4119
  }
2418
4120
 
2419
4121
  for (int sb = 0; sb < QK_K / 64; sb++) {
2420
- // Int accumulators for qs vecdot (4 row x 2 col quartets)
2421
- int32x4_t acc_lo[acc_size];
2422
- int32x4_t acc_hi[acc_size];
2423
- for (int i = 0; i < acc_size; i++) {
2424
- acc_lo[i] = vdupq_n_s32(0);
2425
- acc_hi[i] = vdupq_n_s32(0);
2426
- }
2427
4122
  // Need scales for the low and high nibbles
2428
4123
  // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
2429
- int16x8_t q4sb_scales[2];
2430
- int16x8_t q4sb_mins[2];
4124
+ int8_t q4sb_scales[2][8];
4125
+ int16x8_t q4sb_mins[2]; // int16 as its needed for bias_acc later
2431
4126
  for (int i = 0; i < 2; i++) {
2432
- int8_t aux_q4sb[8];
2433
4127
  const int offset = sb * 24 + i * 12;
2434
- decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
2435
- q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
4128
+ decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]);
2436
4129
  }
2437
4130
 
2438
- constexpr int reads_per_sb = 8; // 8 * 16 bytes each => 32 qs * 4 rows
2439
- for (int k = 0; k < reads_per_sb; k++) {
2440
- const int8x16_t q8_blk0 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k);
2441
- const int8x16_t q8_blk1 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k + 128);
4131
+ // q8_ptr[b].qs has interleaved Q8 rows (01, 23)
4132
+ const int8_t * q8_base = q8_ptr[b].qs + sb * 256;
2442
4133
 
2443
- // 0..3 & 32..35
2444
- const uint8x16_t q4_0123 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k);
2445
- const uint8x16_t q4_4567 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k + 16);
4134
+ int8x16_t q8_qs_01[8];
4135
+ int8x16_t q8_qs_23[8];
2446
4136
 
2447
- const int8x16_t q4_0123_lo = vreinterpretq_s8_u8(vandq_u8(q4_0123, m4b));
2448
- const int8x16_t q4_0123_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_0123, 4));
4137
+ // Load 32-byte per row pair, 1 subblock each time
4138
+ for (int i = 0; i < 8; i++) {
4139
+ const int offset = i * 32; // 16 for row 01, 16 for row 23
4140
+ q8_qs_01[i] = vld1q_s8(q8_base + offset);
4141
+ q8_qs_23[i] = vld1q_s8(q8_base + offset + 16);
4142
+ }
2449
4143
 
2450
- acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q4_0123_lo, q8_blk0, 0); // 0..3 r0 c0123
2451
- acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q4_0123_lo, q8_blk0, 1); // 0..3 r1 c0123
2452
- acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q4_0123_lo, q8_blk0, 2); // 0..3 r2 c0123
2453
- acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q4_0123_lo, q8_blk0, 3); // 0..3 r3 c0123
4144
+ const int8x16_t q8s[2][8] = {
4145
+ { q8_qs_01[0], q8_qs_01[1], q8_qs_01[2], q8_qs_01[3],
4146
+ q8_qs_01[4], q8_qs_01[5], q8_qs_01[6], q8_qs_01[7] },
4147
+ { q8_qs_23[0], q8_qs_23[1], q8_qs_23[2], q8_qs_23[3],
4148
+ q8_qs_23[4], q8_qs_23[5], q8_qs_23[6], q8_qs_23[7] },
4149
+ };
2454
4150
 
2455
- acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q4_0123_hi, q8_blk1, 0); // 32..35 r0 c0123
2456
- acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q4_0123_hi, q8_blk1, 1); // 32..35 r1 c0123
2457
- acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q4_0123_hi, q8_blk1, 2); // 32..35 r2 c0123
2458
- acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q4_0123_hi, q8_blk1, 3); // 32..35 r3 c0123
4151
+ // Q4s columns iterated in pairs (01, 23, 45, 67)
4152
+ for (int cp = 0; cp < ncols_interleaved / 2; cp++) {
4153
+ for (int i = 0; i < 4; i++) {
4154
+ sb_acc[i] = vdupq_n_s32(0);
4155
+ }
2459
4156
 
2460
- const int8x16_t q4_4567_lo = vreinterpretq_s8_u8(vandq_u8(q4_4567, m4b));
2461
- const int8x16_t q4_4567_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_4567, 4));
4157
+ uint8x16_t q4_qs_cp_0 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 0); // 0 .. 7 & 32..39
4158
+ uint8x16_t q4_qs_cp_1 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 64); // 8 ..15 & 40..47
4159
+ uint8x16_t q4_qs_cp_2 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 128); // 16..23 & 48..55
4160
+ uint8x16_t q4_qs_cp_3 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 192); // 24..31 & 56..63
4161
+ const int8x16_t q4_nibbles[2][4] = {
4162
+ {
4163
+ vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_0, m4b)),
4164
+ vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_1, m4b)),
4165
+ vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_2, m4b)),
4166
+ vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_3, m4b)),
4167
+ },
4168
+ {
4169
+ vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_0, 4)),
4170
+ vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_1, 4)),
4171
+ vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_2, 4)),
4172
+ vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_3, 4)),
4173
+ }
4174
+ };
2462
4175
 
2463
- acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q4_4567_lo, q8_blk0, 0); // 0..3 r0 c4567
2464
- acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q4_4567_lo, q8_blk0, 1); // 0..3 r1 c4567
2465
- acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q4_4567_lo, q8_blk0, 2); // 0..3 r2 c4567
2466
- acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q4_4567_lo, q8_blk0, 3); // 0..3 r3 c4567
4176
+ // Calculates the Qs muladd of every row pair (rp) rows 01 and 23 of q8
4177
+ // for each of the internal 32 qs subblock (blk)
4178
+ for (int rp = 0; rp < 2; rp++) {
4179
+ for (int blk = 0; blk < 2; blk++) {
4180
+ const int8x16_t * q8 = &q8s[rp][4 * blk];
4181
+ const int8x16_t * q4 = q4_nibbles[blk];
4182
+ int32x4_t acc = sb_acc[2 * rp + blk];
4183
+ // mul add for each qs in the same subblock
4184
+ for (int qs_offset = 0; qs_offset < 4; qs_offset++) {
4185
+ acc = vmmlaq_s32(acc, q4[qs_offset], q8[qs_offset]);
4186
+ }
4187
+ sb_acc[2 * rp + blk] = acc;
4188
+ }
4189
+ }
2467
4190
 
2468
- acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q4_4567_hi, q8_blk1, 0); // 32..35 r0 c4567
2469
- acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q4_4567_hi, q8_blk1, 1); // 32..35 r1 c4567
2470
- acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q4_4567_hi, q8_blk1, 2); // 32..35 r2 c4567
2471
- acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q4_4567_hi, q8_blk1, 3); // 32..35 r3 c4567
4191
+ // Scales[i] corresponds to column i
4192
+ const int scale_offset = cp * 2;
4193
+ const int32_t scale_00 = q4sb_scales[0][scale_offset];
4194
+ const int32_t scale_01 = q4sb_scales[0][scale_offset + 1];
4195
+ const int32_t scale_10 = q4sb_scales[1][scale_offset];
4196
+ const int32_t scale_11 = q4sb_scales[1][scale_offset + 1];
4197
+ const int32x4_t block_scale_0 = vcombine_s32(vdup_n_s32(scale_00), vdup_n_s32(scale_01));
4198
+ const int32x4_t block_scale_1 = vcombine_s32(vdup_n_s32(scale_10), vdup_n_s32(scale_11));
4199
+
4200
+ acc[cp] = vmlaq_s32(acc[cp], sb_acc[0], block_scale_0);
4201
+ acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[2], block_scale_0);
4202
+ acc[cp] = vmlaq_s32(acc[cp], sb_acc[1], block_scale_1);
4203
+ acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[3], block_scale_1);
2472
4204
  }
2473
4205
 
2474
- // Scale and bias application
2475
- // acc is stored interleaved to match output layout
2476
- const int16x4_t sc_0123_lo = vget_low_s16(q4sb_scales[0]);
2477
- const int16x4_t sc_4567_lo = vget_high_s16(q4sb_scales[0]);
2478
- const int16x4_t sc_0123_hi = vget_low_s16(q4sb_scales[1]);
2479
- const int16x4_t sc_4567_hi = vget_high_s16(q4sb_scales[1]);
2480
- for (int row = 0; row < q8_k_blocklen; row++) {
2481
- // Bias correction
2482
- // row c0123 blk0 and blk1
2483
- const float32x4_t sumf_0123 =
2484
- vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[row]),
2485
- vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[row])));
2486
- acc_f32[2 * row] = vfmaq_f32(acc_f32[2 * row], sbd_scale_0123[row], sumf_0123);
2487
-
2488
- // row c4567 blk0 and blk1
2489
- const float32x4_t sumf_4567 =
2490
- vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[row + 4]),
2491
- vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[row + 4])));
2492
- acc_f32[2 * row + 1] = vfmaq_f32(acc_f32[2 * row + 1], sbd_scale_4567[row], sumf_4567);
2493
-
2494
- // Bias
2495
- const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][row * 2]);
2496
- const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][row * 2 + 1]);
2497
-
2498
- // row c0123 blk0 and blk1
2499
- bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
2500
- bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
4206
+ // Multiply Acc bsum + mins
4207
+ for (int q8_row = 0; q8_row < 4; q8_row++) {
4208
+ // Each pair of subblocks share the same bsums
4209
+ // Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).
4210
+ int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][q8_row * 2]);
4211
+ int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][q8_row * 2 + 1]);
2501
4212
 
2502
- // row c4567 blk0 and blk1
2503
- bias_acc[2 * row + 1] =
2504
- vmlal_s16(bias_acc[2 * row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
2505
- bias_acc[2 * row + 1] =
2506
- vmlal_s16(bias_acc[2 * row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
4213
+ bias_acc[2 * q8_row] =
4214
+ vmlal_s16(bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
4215
+ bias_acc[2 * q8_row] =
4216
+ vmlal_s16(bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
4217
+ bias_acc[2 * q8_row + 1] =
4218
+ vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
4219
+ bias_acc[2 * q8_row + 1] =
4220
+ vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
2507
4221
  }
2508
4222
  } // for sb
2509
4223
 
2510
- for (int row = 0; row < q8_k_blocklen; row++) {
2511
- acc_f32[2 * row] = vmlsq_f32(acc_f32[2 * row], vcvtq_f32_s32(bias_acc[2 * row]), sbd_min_0123[row]);
2512
- acc_f32[2 * row + 1] =
2513
- vmlsq_f32(acc_f32[2 * row + 1], vcvtq_f32_s32(bias_acc[2 * row + 1]), sbd_min_4567[row]);
4224
+ // Reorder of i8mm output with bias and output layout
4225
+ for (int i = 0; i < 8; i++) {
4226
+ int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i]));
4227
+ acc[i] = vcombine_s32(aux.val[0], aux.val[1]);
4228
+ }
4229
+ int32x4_t reorder_acc[8] = {
4230
+ vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])),
4231
+ vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])),
4232
+ vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])),
4233
+ vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])),
4234
+ vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])),
4235
+ vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])),
4236
+ vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])),
4237
+ vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])),
4238
+ };
4239
+
4240
+ for (int i = 0; i < q8_k_blocklen; i++) {
4241
+ for (int j = 0; j < 2; j++) {
4242
+ float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d[i]);
4243
+ float32x4_t q4_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].dmin + j * 4)));
4244
+ const float32x4_t dmins = vmulq_f32(q4_dmin, q8_d);
4245
+
4246
+ float32x4_t q4_d = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].d + j * 4)));
4247
+ const float32x4_t scale = vmulq_f32(q4_d, q8_d);
4248
+
4249
+ acc_f32[2 * i + j] = vmlsq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(bias_acc[2 * i + j]), dmins);
4250
+ acc_f32[2 * i + j] =
4251
+ vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale);
4252
+ }
2514
4253
  }
2515
4254
  } // for b
2516
4255
 
4256
+ // With the previous reorder, the tile is already in the correct memory layout.
2517
4257
  for (int i = 0; i < q8_k_blocklen; i++) {
2518
4258
  int row = y * q8_k_blocklen + i;
2519
4259
  for (int j = 0; j < 2; j++) {
@@ -2525,11 +4265,11 @@ void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
2525
4265
  } // for x
2526
4266
  } // for y
2527
4267
  return;
2528
- #endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
2529
- ggml_gemm_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
4268
+ #endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
4269
+ ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
2530
4270
  }
2531
4271
 
2532
- void ggml_gemm_q4_K_8x8_q8_K(int n,
4272
+ void ggml_gemm_q5_K_8x8_q8_K(int n,
2533
4273
  float * GGML_RESTRICT s,
2534
4274
  size_t bs,
2535
4275
  const void * GGML_RESTRICT vx,
@@ -2552,7 +4292,10 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
2552
4292
 
2553
4293
  #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
2554
4294
  constexpr int q8_k_blocklen = 4;
4295
+ constexpr int col_pairs = ncols_interleaved / 2;
2555
4296
  const uint8x16_t m4b = vdupq_n_u8(0x0f);
4297
+ const uint8x16_t mone = vdupq_n_u8(1);
4298
+ const uint8x16_t mtwo = vdupq_n_u8(2);
2556
4299
 
2557
4300
  // 8 accumulators: 2 row pairs × 4 col pairs
2558
4301
  float32x4_t acc_f32[blocklen];
@@ -2561,7 +4304,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
2561
4304
  const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
2562
4305
 
2563
4306
  for (int x = 0; x < nc / ncols_interleaved; x++) {
2564
- const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
4307
+ const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);
2565
4308
 
2566
4309
  for (int i = 0; i < blocklen; i++) {
2567
4310
  acc_f32[i] = vdupq_n_f32(0);
@@ -2588,14 +4331,24 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
2588
4331
  bias_acc[i] = vdupq_n_s32(0);
2589
4332
  }
2590
4333
 
4334
+ // Load qh once per block and shift after each subblock
4335
+ const uint8_t * qh_base = q5_ptr[b].qh;
4336
+ uint8x16_t qh[col_pairs][4];
4337
+ for (int cp = 0; cp < col_pairs; cp++) {
4338
+ qh[cp][0] = vld1q_u8(qh_base + 16 * cp);
4339
+ qh[cp][1] = vld1q_u8(qh_base + 16 * cp + 64);
4340
+ qh[cp][2] = vld1q_u8(qh_base + 16 * cp + 128);
4341
+ qh[cp][3] = vld1q_u8(qh_base + 16 * cp + 192);
4342
+ }
4343
+
2591
4344
  for (int sb = 0; sb < QK_K / 64; sb++) {
2592
4345
  // Need scales for the low and high nibbles
2593
4346
  // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
2594
- int8_t q4sb_scales[2][8];
2595
- int16x8_t q4sb_mins[2]; // int16 as its needed for bias_acc later
4347
+ int8_t q5sb_scales[2][8];
4348
+ int16x8_t q5sb_mins[2]; // int16 as its needed for bias_acc later
2596
4349
  for (int i = 0; i < 2; i++) {
2597
4350
  const int offset = sb * 24 + i * 12;
2598
- decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]);
4351
+ decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], q5sb_scales[i]);
2599
4352
  }
2600
4353
 
2601
4354
  // q8_ptr[b].qs has interleaved Q8 rows (01, 23)
@@ -2612,64 +4365,89 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
2612
4365
  }
2613
4366
 
2614
4367
  const int8x16_t q8s[2][8] = {
2615
- { q8_qs_01[0], q8_qs_01[1], q8_qs_01[2], q8_qs_01[3],
2616
- q8_qs_01[4], q8_qs_01[5], q8_qs_01[6], q8_qs_01[7] },
2617
- { q8_qs_23[0], q8_qs_23[1], q8_qs_23[2], q8_qs_23[3],
2618
- q8_qs_23[4], q8_qs_23[5], q8_qs_23[6], q8_qs_23[7] },
4368
+ { q8_qs_01[0], q8_qs_01[1], q8_qs_01[2], q8_qs_01[3], q8_qs_01[4], q8_qs_01[5], q8_qs_01[6],
4369
+ q8_qs_01[7] },
4370
+ { q8_qs_23[0], q8_qs_23[1], q8_qs_23[2], q8_qs_23[3], q8_qs_23[4], q8_qs_23[5], q8_qs_23[6],
4371
+ q8_qs_23[7] },
2619
4372
  };
2620
4373
 
2621
- // Q4s columns iterated in pairs (01, 23, 45, 67)
2622
- for (int cp = 0; cp < ncols_interleaved / 2; cp++) {
4374
+ // Q5s columns iterated in pairs (01, 23, 45, 67)
4375
+ for (int cp = 0; cp < col_pairs; cp++) {
2623
4376
  for (int i = 0; i < 4; i++) {
2624
4377
  sb_acc[i] = vdupq_n_s32(0);
2625
4378
  }
2626
4379
 
2627
- uint8x16_t q4_qs_cp_0 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 0); // 0 .. 7 & 32..39
2628
- uint8x16_t q4_qs_cp_1 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 64); // 8 ..15 & 40..47
2629
- uint8x16_t q4_qs_cp_2 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 128); // 16..23 & 48..55
2630
- uint8x16_t q4_qs_cp_3 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 192); // 24..31 & 56..63
2631
- const int8x16_t q4_nibbles[2][4] = {
2632
- {
2633
- vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_0, m4b)),
2634
- vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_1, m4b)),
2635
- vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_2, m4b)),
2636
- vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_3, m4b)),
2637
- },
2638
- {
2639
- vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_0, 4)),
2640
- vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_1, 4)),
2641
- vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_2, 4)),
2642
- vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_3, 4)),
2643
- }
2644
- };
2645
-
2646
- // Calculates the Qs muladd of every row pair (rp) rows 01 and 23 of q8
2647
- // for each of the internal 32 qs subblock (blk)
2648
- for (int rp = 0; rp < 2; rp++) {
2649
- for (int blk = 0; blk < 2; blk++) {
2650
- const int8x16_t * q8 = &q8s[rp][4 * blk];
2651
- const int8x16_t * q4 = q4_nibbles[blk];
2652
- int32x4_t acc = sb_acc[2 * rp + blk];
2653
- // mul add for each qs in the same subblock
2654
- for (int qs_offset = 0; qs_offset < 4; qs_offset++) {
2655
- acc = vmmlaq_s32(acc, q4[qs_offset], q8[qs_offset]);
2656
- }
2657
- sb_acc[2 * rp + blk] = acc;
2658
- }
2659
- }
4380
+ uint8x16_t qs_cp_0 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 0); // 0 .. 7 & 32..39
4381
+ uint8x16_t qs_cp_1 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 64); // 8 ..15 & 40..47
4382
+ uint8x16_t qs_cp_2 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 128); // 16..23 & 48..55
4383
+ uint8x16_t qs_cp_3 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 192); // 24..31 & 56..63
4384
+
4385
+ // This is the only part of the algorithm that differs with Q4_K
4386
+ // Extract High bits and pack into 5 bit weights
4387
+ uint8x16_t hbit_lo_0 = vandq_u8(qh[cp][0], mone);
4388
+ uint8x16_t hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[cp][0], mtwo), 3);
4389
+ qh[cp][0] = vshrq_n_u8(qh[cp][0], 2);
4390
+ // Same as Q4_K, i8mm to dequantize the weights.
4391
+ const int8x16_t qs_lo_0 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_0, m4b), hbit_lo_0, 4));
4392
+ int32x4_t acc_0 = sb_acc[0];
4393
+ acc_0 = vmmlaq_s32(acc_0, qs_lo_0, q8s[0][0]);
4394
+ int32x4_t acc_2 = sb_acc[2];
4395
+ acc_2 = vmmlaq_s32(acc_2, qs_lo_0, q8s[1][0]);
4396
+ const int8x16_t qs_hi_0 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_0, 4), hbit_hi_0));
4397
+ int32x4_t acc_1 = sb_acc[1];
4398
+ acc_1 = vmmlaq_s32(acc_1, qs_hi_0, q8s[0][4]);
4399
+ int32x4_t acc_3 = sb_acc[3];
4400
+ acc_3 = vmmlaq_s32(acc_3, qs_hi_0, q8s[1][4]);
4401
+
4402
+ // Repeat for the other 3 columns (8..15, 16..23, 24..31)
4403
+ uint8x16_t hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[cp][1], mtwo), 3);
4404
+ uint8x16_t hbit_lo_1 = vandq_u8(qh[cp][1], mone);
4405
+ qh[cp][1] = vshrq_n_u8(qh[cp][1], 2);
4406
+ const int8x16_t qs_lo_1 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_1, m4b), hbit_lo_1, 4));
4407
+ acc_0 = vmmlaq_s32(acc_0, qs_lo_1, q8s[0][1]);
4408
+ acc_2 = vmmlaq_s32(acc_2, qs_lo_1, q8s[1][1]);
4409
+ const int8x16_t qs_hi_1 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_1, 4), hbit_hi_1));
4410
+ acc_1 = vmmlaq_s32(acc_1, qs_hi_1, q8s[0][5]);
4411
+ acc_3 = vmmlaq_s32(acc_3, qs_hi_1, q8s[1][5]);
4412
+
4413
+ uint8x16_t hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[cp][2], mtwo), 3);
4414
+ uint8x16_t hbit_lo_2 = vandq_u8(qh[cp][2], mone);
4415
+ qh[cp][2] = vshrq_n_u8(qh[cp][2], 2);
4416
+ const int8x16_t qs_lo_2 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_2, m4b), hbit_lo_2, 4));
4417
+ acc_0 = vmmlaq_s32(acc_0, qs_lo_2, q8s[0][2]);
4418
+ acc_2 = vmmlaq_s32(acc_2, qs_lo_2, q8s[1][2]);
4419
+ const int8x16_t qs_hi_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_2, 4), hbit_hi_2));
4420
+ acc_1 = vmmlaq_s32(acc_1, qs_hi_2, q8s[0][6]);
4421
+ acc_3 = vmmlaq_s32(acc_3, qs_hi_2, q8s[1][6]);
4422
+
4423
+ uint8x16_t hbit_lo_3 = vandq_u8(qh[cp][3], mone);
4424
+ uint8x16_t hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[cp][3], mtwo), 3);
4425
+ qh[cp][3] = vshrq_n_u8(qh[cp][3], 2);
4426
+ const int8x16_t qs_lo_3 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_3, m4b), hbit_lo_3, 4));
4427
+ acc_0 = vmmlaq_s32(acc_0, qs_lo_3, q8s[0][3]);
4428
+ sb_acc[0] = acc_0;
4429
+ acc_2 = vmmlaq_s32(acc_2, qs_lo_3, q8s[1][3]);
4430
+ sb_acc[2] = acc_2;
2660
4431
 
2661
4432
  // Scales[i] corresponds to column i
2662
- const int scale_offset = cp * 2;
2663
- for (int blk = 0; blk < 2; blk++) {
2664
- const int32x4_t block_scale = {
2665
- (int32_t) q4sb_scales[blk][scale_offset],
2666
- (int32_t) q4sb_scales[blk][scale_offset],
2667
- (int32_t) q4sb_scales[blk][scale_offset + 1],
2668
- (int32_t) q4sb_scales[blk][scale_offset + 1],
2669
- };
2670
- acc[cp] = vmlaq_s32(acc[cp], sb_acc[blk], block_scale);
2671
- acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[blk + 2], block_scale);
2672
- }
4433
+ const int scale_offset = cp * 2;
4434
+ const int32_t s0 = q5sb_scales[0][scale_offset];
4435
+ const int32_t s1 = q5sb_scales[0][scale_offset + 1];
4436
+ const int32x4_t block_scale = vcombine_s32(vdup_n_s32(s0), vdup_n_s32(s1));
4437
+ acc[cp] = vmlaq_s32(acc[cp], sb_acc[0], block_scale);
4438
+ acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[2], block_scale);
4439
+
4440
+ const int8x16_t qs_hi_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_3, 4), hbit_hi_3));
4441
+ acc_1 = vmmlaq_s32(acc_1, qs_hi_3, q8s[0][7]);
4442
+ sb_acc[1] = acc_1;
4443
+ acc_3 = vmmlaq_s32(acc_3, qs_hi_3, q8s[1][7]);
4444
+ sb_acc[3] = acc_3;
4445
+
4446
+ const int32_t s2 = q5sb_scales[1][scale_offset];
4447
+ const int32_t s3 = q5sb_scales[1][scale_offset + 1];
4448
+ const int32x4_t block_scale2 = vcombine_s32(vdup_n_s32(s2), vdup_n_s32(s3));
4449
+ acc[cp] = vmlaq_s32(acc[cp], sb_acc[1], block_scale2);
4450
+ acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[3], block_scale2);
2673
4451
  }
2674
4452
 
2675
4453
  // Multiply Acc bsum + mins
@@ -2680,13 +4458,13 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
2680
4458
  int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][q8_row * 2 + 1]);
2681
4459
 
2682
4460
  bias_acc[2 * q8_row] =
2683
- vmlal_s16(bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
4461
+ vmlal_s16(bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16(q5sb_mins[0]));
2684
4462
  bias_acc[2 * q8_row] =
2685
- vmlal_s16(bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
4463
+ vmlal_s16(bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16(q5sb_mins[1]));
2686
4464
  bias_acc[2 * q8_row + 1] =
2687
- vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
4465
+ vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_lo, vget_high_s16(q5sb_mins[0]));
2688
4466
  bias_acc[2 * q8_row + 1] =
2689
- vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
4467
+ vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_hi, vget_high_s16(q5sb_mins[1]));
2690
4468
  }
2691
4469
  } // for sb
2692
4470
 
@@ -2709,11 +4487,11 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
2709
4487
  for (int i = 0; i < q8_k_blocklen; i++) {
2710
4488
  for (int j = 0; j < 2; j++) {
2711
4489
  float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d[i]);
2712
- float32x4_t q4_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].dmin + j * 4)));
2713
- const float32x4_t dmins = vmulq_f32(q4_dmin, q8_d);
4490
+ float32x4_t q5_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) (q5_ptr[b].dmin + j * 4)));
4491
+ const float32x4_t dmins = vmulq_f32(q5_dmin, q8_d);
2714
4492
 
2715
- float32x4_t q4_d = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].d + j * 4)));
2716
- const float32x4_t scale = vmulq_f32(q4_d, q8_d);
4493
+ float32x4_t q5_d = vcvt_f32_f16(vld1_f16((const __fp16 *) (q5_ptr[b].d + j * 4)));
4494
+ const float32x4_t scale = vmulq_f32(q5_d, q8_d);
2717
4495
 
2718
4496
  acc_f32[2 * i + j] = vmlsq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(bias_acc[2 * i + j]), dmins);
2719
4497
  acc_f32[2 * i + j] =
@@ -2735,9 +4513,427 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
2735
4513
  } // for y
2736
4514
  return;
2737
4515
  #endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
2738
- ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
4516
+ ggml_gemm_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
4517
+ }
4518
+
4519
+ void ggml_gemm_q6_K_8x4_q8_K(int n,
4520
+ float * GGML_RESTRICT s,
4521
+ size_t bs,
4522
+ const void * GGML_RESTRICT vx,
4523
+ const void * GGML_RESTRICT vy,
4524
+ int nr,
4525
+ int nc) {
4526
+ constexpr int qk = QK_K;
4527
+ const int nb = n / qk;
4528
+
4529
+ constexpr int ncols_interleaved = 8;
4530
+ constexpr int blocklen = 4;
4531
+
4532
+ assert(n % qk == 0);
4533
+ assert(nr % 4 == 0);
4534
+ assert(nc % ncols_interleaved == 0);
4535
+
4536
+ UNUSED(nb);
4537
+ UNUSED(ncols_interleaved);
4538
+ UNUSED(blocklen);
4539
+
4540
+ #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
4541
+ constexpr int q8_k_blocklen = 4;
4542
+ constexpr int col_groups = ncols_interleaved / 4;
4543
+ constexpr int acc_size = q8_k_blocklen * col_groups; // 4 rows, 2 column groups
4544
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
4545
+ const uint8x16_t mask_lo = vdupq_n_u8(0x03);
4546
+ const uint8x16_t mask_hi = vdupq_n_u8(0x30);
4547
+ const int8x16_t m32s = vdupq_n_s8(32);
4548
+
4549
+ float32x4_t acc_f32[acc_size];
4550
+
4551
+ for (int y = 0; y < nr / q8_k_blocklen; y++) {
4552
+ const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
4553
+
4554
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
4555
+ const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb);
4556
+
4557
+ for (int i = 0; i < acc_size; i++) {
4558
+ acc_f32[i] = vdupq_n_f32(0);
4559
+ }
4560
+
4561
+ for (int b = 0; b < nb; b++) {
4562
+ float32x4_t q6_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d));
4563
+ float32x4_t q6_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4));
4564
+ float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d);
4565
+
4566
+ float32x4_t sbd_scale_0123[q8_k_blocklen];
4567
+ float32x4_t sbd_scale_4567[q8_k_blocklen];
4568
+
4569
+ sbd_scale_0123[0] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 0);
4570
+ sbd_scale_4567[0] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 0);
4571
+ sbd_scale_0123[1] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 1);
4572
+ sbd_scale_4567[1] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 1);
4573
+ sbd_scale_0123[2] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 2);
4574
+ sbd_scale_4567[2] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 2);
4575
+ sbd_scale_0123[3] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 3);
4576
+ sbd_scale_4567[3] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 3);
4577
+
4578
+ int32x4_t acc_s32[acc_size];
4579
+ for (int i = 0; i < acc_size; i++) {
4580
+ acc_s32[i] = vdupq_n_s32(0);
4581
+ }
4582
+
4583
+ int16_t q6_scales[8 * 16];
4584
+ for (int i = 0; i < 16; i++) {
4585
+ int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8));
4586
+ vst1q_s16(q6_scales + i * 8, scales);
4587
+ }
4588
+
4589
+ for (int half = 0; half < 2; half++) {
4590
+ const uint8_t * ql_base = q6_ptr[b].ql + half * 512;
4591
+ const uint8_t * qh_base = q6_ptr[b].qh + half * 256;
4592
+
4593
+ for (int sb = 0; sb < QK_K / 64; sb++) {
4594
+ int32x4_t acc_lo[acc_size];
4595
+ int32x4_t acc_hi[acc_size];
4596
+ for (int i = 0; i < acc_size; i++) {
4597
+ acc_lo[i] = vdupq_n_s32(0);
4598
+ acc_hi[i] = vdupq_n_s32(0);
4599
+ }
4600
+
4601
+ const int8_t * q8_base_l = q8_ptr[b].qs + half * 512 + sb * 64;
4602
+ const int8_t * q8_base_h = q8_ptr[b].qs + half * 512 + 256 + sb * 64;
4603
+
4604
+ // 4 rows * 16 elements per scale
4605
+ // 4 reads of 16 bytes each
4606
+ constexpr int reads_per_sb = 4;
4607
+ int8x16_t q8_l[reads_per_sb];
4608
+ int8x16_t q8_h[reads_per_sb];
4609
+ for (int k = 0; k < reads_per_sb; k++) {
4610
+ q8_l[k] = vld1q_s8(q8_base_l + 16 * k);
4611
+ q8_h[k] = vld1q_s8(q8_base_h + 16 * k);
4612
+ }
4613
+
4614
+ const int ql_off_base = sb * QK_K / 2;
4615
+ const int qh_off_base = ql_off_base & 255;
4616
+
4617
+ uint8x16_t q6_ql_0123[reads_per_sb];
4618
+ uint8x16_t q6_ql_4567[reads_per_sb];
4619
+ uint8x16_t q6_qh_0123[reads_per_sb];
4620
+ uint8x16_t q6_qh_4567[reads_per_sb];
4621
+
4622
+ for (int k = 0; k < reads_per_sb; k++) {
4623
+ q6_ql_0123[k] = vld1q_u8(ql_base + ql_off_base + k * 32);
4624
+ q6_ql_4567[k] = vld1q_u8(ql_base + ql_off_base + k * 32 + 16);
4625
+ q6_qh_0123[k] = vld1q_u8(qh_base + qh_off_base + k * 32);
4626
+ q6_qh_4567[k] = vld1q_u8(qh_base + qh_off_base + k * 32 + 16);
4627
+ }
4628
+
4629
+ if (sb > 1) {
4630
+ for (int k = 0; k < reads_per_sb; k++) {
4631
+ q6_qh_0123[k] = vshrq_n_u8(q6_qh_0123[k], 2);
4632
+ q6_qh_4567[k] = vshrq_n_u8(q6_qh_4567[k], 2);
4633
+ }
4634
+ }
4635
+
4636
+ for (int k = 0; k < reads_per_sb; k++) {
4637
+ // q = (ql | qh) - 32
4638
+ const uint8x16_t hbit_lo_0123 = vandq_u8(q6_qh_0123[k], mask_lo);
4639
+ const uint8x16_t hbit_hi_0123 = vandq_u8(q6_qh_0123[k], mask_hi);
4640
+ const uint8x16_t hbit_lo_4567 = vandq_u8(q6_qh_4567[k], mask_lo);
4641
+ const uint8x16_t hbit_hi_4567 = vandq_u8(q6_qh_4567[k], mask_hi);
4642
+
4643
+ const int8x16_t q6_0123_lo = vsubq_s8(
4644
+ vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_ql_0123[k], m4b), hbit_lo_0123, 4)), m32s);
4645
+ const int8x16_t q6_0123_hi = vsubq_s8(
4646
+ vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_ql_0123[k], 4), hbit_hi_0123)), m32s);
4647
+
4648
+ acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q6_0123_lo, q8_l[k], 0); // 0..3 r0 c0123
4649
+ acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q6_0123_lo, q8_l[k], 1); // 0..3 r1 c0123
4650
+ acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q6_0123_lo, q8_l[k], 2); // 0..3 r2 c0123
4651
+ acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q6_0123_lo, q8_l[k], 3); // 0..3 r3 c0123
4652
+
4653
+ acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q6_0123_hi, q8_h[k], 0); // 64..67 r0 c0123
4654
+ acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q6_0123_hi, q8_h[k], 1); // 64..67 r1 c0123
4655
+ acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q6_0123_hi, q8_h[k], 2); // 64..67 r2 c0123
4656
+ acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q6_0123_hi, q8_h[k], 3); // 64..67 r3 c0123
4657
+
4658
+ const int8x16_t q6_4567_lo = vsubq_s8(
4659
+ vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_ql_4567[k], m4b), hbit_lo_4567, 4)), m32s);
4660
+ const int8x16_t q6_4567_hi = vsubq_s8(
4661
+ vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_ql_4567[k], 4), hbit_hi_4567)), m32s);
4662
+
4663
+ acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q6_4567_lo, q8_l[k], 0); // 0..3 r0 c4567
4664
+ acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q6_4567_lo, q8_l[k], 1); // 0..3 r1 c4567
4665
+ acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q6_4567_lo, q8_l[k], 2); // 0..3 r2 c4567
4666
+ acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q6_4567_lo, q8_l[k], 3); // 0..3 r3 c4567
4667
+
4668
+ acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q6_4567_hi, q8_h[k], 0); // 64..67 r0 c4567
4669
+ acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q6_4567_hi, q8_h[k], 1); // 64..67 r1 c4567
4670
+ acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q6_4567_hi, q8_h[k], 2); // 64..67 r2 c4567
4671
+ acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q6_4567_hi, q8_h[k], 3); // 64..67 r3 c4567
4672
+ }
4673
+
4674
+ // Scale and bias
4675
+ const int scale_idx_l = half * 8 + sb;
4676
+ const int scale_idx_h = half * 8 + sb + 4;
4677
+
4678
+ for (int g = 0; g < col_groups; g++) {
4679
+ const int16x4_t scales_l16 = vld1_s16(q6_scales + scale_idx_l * 8 + g * 4);
4680
+ const int16x4_t scales_h16 = vld1_s16(q6_scales + scale_idx_h * 8 + g * 4);
4681
+ const int32x4_t scale_vec_l = vmovl_s16(scales_l16);
4682
+ const int32x4_t scale_vec_h = vmovl_s16(scales_h16);
4683
+ const int acc_offset = g * q8_k_blocklen;
4684
+
4685
+ for (int row = 0; row < q8_k_blocklen; row++) {
4686
+ const int idx = row * 2 + g;
4687
+ acc_s32[idx] = vmlaq_s32(acc_s32[idx], acc_lo[acc_offset + row], scale_vec_l);
4688
+ acc_s32[idx] = vmlaq_s32(acc_s32[idx], acc_hi[acc_offset + row], scale_vec_h);
4689
+ }
4690
+ }
4691
+ }
4692
+ }
4693
+
4694
+ // Finally we apply the superblock scales
4695
+ for (int row = 0; row < q8_k_blocklen; row++) {
4696
+ const int idx0 = 2 * row;
4697
+ const int idx1 = 2 * row + 1;
4698
+ const int32x4_t acc_0123 = acc_s32[idx0];
4699
+ const int32x4_t acc_4567 = acc_s32[idx1];
4700
+
4701
+ acc_f32[idx0] = vmlaq_f32(acc_f32[idx0], vcvtq_f32_s32(acc_0123), sbd_scale_0123[row]);
4702
+ acc_f32[idx1] = vmlaq_f32(acc_f32[idx1], vcvtq_f32_s32(acc_4567), sbd_scale_4567[row]);
4703
+ }
4704
+ } // for b
4705
+
4706
+ for (int i = 0; i < q8_k_blocklen; i++) {
4707
+ int row = y * q8_k_blocklen + i;
4708
+ for (int j = 0; j < 2; j++) {
4709
+ int col = x * ncols_interleaved + j * 4;
4710
+ int offset = row * bs + col;
4711
+ vst1q_f32(s + offset, acc_f32[2 * i + j]);
4712
+ }
4713
+ }
4714
+ } // for x
4715
+ } // for y
4716
+ return;
4717
+ #endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
4718
+ ggml_gemm_q6_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
2739
4719
  }
2740
4720
 
4721
+ void ggml_gemm_q6_K_8x8_q8_K(int n,
4722
+ float * GGML_RESTRICT s,
4723
+ size_t bs,
4724
+ const void * GGML_RESTRICT vx,
4725
+ const void * GGML_RESTRICT vy,
4726
+ int nr,
4727
+ int nc) {
4728
+ constexpr int qk = QK_K;
4729
+ const int nb = n / qk;
4730
+
4731
+ constexpr int ncols_interleaved = 8;
4732
+ constexpr int blocklen = 8;
4733
+
4734
+ assert(n % qk == 0);
4735
+ assert(nr % 4 == 0);
4736
+ assert(nc % ncols_interleaved == 0);
4737
+
4738
+ UNUSED(nb);
4739
+ UNUSED(ncols_interleaved);
4740
+ UNUSED(blocklen);
4741
+
4742
+ #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
4743
+ constexpr int q8_k_blocklen = 4;
4744
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
4745
+ const uint8x16_t mask_lo = vdupq_n_u8(0x03);
4746
+ const uint8x16_t mask_hi = vdupq_n_u8(0x30);
4747
+ const int8x16_t m32s = vdupq_n_s8(32);
4748
+
4749
+ // 8 accumulators: 4 q8 rows × 2 col groups (0-3, 4-7)
4750
+ float32x4_t acc_f32[blocklen];
4751
+
4752
+ for (int y = 0; y < nr / q8_k_blocklen; y++) {
4753
+ const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
4754
+
4755
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
4756
+ const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb);
4757
+
4758
+ for (int i = 0; i < blocklen; i++) {
4759
+ acc_f32[i] = vdupq_n_f32(0);
4760
+ }
4761
+
4762
+ for (int b = 0; b < nb; b++) {
4763
+ int32x4_t acc[8]; // rows 01 stored in [0][1][2][3], rows 23 stored in [4][5][6][7]
4764
+ for (int i = 0; i < 8; i++) {
4765
+ acc[i] = vdupq_n_s32(0);
4766
+ }
4767
+
4768
+ // Q6_K has simple 8-bit scales, 16 per block (one per 16 values)
4769
+ // Reused for bias and dequantization later
4770
+ int16_t q6_scales[16 * 8];
4771
+ for (int i = 0; i < 16; ++i) {
4772
+ int16x8_t s16 = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8));
4773
+ vst1q_s16(q6_scales + i * 8, s16);
4774
+ }
4775
+
4776
+ // Process two 128-value halves per superblock
4777
+ for (int half = 0; half < 2; half++) {
4778
+
4779
+ const uint8_t * ql_base = q6_ptr[b].ql + half * 512;
4780
+ const uint8_t * qh_base = q6_ptr[b].qh + half * 256;
4781
+
4782
+ // A subblock (sb) is a set of weights that share the scale
4783
+ // Since q6_K scales are per 16 elements
4784
+ // num sbs -> 256 elements / (16 elements/scale * 2 elements/byte * 2 halves)
4785
+ for (int sb = 0; sb < QK_K / 64; sb++) {
4786
+ // Q6_K weight index increasing by 64 instead of 32 requires
4787
+ // loading various q8 memory regions
4788
+ const int8_t * q8_base_l = q8_ptr[b].qs + half * 512 + sb * 64;
4789
+ const int8_t * q8_base_h = q8_ptr[b].qs + half * 512 + 256 + sb * 64;
4790
+
4791
+ int8x16_t q8_l_01[2];
4792
+ int8x16_t q8_l_23[2];
4793
+ for (int i = 0; i < 2; i++) {
4794
+ const int offset = i * 32;
4795
+ q8_l_01[i] = vld1q_s8(q8_base_l + offset); // 0..7 & 8..15 (r01)
4796
+ q8_l_23[i] = vld1q_s8(q8_base_l + offset + 16); // 0..7 & 8..15 (r23)
4797
+ }
4798
+
4799
+ int8x16_t q8_h_01[2];
4800
+ int8x16_t q8_h_23[2];
4801
+ for (int i = 0; i < 2; i++) {
4802
+ const int offset = i * 32;
4803
+ q8_h_01[i] = vld1q_s8(q8_base_h + offset);
4804
+ q8_h_23[i] = vld1q_s8(q8_base_h + offset + 16);
4805
+ }
4806
+
4807
+ const int ql_off_base = sb * QK_K / 2;
4808
+
4809
+ uint8x16_t q6_ql_0[4];
4810
+ uint8x16_t q6_ql_1[4];
4811
+ for (int k = 0; k < 4; k++) {
4812
+ q6_ql_0[k] = vld1q_u8(ql_base + ql_off_base + 16 * k);
4813
+ q6_ql_1[k] = vld1q_u8(ql_base + ql_off_base + 64 + 16 * k);
4814
+ }
4815
+
4816
+ const int qh_off_base = (sb * QK_K / 2) & 255; // wrap after 256 bytes
4817
+ uint8x16_t q6_qh_0[4];
4818
+ uint8x16_t q6_qh_1[4];
4819
+ for (int k = 0; k < 4; k++) {
4820
+ q6_qh_0[k] = vld1q_u8(qh_base + qh_off_base + 16 * k);
4821
+ q6_qh_1[k] = vld1q_u8(qh_base + qh_off_base + 64 + 16 * k);
4822
+ }
4823
+
4824
+ // Adjust for the proper high bits (Sb 2 and 3)
4825
+ if (sb > 1) {
4826
+ for (int k = 0; k < 4; k++) {
4827
+ q6_qh_0[k] = vshrq_n_u8(q6_qh_0[k], 2);
4828
+ q6_qh_1[k] = vshrq_n_u8(q6_qh_1[k], 2);
4829
+ }
4830
+ }
4831
+
4832
+ // Process column pairs (0-1, 2-3, 4-5, 6-7)
4833
+ for (int cp = 0; cp < ncols_interleaved / 2; cp++) {
4834
+ const uint8x16_t q6_qs_cp_0_l = q6_ql_0[cp];
4835
+ const uint8x16_t q6_qs_cp_1_l = q6_ql_1[cp];
4836
+ const uint8x16_t q6_qs_cp_0_h = q6_qh_0[cp];
4837
+ const uint8x16_t q6_qs_cp_1_h = q6_qh_1[cp];
4838
+
4839
+ // Extract high 2 bits for upper nibble reconstruction
4840
+ const uint8x16_t q6_qs_cp_0_hh = vandq_u8(q6_qs_cp_0_h, mask_hi);
4841
+ const uint8x16_t q6_qs_cp_1_hh = vandq_u8(q6_qs_cp_1_h, mask_hi);
4842
+
4843
+ // q6 = (low4 | high2<<4) - 32
4844
+ // Use vsliq_n_u8 to combine shift-left-insert in one instruction (like Q5_K)
4845
+ const int8x16_t q6_l0 = vsubq_s8(
4846
+ vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_qs_cp_0_l, m4b), vandq_u8(q6_qs_cp_0_h, mask_lo), 4)),
4847
+ m32s);
4848
+ const int8x16_t q6_l1 = vsubq_s8(
4849
+ vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_qs_cp_1_l, m4b), vandq_u8(q6_qs_cp_1_h, mask_lo), 4)),
4850
+ m32s);
4851
+ const int8x16_t q6_h0 = vsubq_s8(
4852
+ vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_0_l, 4), q6_qs_cp_0_hh)), m32s);
4853
+ const int8x16_t q6_h1 = vsubq_s8(
4854
+ vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_1_l, 4), q6_qs_cp_1_hh)), m32s);
4855
+
4856
+ // row pair 0, base_l
4857
+ int32x4_t sb_acc_0l = vmmlaq_s32(vdupq_n_s32(0), q6_l0, q8_l_01[0]);
4858
+ sb_acc_0l = vmmlaq_s32(sb_acc_0l, q6_l1, q8_l_01[1]);
4859
+ // row pair 0, base_h
4860
+ int32x4_t sb_acc_0h = vmmlaq_s32(vdupq_n_s32(0), q6_h0, q8_h_01[0]);
4861
+ sb_acc_0h = vmmlaq_s32(sb_acc_0h, q6_h1, q8_h_01[1]);
4862
+ // row pair 1, base_l
4863
+ int32x4_t sb_acc_1l = vmmlaq_s32(vdupq_n_s32(0), q6_l0, q8_l_23[0]);
4864
+ sb_acc_1l = vmmlaq_s32(sb_acc_1l, q6_l1, q8_l_23[1]);
4865
+ // row pair 1, base_h
4866
+ int32x4_t sb_acc_1h = vmmlaq_s32(vdupq_n_s32(0), q6_h0, q8_h_23[0]);
4867
+ sb_acc_1h = vmmlaq_s32(sb_acc_1h, q6_h1, q8_h_23[1]);
4868
+
4869
+ const int scale_idx_l = half * 8 + sb;
4870
+ const int scale_idx_h = half * 8 + sb + 4;
4871
+
4872
+ const int32x4_t scale_vec_l = {
4873
+ q6_scales[scale_idx_l * 8 + cp * 2 + 0],
4874
+ q6_scales[scale_idx_l * 8 + cp * 2 + 0],
4875
+ q6_scales[scale_idx_l * 8 + cp * 2 + 1],
4876
+ q6_scales[scale_idx_l * 8 + cp * 2 + 1],
4877
+ };
4878
+ const int32x4_t scale_vec_h = {
4879
+ q6_scales[scale_idx_h * 8 + cp * 2 + 0],
4880
+ q6_scales[scale_idx_h * 8 + cp * 2 + 0],
4881
+ q6_scales[scale_idx_h * 8 + cp * 2 + 1],
4882
+ q6_scales[scale_idx_h * 8 + cp * 2 + 1],
4883
+ };
4884
+
4885
+ acc[cp] = vmlaq_s32(acc[cp], sb_acc_0l, scale_vec_l);
4886
+ acc[cp] = vmlaq_s32(acc[cp], sb_acc_0h, scale_vec_h);
4887
+ acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc_1l, scale_vec_l);
4888
+ acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc_1h, scale_vec_h);
4889
+ }
4890
+ }
4891
+ } // for half
4892
+
4893
+ // Reorder i8mm output to match memory layout
4894
+ for (int i = 0; i < 8; i++) {
4895
+ int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i]));
4896
+ acc[i] = vcombine_s32(aux.val[0], aux.val[1]);
4897
+ }
4898
+ int32x4_t reorder_acc[8] = {
4899
+ vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])),
4900
+ vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])),
4901
+ vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])),
4902
+ vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])),
4903
+ vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])),
4904
+ vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])),
4905
+ vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])),
4906
+ vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])),
4907
+ };
4908
+
4909
+ // Apply superblock scale (no mins for q6_K)
4910
+ for (int i = 0; i < q8_k_blocklen; i++) {
4911
+ for (int j = 0; j < 2; j++) {
4912
+ float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d[i]);
4913
+ float32x4_t q6_d = vcvt_f32_f16(vld1_f16((const __fp16 *) (q6_ptr[b].d + j * 4)));
4914
+ const float32x4_t scale = vmulq_f32(q6_d, q8_d);
4915
+
4916
+ acc_f32[2 * i + j] =
4917
+ vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale);
4918
+ }
4919
+ }
4920
+ } // for b
4921
+
4922
+ // Store results
4923
+ for (int i = 0; i < q8_k_blocklen; i++) {
4924
+ int row = y * q8_k_blocklen + i;
4925
+ for (int j = 0; j < 2; j++) {
4926
+ int col = x * ncols_interleaved + j * 4;
4927
+ int offset = row * bs + col;
4928
+ vst1q_f32(s + offset, acc_f32[2 * i + j]);
4929
+ }
4930
+ }
4931
+ } // for x
4932
+ } // for y
4933
+ return;
4934
+ #endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
4935
+ ggml_gemm_q6_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
4936
+ }
2741
4937
 
2742
4938
  void ggml_gemm_q8_0_4x4_q8_0(int n,
2743
4939
  float * GGML_RESTRICT s,