whispercpp 1.3.5 → 1.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (610) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +99 -2
  4. data/ext/extconf.rb +1 -0
  5. data/ext/ruby_whisper.c +20 -4
  6. data/ext/ruby_whisper.h +30 -2
  7. data/ext/ruby_whisper_context.c +216 -124
  8. data/ext/ruby_whisper_context_params.c +163 -0
  9. data/ext/ruby_whisper_model.c +0 -1
  10. data/ext/ruby_whisper_params.c +0 -1
  11. data/ext/ruby_whisper_segment.c +0 -1
  12. data/ext/ruby_whisper_token.c +29 -9
  13. data/ext/ruby_whisper_transcribe.cpp +4 -1
  14. data/ext/ruby_whisper_vad_context.c +48 -1
  15. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  16. data/ext/ruby_whisper_vad_params.c +0 -1
  17. data/ext/ruby_whisper_vad_segment.c +0 -1
  18. data/ext/ruby_whisper_vad_segments.c +0 -1
  19. data/ext/sources/CMakeLists.txt +1 -1
  20. data/ext/sources/bindings/javascript/package.json +1 -1
  21. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  22. data/ext/sources/examples/bench/bench.cpp +23 -18
  23. data/ext/sources/examples/cli/cli.cpp +8 -0
  24. data/ext/sources/examples/common-ggml.cpp +2 -0
  25. data/ext/sources/examples/miniaudio.h +4507 -2131
  26. data/ext/sources/examples/server/server.cpp +18 -4
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -2
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +7 -13
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +4 -3
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +335 -17
  31. data/ext/sources/examples/talk-llama/llama-arch.h +42 -0
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +3 -1
  33. data/ext/sources/examples/talk-llama/llama-chat.cpp +21 -1
  34. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  35. data/ext/sources/examples/talk-llama/llama-context.cpp +508 -520
  36. data/ext/sources/examples/talk-llama/llama-context.h +27 -28
  37. data/ext/sources/examples/talk-llama/llama-cparams.h +5 -0
  38. data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +8 -8
  40. data/ext/sources/examples/talk-llama/llama-graph.cpp +583 -130
  41. data/ext/sources/examples/talk-llama/llama-graph.h +131 -10
  42. data/ext/sources/examples/talk-llama/llama-hparams.cpp +57 -40
  43. data/ext/sources/examples/talk-llama/llama-hparams.h +79 -10
  44. data/ext/sources/examples/talk-llama/llama-impl.cpp +4 -4
  45. data/ext/sources/examples/talk-llama/llama-impl.h +13 -1
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +3 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +274 -89
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.h +2 -3
  49. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
  50. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
  51. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +11 -13
  52. data/ext/sources/examples/talk-llama/llama-mmap.cpp +28 -11
  53. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +527 -119
  54. data/ext/sources/examples/talk-llama/llama-model-loader.h +35 -5
  55. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +60 -46
  56. data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
  57. data/ext/sources/examples/talk-llama/llama-model.cpp +1365 -647
  58. data/ext/sources/examples/talk-llama/llama-model.h +72 -19
  59. data/ext/sources/examples/talk-llama/llama-quant.cpp +578 -346
  60. data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +190 -76
  61. data/ext/sources/examples/talk-llama/{llama-sampling.h → llama-sampler.h} +0 -2
  62. data/ext/sources/examples/talk-llama/llama-vocab.cpp +118 -48
  63. data/ext/sources/examples/talk-llama/llama-vocab.h +5 -0
  64. data/ext/sources/examples/talk-llama/llama.cpp +76 -22
  65. data/ext/sources/examples/talk-llama/llama.h +63 -30
  66. data/ext/sources/examples/talk-llama/models/afmoe.cpp +2 -3
  67. data/ext/sources/examples/talk-llama/models/apertus.cpp +3 -3
  68. data/ext/sources/examples/talk-llama/models/arcee.cpp +3 -3
  69. data/ext/sources/examples/talk-llama/models/arctic.cpp +4 -5
  70. data/ext/sources/examples/talk-llama/models/baichuan.cpp +4 -3
  71. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +1 -2
  72. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +3 -5
  73. data/ext/sources/examples/talk-llama/models/bert.cpp +13 -7
  74. data/ext/sources/examples/talk-llama/models/bitnet.cpp +9 -24
  75. data/ext/sources/examples/talk-llama/models/bloom.cpp +2 -2
  76. data/ext/sources/examples/talk-llama/models/chameleon.cpp +3 -3
  77. data/ext/sources/examples/talk-llama/models/chatglm.cpp +2 -2
  78. data/ext/sources/examples/talk-llama/models/codeshell.cpp +3 -3
  79. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +3 -3
  80. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +2 -2
  81. data/ext/sources/examples/talk-llama/models/command-r.cpp +2 -2
  82. data/ext/sources/examples/talk-llama/models/dbrx.cpp +4 -5
  83. data/ext/sources/examples/talk-llama/models/deci.cpp +3 -3
  84. data/ext/sources/examples/talk-llama/models/deepseek.cpp +4 -6
  85. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +24 -21
  86. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
  87. data/ext/sources/examples/talk-llama/models/dots1.cpp +4 -6
  88. data/ext/sources/examples/talk-llama/models/dream.cpp +3 -3
  89. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +4 -6
  90. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +3 -3
  91. data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
  92. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +3 -3
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +3 -3
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +2 -4
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +3 -3
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +1 -1
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +1 -1
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +1 -1
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +1 -1
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +7 -7
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +3 -3
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +14 -7
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +2 -2
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +2 -2
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +4 -5
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +4 -5
  108. data/ext/sources/examples/talk-llama/models/grok.cpp +4 -4
  109. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +5 -7
  110. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +3 -3
  111. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +4 -5
  112. data/ext/sources/examples/talk-llama/models/internlm2.cpp +3 -3
  113. data/ext/sources/examples/talk-llama/models/jais.cpp +2 -2
  114. data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +3 -3
  116. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
  117. data/ext/sources/examples/talk-llama/models/lfm2.cpp +145 -124
  118. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +4 -4
  119. data/ext/sources/examples/talk-llama/models/llada.cpp +3 -3
  120. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +4 -4
  121. data/ext/sources/examples/talk-llama/models/llama.cpp +18 -11
  122. data/ext/sources/examples/talk-llama/models/maincoder.cpp +3 -3
  123. data/ext/sources/examples/talk-llama/models/{graph-context-mamba.cpp → mamba-base.cpp} +9 -3
  124. data/ext/sources/examples/talk-llama/models/mamba.cpp +1 -2
  125. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +11 -5
  126. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +14 -13
  127. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +4 -5
  128. data/ext/sources/examples/talk-llama/models/mistral3.cpp +4 -4
  129. data/ext/sources/examples/talk-llama/models/models.h +181 -46
  130. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +2 -9
  131. data/ext/sources/examples/talk-llama/models/mpt.cpp +2 -2
  132. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +26 -14
  133. data/ext/sources/examples/talk-llama/models/nemotron.cpp +3 -3
  134. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +2 -2
  135. data/ext/sources/examples/talk-llama/models/olmo.cpp +3 -3
  136. data/ext/sources/examples/talk-llama/models/olmo2.cpp +3 -3
  137. data/ext/sources/examples/talk-llama/models/olmoe.cpp +4 -4
  138. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +1 -1
  139. data/ext/sources/examples/talk-llama/models/openelm.cpp +3 -3
  140. data/ext/sources/examples/talk-llama/models/orion.cpp +3 -3
  141. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
  142. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +3 -3
  143. data/ext/sources/examples/talk-llama/models/phi2.cpp +2 -2
  144. data/ext/sources/examples/talk-llama/models/phi3.cpp +3 -3
  145. data/ext/sources/examples/talk-llama/models/plamo.cpp +3 -3
  146. data/ext/sources/examples/talk-llama/models/plamo2.cpp +9 -5
  147. data/ext/sources/examples/talk-llama/models/plamo3.cpp +2 -2
  148. data/ext/sources/examples/talk-llama/models/plm.cpp +15 -14
  149. data/ext/sources/examples/talk-llama/models/qwen.cpp +2 -2
  150. data/ext/sources/examples/talk-llama/models/qwen2.cpp +3 -3
  151. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +4 -4
  152. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +3 -3
  153. data/ext/sources/examples/talk-llama/models/qwen3.cpp +12 -9
  154. data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
  155. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
  156. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +15 -8
  157. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +84 -432
  158. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +9 -18
  159. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +8 -17
  160. data/ext/sources/examples/talk-llama/models/refact.cpp +2 -2
  161. data/ext/sources/examples/talk-llama/models/rnd1.cpp +4 -4
  162. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +2 -0
  163. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +2 -0
  164. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +3 -3
  165. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +4 -4
  166. data/ext/sources/examples/talk-llama/models/smollm3.cpp +3 -3
  167. data/ext/sources/examples/talk-llama/models/stablelm.cpp +2 -2
  168. data/ext/sources/examples/talk-llama/models/starcoder.cpp +2 -2
  169. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +3 -3
  170. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
  171. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +2 -2
  172. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +2 -2
  173. data/ext/sources/examples/talk-llama/models/xverse.cpp +3 -3
  174. data/ext/sources/examples/talk-llama/unicode.cpp +21 -65
  175. data/ext/sources/ggml/CMakeLists.txt +9 -3
  176. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  177. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  179. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  180. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  181. data/ext/sources/ggml/include/ggml-rpc.h +6 -1
  182. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  183. data/ext/sources/ggml/include/ggml.h +56 -9
  184. data/ext/sources/ggml/src/CMakeLists.txt +3 -0
  185. data/ext/sources/ggml/src/ggml-alloc.c +4 -9
  186. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  187. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  188. data/ext/sources/ggml/src/ggml-backend-reg.cpp +28 -86
  189. data/ext/sources/ggml/src/ggml-backend.cpp +5 -2
  190. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  191. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +6 -2
  192. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  193. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  194. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +348 -189
  195. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +40 -85
  196. data/ext/sources/ggml/src/ggml-cann/common.h +3 -4
  197. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +44 -62
  198. data/ext/sources/ggml/src/ggml-common.h +11 -0
  199. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +16 -11
  200. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -19
  201. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  202. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
  203. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +85 -1
  204. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2744 -548
  205. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1653 -0
  206. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
  207. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
  208. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
  209. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +118 -18
  210. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +107 -26
  211. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  212. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  213. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -0
  214. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +59 -12
  215. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +15 -0
  216. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  217. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +965 -252
  218. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +584 -197
  219. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +903 -188
  220. data/ext/sources/ggml/src/ggml-cpu/ops.h +1 -0
  221. data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
  222. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  223. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2890 -679
  224. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  225. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  226. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +111 -3
  227. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  228. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +17 -0
  229. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
  230. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +19 -10
  231. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +32 -30
  232. data/ext/sources/ggml/src/ggml-cuda/common.cuh +134 -18
  233. data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +6 -3
  235. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +78 -64
  236. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +384 -143
  237. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +36 -22
  238. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +3 -3
  239. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +26 -5
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  241. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +127 -12
  242. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
  243. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  244. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +595 -200
  245. data/ext/sources/ggml/src/ggml-cuda/mean.cu +9 -8
  246. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +173 -6
  247. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  248. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +158 -85
  249. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +34 -22
  250. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +127 -67
  251. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  252. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +157 -65
  253. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -0
  254. data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
  255. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  256. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +1 -1
  257. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
  258. data/ext/sources/ggml/src/ggml-cuda/rope.cu +233 -133
  259. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  260. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  261. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +56 -32
  262. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
  264. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  265. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
  266. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  267. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  268. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  269. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +3 -3
  270. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +0 -1
  271. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +199 -135
  272. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  273. data/ext/sources/ggml/src/ggml-cuda/unary.cu +55 -0
  274. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
  275. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
  276. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +10 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +82 -45
  278. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +334 -160
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +7 -5
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +328 -197
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +765 -234
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +412 -265
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +23 -23
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +1 -1
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.h → hex-dma.h} +28 -3
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +1 -1
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +27 -37
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +6 -35
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  297. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +20 -1347
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +211 -13
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +1119 -952
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +254 -244
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +36 -36
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +155 -138
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +209 -114
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  317. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  321. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +6 -0
  322. data/ext/sources/ggml/src/ggml-impl.h +62 -0
  323. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  324. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  325. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  326. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  327. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +274 -73
  328. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +22 -4
  329. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +102 -36
  330. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +174 -23
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +580 -280
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +5 -4
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +320 -107
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1068 -825
  335. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +19 -1
  336. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +3108 -636
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +204 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  367. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  368. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  369. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
  370. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  371. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
  372. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  373. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
  374. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
  375. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  376. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  377. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  378. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  379. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  380. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  381. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  382. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  383. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  384. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  385. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  386. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  387. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  388. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  389. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  390. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  391. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  392. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
  393. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  394. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  395. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  396. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  397. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  398. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  399. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
  400. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
  401. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
  402. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
  403. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  404. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  405. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  406. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  407. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
  412. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
  413. data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
  414. data/ext/sources/ggml/src/ggml-quants.c +96 -5
  415. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  416. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +15 -88
  417. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  418. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +1 -0
  419. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  420. data/ext/sources/ggml/src/ggml-sycl/common.hpp +315 -10
  421. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +69 -1
  422. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
  423. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  424. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  425. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +78 -68
  426. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  427. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
  428. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
  429. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
  430. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
  431. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
  432. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  433. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
  434. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
  435. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +316 -51
  436. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  437. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  438. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  439. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
  440. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  441. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  442. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  443. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  444. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  445. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  446. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  447. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  448. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  449. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  450. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  451. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  452. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
  453. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
  454. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
  455. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
  456. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
  457. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
  458. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
  459. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
  460. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
  461. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
  462. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
  463. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
  464. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
  465. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
  466. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
  467. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
  468. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
  469. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
  470. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
  471. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
  472. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
  473. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
  474. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
  475. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
  476. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
  477. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
  478. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
  479. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
  480. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
  481. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
  482. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
  483. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
  484. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
  485. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
  486. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
  487. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
  488. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +13 -0
  489. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  490. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  491. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  492. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  493. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  494. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  495. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  496. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  497. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  498. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  499. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  500. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  501. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  502. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  503. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  504. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  505. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  506. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  507. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  508. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  509. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  510. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  511. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  512. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  513. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  514. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  515. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  516. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  517. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  518. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  519. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  520. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  521. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  522. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  523. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  524. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  525. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  526. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  527. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  528. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  529. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  530. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  531. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  532. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +1 -1
  533. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1250 -465
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +374 -170
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +66 -22
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +389 -201
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +106 -58
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -3
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +8 -4
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +3 -3
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +36 -63
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +10 -5
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +16 -10
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +55 -35
  560. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1314 -109
  561. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1660 -1371
  562. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  563. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  564. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  565. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
  566. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
  567. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
  568. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +6 -0
  569. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  570. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +40 -5
  571. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +105 -60
  572. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
  573. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +68 -257
  574. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +692 -23
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_reg_tile.tmpl.wgsl → mul_mat_reg_tile.wgsl} +28 -128
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +31 -137
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +9 -36
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  584. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
  585. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  586. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +9 -6
  587. data/ext/sources/ggml/src/ggml.c +167 -33
  588. data/ext/sources/ggml/src/gguf.cpp +229 -44
  589. data/ext/sources/src/whisper.cpp +6 -28
  590. data/sig/whisper.rbs +43 -2
  591. data/test/test_context_params.rb +82 -0
  592. data/test/test_token.rb +11 -0
  593. data/test/test_vad_context.rb +58 -8
  594. data/test/test_whisper.rb +20 -0
  595. data/whispercpp.gemspec +1 -1
  596. metadata +240 -28
  597. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  598. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  599. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  600. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  601. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  602. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  603. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  604. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  605. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  606. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  607. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  608. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  609. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  610. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
@@ -3,105 +3,50 @@
3
3
  #pragma clang diagnostic ignored "-Wunused-variable"
4
4
  #pragma clang diagnostic ignored "-Wunused-but-set-variable"
5
5
 
6
- #ifdef HTP_DEBUG
7
- # define FARF_HIGH 1
8
- #endif
9
-
10
6
  #include <HAP_farf.h>
11
- #include <HAP_mem.h>
12
7
  #include <HAP_perf.h>
13
- #include <HAP_ps.h>
14
- #include <hexagon_protos.h>
15
- #include <hexagon_types.h>
8
+
16
9
  #include <math.h>
17
- #include <qurt_thread.h>
18
10
  #include <string.h>
19
11
 
12
+ #include "hex-dma.h"
13
+ #include "hvx-utils.h"
14
+ #include "hvx-dump.h"
15
+
20
16
  #define GGML_COMMON_DECL_C
21
17
  #include "ggml-common.h"
22
18
  #include "htp-ctx.h"
23
- #include "htp-dma.h"
24
19
  #include "htp-msg.h"
25
20
  #include "htp-ops.h"
26
- #include "hvx-utils.h"
27
- #include "ops-utils.h"
28
21
 
29
22
  #define MM_SPAD_SRC0_NROWS 16
30
23
  #define MM_SPAD_SRC1_NROWS 16
31
24
  #define MM_SPAD_DST_NROWS 2
32
25
 
