whispercpp 1.3.6 → 1.3.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (828) hide show
  1. checksums.yaml +4 -4
  2. data/.document +3 -0
  3. data/.rdoc_options +2 -0
  4. data/README.md +38 -5
  5. data/Rakefile +18 -3
  6. data/ext/dependencies.rb +10 -4
  7. data/ext/dependencies_for_windows.rb +17 -0
  8. data/ext/extconf.rb +20 -8
  9. data/ext/options.rb +54 -14
  10. data/ext/options_for_windows.rb +51 -0
  11. data/ext/ruby_whisper.c +36 -42
  12. data/ext/ruby_whisper.h +135 -0
  13. data/ext/ruby_whisper_context.c +107 -28
  14. data/ext/ruby_whisper_log_queue.c +180 -0
  15. data/ext/ruby_whisper_log_settable.h +47 -0
  16. data/ext/ruby_whisper_parakeet.c +49 -0
  17. data/ext/ruby_whisper_parakeet_context.c +304 -0
  18. data/ext/ruby_whisper_parakeet_context_params.c +117 -0
  19. data/ext/ruby_whisper_parakeet_model.c +84 -0
  20. data/ext/ruby_whisper_parakeet_params.c +548 -0
  21. data/ext/ruby_whisper_parakeet_segment.c +157 -0
  22. data/ext/ruby_whisper_parakeet_token.c +188 -0
  23. data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
  24. data/ext/ruby_whisper_params.c +256 -65
  25. data/ext/ruby_whisper_segment.c +6 -6
  26. data/ext/ruby_whisper_transcribe.cpp +42 -15
  27. data/ext/sources/CMakeLists.txt +41 -3
  28. data/ext/sources/CMakePresets.json +95 -0
  29. data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
  30. data/ext/sources/cmake/parakeet.pc.in +10 -0
  31. data/ext/sources/cmake/whisper.pc.in +1 -1
  32. data/ext/sources/examples/CMakeLists.txt +4 -2
  33. data/ext/sources/examples/bench/bench.cpp +1 -1
  34. data/ext/sources/examples/cli/cli.cpp +43 -9
  35. data/ext/sources/examples/common-ggml.cpp +2 -0
  36. data/ext/sources/examples/common-whisper.cpp +139 -67
  37. data/ext/sources/examples/common-whisper.h +11 -0
  38. data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
  39. data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
  40. data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
  41. data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
  42. data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
  43. data/ext/sources/examples/server/server.cpp +199 -163
  44. data/ext/sources/ggml/CMakeLists.txt +21 -13
  45. data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
  46. data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
  47. data/ext/sources/ggml/include/ggml-alloc.h +1 -0
  48. data/ext/sources/ggml/include/ggml-backend.h +72 -10
  49. data/ext/sources/ggml/include/ggml-cuda.h +3 -0
  50. data/ext/sources/ggml/include/ggml-rpc.h +3 -3
  51. data/ext/sources/ggml/include/ggml.h +101 -9
  52. data/ext/sources/ggml/include/gguf.h +10 -2
  53. data/ext/sources/ggml/src/CMakeLists.txt +22 -5
  54. data/ext/sources/ggml/src/ggml-alloc.c +5 -1
  55. data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
  56. data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
  57. data/ext/sources/ggml/src/ggml-backend-reg.cpp +12 -0
  58. data/ext/sources/ggml/src/ggml-backend.cpp +110 -9
  59. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +4 -0
  60. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +672 -257
  61. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +71 -0
  62. data/ext/sources/ggml/src/ggml-cann/common.h +20 -10
  63. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +211 -30
  64. data/ext/sources/ggml/src/ggml-common.h +11 -0
  65. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +58 -29
  66. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +2 -0
  67. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +16 -16
  68. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +116 -7
  69. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +65 -0
  70. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
  71. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
  72. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +4279 -1292
  73. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +5 -35
  74. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +0 -1
  75. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
  76. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +177 -27
  77. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +1 -1
  78. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +5 -0
  79. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
  80. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +10 -0
  81. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +95 -5
  82. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -0
  83. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +146 -134
  84. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +88 -70
  85. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +372 -73
  86. data/ext/sources/ggml/src/ggml-cpu/ops.h +3 -0
  87. data/ext/sources/ggml/src/ggml-cpu/quants.c +55 -0
  88. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  89. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3 -0
  90. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +90 -0
  91. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +3 -16
  92. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
  93. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
  94. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
  95. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
  96. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
  97. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
  98. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
  99. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
  100. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
  101. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
  102. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
  103. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
  104. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
  105. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
  106. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
  107. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +37 -53
  108. data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
  109. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +17 -7
  110. data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
  111. data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
  112. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +62 -26
  113. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +44 -18
  114. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
  115. data/ext/sources/ggml/src/ggml-cuda/common.cuh +242 -28
  116. data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
  117. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
  118. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
  119. data/ext/sources/ggml/src/ggml-cuda/convert.cu +53 -0
  120. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  121. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +14 -6
  122. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
  123. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +278 -44
  124. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +331 -130
  125. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
  126. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +126 -27
  127. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +40 -15
  128. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +18 -9
  129. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +152 -49
  130. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  131. data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
  132. data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
  133. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +84 -35
  134. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
  135. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1069 -609
  136. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
  137. data/ext/sources/ggml/src/ggml-cuda/mean.cu +4 -2
  138. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +242 -195
  139. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +3 -3
  140. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
  141. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +502 -423
  142. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +19 -12
  143. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +485 -57
  144. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -1
  145. data/ext/sources/ggml/src/ggml-cuda/norm.cu +36 -10
  146. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
  147. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +133 -26
  148. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
  149. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +5 -1
  150. data/ext/sources/ggml/src/ggml-cuda/rope.cu +11 -4
  151. data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
  152. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
  153. data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
  154. data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
  155. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
  156. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +45 -13
  157. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  158. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
  159. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
  160. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
  161. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +1 -0
  162. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
  163. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  164. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
  165. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +1 -0
  166. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  167. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
  168. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
  169. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  170. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
  171. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  172. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
  173. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
  174. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
  175. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
  176. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
  177. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
  178. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
  179. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
  180. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
  181. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
  182. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
  183. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
  184. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
  185. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
  186. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
  187. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
  188. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
  189. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
  190. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
  191. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -4
  192. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +26 -23
  193. data/ext/sources/ggml/src/ggml-cuda/unary.cu +31 -2
  194. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
  195. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +80 -0
  196. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
  197. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -4
  198. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
  199. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +2 -1
  200. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1428 -743
  201. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +45 -7
  202. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +53 -84
  203. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +25 -12
  204. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +165 -184
  205. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
  206. data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
  207. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +170 -127
  208. data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
  209. data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
  210. data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
  211. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +125 -97
  212. data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
  213. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +148 -42
  214. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +2 -2
  215. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +252 -62
  216. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +9 -0
  217. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +87 -1
  218. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
  219. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
  220. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
  221. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
  222. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
  223. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
  224. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
  225. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
  226. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +96 -13
  227. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +182 -57
  228. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
  229. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +71 -3
  230. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +27 -10
  231. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +63 -23
  232. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +9 -8
  233. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
  234. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
  235. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
  236. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
  237. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +1 -0
  238. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
  239. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +5 -8
  240. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +529 -815
  241. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2522 -234
  242. data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
  243. data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
  244. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +291 -95
  245. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +59 -37
  246. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +121 -133
  247. data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
  248. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +244 -151
  249. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +6 -6
  250. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +719 -45
  251. data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
  252. data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
  253. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +3 -1
  254. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +22 -9
  255. data/ext/sources/ggml/src/ggml-impl.h +6 -1
  256. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +138 -13
  257. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +32 -1
  258. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +164 -28
  259. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +80 -0
  260. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +190 -19
  261. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +2 -0
  262. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +39 -26
  263. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +823 -322
  264. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
  265. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +54 -5
  266. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +12248 -5907
  267. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +67 -0
  268. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +59 -0
  269. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1819 -112
  270. data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
  271. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
  272. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
  273. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
  274. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
  275. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
  276. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
  277. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
  278. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
  279. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
  280. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
  281. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
  282. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
  283. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
  284. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
  285. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
  286. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mm_q8_0_f32_8x4.cl → gemm_noshuffle_q8_0_f32.cl} +1 -1
  287. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
  288. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
  289. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
  290. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
  291. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
  292. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
  293. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
  294. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
  295. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
  296. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
  297. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
  298. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
  299. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
  300. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
  301. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
  302. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
  303. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
  304. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
  305. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
  306. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
  307. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
  308. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
  309. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
  310. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
  311. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
  312. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
  313. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
  314. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
  315. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
  316. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
  317. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
  318. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +48 -64
  322. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +15 -5
  323. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +18 -11
  324. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +35 -13
  325. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +264 -192
  326. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +33 -7
  327. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
  328. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +1 -0
  329. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +1 -0
  330. data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
  331. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +27 -3
  332. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +67 -36
  333. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +1 -0
  334. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +101 -44
  335. data/ext/sources/ggml/src/ggml-openvino/utils.h +23 -3
  336. data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
  337. data/ext/sources/ggml/src/ggml-quants.c +289 -114
  338. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  339. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
  340. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
  341. data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
  342. data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
  343. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +50 -4
  344. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +1 -1
  345. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +3 -1
  346. data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
  347. data/ext/sources/ggml/src/ggml-sycl/common.hpp +41 -1
  348. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +115 -13
  349. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +9 -0
  350. data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
  351. data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
  352. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
  353. data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
  354. data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
  355. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
  356. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1 -90
  357. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +0 -2
  358. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
  359. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
  360. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +7 -5
  361. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +4 -0
  362. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +76 -168
  363. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +7 -0
  364. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +3 -1
  365. data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
  366. data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
  367. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +69 -31
  368. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +1 -0
  369. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  370. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
  371. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +823 -190
  372. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
  373. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
  374. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
  375. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
  376. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
  377. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +71 -0
  378. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
  379. data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
  380. data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
  381. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
  382. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
  383. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
  384. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
  385. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
  386. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
  387. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +1 -0
  388. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +1 -0
  389. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +1 -0
  390. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +1 -0
  391. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +1 -0
  392. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +1 -0
  393. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +1 -0
  394. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +1 -0
  395. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +1 -0
  396. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +1 -0
  397. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +1 -0
  398. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +1 -0
  399. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +1 -0
  400. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +1 -0
  401. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +1 -0
  402. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +1 -0
  403. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +1 -0
  404. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +1 -0
  405. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +1 -0
  406. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +1 -0
  407. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +1 -0
  408. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +1 -0
  409. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +1 -0
  410. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +1 -0
  411. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +1 -0
  412. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +1 -0
  413. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +1 -0
  414. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +1 -0
  415. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +1 -0
  416. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +1 -0
  417. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +1 -0
  418. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +1 -0
  419. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +1 -0
  420. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +1 -0
  421. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +1 -0
  422. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +1 -0
  423. data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
  424. data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
  425. data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
  426. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +215 -53
  427. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +4 -0
  428. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +2 -0
  429. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +2 -0
  430. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +1 -0
  431. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +1 -0
  432. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +0 -2
  433. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +11 -0
  434. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +2060 -535
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
  449. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +197 -48
  450. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +60 -59
  451. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +115 -113
  452. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +122 -31
  453. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
  454. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
  455. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
  456. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +125 -64
  457. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
  458. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
  459. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
  460. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
  461. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
  462. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
  463. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
  464. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
  465. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
  466. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
  467. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
  468. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
  469. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
  470. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -17
  471. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +43 -10
  472. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
  473. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
  474. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
  475. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
  476. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +5 -2
  477. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
  478. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +3 -2
  479. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
  480. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +11 -1
  481. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
  482. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
  483. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +171 -147
  484. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
  485. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +2202 -283
  486. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2610 -1403
  487. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
  488. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
  489. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +8 -7
  490. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +76 -95
  491. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +19 -1
  492. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
  493. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -50
  494. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +107 -184
  495. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
  496. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
  497. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
  498. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
  499. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
  500. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
  501. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +183 -78
  502. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
  503. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
  504. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +655 -495
  505. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
  506. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
  507. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
  508. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +8 -6
  509. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +5 -1
  510. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +80 -409
  511. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
  512. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
  513. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
  514. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
  515. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
  516. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
  517. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
  518. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +6 -4
  519. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
  520. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +2 -3
  521. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
  522. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
  523. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
  524. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
  525. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
  526. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +68 -48
  527. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
  528. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +18 -14
  529. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +1 -1
  530. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +244 -10
  531. data/ext/sources/ggml/src/ggml.c +110 -28
  532. data/ext/sources/ggml/src/gguf.cpp +173 -28
  533. data/ext/sources/include/parakeet.h +342 -0
  534. data/ext/sources/include/whisper.h +10 -0
  535. data/ext/sources/media/matmul.png +0 -0
  536. data/ext/sources/src/CMakeLists.txt +23 -0
  537. data/ext/sources/src/parakeet-arch.h +188 -0
  538. data/ext/sources/src/parakeet.cpp +3838 -0
  539. data/ext/sources/src/whisper.cpp +56 -12
  540. data/extsources.rb +26 -10
  541. data/lib/whisper/log_settable.rb +36 -0
  542. data/lib/whisper/model/uri.rb +13 -1
  543. data/lib/whisper/output.rb +74 -0
  544. data/sig/whisper.rbs +411 -62
  545. data/test/helper.rb +2 -0
  546. data/test/jfk_reader/jfk_reader.c +50 -7
  547. data/test/test_callback.rb +1 -0
  548. data/test/test_package.rb +6 -5
  549. data/test/test_parakeet.rb +28 -0
  550. data/test/test_parakeet_callback.rb +107 -0
  551. data/test/test_parakeet_context.rb +116 -0
  552. data/test/test_parakeet_context_params.rb +24 -0
  553. data/test/test_parakeet_model.rb +21 -0
  554. data/test/test_parakeet_params.rb +78 -0
  555. data/test/test_parakeet_segment.rb +42 -0
  556. data/test/test_parakeet_token.rb +73 -0
  557. data/test/test_params.rb +2 -0
  558. data/test/test_vad_segment.rb +1 -1
  559. data/test/test_whisper.rb +24 -6
  560. data/whispercpp.gemspec +2 -2
  561. metadata +215 -281
  562. data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
  563. data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
  564. data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
  565. data/ext/sources/bindings/javascript/package.json +0 -26
  566. data/ext/sources/bindings/javascript/whisper.js +0 -19
  567. data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
  568. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
  569. data/ext/sources/examples/addon.node/addon.cpp +0 -557
  570. data/ext/sources/examples/addon.node/index.js +0 -59
  571. data/ext/sources/examples/addon.node/package.json +0 -16
  572. data/ext/sources/examples/addon.node/vad-example.js +0 -132
  573. data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
  574. data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
  575. data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
  576. data/ext/sources/examples/coi-serviceworker.js +0 -146
  577. data/ext/sources/examples/command/CMakeLists.txt +0 -10
  578. data/ext/sources/examples/command/command.cpp +0 -802
  579. data/ext/sources/examples/command/commands.txt +0 -9
  580. data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
  581. data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
  582. data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
  583. data/ext/sources/examples/generate-karaoke.sh +0 -57
  584. data/ext/sources/examples/helpers.js +0 -191
  585. data/ext/sources/examples/livestream.sh +0 -112
  586. data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
  587. data/ext/sources/examples/lsp/lsp.cpp +0 -471
  588. data/ext/sources/examples/lsp/whisper.vim +0 -362
  589. data/ext/sources/examples/python/test_whisper_processor.py +0 -7
  590. data/ext/sources/examples/python/whisper_processor.py +0 -54
  591. data/ext/sources/examples/server/bench.js +0 -29
  592. data/ext/sources/examples/server.py +0 -120
  593. data/ext/sources/examples/stream/CMakeLists.txt +0 -10
  594. data/ext/sources/examples/stream/stream.cpp +0 -437
  595. data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
  596. data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
  597. data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
  598. data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
  599. data/ext/sources/examples/sycl/build.sh +0 -22
  600. data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
  601. data/ext/sources/examples/sycl/run-whisper.sh +0 -17
  602. data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -48
  603. data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
  604. data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -488
  605. data/ext/sources/examples/talk-llama/llama-adapter.h +0 -89
  606. data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2877
  607. data/ext/sources/examples/talk-llama/llama-arch.h +0 -628
  608. data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -919
  609. data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
  610. data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -896
  611. data/ext/sources/examples/talk-llama/llama-chat.h +0 -71
  612. data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3633
  613. data/ext/sources/examples/talk-llama/llama-context.h +0 -359
  614. data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
  615. data/ext/sources/examples/talk-llama/llama-cparams.h +0 -47
  616. data/ext/sources/examples/talk-llama/llama-ext.h +0 -12
  617. data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
  618. data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
  619. data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2735
  620. data/ext/sources/examples/talk-llama/llama-graph.h +0 -1031
  621. data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -258
  622. data/ext/sources/examples/talk-llama/llama-hparams.h +0 -353
  623. data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
  624. data/ext/sources/examples/talk-llama/llama-impl.h +0 -75
  625. data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
  626. data/ext/sources/examples/talk-llama/llama-io.h +0 -35
  627. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -330
  628. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
  629. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2285
  630. data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -389
  631. data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
  632. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +0 -275
  633. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +0 -140
  634. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
  635. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
  636. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1165
  637. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
  638. data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
  639. data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
  640. data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -752
  641. data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
  642. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1655
  643. data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -206
  644. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -299
  645. data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -40
  646. data/ext/sources/examples/talk-llama/llama-model.cpp +0 -9056
  647. data/ext/sources/examples/talk-llama/llama-model.h +0 -597
  648. data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1304
  649. data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
  650. data/ext/sources/examples/talk-llama/llama-sampler.cpp +0 -3885
  651. data/ext/sources/examples/talk-llama/llama-sampler.h +0 -42
  652. data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3970
  653. data/ext/sources/examples/talk-llama/llama-vocab.h +0 -187
  654. data/ext/sources/examples/talk-llama/llama.cpp +0 -1194
  655. data/ext/sources/examples/talk-llama/llama.h +0 -1573
  656. data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -190
  657. data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
  658. data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
  659. data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -137
  660. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
  661. data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -123
  662. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -143
  663. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -133
  664. data/ext/sources/examples/talk-llama/models/bert.cpp +0 -184
  665. data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -145
  666. data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
  667. data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
  668. data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
  669. data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
  670. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
  671. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
  672. data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
  673. data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -122
  674. data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
  675. data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -142
  676. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -262
  677. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +0 -445
  678. data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -132
  679. data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
  680. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -148
  681. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
  682. data/ext/sources/examples/talk-llama/models/eurobert.cpp +0 -97
  683. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +0 -145
  684. data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
  685. data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
  686. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -111
  687. data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
  688. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
  689. data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
  690. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
  691. data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
  692. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
  693. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
  694. data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -157
  695. data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
  696. data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
  697. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -195
  698. data/ext/sources/examples/talk-llama/models/granite.cpp +0 -210
  699. data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
  700. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -139
  701. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
  702. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -153
  703. data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
  704. data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
  705. data/ext/sources/examples/talk-llama/models/jais2.cpp +0 -123
  706. data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
  707. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +0 -381
  708. data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -196
  709. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
  710. data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
  711. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
  712. data/ext/sources/examples/talk-llama/models/llama.cpp +0 -175
  713. data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
  714. data/ext/sources/examples/talk-llama/models/mamba-base.cpp +0 -289
  715. data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -54
  716. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -129
  717. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -200
  718. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -123
  719. data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
  720. data/ext/sources/examples/talk-llama/models/models.h +0 -704
  721. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -109
  722. data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
  723. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -162
  724. data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
  725. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
  726. data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
  727. data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
  728. data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
  729. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
  730. data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
  731. data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
  732. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +0 -122
  733. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
  734. data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
  735. data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
  736. data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
  737. data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -320
  738. data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
  739. data/ext/sources/examples/talk-llama/models/plm.cpp +0 -169
  740. data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
  741. data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
  742. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
  743. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
  744. data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -120
  745. data/ext/sources/examples/talk-llama/models/qwen35.cpp +0 -381
  746. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +0 -422
  747. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -131
  748. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -525
  749. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -140
  750. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -132
  751. data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
  752. data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
  753. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -164
  754. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
  755. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
  756. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -137
  757. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
  758. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
  759. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
  760. data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
  761. data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
  762. data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
  763. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
  764. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +0 -165
  765. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
  766. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
  767. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
  768. data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
  769. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
  770. data/ext/sources/examples/talk-llama/speak +0 -40
  771. data/ext/sources/examples/talk-llama/speak.bat +0 -1
  772. data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
  773. data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
  774. data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
  775. data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
  776. data/ext/sources/examples/talk-llama/unicode.cpp +0 -1103
  777. data/ext/sources/examples/talk-llama/unicode.h +0 -111
  778. data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
  779. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
  780. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
  781. data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
  782. data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
  783. data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
  784. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
  785. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
  786. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
  787. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
  788. data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
  789. data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
  790. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
  791. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -155
  792. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
  793. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
  794. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +0 -123
  795. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +0 -17
  796. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +0 -333
  797. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
  798. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -182
  799. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
  800. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +0 -718
  801. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
  802. data/ext/sources/tests/CMakeLists.txt +0 -112
  803. data/ext/sources/tests/earnings21/eval.mk +0 -58
  804. data/ext/sources/tests/earnings21/eval.py +0 -68
  805. data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
  806. data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
  807. data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
  808. data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
  809. data/ext/sources/tests/earnings21/requirements.txt +0 -6
  810. data/ext/sources/tests/en-0-ref.txt +0 -1
  811. data/ext/sources/tests/en-1-ref.txt +0 -1
  812. data/ext/sources/tests/en-2-ref.txt +0 -1
  813. data/ext/sources/tests/es-0-ref.txt +0 -1
  814. data/ext/sources/tests/librispeech/eval.mk +0 -39
  815. data/ext/sources/tests/librispeech/eval.py +0 -47
  816. data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
  817. data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
  818. data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
  819. data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
  820. data/ext/sources/tests/librispeech/requirements.txt +0 -6
  821. data/ext/sources/tests/run-tests.sh +0 -130
  822. data/ext/sources/tests/test-c.c +0 -3
  823. data/ext/sources/tests/test-vad-full.cpp +0 -56
  824. data/ext/sources/tests/test-vad.cpp +0 -83
  825. data/ext/sources/tests/test-whisper.js +0 -58
  826. data/lib/whisper/context.rb +0 -15
  827. data/lib/whisper/segment.rb +0 -58
  828. /data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general_q8_0_f32.cl → gemv_noshuffle_q8_0_f32.cl} +0 -0