33
- struct htp_matmul_type {
26
+ struct htp_matmul_context {
34
27
  const char * type;
35
- void (*vec_dot)(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
36
- void (*vec_dot_rx2)(const int n, float * restrict s, const void * restrict vx, uint32_t vx_row_size, const void * restrict vy);
37
- };
28
+ struct htp_ops_context * octx;
38
29
 
39
- typedef struct {
40
- HVX_Vector v[2];
41
- } HVX_Vector_x2;
42
-
43
- typedef struct {
44
- HVX_Vector v[4];
45
- } HVX_Vector_x4;
46
-
47
- typedef struct {
48
- HVX_Vector v[8];
49
- } HVX_Vector_x8;
50
-
51
- // vdelta control to replicate first 4x fp32 values across lanes
52
- static const uint8_t __attribute__((aligned(128))) repl_4x_fp32[128] = {
53
- 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10,
54
- 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20,
55
- 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04,
56
- 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40,
57
- 0x44, 0x44, 0x44, 0x44, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04,
58
- 0x04, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
59
- 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10,
60
- };
30
+ void (*vec_dot_1x1)(const int n, float * restrict s0,
31
+ const void * restrict vx0,
32
+ const void * restrict vy0);
61
33
 
62
- // vdelta control to replicate and interleave first 8x fp32 values across lanes
63
- static const uint8_t __attribute__((aligned(128))) repl_interleave_8x_fp32[128] = {
64
- 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x00, 0x00, 0x00,
65
- 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20,
66
- 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04,
67
- 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40,
68
- 0x44, 0x44, 0x44, 0x44, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, 0x44, 0x44, 0x44,
69
- 0x44, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
70
- 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20,
71
- };
34
+ void (*vec_dot_2x1)(const int n, float * restrict s0,
35
+ const void * restrict vx0, const void * restrict vx1,
36
+ const void * restrict vy0);
72
37
 
73
- // vdelta control to replicate first fp32 value across all elements
74
- static const uint8_t __attribute__((aligned(128))) repl_1x_fp32[128] = {
75
- 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10,
76
- 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
77
- 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08,
78
- 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08,
79
- 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04,
80
- 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10,
81
- 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
82
- };
38
+ void (*vec_dot_2x2)(const int n, float * restrict s0, float * restrict s1,
39
+ const void * restrict vx0, const void * restrict vx1,
40
+ const void * restrict vy0, const void * restrict vy1);
83
41
 
84
- // vdelta control to replicate first fp16 value across all elements
85
- static const uint8_t __attribute__((aligned(128))) repl_1x_fp16[128] = {
86
- 0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02,
87
- 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04,
88
- 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08,
89
- 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x40, 0x40, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02,
90
- 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02,
91
- 0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10,
92
- 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
93
- };
42
+ // Precomputed values
43
+ uint32_t src0_nrows_per_thread;
44
+ uint32_t src1_nrows_per_thread;
94
45
 
95
- // vdelta control to replicate first fp16 value across all elements
96
- static const uint8_t __attribute__((aligned(128))) repl_2x_fp16[128] = {
97
- 0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
98
- 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
99
- 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
100
- 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
101
- 0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
102
- 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
103
- 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
104
- 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
46
+ struct fastdiv_values mm_div_ne12_ne1;
47
+ struct fastdiv_values mm_div_ne1;
48
+ struct fastdiv_values mm_div_r2;
49
+ struct fastdiv_values mm_div_r3;
105
50
  };
106
51
 
107
52
  // vdelta control to expand first 32 e8m0 values into 32 uint32 elements
@@ -129,10 +74,10 @@ static inline size_t q8x4x2_row_size(uint32_t ne) {
129
74
  // ensures perfect alignment of quants and full row
130
75
  const uint32_t qk = QK_Q8_0x4x2;
131
76
  const uint32_t nb = (ne + qk - 1) / qk;
132
- return htp_round_up(ne + nb * 8 * sizeof(__fp16), 128);
77
+ return hex_round_up(ne + nb * 8 * sizeof(__fp16), 128);
133
78
  }
134
79
 
135
- static inline HVX_Vector_x8 hvx_vec_load_q4x4x8(const uint8_t * restrict ptr) {
80
+ static inline HVX_Vector_x8 hvx_vec_load_q4x4x8_full(const uint8_t * restrict ptr) {
136
81
  const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
137
82
 
138
83
  HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes)
@@ -141,10 +86,11 @@ static inline HVX_Vector_x8 hvx_vec_load_q4x4x8(const uint8_t * restrict ptr) {
141
86
  HVX_Vector v6_7 = vptr[3]; // ...
142
87
 
143
88
  const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
89
+ const HVX_Vector i8 = Q6_Vb_vsplat_R(8);
144
90
 
145
- HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F
146
- HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4
147
- HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F
91
+ HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F : first 128 elements
92
+ HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 : second 128 elements
93
+ HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F ...
148
94
  HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4
149
95
  HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F
150
96
  HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4
@@ -152,21 +98,54 @@ static inline HVX_Vector_x8 hvx_vec_load_q4x4x8(const uint8_t * restrict ptr) {
152
98
  HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4
153
99
 
154
100
  // Convert uint4 to int4 (i.e. x - 8)
155
- const HVX_Vector i8 = Q6_Vb_vsplat_R(8);
156
- v0 = Q6_Vb_vsub_VbVb(v0, i8);
157
- v1 = Q6_Vb_vsub_VbVb(v1, i8);
158
- v2 = Q6_Vb_vsub_VbVb(v2, i8);
159
- v3 = Q6_Vb_vsub_VbVb(v3, i8);
160
- v4 = Q6_Vb_vsub_VbVb(v4, i8);
161
- v5 = Q6_Vb_vsub_VbVb(v5, i8);
162
- v6 = Q6_Vb_vsub_VbVb(v6, i8);
163
- v7 = Q6_Vb_vsub_VbVb(v7, i8);
101
+ v0 = Q6_Vb_vsub_VbVb(v0, i8);
102
+ v1 = Q6_Vb_vsub_VbVb(v1, i8);
103
+ v2 = Q6_Vb_vsub_VbVb(v2, i8);
104
+ v3 = Q6_Vb_vsub_VbVb(v3, i8);
105
+ v4 = Q6_Vb_vsub_VbVb(v4, i8);
106
+ v5 = Q6_Vb_vsub_VbVb(v5, i8);
107
+ v6 = Q6_Vb_vsub_VbVb(v6, i8);
108
+ v7 = Q6_Vb_vsub_VbVb(v7, i8);
164
109
 
165
110
  HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 };
166
111
  return r;
167
112
  }
168
113
 
169
- static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8(const uint8_t * restrict ptr) {
114
+ static HVX_Vector_x8 hvx_vec_load_q4x4x8_partial(const uint8_t * restrict ptr, uint32_t n) {
115
+ const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
116
+
117
+ const uint32_t qk = QK_Q4_0x4x2; // 256
118
+ const uint32_t nb = n / qk;
119
+ const uint32_t nloe = n % qk;
120
+
121
+ const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
122
+ const HVX_Vector i8 = Q6_Vb_vsplat_R(8);
123
+
124
+ HVX_Vector_x8 r;
125
+ uint32_t i = 0;
126
+
127
+ #pragma unroll(2)
128
+ for (i=0; i < nb; i++) {
129
+ HVX_Vector v = vptr[i]; // 256 elements (128 bytes)
130
+ HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : first 128 elements
131
+ HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : second 128 elements
132
+ r.v[i*2+0] = Q6_Vb_vsub_VbVb(v0, i8);
133
+ r.v[i*2+1] = Q6_Vb_vsub_VbVb(v1, i8);
134
+ }
135
+
136
+ if (nloe) {
137
+ HVX_Vector v = vptr[i]; // 256 elements (128 bytes)
138
+ HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : even 128 elements
139
+ HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : odd 128 elements
140
+ HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:...
141
+ r.v[i*2+0] = Q6_Vb_vsub_VbVb(Q6_V_lo_W(v0_1_p), i8);
142
+ r.v[i*2+1] = Q6_Vb_vsub_VbVb(Q6_V_hi_W(v0_1_p), i8);
143
+ }
144
+
145
+ return r;
146
+ }
147
+
148
+ static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8_full(const uint8_t * restrict ptr) {
170
149
  const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
171
150
 
172
151
  HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes)
@@ -175,6 +154,7 @@ static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8(const uint8_t * restrict ptr)
175
154
  HVX_Vector v6_7 = vptr[3]; // ...
176
155
 
177
156
  const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
157
+ const HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut;
178
158
 
179
159
  HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F
180
160
  HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4
@@ -185,21 +165,54 @@ static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8(const uint8_t * restrict ptr)
185
165
  HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F
186
166
  HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4
187
167
 
188
- HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut;
189
- v0 = Q6_Vb_vlut32_VbVbI(v0, lut, 0);
190
- v1 = Q6_Vb_vlut32_VbVbI(v1, lut, 0);
191
- v2 = Q6_Vb_vlut32_VbVbI(v2, lut, 0);
192
- v3 = Q6_Vb_vlut32_VbVbI(v3, lut, 0);
193
- v4 = Q6_Vb_vlut32_VbVbI(v4, lut, 0);
194
- v5 = Q6_Vb_vlut32_VbVbI(v5, lut, 0);
195
- v6 = Q6_Vb_vlut32_VbVbI(v6, lut, 0);
196
- v7 = Q6_Vb_vlut32_VbVbI(v7, lut, 0);
168
+ v0 = Q6_Vb_vlut32_VbVbI(v0, lut, 0);
169
+ v1 = Q6_Vb_vlut32_VbVbI(v1, lut, 0);
170
+ v2 = Q6_Vb_vlut32_VbVbI(v2, lut, 0);
171
+ v3 = Q6_Vb_vlut32_VbVbI(v3, lut, 0);
172
+ v4 = Q6_Vb_vlut32_VbVbI(v4, lut, 0);
173
+ v5 = Q6_Vb_vlut32_VbVbI(v5, lut, 0);
174
+ v6 = Q6_Vb_vlut32_VbVbI(v6, lut, 0);
175
+ v7 = Q6_Vb_vlut32_VbVbI(v7, lut, 0);
197
176
 
198
177
  HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 };
199
178
  return r;
200
179
  }
201
180
 
202
- static inline HVX_Vector_x8 hvx_vec_load_q8x4x8(const uint8_t * restrict ptr) {
181
+ static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8_partial(const uint8_t * restrict ptr, uint32_t n) {
182
+ const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
183
+
184
+ const uint32_t qk = QK_Q4_0x4x2; // 256
185
+ const uint32_t nb = n / qk;
186
+ const uint32_t nloe = n % qk;
187
+
188
+ const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
189
+ const HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut;
190
+
191
+ HVX_Vector_x8 r;
192
+ uint32_t i = 0;
193
+
194
+ #pragma unroll(2)
195
+ for (i=0; i < nb; i++) {
196
+ HVX_Vector v = vptr[i]; // 256 elements (128 bytes)
197
+ HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : first 128 elements
198
+ HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : second 128 elements
199
+ r.v[i*2+0] = Q6_Vb_vlut32_VbVbI(v0, lut, 0);
200
+ r.v[i*2+1] = Q6_Vb_vlut32_VbVbI(v1, lut, 0);
201
+ }
202
+
203
+ if (nloe) {
204
+ HVX_Vector v = vptr[i]; // 256 elements (128 bytes)
205
+ HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : even 128 elements
206
+ HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : odd 128 elements
207
+ HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:...
208
+ r.v[i*2+0] = Q6_Vb_vlut32_VbVbI(Q6_V_lo_W(v0_1_p), lut, 0);
209
+ r.v[i*2+1] = Q6_Vb_vlut32_VbVbI(Q6_V_hi_W(v0_1_p), lut, 0);
210
+ }
211
+
212
+ return r;
213
+ }
214
+
215
+ static inline HVX_Vector_x8 hvx_vec_load_q8x4x8_full(const uint8_t * restrict ptr) {
203
216
  const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
204
217
 
205
218
  HVX_Vector v0 = vptr[0]; // first 128 vals
@@ -215,44 +228,8 @@ static inline HVX_Vector_x8 hvx_vec_load_q8x4x8(const uint8_t * restrict ptr) {
215
228
  return r;
216
229
  }
217
230
 
218
- static inline HVX_Vector_x4 hvx_vec_load_x4_f16(const uint8_t * restrict ptr) {
219
- const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
220
-
221
- HVX_Vector v0 = vptr[0]; // first 64 vals
222
- HVX_Vector v1 = vptr[1]; // second 64 vals
223
- HVX_Vector v2 = vptr[2]; // third 64 vals
224
- HVX_Vector v3 = vptr[3]; // forth 64 vals
225
-
226
- HVX_Vector_x4 r = { v0, v1, v2, v3 };
227
- return r;
228
- }
229
-
230
- static inline HVX_Vector_x4 hvx_vec_load_x4_f32_as_f16(const uint8_t * restrict ptr) {
231
- const HVX_VectorPair * restrict vptr = (const HVX_VectorPair *) ptr;
232
-
233
- HVX_VectorPair v0 = vptr[0]; // first 64 vals
234
- HVX_VectorPair v1 = vptr[1]; // second 64 vals
235
- HVX_VectorPair v2 = vptr[2]; // third 64 vals
236
- HVX_VectorPair v3 = vptr[3]; // forth 64 vals
237
-
238
- HVX_Vector vq0_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v0), Q6_V_vzero());
239
- HVX_Vector vq0_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v0), Q6_V_vzero());
240
- HVX_Vector vq1_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v1), Q6_V_vzero());
241
- HVX_Vector vq1_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v1), Q6_V_vzero());
242
- HVX_Vector vq2_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v2), Q6_V_vzero());
243
- HVX_Vector vq2_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v2), Q6_V_vzero());
244
- HVX_Vector vq3_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v3), Q6_V_vzero());
245
- HVX_Vector vq3_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v3), Q6_V_vzero());
246
-
247
- HVX_Vector vh0 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq0_hi, vq0_lo));
248
- HVX_Vector vh1 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq1_hi, vq1_lo));
249
- HVX_Vector vh2 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq2_hi, vq2_lo));
250
- HVX_Vector vh3 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq3_hi, vq3_lo));
251
-
252
- // vcombine does a shuffle, use vdeal to undo
253
-
254
- HVX_Vector_x4 r = { Q6_Vh_vdeal_Vh(vh0), Q6_Vh_vdeal_Vh(vh1), Q6_Vh_vdeal_Vh(vh2), Q6_Vh_vdeal_Vh(vh3) };
255
- return r;
231
+ static inline HVX_Vector_x8 hvx_vec_load_q8x4x8_partial(const uint8_t * restrict ptr, uint32_t nloe) {
232
+ return hvx_vec_load_q8x4x8_full(ptr);
256
233
  }
257
234
 
258
235
  // Reduce multiply 1024 x 1024 int8 elements (32x q4/8 blocks in 8x HVX vectors).
@@ -262,14 +239,14 @@ static inline HVX_Vector_x4 hvx_vec_load_x4_f32_as_f16(const uint8_t * restrict
262
239
  // if() checks are optimized out at compile time -- make sure to pass N as a constexpr.
263
240
 
264
241
  static inline HVX_Vector hvx_vec_rmpy_x8_n(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) {
265
- HVX_Vector r0 = Q6_V_vsplat_R(0);
266
- HVX_Vector r1 = Q6_V_vsplat_R(0);
267
- HVX_Vector r2 = Q6_V_vsplat_R(0);
268
- HVX_Vector r3 = Q6_V_vsplat_R(0);
269
- HVX_Vector r4 = Q6_V_vsplat_R(0);
270
- HVX_Vector r5 = Q6_V_vsplat_R(0);
271
- HVX_Vector r6 = Q6_V_vsplat_R(0);
272
- HVX_Vector r7 = Q6_V_vsplat_R(0);
242
+ HVX_Vector r0 = Q6_V_vzero();
243
+ HVX_Vector r1 = Q6_V_vzero();
244
+ HVX_Vector r2 = Q6_V_vzero();
245
+ HVX_Vector r3 = Q6_V_vzero();
246
+ HVX_Vector r4 = Q6_V_vzero();
247
+ HVX_Vector r5 = Q6_V_vzero();
248
+ HVX_Vector r6 = Q6_V_vzero();
249
+ HVX_Vector r7 = Q6_V_vzero();
273
250
 
274
251
  HVX_VectorPair p3;
275
252
  HVX_VectorPair p2;
@@ -308,40 +285,67 @@ static inline HVX_Vector hvx_vec_rmpy_x8_n(HVX_Vector_x8 x, HVX_Vector_x8 y, uns
308
285
  }
309
286
 
310
287
  static inline HVX_Vector hvx_vec_rmpy_x8_full(HVX_Vector_x8 x, HVX_Vector_x8 y) {
311
- return hvx_vec_rmpy_x8_n(x, y, 1024);
288
+ HVX_Vector r0 = Q6_Vw_vrmpy_VbVb(x.v[0], y.v[0]);
289
+ HVX_Vector r1 = Q6_Vw_vrmpy_VbVb(x.v[1], y.v[1]);
290
+ HVX_Vector r2 = Q6_Vw_vrmpy_VbVb(x.v[2], y.v[2]);
291
+ HVX_Vector r3 = Q6_Vw_vrmpy_VbVb(x.v[3], y.v[3]);
292
+ HVX_Vector r4 = Q6_Vw_vrmpy_VbVb(x.v[4], y.v[4]);
293
+ HVX_Vector r5 = Q6_Vw_vrmpy_VbVb(x.v[5], y.v[5]);
294
+ HVX_Vector r6 = Q6_Vw_vrmpy_VbVb(x.v[6], y.v[6]);
295
+ HVX_Vector r7 = Q6_Vw_vrmpy_VbVb(x.v[7], y.v[7]);
296
+
297
+ HVX_VectorPair p0 = Q6_W_vdeal_VVR(r1, r0, -4);
298
+ HVX_VectorPair p1 = Q6_W_vdeal_VVR(r3, r2, -4);
299
+ HVX_VectorPair p2 = Q6_W_vdeal_VVR(r5, r4, -4);
300
+ HVX_VectorPair p3 = Q6_W_vdeal_VVR(r7, r6, -4);
301
+
302
+ r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0));
303
+ r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1));
304
+ r2 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p2), Q6_V_hi_W(p2));
305
+ r3 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p3), Q6_V_hi_W(p3));
306
+
307
+ p0 = Q6_W_vdeal_VVR(r1, r0, -4);
308
+ p1 = Q6_W_vdeal_VVR(r3, r2, -4);
309
+
310
+ r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0));
311
+ r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1));
312
+
313
+ p0 = Q6_W_vdeal_VVR(r1, r0, -4);
314
+ r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0));
315
+
316
+ return r0;
312
317
  }
313
318
 
314
- // Handle most common cases of tensors not multiple of 1024.
315
- static inline HVX_Vector hvx_vec_rmpy_x8_nloe(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) {
316
- if (n <= 256) { return hvx_vec_rmpy_x8_n(x, y, 256); };
317
- if (n <= 512) { return hvx_vec_rmpy_x8_n(x, y, 512); };
318
- if (n <= 768) { return hvx_vec_rmpy_x8_n(x, y, 768); };
319
- return hvx_vec_rmpy_x8_n(x, y, 1024);
319
+ static inline HVX_Vector hvx_vec_rmpy_x8_partial(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) {
320
+ if (n >= 512)
321
+ return hvx_vec_rmpy_x8_full(x, y);
322
+
323
+ return hvx_vec_rmpy_x8_partial(x, y, 512);
320
324
  }
321
325
 
322
- static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
326
+ static void vec_dot_q4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) {
323
327
  assert(n % 32 == 0); // min sub-block size
324
- assert((unsigned long) vx % 128 == 0);
325
- assert((unsigned long) vy % 128 == 0);
328
+ assert((unsigned long) vx0 % 128 == 0);
329
+ assert((unsigned long) vy0 % 128 == 0);
326
330
 
327
331
  const uint32_t qk = QK_Q4_0x4x2 * 4;
328
332
 
329
- const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
330
- const uint32_t x_qblk_size = qk / 2; // int4
331
- const uint32_t x_qrow_size = n / 2; // int4 (not padded)
333
+ const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
334
+ const uint32_t x_qblk_size = qk / 2; // int4
335
+ const uint32_t x_qrow_size = n / 2; // int4 (not padded)
332
336
 
333
- const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
334
- const uint32_t y_qblk_size = qk; // int8
335
- const uint32_t y_qrow_size = n; // int8 (not padded)
337
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
338
+ const uint32_t y_qblk_size = qk; // int8
339
+ const uint32_t y_qrow_size = n; // int8 (not padded)
336
340
 
337
- const uint8_t * restrict r0_x_q = ((const uint8_t *) vx + 0); // quants first
338
- const uint8_t * restrict r0_x_d = ((const uint8_t *) vx + x_qrow_size); // then scales
341
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first
342
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales
339
343
 
340
- const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first
341
- const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales
344
+ const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
345
+ const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
342
346
 
343
- // Row sum (qf32)
344
- HVX_Vector r0_sum = Q6_V_vsplat_R(0);
347
+ // Row sum (sf)
348
+ HVX_Vector r0_sum = Q6_V_vzero();
345
349
 
346
350
  // Multiply and accumulate into int32.
347
351
  // Compute combined scale (fp32).
@@ -352,79 +356,77 @@ static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void *
352
356
 
353
357
  uint32_t i = 0;
354
358
  for (; i < nb; i++) {
355
- HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
356
- HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
359
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
360
+ HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size);
357
361
 
358
362
  HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
359
363
 
360
- HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
364
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
361
365
  HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
362
366
 
363
367
  HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
364
368
 
365
369
  HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
366
370
 
367
- r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
371
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
368
372
  }
369
373
 
370
- // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
374
+ // Process leftovers
371
375
  if (nloe) {
372
- HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
373
- HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
376
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
377
+ HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
374
378
 
375
- HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
379
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
376
380
 
377
- HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
381
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
378
382
  HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
379
383
 
380
384
  HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
381
385
 
382
- // Zero out unused scales
386
+ // Zero out unused elements
383
387
  HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
384
388
  r0_dd = Q6_V_vand_QV(bmask, r0_dd);
389
+ r0_ia = Q6_V_vand_QV(bmask, r0_ia);
385
390
 
386
391
  HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
387
392
 
388
- r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
393
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
389
394
  }
390
395
 
391
- // Reduce and convert into fp32
392
- r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
396
+ r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
393
397
 
394
- hvx_vec_store_u(&s[0], 4, r0_sum);
398
+ hvx_vec_store_u(s0, 4, r0_sum);
395
399
  }
396
400
 
397
- static void vec_dot_q4x4x2_q8x4x2_rx2(const int n,
398
- float * restrict s,
399
- const void * restrict vx,
400
- uint32_t vx_row_size,
401
- const void * restrict vy) {
401
+ static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0,
402
+ const void * restrict vx0, const void * restrict vx1,
403
+ const void * restrict vy0) {
402
404
  assert(n % 32 == 0); // min sub-block size
403
- assert((unsigned long) vx % 128 == 0);
404
- assert((unsigned long) vy % 128 == 0);
405
+ assert((unsigned long) vx0 % 128 == 0);
406
+ assert((unsigned long) vx1 % 128 == 0);
407
+ assert((unsigned long) vy0 % 128 == 0);
405
408
 
406
409
  const uint32_t qk = QK_Q4_0x4x2 * 4;
407
410
 
408
- const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
409
- const uint32_t x_qblk_size = qk / 2; // int4
410
- const uint32_t x_qrow_size = n / 2; // int4 (not padded)
411
-
412
- const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
413
- const uint32_t y_qblk_size = qk; // int8
414
- const uint32_t y_qrow_size = n; // int8 (not padded)
411
+ const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
412
+ const uint32_t x_qblk_size = qk / 2; // int4
413
+ const uint32_t x_qrow_size = n / 2; // int4 (not padded)
415
414
 
416
- const uint8_t * restrict r0_x_q = ((const uint8_t *) (vx + (0 * vx_row_size)) + 0); // quants first
417
- const uint8_t * restrict r0_x_d = ((const uint8_t *) (vx + (0 * vx_row_size)) + x_qrow_size); // then scales
415
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
416
+ const uint32_t y_qblk_size = qk; // int8
417
+ const uint32_t y_qrow_size = n; // int8 (not padded)
418
418
 
419
- const uint8_t * restrict r1_x_q = ((const uint8_t *) (vx + (1 * vx_row_size)) + 0); // quants first
420
- const uint8_t * restrict r1_x_d = ((const uint8_t *) (vx + (1 * vx_row_size)) + x_qrow_size); // then scales
419
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
420
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
421
+ const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
422
+ const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
421
423
 
422
- const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first
423
- const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales
424
+ const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
425
+ const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
424
426
 
425
- // Row sum (qf32)
426
- HVX_Vector r0_sum = Q6_V_vsplat_R(0);
427
- HVX_Vector r1_sum = Q6_V_vsplat_R(0);
427
+ // Row sum (sf)
428
+ HVX_Vector r0_sum = Q6_V_vzero();
429
+ HVX_Vector r1_sum = Q6_V_vzero();
428
430
 
429
431
  // Multiply and accumulate into int32.
430
432
  // Compute combined scale (fp32).
@@ -435,14 +437,14 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n,
435
437
 
436
438
  uint32_t i = 0;
437
439
  for (; i < nb; i++) {
438
- HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
439
- HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
440
- HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size);
440
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
441
+ HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size);
442
+ HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size);
441
443
 
442
444
  HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
443
445
  HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
444
446
 
445
- HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
447
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
446
448
  HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
447
449
  HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
448
450
 
@@ -452,50 +454,178 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n,
452
454
  HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
453
455
  HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
454
456
 
455
- r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
456
- r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
457
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
458
+ r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
457
459
  }
458
460
 
459
- // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
461
+ // Process leftovers
460
462
  if (nloe) {
461
- HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
462
- HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
463
- HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size);
463
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
464
+ HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
465
+ HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
464
466
 
465
- HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
466
- HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy_q, nloe));
467
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
468
+ HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe));
467
469
 
468
- HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
470
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
469
471
  HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
470
472
  HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
471
473
 
472
474
  HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
473
475
  HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
474
476
 
475
- // Zero out unused scales
477
+ // Zero out unused elements
476
478
  HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
477
479
  r0_dd = Q6_V_vand_QV(bmask, r0_dd);
478
480
  r1_dd = Q6_V_vand_QV(bmask, r1_dd);
481
+ r0_ia = Q6_V_vand_QV(bmask, r0_ia);
482
+ r1_ia = Q6_V_vand_QV(bmask, r1_ia);
479
483
 
480
484
  HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
481
485
  HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
482
486
 
483
- r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
484
- r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
487
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
488
+ r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
485
489
  }
486
490
 
487
- // Convert into fp32 and reduce
488
- r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
489
- r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum));
490
- HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4);
491
+ HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
492
+ hvx_vec_store_u(s0, 8, rsum);
493
+ }
494
+
495
+ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1,
496
+ const void * restrict vx0, const void * restrict vx1,
497
+ const void * restrict vy0, const void * restrict vy1) {
498
+ assert(n % 32 == 0);
499
+ assert((unsigned long) vx0 % 128 == 0);
500
+ assert((unsigned long) vx1 % 128 == 0);
501
+ assert((unsigned long) vy0 % 128 == 0);
502
+ assert((unsigned long) vy1 % 128 == 0);
503
+
504
+ const uint32_t qk = QK_Q4_0x4x2 * 4;
505
+
506
+ const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
507
+ const uint32_t x_qblk_size = qk / 2; // int4
508
+ const uint32_t x_qrow_size = n / 2; // int4 (not padded)
509
+
510
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
511
+ const uint32_t y_qblk_size = qk; // int8
512
+ const uint32_t y_qrow_size = n; // int8 (not padded)
513
+
514
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
515
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
516
+ const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
517
+ const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
518
+
519
+ const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first
520
+ const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales
521
+ const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first
522
+ const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales
523
+
524
+ // Row sums (sf) - 4 accumulators for 2×2 tile
525
+ HVX_Vector r0_c0_sum = Q6_V_vzero();
526
+ HVX_Vector r0_c1_sum = Q6_V_vzero();
527
+ HVX_Vector r1_c0_sum = Q6_V_vzero();
528
+ HVX_Vector r1_c1_sum = Q6_V_vzero();
529
+
530
+ const uint32_t nb = n / qk; // num full blocks
531
+ const uint32_t nloe = n % qk; // num leftover elements
532
+
533
+ uint32_t i = 0;
534
+ for (; i < nb; i++) {
535
+ // Load src1 columns (reused across both src0 rows)
536
+ HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size);
537
+ HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size);
538
+
539
+ // Load src0 rows (reused across both src1 columns)
540
+ HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size);
541
+ HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size);
542
+
543
+ // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1
544
+ HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
545
+ HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q));
546
+ HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q));
547
+ HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q));
548
+
549
+ // Load scales
550
+ HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
551
+ HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
552
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
553
+ HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
554
+
555
+ // Compute combined scales
556
+ HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
557
+ HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
558
+ HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
559
+ HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
560
+
561
+ // Apply scales and accumulate
562
+ HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
563
+ HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
564
+ HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
565
+ HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
566
+
567
+ r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
568
+ r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
569
+ r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
570
+ r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
571
+ }
572
+
573
+ // Process leftovers
574
+ if (nloe) {
575
+ HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe);
576
+ HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe);
577
+ HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
578
+ HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
579
+
580
+ HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe));
581
+ HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe));
582
+ HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe));
583
+ HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe));
584
+
585
+ HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
586
+ HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
587
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
588
+ HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
589
+
590
+ HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
591
+ HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
592
+ HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
593
+ HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
594
+
595
+ // Zero out unused scales
596
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
597
+ r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);
598
+ r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);
599
+ r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd);
600
+ r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd);
601
+ r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia);
602
+ r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia);
603
+ r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia);
604
+ r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia);
605
+
606
+ HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
607
+ HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
608
+ HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
609
+ HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
610
+
611
+ r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
612
+ r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
613
+ r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
614
+ r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
615
+ }
616
+
617
+ // Reduce and store results
618
+ HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
619
+ HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
491
620
 
492
- hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
621
+ hvx_vec_store_u(s0, 8, r0_r1_c0_sum); // row0,col0 row1,col0
622
+ hvx_vec_store_u(s1, 8, r0_r1_c1_sum); // row0,col1 row1,col1
493
623
  }
494
624
 
495
- static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
625
+ static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) {
496
626
  assert(n % 32 == 0); // min sub-block size
497
- assert((unsigned long) vx % 128 == 0);
498
- assert((unsigned long) vy % 128 == 0);
627
+ assert((unsigned long) vx0 % 128 == 0);
628
+ assert((unsigned long) vy0 % 128 == 0);
499
629
 
500
630
  const uint32_t qk = QK_Q4_0x4x2 * 4;
501
631
 
@@ -507,14 +637,14 @@ static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void *
507
637
  const uint32_t y_qblk_size = qk; // int8
508
638
  const uint32_t y_qrow_size = n; // int8 (not padded)
509
639
 
510
- const uint8_t * restrict r0_x_q = ((const uint8_t *) vx + 0); // quants first
511
- const uint8_t * restrict r0_x_d = ((const uint8_t *) vx + x_qrow_size); // then scales
640
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first
641
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales
512
642
 
513
- const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first
514
- const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales
643
+ const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
644
+ const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
515
645
 
516
- // Row sum (qf32)
517
- HVX_Vector r0_sum = Q6_V_vsplat_R(0);
646
+ // Row sum (sf)
647
+ HVX_Vector r0_sum = Q6_V_vzero();
518
648
 
519
649
  // Multiply and accumulate into int32.
520
650
  // Compute combined scale (fp32).
@@ -525,79 +655,77 @@ static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void *
525
655
 
526
656
  uint32_t i = 0;
527
657
  for (; i < nb; i++) {
528
- HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
529
- HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
658
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
659
+ HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size);
530
660
 
531
661
  HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
532
662
 
533
- HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
663
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
534
664
  HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
535
665
 
536
666
  HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
537
667
 
538
668
  HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
539
669
 
540
- r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
670
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
541
671
  }
542
672
 
543
- // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
673
+ // Process leftovers
544
674
  if (nloe) {
545
- HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
546
- HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
675
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
676
+ HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
547
677
 
548
- HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
678
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
549
679
 
550
- HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
680
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
551
681
  HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
552
682
 
553
683
  HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
554
684
 
555
- // Zero out unused scales
685
+ // Zero out unused elements
556
686
  HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
557
687
  r0_dd = Q6_V_vand_QV(bmask, r0_dd);
688
+ r0_ia = Q6_V_vand_QV(bmask, r0_ia);
558
689
 
559
690
  HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
560
691
 
561
- r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
692
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
562
693
  }
563
694
 
564
- // Reduce and convert into fp32
565
- r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
695
+ r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
566
696
 
567
- hvx_vec_store_u(&s[0], 4, r0_sum);
697
+ hvx_vec_store_u(s0, 4, r0_sum);
568
698
  }
569
699
 
570
- static void vec_dot_q8x4x2_q8x4x2_rx2(const int n,
571
- float * restrict s,
572
- const void * restrict vx,
573
- uint32_t vx_row_size,
574
- const void * restrict vy) {
700
+ static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0,
701
+ const void * restrict vx0, const void * restrict vx1,
702
+ const void * restrict vy0) {
575
703
  assert(n % 32 == 0); // min sub-block size
576
- assert((unsigned long) vx % 128 == 0);
577
- assert((unsigned long) vy % 128 == 0);
704
+ assert((unsigned long) vx0 % 128 == 0);
705
+ assert((unsigned long) vx1 % 128 == 0);
706
+ assert((unsigned long) vy0 % 128 == 0);
578
707
 
579
708
  const uint32_t qk = QK_Q4_0x4x2 * 4;
580
709
 
581
- const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
582
- const uint32_t x_qblk_size = qk; // int8
583
- const uint32_t x_qrow_size = n; // int8 (not padded)
584
-
585
- const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
586
- const uint32_t y_qblk_size = qk; // int8
587
- const uint32_t y_qrow_size = n; // int8 (not padded)
710
+ const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
711
+ const uint32_t x_qblk_size = qk; // int8
712
+ const uint32_t x_qrow_size = n; // int8 (not padded)
588
713
 
589
- const uint8_t * restrict r0_x_q = ((const uint8_t *) (vx + (0 * vx_row_size)) + 0); // quants first
590
- const uint8_t * restrict r0_x_d = ((const uint8_t *) (vx + (0 * vx_row_size)) + x_qrow_size); // then scales
714
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
715
+ const uint32_t y_qblk_size = qk; // int8
716
+ const uint32_t y_qrow_size = n; // int8 (not padded)
591
717
 
592
- const uint8_t * restrict r1_x_q = ((const uint8_t *) (vx + (1 * vx_row_size)) + 0); // quants first
593
- const uint8_t * restrict r1_x_d = ((const uint8_t *) (vx + (1 * vx_row_size)) + x_qrow_size); // then scales
718
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
719
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
720
+ const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
721
+ const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
594
722
 
595
- const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first
596
- const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales
723
+ const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
724
+ const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
597
725
 
598
726
  // Row sum (qf32)
599
- HVX_Vector r0_sum = Q6_V_vsplat_R(0);
600
- HVX_Vector r1_sum = Q6_V_vsplat_R(0);
727
+ HVX_Vector r0_sum = Q6_V_vzero();
728
+ HVX_Vector r1_sum = Q6_V_vzero();
601
729
 
602
730
  // Multiply and accumulate into int32.
603
731
  // Compute combined scale (fp32).
@@ -608,14 +736,14 @@ static void vec_dot_q8x4x2_q8x4x2_rx2(const int n,
608
736
 
609
737
  uint32_t i = 0;
610
738
  for (; i < nb; i++) {
611
- HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
612
- HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
613
- HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size);
739
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
740
+ HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size);
741
+ HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_full(r1_x_q + i * x_qblk_size);
614
742
 
615
743
  HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
616
744
  HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
617
745
 
618
- HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
746
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
619
747
  HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
620
748
  HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
621
749
 
@@ -625,18 +753,18 @@ static void vec_dot_q8x4x2_q8x4x2_rx2(const int n,
625
753
  HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
626
754
  HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
627
755
 
628
- r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
629
- r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
756
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
757
+ r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
630
758
  }
631
759
 
632
- // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
760
+ // Process leftovers
633
761
  if (nloe) {
634
- HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
635
- HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
636
- HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size);
762
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
763
+ HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
764
+ HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
637
765
 
638
- HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
639
- HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy_q, nloe));
766
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
767
+ HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe));
640
768
 
641
769
  HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
642
770
  HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
@@ -645,33 +773,158 @@ static void vec_dot_q8x4x2_q8x4x2_rx2(const int n,
645
773
  HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
646
774
  HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
647
775
 
648
- // Zero out unused scales
776
+ // Zero out unused elements
649
777
  HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
650
778
  r0_dd = Q6_V_vand_QV(bmask, r0_dd);
651
779
  r1_dd = Q6_V_vand_QV(bmask, r1_dd);
780
+ r0_ia = Q6_V_vand_QV(bmask, r0_ia);
781
+ r1_ia = Q6_V_vand_QV(bmask, r1_ia);
652
782
 
653
783
  HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
654
784
  HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
655
785
 
656
- r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
657
- r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
786
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
787
+ r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
658
788
  }
659
789
 