@@ -16,8 +16,9 @@
16
16
  #define GGML_COMMON_DECL_C
17
17
  #include "ggml-common.h"
18
18
  #include "htp-ctx.h"
19
- #include "htp-msg.h"
20
19
  #include "htp-ops.h"
20
+ #include "htp-ops.h"
21
+ #include "hmx-ops.h"
21
22
 
22
23
  #define MM_SPAD_SRC0_NROWS 16
23
24
  #define MM_SPAD_SRC1_NROWS 16
@@ -39,6 +40,11 @@ struct htp_matmul_context {
39
40
  const void * restrict vx0, const void * restrict vx1,
40
41
  const void * restrict vy0, const void * restrict vy1);
41
42
 
43
+ void (*vec_dot_4x1)(const int n, float * restrict s0,
44
+ const void * restrict vx0, const void * restrict vx1,
45
+ const void * restrict vx2, const void * restrict vx3,
46
+ const void * restrict vy0);
47
+
42
48
  // Precomputed values
43
49
  uint32_t src0_nrows_per_thread;
44
50
  uint32_t src1_nrows_per_thread;
@@ -47,6 +53,11 @@ struct htp_matmul_context {
47
53
  struct fastdiv_values mm_div_ne1;
48
54
  struct fastdiv_values mm_div_r2;
49
55
  struct fastdiv_values mm_div_r3;
56
+
57
+ // Fields for scattered mapping & HMX support in MUL_MAT_ID
58
+ const uint32_t * matrix_row_counts;
59
+ const struct mmid_row_mapping * matrix_rows;
60
+ bool hmx_eligible;
50
61
  };
51
62
 
52
63
  // vdelta control to expand first 32 e8m0 values into 32 uint32 elements
@@ -60,6 +71,16 @@ static const uint8_t __attribute__((aligned(128))) expand_x32_e8m0[128] = {
60
71
  0x00, 0x00, 0x09, 0x08, 0x00, 0x00, 0x22, 0x20, 0x24, 0x20, 0x21, 0x22, 0x20, 0x20,
61
72
  };
62
73
 
74
+ // IQ4_NL dequantization LUT: maps 4-bit index (0-15) to int8 kvalue
75
+ // kvalues: -127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113
76
+ static const uint8_t __attribute__((aligned(VLEN))) kvalues_iq4nl_lut[] = {
77
+ 0x81, 0, 0x98, 0, 0xAD, 0, 0xBF, 0, 0xCF, 0, 0xDD, 0, 0xEA, 0, 0xF6, 0, 0x01, 0, 0x0D, 0, 0x19, 0, 0x26, 0,
78
+ 0x35, 0, 0x45, 0, 0x59, 0, 0x71, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
79
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
80
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
81
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
82
+ };
83
+
63
84
  static const uint8_t __attribute__((aligned(VLEN))) kvalues_mxfp4_lut[] = {
64
85
  0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 6, 0, 8, 0, 12, 0, 0, 0, 0xff, 0, 0xfe, 0, 0xfd, 0, 0xfc, 0,
65
86
  0xfa, 0, 0xf8, 0, 0xf4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
@@ -68,6 +89,73 @@ static const uint8_t __attribute__((aligned(VLEN))) kvalues_mxfp4_lut[] = {
68
89
  0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
69
90
  };
70
91
 
92
+ static inline HVX_Vector_x8 hvx_vec_load_iq4nlx4x8_full(const uint8_t * restrict ptr) {
93
+ const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
94
+
95
+ HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes)
96
+ HVX_Vector v2_3 = vptr[1]; // ...
97
+ HVX_Vector v4_5 = vptr[2]; // ...
98
+ HVX_Vector v6_7 = vptr[3]; // ...
99
+
100
+ const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
101
+ const HVX_Vector lut = *(const HVX_Vector *) kvalues_iq4nl_lut;
102
+
103
+ HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F
104
+ HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4
105
+ HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F
106
+ HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4
107
+ HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F
108
+ HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4
109
+ HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F
110
+ HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4
111
+
112
+ v0 = Q6_Vb_vlut32_VbVbI(v0, lut, 0);
113
+ v1 = Q6_Vb_vlut32_VbVbI(v1, lut, 0);
114
+ v2 = Q6_Vb_vlut32_VbVbI(v2, lut, 0);
115
+ v3 = Q6_Vb_vlut32_VbVbI(v3, lut, 0);
116
+ v4 = Q6_Vb_vlut32_VbVbI(v4, lut, 0);
117
+ v5 = Q6_Vb_vlut32_VbVbI(v5, lut, 0);
118
+ v6 = Q6_Vb_vlut32_VbVbI(v6, lut, 0);
119
+ v7 = Q6_Vb_vlut32_VbVbI(v7, lut, 0);
120
+
121
+ HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 };
122
+ return r;
123
+ }
124
+
125
+ static inline HVX_Vector_x8 hvx_vec_load_iq4nlx4x8_partial(const uint8_t * restrict ptr, uint32_t n) {
126
+ const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
127
+
128
+ const uint32_t qk = QK_Q4_0x4x2; // 256
129
+ const uint32_t nb = n / qk;
130
+ const uint32_t nloe = n % qk;
131
+
132
+ const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
133
+ const HVX_Vector lut = *(const HVX_Vector *) kvalues_iq4nl_lut;
134
+
135
+ HVX_Vector_x8 r;
136
+ uint32_t i = 0;
137
+
138
+ #pragma unroll(2)
139
+ for (i = 0; i < nb; i++) {
140
+ HVX_Vector v = vptr[i]; // 256 elements (128 bytes)
141
+ HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : first 128 elements
142
+ HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : second 128 elements
143
+ r.v[i * 2 + 0] = Q6_Vb_vlut32_VbVbI(v0, lut, 0);
144
+ r.v[i * 2 + 1] = Q6_Vb_vlut32_VbVbI(v1, lut, 0);
145
+ }
146
+
147
+ if (nloe) {
148
+ HVX_Vector v = vptr[i]; // 256 elements (128 bytes)
149
+ HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : even 128 elements
150
+ HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : odd 128 elements
151
+ HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:...
152
+ r.v[i * 2 + 0] = Q6_Vb_vlut32_VbVbI(Q6_V_lo_W(v0_1_p), lut, 0);
153
+ r.v[i * 2 + 1] = Q6_Vb_vlut32_VbVbI(Q6_V_hi_W(v0_1_p), lut, 0);
154
+ }
155
+
156
+ return r;
157
+ }
158
+
71
159
  // q4x4x2 and q8x4x2 are the flat q4/8_0 formats where all quants are stored first followed by all scales
72
160
 
73
161
  static inline size_t q8x4x2_row_size(uint32_t ne) {
@@ -77,6 +165,13 @@ static inline size_t q8x4x2_row_size(uint32_t ne) {
77
165
  return hex_round_up(ne + nb * 8 * sizeof(__fp16), 128);
78
166
  }
79
167
 
168
+ static inline size_t q8_1x4x2_row_size(uint32_t ne) {
169
+ // ensures perfect alignment of quants and full row
170
+ const uint32_t qk = QK_Q8_0x4x2;
171
+ const uint32_t nb = (ne + qk - 1) / qk;
172
+ return hex_round_up(ne + nb * 8 * 2 * sizeof(__fp16), 128);
173
+ }
174
+
80
175
  static inline HVX_Vector_x8 hvx_vec_load_q4x4x8_full(const uint8_t * restrict ptr) {
81
176
  const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
82
177
 
@@ -145,6 +240,62 @@ static HVX_Vector_x8 hvx_vec_load_q4x4x8_partial(const uint8_t * restrict ptr, u
145
240
  return r;
146
241
  }
147
242
 
243
+ static inline HVX_Vector_x8 hvx_vec_load_q4_1x4x8_full(const uint8_t * restrict ptr) {
244
+ const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
245
+
246
+ HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes)
247
+ HVX_Vector v2_3 = vptr[1]; // ...
248
+ HVX_Vector v4_5 = vptr[2]; // ...
249
+ HVX_Vector v6_7 = vptr[3]; // ...
250
+
251
+ const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
252
+
253
+ HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F : first 128 elements
254
+ HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 : second 128 elements
255
+ HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F ...
256
+ HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4
257
+ HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F
258
+ HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4
259
+ HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F
260
+ HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4
261
+
262
+ HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 };
263
+ return r;
264
+ }
265
+
266
+ static HVX_Vector_x8 hvx_vec_load_q4_1x4x8_partial(const uint8_t * restrict ptr, uint32_t n) {
267
+ const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
268
+
269
+ const uint32_t qk = QK_Q4_0x4x2; // 256
270
+ const uint32_t nb = n / qk;
271
+ const uint32_t nloe = n % qk;
272
+
273
+ const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
274
+
275
+ HVX_Vector_x8 r;
276
+ uint32_t i = 0;
277
+
278
+ #pragma unroll(2)
279
+ for (i=0; i < nb; i++) {
280
+ HVX_Vector v = vptr[i]; // 256 elements (128 bytes)
281
+ HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : first 128 elements
282
+ HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : second 128 elements
283
+ r.v[i*2+0] = v0;
284
+ r.v[i*2+1] = v1;
285
+ }
286
+
287
+ if (nloe) {
288
+ HVX_Vector v = vptr[i]; // 256 elements (128 bytes)
289
+ HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : even 128 elements
290
+ HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : odd 128 elements
291
+ HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:...
292
+ r.v[i*2+0] = Q6_V_lo_W(v0_1_p);
293
+ r.v[i*2+1] = Q6_V_hi_W(v0_1_p);
294
+ }
295
+
296
+ return r;
297
+ }
298
+
148
299
  static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8_full(const uint8_t * restrict ptr) {
149
300
  const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
150
301
 
@@ -323,82 +474,96 @@ static inline HVX_Vector hvx_vec_rmpy_x8_partial(HVX_Vector_x8 x, HVX_Vector_x8
323
474
  return hvx_vec_rmpy_x8_partial(x, y, 512);
324
475
  }
325
476
 
326
- static void vec_dot_q4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) {
477
+ static void vec_dot_q4_1x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) {
327
478
  assert(n % 32 == 0); // min sub-block size
328
479
  assert((unsigned long) vx0 % 128 == 0);
329
480
  assert((unsigned long) vy0 % 128 == 0);
330
481
 
331
482
  const uint32_t qk = QK_Q4_0x4x2 * 4;
332
483
 
333
- const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
484
+ const uint32_t x_dblk_size = 8 * 4 * 2 * 2; // 32x (d, m) __fp16 = 128 bytes
334
485
  const uint32_t x_qblk_size = qk / 2; // int4
335
486
  const uint32_t x_qrow_size = n / 2; // int4 (not padded)
336
487
 
337
- const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
488
+ const uint32_t y_dblk_size = 8 * 4 * 4; // 32x (d, s) __fp16 = 128 bytes
338
489
  const uint32_t y_qblk_size = qk; // int8
339
490
  const uint32_t y_qrow_size = n; // int8 (not padded)
340
491
 
341
492
  const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first
342
- const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales
493
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales/offsets
343
494
 
344
495
  const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
345
- const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
496
+ const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales/sums
346
497
 
347
498
  // Row sum (sf)
348
499
  HVX_Vector r0_sum = Q6_V_vzero();
349
500
 
350
- // Multiply and accumulate into int32.
351
- // Compute combined scale (fp32).
352
- // Apply scale to acc and accumulate into the row sum (qf32).
353
-
354
501
  const uint32_t nb = n / qk; // num full blocks
355
502
  const uint32_t nloe = n % qk; // num leftover elemements
356
503
 
357
504
  uint32_t i = 0;
358
505
  for (; i < nb; i++) {
359
506
  HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
360
- HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size);
507
+ HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_full(r0_x_q + i * x_qblk_size);
361
508
 
362
509
  HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
363
510
 
364
- HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
365
- HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
511
+ HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size);
512
+ HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2);
513
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal));
514
+ HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal));
515
+
516
+ HVX_Vector dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
517
+ HVX_VectorPair dm_deal = Q6_W_vdeal_VVR(dm, dm, -2);
518
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(dm_deal));
519
+ HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(dm_deal));
366
520
 
367
521
  HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
522
+ HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s)));
368
523
 
369
524
  HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
525
+ HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms);
370
526
 
371
- r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
527
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum));
372
528
  }
373
529
 
374
530
  // Process leftovers
375
531
  if (nloe) {
376
532
  HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
377
- HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
533
+ HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
378
534
 
379
535
  HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
380
536
 
381
- HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
382
- HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
537
+ HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size);
538
+ HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2);
539
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal));
540
+ HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal));
541
+
542
+ HVX_Vector dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
543
+ HVX_VectorPair dm_deal = Q6_W_vdeal_VVR(dm, dm, -2);
544
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(dm_deal));
545
+ HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(dm_deal));
383
546
 
384
547
  HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
548
+ HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s)));
385
549
 
386
550
  // Zero out unused elements
387
551
  HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
388
552
  r0_dd = Q6_V_vand_QV(bmask, r0_dd);
553
+ r0_ms = Q6_V_vand_QV(bmask, r0_ms);
389
554
  r0_ia = Q6_V_vand_QV(bmask, r0_ia);
390
555
 
391
556
  HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
557
+ HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms);
392
558
 
393
- r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
559
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum));
394
560
  }
395
561
 
396
562
  r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
397
-
398
563
  hvx_vec_store_u(s0, 4, r0_sum);
399
564
  }
400
565
 
401
- static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0,
566
+ static void vec_dot_q4_1x4x2_q8x4x2_2x1(const int n, float * restrict s0,
402
567
  const void * restrict vx0, const void * restrict vx1,
403
568
  const void * restrict vy0) {
404
569
  assert(n % 32 == 0); // min sub-block size
@@ -408,11 +573,11 @@ static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0,
408
573
 
409
574
  const uint32_t qk = QK_Q4_0x4x2 * 4;
410
575
 
411
- const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
576
+ const uint32_t x_dblk_size = 8 * 4 * 2 * 2; // 32x (d, m) __fp16 = 128 bytes
412
577
  const uint32_t x_qblk_size = qk / 2; // int4
413
578
  const uint32_t x_qrow_size = n / 2; // int4 (not padded)
414
579
 
415
- const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
580
+ const uint32_t y_dblk_size = 8 * 4 * 4; // 32x (d, s) __fp16 = 128 bytes
416
581
  const uint32_t y_qblk_size = qk; // int8
417
582
  const uint32_t y_qrow_size = n; // int8 (not padded)
418
583
 
@@ -422,77 +587,306 @@ static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0,
422
587
  const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
423
588
 
424
589
  const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
425
- const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
590
+ const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales/sums
426
591
 
427
592
  // Row sum (sf)
428
593
  HVX_Vector r0_sum = Q6_V_vzero();
429
594
  HVX_Vector r1_sum = Q6_V_vzero();
430
595
 
431
- // Multiply and accumulate into int32.
432
- // Compute combined scale (fp32).
433
- // Apply scale to acc and accumulate into the row sum (qf32).
434
-
435
596
  const uint32_t nb = n / qk; // num full blocks
436
597
  const uint32_t nloe = n % qk; // num leftover elemements
437
598
 
438
599
  uint32_t i = 0;
439
600
  for (; i < nb; i++) {
440
601
  HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
441
- HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size);
442
- HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size);
602
+ HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_full(r0_x_q + i * x_qblk_size);
603
+ HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_full(r1_x_q + i * x_qblk_size);
443
604
 
444
605
  HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
445
606
  HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
446
607
 
447
- HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
448
- HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
449
- HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
608
+ HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size);
609
+ HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2);
610
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal));
611
+ HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal));
612
+
613
+ HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
614
+ HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2);
615
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal));
616
+ HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal));
617
+
618
+ HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
619
+ HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2);
620
+ HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal));
621
+ HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal));
450
622
 
451
623
  HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
624
+ HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s)));
625
+
452
626
  HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
627
+ HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy_s)));
453
628
 
454
629
  HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
630
+ HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms);
631
+
455
632
  HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
633
+ HVX_Vector r1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_ms);
456
634
 
457
- r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
458
- r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
635
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum));
636
+ r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa_total, r1_sum));
459
637
  }
460
638
 
461
639
  // Process leftovers
462
640
  if (nloe) {
463
641
  HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
464
- HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
465
- HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
642
+ HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
643
+ HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
466
644
 
467
645
  HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
468
646
  HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe));
469
647
 
470
- HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
471
- HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
472
- HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
648
+ HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size);
649
+ HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2);
650
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal));
651
+ HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal));
652
+
653
+ HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
654
+ HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2);
655
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal));
656
+ HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal));
657
+
658
+ HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
659
+ HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2);
660
+ HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal));
661
+ HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal));
473
662
 
474
663
  HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
664
+ HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s)));
665
+
475
666
  HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
667
+ HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy_s)));
476
668
 
477
669
  // Zero out unused elements
478
670
  HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
479
671
  r0_dd = Q6_V_vand_QV(bmask, r0_dd);
672
+ r0_ms = Q6_V_vand_QV(bmask, r0_ms);
480
673
  r1_dd = Q6_V_vand_QV(bmask, r1_dd);
674
+ r1_ms = Q6_V_vand_QV(bmask, r1_ms);
481
675
  r0_ia = Q6_V_vand_QV(bmask, r0_ia);
482
676
  r1_ia = Q6_V_vand_QV(bmask, r1_ia);
483
677
 
484
678
  HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
679
+ HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms);
680
+
485
681
  HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
682
+ HVX_Vector r1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_ms);
486
683
 
487
- r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
488
- r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
684
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum));
685
+ r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa_total, r1_sum));
489
686
  }
490
687
 
491
688
  HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
492
689
  hvx_vec_store_u(s0, 8, rsum);
493
690
  }
494
691
 
495
- static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1,
692
+ static void vec_dot_q4_1x4x2_q8x4x2_4x1(const int n, float * restrict s0,
693
+ const void * restrict vx0, const void * restrict vx1,
694
+ const void * restrict vx2, const void * restrict vx3,
695
+ const void * restrict vy0) {
696
+ assert(n % 32 == 0); // min sub-block size
697
+ assert((unsigned long) vx0 % 128 == 0);
698
+ assert((unsigned long) vx1 % 128 == 0);
699
+ assert((unsigned long) vx2 % 128 == 0);
700
+ assert((unsigned long) vx3 % 128 == 0);
701
+ assert((unsigned long) vy0 % 128 == 0);
702
+
703
+ const uint32_t qk = QK_Q4_0x4x2 * 4;
704
+
705
+ const uint32_t x_dblk_size = 8 * 4 * 2 * 2; // 32x (d, m) __fp16 = 128 bytes
706
+ const uint32_t x_qblk_size = qk / 2; // int4
707
+ const uint32_t x_qrow_size = n / 2; // int4 (not padded)
708
+
709
+ const uint32_t y_dblk_size = 8 * 4 * 4; // 32x (d, s) __fp16 = 128 bytes
710
+ const uint32_t y_qblk_size = qk; // int8
711
+ const uint32_t y_qrow_size = n; // int8 (not padded)
712
+
713
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
714
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
715
+ const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
716
+ const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
717
+ const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; // quants first
718
+ const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; // then scales
719
+ const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; // quants first
720
+ const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; // then scales
721
+
722
+ const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
723
+ const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales/sums
724
+
725
+ // Row sum (sf)
726
+ HVX_Vector r0_sum = Q6_V_vzero();
727
+ HVX_Vector r1_sum = Q6_V_vzero();
728
+ HVX_Vector r2_sum = Q6_V_vzero();
729
+ HVX_Vector r3_sum = Q6_V_vzero();
730
+
731
+ const uint32_t nb = n / qk; // num full blocks
732
+ const uint32_t nloe = n % qk; // num leftover elements
733
+
734
+ uint32_t i = 0;
735
+ for (; i < nb; i++) {
736
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
737
+ HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_full(r0_x_q + i * x_qblk_size);
738
+ HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_full(r1_x_q + i * x_qblk_size);
739
+ HVX_Vector_x8 r2_q = hvx_vec_load_q4_1x4x8_full(r2_x_q + i * x_qblk_size);
740
+ HVX_Vector_x8 r3_q = hvx_vec_load_q4_1x4x8_full(r3_x_q + i * x_qblk_size);
741
+
742
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
743
+ HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
744
+ HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q));
745
+ HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q));
746
+
747
+ HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size);
748
+ HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2);
749
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal));
750
+ HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal));
751
+
752
+ HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
753
+ HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2);
754
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal));
755
+ HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal));
756
+
757
+ HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
758
+ HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2);
759
+ HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal));
760
+ HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal));
761
+
762
+ HVX_Vector r2_dm = *(const HVX_UVector *) (r2_x_d + i * x_dblk_size);
763
+ HVX_VectorPair r2_dm_deal = Q6_W_vdeal_VVR(r2_dm, r2_dm, -2);
764
+ HVX_Vector r2_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r2_dm_deal));
765
+ HVX_Vector r2_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r2_dm_deal));
766
+
767
+ HVX_Vector r3_dm = *(const HVX_UVector *) (r3_x_d + i * x_dblk_size);
768
+ HVX_VectorPair r3_dm_deal = Q6_W_vdeal_VVR(r3_dm, r3_dm, -2);
769
+ HVX_Vector r3_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r3_dm_deal));
770
+ HVX_Vector r3_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r3_dm_deal));
771
+
772
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
773
+ HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s)));
774
+
775
+ HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
776
+ HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy_s)));
777
+
778
+ HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d)));
779
+ HVX_Vector r2_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_m, vy_s)));
780
+
781
+ HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d)));
782
+ HVX_Vector r3_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_m, vy_s)));
783
+
784
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
785
+ HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms);
786
+
787
+ HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
788
+ HVX_Vector r1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_ms);
789
+
790
+ HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd);
791
+ HVX_Vector r2_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_ms);
792
+
793
+ HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd);
794
+ HVX_Vector r3_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_ms);
795
+
796
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum));
797
+ r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa_total, r1_sum));
798
+ r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa_total, r2_sum));
799
+ r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa_total, r3_sum));
800
+ }
801
+
802
+ if (nloe) {
803
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
804
+ HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
805
+ HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
806
+ HVX_Vector_x8 r2_q = hvx_vec_load_q4_1x4x8_partial(r2_x_q + i * x_qblk_size, nloe);
807
+ HVX_Vector_x8 r3_q = hvx_vec_load_q4_1x4x8_partial(r3_x_q + i * x_qblk_size, nloe);
808
+
809
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
810
+ HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe));
811
+ HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r2_q, vy_q, nloe));
812
+ HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r3_q, vy_q, nloe));
813
+
814
+ HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size);
815
+ HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2);
816
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal));
817
+ HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal));
818
+
819
+ HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
820
+ HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2);
821
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal));
822
+ HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal));
823
+
824
+ HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
825
+ HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2);
826
+ HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal));
827
+ HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal));
828
+
829
+ HVX_Vector r2_dm = *(const HVX_UVector *) (r2_x_d + i * x_dblk_size);
830
+ HVX_VectorPair r2_dm_deal = Q6_W_vdeal_VVR(r2_dm, r2_dm, -2);
831
+ HVX_Vector r2_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r2_dm_deal));
832
+ HVX_Vector r2_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r2_dm_deal));
833
+
834
+ HVX_Vector r3_dm = *(const HVX_UVector *) (r3_x_d + i * x_dblk_size);
835
+ HVX_VectorPair r3_dm_deal = Q6_W_vdeal_VVR(r3_dm, r3_dm, -2);
836
+ HVX_Vector r3_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r3_dm_deal));
837
+ HVX_Vector r3_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r3_dm_deal));
838
+
839
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
840
+ HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s)));
841
+
842
+ HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
843
+ HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy_s)));
844
+
845
+ HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d)));
846
+ HVX_Vector r2_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_m, vy_s)));
847
+
848
+ HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d)));
849
+ HVX_Vector r3_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_m, vy_s)));
850
+
851
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
852
+ r0_dd = Q6_V_vand_QV(bmask, r0_dd);
853
+ r0_ms = Q6_V_vand_QV(bmask, r0_ms);
854
+ r1_dd = Q6_V_vand_QV(bmask, r1_dd);
855
+ r1_ms = Q6_V_vand_QV(bmask, r1_ms);
856
+ r2_dd = Q6_V_vand_QV(bmask, r2_dd);
857
+ r2_ms = Q6_V_vand_QV(bmask, r2_ms);
858
+ r3_dd = Q6_V_vand_QV(bmask, r3_dd);
859
+ r3_ms = Q6_V_vand_QV(bmask, r3_ms);
860
+ r0_ia = Q6_V_vand_QV(bmask, r0_ia);
861
+ r1_ia = Q6_V_vand_QV(bmask, r1_ia);
862
+ r2_ia = Q6_V_vand_QV(bmask, r2_ia);
863
+ r3_ia = Q6_V_vand_QV(bmask, r3_ia);
864
+
865
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
866
+ HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms);
867
+
868
+ HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
869
+ HVX_Vector r1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_ms);
870
+
871
+ HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd);
872
+ HVX_Vector r2_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_ms);
873
+
874
+ HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd);
875
+ HVX_Vector r3_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_ms);
876
+
877
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum));
878
+ r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa_total, r1_sum));
879
+ r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa_total, r2_sum));
880
+ r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa_total, r3_sum));
881
+ }
882
+
883
+ HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } };
884
+ HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in);
885
+ hvx_vec_store_u(s0, 16, rsum);
886
+ }
887
+
888
+
889
+ static void vec_dot_q4_1x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1,
496
890
  const void * restrict vx0, const void * restrict vx1,
497
891
  const void * restrict vy0, const void * restrict vy1) {
498
892
  assert(n % 32 == 0);
@@ -503,11 +897,11 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float *
503
897
 
504
898
  const uint32_t qk = QK_Q4_0x4x2 * 4;
505
899
 
506
- const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
900
+ const uint32_t x_dblk_size = 8 * 4 * 2 * 2; // 32x (d, m) __fp16 = 128 bytes
507
901
  const uint32_t x_qblk_size = qk / 2; // int4
508
902
  const uint32_t x_qrow_size = n / 2; // int4 (not padded)
509
903
 
510
- const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
904
+ const uint32_t y_dblk_size = 8 * 4 * 4; // 32x (d, s) __fp16 = 128 bytes
511
905
  const uint32_t y_qblk_size = qk; // int8
512
906
  const uint32_t y_qrow_size = n; // int8 (not padded)
513
907
 
@@ -517,9 +911,9 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float *
517
911
  const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
518
912
 
519
913
  const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first
520
- const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales
914
+ const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales/sums
521
915
  const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first
522
- const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales
916
+ const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales/sums
523
917
 
524
918
  // Row sums (sf) - 4 accumulators for 2×2 tile
525
919
  HVX_Vector r0_c0_sum = Q6_V_vzero();
@@ -532,13 +926,13 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float *
532
926
 
533
927
  uint32_t i = 0;
534
928
  for (; i < nb; i++) {
535
- // Load src1 columns (reused across both src0 rows)
929
+ // Load src1 columns
536
930
  HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size);
537
931
  HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size);
538
932
 
539
- // Load src0 rows (reused across both src1 columns)
540
- HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size);
541
- HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size);
933
+ // Load src0 rows
934
+ HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_full(r0_x_q + i * x_qblk_size);
935
+ HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_full(r1_x_q + i * x_qblk_size);
542
936
 
543
937
  // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1
544
938
  HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
@@ -547,16 +941,38 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float *
547
941
  HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q));
548
942
 
549
943
  // Load scales
550
- HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
551
- HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
552
- HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
553
- HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
944
+ HVX_Vector ds0 = *(const HVX_UVector *) (y0_d + i * y_dblk_size);
945
+ HVX_VectorPair ds0_deal = Q6_W_vdeal_VVR(ds0, ds0, -2);
946
+ HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds0_deal));
947
+ HVX_Vector vy0_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds0_deal));
948
+
949
+ HVX_Vector ds1 = *(const HVX_UVector *) (y1_d + i * y_dblk_size);
950
+ HVX_VectorPair ds1_deal = Q6_W_vdeal_VVR(ds1, ds1, -2);
951
+ HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds1_deal));
952
+ HVX_Vector vy1_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds1_deal));
953
+
954
+ HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
955
+ HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2);
956
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal));
957
+ HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal));
958
+
959
+ HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
960
+ HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2);
961
+ HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal));
962
+ HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal));
554
963
 
555
964
  // Compute combined scales
556
965
  HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
966
+ HVX_Vector r0_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy0_s)));
967
+
557
968
  HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
969
+ HVX_Vector r0_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy1_s)));
970
+
558
971
  HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
972
+ HVX_Vector r1_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy0_s)));
973
+
559
974
  HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
975
+ HVX_Vector r1_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy1_s)));
560
976
 
561
977
  // Apply scales and accumulate
562
978
  HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
@@ -564,40 +980,72 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float *
564
980
  HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
565
981
  HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
566
982
 
567
- r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
568
- r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
569
- r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
570
- r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
983
+ HVX_Vector r0_c0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_ms);
984
+ HVX_Vector r0_c1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_ms);
985
+ HVX_Vector r1_c0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_ms);
986
+ HVX_Vector r1_c1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_ms);
987
+
988
+ r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa_total, r0_c0_sum));
989
+ r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa_total, r0_c1_sum));
990
+ r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa_total, r1_c0_sum));
991
+ r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa_total, r1_c1_sum));
571
992
  }
572
993
 
573
994
  // Process leftovers
574
995
  if (nloe) {
575
996
  HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe);
576
997
  HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe);
577
- HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
578
- HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
998
+ HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
999
+ HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
579
1000
 
580
1001
  HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe));
581
1002
  HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe));
582
1003
  HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe));
583
1004
  HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe));
584
1005
 
585
- HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
586
- HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
587
- HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
588
- HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
1006
+ HVX_Vector ds0 = *(const HVX_UVector *) (y0_d + i * y_dblk_size);
1007
+ HVX_VectorPair ds0_deal = Q6_W_vdeal_VVR(ds0, ds0, -2);
1008
+ HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds0_deal));
1009
+ HVX_Vector vy0_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds0_deal));
1010
+
1011
+ HVX_Vector ds1 = *(const HVX_UVector *) (y1_d + i * y_dblk_size);
1012
+ HVX_VectorPair ds1_deal = Q6_W_vdeal_VVR(ds1, ds1, -2);
1013
+ HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds1_deal));
1014
+ HVX_Vector vy1_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds1_deal));
1015
+
1016
+ HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
1017
+ HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2);
1018
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal));
1019
+ HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal));
1020
+
1021
+ HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
1022
+ HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2);
1023
+ HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal));
1024
+ HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal));
589
1025
 
590
1026
  HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
1027
+ HVX_Vector r0_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy0_s)));
1028
+
591
1029
  HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
1030
+ HVX_Vector r0_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy1_s)));
1031
+
592
1032
  HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
1033
+ HVX_Vector r1_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy0_s)));
1034
+
593
1035
  HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
1036
+ HVX_Vector r1_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy1_s)));
594
1037
 
595
- // Zero out unused scales
1038
+ // Zero out unused elements
596
1039
  HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
597
1040
  r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);
1041
+ r0_c0_ms = Q6_V_vand_QV(bmask, r0_c0_ms);
598
1042
  r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);
1043
+ r0_c1_ms = Q6_V_vand_QV(bmask, r0_c1_ms);
599
1044
  r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd);
1045
+ r1_c0_ms = Q6_V_vand_QV(bmask, r1_c0_ms);
600
1046
  r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd);
1047
+ r1_c1_ms = Q6_V_vand_QV(bmask, r1_c1_ms);
1048
+
601
1049
  r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia);
602
1050
  r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia);
603
1051
  r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia);
@@ -608,10 +1056,15 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float *
608
1056
  HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
609
1057
  HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
610
1058
 
611
- r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
612
- r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
613
- r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
614
- r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
1059
+ HVX_Vector r0_c0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_ms);
1060
+ HVX_Vector r0_c1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_ms);
1061
+ HVX_Vector r1_c0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_ms);
1062
+ HVX_Vector r1_c1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_ms);
1063
+
1064
+ r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa_total, r0_c0_sum));
1065
+ r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa_total, r0_c1_sum));
1066
+ r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa_total, r1_c0_sum));
1067
+ r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa_total, r1_c1_sum));
615
1068
  }
616
1069
 
617
1070
  // Reduce and store results
@@ -622,26 +1075,26 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float *
622
1075
  hvx_vec_store_u(s1, 8, r0_r1_c1_sum); // row0,col1 row1,col1
623
1076
  }
624
1077
 
625
- static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) {
1078
+ static void vec_dot_q4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) {
626
1079
  assert(n % 32 == 0); // min sub-block size
627
1080
  assert((unsigned long) vx0 % 128 == 0);
628
1081
  assert((unsigned long) vy0 % 128 == 0);
629
1082
 
630
1083
  const uint32_t qk = QK_Q4_0x4x2 * 4;
631
1084
 
632
- const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
633
- const uint32_t x_qblk_size = qk; // int8
634
- const uint32_t x_qrow_size = n; // int8 (not padded)
1085
+ const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
1086
+ const uint32_t x_qblk_size = qk / 2; // int4
1087
+ const uint32_t x_qrow_size = n / 2; // int4 (not padded)
635
1088
 
636
- const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
637
- const uint32_t y_qblk_size = qk; // int8
638
- const uint32_t y_qrow_size = n; // int8 (not padded)
1089
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
1090
+ const uint32_t y_qblk_size = qk; // int8
1091
+ const uint32_t y_qrow_size = n; // int8 (not padded)
639
1092
 
640
- const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first
641
- const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales
1093
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first
1094
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales
642
1095
 
643
- const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
644
- const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
1096
+ const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
1097
+ const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
645
1098
 
646
1099
  // Row sum (sf)
647
1100
  HVX_Vector r0_sum = Q6_V_vzero();
@@ -651,12 +1104,12 @@ static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const vo
651
1104
  // Apply scale to acc and accumulate into the row sum (qf32).
652
1105
 
653
1106
  const uint32_t nb = n / qk; // num full blocks
654
- int32_t nloe = n % qk; // num leftover elemements (must be signed)
1107
+ const uint32_t nloe = n % qk; // num leftover elemements
655
1108
 
656
1109
  uint32_t i = 0;
657
1110
  for (; i < nb; i++) {
658
1111
  HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
659
- HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size);
1112
+ HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size);
660
1113
 
661
1114
  HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
662
1115
 
@@ -673,7 +1126,7 @@ static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const vo
673
1126
  // Process leftovers
674
1127
  if (nloe) {
675
1128
  HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
676
- HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
1129
+ HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
677
1130
 
678
1131
  HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
679
1132
 
@@ -697,7 +1150,7 @@ static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const vo
697
1150
  hvx_vec_store_u(s0, 4, r0_sum);
698
1151
  }
699
1152
 
700
- static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0,
1153
+ static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0,
701
1154
  const void * restrict vx0, const void * restrict vx1,
702
1155
  const void * restrict vy0) {
703
1156
  assert(n % 32 == 0); // min sub-block size
@@ -708,8 +1161,8 @@ static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0,
708
1161
  const uint32_t qk = QK_Q4_0x4x2 * 4;
709
1162
 
710
1163
  const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
711
- const uint32_t x_qblk_size = qk; // int8
712
- const uint32_t x_qrow_size = n; // int8 (not padded)
1164
+ const uint32_t x_qblk_size = qk / 2; // int4
1165
+ const uint32_t x_qrow_size = n / 2; // int4 (not padded)
713
1166
 
714
1167
  const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
715
1168
  const uint32_t y_qblk_size = qk; // int8
@@ -723,7 +1176,7 @@ static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0,
723
1176
  const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
724
1177
  const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
725
1178
 
726
- // Row sum (qf32)
1179
+ // Row sum (sf)
727
1180
  HVX_Vector r0_sum = Q6_V_vzero();
728
1181
  HVX_Vector r1_sum = Q6_V_vzero();
729
1182
 
@@ -732,13 +1185,13 @@ static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0,
732
1185
  // Apply scale to acc and accumulate into the row sum (qf32).
733
1186
 
734
1187
  const uint32_t nb = n / qk; // num full blocks
735
- int32_t nloe = n % qk; // num leftover elemements (must be signed)
1188
+ const uint32_t nloe = n % qk; // num leftover elemements
736
1189
 
737
1190
  uint32_t i = 0;
738
1191
  for (; i < nb; i++) {
739
1192
  HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
740
- HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size);
741
- HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_full(r1_x_q + i * x_qblk_size);
1193
+ HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size);
1194
+ HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size);
742
1195
 
743
1196
  HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
744
1197
  HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
@@ -760,13 +1213,13 @@ static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0,
760
1213
  // Process leftovers
761
1214
  if (nloe) {
762
1215
  HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
763
- HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
764
- HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
1216
+ HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
1217
+ HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
765
1218
 
766
1219
  HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
767
1220
  HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe));
768
1221
 
769
- HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
1222
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
770
1223
  HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
771
1224
  HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
772
1225
 
@@ -791,7 +1244,134 @@ static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0,
791
1244
  hvx_vec_store_u(s0, 8, rsum);
792
1245
  }
793
1246
 
794
- static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1,
1247
+ static void vec_dot_q4x4x2_q8x4x2_4x1(const int n, float * restrict s0,
1248
+ const void * restrict vx0, const void * restrict vx1,
1249
+ const void * restrict vx2, const void * restrict vx3,
1250
+ const void * restrict vy0) {
1251
+ assert(n % 32 == 0); // min sub-block size
1252
+ assert((unsigned long) vx0 % 128 == 0);
1253
+ assert((unsigned long) vx1 % 128 == 0);
1254
+ assert((unsigned long) vx2 % 128 == 0);
1255
+ assert((unsigned long) vx3 % 128 == 0);
1256
+ assert((unsigned long) vy0 % 128 == 0);
1257
+
1258
+ const uint32_t qk = QK_Q4_0x4x2 * 4;
1259
+
1260
+ const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
1261
+ const uint32_t x_qblk_size = qk / 2; // int4
1262
+ const uint32_t x_qrow_size = n / 2; // int4 (not padded)
1263
+
1264
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
1265
+ const uint32_t y_qblk_size = qk; // int8
1266
+ const uint32_t y_qrow_size = n; // int8 (not padded)
1267
+
1268
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0;
1269
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size;
1270
+ const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0;
1271
+ const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size;
1272
+ const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0;
1273
+ const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size;
1274
+ const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0;
1275
+ const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size;
1276
+
1277
+ const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0);
1278
+ const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size);
1279
+
1280
+ // Row sum (sf)
1281
+ HVX_Vector r0_sum = Q6_V_vzero();
1282
+ HVX_Vector r1_sum = Q6_V_vzero();
1283
+ HVX_Vector r2_sum = Q6_V_vzero();
1284
+ HVX_Vector r3_sum = Q6_V_vzero();
1285
+
1286
+ const uint32_t nb = n / qk; // num full blocks
1287
+ const uint32_t nloe = n % qk; // num leftover elements
1288
+
1289
+ uint32_t i = 0;
1290
+ for (; i < nb; i++) {
1291
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
1292
+ HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size);
1293
+ HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size);
1294
+ HVX_Vector_x8 r2_q = hvx_vec_load_q4x4x8_full(r2_x_q + i * x_qblk_size);
1295
+ HVX_Vector_x8 r3_q = hvx_vec_load_q4x4x8_full(r3_x_q + i * x_qblk_size);
1296
+
1297
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
1298
+ HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
1299
+ HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q));
1300
+ HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q));
1301
+
1302
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
1303
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
1304
+ HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
1305
+ HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size));
1306
+ HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size));
1307
+
1308
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
1309
+ HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
1310
+ HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d)));
1311
+ HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d)));
1312
+
1313
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
1314
+ HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
1315
+ HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd);
1316
+ HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd);
1317
+
1318
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
1319
+ r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
1320
+ r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum));
1321
+ r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum));
1322
+ }
1323
+
1324
+ if (nloe) {
1325
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
1326
+ HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
1327
+ HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
1328
+ HVX_Vector_x8 r2_q = hvx_vec_load_q4x4x8_partial(r2_x_q + i * x_qblk_size, nloe);
1329
+ HVX_Vector_x8 r3_q = hvx_vec_load_q4x4x8_partial(r3_x_q + i * x_qblk_size, nloe);
1330
+
1331
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
1332
+ HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe));
1333
+ HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r2_q, vy_q, nloe));
1334
+ HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r3_q, vy_q, nloe));
1335
+
1336
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
1337
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
1338
+ HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
1339
+ HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size));
1340
+ HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size));
1341
+
1342
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
1343
+ HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
1344
+ HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d)));
1345
+ HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d)));
1346
+
1347
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
1348
+ r0_dd = Q6_V_vand_QV(bmask, r0_dd);
1349
+ r1_dd = Q6_V_vand_QV(bmask, r1_dd);
1350
+ r2_dd = Q6_V_vand_QV(bmask, r2_dd);
1351
+ r3_dd = Q6_V_vand_QV(bmask, r3_dd);
1352
+ r0_ia = Q6_V_vand_QV(bmask, r0_ia);
1353
+ r1_ia = Q6_V_vand_QV(bmask, r1_ia);
1354
+ r2_ia = Q6_V_vand_QV(bmask, r2_ia);
1355
+ r3_ia = Q6_V_vand_QV(bmask, r3_ia);
1356
+
1357
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
1358
+ HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
1359
+ HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd);
1360
+ HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd);
1361
+
1362
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
1363
+ r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
1364
+ r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum));
1365
+ r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum));
1366
+ }
1367
+
1368
+ HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } };
1369
+ HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in);
1370
+ hvx_vec_store_u(s0, 16, rsum);
1371
+ }
1372
+
1373
+
1374
+ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1,
795
1375
  const void * restrict vx0, const void * restrict vx1,