660
- // Convert into fp32 and reduce
661
- r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
662
- r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum));
663
- HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4);
790
+ HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
791
+ hvx_vec_store_u(s0, 8, rsum);
792
+ }
793
+
794
+ static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1,
795
+ const void * restrict vx0, const void * restrict vx1,
796
+ const void * restrict vy0, const void * restrict vy1) {
797
+ assert(n % 32 == 0);
798
+ assert((unsigned long) vx0 % 128 == 0);
799
+ assert((unsigned long) vx1 % 128 == 0);
800
+ assert((unsigned long) vy0 % 128 == 0);
801
+ assert((unsigned long) vy1 % 128 == 0);
802
+
803
+ const uint32_t qk = QK_Q8_0x4x2 * 4;
804
+
805
+ const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
806
+ const uint32_t x_qblk_size = qk; // int8
807
+ const uint32_t x_qrow_size = n; // int8 (not padded)
808
+
809
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
810
+ const uint32_t y_qblk_size = qk; // int8
811
+ const uint32_t y_qrow_size = n; // int8 (not padded)
812
+
813
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
814
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
815
+ const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
816
+ const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
817
+
818
+ const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first
819
+ const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales
820
+ const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first
821
+ const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales
822
+
823
+ // Row sums (sf) - 4 accumulators for 2×2 tile
824
+ HVX_Vector r0_c0_sum = Q6_V_vzero();
825
+ HVX_Vector r0_c1_sum = Q6_V_vzero();
826
+ HVX_Vector r1_c0_sum = Q6_V_vzero();
827
+ HVX_Vector r1_c1_sum = Q6_V_vzero();
828
+
829
+ const uint32_t nb = n / qk; // num full blocks
830
+ const uint32_t nloe = n % qk; // num leftover elements
831
+
832
+ uint32_t i = 0;
833
+ for (; i < nb; i++) {
834
+ // Load src1 columns (reused across both src0 rows)
835
+ HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size);
836
+ HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size);
837
+
838
+ // Load src0 rows (reused across both src1 columns)
839
+ HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size);
840
+ HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_full(r1_x_q + i * x_qblk_size);
841
+
842
+ // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1
843
+ HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
844
+ HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q));
845
+ HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q));
846
+ HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q));
847
+
848
+ // Load scales
849
+ HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
850
+ HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
851
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
852
+ HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
853
+
854
+ // Compute combined scales
855
+ HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
856
+ HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
857
+ HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
858
+ HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
859
+
860
+ // Apply scales and accumulate
861
+ HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
862
+ HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
863
+ HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
864
+ HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
865
+
866
+ r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
867
+ r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
868
+ r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
869
+ r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
870
+ }
664
871
 
665
- hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
872
+ // Process leftovers
873
+ if (nloe) {
874
+ HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe);
875
+ HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe);
876
+ HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
877
+ HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
878
+
879
+ HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe));
880
+ HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe));
881
+ HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe));
882
+ HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe));
883
+
884
+ HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
885
+ HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
886
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
887
+ HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
888
+
889
+ HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
890
+ HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
891
+ HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
892
+ HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
893
+
894
+ // Zero out unused elements
895
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
896
+ r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);
897
+ r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);
898
+ r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd);
899
+ r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd);
900
+ r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia);
901
+ r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia);
902
+ r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia);
903
+ r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia);
904
+
905
+ HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
906
+ HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
907
+ HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
908
+ HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
909
+
910
+ r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
911
+ r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
912
+ r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
913
+ r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
914
+ }
915
+
916
+ // Reduce and store results
917
+ HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
918
+ HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
919
+
920
+ hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0
921
+ hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1
666
922
  }
667
923
 
668
- static void vec_dot_mxfp4x4x2_q8x4x2(const int n,
669
- float * restrict s,
670
- const void * restrict vx,
671
- const void * restrict vy) {
924
+ static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) {
672
925
  assert(n % 32 == 0); // min sub-block size
673
- assert((unsigned long) vx % 128 == 0);
674
- assert((unsigned long) vy % 128 == 0);
926
+ assert((unsigned long) vx0 % 128 == 0);
927
+ assert((unsigned long) vy0 % 128 == 0);
675
928
 
676
929
  const uint32_t qk = QK_MXFP4x4x2 * 4;
677
930
 
@@ -683,14 +936,14 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n,
683
936
  const uint32_t y_qblk_size = qk; // int8
684
937
  const uint32_t y_qrow_size = n; // int8 (not padded)
685
938
 
686
- const uint8_t * restrict r0_x_q = ((const uint8_t *) vx + 0); // quants first
687
- const uint8_t * restrict r0_x_d = ((const uint8_t *) vx + x_qrow_size); // then scales
939
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first
940
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales
688
941
 
689
- const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first
690
- const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales
942
+ const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
943
+ const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
691
944
 
692
- // Row sum (qf32)
693
- HVX_Vector r0_sum = Q6_V_vsplat_R(0);
945
+ // Row sum (sf)
946
+ HVX_Vector r0_sum = Q6_V_vzero();
694
947
 
695
948
  // Multiply and accumulate into int32.
696
949
  // Compute combined scale (fp32).
@@ -701,8 +954,8 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n,
701
954
 
702
955
  uint32_t i = 0;
703
956
  for (; i < nb; i++) {
704
- HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
705
- HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
957
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full( y_q + i * y_qblk_size);
958
+ HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size);
706
959
 
707
960
  HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
708
961
 
@@ -728,17 +981,17 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n,
728
981
 
729
982
  HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
730
983
 
731
- r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
984
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
732
985
  }
733
986
 
734
987
  // Process leftovers
735
988
  if (nloe) {
736
- HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
737
- HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
989
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial( y_q + i * y_qblk_size, nloe);
990
+ HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
738
991
 
739
- HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
992
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
740
993
 
741
- HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
994
+ HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
742
995
  HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
743
996
 
744
997
  // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
@@ -761,62 +1014,60 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n,
761
1014
  // Zero-out unused scales
762
1015
  HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
763
1016
  r0_dd = Q6_V_vand_QV(bmask, r0_dd);
1017
+ r0_ia = Q6_V_vand_QV(bmask, r0_ia);
764
1018
 
765
1019
  HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
766
1020
 
767
- r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
1021
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
768
1022
  }
769
1023
 
770
- // Reduce and convert into fp32
771
- r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
1024
+ r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
772
1025
 
773
- hvx_vec_store_u(&s[0], 4, r0_sum);
1026
+ hvx_vec_store_u(s0, 4, r0_sum);
774
1027
  }
775
1028
 
776
- static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n,
777
- float * restrict s,
778
- const void * restrict vx,
779
- uint32_t vx_row_size,
780
- const void * restrict vy) {
1029
+ static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0,
1030
+ const void * restrict vx0, const void * restrict vx1,
1031
+ const void * restrict vy0) {
781
1032
  assert(n % 32 == 0); // min sub-block size
782
- assert((unsigned long) vx % 128 == 0);
783
- assert((unsigned long) vy % 128 == 0);
1033
+ assert((unsigned long) vx0 % 128 == 0);
1034
+ assert((unsigned long) vx1 % 128 == 0);
1035
+ assert((unsigned long) vy0 % 128 == 0);
784
1036
 
785
1037
  const uint32_t qk = QK_MXFP4x4x2 * 4;
786
1038
 
787
- const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0
788
- const uint32_t x_qblk_size = qk / 2; // fp4
789
- const uint32_t x_qrow_size = n / 2; // fp4 (not padded)
1039
+ const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0
1040
+ const uint32_t x_qblk_size = qk / 2; // fp4
1041
+ const uint32_t x_qrow_size = n / 2; // fp4 (not padded)
790
1042
 
791
- const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
792
- const uint32_t y_qblk_size = qk; // int8
793
- const uint32_t y_qrow_size = n; // int8 (not padded)
1043
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
1044
+ const uint32_t y_qblk_size = qk; // int8
1045
+ const uint32_t y_qrow_size = n; // int8 (not padded)
794
1046
 
795
- const uint8_t * restrict r0_x_q = ((const uint8_t *) (vx + (0 * vx_row_size)) + 0); // quants first
796
- const uint8_t * restrict r0_x_d = ((const uint8_t *) (vx + (0 * vx_row_size)) + x_qrow_size); // then scales
1047
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
1048
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
1049
+ const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
1050
+ const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
797
1051
 
798
- const uint8_t * restrict r1_x_q = ((const uint8_t *) (vx + (1 * vx_row_size)) + 0); // quants first
799
- const uint8_t * restrict r1_x_d = ((const uint8_t *) (vx + (1 * vx_row_size)) + x_qrow_size); // then scales
1052
+ const uint8_t * restrict y_q = ((const uint8_t *) vy0) + 0; // quants first
1053
+ const uint8_t * restrict y_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales
800
1054
 
801
- const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first
802
- const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales
803
-
804
- // Row sum (qf32)
805
- HVX_Vector r0_sum = Q6_V_vsplat_R(0);
806
- HVX_Vector r1_sum = Q6_V_vsplat_R(0);
1055
+ // Row sum (sf)
1056
+ HVX_Vector r0_sum = Q6_V_vzero();
1057
+ HVX_Vector r1_sum = Q6_V_vzero();
807
1058
 
808
1059
  // Multiply and accumulate into int32.
809
1060
  // Compute combined scale (fp32).
810
- // Apply scale to acc and accumulate into the row sum (qf32).
1061
+ // Apply scale to acc and accumulate into the row sum (f32).
811
1062
 
812
1063
  const uint32_t nb = n / qk; // num full blocks
813
1064
  int32_t nloe = n % qk; // num leftover elemements (must be signed)
814
1065
 
815
1066
  uint32_t i = 0;
816
1067
  for (; i < nb; i++) {
817
- HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
818
- HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
819
- HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size);
1068
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full( y_q + i * y_qblk_size);
1069
+ HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size);
1070
+ HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_full(r1_x_q + i * x_qblk_size);
820
1071
 
821
1072
  HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
822
1073
  HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
@@ -849,20 +1100,20 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n,
849
1100
  HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
850
1101
  HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
851
1102
 
852
- r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
853
- r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
1103
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
1104
+ r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
854
1105
  }
855
1106
 
856
1107
  // Process leftovers
857
1108
  if (nloe) {
858
- HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
859
- HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
860
- HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size);
1109
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial( y_q + i * y_qblk_size, nloe);
1110
+ HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
1111
+ HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
861
1112
 
862
1113
  HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
863
1114
  HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
864
1115
 
865
- HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
1116
+ HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
866
1117
  HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
867
1118
  HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
868
1119
 
@@ -887,111 +1138,326 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n,
887
1138
  HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));
888
1139
  HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d));
889
1140
 
890
- // Zero-out unused scales
1141
+ // Zero-out unused values
891
1142
  HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
892
1143
  r0_dd = Q6_V_vand_QV(bmask, r0_dd);
893
1144
  r1_dd = Q6_V_vand_QV(bmask, r1_dd);
1145
+ r0_ia = Q6_V_vand_QV(bmask, r0_ia);
1146
+ r1_ia = Q6_V_vand_QV(bmask, r1_ia);
894
1147
 
895
1148
  HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
896
1149
  HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
897
1150
 
898
- r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
899
- r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
1151
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
1152
+ r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
900
1153
  }
901
1154
 
902
- // Convert into fp32 and reduce
903
- r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
904
- r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum));
905
- HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4);
1155
+ HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
1156
+ hvx_vec_store_u(s0, 8, rsum);
1157
+ }
1158
+
1159
+ static void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1,
1160
+ const void * restrict vx0, const void * restrict vx1,
1161
+ const void * restrict vy0, const void * restrict vy1) {
1162
+ assert(n % 32 == 0);
1163
+ assert((unsigned long) vx0 % 128 == 0);
1164
+ assert((unsigned long) vx1 % 128 == 0);
1165
+ assert((unsigned long) vy0 % 128 == 0);
1166
+ assert((unsigned long) vy1 % 128 == 0);
1167
+
1168
+ const uint32_t qk = QK_MXFP4x4x2 * 4;
1169
+
1170
+ const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0
1171
+ const uint32_t x_qblk_size = qk / 2; // fp4
1172
+ const uint32_t x_qrow_size = n / 2; // fp4 (not padded)
1173
+
1174
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
1175
+ const uint32_t y_qblk_size = qk; // int8
1176
+ const uint32_t y_qrow_size = n; // int8 (not padded)
1177
+
1178
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
1179
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
1180
+ const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
1181
+ const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
1182
+
1183
+ const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first
1184
+ const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales
1185
+ const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first
1186
+ const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales
1187
+
1188
+ // Row sums (sf) - 4 accumulators for 2×2 tile
1189
+ HVX_Vector r0_c0_sum = Q6_V_vzero();
1190
+ HVX_Vector r0_c1_sum = Q6_V_vzero();
1191
+ HVX_Vector r1_c0_sum = Q6_V_vzero();
1192
+ HVX_Vector r1_c1_sum = Q6_V_vzero();
1193
+
1194
+ const uint32_t nb = n / qk; // num full blocks
1195
+ const uint32_t nloe = n % qk; // num leftover elements
1196
+
1197
+ uint32_t i = 0;
1198
+ for (; i < nb; i++) {
1199
+ // Load src1 columns (reused across both src0 rows)
1200
+ HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size);
1201
+ HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size);
1202
+
1203
+ // Load src0 rows (reused across both src1 columns)
1204
+ HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size);
1205
+ HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_full(r1_x_q + i * x_qblk_size);
1206
+
1207
+ // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1
1208
+ HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
1209
+ HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q));
1210
+ HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q));
1211
+ HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q));
1212
+
1213
+ // Load scales
1214
+ HVX_Vector vy0_d = *(const HVX_UVector *) (y0_d + i * y_dblk_size);
1215
+ HVX_Vector vy1_d = *(const HVX_UVector *) (y1_d + i * y_dblk_size);
1216
+ HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
1217
+ HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
1218
+
1219
+ // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
1220
+ HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16
1221
+ vy0_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy0_d), half));
1222
+ vy0_d = Q6_Vsf_equals_Vqf32(vy0_d);
1223
+ vy1_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy1_d), half));
1224
+ vy1_d = Q6_Vsf_equals_Vqf32(vy1_d);
1225
+
1226
+ // Convert rX_d scales from e8m0 to fp32
1227
+ // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
1228
+ // Left shift with zero fill to create FP32
1229
+ // FIXME: might need to handle zero as a special case (see ggml-cpu code)
1230
+ HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0;
1231
+ HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
1232
+ r0_d = Q6_V_vdelta_VV(r0_d, expand);
1233
+ r0_d = Q6_V_vand_VV(r0_d, e8m0_mask);
1234
+ r0_d = Q6_Vw_vasl_VwR(r0_d, 23);
1235
+ r1_d = Q6_V_vdelta_VV(r1_d, expand);
1236
+ r1_d = Q6_V_vand_VV(r1_d, e8m0_mask);
1237
+ r1_d = Q6_Vw_vasl_VwR(r1_d, 23);
1238
+
1239
+ // Compute combined scales
1240
+ HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy0_d));
1241
+ HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy1_d));
1242
+ HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy0_d));
1243
+ HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy1_d));
1244
+
1245
+ // Apply scales and accumulate
1246
+ HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
1247
+ HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
1248
+ HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
1249
+ HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
1250
+
1251
+ r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
1252
+ r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
1253
+ r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
1254
+ r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
1255
+ }
1256
+
1257
+ // Process leftovers
1258
+ if (nloe) {
1259
+ HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial( y0_q + i * y_qblk_size, nloe);
1260
+ HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial( y1_q + i * y_qblk_size, nloe);
1261
+ HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
1262
+ HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
1263
+
1264
+ HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe));
1265
+ HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe));
1266
+ HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe));
1267
+ HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe));
1268
+
1269
+ HVX_Vector vy0_d = *(const HVX_UVector *) (y0_d + i * y_dblk_size);
1270
+ HVX_Vector vy1_d = *(const HVX_UVector *) (y1_d + i * y_dblk_size);
1271
+ HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
1272
+ HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
906
1273
 
907
- hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
1274
+ // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
1275
+ HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16
1276
+ vy0_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy0_d), half));
1277
+ vy0_d = Q6_Vsf_equals_Vqf32(vy0_d);
1278
+ vy1_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy1_d), half));
1279
+ vy1_d = Q6_Vsf_equals_Vqf32(vy1_d);
1280
+
1281
+ // Convert rX_d scales from e8m0 to fp32
1282
+ // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
1283
+ // Left shift with zero fill to create FP32
1284
+ // FIXME: might need to handle zero as a special case (see ggml-cpu code)
1285
+ HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0;
1286
+ HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
1287
+ r0_d = Q6_V_vdelta_VV(r0_d, expand);
1288
+ r0_d = Q6_V_vand_VV(r0_d, e8m0_mask);
1289
+ r0_d = Q6_Vw_vasl_VwR(r0_d, 23);
1290
+ r1_d = Q6_V_vdelta_VV(r1_d, expand);
1291
+ r1_d = Q6_V_vand_VV(r1_d, e8m0_mask);
1292
+ r1_d = Q6_Vw_vasl_VwR(r1_d, 23);
1293
+
1294
+ HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy0_d));
1295
+ HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy1_d));
1296
+ HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy0_d));
1297
+ HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy1_d));
1298
+
1299
+ // Zero out unused scales
1300
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
1301
+ r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);
1302
+ r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);
1303
+ r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd);
1304
+ r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd);
1305
+ r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia);
1306
+ r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia);
1307
+ r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia);
1308
+ r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia);
1309
+
1310
+ HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
1311
+ HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
1312
+ HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
1313
+ HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
1314
+
1315
+ r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
1316
+ r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
1317
+ r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
1318
+ r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
1319
+ }
1320
+
1321
+ // Reduce and store results
1322
+ HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
1323
+ HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
1324
+
1325
+ hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0
1326
+ hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1
908
1327
  }
909
1328
 
910
- static void vec_dot_f16_f16_aa(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
1329
+ static void vec_dot_f16_f16_aa_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
911
1330
  const HVX_Vector * restrict x = (const HVX_Vector *) vx;
912
1331
  const HVX_Vector * restrict y = (const HVX_Vector *) vy;
913
1332
 
914
1333
  uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
915
1334
  uint32_t nloe = n % VLEN_FP16; // leftover elements
916
1335
 
917
- HVX_Vector rsum = Q6_V_vsplat_R(0);
1336
+ HVX_VectorPair rsum_p = Q6_W_vzero();
918
1337
 
919
1338
  uint32_t i = 0;
920
1339
 
921
1340
  #pragma unroll(4)
922
1341
  for (i = 0; i < nvec; i++) {
923
- HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x[i], y[i]);
924
- rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
1342
+ rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x[i], y[i]);
925
1343
  }