796
1376
  const void * restrict vy0, const void * restrict vy1) {
797
1377
  assert(n % 32 == 0);
@@ -800,11 +1380,11 @@ static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float *
800
1380
  assert((unsigned long) vy0 % 128 == 0);
801
1381
  assert((unsigned long) vy1 % 128 == 0);
802
1382
 
803
- const uint32_t qk = QK_Q8_0x4x2 * 4;
1383
+ const uint32_t qk = QK_Q4_0x4x2 * 4;
804
1384
 
805
1385
  const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
806
- const uint32_t x_qblk_size = qk; // int8
807
- const uint32_t x_qrow_size = n; // int8 (not padded)
1386
+ const uint32_t x_qblk_size = qk / 2; // int4
1387
+ const uint32_t x_qrow_size = n / 2; // int4 (not padded)
808
1388
 
809
1389
  const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
810
1390
  const uint32_t y_qblk_size = qk; // int8
@@ -836,8 +1416,8 @@ static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float *
836
1416
  HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size);
837
1417
 
838
1418
  // Load src0 rows (reused across both src1 columns)
839
- HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size);
840
- HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_full(r1_x_q + i * x_qblk_size);
1419
+ HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size);
1420
+ HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size);
841
1421
 
842
1422
  // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1
843
1423
  HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
@@ -873,8 +1453,8 @@ static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float *
873
1453
  if (nloe) {
874
1454
  HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe);
875
1455
  HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe);
876
- HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
877
- HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
1456
+ HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
1457
+ HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
878
1458
 
879
1459
  HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe));
880
1460
  HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe));
@@ -891,63 +1471,1016 @@ static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float *
891
1471
  HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
892
1472
  HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
893
1473
 
894
- // Zero out unused elements
1474
+ // Zero out unused scales
1475
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
1476
+ r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);
1477
+ r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);
1478
+ r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd);
1479
+ r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd);
1480
+ r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia);
1481
+ r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia);
1482
+ r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia);
1483
+ r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia);
1484
+
1485
+ HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
1486
+ HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
1487
+ HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
1488
+ HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
1489
+
1490
+ r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
1491
+ r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
1492
+ r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
1493
+ r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
1494
+ }
1495
+
1496
+ // Reduce and store results
1497
+ HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
1498
+ HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
1499
+
1500
+ hvx_vec_store_u(s0, 8, r0_r1_c0_sum); // row0,col0 row1,col0
1501
+ hvx_vec_store_u(s1, 8, r0_r1_c1_sum); // row0,col1 row1,col1
1502
+ }
1503
+
1504
+ static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) {
1505
+ assert(n % 32 == 0); // min sub-block size
1506
+ assert((unsigned long) vx0 % 128 == 0);
1507
+ assert((unsigned long) vy0 % 128 == 0);
1508
+
1509
+ const uint32_t qk = QK_Q4_0x4x2 * 4;
1510
+
1511
+ const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
1512
+ const uint32_t x_qblk_size = qk; // int8
1513
+ const uint32_t x_qrow_size = n; // int8 (not padded)
1514
+
1515
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
1516
+ const uint32_t y_qblk_size = qk; // int8
1517
+ const uint32_t y_qrow_size = n; // int8 (not padded)
1518
+
1519
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first
1520
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales
1521
+
1522
+ const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
1523
+ const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
1524
+
1525
+ // Row sum (sf)
1526
+ HVX_Vector r0_sum = Q6_V_vzero();
1527
+
1528
+ // Multiply and accumulate into int32.
1529
+ // Compute combined scale (fp32).
1530
+ // Apply scale to acc and accumulate into the row sum (qf32).
1531
+
1532
+ const uint32_t nb = n / qk; // num full blocks
1533
+ int32_t nloe = n % qk; // num leftover elemements (must be signed)
1534
+
1535
+ uint32_t i = 0;
1536
+ for (; i < nb; i++) {
1537
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
1538
+ HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size);
1539
+
1540
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
1541
+
1542
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
1543
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
1544
+
1545
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
1546
+
1547
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
1548
+
1549
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
1550
+ }
1551
+
1552
+ // Process leftovers
1553
+ if (nloe) {
1554
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
1555
+ HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
1556
+
1557
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
1558
+
1559
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
1560
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
1561
+
1562
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
1563
+
1564
+ // Zero out unused elements
1565
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
1566
+ r0_dd = Q6_V_vand_QV(bmask, r0_dd);
1567
+ r0_ia = Q6_V_vand_QV(bmask, r0_ia);
1568
+
1569
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
1570
+
1571
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
1572
+ }
1573
+
1574
+ r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
1575
+
1576
+ hvx_vec_store_u(s0, 4, r0_sum);
1577
+ }
1578
+
1579
+ static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0,
1580
+ const void * restrict vx0, const void * restrict vx1,
1581
+ const void * restrict vy0) {
1582
+ assert(n % 32 == 0); // min sub-block size
1583
+ assert((unsigned long) vx0 % 128 == 0);
1584
+ assert((unsigned long) vx1 % 128 == 0);
1585
+ assert((unsigned long) vy0 % 128 == 0);
1586
+
1587
+ const uint32_t qk = QK_Q4_0x4x2 * 4;
1588
+
1589
+ const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
1590
+ const uint32_t x_qblk_size = qk; // int8
1591
+ const uint32_t x_qrow_size = n; // int8 (not padded)
1592
+
1593
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
1594
+ const uint32_t y_qblk_size = qk; // int8
1595
+ const uint32_t y_qrow_size = n; // int8 (not padded)
1596
+
1597
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
1598
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
1599
+ const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
1600
+ const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
1601
+
1602
+ const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
1603
+ const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
1604
+
1605
+ // Row sum (qf32)
1606
+ HVX_Vector r0_sum = Q6_V_vzero();
1607
+ HVX_Vector r1_sum = Q6_V_vzero();
1608
+
1609
+ // Multiply and accumulate into int32.
1610
+ // Compute combined scale (fp32).
1611
+ // Apply scale to acc and accumulate into the row sum (qf32).
1612
+
1613
+ const uint32_t nb = n / qk; // num full blocks
1614
+ int32_t nloe = n % qk; // num leftover elemements (must be signed)
1615
+
1616
+ uint32_t i = 0;
1617
+ for (; i < nb; i++) {
1618
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
1619
+ HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size);
1620
+ HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_full(r1_x_q + i * x_qblk_size);
1621
+
1622
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
1623
+ HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
1624
+
1625
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
1626
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
1627
+ HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
1628
+
1629
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
1630
+ HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
1631
+
1632
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
1633
+ HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
1634
+
1635
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
1636
+ r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
1637
+ }
1638
+
1639
+ // Process leftovers
1640
+ if (nloe) {
1641
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
1642
+ HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
1643
+ HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
1644
+
1645
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
1646
+ HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe));
1647
+
1648
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
1649
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
1650
+ HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
1651
+
1652
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
1653
+ HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
1654
+
1655
+ // Zero out unused elements
1656
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
1657
+ r0_dd = Q6_V_vand_QV(bmask, r0_dd);
1658
+ r1_dd = Q6_V_vand_QV(bmask, r1_dd);
1659
+ r0_ia = Q6_V_vand_QV(bmask, r0_ia);
1660
+ r1_ia = Q6_V_vand_QV(bmask, r1_ia);
1661
+
1662
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
1663
+ HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
1664
+
1665
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
1666
+ r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
1667
+ }
1668
+
1669
+ HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
1670
+ hvx_vec_store_u(s0, 8, rsum);
1671
+ }
1672
+
1673
+ static void vec_dot_q8x4x2_q8x4x2_4x1(const int n, float * restrict s0,
1674
+ const void * restrict vx0, const void * restrict vx1,
1675
+ const void * restrict vx2, const void * restrict vx3,
1676
+ const void * restrict vy0) {
1677
+ assert(n % 32 == 0); // min sub-block size
1678
+ assert((unsigned long) vx0 % 128 == 0);
1679
+ assert((unsigned long) vx1 % 128 == 0);
1680
+ assert((unsigned long) vx2 % 128 == 0);
1681
+ assert((unsigned long) vx3 % 128 == 0);
1682
+ assert((unsigned long) vy0 % 128 == 0);
1683
+
1684
+ const uint32_t qk = QK_Q4_0x4x2 * 4;
1685
+
1686
+ const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
1687
+ const uint32_t x_qblk_size = qk; // int8
1688
+ const uint32_t x_qrow_size = n; // int8 (not padded)
1689
+
1690
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
1691
+ const uint32_t y_qblk_size = qk; // int8
1692
+ const uint32_t y_qrow_size = n; // int8 (not padded)
1693
+
1694
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
1695
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
1696
+ const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
1697
+ const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
1698
+ const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; // quants first
1699
+ const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; // then scales
1700
+ const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; // quants first
1701
+ const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; // then scales
1702
+
1703
+ const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
1704
+ const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
1705
+
1706
+ // Row sum (qf32)
1707
+ HVX_Vector r0_sum = Q6_V_vzero();
1708
+ HVX_Vector r1_sum = Q6_V_vzero();
1709
+ HVX_Vector r2_sum = Q6_V_vzero();
1710
+ HVX_Vector r3_sum = Q6_V_vzero();
1711
+
1712
+ const uint32_t nb = n / qk; // num full blocks
1713
+ int32_t nloe = n % qk; // num leftover elemements (must be signed)
1714
+
1715
+ uint32_t i = 0;
1716
+ for (; i < nb; i++) {
1717
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
1718
+ HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size);
1719
+ HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_full(r1_x_q + i * x_qblk_size);
1720
+ HVX_Vector_x8 r2_q = hvx_vec_load_q8x4x8_full(r2_x_q + i * x_qblk_size);
1721
+ HVX_Vector_x8 r3_q = hvx_vec_load_q8x4x8_full(r3_x_q + i * x_qblk_size);
1722
+
1723
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
1724
+ HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
1725
+ HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q));
1726
+ HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q));
1727
+
1728
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
1729
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
1730
+ HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
1731
+ HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size));
1732
+ HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size));
1733
+
1734
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
1735
+ HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
1736
+ HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d)));
1737
+ HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d)));
1738
+
1739
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
1740
+ HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
1741
+ HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd);
1742
+ HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd);
1743
+
1744
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
1745
+ r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
1746
+ r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum));
1747
+ r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum));
1748
+ }
1749
+
1750
+ if (nloe) {
1751
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
1752
+ HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
1753
+ HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
1754
+ HVX_Vector_x8 r2_q = hvx_vec_load_q8x4x8_partial(r2_x_q + i * x_qblk_size, nloe);
1755
+ HVX_Vector_x8 r3_q = hvx_vec_load_q8x4x8_partial(r3_x_q + i * x_qblk_size, nloe);
1756
+
1757
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
1758
+ HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe));
1759
+ HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r2_q, vy_q, nloe));
1760
+ HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r3_q, vy_q, nloe));
1761
+
1762
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
1763
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
1764
+ HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
1765
+ HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size));
1766
+ HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size));
1767
+
1768
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
1769
+ HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
1770
+ HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d)));
1771
+ HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d)));
1772
+
1773
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
1774
+ r0_dd = Q6_V_vand_QV(bmask, r0_dd);
1775
+ r1_dd = Q6_V_vand_QV(bmask, r1_dd);
1776
+ r2_dd = Q6_V_vand_QV(bmask, r2_dd);
1777
+ r3_dd = Q6_V_vand_QV(bmask, r3_dd);
1778
+ r0_ia = Q6_V_vand_QV(bmask, r0_ia);
1779
+ r1_ia = Q6_V_vand_QV(bmask, r1_ia);
1780
+ r2_ia = Q6_V_vand_QV(bmask, r2_ia);
1781
+ r3_ia = Q6_V_vand_QV(bmask, r3_ia);
1782
+
1783
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
1784
+ HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
1785
+ HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd);
1786
+ HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd);
1787
+
1788
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
1789
+ r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
1790
+ r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum));
1791
+ r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum));
1792
+ }
1793
+
1794
+ HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } };
1795
+ HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in);
1796
+ hvx_vec_store_u(s0, 16, rsum);
1797
+ }
1798
+
1799
+
1800
+ static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1,
1801
+ const void * restrict vx0, const void * restrict vx1,
1802
+ const void * restrict vy0, const void * restrict vy1) {
1803
+ assert(n % 32 == 0);
1804
+ assert((unsigned long) vx0 % 128 == 0);
1805
+ assert((unsigned long) vx1 % 128 == 0);
1806
+ assert((unsigned long) vy0 % 128 == 0);
1807
+ assert((unsigned long) vy1 % 128 == 0);
1808
+
1809
+ const uint32_t qk = QK_Q8_0x4x2 * 4;
1810
+
1811
+ const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
1812
+ const uint32_t x_qblk_size = qk; // int8
1813
+ const uint32_t x_qrow_size = n; // int8 (not padded)
1814
+
1815
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
1816
+ const uint32_t y_qblk_size = qk; // int8
1817
+ const uint32_t y_qrow_size = n; // int8 (not padded)
1818
+
1819
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
1820
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
1821
+ const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
1822
+ const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
1823
+
1824
+ const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first
1825
+ const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales
1826
+ const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first
1827
+ const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales
1828
+
1829
+ // Row sums (sf) - 4 accumulators for 2×2 tile
1830
+ HVX_Vector r0_c0_sum = Q6_V_vzero();
1831
+ HVX_Vector r0_c1_sum = Q6_V_vzero();
1832
+ HVX_Vector r1_c0_sum = Q6_V_vzero();
1833
+ HVX_Vector r1_c1_sum = Q6_V_vzero();
1834
+
1835
+ const uint32_t nb = n / qk; // num full blocks
1836
+ const uint32_t nloe = n % qk; // num leftover elements
1837
+
1838
+ uint32_t i = 0;
1839
+ for (; i < nb; i++) {
1840
+ // Load src1 columns (reused across both src0 rows)
1841
+ HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size);
1842
+ HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size);
1843
+
1844
+ // Load src0 rows (reused across both src1 columns)
1845
+ HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size);
1846
+ HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_full(r1_x_q + i * x_qblk_size);
1847
+
1848
+ // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1
1849
+ HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
1850
+ HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q));
1851
+ HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q));
1852
+ HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q));
1853
+
1854
+ // Load scales
1855
+ HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
1856
+ HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
1857
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
1858
+ HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
1859
+
1860
+ // Compute combined scales
1861
+ HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
1862
+ HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
1863
+ HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
1864
+ HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
1865
+
1866
+ // Apply scales and accumulate
1867
+ HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
1868
+ HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
1869
+ HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
1870
+ HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
1871
+
1872
+ r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
1873
+ r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
1874
+ r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
1875
+ r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
1876
+ }
1877
+
1878
+ // Process leftovers
1879
+ if (nloe) {
1880
+ HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe);
1881
+ HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe);
1882
+ HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
1883
+ HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
1884
+
1885
+ HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe));
1886
+ HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe));
1887
+ HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe));
1888
+ HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe));
1889
+
1890
+ HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
1891
+ HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
1892
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
1893
+ HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
1894
+
1895
+ HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
1896
+ HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
1897
+ HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
1898
+ HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
1899
+
1900
+ // Zero out unused elements
1901
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
1902
+ r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);
1903
+ r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);
1904
+ r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd);
1905
+ r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd);
1906
+ r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia);
1907
+ r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia);
1908
+ r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia);
1909
+ r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia);
1910
+
1911
+ HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
1912
+ HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
1913
+ HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
1914
+ HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
1915
+
1916
+ r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
1917
+ r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
1918
+ r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
1919
+ r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
1920
+ }
1921
+
1922
+ // Reduce and store results
1923
+ HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
1924
+ HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
1925
+
1926
+ hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0
1927
+ hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1
1928
+ }
1929
+
1930
+ // ======== IQ4_NL x Q8_0 vec_dot kernels ========
1931
+ // Same structure as Q4_0 vec_dot but uses IQ4_NL LUT-based load (4-bit index -> int8 kvalue).
1932
+ // Scale format is identical to Q4_0 (fp16 scales).
1933
+
1934
+ static void vec_dot_iq4nlx4x2_q8x4x2_1x1(const int n,
1935
+ float * restrict s0,
1936
+ const void * restrict vx0,
1937
+ const void * restrict vy0) {
1938
+ assert(n % 32 == 0);
1939
+ assert((unsigned long) vx0 % 128 == 0);
1940
+ assert((unsigned long) vy0 % 128 == 0);
1941
+
1942
+ const uint32_t qk = QK_Q4_0x4x2 * 4;
1943
+
1944
+ const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
1945
+ const uint32_t x_qblk_size = qk / 2; // int4
1946
+ const uint32_t x_qrow_size = n / 2; // int4 (not padded)
1947
+
1948
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
1949
+ const uint32_t y_qblk_size = qk; // int8
1950
+ const uint32_t y_qrow_size = n; // int8 (not padded)
1951
+
1952
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first
1953
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales
1954
+
1955
+ const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
1956
+ const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
1957
+
1958
+ HVX_Vector r0_sum = Q6_V_vzero();
1959
+
1960
+ const uint32_t nb = n / qk;
1961
+ const uint32_t nloe = n % qk;
1962
+
1963
+ uint32_t i = 0;
1964
+ for (; i < nb; i++) {
1965
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
1966
+ HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size);
1967
+
1968
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
1969
+
1970
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
1971
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
1972
+
1973
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
1974
+
1975
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
1976
+
1977
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
1978
+ }
1979
+
1980
+ if (nloe) {
1981
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
1982
+ HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe);
1983
+
1984
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
1985
+
1986
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
1987
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
1988
+
1989
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
1990
+
1991
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
1992
+ r0_dd = Q6_V_vand_QV(bmask, r0_dd);
1993
+ r0_ia = Q6_V_vand_QV(bmask, r0_ia);
1994
+
1995
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
1996
+
1997
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
1998
+ }
1999
+
2000
+ r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
2001
+
2002
+ hvx_vec_store_u(s0, 4, r0_sum);
2003
+ }
2004
+
2005
+ static void vec_dot_iq4nlx4x2_q8x4x2_2x1(const int n,
2006
+ float * restrict s0,
2007
+ const void * restrict vx0,
2008
+ const void * restrict vx1,
2009
+ const void * restrict vy0) {
2010
+ assert(n % 32 == 0);
2011
+ assert((unsigned long) vx0 % 128 == 0);
2012
+ assert((unsigned long) vx1 % 128 == 0);
2013
+ assert((unsigned long) vy0 % 128 == 0);
2014
+
2015
+ const uint32_t qk = QK_Q4_0x4x2 * 4;
2016
+
2017
+ const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
2018
+ const uint32_t x_qblk_size = qk / 2; // int4
2019
+ const uint32_t x_qrow_size = n / 2; // int4 (not padded)
2020
+
2021
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
2022
+ const uint32_t y_qblk_size = qk; // int8
2023
+ const uint32_t y_qrow_size = n; // int8 (not padded)
2024
+
2025
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
2026
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
2027
+ const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
2028
+ const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
2029
+
2030
+ const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
2031
+ const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
2032
+
2033
+ HVX_Vector r0_sum = Q6_V_vzero();
2034
+ HVX_Vector r1_sum = Q6_V_vzero();
2035
+
2036
+ const uint32_t nb = n / qk;
2037
+ const uint32_t nloe = n % qk;
2038
+
2039
+ uint32_t i = 0;
2040
+ for (; i < nb; i++) {
2041
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
2042
+ HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size);
2043
+ HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_full(r1_x_q + i * x_qblk_size);
2044
+
2045
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
2046
+ HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
2047
+
2048
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
2049
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
2050
+ HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
2051
+
2052
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
2053
+ HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
2054
+
2055
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
2056
+ HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
2057
+
2058
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
2059
+ r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
2060
+ }
2061
+
2062
+ if (nloe) {
2063
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
2064
+ HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe);
2065
+ HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_partial(r1_x_q + i * x_qblk_size, nloe);
2066
+
2067
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
2068
+ HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe));
2069
+
2070
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
2071
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
2072
+ HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
2073
+
2074
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
2075
+ HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
2076
+
2077
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
2078
+ r0_dd = Q6_V_vand_QV(bmask, r0_dd);
2079
+ r1_dd = Q6_V_vand_QV(bmask, r1_dd);
2080
+ r0_ia = Q6_V_vand_QV(bmask, r0_ia);
2081
+ r1_ia = Q6_V_vand_QV(bmask, r1_ia);
2082
+
2083
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
2084
+ HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
2085
+
2086
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
2087
+ r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
2088
+ }
2089
+
2090
+ HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
2091
+ hvx_vec_store_u(s0, 8, rsum);
2092
+ }
2093
+
2094
+ static void vec_dot_iq4nlx4x2_q8x4x2_4x1(const int n,
2095
+ float * restrict s0,
2096
+ const void * restrict vx0,
2097
+ const void * restrict vx1,
2098
+ const void * restrict vx2,
2099
+ const void * restrict vx3,
2100
+ const void * restrict vy0) {
2101
+ assert(n % 32 == 0);
2102
+ assert((unsigned long) vx0 % 128 == 0);
2103
+ assert((unsigned long) vx1 % 128 == 0);
2104
+ assert((unsigned long) vx2 % 128 == 0);
2105
+ assert((unsigned long) vx3 % 128 == 0);
2106
+ assert((unsigned long) vy0 % 128 == 0);
2107
+
2108
+ const uint32_t qk = QK_Q4_0x4x2 * 4;
2109
+
2110
+ const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
2111
+ const uint32_t x_qblk_size = qk / 2; // int4
2112
+ const uint32_t x_qrow_size = n / 2; // int4 (not padded)
2113
+
2114
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
2115
+ const uint32_t y_qblk_size = qk; // int8
2116
+ const uint32_t y_qrow_size = n; // int8 (not padded)
2117
+
2118
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
2119
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
2120
+ const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
2121
+ const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
2122
+ const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; // quants first
2123
+ const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; // then scales
2124
+ const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; // quants first
2125
+ const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; // then scales
2126
+
2127
+ const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
2128
+ const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
2129
+
2130
+ HVX_Vector r0_sum = Q6_V_vzero();
2131
+ HVX_Vector r1_sum = Q6_V_vzero();
2132
+ HVX_Vector r2_sum = Q6_V_vzero();
2133
+ HVX_Vector r3_sum = Q6_V_vzero();
2134
+
2135
+ const uint32_t nb = n / qk;
2136
+ const uint32_t nloe = n % qk;
2137
+
2138
+ uint32_t i = 0;
2139
+ for (; i < nb; i++) {
2140
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
2141
+ HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size);
2142
+ HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_full(r1_x_q + i * x_qblk_size);
2143
+ HVX_Vector_x8 r2_q = hvx_vec_load_iq4nlx4x8_full(r2_x_q + i * x_qblk_size);
2144
+ HVX_Vector_x8 r3_q = hvx_vec_load_iq4nlx4x8_full(r3_x_q + i * x_qblk_size);
2145
+
2146
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
2147
+ HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
2148
+ HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q));
2149
+ HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q));
2150
+
2151
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
2152
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
2153
+ HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
2154
+ HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size));
2155
+ HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size));
2156
+
2157
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
2158
+ HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
2159
+ HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d)));
2160
+ HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d)));
2161
+
2162
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
2163
+ HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
2164
+ HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd);
2165
+ HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd);
2166
+
2167
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
2168
+ r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
2169
+ r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum));
2170
+ r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum));
2171
+ }
2172
+
2173
+ if (nloe) {
2174
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
2175
+ HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe);
2176
+ HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_partial(r1_x_q + i * x_qblk_size, nloe);
2177
+ HVX_Vector_x8 r2_q = hvx_vec_load_iq4nlx4x8_partial(r2_x_q + i * x_qblk_size, nloe);
2178
+ HVX_Vector_x8 r3_q = hvx_vec_load_iq4nlx4x8_partial(r3_x_q + i * x_qblk_size, nloe);
2179
+
2180
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
2181
+ HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe));
2182
+ HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r2_q, vy_q, nloe));
2183
+ HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r3_q, vy_q, nloe));
2184
+
2185
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
2186
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
2187
+ HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
2188
+ HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size));
2189
+ HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size));
2190
+
2191
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
2192
+ HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
2193
+ HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d)));
2194
+ HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d)));
2195
+
2196
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
2197
+ r0_dd = Q6_V_vand_QV(bmask, r0_dd);
2198
+ r1_dd = Q6_V_vand_QV(bmask, r1_dd);
2199
+ r2_dd = Q6_V_vand_QV(bmask, r2_dd);
2200
+ r3_dd = Q6_V_vand_QV(bmask, r3_dd);
2201
+ r0_ia = Q6_V_vand_QV(bmask, r0_ia);
2202
+ r1_ia = Q6_V_vand_QV(bmask, r1_ia);
2203
+ r2_ia = Q6_V_vand_QV(bmask, r2_ia);
2204
+ r3_ia = Q6_V_vand_QV(bmask, r3_ia);
2205
+
2206
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
2207
+ HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
2208
+ HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd);
2209
+ HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd);
2210
+
2211
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
2212
+ r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
2213
+ r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum));
2214
+ r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum));
2215
+ }
2216
+
2217
+ HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } };
2218
+ HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in);
2219
+ hvx_vec_store_u(s0, 16, rsum);
2220
+ }
2221
+
2222
+
2223
+ static void vec_dot_iq4nlx4x2_q8x4x2_2x2(const int n,
2224
+ float * restrict s0,
2225
+ float * restrict s1,
2226
+ const void * restrict vx0,
2227
+ const void * restrict vx1,
2228
+ const void * restrict vy0,
2229
+ const void * restrict vy1) {
2230
+ assert(n % 32 == 0);
2231
+ assert((unsigned long) vx0 % 128 == 0);
2232
+ assert((unsigned long) vx1 % 128 == 0);
2233
+ assert((unsigned long) vy0 % 128 == 0);
2234
+ assert((unsigned long) vy1 % 128 == 0);
2235
+
2236
+ const uint32_t qk = QK_Q4_0x4x2 * 4;
2237
+
2238
+ const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
2239
+ const uint32_t x_qblk_size = qk / 2; // int4
2240
+ const uint32_t x_qrow_size = n / 2; // int4 (not padded)
2241
+
2242
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
2243
+ const uint32_t y_qblk_size = qk; // int8
2244
+ const uint32_t y_qrow_size = n; // int8 (not padded)
2245
+
2246
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0;
2247
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size;
2248
+ const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0;
2249
+ const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size;
2250
+
2251
+ const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0;
2252
+ const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size;
2253
+ const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0;
2254
+ const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size;
2255
+
2256
+ HVX_Vector r0_c0_sum = Q6_V_vzero();
2257
+ HVX_Vector r0_c1_sum = Q6_V_vzero();
2258
+ HVX_Vector r1_c0_sum = Q6_V_vzero();
2259
+ HVX_Vector r1_c1_sum = Q6_V_vzero();
2260
+
2261
+ const uint32_t nb = n / qk;
2262
+ const uint32_t nloe = n % qk;
2263
+
2264
+ uint32_t i = 0;
2265
+ for (; i < nb; i++) {
2266
+ HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size);
2267
+ HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size);
2268
+ HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size);
2269
+ HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_full(r1_x_q + i * x_qblk_size);
2270
+
2271
+ HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
2272
+ HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q));
2273
+ HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q));
2274
+ HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q));
2275
+
2276
+ HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
2277
+ HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
2278
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
2279
+ HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
2280
+
2281
+ HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
2282
+ HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
2283
+ HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
2284
+ HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
2285
+
2286
+ HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
2287
+ HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
2288
+ HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
2289
+ HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
2290
+
2291
+ r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
2292
+ r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
2293
+ r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
2294
+ r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
2295
+ }
2296
+
2297
+ if (nloe) {
2298
+ HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe);
2299
+ HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe);
2300
+ HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe);
2301
+ HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_partial(r1_x_q + i * x_qblk_size, nloe);
2302
+
2303
+ HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe));
2304
+ HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe));
2305
+ HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe));
2306
+ HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe));
2307
+
2308
+ HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
2309
+ HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
2310
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
2311
+ HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
2312
+
2313
+ HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
2314
+ HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
2315
+ HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
2316
+ HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
2317
+
2318
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
2319
+ r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);
2320
+ r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);
2321
+ r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd);
2322
+ r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd);
2323
+ r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia);
2324
+ r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia);
2325
+ r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia);
2326
+ r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia);
2327
+
2328
+ HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
2329
+ HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
2330
+ HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
2331
+ HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
2332
+
2333
+ r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
2334
+ r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
2335
+ r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
2336
+ r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
2337
+ }
2338
+
2339
+ HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
2340
+ HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
2341
+
2342
+ hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum);
2343
+ hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum);
2344
+ }
2345
+
2346
+ static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) {
2347
+ assert(n % 32 == 0); // min sub-block size
2348
+ assert((unsigned long) vx0 % 128 == 0);
2349
+ assert((unsigned long) vy0 % 128 == 0);
2350
+
2351
+ const uint32_t qk = QK_MXFP4x4x2 * 4;
2352
+
2353
+ const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0
2354
+ const uint32_t x_qblk_size = qk / 2; // fp4
2355
+ const uint32_t x_qrow_size = n / 2; // fp4 (not padded)
2356
+
2357
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
2358
+ const uint32_t y_qblk_size = qk; // int8
2359
+ const uint32_t y_qrow_size = n; // int8 (not padded)
2360
+
2361
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first
2362
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales
2363
+
2364
+ const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
2365
+ const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
2366
+
2367
+ // Row sum (sf)
2368
+ HVX_Vector r0_sum = Q6_V_vzero();
2369
+
2370
+ // Multiply and accumulate into int32.
2371
+ // Compute combined scale (fp32).
2372
+ // Apply scale to acc and accumulate into the row sum (qf32).
2373
+
2374
+ const uint32_t nb = n / qk; // num full blocks
2375
+ int32_t nloe = n % qk; // num leftover elemements (must be signed)
2376
+
2377
+ uint32_t i = 0;
2378
+ for (; i < nb; i++) {
2379
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full( y_q + i * y_qblk_size);
2380
+ HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size);
2381
+
2382
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
2383
+
2384
+ HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
2385
+ HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
2386
+
2387
+ // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
2388
+ HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16
2389
+ vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half));
2390
+ vy_d = Q6_Vsf_equals_Vqf32(vy_d);
2391
+
2392
+ // Convert rX_d scales from e8m0 to fp32
2393
+ // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
2394
+ // Left shift with zero fill to create FP32
2395
+ // FIXME: might need to handle zero as a special case (see ggml-cpu code)
2396
+ HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0;
2397
+ HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
2398
+ r0_d = Q6_V_vdelta_VV(r0_d, expand);
2399
+ r0_d = Q6_V_vand_VV(r0_d, e8m0_mask);
2400
+ r0_d = Q6_Vw_vasl_VwR(r0_d, 23);
2401
+
2402
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));
2403
+
2404
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
2405
+
2406
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
2407
+ }
2408
+
2409
+ // Process leftovers
2410
+ if (nloe) {
2411
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial( y_q + i * y_qblk_size, nloe);
2412
+ HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
2413
+
2414
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
2415
+
2416
+ HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
2417
+ HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
2418
+
2419
+ // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
2420
+ HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16
2421
+ vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half));
2422
+ vy_d = Q6_Vsf_equals_Vqf32(vy_d);
2423
+
2424
+ // Convert rX_d scales from e8m0 to fp32
2425
+ // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
2426
+ // Left shift with zero fill to create FP32
2427
+ // FIXME: might need to handle zero as a special case (see ggml-cpu code)
2428
+ HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0;
2429
+ HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
2430
+ r0_d = Q6_V_vdelta_VV(r0_d, expand);
2431
+ r0_d = Q6_V_vand_VV(r0_d, e8m0_mask);
2432
+ r0_d = Q6_Vw_vasl_VwR(r0_d, 23);
2433
+
2434
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));
2435
+
2436
+ // Zero-out unused scales
895
2437
  HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