926
1344
 
927
1345
  if (nloe) {
928
1346
  HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
929
1347
  HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]);
930
1348
  HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
931
-
932
- HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
933
- rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
1349
+ rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x_hf, y_hf);
934
1350
  }
935
1351
 
936
- rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
937
- hvx_vec_store_u(&s[0], 4, rsum);
1352
+ HVX_Vector rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p)));
1353
+ hvx_vec_store_u(s, 4, hvx_vec_reduce_sum_f32(rsum));
938
1354
  }
939
1355
 
940
- static void vec_dot_f16_f16_aa_rx2(const int n,
941
- float * restrict s,
942
- const void * restrict vx,
943
- uint32_t vx_row_size,
944
- const void * restrict vy) {
945
- const HVX_Vector * restrict x0 = (const HVX_Vector *) vx;
946
- const HVX_Vector * restrict x1 = (const HVX_Vector *) ((const uint8_t *) vx + vx_row_size);
947
- const HVX_Vector * restrict y = (const HVX_Vector *) vy;
1356
+ static void vec_dot_f16_f16_aa_2x1(const int n, float * restrict s0,
1357
+ const void * restrict vx0, const void * restrict vx1,
1358
+ const void * restrict vy0) {
1359
+ const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
1360
+ const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
1361
+ const HVX_Vector * restrict y = (const HVX_Vector *) vy0;
948
1362
 
949
1363
  uint32_t nvec = n / VLEN_FP16;
950
1364
  uint32_t nloe = n % VLEN_FP16;
951
1365
 
952
- HVX_Vector rsum0 = Q6_V_vsplat_R(0);
953
- HVX_Vector rsum1 = Q6_V_vsplat_R(0);
1366
+ HVX_VectorPair rsum0_p = Q6_W_vzero();
1367
+ HVX_VectorPair rsum1_p = Q6_W_vzero();
954
1368
 
955
1369
  uint32_t i = 0;
956
1370
 
957
1371
  #pragma unroll(2)
958
1372
  for (i = 0; i < nvec; i++) {
959
1373
  HVX_Vector y_hf = y[i];
960
- HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0[i], y_hf);
961
- HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1[i], y_hf);
962
-
963
- rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)));
964
- rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)));
1374
+ rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0[i], y_hf);
1375
+ rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1[i], y_hf);
965
1376
  }
966
1377
 
967
1378
  if (nloe) {
968
1379
  HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
1380
+ HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
969
1381
  HVX_Vector x0_hf = Q6_V_vand_QV(bmask, x0[i]);
970
1382
  HVX_Vector x1_hf = Q6_V_vand_QV(bmask, x1[i]);
971
- HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
1383
+ rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf);
1384
+ rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf);
1385
+ }
1386
+
1387
+ HVX_Vector rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p)));
1388
+ HVX_Vector rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p)));
1389
+ HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(rsum0, rsum1);
1390
+ hvx_vec_store_u(s0, 8, rsum);
1391
+ }
972
1392
 
973
- HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
974
- HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
1393
+ static void vec_dot_f16_f16_aa_2x2(const int n, float * restrict s0, float * restrict s1,
1394
+ const void * restrict vx0, const void * restrict vx1,
1395
+ const void * restrict vy0, const void * restrict vy1) {
1396
+ const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
1397
+ const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
1398
+ const HVX_Vector * restrict y0 = (const HVX_Vector *) vy0;
1399
+ const HVX_Vector * restrict y1 = (const HVX_Vector *) vy1;
1400
+
1401
+ uint32_t nvec = n / VLEN_FP16;
1402
+ uint32_t nloe = n % VLEN_FP16;
975
1403
 
976
- rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)));
977
- rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)));
1404
+ // Row sums (sf) - 4 accumulators for 2×2 tile
1405
+ HVX_VectorPair r0_c0_sum_p = Q6_W_vzero();
1406
+ HVX_VectorPair r0_c1_sum_p = Q6_W_vzero();
1407
+ HVX_VectorPair r1_c0_sum_p = Q6_W_vzero();
1408
+ HVX_VectorPair r1_c1_sum_p = Q6_W_vzero();
1409
+
1410
+ uint32_t i = 0;
1411
+
1412
+ #pragma unroll(2)
1413
+ for (i = 0; i < nvec; i++) {
1414
+ HVX_Vector r0_hf = x0[i];
1415
+ HVX_Vector r1_hf = x1[i];
1416
+ HVX_Vector c0_hf = y0[i];
1417
+ HVX_Vector c1_hf = y1[i];
1418
+
1419
+ // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1
1420
+ r0_c0_sum_p = hvx_vec_mpyacc_f32_f16(r0_c0_sum_p, r0_hf, c0_hf);
1421
+ r0_c1_sum_p = hvx_vec_mpyacc_f32_f16(r0_c1_sum_p, r0_hf, c1_hf);
1422
+ r1_c0_sum_p = hvx_vec_mpyacc_f32_f16(r1_c0_sum_p, r1_hf, c0_hf);
1423
+ r1_c1_sum_p = hvx_vec_mpyacc_f32_f16(r1_c1_sum_p, r1_hf, c1_hf);
1424
+ }
1425
+
1426
+ if (nloe) {
1427
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
1428
+
1429
+ HVX_Vector r0_hf = Q6_V_vand_QV(bmask, x0[i]);
1430
+ HVX_Vector r1_hf = Q6_V_vand_QV(bmask, x1[i]);
1431
+ HVX_Vector c0_hf = Q6_V_vand_QV(bmask, y0[i]);
1432
+ HVX_Vector c1_hf = Q6_V_vand_QV(bmask, y1[i]);
1433
+
1434
+ r0_c0_sum_p = hvx_vec_mpyacc_f32_f16(r0_c0_sum_p, r0_hf, c0_hf);
1435
+ r0_c1_sum_p = hvx_vec_mpyacc_f32_f16(r0_c1_sum_p, r0_hf, c1_hf);
1436
+ r1_c0_sum_p = hvx_vec_mpyacc_f32_f16(r1_c0_sum_p, r1_hf, c0_hf);
1437
+ r1_c1_sum_p = hvx_vec_mpyacc_f32_f16(r1_c1_sum_p, r1_hf, c1_hf);
978
1438
  }
979
1439
 
980
- rsum0 = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum0));
981
- rsum1 = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum1));
982
- HVX_VectorPair p0 = Q6_W_vshuff_VVR(rsum1, rsum0, 4);
1440
+ HVX_Vector r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r0_c0_sum_p), Q6_V_hi_W(r0_c0_sum_p)));
1441
+ HVX_Vector r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r0_c1_sum_p), Q6_V_hi_W(r0_c1_sum_p)));
1442
+ HVX_Vector r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r1_c0_sum_p), Q6_V_hi_W(r1_c0_sum_p)));
1443
+ HVX_Vector r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r1_c1_sum_p), Q6_V_hi_W(r1_c1_sum_p)));
1444
+
1445
+ // Reduce and store results
1446
+ HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
1447
+ HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
983
1448
 
984
- hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
1449
+ hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0
1450
+ hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1
985
1451
  }
986
1452
 
987
- static void vec_dot_f16_f16_uu(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
1453
+ static void vec_dot_f16_f16_uu_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
988
1454
  const HVX_UVector * restrict x = (const HVX_UVector *) vx;
989
1455
  const HVX_UVector * restrict y = (const HVX_UVector *) vy;
990
1456
 
991
1457
  uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
992
1458
  uint32_t nloe = n % VLEN_FP16; // leftover elements
993
1459
 
994
- HVX_Vector rsum = Q6_V_vsplat_R(0);
1460
+ HVX_Vector rsum = Q6_V_vzero();
995
1461
 
996
1462
  uint32_t i = 0;
997
1463
 
@@ -1010,20 +1476,20 @@ static void vec_dot_f16_f16_uu(const int n, float * restrict s, const void * res
1010
1476
  rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
1011
1477
  }
1012
1478
 
1013
- rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
1479
+ rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum));
1014
1480
  hvx_vec_store_u(&s[0], 4, rsum);
1015
1481
  }
1016
1482
 
1017
- static void vec_dot_f16_f32_uu(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
1483
+ static void vec_dot_f16_f32_uu_1x1(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
1018
1484
  const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x;
1019
1485
  const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y;
1020
1486
 
1021
1487
  uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
1022
1488
  uint32_t nloe = n % VLEN_FP16; // leftover elements
1023
1489
 
1024
- const HVX_Vector zero = Q6_V_vsplat_R(0);
1490
+ const HVX_Vector zero = Q6_V_vzero();
1025
1491
 
1026
- HVX_Vector rsum = Q6_V_vsplat_R(0);
1492
+ HVX_Vector rsum = Q6_V_vzero();
1027
1493
 
1028
1494
  uint32_t i = 0;
1029
1495
 
@@ -1062,7 +1528,8 @@ static void vec_dot_f16_f32_uu(const int n, float * restrict s, const void * res
1062
1528
  rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
1063
1529
  }
1064
1530
 
1065
- rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
1531
+ // Convert into fp32 and reduce
1532
+ rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum));
1066
1533
  hvx_vec_store_u(&s[0], 4, rsum);
1067
1534
  }
1068
1535
 
@@ -1110,14 +1577,16 @@ static void vec_dot_f16_f32_uu(const int n, float * restrict s, const void * res
1110
1577
  const uint32_t nb2 = dst->nb[2]; \
1111
1578
  const uint32_t nb3 = dst->nb[3];
1112
1579
 
1113
- #define htp_matmul_preamble \
1114
- htp_matmul_tensors_preamble; \
1115
- dma_queue *dma_queue = octx->ctx->dma[ith]; \
1116
- uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
1580
+ #define htp_matmul_preamble \
1581
+ struct htp_matmul_context * mmctx = data; \
1582
+ struct htp_ops_context * octx = mmctx->octx; \
1583
+ htp_matmul_tensors_preamble; \
1584
+ dma_queue *dma_queue = octx->ctx->dma[ith]; \
1585
+ uint32_t src0_nrows_per_thread = mmctx->src0_nrows_per_thread;
1117
1586
 
1118
1587
  // *** matmul with support for 4d tensors and full broadcasting
1119
1588
 
1120
- static void matmul_4d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) {
1589
+ static void matmul_4d(unsigned int nth, unsigned int ith, void * data) {
1121
1590
  htp_matmul_preamble;
1122
1591
 
1123
1592
  uint64_t t1, t2;
@@ -1163,13 +1632,13 @@ static void matmul_4d(struct htp_matmul_type * mt, struct htp_ops_context * octx
1163
1632
  for (uint32_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
1164
1633
  for (uint32_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
1165
1634
  for (uint32_t ir1 = iir1; ir1 < MIN(iir1 + blck_1, ir1_end); ir1++) {
1166
- const uint32_t i13 = fastdiv(ir1, &octx->mm_div_ne12_ne1);
1167
- const uint32_t i12 = fastdiv(ir1 - i13 * ne12 * ne1, &octx->mm_div_ne1);
1635
+ const uint32_t i13 = fastdiv(ir1, &mmctx->mm_div_ne12_ne1);
1636
+ const uint32_t i12 = fastdiv(ir1 - i13 * ne12 * ne1, &mmctx->mm_div_ne1);
1168
1637
  const uint32_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1);
1169
1638
 
1170
1639
  // broadcast src0 into src1
1171
- const uint32_t i03 = fastdiv(i13, &octx->mm_div_r3);
1172
- const uint32_t i02 = fastdiv(i12, &octx->mm_div_r2);
1640
+ const uint32_t i03 = fastdiv(i13, &mmctx->mm_div_r3);
1641
+ const uint32_t i02 = fastdiv(i12, &mmctx->mm_div_r2);
1173
1642
 
1174
1643
  const uint32_t i1 = i11;
1175
1644
  const uint32_t i2 = i12;
@@ -1182,7 +1651,7 @@ static void matmul_4d(struct htp_matmul_type * mt, struct htp_ops_context * octx
1182
1651
  const uint32_t ir0_block_end = MIN(iir0 + blck_0, ir0_end);
1183
1652
  for (uint32_t ir0 = iir0; ir0 < ir0_block_end; ir0++) {
1184
1653
  const uint8_t * restrict src0_row = src0_base + ir0 * nb01;
1185
- mt->vec_dot(ne00, &dst_col[ir0], src0_row, src1_col);
1654
+ mmctx->vec_dot_1x1(ne00, &dst_col[ir0], src0_row, src1_col);
1186
1655
  }
1187
1656
  }
1188
1657
  }
@@ -1197,7 +1666,7 @@ static void matmul_4d(struct htp_matmul_type * mt, struct htp_ops_context * octx
1197
1666
  }
1198
1667
 
1199
1668
  // src1 tensor is already in VTCM spad
1200
- static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) {
1669
+ static void matmul_2d(unsigned int nth, unsigned int ith, void * data) {
1201
1670
  htp_matmul_preamble;
1202
1671
 
1203
1672
  const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
@@ -1222,7 +1691,7 @@ static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx
1222
1691
  // Per-thread VTCM scratchpads for all tensors
1223
1692
  // Note that the entire src1 tensor is already in VTCM
1224
1693
  // For other tensors we allocate N rows per thread, padded to HVX vector size
1225
- uint8_t * restrict spad_dst = dst_spad->data + dst_spad->size_per_thread * ith;
1694
+ uint8_t * restrict spad_dst = dst_spad->data + dst_spad->size_per_thread * ith;
1226
1695
  uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith;
1227
1696
  uint8_t * restrict src1_data = src1_spad->data;
1228
1697
 
@@ -1246,11 +1715,21 @@ static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx
1246
1715
  for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
1247
1716
  const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
1248
1717
 
1249
- #pragma unroll(2)
1250
- for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) {
1718
+ // Process src1 columns in pairs (2×2 tiling)
1719
+ uint32_t ir1 = 0;
1720
+ for (; ir1 + 1 < src1_nrows; ir1 += 2) {
1721
+ const uint8_t * restrict src1_col0 = (const uint8_t *) (src1_data + (ir1+0) * src1_stride);
1722
+ const uint8_t * restrict src1_col1 = (const uint8_t *) (src1_data + (ir1+1) * src1_stride);
1723
+ float * restrict dst_row0 = (float *) (dst->data + ((ir1+0) * dst_row_size));
1724
+ float * restrict dst_row1 = (float *) (dst->data + ((ir1+1) * dst_row_size));
1725
+ mmctx->vec_dot_2x2(ne00, &dst_row0[ir0], &dst_row1[ir0], ss0, ss0 + src0_stride, src1_col0, src1_col1);
1726
+ }
1727
+
1728
+ // Handle remaining src1 rows (fallback to 2×1)
1729
+ for (; ir1 < src1_nrows; ++ir1) {
1251
1730
  const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride);
1252
1731
  float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size));
1253
- mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_stride, src1_col);
1732
+ mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_stride, src1_col);
1254
1733
  }
1255
1734
 
1256
1735
  // Prefetch next (n + spad_nrows) row
@@ -1274,20 +1753,20 @@ static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx
1274
1753
  for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) {
1275
1754
  const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride);
1276
1755
  float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size));
1277
- mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col);
1756
+ mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col);
1278
1757
  }
1279
1758
  }
1280
1759
 
1281
1760
  t2 = HAP_perf_get_qtimer_count();
1282
1761
 
1283
- FARF(HIGH, "matmul-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mt->type, ith, nth,
1762
+ FARF(HIGH, "matmul-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mmctx->type, ith, nth,
1284
1763
  src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1],
1285
1764
  src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
1286
1765
  (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
1287
1766
  }
1288
1767
 
1289
1768
  // q8x4x2 src1 tensor is already in VTCM spad
1290
- static void matvec_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) {
1769
+ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) {
1291
1770
  htp_matmul_preamble;
1292
1771
 
1293
1772
  const uint32_t src0_nrows = ne01;
@@ -1338,7 +1817,7 @@ static void matvec_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx
1338
1817
  // Process src0 rows
1339
1818
  for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
1340
1819
  const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
1341
- mt->vec_dot_rx2(ne00, &tmp[ir0 - src0_start_row], ss0, src0_stride, src1_col);
1820
+ mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col);
1342
1821
 
1343
1822
  // Prefetch next (n + spad_nrows) row
1344
1823
  const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
@@ -1356,14 +1835,14 @@ static void matvec_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx
1356
1835
  dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
1357
1836
  src0_stride, src0_row_size, 1);
1358
1837
  const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
1359
- mt->vec_dot(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col);
1838
+ mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col);
1360
1839
  }
1361
1840
 
1362
- hvx_copy_fp32_ua((uint8_t *) &dst_col[src0_start_row], (uint8_t *) tmp, src0_end_row - src0_start_row);
1841
+ hvx_copy_f32_ua((uint8_t *) &dst_col[src0_start_row], (uint8_t *) tmp, src0_end_row - src0_start_row);
1363
1842
 
1364
1843
  t2 = HAP_perf_get_qtimer_count();
1365
1844
 
1366
- FARF(HIGH, "matvec-%s %u/%u: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mt->type, ith, nth,
1845
+ FARF(HIGH, "matvec-%s %u/%u: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mmctx->type, ith, nth,
1367
1846
  src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1],
1368
1847
  src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
1369
1848
  (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
@@ -1377,7 +1856,7 @@ struct mmid_row_mapping {
1377
1856
  };
1378
1857
 
1379
1858
  // src1 tensor is already in VTCM spad
1380
- static void matmul_id(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) {
1859
+ static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
1381
1860
  htp_matmul_preamble;
1382
1861
 
1383
1862
  struct htp_tensor * restrict ids = &octx->src2;
@@ -1411,7 +1890,7 @@ static void matmul_id(struct htp_matmul_type * mt, struct htp_ops_context * octx
1411
1890
  const size_t src0_row_size = nb01;
1412
1891
  const size_t src1_row_size = q8x4x2_row_size(ne10);
1413
1892
 
1414
- const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128);
1893
+ const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128);
1415
1894
 
1416
1895
  // Per-thread VTCM scratchpads for all tensors
1417
1896
  // Note that the entire src1 tensor is already in VTCM
@@ -1450,11 +1929,10 @@ static void matmul_id(struct htp_matmul_type * mt, struct htp_ops_context * octx
1450
1929
  const int rm2 = row_mapping.i2; // token idx
1451
1930
 
1452
1931
  const uint32_t ir1 = src1_nrows == 1 ? 0 : rm1; // src1 row idx
1453
- const uint8_t * restrict src1_col =
1454
- (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size);
1932
+ const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size);
1455
1933
  float * dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0));
1456
1934
 
1457
- mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_row_size_padded, src1_col);
1935
+ mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col);
1458
1936
  }
1459
1937
 
1460
1938
  // Prefetch next (n + spad_nrows) row
@@ -1480,25 +1958,24 @@ static void matmul_id(struct htp_matmul_type * mt, struct htp_ops_context * octx
1480
1958
  const int rm2 = row_mapping.i2; // token idx
1481
1959
 
1482
1960
  const uint32_t ir1 = src1_nrows == 1 ? 0 : rm1; // src1 row idx
1483
- const uint8_t * restrict src1_col =
1484
- (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size);
1961
+ const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size);
1485
1962
  float * dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0));
1486
1963
 
1487
- mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col);
1964
+ mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col);
1488
1965
  }
1489
1966
  }
1490
1967
  }
1491
1968
 
1492
1969
  t2 = HAP_perf_get_qtimer_count();
1493
1970
 
1494
- FARF(HIGH, "matmul-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mt->type,
1971
+ FARF(HIGH, "matmul-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mmctx->type,
1495
1972
  ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0],
1496
1973
  src1->ne[1], src1->ne[2], src1->ne[3], ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1],
1497
1974
  dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
1498
1975
  }
1499
1976
 
1500
1977
  // src1 tensor is already in VTCM spad
1501
- static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) {
1978
+ static void matvec_id(unsigned int nth, unsigned int ith, void * data) {
1502
1979
  htp_matmul_preamble;
1503
1980
 
1504
1981
  struct htp_tensor * restrict ids = &octx->src2;
@@ -1524,7 +2001,7 @@ static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx
1524
2001
  const size_t src0_row_size = nb01;
1525
2002
  const size_t src1_row_size = q8x4x2_row_size(ne10);
1526
2003
 
1527
- const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128);
2004
+ const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128);
1528
2005
 
1529
2006
  const uint32_t n_aids = src2->ne[0]; // num activated experts
1530
2007
  const uint32_t n_ids = ne02; // num experts
@@ -1558,7 +2035,7 @@ static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx
1558
2035
  // Process src0 rows
1559
2036
  for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
1560
2037
  const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
1561
- mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_row_size_padded, src1_col);
2038
+ mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col);
1562
2039
 
1563
2040
  // Prefetch next (n + spad_nrows) row
1564
2041
  const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
@@ -1576,13 +2053,13 @@ static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx
1576
2053
  dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
1577
2054
  src0_row_size_padded, src0_row_size, 1);
1578
2055
  const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
1579
- mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col);
2056
+ mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col);
1580
2057
  }
1581
2058
  }
1582
2059
 
1583
2060
  t2 = HAP_perf_get_qtimer_count();
1584
2061
 
1585
- FARF(HIGH, "matvec-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mt->type,
2062
+ FARF(HIGH, "matvec-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mmctx->type,
1586
2063
  ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0],
1587
2064
  src1->ne[1], src1->ne[2], src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0],
1588
2065
  dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
@@ -1590,18 +2067,18 @@ static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx
1590
2067
 
1591
2068
  // *** dynamic quant
1592
2069
 
1593
- static inline void quantize_block_fp32_q8x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
2070
+ static inline void quantize_block_f32_q8x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
1594
2071
  assert((unsigned long) x % 128 == 0);
1595
2072
  assert((unsigned long) y_q % 128 == 0);
1596
2073
 
1597
2074
  HVX_Vector * vx = (HVX_Vector *) x;
1598
- HVX_Vector zero = Q6_V_vsplat_R(0);
2075
+ HVX_Vector zero = Q6_V_vzero();
1599
2076
 
1600
2077
  // Use reduce max fp32 to find max(abs(e)) first
1601
- HVX_Vector vmax0_sf = hvx_vec_reduce_max_fp32(hvx_vec_abs_fp32(vx[0]));
1602
- HVX_Vector vmax1_sf = hvx_vec_reduce_max_fp32(hvx_vec_abs_fp32(vx[1]));
1603
- HVX_Vector vmax2_sf = hvx_vec_reduce_max_fp32(hvx_vec_abs_fp32(vx[2]));
1604
- HVX_Vector vmax3_sf = hvx_vec_reduce_max_fp32(hvx_vec_abs_fp32(vx[3]));
2078
+ HVX_Vector vmax0_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[0]));
2079
+ HVX_Vector vmax1_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[1]));
2080
+ HVX_Vector vmax2_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[2]));
2081
+ HVX_Vector vmax3_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[3]));
1605
2082
  // Load and convert into QF32
1606
2083
  HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements
1607
2084
  HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements
@@ -1609,10 +2086,10 @@ static inline void quantize_block_fp32_q8x1(float * restrict x, uint8_t * restri
1609
2086
  HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements
1610
2087
 
1611
2088
  // Convert to QF32
1612
- HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero);
1613
- HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero);
1614
- HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero);
1615
- HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero);
2089
+ HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero); // replicated over all lanes
2090
+ HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero); // replicated over all lanes
2091
+ HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero); // replicated over all lanes
2092
+ HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero); // replicated over all lanes
1616
2093
 
1617
2094
  // Combine and convert to fp16
1618
2095
  HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf)));
@@ -1622,11 +2099,6 @@ static inline void quantize_block_fp32_q8x1(float * restrict x, uint8_t * restri
1622
2099
  HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf)));
1623
2100
  HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
1624
2101
 
1625
- // Replicate first fp16 scale across all lanes
1626
- HVX_Vector ctrl = *(const HVX_Vector *) repl_2x_fp16;
1627
- vmax01_hf = Q6_V_vdelta_VV(vmax01_hf, ctrl);
1628
- vmax23_hf = Q6_V_vdelta_VV(vmax23_hf, ctrl);
1629
-
1630
2102
  HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
1631
2103
  HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
1632
2104
  HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16);
@@ -1641,8 +2113,8 @@ static inline void quantize_block_fp32_q8x1(float * restrict x, uint8_t * restri
1641
2113
  hvx_vec_store_u(y_d + 6, 2, rotated_vd_hf);
1642
2114
 
1643
2115
  // Divide input by the scale
1644
- HVX_Vector vd01_inv_hf = hvx_vec_inverse_fp16(vd01_hf);
1645
- HVX_Vector vd23_inv_hf = hvx_vec_inverse_fp16(vd23_hf);
2116
+ HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf);
2117
+ HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf);
1646
2118
  vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf));
1647
2119
  vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf));
1648
2120
 
@@ -1654,14 +2126,14 @@ static inline void quantize_block_fp32_q8x1(float * restrict x, uint8_t * restri
1654
2126
  *(HVX_Vector *) y_q = vx_i8;
1655
2127
  }
1656
2128
 
1657
- static inline void quantize_block_fp32_q8x2(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
2129
+ static inline void quantize_block_f32_q8x2(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
1658
2130
  assert((unsigned long) x % 128 == 0);
1659
2131
  assert((unsigned long) y_q % 128 == 0);
1660
2132
 
1661
2133
  HVX_Vector * vx = (HVX_Vector *) x;
1662
2134
 
1663
2135
  // Load and convert into QF32
1664
- HVX_Vector zero = Q6_V_vsplat_R(0);
2136
+ HVX_Vector zero = Q6_V_vzero();
1665
2137
  HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements
1666
2138
  HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements
1667
2139
  HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements
@@ -1672,13 +2144,8 @@ static inline void quantize_block_fp32_q8x2(float * restrict x, uint8_t * restri
1672
2144
  HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
1673
2145
 
1674
2146
  // Compute max and scale
1675
- HVX_Vector vmax01_hf = hvx_vec_reduce_max_fp16(hvx_vec_abs_fp16(vx01_hf));
1676
- HVX_Vector vmax23_hf = hvx_vec_reduce_max_fp16(hvx_vec_abs_fp16(vx23_hf));
1677
-
1678
- // Replicate first fp16 scale across all lanes
1679
- HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_fp16;
1680
- vmax01_hf = Q6_V_vdelta_VV(vmax01_hf, ctrl);
1681
- vmax23_hf = Q6_V_vdelta_VV(vmax23_hf, ctrl);
2147
+ HVX_Vector vmax01_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf)); // replicated over all lanes
2148
+ HVX_Vector vmax23_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx23_hf)); // replicated over all lanes
1682
2149
 
1683
2150
  HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
1684
2151
  HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
@@ -1689,8 +2156,8 @@ static inline void quantize_block_fp32_q8x2(float * restrict x, uint8_t * restri
1689
2156
  hvx_vec_store_u(y_d + 4, 4, vd23_hf);
1690
2157
 
1691
2158
  // Divide input by the scale
1692
- HVX_Vector vd01_inv_hf = hvx_vec_inverse_fp16(vd01_hf);
1693
- HVX_Vector vd23_inv_hf = hvx_vec_inverse_fp16(vd23_hf);
2159
+ HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf);
2160
+ HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf);
1694
2161
  vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf));
1695
2162
  vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf));
1696
2163
 
@@ -1702,14 +2169,14 @@ static inline void quantize_block_fp32_q8x2(float * restrict x, uint8_t * restri
1702
2169
  *(HVX_Vector *) y_q = vx_i8;
1703
2170
  }
1704
2171
 
1705
- static inline void quantize_block_fp32_q8x4(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
2172
+ static inline void quantize_block_f32_q8x4(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
1706
2173
  assert((unsigned long) x % 128 == 0);
1707
2174
  assert((unsigned long) y_q % 128 == 0);
1708
2175
 
1709
2176
  HVX_Vector * vx = (HVX_Vector *) x;
1710
2177
 
1711
2178
  // Load and convert into QF32
1712
- HVX_Vector zero = Q6_V_vsplat_R(0);
2179
+ HVX_Vector zero = Q6_V_vzero();
1713
2180
  HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements
1714
2181
  HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements
1715
2182
  HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements
@@ -1720,12 +2187,8 @@ static inline void quantize_block_fp32_q8x4(float * restrict x, uint8_t * restri
1720
2187
  HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
1721
2188
 
1722
2189
  // Compute max and scale
1723
- HVX_Vector vmax_hf = hvx_vec_reduce_max_fp16(hvx_vec_abs_fp16(vx01_hf));
1724
- vmax_hf = hvx_vec_reduce_max2_fp16(hvx_vec_abs_fp16(vx23_hf), vmax_hf);
1725
-
1726
- // Replicate first fp16 scale across all lanes
1727
- HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_fp16;
1728
- vmax_hf = Q6_V_vdelta_VV(vmax_hf, ctrl);
2190
+ HVX_Vector vmax_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf));
2191
+ vmax_hf = hvx_vec_reduce_max2_f16(hvx_vec_abs_f16(vx23_hf), vmax_hf); // replicated over all lanes
1729
2192
 
1730
2193
  HVX_Vector vd_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
1731
2194
  HVX_Vector vd_hf = Q6_Vhf_equals_Vqf16(vd_qf16);
@@ -1733,7 +2196,7 @@ static inline void quantize_block_fp32_q8x4(float * restrict x, uint8_t * restri
1733
2196
  *(HVX_UVector *) y_d = vd_hf;
1734
2197
 
1735
2198
  // Divide input by the scale
1736
- HVX_Vector vd_inv_hf = hvx_vec_inverse_fp16(vd_hf);
2199
+ HVX_Vector vd_inv_hf = hvx_vec_inverse_f16(vd_hf);
1737
2200
  vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd_inv_hf));
1738
2201
  vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd_inv_hf));
1739
2202
 
@@ -1746,7 +2209,7 @@ static inline void quantize_block_fp32_q8x4(float * restrict x, uint8_t * restri
1746
2209
  }
1747
2210
 
1748
2211
  // Overrides input x