896
- r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);
897
- r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);
898
- r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd);
899
- r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd);
900
- r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia);
901
- r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia);
902
- r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia);
903
- r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia);
2438
+ r0_dd = Q6_V_vand_QV(bmask, r0_dd);
2439
+ r0_ia = Q6_V_vand_QV(bmask, r0_ia);
904
2440
 
905
- HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
906
- HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
907
- HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
908
- HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
2441
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
909
2442
 
910
- r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
911
- r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
912
- r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
913
- r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
2443
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
914
2444
  }
915
2445
 
916
- // Reduce and store results
917
- HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
918
- HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
2446
+ r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
919
2447
 
920
- hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0
921
- hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1
2448
+ hvx_vec_store_u(s0, 4, r0_sum);
922
2449
  }
923
2450
 
924
- static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) {
2451
+ static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0,
2452
+ const void * restrict vx0, const void * restrict vx1,
2453
+ const void * restrict vy0) {
925
2454
  assert(n % 32 == 0); // min sub-block size
926
2455
  assert((unsigned long) vx0 % 128 == 0);
2456
+ assert((unsigned long) vx1 % 128 == 0);
927
2457
  assert((unsigned long) vy0 % 128 == 0);
928
2458
 
929
2459
  const uint32_t qk = QK_MXFP4x4x2 * 4;
930
2460
 
931
- const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0
932
- const uint32_t x_qblk_size = qk / 2; // fp4
933
- const uint32_t x_qrow_size = n / 2; // fp4 (not padded)
2461
+ const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0
2462
+ const uint32_t x_qblk_size = qk / 2; // fp4
2463
+ const uint32_t x_qrow_size = n / 2; // fp4 (not padded)
934
2464
 
935
- const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
936
- const uint32_t y_qblk_size = qk; // int8
937
- const uint32_t y_qrow_size = n; // int8 (not padded)
2465
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
2466
+ const uint32_t y_qblk_size = qk; // int8
2467
+ const uint32_t y_qrow_size = n; // int8 (not padded)
938
2468
 
939
- const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first
940
- const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales
2469
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
2470
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
2471
+ const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
2472
+ const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
941
2473
 
942
- const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
943
- const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
2474
+ const uint8_t * restrict y_q = ((const uint8_t *) vy0) + 0; // quants first
2475
+ const uint8_t * restrict y_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales
944
2476
 
945
2477
  // Row sum (sf)
946
2478
  HVX_Vector r0_sum = Q6_V_vzero();
2479
+ HVX_Vector r1_sum = Q6_V_vzero();
947
2480
 
948
2481
  // Multiply and accumulate into int32.
949
2482
  // Compute combined scale (fp32).
950
- // Apply scale to acc and accumulate into the row sum (qf32).
2483
+ // Apply scale to acc and accumulate into the row sum (f32).
951
2484
 
952
2485
  const uint32_t nb = n / qk; // num full blocks
953
2486
  int32_t nloe = n % qk; // num leftover elemements (must be signed)
@@ -956,11 +2489,14 @@ static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const
956
2489
  for (; i < nb; i++) {
957
2490
  HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full( y_q + i * y_qblk_size);
958
2491
  HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size);
2492
+ HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_full(r1_x_q + i * x_qblk_size);
959
2493
 
960
2494
  HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
2495
+ HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
961
2496
 
962
2497
  HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
963
2498
  HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
2499
+ HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
964
2500
 
965
2501
  // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
966
2502
  HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16
@@ -976,23 +2512,32 @@ static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const
976
2512
  r0_d = Q6_V_vdelta_VV(r0_d, expand);
977
2513
  r0_d = Q6_V_vand_VV(r0_d, e8m0_mask);
978
2514
  r0_d = Q6_Vw_vasl_VwR(r0_d, 23);
2515
+ r1_d = Q6_V_vdelta_VV(r1_d, expand);
2516
+ r1_d = Q6_V_vand_VV(r1_d, e8m0_mask);
2517
+ r1_d = Q6_Vw_vasl_VwR(r1_d, 23);
979
2518
 
980
2519
  HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));
2520
+ HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d));
981
2521
 
982
2522
  HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
2523
+ HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
983
2524
 
984
2525
  r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
2526
+ r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
985
2527
  }
986
2528
 
987
2529
  // Process leftovers
988
2530
  if (nloe) {
989
2531
  HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial( y_q + i * y_qblk_size, nloe);
990
2532
  HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
2533
+ HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
991
2534
 
992
- HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
2535
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
2536
+ HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
993
2537
 
994
2538
  HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
995
2539
  HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
2540
+ HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
996
2541
 
997
2542
  // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
998
2543
  HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16
@@ -1008,30 +2553,40 @@ static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const
1008
2553
  r0_d = Q6_V_vdelta_VV(r0_d, expand);
1009
2554
  r0_d = Q6_V_vand_VV(r0_d, e8m0_mask);
1010
2555
  r0_d = Q6_Vw_vasl_VwR(r0_d, 23);
2556
+ r1_d = Q6_V_vdelta_VV(r1_d, expand);
2557
+ r1_d = Q6_V_vand_VV(r1_d, e8m0_mask);
2558
+ r1_d = Q6_Vw_vasl_VwR(r1_d, 23);
1011
2559
 
1012
2560
  HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));
2561
+ HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d));
1013
2562
 
1014
- // Zero-out unused scales
2563
+ // Zero-out unused values
1015
2564
  HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
1016
2565
  r0_dd = Q6_V_vand_QV(bmask, r0_dd);
2566
+ r1_dd = Q6_V_vand_QV(bmask, r1_dd);
1017
2567
  r0_ia = Q6_V_vand_QV(bmask, r0_ia);
2568
+ r1_ia = Q6_V_vand_QV(bmask, r1_ia);
1018
2569
 
1019
2570
  HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
2571
+ HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
1020
2572
 
1021
2573
  r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
2574
+ r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
1022
2575
  }
1023
2576
 
1024
- r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
1025
-
1026
- hvx_vec_store_u(s0, 4, r0_sum);
2577
+ HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
2578
+ hvx_vec_store_u(s0, 8, rsum);
1027
2579
  }
1028
2580
 
1029
- static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0,
2581
+ static void vec_dot_mxfp4x4x2_q8x4x2_4x1(const int n, float * restrict s0,
1030
2582
  const void * restrict vx0, const void * restrict vx1,
2583
+ const void * restrict vx2, const void * restrict vx3,
1031
2584
  const void * restrict vy0) {
1032
2585
  assert(n % 32 == 0); // min sub-block size
1033
2586
  assert((unsigned long) vx0 % 128 == 0);
1034
2587
  assert((unsigned long) vx1 % 128 == 0);
2588
+ assert((unsigned long) vx2 % 128 == 0);
2589
+ assert((unsigned long) vx3 % 128 == 0);
1035
2590
  assert((unsigned long) vy0 % 128 == 0);
1036
2591
 
1037
2592
  const uint32_t qk = QK_MXFP4x4x2 * 4;
@@ -1048,17 +2603,19 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0,
1048
2603
  const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
1049
2604
  const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
1050
2605
  const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
2606
+ const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; // quants first
2607
+ const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; // then scales
2608
+ const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; // quants first
2609
+ const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; // then scales
1051
2610
 
1052
2611
  const uint8_t * restrict y_q = ((const uint8_t *) vy0) + 0; // quants first
1053
- const uint8_t * restrict y_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales
2612
+ const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
1054
2613
 
1055
2614
  // Row sum (sf)
1056
2615
  HVX_Vector r0_sum = Q6_V_vzero();
1057
2616
  HVX_Vector r1_sum = Q6_V_vzero();
1058
-
1059
- // Multiply and accumulate into int32.
1060
- // Compute combined scale (fp32).
1061
- // Apply scale to acc and accumulate into the row sum (f32).
2617
+ HVX_Vector r2_sum = Q6_V_vzero();
2618
+ HVX_Vector r3_sum = Q6_V_vzero();
1062
2619
 
1063
2620
  const uint32_t nb = n / qk; // num full blocks
1064
2621
  int32_t nloe = n % qk; // num leftover elemements (must be signed)
@@ -1068,13 +2625,19 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0,
1068
2625
  HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full( y_q + i * y_qblk_size);
1069
2626
  HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size);