1749
- static void quantize_row_fp32_q8x4x2(float * restrict x, uint8_t * restrict y, uint32_t k) {
2212
+ static void quantize_row_f32_q8x4x2(float * restrict x, uint8_t * restrict y, uint32_t k) {
1750
2213
  assert(k % 32 == 0);
1751
2214
  const uint32_t qk = QK_Q8_0x4x2;
1752
2215
  const uint32_t nb = (k + qk - 1) / qk;
@@ -1764,29 +2227,31 @@ static void quantize_row_fp32_q8x4x2(float * restrict x, uint8_t * restrict y, u
1764
2227
 
1765
2228
  for (uint32_t i = 0; i < nb; i++) {
1766
2229
  #if FP32_QUANTIZE_GROUP_SIZE == 32
1767
- quantize_block_fp32_q8x1(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);
1768
- quantize_block_fp32_q8x1(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);
2230
+ quantize_block_f32_q8x1(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);
2231
+ quantize_block_f32_q8x1(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);
1769
2232
  #elif FP32_QUANTIZE_GROUP_SIZE == 64
1770
- quantize_block_fp32_q8x2(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);
1771
- quantize_block_fp32_q8x2(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);
2233
+ quantize_block_f32_q8x2(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);
2234
+ quantize_block_f32_q8x2(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);
1772
2235
  #elif FP32_QUANTIZE_GROUP_SIZE == 128
1773
- quantize_block_fp32_q8x4(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);
1774
- quantize_block_fp32_q8x4(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);
2236
+ quantize_block_f32_q8x4(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);
2237
+ quantize_block_f32_q8x4(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);
1775
2238
  #else
1776
2239
  #error "FP32_QUANTIZE_GROUP_SIZE must be 32, 64, or 128"
1777
2240
  #endif
1778
2241
  }
1779
2242
 
1780
2243
  // now copy the scales into final location
1781
- hvx_copy_fp16_ua(y_d, t_d, nb * 8);
2244
+ hvx_copy_f16_ua(y_d, t_d, nb * 8);
1782
2245
  }
1783
2246
 
1784
- static void quantize_fp32_q8x4x2(const struct htp_tensor * src,
1785
- uint8_t * restrict dst,
1786
- struct htp_spad * spad,
1787
- uint32_t nth,
1788
- uint32_t ith,
1789
- uint32_t nrows_per_thread) {
2247
+ static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data) {
2248
+ struct htp_matmul_context * mmctx = data;
2249
+ struct htp_ops_context * octx = mmctx->octx;
2250
+
2251
+ const struct htp_tensor * src = &octx->src1;
2252
+ uint8_t * restrict dst = octx->src1_spad.data;
2253
+ struct htp_spad * spad = &octx->src0_spad;
2254
+ uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread;
1790
2255
 
1791
2256
  uint64_t t1 = HAP_perf_get_qtimer_count();
1792
2257
 
@@ -1807,27 +2272,33 @@ static void quantize_fp32_q8x4x2(const struct htp_tensor * src,
1807
2272
  uint8_t * restrict dst_data = (uint8_t *) dst + (dst_row_size * ir_first);
1808
2273
  uint8_t * restrict tmp_data = (uint8_t *) spad->data + (spad->size_per_thread * ith);
1809
2274
 
1810
- const size_t src_row_size_padded = htp_round_up(src_row_size, QK_Q8_0x4x2 * sizeof(float));
2275
+ const size_t src_row_size_padded = hex_round_up(src_row_size, QK_Q8_0x4x2 * sizeof(float));
1811
2276
  memset(tmp_data, 0, src_row_size_padded); // zero-out temp row data for padding
1812
2277
 
1813
2278
  for (uint32_t i = ir_first; i < ir_last; ++i) {
1814
- htp_l2fetch(src_data, 2, src_row_size, src_row_size);
1815
- hvx_copy_fp32_aa(tmp_data, src_data, ne0);
2279
+ hex_l2fetch(src_data, src_row_size, src_row_size, 2);
2280
+ hvx_copy_f32_aa(tmp_data, src_data, ne0);
1816
2281
 
1817
2282
  // FARF(HIGH, "quantize-q8x4-row: %u\n", i);
1818
- quantize_row_fp32_q8x4x2((float *) tmp_data, dst_data, ne0);
2283
+ quantize_row_f32_q8x4x2((float *) tmp_data, dst_data, ne0);
1819
2284
  dst_data += dst_row_size;
1820
2285
  src_data += src_row_size;
1821
2286
  }
1822
2287
 
1823
2288
  uint64_t t2 = HAP_perf_get_qtimer_count();
1824
2289
 
1825
- FARF(HIGH, "quantize-fp32-q8x4: %u/%u : n-rows %u (%u:%u) row-size %u -> %u usec %u\n", ith, nth, nrows, ir_first,
2290
+ FARF(HIGH, "quantize-f32-q8x4: %u/%u : n-rows %u (%u:%u) row-size %u -> %u usec %u\n", ith, nth, nrows, ir_first,
1826
2291
  ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
1827
2292
  }
1828
2293
 
1829
- static void quantize_fp32_fp16(const struct htp_tensor * src, uint8_t * restrict dst, uint32_t nth, uint32_t ith,
1830
- uint32_t nrows_per_thread, uint32_t dst_stride) {
2294
+ static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) {
2295
+ struct htp_matmul_context * mmctx = data;
2296
+ struct htp_ops_context * octx = mmctx->octx;
2297
+
2298
+ const struct htp_tensor * src = &octx->src1;
2299
+ uint8_t * restrict dst = octx->src1_spad.data;
2300
+ uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread;
2301
+ uint32_t dst_stride = octx->src1_spad.stride;
1831
2302
 
1832
2303
  uint64_t t1 = HAP_perf_get_qtimer_count();
1833
2304
 
@@ -1848,8 +2319,8 @@ static void quantize_fp32_fp16(const struct htp_tensor * src, uint8_t * restrict
1848
2319
  uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first);
1849
2320
 
1850
2321
  for (uint32_t i = ir_first; i < ir_last; ++i) {
1851
- htp_l2fetch(src_data, 2, src_row_size, src_stride);
1852
- hvx_copy_fp16_fp32_au(dst_data, src_data, ne0);
2322
+ hex_l2fetch(src_data, src_row_size, src_stride, 2);
2323
+ hvx_copy_f16_f32_au(dst_data, src_data, ne0);
1853
2324
 
1854
2325
  dst_data += dst_stride;
1855
2326
  src_data += src_stride;
@@ -1857,13 +2328,19 @@ static void quantize_fp32_fp16(const struct htp_tensor * src, uint8_t * restrict
1857
2328
 
1858
2329
  uint64_t t2 = HAP_perf_get_qtimer_count();
1859
2330
 
1860
- FARF(HIGH, "quantize-fp32-fp16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
2331
+ FARF(HIGH, "quantize-f32-f16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
1861
2332
  ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
1862
2333
  }
1863
2334
 
1864
2335
  // TODO just a plain copy that should be done via the DMA during the Op setup
1865
- static void quantize_fp16_fp16(const struct htp_tensor * src, uint8_t * restrict dst, uint32_t nth, uint32_t ith,
1866
- uint32_t nrows_per_thread, uint32_t dst_stride) {
2336
+ static void quantize_f16_f16(unsigned int nth, unsigned int ith, void * data) {
2337
+ struct htp_matmul_context * mmctx = data;
2338
+ struct htp_ops_context * octx = mmctx->octx;
2339
+
2340
+ const struct htp_tensor * src = &octx->src1;
2341
+ uint8_t * restrict dst = octx->src1_spad.data;
2342
+ uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread;
2343
+ uint32_t dst_stride = octx->src1_spad.stride;
1867
2344
 
1868
2345
  uint64_t t1 = HAP_perf_get_qtimer_count();
1869
2346
 
@@ -1884,8 +2361,8 @@ static void quantize_fp16_fp16(const struct htp_tensor * src, uint8_t * restrict
1884
2361
  uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first);
1885
2362
 
1886
2363
  for (uint32_t i = ir_first; i < ir_last; ++i) {
1887
- htp_l2fetch(src_data, 2, src_row_size, src_stride);
1888
- hvx_copy_fp16_au(dst_data, src_data, ne0);
2364
+ hex_l2fetch(src_data, src_row_size, src_stride, 2);
2365
+ hvx_copy_f16_au(dst_data, src_data, ne0);
1889
2366
 
1890
2367
  dst_data += dst_stride;
1891
2368
  src_data += src_stride;
@@ -1893,400 +2370,177 @@ static void quantize_fp16_fp16(const struct htp_tensor * src, uint8_t * restrict
1893
2370
 
1894
2371
  uint64_t t2 = HAP_perf_get_qtimer_count();
1895
2372
 
1896
- FARF(HIGH, "quantize-fp16-fp16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
2373
+ FARF(HIGH, "quantize-f16-f16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
1897
2374
  ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
1898
2375
  }
1899
2376
 
1900
- static void htp_quantize_fp32_q8x4x2(unsigned int n, unsigned int i, void * data) {
1901
- struct htp_ops_context * octx = data;
1902
- quantize_fp32_q8x4x2(&octx->src1, octx->src1_spad.data, &octx->src0_spad, n, i, octx->src1_nrows_per_thread);
1903
- }
1904
-
1905
- static void htp_quantize_fp32_fp16(unsigned int n, unsigned int i, void * data) {
1906
- struct htp_ops_context * octx = data;
1907
- quantize_fp32_fp16(&octx->src1, octx->src1_spad.data, n, i, octx->src1_nrows_per_thread, octx->src1_spad.stride);
1908
- }
1909
-
1910
- static void htp_quantize_fp16_fp16(unsigned int n, unsigned int i, void * data) {
1911
- struct htp_ops_context * octx = data;
1912
- quantize_fp16_fp16(&octx->src1, octx->src1_spad.data, n, i, octx->src1_nrows_per_thread, octx->src1_spad.stride);
1913
- }
1914
-
1915
- // ** matmul/matvec callbacks for worker_pool
1916
-
1917
- static void htp_matvec_2d_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
1918
- struct htp_ops_context * octx = data;
1919
-
1920
- struct htp_matmul_type mt;
1921
- mt.type = "q4x4x2-q8x4x2";
1922
- mt.vec_dot = vec_dot_q4x4x2_q8x4x2;
1923
- mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
1924
-
1925
- matvec_2d(&mt, octx, n, i);
1926
- }
1927
-
1928
- static void htp_matmul_2d_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
1929
- struct htp_ops_context * octx = data;
1930
-
1931
- struct htp_matmul_type mt;
1932
- mt.type = "q4x4x2-q8x4x2";
1933
- mt.vec_dot = vec_dot_q4x4x2_q8x4x2;
1934
- mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
1935
-
1936
- matmul_2d(&mt, octx, n, i);
1937
- }
1938
-
1939
- static void htp_matvec_2d_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
1940
- struct htp_ops_context * octx = data;
1941
-
1942
- struct htp_matmul_type mt;
1943
- mt.type = "q8x4x2-q8x4x2";
1944
- mt.vec_dot = vec_dot_q8x4x2_q8x4x2;
1945
- mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
1946
-
1947
- matvec_2d(&mt, octx, n, i);
1948
- }
1949
-
1950
- static void htp_matmul_2d_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
1951
- struct htp_ops_context * octx = data;
1952
-
1953
- struct htp_matmul_type mt;
1954
- mt.type = "q8x4x2-q8x4x2";
1955
- mt.vec_dot = vec_dot_q8x4x2_q8x4x2;
1956
- mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
1957
-
1958
- matmul_2d(&mt, octx, n, i);
1959
- }
1960
-
1961
- static void htp_matvec_2d_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
1962
- struct htp_ops_context * octx = data;
1963
-
1964
- struct htp_matmul_type mt;
1965
- mt.type = "mxfp4x4x2-q8x4x2";
1966
- mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2;
1967
- mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2;
1968
-
1969
- matvec_2d(&mt, octx, n, i);
1970
- }
1971
-
1972
- static void htp_matmul_2d_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
1973
- struct htp_ops_context * octx = data;
1974
-
1975
- struct htp_matmul_type mt;
1976
- mt.type = "mxfp4x4x2-q8x4x2";
1977
- mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2;
1978
- mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2;
1979
-
1980
- matmul_2d(&mt, octx, n, i);
1981
- }
1982
-
1983
- static void htp_matvec_2d_f16_f16(unsigned int n, unsigned int i, void * data) {
1984
- struct htp_ops_context * octx = data;
1985
-
1986
- struct htp_matmul_type mt;
1987
- mt.type = "f16-f16";
1988
- mt.vec_dot = vec_dot_f16_f16_aa;
1989
- mt.vec_dot_rx2 = vec_dot_f16_f16_aa_rx2;
1990
-
1991
- matvec_2d(&mt, octx, n, i);
1992
- }
1993
-
1994
- static void htp_matmul_2d_f16_f16(unsigned int n, unsigned int i, void * data) {
1995
- struct htp_ops_context * octx = data;
1996
-
1997
- struct htp_matmul_type mt;
1998
- mt.type = "f16-f16";
1999
- mt.vec_dot = vec_dot_f16_f16_aa;
2000
- mt.vec_dot_rx2 = vec_dot_f16_f16_aa_rx2;
2001
-
2002
- matmul_2d(&mt, octx, n, i);
2003
- }
2004
-
2005
- static void htp_matmul_4d_f16_f32(unsigned int n, unsigned int i, void * data) {
2006
- struct htp_ops_context * octx = data;
2007
-
2008
- struct htp_matmul_type mt;
2009
- mt.type = "f16-f32";
2010
- mt.vec_dot = vec_dot_f16_f32_uu;
2011
-
2012
- matmul_4d(&mt, octx, n, i);
2013
- }
2014
-
2015
- static void htp_matmul_4d_f16_f16(unsigned int n, unsigned int i, void * data) {
2016
- struct htp_ops_context * octx = data;
2017
-
2018
- struct htp_matmul_type mt;
2019
- mt.type = "f16-f16";
2020
- mt.vec_dot = vec_dot_f16_f16_uu;
2021
-
2022
- matmul_4d(&mt, octx, n, i);
2023
- }
2024
-
2025
- // ** matmul-id callbacks for worker_pool
2026
-
2027
- static void htp_matvec_id_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
2028
- struct htp_ops_context * octx = data;
2029
-
2030
- struct htp_matmul_type mt;
2031
- mt.type = "q4x4x2-q8x4x2";
2032
- mt.vec_dot = vec_dot_q4x4x2_q8x4x2;
2033
- mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
2034
-
2035
- matvec_id(&mt, octx, n, i);
2036
- }
2037
-
2038
- static void htp_matmul_id_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
2039
- struct htp_ops_context * octx = data;
2040
-
2041
- struct htp_matmul_type mt;
2042
- mt.type = "q4x4x2-q8x4x2";
2043
- mt.vec_dot = vec_dot_q4x4x2_q8x4x2;
2044
- mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
2045
-
2046
- matmul_id(&mt, octx, n, i);
2047
- }
2048
-
2049
- static void htp_matvec_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
2050
- struct htp_ops_context * octx = data;
2051
-
2052
- struct htp_matmul_type mt;
2053
- mt.type = "q8x4x2-q8x4x2";
2054
- mt.vec_dot = vec_dot_q8x4x2_q8x4x2;
2055
- mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
2056
-
2057
- matvec_id(&mt, octx, n, i);
2058
- }
2059
-
2060
- static void htp_matmul_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
2061
- struct htp_ops_context * octx = data;
2062
-
2063
- struct htp_matmul_type mt;
2064
- mt.type = "q8x4x2-q8x4x2";
2065
- mt.vec_dot = vec_dot_q8x4x2_q8x4x2;
2066
- mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
2067
2377
 
2068
- matmul_id(&mt, octx, n, i);
2378
+ static inline bool htp_is_permuted(const struct htp_tensor * t) {
2379
+ return t->nb[0] > t->nb[1] || t->nb[1] > t->nb[2] || t->nb[2] > t->nb[3];
2069
2380
  }
2070
2381
 
2071
- static void htp_matvec_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
2072
- struct htp_ops_context * octx = data;
2073
-
2074
- struct htp_matmul_type mt;
2075
- mt.type = "mxfp4x4x2-q8x4x2";
2076
- mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2;
2077
- mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2;
2078
-
2079
- matvec_id(&mt, octx, n, i);
2382
+ static int htp_mminit_vec_dot(struct htp_matmul_context * mmctx, enum htp_data_type type) {
2383
+ switch (type) {
2384
+ case HTP_TYPE_Q4_0:
2385
+ mmctx->type = "q4x4x2-f32";
2386
+ mmctx->vec_dot_1x1 = vec_dot_q4x4x2_q8x4x2_1x1;
2387
+ mmctx->vec_dot_2x1 = vec_dot_q4x4x2_q8x4x2_2x1;
2388
+ mmctx->vec_dot_2x2 = vec_dot_q4x4x2_q8x4x2_2x2;
2389
+ return 0;
2390
+ case HTP_TYPE_Q8_0:
2391
+ mmctx->type = "q8x4x2-f32";
2392
+ mmctx->vec_dot_1x1 = vec_dot_q8x4x2_q8x4x2_1x1;
2393
+ mmctx->vec_dot_2x1 = vec_dot_q8x4x2_q8x4x2_2x1;
2394
+ mmctx->vec_dot_2x2 = vec_dot_q8x4x2_q8x4x2_2x2;
2395
+ return 0;
2396
+ case HTP_TYPE_MXFP4:
2397
+ mmctx->type = "mxfp4x4x2-f32";
2398
+ mmctx->vec_dot_1x1 = vec_dot_mxfp4x4x2_q8x4x2_1x1;
2399
+ mmctx->vec_dot_2x1 = vec_dot_mxfp4x4x2_q8x4x2_2x1;
2400
+ mmctx->vec_dot_2x2 = vec_dot_mxfp4x4x2_q8x4x2_2x2;
2401
+ return 0;
2402
+ default:
2403
+ return -1;
2404
+ }
2080
2405
  }
2081
2406
 
2082
- static void htp_matmul_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
2083
- struct htp_ops_context * octx = data;
2084
-
2085
- struct htp_matmul_type mt;
2086
- mt.type = "mxfp4x4x2-q8x4x2";
2087
- mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2;
2088
- mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2;
2089
-
2090
- matmul_id(&mt, octx, n, i);
2091
- }
2407
+ static void htp_mminit_spad(struct htp_ops_context * octx,
2408
+ size_t dst_row_size,
2409
+ size_t src0_row_size_padded,
2410
+ size_t src1_row_size,
2411
+ uint32_t src1_nrows,
2412
+ size_t src2_spad_size_per_thread) {
2413
+ octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
2414
+ octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
2415
+ octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256);
2416
+
2417
+ if (src2_spad_size_per_thread > 0) {
2418
+ octx->src2_spad.size_per_thread = src2_spad_size_per_thread;
2419
+ octx->src2_spad.size = octx->src2_spad.size_per_thread;
2420
+ }
2092
2421
 
2093
- // ** main matmul entry point
2422
+ // src0 spad is also used in dynamic quantizer to store padded src1 rows
2423
+ size_t src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
2424
+ if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
2425
+ octx->src0_spad.size_per_thread = src1_row_size_padded;
2426
+ }
2094
2427
 
2095
- static inline bool htp_is_permuted(const struct htp_tensor * t) {
2096
- return t->nb[0] > t->nb[1] || t->nb[1] > t->nb[2] || t->nb[2] > t->nb[3];
2428
+ octx->src1_spad.size = octx->src1_spad.size_per_thread;
2429
+ octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
2430
+ octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
2097
2431
  }
2098
2432
 
2099
2433
  int op_matmul(struct htp_ops_context * octx) {
2100
2434
  htp_matmul_tensors_preamble;
2101
2435
 
2102
- const char * op_type;
2436
+ struct htp_matmul_context mmctx_struct = {0};
2437
+ struct htp_matmul_context * mmctx = &mmctx_struct;
2438
+ mmctx->octx = octx;
2103
2439
 
2104
2440
  const uint32_t src0_nrows = ne01 * ne02 * ne03;
2105
2441
  const uint32_t src1_nrows = ne11 * ne12 * ne13;
2106
2442
 
2443
+ // Compute src0_nrows_per_thread
2444
+ mmctx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads;
2445
+ mmctx->src0_nrows_per_thread += (mmctx->src0_nrows_per_thread & 1); // round up to even
2446
+
2107
2447
  const size_t src0_row_size = nb01;
2108
2448
  const size_t dst_row_size = nb1;
2109
2449
  size_t src1_row_size = nb11;
2110
2450
 
2111
- const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128);
2451
+ const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128);
2112
2452
  size_t src1_row_size_padded;
2113
2453
 
2114
2454
  worker_callback_t quant_job_func;
2115
- worker_callback_t matmul_job_func;
2455
+ worker_callback_t matmul_job_func = src1_nrows > 1 ? matmul_2d : matvec_2d;
2116
2456
 
2117
2457
  bool need_quant = !(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE);
2118
2458
 
2119
- switch (src0->type) {
2120
- case HTP_TYPE_Q4_0:
2121
- op_type = "q4x4x2-fp32";
2122
- quant_job_func = htp_quantize_fp32_q8x4x2;
2123
- if (src1_nrows > 1) {
2124
- matmul_job_func = htp_matmul_2d_q4x4x2_q8x4x2;
2125
- } else {
2126
- matmul_job_func = htp_matvec_2d_q4x4x2_q8x4x2;
2127
- }
2459
+ if (src0->type == HTP_TYPE_F16) {
2460
+ // Try optimized f16-f16 path first (src1 in VTCM)
2461
+ const size_t f16_src1_row_size = hex_round_up(ne10 * 2, 128);
2462
+ const size_t f16_src1_spad_size = hex_round_up(f16_src1_row_size * src1_nrows, 256);
2463
+ const size_t f16_src0_spad_size = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads;
2464
+ const size_t f16_dst_spad_size = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads;
2128
2465
 
2129
- src1_row_size = q8x4x2_row_size(ne10); // row size post quantization
2466
+ const size_t f16_total_size = f16_src1_spad_size + f16_src0_spad_size + f16_dst_spad_size;
2130
2467
 
2131
- // Entire src1 tensor is placed into the VTCM
2132
- // For other tensors we allocate N rows per thread, padded to HVX vector size
2468
+ // Default matmul implementation does not support multi-batch src0 (N-vs-N broadcasting).
2469
+ // It only supports 1-vs-N broadcasting (src0 is 2D) or standard 2D matmul.
2470
+ const bool is_batched = (ne02 > 1) || (ne03 > 1);
2471
+ const bool is_permuted = htp_is_permuted(&octx->src0) || htp_is_permuted(&octx->src1);
2133
2472
 
2134
- octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
2135
- octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
2136
- octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
2473
+ if (!is_batched && !is_permuted && f16_total_size <= octx->ctx->vtcm_size) {
2474
+ // Optimized path
2475
+ quant_job_func = (src1->type == HTP_TYPE_F32) ? quantize_f32_f16 : quantize_f16_f16;
2476
+ mmctx->type = "f16-f16";
2477
+ mmctx->vec_dot_1x1 = vec_dot_f16_f16_aa_1x1;
2478
+ mmctx->vec_dot_2x1 = vec_dot_f16_f16_aa_2x1;
2479
+ mmctx->vec_dot_2x2 = vec_dot_f16_f16_aa_2x2;
2137
2480
 
2138
- // src0 spad is also used in dynamic quantizer to store padded src1 rows
2139
- src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
2140
- if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
2141
- octx->src0_spad.size_per_thread = src1_row_size_padded;
2142
- }
2481
+ src1_row_size = f16_src1_row_size; // row size post quantization
2143
2482
 
2144
- octx->src1_spad.size = octx->src1_spad.size_per_thread;
2145
- octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
2146
- octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
2147
- break;
2148
-
2149
- case HTP_TYPE_Q8_0:
2150
- op_type = "q8x4x2-fp32";
2151
- quant_job_func = htp_quantize_fp32_q8x4x2;
2152
- if (src1_nrows > 1) {
2153
- matmul_job_func = htp_matmul_2d_q8x4x2_q8x4x2;
2154
- } else {
2155
- matmul_job_func = htp_matvec_2d_q8x4x2_q8x4x2;
2156
- }
2157
-
2158
- src1_row_size = q8x4x2_row_size(ne10); // row size post quantization
2159
-
2160
- // Entire src1 tensor is placed into the VTCM
2161
- // For other tensors we allocate N rows per thread, padded to HVX vector size
2162
-
2163
- octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
2164
- octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
2165
- octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
2166
-
2167
- // src0 spad is also used in dynamic quantizer to store padded src1 rows
2168
- src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
2169
- if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
2170
- octx->src0_spad.size_per_thread = src1_row_size_padded;
2171
- }
2483
+ octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
2484
+ octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
2485
+ octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256);
2172
2486
 
2173
2487
  octx->src1_spad.size = octx->src1_spad.size_per_thread;
2174
2488
  octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
2175
2489
  octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
2176
- break;
2177
-
2178
- case HTP_TYPE_MXFP4:
2179
- op_type = "mxfp4x4x2-f32";
2180
- quant_job_func = htp_quantize_fp32_q8x4x2;
2181
- if (src1_nrows > 1) {
2182
- matmul_job_func = htp_matmul_2d_mxfp4x4x2_q8x4x2;
2490
+ } else {
2491
+ // Fallback to f16/f32 (DDR) if src1 doesn't fit in VTCM or broadcasting is required
2492
+ quant_job_func = NULL;
2493
+ if (src1->type == HTP_TYPE_F32) {
2494
+ mmctx->type = "f16-f32";
2495
+ mmctx->vec_dot_1x1 = vec_dot_f16_f32_uu_1x1;
2496
+ matmul_job_func = matmul_4d;
2183
2497
  } else {
2184
- matmul_job_func = htp_matvec_2d_mxfp4x4x2_q8x4x2;
2498
+ mmctx->type = "f16-f16";
2499
+ mmctx->vec_dot_1x1 = vec_dot_f16_f16_uu_1x1;
2500
+ matmul_job_func = matmul_4d;
2185
2501
  }
2186
2502
 
2187
- src1_row_size = q8x4x2_row_size(ne10); // row size post quantization
2503
+ src1_row_size = nb11; // original row size in DDR
2188
2504
 
2189
- // Entire src1 tensor is placed into the VTCM
2190
- // For other tensors we allocate N rows per thread, padded to HVX vector size
2505
+ octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
2506
+ octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256);
2507
+ octx->src1_spad.size_per_thread = hex_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256);
2191
2508
 
2192
- octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
2193
- octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
2194
- octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
2195
-
2196
- // src0 spad is also used in dynamic quantizer to store padded src1 rows
2197
- src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
2198
- if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
2199
- octx->src0_spad.size_per_thread = src1_row_size_padded;
2200
- }
2201
-
2202
- octx->src1_spad.size = octx->src1_spad.size_per_thread;
2203
2509
  octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