1070
2627
  HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_full(r1_x_q + i * x_qblk_size);
2628
+ HVX_Vector_x8 r2_q = hvx_vec_load_mxfp4x4x8_full(r2_x_q + i * x_qblk_size);
2629
+ HVX_Vector_x8 r3_q = hvx_vec_load_mxfp4x4x8_full(r3_x_q + i * x_qblk_size);
1071
2630
 
1072
2631
  HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
1073
2632
  HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
2633
+ HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q));
2634
+ HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q));
1074
2635
 
1075
- HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
2636
+ HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
1076
2637
  HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
1077
2638
  HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
2639
+ HVX_Vector r2_d = *(const HVX_UVector *) (r2_x_d + i * x_dblk_size);
2640
+ HVX_Vector r3_d = *(const HVX_UVector *) (r3_x_d + i * x_dblk_size);
1078
2641
 
1079
2642
  // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
1080
2643
  HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16
@@ -1082,9 +2645,6 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0,
1082
2645
  vy_d = Q6_Vsf_equals_Vqf32(vy_d);
1083
2646
 
1084
2647
  // Convert rX_d scales from e8m0 to fp32
1085
- // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
1086
- // Left shift with zero fill to create FP32
1087
- // FIXME: might need to handle zero as a special case (see ggml-cpu code)
1088
2648
  HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0;
1089
2649
  HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
1090
2650
  r0_d = Q6_V_vdelta_VV(r0_d, expand);
@@ -1093,29 +2653,46 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0,
1093
2653
  r1_d = Q6_V_vdelta_VV(r1_d, expand);
1094
2654
  r1_d = Q6_V_vand_VV(r1_d, e8m0_mask);
1095
2655
  r1_d = Q6_Vw_vasl_VwR(r1_d, 23);
2656
+ r2_d = Q6_V_vdelta_VV(r2_d, expand);
2657
+ r2_d = Q6_V_vand_VV(r2_d, e8m0_mask);
2658
+ r2_d = Q6_Vw_vasl_VwR(r2_d, 23);
2659
+ r3_d = Q6_V_vdelta_VV(r3_d, expand);
2660
+ r3_d = Q6_V_vand_VV(r3_d, e8m0_mask);
2661
+ r3_d = Q6_Vw_vasl_VwR(r3_d, 23);
1096
2662
 
1097
2663
  HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));
1098
2664
  HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d));
2665
+ HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r2_d, vy_d));
2666
+ HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r3_d, vy_d));
1099
2667
 
1100
2668
  HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
1101
2669
  HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
2670
+ HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd);
2671
+ HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd);
1102
2672
 
1103
2673
  r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
1104
2674
  r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
2675
+ r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum));
2676
+ r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum));
1105
2677
  }
1106
2678
 
1107
- // Process leftovers
1108
2679
  if (nloe) {
1109
2680
  HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial( y_q + i * y_qblk_size, nloe);
1110
2681
  HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
1111
2682
  HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
2683
+ HVX_Vector_x8 r2_q = hvx_vec_load_mxfp4x4x8_partial(r2_x_q + i * x_qblk_size, nloe);
2684
+ HVX_Vector_x8 r3_q = hvx_vec_load_mxfp4x4x8_partial(r3_x_q + i * x_qblk_size, nloe);
1112
2685
 
1113
2686
  HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
1114
2687
  HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
2688
+ HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q));
2689
+ HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q));
1115
2690
 
1116
2691
  HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
1117
2692
  HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
1118
2693
  HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
2694
+ HVX_Vector r2_d = *(const HVX_UVector *) (r2_x_d + i * x_dblk_size);
2695
+ HVX_Vector r3_d = *(const HVX_UVector *) (r3_x_d + i * x_dblk_size);
1119
2696
 
1120
2697
  // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
1121
2698
  HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16
@@ -1123,9 +2700,6 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0,
1123
2700
  vy_d = Q6_Vsf_equals_Vqf32(vy_d);
1124
2701
 
1125
2702
  // Convert rX_d scales from e8m0 to fp32
1126
- // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
1127
- // Left shift with zero fill to create FP32
1128
- // FIXME: might need to handle zero as a special case (see ggml-cpu code)
1129
2703
  HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0;
1130
2704
  HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
1131
2705
  r0_d = Q6_V_vdelta_VV(r0_d, expand);
@@ -1134,28 +2708,46 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0,
1134
2708
  r1_d = Q6_V_vdelta_VV(r1_d, expand);
1135
2709
  r1_d = Q6_V_vand_VV(r1_d, e8m0_mask);
1136
2710
  r1_d = Q6_Vw_vasl_VwR(r1_d, 23);
2711
+ r2_d = Q6_V_vdelta_VV(r2_d, expand);
2712
+ r2_d = Q6_V_vand_VV(r2_d, e8m0_mask);
2713
+ r2_d = Q6_Vw_vasl_VwR(r2_d, 23);
2714
+ r3_d = Q6_V_vdelta_VV(r3_d, expand);
2715
+ r3_d = Q6_V_vand_VV(r3_d, e8m0_mask);
2716
+ r3_d = Q6_Vw_vasl_VwR(r3_d, 23);
1137
2717
 
1138
2718
  HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));
1139
2719
  HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d));
2720
+ HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r2_d, vy_d));
2721
+ HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r3_d, vy_d));
1140
2722
 
1141
2723
  // Zero-out unused values
1142
2724
  HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
1143
2725
  r0_dd = Q6_V_vand_QV(bmask, r0_dd);
1144
2726
  r1_dd = Q6_V_vand_QV(bmask, r1_dd);
2727
+ r2_dd = Q6_V_vand_QV(bmask, r2_dd);
2728
+ r3_dd = Q6_V_vand_QV(bmask, r3_dd);
1145
2729
  r0_ia = Q6_V_vand_QV(bmask, r0_ia);
1146
2730
  r1_ia = Q6_V_vand_QV(bmask, r1_ia);
2731
+ r2_ia = Q6_V_vand_QV(bmask, r2_ia);
2732
+ r3_ia = Q6_V_vand_QV(bmask, r3_ia);
1147
2733
 
1148
2734
  HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
1149
2735
  HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
2736
+ HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd);
2737
+ HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd);
1150
2738
 
1151
2739
  r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
1152
2740
  r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
2741
+ r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum));
2742
+ r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum));
1153
2743
  }
1154
2744
 
1155
- HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
1156
- hvx_vec_store_u(s0, 8, rsum);
2745
+ HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } };
2746
+ HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in);
2747
+ hvx_vec_store_u(s0, 16, rsum);
1157
2748
  }
1158
2749
 
2750
+
1159
2751
  static void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1,
1160
2752
  const void * restrict vx0, const void * restrict vx1,
1161
2753
  const void * restrict vy0, const void * restrict vy1) {
@@ -1326,6 +2918,176 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float
1326
2918
  hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1
1327
2919
  }
1328
2920
 
2921
+ #if __HVX_ARCH__ < 79
2922
+ #define HVX_OP_ADD_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b))
2923
+ #define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
2924
+ #else
2925
+ #define HVX_OP_ADD_F32(a, b) Q6_Vsf_vadd_VsfVsf(a, b)
2926
+ #define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
2927
+ #endif
2928
+
2929
+ static void vec_dot_f32_f32_aa_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2930
+ const HVX_Vector * restrict x = (const HVX_Vector *) vx;
2931
+ const HVX_Vector * restrict y = (const HVX_Vector *) vy;
2932
+
2933
+ uint32_t nvec = n / VLEN_FP32; // num full fp32 hvx vectors
2934
+ uint32_t nloe = n % VLEN_FP32; // leftover elements
2935
+
2936
+ HVX_Vector rsum = Q6_V_vzero();
2937
+
2938
+ uint32_t i = 0;
2939
+
2940
+ #pragma unroll(4)
2941
+ for (i = 0; i < nvec; i++) {
2942
+ HVX_Vector prod = HVX_OP_MUL_F32(x[i], y[i]);
2943
+ rsum = HVX_OP_ADD_F32(rsum, prod);
2944
+ }
2945
+
2946
+ if (nloe) {
2947
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
2948
+ HVX_Vector x_sf = Q6_V_vand_QV(bmask, x[i]);
2949
+ HVX_Vector y_sf = Q6_V_vand_QV(bmask, y[i]);
2950
+ HVX_Vector prod = HVX_OP_MUL_F32(x_sf, y_sf);
2951
+ rsum = HVX_OP_ADD_F32(rsum, prod);
2952
+ }
2953
+
2954
+ *s = hvx_vec_get_f32(hvx_vec_reduce_sum_f32(rsum));
2955
+ }
2956
+
2957
+ static void vec_dot_f32_f32_aa_2x1(const int n, float * restrict s0,
2958
+ const void * restrict vx0, const void * restrict vx1,
2959
+ const void * restrict vy0) {
2960
+ const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
2961
+ const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
2962
+ const HVX_Vector * restrict y = (const HVX_Vector *) vy0;
2963
+
2964
+ uint32_t nvec = n / VLEN_FP32;
2965
+ uint32_t nloe = n % VLEN_FP32;
2966
+
2967
+ HVX_Vector rsum0 = Q6_V_vzero();
2968
+ HVX_Vector rsum1 = Q6_V_vzero();
2969
+
2970
+ uint32_t i = 0;
2971
+
2972
+ #pragma unroll(2)
2973
+ for (i = 0; i < nvec; i++) {
2974
+ HVX_Vector y_sf = y[i];
2975
+ HVX_Vector prod0 = HVX_OP_MUL_F32(x0[i], y_sf);
2976
+ HVX_Vector prod1 = HVX_OP_MUL_F32(x1[i], y_sf);
2977
+ rsum0 = HVX_OP_ADD_F32(rsum0, prod0);
2978
+ rsum1 = HVX_OP_ADD_F32(rsum1, prod1);
2979
+ }
2980
+
2981
+ if (nloe) {
2982
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
2983
+ HVX_Vector y_sf = Q6_V_vand_QV(bmask, y[i]);
2984
+ HVX_Vector x0_sf = Q6_V_vand_QV(bmask, x0[i]);
2985
+ HVX_Vector x1_sf = Q6_V_vand_QV(bmask, x1[i]);
2986
+ HVX_Vector prod0 = HVX_OP_MUL_F32(x0_sf, y_sf);
2987
+ HVX_Vector prod1 = HVX_OP_MUL_F32(x1_sf, y_sf);
2988
+ rsum0 = HVX_OP_ADD_F32(rsum0, prod0);
2989
+ rsum1 = HVX_OP_ADD_F32(rsum1, prod1);
2990
+ }
2991
+
2992
+ HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(rsum0, rsum1);
2993
+ HVX_VectorAlias va;
2994
+ va.v = rsum;
2995
+ s0[0] = va.fp32[0];
2996
+ s0[1] = va.fp32[1];
2997
+ }
2998
+
2999
+ static void vec_dot_f32_f32_aa_2x2(const int n, float * restrict s0, float * restrict s1,
3000
+ const void * restrict vx0, const void * restrict vx1,
3001
+ const void * restrict vy0, const void * restrict vy1) {
3002
+ const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
3003
+ const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
3004
+ const HVX_Vector * restrict y0 = (const HVX_Vector *) vy0;
3005
+ const HVX_Vector * restrict y1 = (const HVX_Vector *) vy1;
3006
+
3007
+ uint32_t nvec = n / VLEN_FP32;
3008
+ uint32_t nloe = n % VLEN_FP32;
3009
+
3010
+ HVX_Vector r0_c0_sum = Q6_V_vzero();
3011
+ HVX_Vector r0_c1_sum = Q6_V_vzero();
3012
+ HVX_Vector r1_c0_sum = Q6_V_vzero();
3013
+ HVX_Vector r1_c1_sum = Q6_V_vzero();
3014
+
3015
+ uint32_t i = 0;
3016
+
3017
+ #pragma unroll(2)
3018
+ for (i = 0; i < nvec; i++) {
3019
+ HVX_Vector r0_sf = x0[i];
3020
+ HVX_Vector r1_sf = x1[i];
3021
+ HVX_Vector c0_sf = y0[i];
3022
+ HVX_Vector c1_sf = y1[i];
3023
+
3024
+ r0_c0_sum = HVX_OP_ADD_F32(r0_c0_sum, HVX_OP_MUL_F32(r0_sf, c0_sf));
3025
+ r0_c1_sum = HVX_OP_ADD_F32(r0_c1_sum, HVX_OP_MUL_F32(r0_sf, c1_sf));
3026
+ r1_c0_sum = HVX_OP_ADD_F32(r1_c0_sum, HVX_OP_MUL_F32(r1_sf, c0_sf));
3027
+ r1_c1_sum = HVX_OP_ADD_F32(r1_c1_sum, HVX_OP_MUL_F32(r1_sf, c1_sf));
3028
+ }
3029
+
3030
+ if (nloe) {
3031
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
3032
+
3033
+ HVX_Vector r0_sf = Q6_V_vand_QV(bmask, x0[i]);
3034
+ HVX_Vector r1_sf = Q6_V_vand_QV(bmask, x1[i]);
3035
+ HVX_Vector c0_sf = Q6_V_vand_QV(bmask, y0[i]);
3036
+ HVX_Vector c1_sf = Q6_V_vand_QV(bmask, y1[i]);
3037
+
3038
+ r0_c0_sum = HVX_OP_ADD_F32(r0_c0_sum, HVX_OP_MUL_F32(r0_sf, c0_sf));
3039
+ r0_c1_sum = HVX_OP_ADD_F32(r0_c1_sum, HVX_OP_MUL_F32(r0_sf, c1_sf));
3040
+ r1_c0_sum = HVX_OP_ADD_F32(r1_c0_sum, HVX_OP_MUL_F32(r1_sf, c0_sf));
3041
+ r1_c1_sum = HVX_OP_ADD_F32(r1_c1_sum, HVX_OP_MUL_F32(r1_sf, c1_sf));
3042
+ }
3043
+
3044
+ // Reduce and store results
3045
+ HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
3046
+ HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
3047
+
3048
+ HVX_VectorAlias va0, va1;
3049
+ va0.v = r0_r1_c0_sum;
3050
+ va1.v = r0_r1_c1_sum;
3051
+ s0[0] = va0.fp32[0];
3052
+ s0[1] = va0.fp32[1];
3053
+ s1[0] = va1.fp32[0];
3054
+ s1[1] = va1.fp32[1];
3055
+ }
3056
+
3057
+ static void vec_dot_f32_f32_uu_1x1(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
3058
+ const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x;
3059
+ const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y;
3060
+
3061
+ uint32_t nvec = n / VLEN_FP32; // num full fp32 hvx vectors
3062
+ uint32_t nloe = n % VLEN_FP32; // leftover elements
3063
+
3064
+ HVX_Vector rsum = Q6_V_vzero();
3065
+
3066
+ uint32_t i = 0;
3067
+
3068
+ #pragma unroll(2)
3069
+ for (i = 0; i < nvec; i++) {
3070
+ HVX_Vector x_sf = vx[i];
3071
+ HVX_Vector y_sf = vy[i];
3072
+
3073
+ rsum = HVX_OP_ADD_F32(rsum, HVX_OP_MUL_F32(x_sf, y_sf));
3074
+ }
3075
+
3076
+ if (nloe) {
3077
+ HVX_Vector x_sf = vx[i];
3078
+ HVX_Vector y_sf = vy[i];
3079
+
3080
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
3081
+ x_sf = Q6_V_vand_QV(bmask, x_sf);
3082
+ y_sf = Q6_V_vand_QV(bmask, y_sf);
3083
+
3084
+ rsum = HVX_OP_ADD_F32(rsum, HVX_OP_MUL_F32(x_sf, y_sf));
3085
+ }
3086
+
3087
+ rsum = hvx_vec_reduce_sum_f32(rsum);
3088
+ hvx_vec_store_u(&s[0], 4, rsum);
3089
+ }
3090
+
1329
3091
  static void vec_dot_f16_f16_aa_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
1330
3092
  const HVX_Vector * restrict x = (const HVX_Vector *) vx;
1331
3093
  const HVX_Vector * restrict y = (const HVX_Vector *) vy;
@@ -1533,11 +3295,11 @@ static void vec_dot_f16_f32_uu_1x1(const int n, float * restrict s, const void *
1533
3295
  hvx_vec_store_u(&s[0], 4, rsum);
1534
3296
  }
1535
3297
 
1536
- #define htp_matmul_tensors_preamble \
1537
- struct htp_tensor * restrict src0 = &octx->src0; \
1538
- struct htp_tensor * restrict src1 = &octx->src1; \
1539
- struct htp_tensor * restrict src2 = &octx->src2; \
1540
- struct htp_tensor * restrict dst = &octx->dst; \
3298
+ #define htp_matmul_tensors_preamble \
3299
+ const struct htp_tensor * restrict src0 = octx->src[0]; \
3300
+ const struct htp_tensor * restrict src1 = octx->src[1]; \
3301
+ const struct htp_tensor * restrict src2 = octx->src[2]; \
3302
+ const struct htp_tensor * restrict dst = octx->dst; \
1541
3303
  struct htp_spad * restrict src0_spad = &octx->src0_spad; \
1542
3304
  struct htp_spad * restrict src1_spad = &octx->src1_spad; \
1543
3305
  struct htp_spad * restrict dst_spad = &octx->dst_spad; \
@@ -1744,7 +3506,7 @@ static void matmul_2d(unsigned int nth, unsigned int ith, void * data) {
1744
3506
  // Process the last row (if any)
1745
3507
  if (src0_end_row != src0_end_row_x2) {
1746
3508
  uint32_t ir0 = src0_end_row_x2;
1747
- const int is0 = (ir0 - src0_start_row);
3509
+ const int is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
1748
3510
  dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
1749
3511
  src0_stride, src0_row_size, 1);
1750
3512
  const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
@@ -1773,7 +3535,6 @@ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) {
1773
3535
 
1774
3536
  const uint32_t src0_start_row = src0_nrows_per_thread * ith;
1775
3537
  const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
1776
- const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U);
1777
3538
 
1778
3539
  // no work for this thread
1779
3540
  if (src0_start_row >= src0_end_row) {
@@ -1803,39 +3564,89 @@ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) {
1803
3564
  const uint8_t * restrict src1_col = (const uint8_t *) src1_data;
1804
3565
  float * restrict dst_col = (float *) dst->data;
1805
3566
 
1806
- // Prefill spad with 2x src0 rows
1807
- #pragma unroll(2)
1808
- for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
1809
- const uint32_t is0 = (ir0 - src0_start_row);
1810
- if (is0 >= MM_SPAD_SRC0_NROWS) {
1811
- break;
3567
+ if (mmctx->vec_dot_4x1 != NULL) {
3568
+ const uint32_t src0_end_row_x4 = src0_start_row + ((src0_end_row - src0_start_row) & ~3U);
3569
+
3570
+ // Prefill spad with 4x src0 rows
3571
+ #pragma unroll(4)
3572
+ for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x4; ir0 += 4) {
3573
+ const uint32_t is0 = (ir0 - src0_start_row);
3574
+ if (is0 >= MM_SPAD_SRC0_NROWS) {
3575
+ break;
3576
+ }
3577
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
3578
+ src0_stride, src0_row_size, 4);
1812
3579
  }
1813
- dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
1814
- src0_stride, src0_row_size, 2);
1815
- }
1816
3580
 
1817
- // Process src0 rows
1818
- for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
1819
- const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
1820
- mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col);
3581
+ // Process src0 rows
3582
+ for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x4; ir0 += 4) {
3583
+ const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
3584
+ mmctx->vec_dot_4x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, ss0 + 2 * src0_stride, ss0 + 3 * src0_stride, src1_col);
1821
3585
 
1822
- // Prefetch next (n + spad_nrows) row
1823
- const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
1824
- const uint32_t is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
1825
- if (pr0 < src0_end_row_x2) {
1826
- dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size),
3586
+ // Prefetch next (n + spad_nrows) row
3587
+ const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
3588
+ const uint32_t is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
3589
+ if (pr0 < src0_end_row_x4) {
3590
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size),
3591
+ src0_stride, src0_row_size, 4);
3592
+ }
3593
+ }
3594
+
3595
+ // Process leftovers
3596
+ uint32_t ir0 = src0_end_row_x4;
3597
+ if (ir0 + 2 <= src0_end_row) {
3598
+ const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
3599
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
3600
+ src0_stride, src0_row_size, 2);
3601
+ const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
3602
+ mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col);
3603
+ ir0 += 2;
3604
+ }
3605
+ if (ir0 < src0_end_row) {
3606
+ const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
3607
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
3608
+ src0_stride, src0_row_size, 1);
3609
+ const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
3610
+ mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col);
3611
+ ir0 += 1;
3612
+ }
3613
+ } else {
3614
+ const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U);
3615
+
3616
+ // Prefill spad with 2x src0 rows
3617
+ #pragma unroll(2)
3618
+ for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
3619
+ const uint32_t is0 = (ir0 - src0_start_row);
3620
+ if (is0 >= MM_SPAD_SRC0_NROWS) {
3621
+ break;
3622
+ }
3623
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
1827
3624
  src0_stride, src0_row_size, 2);
1828
3625
  }
1829
- }
1830
3626
 
1831
- // Process the last row (if any)
1832
- if (src0_end_row != src0_end_row_x2) {
1833
- const uint32_t ir0 = src0_end_row_x2;
1834
- const uint32_t is0 = (ir0 - src0_start_row);
1835
- dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
1836
- src0_stride, src0_row_size, 1);
1837
- const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
1838
- mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col);
3627
+ // Process src0 rows
3628
+ for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
3629
+ const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
3630
+ mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col);
3631
+
3632
+ // Prefetch next (n + spad_nrows) row
3633
+ const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
3634
+ const uint32_t is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
3635
+ if (pr0 < src0_end_row_x2) {
3636
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size),
3637
+ src0_stride, src0_row_size, 2);
3638
+ }
3639
+ }
3640
+
3641
+ // Process the last row (if any)
3642
+ if (src0_end_row != src0_end_row_x2) {
3643
+ const uint32_t ir0 = src0_end_row_x2;
3644
+ const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
3645
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
3646
+ src0_stride, src0_row_size, 1);
3647
+ const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
3648
+ mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col);
3649
+ }
1839
3650
  }
1840
3651
 
1841
3652
  hvx_copy_f32_ua((uint8_t *) &dst_col[src0_start_row], (uint8_t *) tmp, src0_end_row - src0_start_row);
@@ -1859,8 +3670,8 @@ struct mmid_row_mapping {
1859
3670
  static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
1860
3671
  htp_matmul_preamble;
1861
3672
 
1862
- struct htp_tensor * restrict ids = &octx->src2;
1863
- struct htp_spad * restrict src2_spad = &octx->src2_spad;
3673
+ const struct htp_tensor * restrict ids = octx->src[2];
3674
+ struct htp_spad * restrict src2_spad = &octx->src2_spad;
1864
3675
 
1865
3676
  uint64_t t1, t2;
1866
3677
  t1 = HAP_perf_get_qtimer_count();
@@ -1880,11 +3691,8 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
1880
3691
  const uint32_t n_ids = ids->ne[0]; // n_expert_used
1881
3692
  const uint32_t n_as = ne02; // n_expert
1882
3693
 
1883
- const size_t matrix_row_counts_size = n_as * sizeof(uint32_t);
1884
- const size_t matrix_row_map_size = n_as * ids->ne[0] * ids->ne[1] * sizeof(struct mmid_row_mapping);
1885
-
1886
- const uint32_t * matrix_row_counts = (const uint32_t *) src2_spad->data + 0;
1887
- const struct mmid_row_mapping * matrix_rows = (const void *) src2_spad->data + matrix_row_counts_size;
3694
+ const uint32_t * matrix_row_counts = mmctx->matrix_row_counts;
3695
+ const struct mmid_row_mapping * matrix_rows = mmctx->matrix_rows;
1888
3696
 
1889
3697
  const size_t dst_row_size = nb1;
1890
3698
  const size_t src0_row_size = nb01;
@@ -1906,6 +3714,10 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
1906
3714
  continue;
1907
3715
  }
1908
3716
 
3717
+ if (mmctx->hmx_eligible) {
3718
+ continue;
3719
+ }
3720
+
1909
3721
  const uint8_t * src0_row = (const uint8_t *) src0->data + (0 + cur_a * nb02 + 0);
1910
3722
 
1911
3723
  // Prefill spad with src0 rows
@@ -1947,7 +3759,7 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
1947
3759
  // Process the last row (if any)
1948
3760
  if (src0_end_row != src0_end_row_x2) {
1949
3761
  uint32_t ir0 = src0_end_row_x2;
1950
- const uint32_t is0 = (ir0 - src0_start_row);
3762
+ const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
1951
3763
  dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
1952
3764
  src0_row_size_padded, src0_row_size, 1);
1953
3765
  const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
@@ -1978,8 +3790,8 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
1978
3790
  static void matvec_id(unsigned int nth, unsigned int ith, void * data) {
1979
3791
  htp_matmul_preamble;
1980
3792
 
1981
- struct htp_tensor * restrict ids = &octx->src2;
1982
- struct htp_spad * restrict src2_spad = &octx->src2_spad;
3793
+ const struct htp_tensor * restrict ids = octx->src[2];
3794
+ struct htp_spad * restrict src2_spad = &octx->src2_spad;
1983
3795
 
1984
3796
  uint64_t t1, t2;
1985
3797
  t1 = HAP_perf_get_qtimer_count();
@@ -2049,7 +3861,7 @@ static void matvec_id(unsigned int nth, unsigned int ith, void * data) {
2049
3861
  // Process the last row (if any)
2050
3862
  if (src0_end_row != src0_end_row_x2) {
2051
3863
  uint32_t ir0 = src0_end_row_x2;
2052
- const uint32_t is0 = (ir0 - src0_start_row);
3864
+ const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
2053
3865
  dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
2054
3866
  src0_row_size_padded, src0_row_size, 1);
2055
3867
  const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
@@ -2067,6 +3879,94 @@ static void matvec_id(unsigned int nth, unsigned int ith, void * data) {
2067
3879
 
2068
3880
  // *** dynamic quant
2069
3881
 
3882
+ static inline void quantize_block_f32_q8_1x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
3883
+ assert((unsigned long) x % 128 == 0);
3884
+ assert((unsigned long) y_q % 128 == 0);
3885
+
3886
+ HVX_Vector * vx = (HVX_Vector *) x;
3887
+ HVX_Vector zero = Q6_V_vzero();
3888
+
3889
+ // Use reduce max fp32 to find max(abs(e)) first
3890
+ HVX_Vector vmax0_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[0]));
3891
+ HVX_Vector vmax1_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[1]));
3892
+ HVX_Vector vmax2_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[2]));
3893
+ HVX_Vector vmax3_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[3]));
3894
+
3895
+ // Load and convert into QF32
3896
+ HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements
3897
+ HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements
3898
+ HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements
3899
+ HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements
3900
+
3901
+ // Convert to QF32
3902
+ HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero);
3903
+ HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero);
3904
+ HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero);
3905
+ HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero);
3906
+
3907
+ // Combine and convert to fp16
3908
+ HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf)));
3909
+ HVX_Vector vmax23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax3_qf, vmax2_qf)));
3910
+
3911
+ // Convert into fp16
3912
+ HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf)));
3913
+ HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
3914
+
3915
+ HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
3916
+ HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
3917
+ HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16);
3918
+ HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16);
3919
+
3920
+ // Divide input by the scale
3921
+ HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf);
3922
+ HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf);
3923
+ vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf));
3924
+ vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf));
3925
+
3926
+ // Convert to int8
3927
+ HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf);
3928
+ HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf);
3929
+ HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16);
3930
+
3931
+ *(HVX_Vector *) y_q = vx_i8;
3932
+
3933
+ // --- Sum calculation ---
3934
+ const HVX_Vector ones = Q6_Vb_vsplat_R(1);
3935
+ HVX_Vector v_sums = Q6_Vw_vrmpy_VbVb(vx_i8, ones); // sum every 4 consecutive elements
3936
+ // Sum 8 elements:
3937
+ v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 4));
3938
+ v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 8));
3939
+ v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 16));
3940
+
3941
+ // Copy to stack to extract sums and vmaxes
3942
+ float vmax0[32] __attribute__((aligned(128)));
3943
+ float vmax1[32] __attribute__((aligned(128)));
3944
+ float vmax2[32] __attribute__((aligned(128)));
3945
+ float vmax3[32] __attribute__((aligned(128)));
3946
+ int32_t sums[32] __attribute__((aligned(128)));
3947
+
3948
+ hvx_vec_store_u(vmax0, 128, vmax0_sf);
3949
+ hvx_vec_store_u(vmax1, 128, vmax1_sf);
3950
+ hvx_vec_store_u(vmax2, 128, vmax2_sf);
3951
+ hvx_vec_store_u(vmax3, 128, vmax3_sf);
3952
+ hvx_vec_store_u(sums, 128, v_sums);
3953
+
3954
+ float d0 = vmax0[0] / 127.0f;
3955
+ float d1 = vmax1[0] / 127.0f;
3956
+ float d2 = vmax2[0] / 127.0f;
3957
+ float d3 = vmax3[0] / 127.0f;
3958
+
3959
+ __fp16 * y_d_half = (__fp16 *) y_d;
3960
+ y_d_half[0] = d0;
3961
+ y_d_half[1] = (float) sums[0] * d0;
3962
+ y_d_half[2] = d1;
3963
+ y_d_half[3] = (float) sums[8] * d1;
3964
+ y_d_half[4] = d2;
3965
+ y_d_half[5] = (float) sums[16] * d2;
3966
+ y_d_half[6] = d3;
3967
+ y_d_half[7] = (float) sums[24] * d3;
3968
+ }
3969
+
2070
3970
  static inline void quantize_block_f32_q8x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
2071
3971
  assert((unsigned long) x % 128 == 0);
2072
3972
  assert((unsigned long) y_q % 128 == 0);
@@ -2248,7 +4148,7 @@ static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data)
2248
4148
  struct htp_matmul_context * mmctx = data;
2249
4149
  struct htp_ops_context * octx = mmctx->octx;
2250
4150
 
2251
- const struct htp_tensor * src = &octx->src1;
4151
+ const struct htp_tensor * src = octx->src[1];
2252
4152
  uint8_t * restrict dst = octx->src1_spad.data;
2253
4153
  struct htp_spad * spad = &octx->src0_spad;
2254
4154
  uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread;
@@ -2291,11 +4191,123 @@ static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data)
2291
4191
  ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
2292
4192
  }
2293
4193
 
4194
+ static void quantize_row_f32_q8_1x4x2(float * restrict x, uint8_t * restrict y, uint32_t k) {
4195
+ assert(k % 32 == 0);
4196
+ const uint32_t qk = QK_Q8_0x4x2;
4197
+ const uint32_t nb = (k + qk - 1) / qk;
4198
+
4199
+ const uint32_t qrow_size = k; // int8
4200
+
4201
+ const uint32_t dblk_size = 8 * 4; // 8x (d, s) __fp16 = 32 bytes
4202
+ const uint32_t qblk_size = QK_Q8_0x4x2; // int8
4203
+
4204
+ uint8_t * restrict y_q = (y + 0); // quants first
4205
+ uint8_t * restrict y_d = (y + qrow_size); // then scales/sums
4206
+
4207
+ // Temp scales override input since we're working off of the aligned temp buffer in VTCM
4208
+ uint8_t * restrict t_d = (uint8_t *) x;
4209
+
4210
+ for (uint32_t i = 0; i < nb; i++) {
4211
+ quantize_block_f32_q8_1x1(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);
4212
+ quantize_block_f32_q8_1x1(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);
4213
+ }
4214
+
4215
+ // now copy the scales/sums into final location
4216
+ hvx_copy_f16_ua(y_d, t_d, nb * 16);
4217
+ }
4218
+
4219
+ static void quantize_f32_q8_1x4x2(unsigned int nth, unsigned int ith, void * data) {
4220
+ struct htp_matmul_context * mmctx = data;
4221
+ struct htp_ops_context * octx = mmctx->octx;
4222
+
4223
+ const struct htp_tensor * src = octx->src[1];
4224
+ uint8_t * restrict dst = octx->src1_spad.data;
4225
+ struct htp_spad * spad = &octx->src0_spad;
4226
+ uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread;
4227
+
4228
+ uint64_t t1 = HAP_perf_get_qtimer_count();
4229
+
4230
+ const uint32_t ne0 = src->ne[0];
4231
+ const uint32_t ne1 = src->ne[1];
4232
+ const uint32_t ne2 = src->ne[2];
4233
+ const uint32_t ne3 = src->ne[3];
4234
+
4235
+ const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows
4236
+
4237
+ const uint32_t ir_first = nrows_per_thread * ith; // first row
4238
+ const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row
4239
+
4240
+ const size_t src_row_size = src->nb[1];
4241
+ const size_t dst_row_size = q8_1x4x2_row_size(ne0);
4242
+
4243
+ uint8_t * restrict src_data = (uint8_t *) src->data + (src_row_size * ir_first);
4244
+ uint8_t * restrict dst_data = (uint8_t *) dst + (dst_row_size * ir_first);
4245
+ uint8_t * restrict tmp_data = (uint8_t *) spad->data + (spad->size_per_thread * ith);
4246
+
4247
+ const size_t src_row_size_padded = hex_round_up(src_row_size, QK_Q8_0x4x2 * sizeof(float));
4248
+ memset(tmp_data, 0, src_row_size_padded); // zero-out temp row data for padding
4249
+
4250
+ for (uint32_t i = ir_first; i < ir_last; ++i) {
4251
+ hex_l2fetch(src_data, src_row_size, src_row_size, 2);
4252
+ hvx_copy_f32_aa(tmp_data, src_data, ne0);
4253
+
4254
+ quantize_row_f32_q8_1x4x2((float *) tmp_data, dst_data, ne0);
4255
+ dst_data += dst_row_size;
4256
+ src_data += src_row_size;
4257
+ }
4258
+
4259
+ uint64_t t2 = HAP_perf_get_qtimer_count();
4260
+
4261
+ FARF(HIGH, "quantize-f32-q8_1x4: %u/%u : n-rows %u (%u:%u) row-size %u -> %u usec %u\n", ith, nth, nrows, ir_first,
4262
+ ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
4263
+ }
4264
+
4265
+ static void quantize_f32_f32(unsigned int nth, unsigned int ith, void * data) {
4266
+ struct htp_matmul_context * mmctx = data;
4267
+ struct htp_ops_context * octx = mmctx->octx;
4268
+
4269
+ const struct htp_tensor * src = octx->src[1];
4270
+ uint8_t * restrict dst = octx->src1_spad.data;
4271
+ uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread;
4272
+ uint32_t dst_stride = octx->src1_spad.stride;
4273
+
4274
+ uint64_t t1 = HAP_perf_get_qtimer_count();
4275
+
4276
+ const uint32_t ne0 = src->ne[0];
4277
+ const uint32_t ne1 = src->ne[1];
4278
+ const uint32_t ne2 = src->ne[2];
4279
+ const uint32_t ne3 = src->ne[3];
4280
+
4281
+ const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows
4282
+
4283
+ const uint32_t ir_first = nrows_per_thread * ith; // first row
4284
+ const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row
4285
+
4286
+ const size_t src_row_size = ne0 * sizeof(float);
4287
+ const size_t src_stride = src->nb[1];
4288
+
4289
+ uint8_t * restrict src_data = (uint8_t *) src->data + (src_stride * ir_first);
4290
+ uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first);
4291
+
4292
+ for (uint32_t i = ir_first; i < ir_last; ++i) {
4293
+ hex_l2fetch(src_data, src_row_size, src_stride, 2);
4294
+ hvx_copy_f32_au(dst_data, src_data, ne0);
4295
+
4296
+ dst_data += dst_stride;
4297
+ src_data += src_stride;
4298
+ }
4299
+
4300
+ uint64_t t2 = HAP_perf_get_qtimer_count();
4301
+
4302
+ FARF(HIGH, "quantize-f32-f32: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
4303
+ ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
4304
+ }
4305
+
2294
4306
  static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) {
2295
4307
  struct htp_matmul_context * mmctx = data;
2296
4308
  struct htp_ops_context * octx = mmctx->octx;
2297
4309
 
2298
- const struct htp_tensor * src = &octx->src1;
4310
+ const struct htp_tensor * src = octx->src[1];
2299
4311
  uint8_t * restrict dst = octx->src1_spad.data;
2300
4312
  uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread;
2301
4313
  uint32_t dst_stride = octx->src1_spad.stride;
@@ -2337,7 +4349,7 @@ static void quantize_f16_f16(unsigned int nth, unsigned int ith, void * data) {
2337
4349
  struct htp_matmul_context * mmctx = data;
2338
4350
  struct htp_ops_context * octx = mmctx->octx;
2339
4351
 
2340
- const struct htp_tensor * src = &octx->src1;
4352
+ const struct htp_tensor * src = octx->src[1];
2341
4353
  uint8_t * restrict dst = octx->src1_spad.data;
2342
4354
  uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread;
2343
4355
  uint32_t dst_stride = octx->src1_spad.stride;
@@ -2386,18 +4398,35 @@ static int htp_mminit_vec_dot(struct htp_matmul_context * mmctx, enum htp_data_t
2386
4398
  mmctx->vec_dot_1x1 = vec_dot_q4x4x2_q8x4x2_1x1;
2387
4399
  mmctx->vec_dot_2x1 = vec_dot_q4x4x2_q8x4x2_2x1;
2388
4400
  mmctx->vec_dot_2x2 = vec_dot_q4x4x2_q8x4x2_2x2;
4401
+ mmctx->vec_dot_4x1 = vec_dot_q4x4x2_q8x4x2_4x1;
4402
+ return 0;
4403
+ case HTP_TYPE_Q4_1:
4404
+ mmctx->type = "q4_1x4x2-f32";
4405
+ mmctx->vec_dot_1x1 = vec_dot_q4_1x4x2_q8x4x2_1x1;
4406
+ mmctx->vec_dot_2x1 = vec_dot_q4_1x4x2_q8x4x2_2x1;
4407
+ mmctx->vec_dot_2x2 = vec_dot_q4_1x4x2_q8x4x2_2x2;
4408
+ mmctx->vec_dot_4x1 = vec_dot_q4_1x4x2_q8x4x2_4x1;
2389
4409
  return 0;
2390
4410
  case HTP_TYPE_Q8_0:
2391
4411
  mmctx->type = "q8x4x2-f32";
2392
4412
  mmctx->vec_dot_1x1 = vec_dot_q8x4x2_q8x4x2_1x1;
2393
4413
  mmctx->vec_dot_2x1 = vec_dot_q8x4x2_q8x4x2_2x1;
2394
4414
  mmctx->vec_dot_2x2 = vec_dot_q8x4x2_q8x4x2_2x2;
4415
+ mmctx->vec_dot_4x1 = vec_dot_q8x4x2_q8x4x2_4x1;
4416
+ return 0;
4417
+ case HTP_TYPE_IQ4_NL:
4418
+ mmctx->type = "iq4nlx4x2-f32";
4419
+ mmctx->vec_dot_1x1 = vec_dot_iq4nlx4x2_q8x4x2_1x1;
4420
+ mmctx->vec_dot_2x1 = vec_dot_iq4nlx4x2_q8x4x2_2x1;
4421
+ mmctx->vec_dot_2x2 = vec_dot_iq4nlx4x2_q8x4x2_2x2;
4422
+ mmctx->vec_dot_4x1 = vec_dot_iq4nlx4x2_q8x4x2_4x1;
2395
4423
  return 0;
2396
4424
  case HTP_TYPE_MXFP4:
2397
4425
  mmctx->type = "mxfp4x4x2-f32";
2398
4426
  mmctx->vec_dot_1x1 = vec_dot_mxfp4x4x2_q8x4x2_1x1;
2399
4427
  mmctx->vec_dot_2x1 = vec_dot_mxfp4x4x2_q8x4x2_2x1;
2400
4428
  mmctx->vec_dot_2x2 = vec_dot_mxfp4x4x2_q8x4x2_2x2;
4429
+ mmctx->vec_dot_4x1 = vec_dot_mxfp4x4x2_q8x4x2_4x1;
2401
4430
  return 0;
2402
4431
  default:
2403
4432
  return -1;
@@ -2430,7 +4459,7 @@ static void htp_mminit_spad(struct htp_ops_context * octx,
2430
4459
  octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
2431
4460
  }
2432
4461
 
2433
- int op_matmul(struct htp_ops_context * octx) {
4462
+ static int op_matmul_hvx(struct htp_ops_context * octx) {
2434
4463
  htp_matmul_tensors_preamble;
2435
4464
 
2436
4465
  struct htp_matmul_context mmctx_struct = {0};
@@ -2454,7 +4483,7 @@ int op_matmul(struct htp_ops_context * octx) {
2454
4483
  worker_callback_t quant_job_func;
2455
4484
  worker_callback_t matmul_job_func = src1_nrows > 1 ? matmul_2d : matvec_2d;
2456
4485
 
2457
- bool need_quant = !(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE);
4486
+ bool need_quant = true;
2458
4487
 
2459
4488
  if (src0->type == HTP_TYPE_F16) {
2460
4489
  // Try optimized f16-f16 path first (src1 in VTCM)
@@ -2468,7 +4497,7 @@ int op_matmul(struct htp_ops_context * octx) {
2468
4497
  // Default matmul implementation does not support multi-batch src0 (N-vs-N broadcasting).
2469
4498
  // It only supports 1-vs-N broadcasting (src0 is 2D) or standard 2D matmul.
2470
4499
  const bool is_batched = (ne02 > 1) || (ne03 > 1);
2471
- const bool is_permuted = htp_is_permuted(&octx->src0) || htp_is_permuted(&octx->src1);
4500
+ const bool is_permuted = htp_is_permuted(octx->src[0]) || htp_is_permuted(octx->src[1]);
2472
4501
 
2473
4502
  if (!is_batched && !is_permuted && f16_total_size <= octx->ctx->vtcm_size) {
2474
4503
  // Optimized path
@@ -2516,6 +4545,60 @@ int op_matmul(struct htp_ops_context * octx) {
2516
4545
  mmctx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]);
2517
4546
  mmctx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]);
2518
4547
 
4548
+ need_quant = false;
4549
+ }
4550
+ } else if (src0->type == HTP_TYPE_F32) {
4551
+ // Try optimized f32-f32 path first (src1 in VTCM)
4552
+ const size_t f32_src1_row_size = hex_round_up(ne10 * 4, 128);
4553
+ const size_t f32_src1_spad_size = hex_round_up(f32_src1_row_size * src1_nrows, 256);
4554
+ const size_t f32_src0_spad_size = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads;
4555
+ const size_t f32_dst_spad_size = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads;
4556
+
4557
+ const size_t f32_total_size = f32_src1_spad_size + f32_src0_spad_size + f32_dst_spad_size;
4558
+
4559
+ const bool is_batched = (ne02 > 1) || (ne03 > 1);
4560
+ const bool is_permuted = htp_is_permuted(octx->src[0]) || htp_is_permuted(octx->src[1]);
4561
+
4562
+ if (!is_batched && !is_permuted && f32_total_size <= octx->ctx->vtcm_size) {
4563
+ // Optimized path
4564
+ quant_job_func = quantize_f32_f32;
4565
+ mmctx->type = "f32-f32";
4566
+ mmctx->vec_dot_1x1 = vec_dot_f32_f32_aa_1x1;
4567
+ mmctx->vec_dot_2x1 = vec_dot_f32_f32_aa_2x1;
4568
+ mmctx->vec_dot_2x2 = vec_dot_f32_f32_aa_2x2;
4569
+
4570
+ src1_row_size = f32_src1_row_size;
4571
+
4572
+ octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
4573
+ octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
4574
+ octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256);
4575
+
4576
+ octx->src1_spad.size = octx->src1_spad.size_per_thread;
4577
+ octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
4578
+ octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
4579
+ } else {
4580
+ // Fallback to DDR / broadcasting
4581
+ quant_job_func = NULL;
4582
+ mmctx->type = "f32-f32";
4583
+ mmctx->vec_dot_1x1 = vec_dot_f32_f32_uu_1x1;
4584
+ matmul_job_func = matmul_4d;
4585
+
4586
+ src1_row_size = nb11;
4587
+
4588
+ octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
4589
+ octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256);
4590
+ octx->src1_spad.size_per_thread = hex_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256);
4591
+
4592
+ octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
4593
+ octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads;
4594
+ octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
4595
+
4596
+ // Init fastdiv for matmul_4d (supports broadcasting)
4597
+ mmctx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]);
4598
+ mmctx->mm_div_ne1 = init_fastdiv_values(dst->ne[1]);
4599
+ mmctx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]);
4600
+ mmctx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]);
4601
+
2519
4602
  need_quant = false;