2510
+ octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads;
2204
2511
  octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
2205
- break;
2206
2512
 
2207
- case HTP_TYPE_F16:
2208
- {
2209
- // Try optimized f16-f16 path first (src1 in VTCM)
2210
- const size_t f16_src1_row_size = htp_round_up(ne10 * 2, 128);
2211
- const size_t f16_src1_spad_size = htp_round_up(f16_src1_row_size * src1_nrows, 256);
2212
- const size_t f16_src0_spad_size = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads;
2213
- const size_t f16_dst_spad_size = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads;
2214
-
2215
- const size_t f16_total_size = f16_src1_spad_size + f16_src0_spad_size + f16_dst_spad_size;
2216
-
2217
- // Default matmul implementation does not support multi-batch src0 (N-vs-N broadcasting).
2218
- // It only supports 1-vs-N broadcasting (src0 is 2D) or standard 2D matmul.
2219
- const bool is_batched = (ne02 > 1) || (ne03 > 1);
2220
- const bool is_permuted = htp_is_permuted(&octx->src0) || htp_is_permuted(&octx->src1);
2221
-
2222
- if (!is_batched && !is_permuted && f16_total_size <= octx->ctx->vtcm_size) {
2223
- // Optimized path
2224
- op_type = "f16-f16";
2225
- quant_job_func = (src1->type == HTP_TYPE_F32) ? htp_quantize_fp32_fp16 : htp_quantize_fp16_fp16;
2226
- if (src1_nrows > 1) {
2227
- matmul_job_func = htp_matmul_2d_f16_f16;
2228
- } else {
2229
- matmul_job_func = htp_matvec_2d_f16_f16;
2230
- }
2231
-
2232
- src1_row_size = f16_src1_row_size; // row size post quantization
2233
-
2234
- octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
2235
- octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
2236
- octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
2237
-
2238
- octx->src1_spad.size = octx->src1_spad.size_per_thread;
2239
- octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
2240
- octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
2241
- } else {
2242
- // Fallback to f16/f32 (DDR) if src1 doesn't fit in VTCM or broadcasting is required
2243
- quant_job_func = NULL;
2244
- if (src1->type == HTP_TYPE_F32) {
2245
- op_type = "f16-f32";
2246
- matmul_job_func = htp_matmul_4d_f16_f32;
2247
- } else {
2248
- op_type = "f16-f16";
2249
- matmul_job_func = htp_matmul_4d_f16_f16;
2250
- }
2251
-
2252
- src1_row_size = nb11; // original row size in DDR
2253
-
2254
- octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
2255
- octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256);
2256
- octx->src1_spad.size_per_thread = htp_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256);
2257
-
2258
- octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
2259
- octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads;
2260
- octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
2261
-
2262
- // Init fastdiv for matmul_4d (supports broadcasting)
2263
- octx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]);
2264
- octx->mm_div_ne1 = init_fastdiv_values(dst->ne[1]);
2265
- octx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]);
2266
- octx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]);
2267
-
2268
- need_quant = false;
2269
- }
2270
- }
2271
- break;
2513
+ // Init fastdiv for matmul_4d (supports broadcasting)
2514
+ mmctx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]);
2515
+ mmctx->mm_div_ne1 = init_fastdiv_values(dst->ne[1]);
2516
+ mmctx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]);
2517
+ mmctx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]);
2272
2518
 
2273
- default:
2519
+ need_quant = false;
2520
+ }
2521
+ } else {
2522
+ if (htp_mminit_vec_dot(mmctx, src0->type) != 0) {
2274
2523
  return HTP_STATUS_NO_SUPPORT;
2524
+ }
2525
+
2526
+ quant_job_func = quantize_f32_q8x4x2;
2527
+ src1_row_size = q8x4x2_row_size(ne10);
2528
+ htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, 0);
2275
2529
  }
2276
2530
 
2277
2531
  // VTCM scratchpads for all tensors
2278
2532
  size_t spad_size = octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size;
2279
2533
 
2280
- FARF(HIGH, "matmul-%s : src0-spad-size %u src1-spad-size %u dst-spad-size %u (%zu)\n", op_type,
2534
+ FARF(HIGH, "matmul-%s : src0-spad-size %u src1-spad-size %u dst-spad-size %u (%zu)\n", mmctx->type,
2281
2535
  octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size, spad_size);
2282
2536
 
2283
- FARF(HIGH, "matmul-%s : %ux%ux%ux%u * %ux%ux%ux%u-> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", op_type, src0->ne[0],
2537
+ FARF(HIGH, "matmul-%s : %ux%ux%ux%u * %ux%ux%ux%u-> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", mmctx->type, src0->ne[0],
2284
2538
  src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0],
2285
2539
  dst->ne[1], dst->ne[2], dst->ne[3], src0->data, src1->data, dst->data);
2286
2540
 
2287
2541
  // Make sure the reserved vtcm size is sufficient
2288
2542
  if (octx->ctx->vtcm_size < spad_size) {
2289
- FARF(ERROR, "matmul-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type,
2543
+ FARF(ERROR, "matmul-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type,
2290
2544
  octx->ctx->vtcm_size, spad_size);
2291
2545
  return HTP_STATUS_VTCM_TOO_SMALL;
2292
2546
  }
@@ -2295,48 +2549,47 @@ int op_matmul(struct htp_ops_context * octx) {
2295
2549
  octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
2296
2550
  octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
2297
2551
 
2298
- octx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads;
2299
- octx->src0_nrows_per_thread += (octx->src0_nrows_per_thread & 1); // round up to even
2300
-
2301
2552
  octx->src0_spad.stride = src0_row_size_padded;
2302
2553
  octx->src1_spad.stride = src1_row_size;
2303
2554
 
2304
2555
  if (need_quant) {
2305
- // Run quant jobs
2306
- const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads);
2307
- octx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
2308
- worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, octx, n_quant_jobs);
2556
+ const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads);
2557
+ mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
2558
+ worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs);
2309
2559
  }
2310
2560
 
2311
2561
  if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
2312
- // Run matmul jobs
2313
2562
  const uint32_t n_matmul_jobs = octx->n_threads;
2314
- worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, octx, n_matmul_jobs);
2563
+ worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, mmctx, n_matmul_jobs);
2315
2564
  }
2316
2565
 
2317
2566
  return HTP_STATUS_OK;
2318
2567
  }
2319
2568
 
2320
- // ** main matmul-id entry point
2321
-
2322
2569
  int op_matmul_id(struct htp_ops_context * octx) {
2323
2570
  htp_matmul_tensors_preamble;
2324
2571
 
2325
- struct htp_tensor * restrict ids = &octx->src2;
2326
-
2327
- const char * op_type;
2572
+ struct htp_matmul_context mmctx_struct = {0};
2573
+ struct htp_matmul_context * mmctx = &mmctx_struct;
2574
+ mmctx->octx = octx;
2328
2575
 
2329
- worker_callback_t quant_job_func;
2330
- worker_callback_t matmul_id_job_func;
2576
+ struct htp_tensor * restrict ids = &octx->src2;
2331
2577
 
2332
2578
  const size_t src0_row_size = nb01;
2333
2579
  const size_t dst_row_size = nb1;
2334
2580
 
2335
- const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128);
2581
+ const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128);
2336
2582
 
2337
2583
  const uint32_t src0_nrows = ne01; // per expert
2338
2584
  const uint32_t src1_nrows = ne11 * ne12 * ne13;
2339
2585
 
2586
+ worker_callback_t quant_job_func;
2587
+ worker_callback_t matmul_id_job_func = src1_nrows > 1 ? matmul_id : matvec_id;
2588
+
2589
+ // Compute src0_nrows_per_thread
2590
+ mmctx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads;
2591
+ mmctx->src0_nrows_per_thread += (mmctx->src0_nrows_per_thread & 1); // round up to even
2592
+
2340
2593
  size_t src1_row_size;
2341
2594
  size_t src1_row_size_padded;
2342
2595
 
@@ -2347,112 +2600,29 @@ int op_matmul_id(struct htp_ops_context * octx) {
2347
2600
  size_t matrix_row_counts_size = n_as * sizeof(uint32_t);
2348
2601
  size_t matrix_row_map_size = n_as * ids->ne[0] * ids->ne[1] * sizeof(struct mmid_row_mapping);
2349
2602
 
2350
- switch (src0->type) {
2351
- case HTP_TYPE_Q4_0:
2352
- op_type = "q4x2x2-f32";
2353
- quant_job_func = htp_quantize_fp32_q8x4x2;
2354
- src1_row_size = q8x4x2_row_size(ne10); // row size post quantization
2355
- if (src1_nrows > 1) {
2356
- matmul_id_job_func = htp_matmul_id_q4x4x2_q8x4x2;
2357
- } else {
2358
- matmul_id_job_func = htp_matvec_id_q4x4x2_q8x4x2;
2359
- }
2360
-
2361
- // Entire src1 tensor is placed into the VTCM
2362
- // For other tensors we allocate N rows per thread, padded to HVX vector size
2363
- octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
2364
- octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
2365
- octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
2366
- octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
2367
-
2368
- // src0 spad is also used in dynamic quantizer to store padded src1 rows
2369
- src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
2370
- if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
2371
- octx->src0_spad.size_per_thread = src1_row_size_padded;
2372
- }
2373
-
2374
- octx->src2_spad.size = octx->src2_spad.size_per_thread;
2375
- octx->src1_spad.size = octx->src1_spad.size_per_thread;
2376
- octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
2377
- octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
2378
- break;
2379
-
2380
- case HTP_TYPE_Q8_0:
2381
- op_type = "q8x2x2-f32";
2382
- quant_job_func = htp_quantize_fp32_q8x4x2;
2383
- src1_row_size = q8x4x2_row_size(ne10); // row size post quantization
2384
- if (src1_nrows > 1) {
2385
- matmul_id_job_func = htp_matmul_id_q8x4x2_q8x4x2;
2386
- } else {
2387
- matmul_id_job_func = htp_matvec_id_q8x4x2_q8x4x2;
2388
- }
2389
-
2390
- // Entire src1 tensor is placed into the VTCM
2391
- // For other tensors we allocate N rows per thread, padded to HVX vector size
2392
- octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
2393
- octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
2394
- octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
2395
- octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
2396
-
2397
- // src0 spad is also used in dynamic quantizer to store padded src1 rows
2398
- src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
2399
- if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
2400
- octx->src0_spad.size_per_thread = src1_row_size_padded;
2401
- }
2402
-
2403
- octx->src2_spad.size = octx->src2_spad.size_per_thread;
2404
- octx->src1_spad.size = octx->src1_spad.size_per_thread;
2405
- octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
2406
- octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
2407
- break;
2408
-
2409
- case HTP_TYPE_MXFP4:
2410
- op_type = "mxfp4x2x2-f32";
2411
- quant_job_func = htp_quantize_fp32_q8x4x2;
2412
- src1_row_size = q8x4x2_row_size(ne10); // row size post quantization
2413
- if (src1_nrows > 1) {
2414
- matmul_id_job_func = htp_matmul_id_mxfp4x4x2_q8x4x2;
2415
- } else {
2416
- matmul_id_job_func = htp_matvec_id_mxfp4x4x2_q8x4x2;
2417
- }
2418
-
2419
- // Entire src1 tensor is placed into the VTCM
2420
- // For other tensors we allocate N rows per thread, padded to HVX vector size
2421
- octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
2422
- octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
2423
- octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
2424
- octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
2425
-
2426
- // src0 spad is also used in dynamic quantizer to store padded src1 rows
2427
- src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
2428
- if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
2429
- octx->src0_spad.size_per_thread = src1_row_size_padded;
2430
- }
2603
+ if (htp_mminit_vec_dot(mmctx, src0->type) != 0) {
2604
+ return HTP_STATUS_NO_SUPPORT;
2605
+ }
2431
2606
 
2432
- octx->src2_spad.size = octx->src2_spad.size_per_thread;
2433
- octx->src1_spad.size = octx->src1_spad.size_per_thread;
2434
- octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
2435
- octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
2436
- break;
2607
+ quant_job_func = quantize_f32_q8x4x2;
2608
+ src1_row_size = q8x4x2_row_size(ne10);
2437
2609
 
2438
- default:
2439
- return HTP_STATUS_NO_SUPPORT;
2440
- }
2610
+ const size_t src2_spad_size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
2611
+ htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, src2_spad_size_per_thread);
2441
2612
 
2442
2613
  size_t spad_size = octx->src2_spad.size + octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size;
2443
2614
 
2444
- FARF(HIGH, "matmul-id-%s : src0-spad-size %u src1-spad-size %u src2-spad-size %u dst-spad-size %u (%zu)\n", op_type,
2615
+ FARF(HIGH, "matmul-id-%s : src0-spad-size %u src1-spad-size %u src2-spad-size %u dst-spad-size %u (%zu)\n", mmctx->type,
2445
2616
  octx->src0_spad.size, octx->src1_spad.size, octx->src2_spad.size, octx->dst_spad.size, spad_size);
2446
2617
 
2447
- FARF(HIGH, "matmul-id-%s : %ux%ux%ux%u * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", op_type,
2618
+ FARF(HIGH, "matmul-id-%s : %ux%ux%ux%u * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", mmctx->type,
2448
2619
  src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
2449
2620
  ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->data,
2450
2621
  src1->data, dst->data);
2451
2622
 
2452
2623
  // Make sure the reserved vtcm size is sufficient
2453
2624
  if (octx->ctx->vtcm_size < spad_size) {
2454
- FARF(ERROR, "matmul-id-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type,
2455
- octx->ctx->vtcm_size, spad_size);
2625
+ FARF(ERROR, "matmul-id-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type, octx->ctx->vtcm_size, spad_size);
2456
2626
  return HTP_STATUS_VTCM_TOO_SMALL;
2457
2627
  }
2458
2628
 
@@ -2461,8 +2631,8 @@ int op_matmul_id(struct htp_ops_context * octx) {
2461
2631
  octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size;
2462
2632
  octx->dst_spad.data = octx->src2_spad.data + octx->src2_spad.size;
2463
2633
 
2464
- octx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads;
2465
- octx->src0_nrows_per_thread += (octx->src0_nrows_per_thread & 1); // round up to even
2634
+ octx->src0_spad.stride = src0_row_size_padded;
2635
+ octx->src1_spad.stride = src1_row_size;
2466
2636
 
2467
2637
  if (src1_nrows > 1) {
2468
2638
  // initialize matrix_row_counts and map
@@ -2474,8 +2644,7 @@ int op_matmul_id(struct htp_ops_context * octx) {
2474
2644
  // group rows by src0 matrix
2475
2645
  for (uint32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { // token idx
2476
2646
  for (uint32_t id = 0; id < n_ids; ++id) { // expert idx
2477
- const uint32_t i02 =
2478
- *(const uint32_t *) ((const uint8_t *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]);
2647
+ const uint32_t i02 = *(const uint32_t *) ((const uint8_t *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]);
2479
2648
 
2480
2649
  assert(i02 >= 0 && i02 < n_as);
2481
2650
 
@@ -2487,16 +2656,14 @@ int op_matmul_id(struct htp_ops_context * octx) {
2487
2656
 
2488
2657
  // Setup worker pool callbacks
2489
2658
  if (!(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE)) {
2490
- // Run quant jobs
2491
2659
  const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads);
2492
- octx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
2493
- worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, octx, n_quant_jobs);
2660
+ mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
2661
+ worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs);
2494
2662
  }
2495
2663
 
2496
2664
  if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
2497
- // Run matmul-id jobs
2498
2665
  const uint32_t n_matmul_jobs = octx->n_threads;
2499
- worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, octx, n_matmul_jobs);
2666
+ worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, mmctx, n_matmul_jobs);
2500
2667
  }
2501
2668
 
2502
2669
  return HTP_STATUS_OK;