2520
4603
  }
2521
4604
  } else {
@@ -2523,8 +4606,13 @@ int op_matmul(struct htp_ops_context * octx) {
2523
4606
  return HTP_STATUS_NO_SUPPORT;
2524
4607
  }
2525
4608
 
2526
- quant_job_func = quantize_f32_q8x4x2;
2527
- src1_row_size = q8x4x2_row_size(ne10);
4609
+ if (src0->type == HTP_TYPE_Q4_1) {
4610
+ quant_job_func = quantize_f32_q8_1x4x2;
4611
+ src1_row_size = q8_1x4x2_row_size(ne10);
4612
+ } else {
4613
+ quant_job_func = quantize_f32_q8x4x2;
4614
+ src1_row_size = q8x4x2_row_size(ne10);
4615
+ }
2528
4616
  htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, 0);
2529
4617
  }
2530
4618
 
@@ -2545,27 +4633,148 @@ int op_matmul(struct htp_ops_context * octx) {
2545
4633
  return HTP_STATUS_VTCM_TOO_SMALL;
2546
4634
  }
2547
4635
 
2548
- octx->src0_spad.data = octx->ctx->vtcm_base;
2549
- octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
2550
- octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
4636
+ // Place src1 spad first. We use it for dyn.quant and may reuse between ops
4637
+ octx->src1_spad.data = octx->ctx->vtcm_base;
4638
+ octx->src0_spad.data = octx->src1_spad.data + octx->src1_spad.size;
4639
+ octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size;
4640
+
4641
+ octx->src1_spad.src = (src1 == octx->src1_spad.src) ? src1 : NULL;
4642
+ octx->src0_spad.src = NULL;
4643
+ octx->dst_spad.src = NULL;
2551
4644
 
2552
4645
  octx->src0_spad.stride = src0_row_size_padded;
2553
4646
  octx->src1_spad.stride = src1_row_size;
2554
4647
 
2555
- if (need_quant) {
4648
+ if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)
4649
+ return HTP_STATUS_OK;
4650
+
4651
+ if (need_quant && !octx->src1_spad.src) {
2556
4652
  const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads);
2557
4653
  mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
2558
4654
  worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs);
4655
+ octx->src1_spad.src = src1;
2559
4656
  }
2560
4657
 
2561
- if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
2562
- const uint32_t n_matmul_jobs = octx->n_threads;
2563
- worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, mmctx, n_matmul_jobs);
2564
- }
4658
+ const uint32_t n_matmul_jobs = octx->n_threads;
4659
+ worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, mmctx, n_matmul_jobs);
2565
4660
 
2566
4661
  return HTP_STATUS_OK;
2567
4662
  }
2568
4663
 
4664
+ int op_matmul(struct htp_ops_context * octx) {
4665
+ htp_matmul_tensors_preamble;
4666
+
4667
+ #ifndef HTP_HAS_HMX
4668
+ return op_matmul_hvx(octx);
4669
+ #else
4670
+ if (!octx->ctx->hmx_enabled) {
4671
+ return op_matmul_hvx(octx);
4672
+ }
4673
+
4674
+ // HMX weight tile requires N to be 32-aligned.
4675
+ if (src0->ne[1] % 32 != 0) {
4676
+ return op_matmul_hvx(octx);
4677
+ }
4678
+
4679
+ // HMX supports F16, F32, Q4_0, Q8_0, IQ4_NL, MXFP4 weights.
4680
+ // Other types fall back to HVX.
4681
+ uint32_t wtype = src0->type;
4682
+ if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32 && wtype != HTP_TYPE_Q4_0 && wtype != HTP_TYPE_Q4_1 && wtype != HTP_TYPE_Q8_0 && wtype != HTP_TYPE_IQ4_NL && wtype != HTP_TYPE_MXFP4) {
4683
+ return op_matmul_hvx(octx);
4684
+ }
4685
+
4686
+ // Quantised HMX path requires K aligned to 256 (x4x2 super-block).
4687
+ // F16 and F32 HMX paths require K aligned to 32 (tile width).
4688
+ if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32 && src0->ne[0] % 256 != 0) {
4689
+ return op_matmul_hvx(octx);
4690
+ }
4691
+
4692
+ if ((wtype == HTP_TYPE_F16 || wtype == HTP_TYPE_F32) && src0->ne[0] % 32 != 0) {
4693
+ return op_matmul_hvx(octx);
4694
+ }
4695
+
4696
+ const bool is_batched = (src0->ne[2] * src0->ne[3] > 1 || src1->ne[2] * src1->ne[3] > 1);
4697
+
4698
+ // Quantised HMX kernels only handle flat 2D matmul (host already rejects
4699
+ // batched quantised, but guard here too). F16 batched matmul is handled
4700
+ // by the dedicated wrapper in hmx-matmul-ops.c.
4701
+ if (is_batched && src0->type != HTP_TYPE_F16) {
4702
+ return op_matmul_hvx(octx);
4703
+ }
4704
+
4705
+ // HMX assumes contiguous row-major layout. Fall back for permuted
4706
+ // tensors where strides are non-monotonic (e.g. transposed KV cache).
4707
+ if (src0->nb[0] > src0->nb[1] || src1->nb[0] > src1->nb[1]) {
4708
+ return op_matmul_hvx(octx);
4709
+ }
4710
+
4711
+ // M alignment: Use HMX when M >= 32, the last partial tile (m_total % 32 rows)
4712
+ // is handled by HMX itself; when M < 32 fall back to HVX.
4713
+ const int m_total = (int) src1->ne[1];
4714
+ const int m_hmx = m_total & ~31; // 0 when M < 32
4715
+ if (m_hmx == 0) {
4716
+ return op_matmul_hvx(octx);
4717
+ }
4718
+
4719
+ // Always re-quantize src1 since HMX kernel overwrites vtcm/spad,
4720
+ // so any previously cached quantized data is invalid.
4721
+ octx->src1_spad.src = NULL;
4722
+
4723
+ int k = (int) src0->ne[0]; // inner dimension
4724
+ int n = (int) src0->ne[1]; // weight columns
4725
+
4726
+ int ret = -1;
4727
+
4728
+ // Row strides in elements. For compact tensors these equal k; for
4729
+ // permuted attention views they can be larger, so pass the real stride.
4730
+ const int act_stride = (int)(src1->nb[1] / sizeof(float));
4731
+ const int wgt_stride = (int)(src0->nb[1] / sizeof(__fp16));
4732
+
4733
+ if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
4734
+ return HTP_STATUS_OK;
4735
+ }
4736
+
4737
+ if (is_batched) {
4738
+ if (src0->type == HTP_TYPE_F16) {
4739
+ hmx_matmul_f16_f32_batched_params_t batch_params = {
4740
+ .dst = (float *) dst->data,
4741
+ .activation = (float *) src1->data,
4742
+ .permuted_weight = (const __fp16 *) src0->data,
4743
+ .m = m_total,
4744
+ .k = k,
4745
+ .n = n,
4746
+ .act_stride = act_stride,
4747
+ .weight_stride = wgt_stride,
4748
+ .dst_stride = (int) (dst->nb[1] / sizeof(float)),
4749
+ .ne02 = ne02,
4750
+ .ne03 = ne03,
4751
+ .ne12 = ne12,
4752
+ .ne13 = ne13,
4753
+ .src0_nb2 = src0->nb[2],
4754
+ .src0_nb3 = src0->nb[3],
4755
+ .src1_nb2 = src1->nb[2],
4756
+ .src1_nb3 = src1->nb[3],
4757
+ .dst_nb2 = dst->nb[2],
4758
+ .dst_nb3 = dst->nb[3],
4759
+ };
4760
+ ret = hmx_matmul_f16_f32_batched(octx->ctx, &batch_params);
4761
+ } else {
4762
+ return op_matmul_hvx(octx);
4763
+ }
4764
+ } else {
4765
+ ret = hmx_matmul_2d_f32(octx->ctx, (float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data,
4766
+ m_total, k, n, act_stride, (int) src0->nb[1], (int) src0->type);
4767
+ }
4768
+
4769
+ if (ret != 0) {
4770
+ FARF(HIGH, "HMX matmul failed (ret=%d), falling back to HVX", ret);
4771
+ return op_matmul(octx);
4772
+ }
4773
+
4774
+ return 0;
4775
+ #endif // HTP_HAS_HMX
4776
+ }
4777
+
2569
4778
  int op_matmul_id(struct htp_ops_context * octx) {
2570
4779
  htp_matmul_tensors_preamble;
2571
4780
 
@@ -2573,7 +4782,7 @@ int op_matmul_id(struct htp_ops_context * octx) {
2573
4782
  struct htp_matmul_context * mmctx = &mmctx_struct;
2574
4783
  mmctx->octx = octx;
2575
4784
 
2576
- struct htp_tensor * restrict ids = &octx->src2;
4785
+ const struct htp_tensor * restrict ids = octx->src[2];
2577
4786
 
2578
4787
  const size_t src0_row_size = nb01;
2579
4788
  const size_t dst_row_size = nb1;
@@ -2599,15 +4808,42 @@ int op_matmul_id(struct htp_ops_context * octx) {
2599
4808
 
2600
4809
  size_t matrix_row_counts_size = n_as * sizeof(uint32_t);
2601
4810
  size_t matrix_row_map_size = n_as * ids->ne[0] * ids->ne[1] * sizeof(struct mmid_row_mapping);
4811
+ const size_t total_map_size = matrix_row_counts_size + matrix_row_map_size;
4812
+
4813
+ void * mapping_buf = NULL;
4814
+ bool must_free_mapping = false;
4815
+
4816
+ if (octx->ctx->ddr_spad_base && total_map_size <= octx->ctx->ddr_spad_size) {
4817
+ mapping_buf = octx->ctx->ddr_spad_base;
4818
+ } else {
4819
+ mapping_buf = memalign(128, total_map_size);
4820
+ if (mapping_buf) {
4821
+ must_free_mapping = true;
4822
+ } else {
4823
+ return HTP_STATUS_INTERNAL_ERR;
4824
+ }
4825
+ }
4826
+
4827
+ uint32_t * matrix_row_counts = (uint32_t *) mapping_buf;
4828
+ struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) ((uint8_t *) mapping_buf + matrix_row_counts_size);
4829
+
4830
+ mmctx->matrix_row_counts = matrix_row_counts;
4831
+ mmctx->matrix_rows = matrix_rows;
2602
4832
 
2603
4833
  if (htp_mminit_vec_dot(mmctx, src0->type) != 0) {
4834
+ if (must_free_mapping) free(mapping_buf);
2604
4835
  return HTP_STATUS_NO_SUPPORT;
2605
4836
  }
2606
4837
 
2607
- quant_job_func = quantize_f32_q8x4x2;
2608
- src1_row_size = q8x4x2_row_size(ne10);
4838
+ if (src0->type == HTP_TYPE_Q4_1) {
4839
+ quant_job_func = quantize_f32_q8_1x4x2;
4840
+ src1_row_size = q8_1x4x2_row_size(ne10);
4841
+ } else {
4842
+ quant_job_func = quantize_f32_q8x4x2;
4843
+ src1_row_size = q8x4x2_row_size(ne10);
4844
+ }
2609
4845
 
2610
- const size_t src2_spad_size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
4846
+ const size_t src2_spad_size_per_thread = 0; // We moved the mapping to DDR!
2611
4847
  htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, src2_spad_size_per_thread);
2612
4848
 
2613
4849
  size_t spad_size = octx->src2_spad.size + octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size;
@@ -2623,22 +4859,26 @@ int op_matmul_id(struct htp_ops_context * octx) {
2623
4859
  // Make sure the reserved vtcm size is sufficient
2624
4860
  if (octx->ctx->vtcm_size < spad_size) {
2625
4861
  FARF(ERROR, "matmul-id-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type, octx->ctx->vtcm_size, spad_size);
4862
+ if (must_free_mapping) free(mapping_buf);
2626
4863
  return HTP_STATUS_VTCM_TOO_SMALL;
2627
4864
  }
2628
4865
 
2629
- octx->src0_spad.data = octx->ctx->vtcm_base;
2630
- octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
2631
- octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size;
4866
+ // Place src1 spad first. We use it for dyn.quant and may reuse in subseq ops.
4867
+ octx->src1_spad.data = octx->ctx->vtcm_base;
4868
+ octx->src0_spad.data = octx->src1_spad.data + octx->src1_spad.size;
4869
+ octx->src2_spad.data = octx->src0_spad.data + octx->src0_spad.size;
2632
4870
  octx->dst_spad.data = octx->src2_spad.data + octx->src2_spad.size;
2633
4871
 
4872
+ octx->src1_spad.src = (src1 == octx->src1_spad.src) ? src1 : NULL;
4873
+ octx->src0_spad.src = NULL;
4874
+ octx->src2_spad.src = NULL;
4875
+ octx->dst_spad.src = NULL;
4876
+
2634
4877
  octx->src0_spad.stride = src0_row_size_padded;
2635
4878
  octx->src1_spad.stride = src1_row_size;
2636
4879
 
2637
4880
  if (src1_nrows > 1) {
2638
4881
  // initialize matrix_row_counts and map
2639
- uint32_t * matrix_row_counts = (uint32_t *) octx->src2_spad.data + 0;
2640
- struct mmid_row_mapping * matrix_rows = (void *) octx->src2_spad.data + matrix_row_counts_size;
2641
-
2642
4882
  memset(matrix_row_counts, 0, n_as * sizeof(uint32_t));
2643
4883
 
2644
4884
  // group rows by src0 matrix
@@ -2648,23 +4888,71 @@ int op_matmul_id(struct htp_ops_context * octx) {
2648
4888
 
2649
4889
  assert(i02 >= 0 && i02 < n_as);
2650
4890
 
2651
- MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) { id, iid1 };
4891
+ matrix_rows[i02 * n_ids * ids->ne[1] + matrix_row_counts[i02]] = (struct mmid_row_mapping) { id, iid1 };
2652
4892
  matrix_row_counts[i02] += 1;
2653
4893
  }
2654
4894
  }
2655
4895
  }
2656
4896
 
2657
- // Setup worker pool callbacks
2658
- if (!(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE)) {
4897
+ if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
4898
+ if (must_free_mapping) free(mapping_buf);
4899
+ return HTP_STATUS_OK;
4900
+ }
4901
+
4902
+ bool hmx_eligible = false;
4903
+ #ifdef HTP_HAS_HMX
4904
+ if (octx->ctx->hmx_enabled && src1_nrows > 1) {
4905
+ uint32_t wtype = src0->type;
4906
+ if (ne01 % 32 == 0 &&
4907
+ (wtype == HTP_TYPE_F16 || wtype == HTP_TYPE_F32 || wtype == HTP_TYPE_Q4_0 || wtype == HTP_TYPE_Q4_1 || wtype == HTP_TYPE_Q8_0 || wtype == HTP_TYPE_IQ4_NL || wtype == HTP_TYPE_MXFP4)) {
4908
+ if ((wtype == HTP_TYPE_F16 || wtype == HTP_TYPE_F32) && ne00 % 32 == 0) {
4909
+ hmx_eligible = true;
4910
+ } else if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32 && ne00 % 256 == 0) {
4911
+ hmx_eligible = true;
4912
+ }
4913
+ }
4914
+ }
4915
+ #endif
4916
+
4917
+ mmctx->hmx_eligible = hmx_eligible;
4918
+
4919
+ if (hmx_eligible) {
4920
+ for (uint32_t cur_a = 0; cur_a < n_as; ++cur_a) {
4921
+ const int32_t cne1 = matrix_row_counts[cur_a];
4922
+ if (cne1 == 0) continue;
4923
+
4924
+ int ret = hmx_matmul_id_2d_f32(octx->ctx, (float*) dst->data, (float*) src1->data,
4925
+ (const uint8_t *) src0->data + cur_a * nb02,
4926
+ cne1, ne00, ne01,
4927
+ ne11,
4928
+ nb11, nb12,
4929
+ nb1, nb2,
4930
+ (int) src0->nb[1], (int) src0->type,
4931
+ matrix_rows, cur_a, n_ids * ids->ne[1]);
4932
+ if (ret != 0) {
4933
+ FARF(ERROR, "HMX matmul failed for expert %u, error %d\n", cur_a, ret);
4934
+ if (must_free_mapping) free(mapping_buf);
4935
+ return HTP_STATUS_NO_SUPPORT;
4936
+ }
4937
+ }
4938
+
4939
+ // HMX has overwritten VTCM, so force dynamic quantization cache to clear
4940
+ octx->src1_spad.src = NULL;
4941
+
4942
+ if (must_free_mapping) free(mapping_buf);
4943
+ return HTP_STATUS_OK;
4944
+ }
4945
+
4946
+ if (octx->src1_spad.src != src1) {
2659
4947
  const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads);
2660
4948
  mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
2661
4949
  worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs);
4950
+ octx->src1_spad.src = src1;
2662
4951
  }
2663
4952
 
2664
- if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
2665
- const uint32_t n_matmul_jobs = octx->n_threads;
2666
- worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, mmctx, n_matmul_jobs);
2667
- }
4953
+ const uint32_t n_matmul_jobs = octx->n_threads;
4954
+ worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, mmctx, n_matmul_jobs);
2668
4955
 
4956
+ if (must_free_mapping) free(mapping_buf);
2669
4957
  return HTP_STATUS_OK;
2670
4958
  }